diff --git a/kernel/src/vx_spawn.c b/kernel/src/vx_spawn.c index fd8258e1..eb0bdb90 100644 --- a/kernel/src/vx_spawn.c +++ b/kernel/src/vx_spawn.c @@ -74,6 +74,27 @@ static void __attribute__ ((noinline)) spawn_tasks_all_stub() { } } +static void __attribute__ ((noinline)) spawn_tasks_contiguous_all_stub() { + int NT = vx_num_threads(); + int NW = vx_num_warps(); + int cid = vx_core_id(); + int wid = vx_warp_id(); + int tid = vx_thread_id(); + + wspawn_tasks_args_t* p_wspawn_args = (wspawn_tasks_args_t*)g_wspawn_args[cid]; + + // FIXME: handle RW + int waves = p_wspawn_args->NWs; + int offset = p_wspawn_args->offset + (NT * wid + tid); + + vx_spawn_tasks_cb callback = p_wspawn_args->callback; + void* arg = p_wspawn_args->arg; + for (int wave_id = 0; wave_id < waves; ++wave_id) { + int task_id = offset + (wave_id * NT * NW); + callback(task_id, arg); + } +} + static void __attribute__ ((noinline)) spawn_tasks_rem_stub() { int cid = vx_core_id(); int tid = vx_thread_id(); @@ -88,7 +109,8 @@ static void __attribute__ ((noinline)) spawn_tasks_all_cb() { vx_tmc(-1); // call stub routine - spawn_tasks_all_stub(); + // spawn_tasks_all_stub(); + spawn_tasks_contiguous_all_stub(); // disable warp vx_tmc_zero(); @@ -141,7 +163,7 @@ void vx_spawn_tasks(int num_tasks, vx_spawn_tasks_cb callback , void * arg) { vx_tmc(-1); // call stub routine - spawn_tasks_all_stub(); + spawn_tasks_contiguous_all_stub(); // back to single-threaded vx_tmc_one();