Implement WU architecture support

This commit is contained in:
2026-05-25 19:25:05 +08:00
parent 323ed7d7e9
commit 0ad87bde81
35 changed files with 3303 additions and 472 deletions

View File

@@ -74,18 +74,9 @@ static void __attribute__ ((noinline)) spawn_tasks_all_stub() {
}
}
static void __attribute__ ((noinline)) spawn_tasks_rem_stub() {
int cid = vx_core_id();
int tid = vx_thread_id();
wspawn_tasks_args_t* p_wspawn_args = (wspawn_tasks_args_t*)g_wspawn_args[cid];
int task_id = p_wspawn_args->offset + tid;
(p_wspawn_args->callback)(task_id, p_wspawn_args->arg);
}
static void __attribute__ ((noinline)) spawn_tasks_contiguous_all_stub() {
int NT = vx_num_threads();
int NW = vx_num_warps();
int NW = NUM_SCALAR_WARPS;
int cid = vx_core_id();
int wid = vx_warp_id();
int tid = vx_thread_id();
@@ -103,6 +94,60 @@ static void __attribute__ ((noinline)) spawn_tasks_contiguous_all_stub() {
}
}
static void __attribute__ ((noinline)) spawn_tasks_cluster_all_stub() {
int NT = vx_num_threads();
int NW = NUM_SCALAR_WARPS;
int cid = vx_core_id();
int wid = vx_warp_id();
int tid = vx_thread_id();
const int core_id_in_cluster = cid % CORES_PER_CLUSTER;
// round-robin warp_id allocation across cores in cluster
const int wid_in_cluster = CORES_PER_CLUSTER * wid + core_id_in_cluster;
wspawn_tasks_args_t* p_wspawn_args = (wspawn_tasks_args_t*)g_wspawn_args[cid];
int waves = p_wspawn_args->NWs + (wid < p_wspawn_args->RWs);
int offset = p_wspawn_args->offset + (NT * wid_in_cluster + tid);
vx_spawn_tasks_cb callback = p_wspawn_args->callback;
void* arg = p_wspawn_args->arg;
// sequential iterations
for (int wave_id = 0; wave_id < waves; ++wave_id) {
int task_id = offset + (wave_id * NT * NW * CORES_PER_CLUSTER);
callback(task_id, arg);
}
}
static void __attribute__ ((noinline)) spawn_tasks_rem_stub() {
int cid = vx_core_id();
int tid = vx_thread_id();
wspawn_tasks_args_t* p_wspawn_args = (wspawn_tasks_args_t*)g_wspawn_args[cid];
int task_id = p_wspawn_args->offset + tid;
(p_wspawn_args->callback)(task_id, p_wspawn_args->arg);
}
static void __attribute__ ((noinline)) spawn_tasks_cluster_rem_stub() {
int NT = vx_num_threads();
int cid = vx_core_id();
int tid = vx_thread_id();
int wid = vx_warp_id();
const int core_id_in_cluster = cid % CORES_PER_CLUSTER;
// round-robin warp_id allocation across cores in cluster
const int wid_in_cluster = CORES_PER_CLUSTER * wid + core_id_in_cluster;
wspawn_tasks_args_t* p_wspawn_args = (wspawn_tasks_args_t*)g_wspawn_args[cid];
// FIXME: This assumes that all cores but the last one are working with full
// warps, and only the last core has a partially-filled warp.
int offset = p_wspawn_args->offset + (NT * wid_in_cluster + tid);
int task_id = offset;
(p_wspawn_args->callback)(task_id, p_wspawn_args->arg);
}
static void __attribute__ ((noinline)) spawn_tasks_contiguous_all_cb() {
// activate all threads
vx_tmc(-1);
@@ -111,11 +156,21 @@ static void __attribute__ ((noinline)) spawn_tasks_contiguous_all_cb() {
spawn_tasks_contiguous_all_stub();
// disable warp
// deadlock here on warps 1, 2, 3
vx_tmc_zero();
}
static void __attribute__ ((noinline)) spawn_tasks_all_cb() {
static void __attribute__ ((noinline)) spawn_tasks_cluster_all_cb() {
// activate all threads
vx_tmc(-1);
// call stub routine
spawn_tasks_cluster_all_stub();
// disable warp
vx_tmc_zero();
}
static void __attribute__ ((noinline)) spawn_tasks_all_cb() {
// activate all threads
vx_tmc(-1);
@@ -126,10 +181,115 @@ static void __attribute__ ((noinline)) spawn_tasks_all_cb() {
vx_tmc_zero();
}
// This function runs in every core, but with only 1 warp and 1 thread enabled.
// The logic in this function figures out how many warps/threads this particular
// core has to enable to fulfill an entire grid of computation.
void vx_spawn_tasks_cluster(int num_tasks, vx_spawn_tasks_cb callback, void *arg) {
// device specs
const int NC = vx_num_cores();
const int NW = NUM_SCALAR_WARPS;
const int NT = vx_num_threads();
// NOTE: assumes divisible
const int num_cluster = NC / CORES_PER_CLUSTER;
// current core id
int core_id = vx_core_id();
if (core_id >= NUM_CORES_MAX)
return;
const int cluster_id = core_id / CORES_PER_CLUSTER;
const int core_id_in_cluster = core_id % CORES_PER_CLUSTER;
// try to fill up full clusters first
const int num_threads_in_cluster = CORES_PER_CLUSTER * NW * NT;
const int num_used_clusters =
(num_tasks + (num_threads_in_cluster - 1)) / num_threads_in_cluster;
if (cluster_id >= num_used_clusters) {
return; // terminate extra clusters
}
// fill up the last cluster with remaining tasks
const int num_full_clusters = num_tasks / num_threads_in_cluster;
int num_tasks_this_cluster = num_threads_in_cluster;
if (cluster_id >= num_full_clusters) {
num_tasks_this_cluster = num_tasks % num_threads_in_cluster;
}
// Distribute threads equally across as many cores as possible, even if they
// don't fill up NW*NT in a single core. This makes sure the warps get evenly
// distributed in a single cluster
//
// TODO: Try to contain in a single cluster if possible?
const int num_active_cores = (num_tasks + (NT - 1)) / NT;
if (core_id >= num_active_cores)
return; // terminate extra cores
const int num_full_warps_this_cluster = num_tasks_this_cluster / NT;
const int rem_threads_in_last_warp = num_tasks_this_cluster % NT;
// const int num_warps = (num_tasks_this_cluster + (NT - 1)) / NT;
int num_warps_this_core = num_full_warps_this_cluster / CORES_PER_CLUSTER;
const int num_warps_in_last_row = num_full_warps_this_cluster % CORES_PER_CLUSTER;
if (core_id_in_cluster < num_warps_in_last_row) {
num_warps_this_core++;
}
// if 0, last warp is full-threads enabled
int rem_threads_in_last_warp_this_core = 0;
if (rem_threads_in_last_warp != 0) {
if (core_id_in_cluster == num_warps_in_last_row - 1) {
rem_threads_in_last_warp_this_core = rem_threads_in_last_warp;
}
}
// sequential iterations
const int num_full_waves = num_warps_this_core / NW;
const int rem_full_warps_in_last_wave = num_warps_this_core % NW;
const int offset = cluster_id * num_tasks_this_cluster;
wspawn_tasks_args_t wspawn_args = {callback, arg, offset, num_full_waves,
rem_full_warps_in_last_wave};
g_wspawn_args[core_id] = &wspawn_args;
if (num_warps_this_core > 0) {
// execute callback on other warps
const int nw = MIN(num_warps_this_core, NW);
vx_wspawn(nw, spawn_tasks_cluster_all_cb);
// activate all threads
vx_tmc(-1);
// call stub routine
spawn_tasks_cluster_all_stub();
// back to single-threaded
vx_tmc_one();
// wait for spawn warps to terminate
vx_wspawn_wait();
}
// TODO: this is incomplete
// TODO: Instead of launching an additional wave just to work on remaining
// threads, handle this in the last wave amongst other full warps.
if (rem_threads_in_last_warp != 0 && core_id_in_cluster == 0) {
// adjust offset
// FIXME: use rem_threads_in_last_warp_this_core
wspawn_args.offset += (num_tasks_this_cluster - rem_threads_in_last_warp);
// activate remaining threads
const int tmask = (1 << rem_threads_in_last_warp) - 1;
vx_tmc(tmask);
// call stub routine
spawn_tasks_cluster_rem_stub();
// back to single-threaded
vx_tmc_one();
}
}
void vx_spawn_tasks_contiguous(int num_tasks, vx_spawn_tasks_cb callback , void * arg) {
// device specs
int NC = vx_num_cores();
int NW = vx_num_warps();
int NW = NUM_SCALAR_WARPS;
int NT = vx_num_threads();
// current core id
@@ -179,7 +339,6 @@ void vx_spawn_tasks_contiguous(int num_tasks, vx_spawn_tasks_cb callback , void
vx_tmc_one();
// wait for spawn warps to terminate
// deadlock here on warp 0!
vx_wspawn_wait();
}
@@ -202,7 +361,7 @@ void vx_spawn_tasks_contiguous(int num_tasks, vx_spawn_tasks_cb callback , void
void vx_spawn_tasks(int num_tasks, vx_spawn_tasks_cb callback , void * arg) {
// device specs
int NC = vx_num_cores();
int NW = vx_num_warps();
int NW = NUM_SCALAR_WARPS;
int NT = vx_num_threads();
// current core id
@@ -356,7 +515,7 @@ void vx_spawn_kernel(context_t * ctx, vx_spawn_kernel_cb callback, void * arg) {
// device specs
int NC = vx_num_cores();
int NW = vx_num_warps();
int NW = NUM_SCALAR_WARPS;
int NT = vx_num_threads();
// current core id

View File

@@ -22,9 +22,9 @@
_start:
# initialize per-thread registers
csrr t0, VX_CSR_NUM_WARPS # get num warps
li t0, ((1 << NUM_SCALAR_WARPS) - 1) # scalar warp mask
la t1, init_regs_all
.insn r RISCV_CUSTOM0, 1, 0, x0, t0, t1 # wspawn t0, t1
.insn r RISCV_CUSTOM0, 6, 0, x0, t0, t1 # wspawn_mask t0, t1
li t0, -1
.insn r RISCV_CUSTOM0, 0, 0, x0, t0, x0 # tmc t0
jal init_regs
@@ -35,9 +35,9 @@ _start:
jal vx_wspawn_wait
# initialize TLS for all warps
csrr t0, VX_CSR_NUM_WARPS # get num warps
li t0, ((1 << NUM_SCALAR_WARPS) - 1) # scalar warp mask
la t1, init_tls_all
.insn r RISCV_CUSTOM0, 1, 0, x0, t0, t1 # wspawn t0, t1
.insn r RISCV_CUSTOM0, 6, 0, x0, t0, t1 # wspawn_mask t0, t1
li t0, -1
.insn r RISCV_CUSTOM0, 0, 0, x0, t0, x0 # tmc t0
call __init_tls
@@ -102,6 +102,8 @@ init_regs:
#endif
csrr t0, VX_CSR_MHARTID
sll t1, t0, STACK_LOG2_SIZE
sll t2, t0, 4
add t1, t1, t2
sub sp, sp, t1
# set thread pointer register