vx_spawn.c: Handle num_clusters > 1

WIP: still assumes num_tasks is divisible by num_cluster
This commit is contained in:
Hansung Kim
2024-03-28 20:16:44 -07:00
parent a9b0814211
commit e4eec8ab4d

View File

@@ -148,7 +148,7 @@ static void __attribute__ ((noinline)) spawn_tasks_cluster_rem_stub() {
(p_wspawn_args->callback)(task_id, p_wspawn_args->arg);
}
static void __attribute__ ((noinline)) spawn_tasks_contiguous_all_cb() {
static void __attribute__ ((noinline)) spawn_tasks_contiguous_all_cb() {
// activate all threads
vx_tmc(-1);
@@ -159,7 +159,7 @@ static void __attribute__ ((noinline)) spawn_tasks_contiguous_all_cb() {
vx_tmc_zero();
}
static void __attribute__ ((noinline)) spawn_tasks_cluster_all_cb() {
static void __attribute__ ((noinline)) spawn_tasks_cluster_all_cb() {
// activate all threads
vx_tmc(-1);
@@ -186,9 +186,11 @@ static void __attribute__ ((noinline)) spawn_tasks_all_cb() {
// core has to enable to fulfill an entire grid of computation.
void vx_spawn_tasks_cluster(int num_tasks, vx_spawn_tasks_cb callback, void *arg) {
// device specs
int NC = vx_num_cores();
int NW = vx_num_warps();
int NT = vx_num_threads();
const int NC = vx_num_cores();
const int NW = vx_num_warps();
const int NT = vx_num_threads();
// NOTE: assumes divisible
const int num_cluster = NC / CORES_PER_CLUSTER;
// current core id
int core_id = vx_core_id();
@@ -206,8 +208,8 @@ void vx_spawn_tasks_cluster(int num_tasks, vx_spawn_tasks_cb callback, void *arg
if (core_id >= num_active_cores)
return; // terminate extra cores
// FIXME: HARDCODES 1 CLUSTER!
const int num_tasks_this_cluster = num_tasks;
// FIXME: assumes num_tasks is divisible by num_cluster
const int num_tasks_this_cluster = num_tasks / num_cluster;
const int num_full_warps = num_tasks_this_cluster / NT;
const int rem_threads_in_last_warp = num_tasks_this_cluster % NT;
// const int num_warps = (num_tasks_this_cluster + (NT - 1)) / NT;