sgemm_wg: ifdef-guard cluster specific code
This commit is contained in:
@@ -130,8 +130,13 @@ void kernel_body(int task_id, kernel_arg_t* __UNIFORM__ arg) {
|
||||
// across the threadblock
|
||||
|
||||
const uint32_t threads_per_threadblock = (BM * BN) / (TM * TN);
|
||||
#ifdef RADIANCE
|
||||
const uint32_t threadblocks_per_core =
|
||||
vx_num_threads() * vx_num_warps() / (threads_per_threadblock / CORES_PER_CLUSTER);
|
||||
#else
|
||||
const uint32_t threadblocks_per_core =
|
||||
vx_num_threads() * vx_num_warps() / threads_per_threadblock;
|
||||
#endif
|
||||
const uint32_t threadblock_dim_x = vx_num_threads();
|
||||
const uint32_t threadblock_dim_y = vx_num_warps() / threadblocks_per_core;
|
||||
const int threadblock_id = task_id / threads_per_threadblock;
|
||||
@@ -156,6 +161,12 @@ void kernel_body(int task_id, kernel_arg_t* __UNIFORM__ arg) {
|
||||
int main() {
|
||||
kernel_arg_t *arg = (kernel_arg_t *)KERNEL_ARG_DEV_MEM_ADDR;
|
||||
const uint32_t grid_size = arg->dim_m * arg->dim_n / (TM * TN);
|
||||
vx_spawn_tasks(grid_size, (vx_spawn_tasks_cb)kernel_body, arg);
|
||||
#ifdef RADIANCE
|
||||
vx_spawn_tasks_cluster(grid_size, (vx_spawn_tasks_cb)kernel_body, arg);
|
||||
#else
|
||||
// NOTE: This kernel assumes contiguous thread scheduling for threadblock
|
||||
// allocation, and therefore does not work with original vx_spawn_tasks
|
||||
vx_spawn_tasks_contiguous(grid_size, (vx_spawn_tasks_cb)kernel_body, arg);
|
||||
#endif
|
||||
return 0;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user