diff --git a/kernels/wu_arch_cases/case01_scalar_spawn/kernel.cpp b/kernels/wu_arch_cases/case01_scalar_spawn/kernel.cpp index c29c12b5..d6ccc677 100644 --- a/kernels/wu_arch_cases/case01_scalar_spawn/kernel.cpp +++ b/kernels/wu_arch_cases/case01_scalar_spawn/kernel.cpp @@ -1,14 +1,11 @@ +#define WU_START_BRANCH_TO_MAIN 1 #include "common_wu_min.h" -extern "C" void scalar_worker() { - wu_short_delay(wu_wid()); - wu_mark_seen(WU_CASE_SCALAR_BASE); - wu_stop_warp(); -} +extern "C" void scalar_worker(); extern "C" int wu_main() { if (!wu_is_leader()) { - return 0; + wu_stop_warp(); } wu_case_reset(); @@ -21,9 +18,37 @@ extern "C" int wu_main() { wu_mark_seen(WU_CASE_SCALAR_BASE); if (wu_wait_seen_range(0, NUM_SCALAR_WARPS, WU_CASE_SCALAR_BASE) != 0) { wu_case_fail(0x01u); - return 1; + wu_stop_warp(); } wu_case_pass(); - return 0; + wu_stop_warp(); +} + +extern "C" void scalar_worker_body(); + +extern "C" void __attribute__((naked, used)) scalar_worker() { + asm volatile( + ".option push\n\t" + ".option norelax\n\t" + "la gp, __global_pointer\n\t" + ".option pop\n\t" + "li sp, %[stack_base]\n\t" + "csrr t0, %[csr_hart]\n\t" + "slli t1, t0, %[stack_log2]\n\t" + "slli t2, t0, 4\n\t" + "add t1, t1, t2\n\t" + "sub sp, sp, t1\n\t" + "j scalar_worker_body\n\t" + : + : [csr_hart] "i"(VX_CSR_MHARTID), + [stack_base] "i"(STACK_BASE_ADDR), + [stack_log2] "i"(STACK_LOG2_SIZE) + : "memory"); +} + +extern "C" void scalar_worker_body() { + wu_short_delay(wu_wid()); + wu_mark_seen(WU_CASE_SCALAR_BASE); + wu_stop_warp(); } diff --git a/kernels/wu_arch_cases/case02_tensor_spawn_stop/kernel.cpp b/kernels/wu_arch_cases/case02_tensor_spawn_stop/kernel.cpp index a65a22b8..2b0843d4 100644 --- a/kernels/wu_arch_cases/case02_tensor_spawn_stop/kernel.cpp +++ b/kernels/wu_arch_cases/case02_tensor_spawn_stop/kernel.cpp @@ -1,3 +1,4 @@ +#define WU_CASE_WAIT_SPIN 1024u #include "common_wu_min.h" extern "C" void __attribute__((naked, noinline, used)) tensor_worker() { diff --git a/kernels/wu_arch_cases/case08_tensor_lsu_optional/kernel.cpp b/kernels/wu_arch_cases/case08_tensor_lsu_optional/kernel.cpp index 3746c857..1f84887c 100644 --- a/kernels/wu_arch_cases/case08_tensor_lsu_optional/kernel.cpp +++ b/kernels/wu_arch_cases/case08_tensor_lsu_optional/kernel.cpp @@ -6,13 +6,15 @@ extern "C" void __attribute__((naked, noinline, used)) tensor_worker() { "slli x6, x5, 2\n\t" "la x7, g_case_mem\n\t" "add x7, x7, x6\n\t" - "li x8, %[tensor_lsu_base]\n\t" - "or x8, x8, x5\n\t" - "sw x8, 0(x7)\n\t" - "lw x8, 0(x7)\n\t" + "li x6, %[tensor_lsu_base]\n\t" + "or x5, x6, x5\n\t" + "sw x5, 0(x7)\n\t" + "lw x5, 0(x7)\n\t" + "sub x6, x5, x6\n\t" + "slli x6, x6, 2\n\t" "la x7, g_seen\n\t" "add x7, x7, x6\n\t" - "sw x8, 0(x7)\n\t" + "sw x5, 0(x7)\n\t" ".insn r %[custom0], 0, 0, x0, x0, x0\n\t" "1: j 1b\n\t" : diff --git a/kernels/wu_arch_cases/common_wu_min.h b/kernels/wu_arch_cases/common_wu_min.h index 81ab22d0..eb01b4f7 100644 --- a/kernels/wu_arch_cases/common_wu_min.h +++ b/kernels/wu_arch_cases/common_wu_min.h @@ -5,8 +5,12 @@ #include #define WU_CASE_MAX_WARPS 8u +#ifndef WU_CASE_WAIT_SPIN #define WU_CASE_WAIT_SPIN 1024u +#endif +#ifndef WU_CASE_SHORT_SPIN #define WU_CASE_SHORT_SPIN 8u +#endif #define WU_CASE_PASS 0x600du #define WU_CASE_FAIL_BASE 0xe000u @@ -15,6 +19,10 @@ #define WU_CASE_TENSOR_CSR_BASE 0x7300u #define WU_CASE_TENSOR_LSU_BASE 0x7400u +#ifndef WU_START_BRANCH_TO_MAIN +#define WU_START_BRANCH_TO_MAIN 0 +#endif + extern "C" { volatile uint32_t g_status[WU_CASE_MAX_WARPS] __attribute__((aligned(32))); volatile uint32_t g_seen[WU_CASE_MAX_WARPS] __attribute__((aligned(32))); @@ -35,8 +43,12 @@ extern "C" void __attribute__((naked, section(".init"), used)) _start() { "csrr t0, %[csr_core]\n\t" "bnez t0, 2f\n\t" "li sp, %[stack_base]\n\t" +#if WU_START_BRANCH_TO_MAIN + "beq zero, zero, wu_main\n\t" +#else "call wu_main\n\t" "mv gp, a0\n\t" +#endif "2:\n\t" ".insn r %[custom0], 0, 0, x0, x0, x0\n\t" "1: j 1b\n\t" @@ -113,7 +125,7 @@ static inline void wu_mark_seen(uint32_t base) { } } -static inline void wu_stop_warp() { +static inline void __attribute__((noreturn)) wu_stop_warp() { vx_tmc_zero(); while (1) {} } diff --git a/lib/include/VX_config.h b/lib/include/VX_config.h index e7a6b559..b2c28f52 100644 --- a/lib/include/VX_config.h +++ b/lib/include/VX_config.h @@ -84,15 +84,32 @@ #endif #ifndef NUM_CORES -#define NUM_CORES 8 +#define NUM_CORES 1 #endif #ifndef NUM_WARPS -#define NUM_WARPS 8 +#define NUM_WARPS 4 +#endif + +#ifndef NUM_TENSOR_WARPS +#define NUM_TENSOR_WARPS 2 +#endif + +#define NUM_SCALAR_WARPS (NUM_WARPS - NUM_TENSOR_WARPS) + +#define IS_SCALAR_WARP(wid) ((wid) < NUM_SCALAR_WARPS) +#define IS_TENSOR_WARP(wid) ((wid) >= NUM_SCALAR_WARPS) + +#ifndef TENSOR_NUM_GPRS +#define TENSOR_NUM_GPRS 8 +#endif + +#ifndef TENSOR_NUM_FPRS +#define TENSOR_NUM_FPRS 8 #endif #ifndef NUM_THREADS -#define NUM_THREADS 8 +#define NUM_THREADS 4 #endif #ifndef NUM_BARRIERS @@ -682,4 +699,3 @@ #define IMPLEMENTATION_ID 0 #endif // VX_CONFIG_VH - diff --git a/lib/include/vx_intrinsics.h b/lib/include/vx_intrinsics.h index f51601f7..26bbe65f 100644 --- a/lib/include/vx_intrinsics.h +++ b/lib/include/vx_intrinsics.h @@ -136,6 +136,19 @@ inline void vx_wspawn(unsigned num_warps, vx_wspawn_pfn func_ptr) { asm volatile (".insn r %0, 1, 0, x0, %1, %2" :: "i"(RISCV_CUSTOM0), "r"(num_warps), "r"(func_ptr)); } +// Spawn an explicit warp mask. The current warp bit is ignored by hardware. +inline void vx_wspawn_mask(unsigned warp_mask, vx_wspawn_pfn func_ptr) { + asm volatile (".insn r %0, 6, 0, x0, %1, %2" :: "i"(RISCV_CUSTOM0), "r"(warp_mask), "r"(func_ptr)); +} + +inline void vx_spawn_scalar(unsigned warp_mask, vx_wspawn_pfn func_ptr) { + vx_wspawn_mask(warp_mask & ((1u << NUM_SCALAR_WARPS) - 1u), func_ptr); +} + +inline void vx_spawn_tensor(unsigned warp_mask, vx_wspawn_pfn func_ptr) { + vx_wspawn_mask(warp_mask & (((1u << NUM_TENSOR_WARPS) - 1u) << NUM_SCALAR_WARPS), func_ptr); +} + // Split on a predicate inline unsigned vx_split(unsigned predicate) { unsigned ret; @@ -151,7 +164,34 @@ inline void vx_join(unsigned stack_ptr) { // Warp Barrier __attribute__((convergent)) inline void vx_barrier(unsigned barried_id, unsigned num_warps) { - asm volatile (".insn r %0, 4, 0, x0, %1, %2" :: "i"(RISCV_CUSTOM0), "r"(barried_id), "r"(num_warps)); + unsigned scalar_warps = (num_warps > NUM_SCALAR_WARPS) ? NUM_SCALAR_WARPS : num_warps; + asm volatile (".insn r %0, 4, 0, x0, %1, %2" :: "i"(RISCV_CUSTOM0), "r"(barried_id), "r"(scalar_warps)); +} + +#define VX_BARRIER_DOMAIN_SHIFT 28 +#define VX_BARRIER_DOMAIN_ALL 0u +#define VX_BARRIER_DOMAIN_SCALAR 1u +#define VX_BARRIER_DOMAIN_TENSOR 2u + +__attribute__((convergent)) +inline void vx_barrier_domain(unsigned barrier_id, unsigned num_warps, unsigned domain) { + unsigned encoded_id = barrier_id | (domain << VX_BARRIER_DOMAIN_SHIFT); + asm volatile (".insn r %0, 4, 0, x0, %1, %2" :: "i"(RISCV_CUSTOM0), "r"(encoded_id), "r"(num_warps)); +} + +__attribute__((convergent)) +inline void vx_barrier_scalar(unsigned barrier_id, unsigned num_warps) { + vx_barrier_domain(barrier_id, num_warps, VX_BARRIER_DOMAIN_SCALAR); +} + +__attribute__((convergent)) +inline void vx_barrier_tensor(unsigned barrier_id, unsigned num_warps) { + vx_barrier_domain(barrier_id, num_warps, VX_BARRIER_DOMAIN_TENSOR); +} + +__attribute__((convergent)) +inline void vx_barrier_mask(unsigned barrier_id, unsigned warp_mask) { + asm volatile (".insn r %0, 7, 0, x0, %1, %2" :: "i"(RISCV_CUSTOM0), "r"(barrier_id), "r"(warp_mask)); } // Return current thread identifier @@ -203,6 +243,22 @@ inline int vx_num_warps() { return ret; } +inline int vx_num_scalar_warps() { + return NUM_SCALAR_WARPS; +} + +inline int vx_num_tensor_warps() { + return NUM_TENSOR_WARPS; +} + +inline unsigned vx_scalar_warp_mask() { + return (1u << NUM_SCALAR_WARPS) - 1u; +} + +inline unsigned vx_tensor_warp_mask() { + return ((1u << NUM_TENSOR_WARPS) - 1u) << NUM_SCALAR_WARPS; +} + // Return the number of cores per cluster inline int vx_num_cores() { int ret; diff --git a/lib/src/vx_spawn.c b/lib/src/vx_spawn.c index b53538f0..1971ae55 100644 --- a/lib/src/vx_spawn.c +++ b/lib/src/vx_spawn.c @@ -76,7 +76,7 @@ static void __attribute__ ((noinline)) spawn_tasks_all_stub() { static void __attribute__ ((noinline)) spawn_tasks_contiguous_all_stub() { int NT = vx_num_threads(); - int NW = vx_num_warps(); + int NW = NUM_SCALAR_WARPS; int cid = vx_core_id(); int wid = vx_warp_id(); int tid = vx_thread_id(); @@ -96,7 +96,7 @@ static void __attribute__ ((noinline)) spawn_tasks_contiguous_all_stub() { static void __attribute__ ((noinline)) spawn_tasks_cluster_all_stub() { int NT = vx_num_threads(); - int NW = vx_num_warps(); + int NW = NUM_SCALAR_WARPS; int cid = vx_core_id(); int wid = vx_warp_id(); int tid = vx_thread_id(); @@ -187,7 +187,7 @@ static void __attribute__ ((noinline)) spawn_tasks_all_cb() { void vx_spawn_tasks_cluster(int num_tasks, vx_spawn_tasks_cb callback, void *arg) { // device specs const int NC = vx_num_cores(); - const int NW = vx_num_warps(); + const int NW = NUM_SCALAR_WARPS; const int NT = vx_num_threads(); // NOTE: assumes divisible const int num_cluster = NC / CORES_PER_CLUSTER; @@ -243,7 +243,7 @@ void vx_spawn_tasks_cluster(int num_tasks, vx_spawn_tasks_cb callback, void *arg const int num_full_waves = num_warps_this_core / NW; const int rem_full_warps_in_last_wave = num_warps_this_core % NW; - const const int offset = cluster_id * num_tasks_this_cluster; + const int offset = cluster_id * num_tasks_this_cluster; wspawn_tasks_args_t wspawn_args = {callback, arg, offset, num_full_waves, rem_full_warps_in_last_wave}; g_wspawn_args[core_id] = &wspawn_args; @@ -289,7 +289,7 @@ void vx_spawn_tasks_cluster(int num_tasks, vx_spawn_tasks_cb callback, void *arg void vx_spawn_tasks_contiguous(int num_tasks, vx_spawn_tasks_cb callback , void * arg) { // device specs int NC = vx_num_cores(); - int NW = vx_num_warps(); + int NW = NUM_SCALAR_WARPS; int NT = vx_num_threads(); // current core id @@ -361,7 +361,7 @@ void vx_spawn_tasks_contiguous(int num_tasks, vx_spawn_tasks_cb callback , void void vx_spawn_tasks(int num_tasks, vx_spawn_tasks_cb callback , void * arg) { // device specs int NC = vx_num_cores(); - int NW = vx_num_warps(); + int NW = NUM_SCALAR_WARPS; int NT = vx_num_threads(); // current core id @@ -515,7 +515,7 @@ void vx_spawn_kernel(context_t * ctx, vx_spawn_kernel_cb callback, void * arg) { // device specs int NC = vx_num_cores(); - int NW = vx_num_warps(); + int NW = NUM_SCALAR_WARPS; int NT = vx_num_threads(); // current core id diff --git a/lib/src/vx_start.S b/lib/src/vx_start.S index 65dbb9a6..41c42f60 100644 --- a/lib/src/vx_start.S +++ b/lib/src/vx_start.S @@ -22,9 +22,9 @@ _start: # initialize per-thread registers - csrr t0, VX_CSR_NUM_WARPS # get num warps + li t0, ((1 << NUM_SCALAR_WARPS) - 1) # scalar warp mask la t1, init_regs_all - .insn r RISCV_CUSTOM0, 1, 0, x0, t0, t1 # wspawn t0, t1 + .insn r RISCV_CUSTOM0, 6, 0, x0, t0, t1 # wspawn_mask t0, t1 li t0, -1 .insn r RISCV_CUSTOM0, 0, 0, x0, t0, x0 # tmc t0 jal init_regs @@ -35,9 +35,9 @@ _start: jal vx_wspawn_wait # initialize TLS for all warps - csrr t0, VX_CSR_NUM_WARPS # get num warps + li t0, ((1 << NUM_SCALAR_WARPS) - 1) # scalar warp mask la t1, init_tls_all - .insn r RISCV_CUSTOM0, 1, 0, x0, t0, t1 # wspawn t0, t1 + .insn r RISCV_CUSTOM0, 6, 0, x0, t0, t1 # wspawn_mask t0, t1 li t0, -1 .insn r RISCV_CUSTOM0, 0, 0, x0, t0, x0 # tmc t0 call __init_tls @@ -150,4 +150,3 @@ vx_wspawn_wait: .weak __dso_handle __dso_handle: .long 0 -