sgemm_tcore: Hardcode threadblock id 0
this is fine since we're statically dispatching only one threadblock to the whole cluster.
This commit is contained in:
@@ -547,7 +547,7 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
|
||||
const uint32_t threadblock_dim_y,
|
||||
/*const uint32_t threadblock_id_x,
|
||||
const uint32_t threadblock_id_y,*/
|
||||
const uint32_t threadblock_id_in_cluster,
|
||||
// const uint32_t threadblock_id_in_cluster,
|
||||
float *sharedmem_per_threadblock) {
|
||||
const float *A = (const float *)arg->addr_a;
|
||||
const float *B = (const float *)arg->addr_b;
|
||||
@@ -602,7 +602,7 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
|
||||
global_dmem_load(dim_n, dim_k, 0 /*k*/, A, B, local_a, local_b,
|
||||
tid_in_warpgroup, block_n, block_m);
|
||||
|
||||
threadblock_barrier(threadblock_id_in_cluster, threadblock_dim_y);
|
||||
threadblock_barrier(0/*threadblock_id_in_cluster*/, threadblock_dim_y);
|
||||
}
|
||||
|
||||
// NOTE: this *should* be signed integer to trigger arithmetic
|
||||
@@ -633,11 +633,11 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
|
||||
local_a_produce, local_b_produce, tid_in_warpgroup,
|
||||
block_n, block_m);
|
||||
|
||||
threadblock_barrier(threadblock_id_in_cluster, threadblock_dim_y);
|
||||
threadblock_barrier(0/*threadblock_id_in_cluster*/, threadblock_dim_y);
|
||||
}
|
||||
|
||||
// sync with final consumer stage in the k-loop
|
||||
threadblock_barrier(threadblock_id_in_cluster, threadblock_dim_y);
|
||||
threadblock_barrier(0/*threadblock_id_in_cluster*/, threadblock_dim_y);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
@@ -650,7 +650,7 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
|
||||
initialize_C(1);
|
||||
|
||||
// sync with initial producer stage in the k-loop
|
||||
threadblock_barrier(threadblock_id_in_cluster, threadblock_dim_y);
|
||||
threadblock_barrier(0/*threadblock_id_in_cluster*/, threadblock_dim_y);
|
||||
|
||||
// NOTE: this *should* be signed integer to trigger arithmetic
|
||||
// right-shift
|
||||
@@ -718,7 +718,7 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
|
||||
}
|
||||
}
|
||||
|
||||
threadblock_barrier(threadblock_id_in_cluster, threadblock_dim_y);
|
||||
threadblock_barrier(0/*threadblock_id_in_cluster*/, threadblock_dim_y);
|
||||
|
||||
#else
|
||||
|
||||
@@ -819,7 +819,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
const int warp_id = vx_warp_id();
|
||||
thread_block_gemm(arg, tid_in_threadblock, threads_per_threadblock,
|
||||
threadblock_dim_x, threadblock_dim_y, /*threadblock_id_x,
|
||||
threadblock_id_y,*/ threadblock_id_in_cluster,
|
||||
threadblock_id_y,*/ /*threadblock_id_in_cluster, */
|
||||
sharedmem_per_threadblock);
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user