diff --git a/kernel/src/vx_spawn.c b/kernel/src/vx_spawn.c index 278516a3..9ea45ded 100644 --- a/kernel/src/vx_spawn.c +++ b/kernel/src/vx_spawn.c @@ -148,7 +148,7 @@ static void __attribute__ ((noinline)) spawn_tasks_cluster_rem_stub() { (p_wspawn_args->callback)(task_id, p_wspawn_args->arg); } -static void __attribute__ ((noinline)) spawn_tasks_contiguous_all_cb() { +static void __attribute__ ((noinline)) spawn_tasks_contiguous_all_cb() { // activate all threads vx_tmc(-1); @@ -159,7 +159,7 @@ static void __attribute__ ((noinline)) spawn_tasks_contiguous_all_cb() { vx_tmc_zero(); } -static void __attribute__ ((noinline)) spawn_tasks_cluster_all_cb() { +static void __attribute__ ((noinline)) spawn_tasks_cluster_all_cb() { // activate all threads vx_tmc(-1); @@ -186,9 +186,11 @@ static void __attribute__ ((noinline)) spawn_tasks_all_cb() { // core has to enable to fulfill an entire grid of computation. void vx_spawn_tasks_cluster(int num_tasks, vx_spawn_tasks_cb callback, void *arg) { // device specs - int NC = vx_num_cores(); - int NW = vx_num_warps(); - int NT = vx_num_threads(); + const int NC = vx_num_cores(); + const int NW = vx_num_warps(); + const int NT = vx_num_threads(); + // NOTE: assumes divisible + const int num_cluster = NC / CORES_PER_CLUSTER; // current core id int core_id = vx_core_id(); @@ -206,8 +208,8 @@ void vx_spawn_tasks_cluster(int num_tasks, vx_spawn_tasks_cb callback, void *arg if (core_id >= num_active_cores) return; // terminate extra cores - // FIXME: HARDCODES 1 CLUSTER! - const int num_tasks_this_cluster = num_tasks; + // FIXME: assumes num_tasks is divisible by num_cluster + const int num_tasks_this_cluster = num_tasks / num_cluster; const int num_full_warps = num_tasks_this_cluster / NT; const int rem_threads_in_last_warp = num_tasks_this_cluster % NT; // const int num_warps = (num_tasks_this_cluster + (NT - 1)) / NT;