diff --git a/kernel/src/vx_spawn.c b/kernel/src/vx_spawn.c index fb36b0bc..8e5002f4 100644 --- a/kernel/src/vx_spawn.c +++ b/kernel/src/vx_spawn.c @@ -74,27 +74,6 @@ static void __attribute__ ((noinline)) spawn_tasks_all_stub() { } } -static void __attribute__ ((noinline)) spawn_tasks_contiguous_all_stub() { - int NT = vx_num_threads(); - int NW = vx_num_warps(); - int cid = vx_core_id(); - int wid = vx_warp_id(); - int tid = vx_thread_id(); - - wspawn_tasks_args_t* p_wspawn_args = (wspawn_tasks_args_t*)g_wspawn_args[cid]; - - // FIXME: handle RW - int waves = p_wspawn_args->NWs; - int offset = p_wspawn_args->offset + (NT * wid + tid); - - vx_spawn_tasks_cb callback = p_wspawn_args->callback; - void* arg = p_wspawn_args->arg; - for (int wave_id = 0; wave_id < waves; ++wave_id) { - int task_id = offset + (wave_id * NT * NW); - callback(task_id, arg); - } -} - static void __attribute__ ((noinline)) spawn_tasks_cluster_all_stub() { int NT = vx_num_threads(); int NW = vx_num_warps(); @@ -109,11 +88,13 @@ static void __attribute__ ((noinline)) spawn_tasks_cluster_all_stub() { wspawn_tasks_args_t* p_wspawn_args = (wspawn_tasks_args_t*)g_wspawn_args[cid]; // FIXME: handle RW - int waves = p_wspawn_args->NWs; + int waves = p_wspawn_args->NWs + (wid < p_wspawn_args->RWs); int offset = p_wspawn_args->offset + (NT * wid_in_cluster + tid); vx_spawn_tasks_cb callback = p_wspawn_args->callback; void* arg = p_wspawn_args->arg; + + // sequential iterations for (int wave_id = 0; wave_id < waves; ++wave_id) { int task_id = offset + (wave_id * NT * NW * CORES_PER_CLUSTER); callback(task_id, arg); @@ -171,6 +152,9 @@ static void __attribute__ ((noinline)) spawn_tasks_all_cb() { vx_tmc_zero(); } +// This function runs in every core, but with only 1 warp and 1 thread enabled. +// The logic in this function figures out how many warps/threads this particular +// core has to enable to fulfill an entire grid of computation. void vx_spawn_tasks_cluster(int num_tasks, vx_spawn_tasks_cb callback, void *arg) { // device specs int NC = vx_num_cores(); @@ -181,45 +165,49 @@ void vx_spawn_tasks_cluster(int num_tasks, vx_spawn_tasks_cb callback, void *arg int core_id = vx_core_id(); if (core_id >= NUM_CORES_MAX) return; + const int cluster_id = core_id / CORES_PER_CLUSTER; + const int core_id_in_cluster = core_id % CORES_PER_CLUSTER; // Distribute threads equally across as many cores as possible, even if they // don't fill up NW*NT in a single core. This makes sure the warps get evenly // distributed in a single cluster // // TODO: Try to contain in a single cluster if possible? - int num_active_cores = (num_tasks > NT) ? (num_tasks / NT) : 1; - num_active_cores = MIN(num_active_cores, NC); + const int num_active_cores = (num_tasks + (NT - 1)) / NT; if (core_id >= num_active_cores) return; // terminate extra cores - int tasks_per_core = num_tasks / num_active_cores; - int tasks_per_core_last = tasks_per_core; - if (core_id == (num_active_cores - 1)) { - int rem = num_tasks % num_active_cores; - tasks_per_core_last += rem; // last core also executes remaining tasks + // FIXME: HARDCODES 1 CLUSTER! + const int num_tasks_this_cluster = num_tasks; + const int num_full_warps = num_tasks_this_cluster / NT; + const int rem_threads_in_last_warp = num_tasks_this_cluster % NT; + // const int num_warps = (num_tasks_this_cluster + (NT - 1)) / NT; + + int num_warps_this_core = num_full_warps / CORES_PER_CLUSTER; + const int num_warps_in_last_row = num_full_warps % CORES_PER_CLUSTER; + if (core_id_in_cluster < num_warps_in_last_row) { + num_warps_this_core++; + } + // if 0, last warp is full-threads enabled + int rem_threads_in_last_warp_this_core = 0; + if (rem_threads_in_last_warp != 0) { + if (core_id_in_cluster == num_warps_in_last_row - 1) { + rem_threads_in_last_warp_this_core = rem_threads_in_last_warp; + } } - int num_full_warps = tasks_per_core_last / NT; - int rem_threads_in_last_warp = tasks_per_core_last % NT; // sequential iterations - int num_full_waves = 1; - int rem_warps_in_last_wave = 0; - if (num_full_warps >= NW) { - // this division will result in the same value for both the last core and - // the rest - num_full_waves = num_full_warps / NW; - rem_warps_in_last_wave = num_full_warps % NW; - } + const int num_full_waves = num_warps_this_core / NW; + const int rem_full_warps_in_last_wave = num_warps_this_core % NW; - int cluster_id = core_id / CORES_PER_CLUSTER; - const int tasks_per_cluster = tasks_per_core * CORES_PER_CLUSTER; - const int offset = cluster_id * tasks_per_cluster; - wspawn_tasks_args_t wspawn_args = {callback, arg, offset, num_full_waves, rem_warps_in_last_wave}; + const const int offset = cluster_id * num_tasks_this_cluster; + wspawn_tasks_args_t wspawn_args = {callback, arg, offset, num_full_waves, + rem_full_warps_in_last_wave}; g_wspawn_args[core_id] = &wspawn_args; - if (num_full_warps >= 1) { + if (num_warps_this_core > 0) { // execute callback on other warps - int nw = MIN(num_full_warps, NW); + const int nw = MIN(num_warps_this_core, NW); vx_wspawn(nw, spawn_tasks_cluster_all_cb); // activate all threads @@ -235,12 +223,16 @@ void vx_spawn_tasks_cluster(int num_tasks, vx_spawn_tasks_cb callback, void *arg vx_wspawn_wait(); } - if (rem_threads_in_last_warp != 0) { + // TODO: Instead of launching an additional wave just to work on remaining + // threads, handle this in the last wave amongst other full warps. + if (rem_threads_in_last_warp != 0 && core_id_in_cluster == 0) { // adjust offset - wspawn_args.offset += (tasks_per_core_last - rem_threads_in_last_warp); + // FIXME: consider cluster_id here + // FIXME: use rem_threads_in_last_warp_this_core + wspawn_args.offset += (num_tasks_this_cluster - rem_threads_in_last_warp); // activate remaining threads - int tmask = (1 << rem_threads_in_last_warp) - 1; + const int tmask = (1 << rem_threads_in_last_warp) - 1; vx_tmc(tmask); // call stub routine