sgemm_tcore: Hardcode threadblock id 0

this is fine since we're statically dispatching only one threadblock to
the whole cluster.
This commit is contained in:
Hansung Kim
2024-06-07 16:08:40 -07:00
parent 856596cbb3
commit 3a6427a491

View File

@@ -547,7 +547,7 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
const uint32_t threadblock_dim_y,
/*const uint32_t threadblock_id_x,
const uint32_t threadblock_id_y,*/
const uint32_t threadblock_id_in_cluster,
// const uint32_t threadblock_id_in_cluster,
float *sharedmem_per_threadblock) {
const float *A = (const float *)arg->addr_a;
const float *B = (const float *)arg->addr_b;
@@ -602,7 +602,7 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
global_dmem_load(dim_n, dim_k, 0 /*k*/, A, B, local_a, local_b,
tid_in_warpgroup, block_n, block_m);
threadblock_barrier(threadblock_id_in_cluster, threadblock_dim_y);
threadblock_barrier(0/*threadblock_id_in_cluster*/, threadblock_dim_y);
}
// NOTE: this *should* be signed integer to trigger arithmetic
@@ -633,11 +633,11 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
local_a_produce, local_b_produce, tid_in_warpgroup,
block_n, block_m);
threadblock_barrier(threadblock_id_in_cluster, threadblock_dim_y);
threadblock_barrier(0/*threadblock_id_in_cluster*/, threadblock_dim_y);
}
// sync with final consumer stage in the k-loop
threadblock_barrier(threadblock_id_in_cluster, threadblock_dim_y);
threadblock_barrier(0/*threadblock_id_in_cluster*/, threadblock_dim_y);
}
}
} else {
@@ -650,7 +650,7 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
initialize_C(1);
// sync with initial producer stage in the k-loop
threadblock_barrier(threadblock_id_in_cluster, threadblock_dim_y);
threadblock_barrier(0/*threadblock_id_in_cluster*/, threadblock_dim_y);
// NOTE: this *should* be signed integer to trigger arithmetic
// right-shift
@@ -718,7 +718,7 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
}
}
threadblock_barrier(threadblock_id_in_cluster, threadblock_dim_y);
threadblock_barrier(0/*threadblock_id_in_cluster*/, threadblock_dim_y);
#else
@@ -819,7 +819,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
const int warp_id = vx_warp_id();
thread_block_gemm(arg, tid_in_threadblock, threads_per_threadblock,
threadblock_dim_x, threadblock_dim_y, /*threadblock_id_x,
threadblock_id_y,*/ threadblock_id_in_cluster,
threadblock_id_y,*/ /*threadblock_id_in_cluster, */
sharedmem_per_threadblock);
}