diff --git a/kernel/src/vx_spawn.c b/kernel/src/vx_spawn.c index eb0bdb90..c4c00a06 100644 --- a/kernel/src/vx_spawn.c +++ b/kernel/src/vx_spawn.c @@ -15,6 +15,8 @@ #include #include +#define CORES_PER_CLUSTER 2 + #ifdef __cplusplus extern "C" { #endif @@ -95,6 +97,30 @@ static void __attribute__ ((noinline)) spawn_tasks_contiguous_all_stub() { } } +static void __attribute__ ((noinline)) spawn_tasks_cluster_all_stub() { + int NT = vx_num_threads(); + int NW = vx_num_warps(); + int cid = vx_core_id(); + int wid = vx_warp_id(); + int tid = vx_thread_id(); + + const int core_id_in_cluster = vx_core_id() % CORES_PER_CLUSTER; + const int cluster_wid = CORES_PER_CLUSTER * wid + core_id_in_cluster; + + wspawn_tasks_args_t* p_wspawn_args = (wspawn_tasks_args_t*)g_wspawn_args[cid]; + + // FIXME: handle RW + int waves = p_wspawn_args->NWs; + int offset = p_wspawn_args->offset + (NT * cluster_wid + tid); + + vx_spawn_tasks_cb callback = p_wspawn_args->callback; + void* arg = p_wspawn_args->arg; + for (int wave_id = 0; wave_id < waves; ++wave_id) { + int task_id = offset + (wave_id * NT * NW * CORES_PER_CLUSTER); + callback(task_id, arg); + } +} + static void __attribute__ ((noinline)) spawn_tasks_rem_stub() { int cid = vx_core_id(); int tid = vx_thread_id(); @@ -110,7 +136,7 @@ static void __attribute__ ((noinline)) spawn_tasks_all_cb() { // call stub routine // spawn_tasks_all_stub(); - spawn_tasks_contiguous_all_stub(); + spawn_tasks_cluster_all_stub(); // disable warp vx_tmc_zero(); @@ -151,7 +177,11 @@ void vx_spawn_tasks(int num_tasks, vx_spawn_tasks_cb callback , void * arg) { rW = TW - fW * NW; // remaining warps } - wspawn_tasks_args_t wspawn_args = { callback, arg, core_id * tasks_per_core, fW, rW }; + int cluster_id = core_id / CORES_PER_CLUSTER; + const int tasks_per_cluster = tasks_per_core * CORES_PER_CLUSTER; + const int offset = cluster_id * tasks_per_cluster; + wspawn_tasks_args_t wspawn_args = { callback, arg, offset, fW, rW }; + // wspawn_tasks_args_t wspawn_args = { callback, arg, core_id * tasks_per_core, fW, rW }; g_wspawn_args[core_id] = &wspawn_args; if (TW >= 1) { @@ -163,7 +193,7 @@ void vx_spawn_tasks(int num_tasks, vx_spawn_tasks_cb callback , void * arg) { vx_tmc(-1); // call stub routine - spawn_tasks_contiguous_all_stub(); + spawn_tasks_cluster_all_stub(); // back to single-threaded vx_tmc_one();