vx_spawn.c: Rewrite cluster-based vx_spawn_tasks variant
Implements round-robin allocation of warps to cores & maintains contiguous thread ID allocation to neighboring threads. Also handles partially-enabled remainder warp logic. TODO: Hardcodes only 1 cluster in the system.
This commit is contained in:
@@ -74,27 +74,6 @@ static void __attribute__ ((noinline)) spawn_tasks_all_stub() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static void __attribute__ ((noinline)) spawn_tasks_contiguous_all_stub() {
|
|
||||||
int NT = vx_num_threads();
|
|
||||||
int NW = vx_num_warps();
|
|
||||||
int cid = vx_core_id();
|
|
||||||
int wid = vx_warp_id();
|
|
||||||
int tid = vx_thread_id();
|
|
||||||
|
|
||||||
wspawn_tasks_args_t* p_wspawn_args = (wspawn_tasks_args_t*)g_wspawn_args[cid];
|
|
||||||
|
|
||||||
// FIXME: handle RW
|
|
||||||
int waves = p_wspawn_args->NWs;
|
|
||||||
int offset = p_wspawn_args->offset + (NT * wid + tid);
|
|
||||||
|
|
||||||
vx_spawn_tasks_cb callback = p_wspawn_args->callback;
|
|
||||||
void* arg = p_wspawn_args->arg;
|
|
||||||
for (int wave_id = 0; wave_id < waves; ++wave_id) {
|
|
||||||
int task_id = offset + (wave_id * NT * NW);
|
|
||||||
callback(task_id, arg);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
static void __attribute__ ((noinline)) spawn_tasks_cluster_all_stub() {
|
static void __attribute__ ((noinline)) spawn_tasks_cluster_all_stub() {
|
||||||
int NT = vx_num_threads();
|
int NT = vx_num_threads();
|
||||||
int NW = vx_num_warps();
|
int NW = vx_num_warps();
|
||||||
@@ -109,11 +88,13 @@ static void __attribute__ ((noinline)) spawn_tasks_cluster_all_stub() {
|
|||||||
wspawn_tasks_args_t* p_wspawn_args = (wspawn_tasks_args_t*)g_wspawn_args[cid];
|
wspawn_tasks_args_t* p_wspawn_args = (wspawn_tasks_args_t*)g_wspawn_args[cid];
|
||||||
|
|
||||||
// FIXME: handle RW
|
// FIXME: handle RW
|
||||||
int waves = p_wspawn_args->NWs;
|
int waves = p_wspawn_args->NWs + (wid < p_wspawn_args->RWs);
|
||||||
int offset = p_wspawn_args->offset + (NT * wid_in_cluster + tid);
|
int offset = p_wspawn_args->offset + (NT * wid_in_cluster + tid);
|
||||||
|
|
||||||
vx_spawn_tasks_cb callback = p_wspawn_args->callback;
|
vx_spawn_tasks_cb callback = p_wspawn_args->callback;
|
||||||
void* arg = p_wspawn_args->arg;
|
void* arg = p_wspawn_args->arg;
|
||||||
|
|
||||||
|
// sequential iterations
|
||||||
for (int wave_id = 0; wave_id < waves; ++wave_id) {
|
for (int wave_id = 0; wave_id < waves; ++wave_id) {
|
||||||
int task_id = offset + (wave_id * NT * NW * CORES_PER_CLUSTER);
|
int task_id = offset + (wave_id * NT * NW * CORES_PER_CLUSTER);
|
||||||
callback(task_id, arg);
|
callback(task_id, arg);
|
||||||
@@ -171,6 +152,9 @@ static void __attribute__ ((noinline)) spawn_tasks_all_cb() {
|
|||||||
vx_tmc_zero();
|
vx_tmc_zero();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// This function runs in every core, but with only 1 warp and 1 thread enabled.
|
||||||
|
// The logic in this function figures out how many warps/threads this particular
|
||||||
|
// core has to enable to fulfill an entire grid of computation.
|
||||||
void vx_spawn_tasks_cluster(int num_tasks, vx_spawn_tasks_cb callback, void *arg) {
|
void vx_spawn_tasks_cluster(int num_tasks, vx_spawn_tasks_cb callback, void *arg) {
|
||||||
// device specs
|
// device specs
|
||||||
int NC = vx_num_cores();
|
int NC = vx_num_cores();
|
||||||
@@ -181,45 +165,49 @@ void vx_spawn_tasks_cluster(int num_tasks, vx_spawn_tasks_cb callback, void *arg
|
|||||||
int core_id = vx_core_id();
|
int core_id = vx_core_id();
|
||||||
if (core_id >= NUM_CORES_MAX)
|
if (core_id >= NUM_CORES_MAX)
|
||||||
return;
|
return;
|
||||||
|
const int cluster_id = core_id / CORES_PER_CLUSTER;
|
||||||
|
const int core_id_in_cluster = core_id % CORES_PER_CLUSTER;
|
||||||
|
|
||||||
// Distribute threads equally across as many cores as possible, even if they
|
// Distribute threads equally across as many cores as possible, even if they
|
||||||
// don't fill up NW*NT in a single core. This makes sure the warps get evenly
|
// don't fill up NW*NT in a single core. This makes sure the warps get evenly
|
||||||
// distributed in a single cluster
|
// distributed in a single cluster
|
||||||
//
|
//
|
||||||
// TODO: Try to contain in a single cluster if possible?
|
// TODO: Try to contain in a single cluster if possible?
|
||||||
int num_active_cores = (num_tasks > NT) ? (num_tasks / NT) : 1;
|
const int num_active_cores = (num_tasks + (NT - 1)) / NT;
|
||||||
num_active_cores = MIN(num_active_cores, NC);
|
|
||||||
if (core_id >= num_active_cores)
|
if (core_id >= num_active_cores)
|
||||||
return; // terminate extra cores
|
return; // terminate extra cores
|
||||||
|
|
||||||
int tasks_per_core = num_tasks / num_active_cores;
|
// FIXME: HARDCODES 1 CLUSTER!
|
||||||
int tasks_per_core_last = tasks_per_core;
|
const int num_tasks_this_cluster = num_tasks;
|
||||||
if (core_id == (num_active_cores - 1)) {
|
const int num_full_warps = num_tasks_this_cluster / NT;
|
||||||
int rem = num_tasks % num_active_cores;
|
const int rem_threads_in_last_warp = num_tasks_this_cluster % NT;
|
||||||
tasks_per_core_last += rem; // last core also executes remaining tasks
|
// const int num_warps = (num_tasks_this_cluster + (NT - 1)) / NT;
|
||||||
|
|
||||||
|
int num_warps_this_core = num_full_warps / CORES_PER_CLUSTER;
|
||||||
|
const int num_warps_in_last_row = num_full_warps % CORES_PER_CLUSTER;
|
||||||
|
if (core_id_in_cluster < num_warps_in_last_row) {
|
||||||
|
num_warps_this_core++;
|
||||||
|
}
|
||||||
|
// if 0, last warp is full-threads enabled
|
||||||
|
int rem_threads_in_last_warp_this_core = 0;
|
||||||
|
if (rem_threads_in_last_warp != 0) {
|
||||||
|
if (core_id_in_cluster == num_warps_in_last_row - 1) {
|
||||||
|
rem_threads_in_last_warp_this_core = rem_threads_in_last_warp;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
int num_full_warps = tasks_per_core_last / NT;
|
|
||||||
int rem_threads_in_last_warp = tasks_per_core_last % NT;
|
|
||||||
// sequential iterations
|
// sequential iterations
|
||||||
int num_full_waves = 1;
|
const int num_full_waves = num_warps_this_core / NW;
|
||||||
int rem_warps_in_last_wave = 0;
|
const int rem_full_warps_in_last_wave = num_warps_this_core % NW;
|
||||||
if (num_full_warps >= NW) {
|
|
||||||
// this division will result in the same value for both the last core and
|
|
||||||
// the rest
|
|
||||||
num_full_waves = num_full_warps / NW;
|
|
||||||
rem_warps_in_last_wave = num_full_warps % NW;
|
|
||||||
}
|
|
||||||
|
|
||||||
int cluster_id = core_id / CORES_PER_CLUSTER;
|
const const int offset = cluster_id * num_tasks_this_cluster;
|
||||||
const int tasks_per_cluster = tasks_per_core * CORES_PER_CLUSTER;
|
wspawn_tasks_args_t wspawn_args = {callback, arg, offset, num_full_waves,
|
||||||
const int offset = cluster_id * tasks_per_cluster;
|
rem_full_warps_in_last_wave};
|
||||||
wspawn_tasks_args_t wspawn_args = {callback, arg, offset, num_full_waves, rem_warps_in_last_wave};
|
|
||||||
g_wspawn_args[core_id] = &wspawn_args;
|
g_wspawn_args[core_id] = &wspawn_args;
|
||||||
|
|
||||||
if (num_full_warps >= 1) {
|
if (num_warps_this_core > 0) {
|
||||||
// execute callback on other warps
|
// execute callback on other warps
|
||||||
int nw = MIN(num_full_warps, NW);
|
const int nw = MIN(num_warps_this_core, NW);
|
||||||
vx_wspawn(nw, spawn_tasks_cluster_all_cb);
|
vx_wspawn(nw, spawn_tasks_cluster_all_cb);
|
||||||
|
|
||||||
// activate all threads
|
// activate all threads
|
||||||
@@ -235,12 +223,16 @@ void vx_spawn_tasks_cluster(int num_tasks, vx_spawn_tasks_cb callback, void *arg
|
|||||||
vx_wspawn_wait();
|
vx_wspawn_wait();
|
||||||
}
|
}
|
||||||
|
|
||||||
if (rem_threads_in_last_warp != 0) {
|
// TODO: Instead of launching an additional wave just to work on remaining
|
||||||
|
// threads, handle this in the last wave amongst other full warps.
|
||||||
|
if (rem_threads_in_last_warp != 0 && core_id_in_cluster == 0) {
|
||||||
// adjust offset
|
// adjust offset
|
||||||
wspawn_args.offset += (tasks_per_core_last - rem_threads_in_last_warp);
|
// FIXME: consider cluster_id here
|
||||||
|
// FIXME: use rem_threads_in_last_warp_this_core
|
||||||
|
wspawn_args.offset += (num_tasks_this_cluster - rem_threads_in_last_warp);
|
||||||
|
|
||||||
// activate remaining threads
|
// activate remaining threads
|
||||||
int tmask = (1 << rem_threads_in_last_warp) - 1;
|
const int tmask = (1 << rem_threads_in_last_warp) - 1;
|
||||||
vx_tmc(tmask);
|
vx_tmc(tmask);
|
||||||
|
|
||||||
// call stub routine
|
// call stub routine
|
||||||
|
|||||||
Reference in New Issue
Block a user