Update Wu architecture kernel implementations and runtime library
This commit is contained in:
@@ -1,14 +1,11 @@
|
||||
#define WU_START_BRANCH_TO_MAIN 1
|
||||
#include "common_wu_min.h"
|
||||
|
||||
extern "C" void scalar_worker() {
|
||||
wu_short_delay(wu_wid());
|
||||
wu_mark_seen(WU_CASE_SCALAR_BASE);
|
||||
wu_stop_warp();
|
||||
}
|
||||
extern "C" void scalar_worker();
|
||||
|
||||
extern "C" int wu_main() {
|
||||
if (!wu_is_leader()) {
|
||||
return 0;
|
||||
wu_stop_warp();
|
||||
}
|
||||
|
||||
wu_case_reset();
|
||||
@@ -21,9 +18,37 @@ extern "C" int wu_main() {
|
||||
wu_mark_seen(WU_CASE_SCALAR_BASE);
|
||||
if (wu_wait_seen_range(0, NUM_SCALAR_WARPS, WU_CASE_SCALAR_BASE) != 0) {
|
||||
wu_case_fail(0x01u);
|
||||
return 1;
|
||||
wu_stop_warp();
|
||||
}
|
||||
|
||||
wu_case_pass();
|
||||
return 0;
|
||||
wu_stop_warp();
|
||||
}
|
||||
|
||||
extern "C" void scalar_worker_body();
|
||||
|
||||
extern "C" void __attribute__((naked, used)) scalar_worker() {
|
||||
asm volatile(
|
||||
".option push\n\t"
|
||||
".option norelax\n\t"
|
||||
"la gp, __global_pointer\n\t"
|
||||
".option pop\n\t"
|
||||
"li sp, %[stack_base]\n\t"
|
||||
"csrr t0, %[csr_hart]\n\t"
|
||||
"slli t1, t0, %[stack_log2]\n\t"
|
||||
"slli t2, t0, 4\n\t"
|
||||
"add t1, t1, t2\n\t"
|
||||
"sub sp, sp, t1\n\t"
|
||||
"j scalar_worker_body\n\t"
|
||||
:
|
||||
: [csr_hart] "i"(VX_CSR_MHARTID),
|
||||
[stack_base] "i"(STACK_BASE_ADDR),
|
||||
[stack_log2] "i"(STACK_LOG2_SIZE)
|
||||
: "memory");
|
||||
}
|
||||
|
||||
extern "C" void scalar_worker_body() {
|
||||
wu_short_delay(wu_wid());
|
||||
wu_mark_seen(WU_CASE_SCALAR_BASE);
|
||||
wu_stop_warp();
|
||||
}
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
#define WU_CASE_WAIT_SPIN 1024u
|
||||
#include "common_wu_min.h"
|
||||
|
||||
extern "C" void __attribute__((naked, noinline, used)) tensor_worker() {
|
||||
|
||||
@@ -6,13 +6,15 @@ extern "C" void __attribute__((naked, noinline, used)) tensor_worker() {
|
||||
"slli x6, x5, 2\n\t"
|
||||
"la x7, g_case_mem\n\t"
|
||||
"add x7, x7, x6\n\t"
|
||||
"li x8, %[tensor_lsu_base]\n\t"
|
||||
"or x8, x8, x5\n\t"
|
||||
"sw x8, 0(x7)\n\t"
|
||||
"lw x8, 0(x7)\n\t"
|
||||
"li x6, %[tensor_lsu_base]\n\t"
|
||||
"or x5, x6, x5\n\t"
|
||||
"sw x5, 0(x7)\n\t"
|
||||
"lw x5, 0(x7)\n\t"
|
||||
"sub x6, x5, x6\n\t"
|
||||
"slli x6, x6, 2\n\t"
|
||||
"la x7, g_seen\n\t"
|
||||
"add x7, x7, x6\n\t"
|
||||
"sw x8, 0(x7)\n\t"
|
||||
"sw x5, 0(x7)\n\t"
|
||||
".insn r %[custom0], 0, 0, x0, x0, x0\n\t"
|
||||
"1: j 1b\n\t"
|
||||
:
|
||||
|
||||
@@ -5,8 +5,12 @@
|
||||
#include <vx_intrinsics.h>
|
||||
|
||||
#define WU_CASE_MAX_WARPS 8u
|
||||
#ifndef WU_CASE_WAIT_SPIN
|
||||
#define WU_CASE_WAIT_SPIN 1024u
|
||||
#endif
|
||||
#ifndef WU_CASE_SHORT_SPIN
|
||||
#define WU_CASE_SHORT_SPIN 8u
|
||||
#endif
|
||||
|
||||
#define WU_CASE_PASS 0x600du
|
||||
#define WU_CASE_FAIL_BASE 0xe000u
|
||||
@@ -15,6 +19,10 @@
|
||||
#define WU_CASE_TENSOR_CSR_BASE 0x7300u
|
||||
#define WU_CASE_TENSOR_LSU_BASE 0x7400u
|
||||
|
||||
#ifndef WU_START_BRANCH_TO_MAIN
|
||||
#define WU_START_BRANCH_TO_MAIN 0
|
||||
#endif
|
||||
|
||||
extern "C" {
|
||||
volatile uint32_t g_status[WU_CASE_MAX_WARPS] __attribute__((aligned(32)));
|
||||
volatile uint32_t g_seen[WU_CASE_MAX_WARPS] __attribute__((aligned(32)));
|
||||
@@ -35,8 +43,12 @@ extern "C" void __attribute__((naked, section(".init"), used)) _start() {
|
||||
"csrr t0, %[csr_core]\n\t"
|
||||
"bnez t0, 2f\n\t"
|
||||
"li sp, %[stack_base]\n\t"
|
||||
#if WU_START_BRANCH_TO_MAIN
|
||||
"beq zero, zero, wu_main\n\t"
|
||||
#else
|
||||
"call wu_main\n\t"
|
||||
"mv gp, a0\n\t"
|
||||
#endif
|
||||
"2:\n\t"
|
||||
".insn r %[custom0], 0, 0, x0, x0, x0\n\t"
|
||||
"1: j 1b\n\t"
|
||||
@@ -113,7 +125,7 @@ static inline void wu_mark_seen(uint32_t base) {
|
||||
}
|
||||
}
|
||||
|
||||
static inline void wu_stop_warp() {
|
||||
static inline void __attribute__((noreturn)) wu_stop_warp() {
|
||||
vx_tmc_zero();
|
||||
while (1) {}
|
||||
}
|
||||
|
||||
@@ -84,15 +84,32 @@
|
||||
#endif
|
||||
|
||||
#ifndef NUM_CORES
|
||||
#define NUM_CORES 8
|
||||
#define NUM_CORES 1
|
||||
#endif
|
||||
|
||||
#ifndef NUM_WARPS
|
||||
#define NUM_WARPS 8
|
||||
#define NUM_WARPS 4
|
||||
#endif
|
||||
|
||||
#ifndef NUM_TENSOR_WARPS
|
||||
#define NUM_TENSOR_WARPS 2
|
||||
#endif
|
||||
|
||||
#define NUM_SCALAR_WARPS (NUM_WARPS - NUM_TENSOR_WARPS)
|
||||
|
||||
#define IS_SCALAR_WARP(wid) ((wid) < NUM_SCALAR_WARPS)
|
||||
#define IS_TENSOR_WARP(wid) ((wid) >= NUM_SCALAR_WARPS)
|
||||
|
||||
#ifndef TENSOR_NUM_GPRS
|
||||
#define TENSOR_NUM_GPRS 8
|
||||
#endif
|
||||
|
||||
#ifndef TENSOR_NUM_FPRS
|
||||
#define TENSOR_NUM_FPRS 8
|
||||
#endif
|
||||
|
||||
#ifndef NUM_THREADS
|
||||
#define NUM_THREADS 8
|
||||
#define NUM_THREADS 4
|
||||
#endif
|
||||
|
||||
#ifndef NUM_BARRIERS
|
||||
@@ -682,4 +699,3 @@
|
||||
#define IMPLEMENTATION_ID 0
|
||||
|
||||
#endif // VX_CONFIG_VH
|
||||
|
||||
|
||||
@@ -136,6 +136,19 @@ inline void vx_wspawn(unsigned num_warps, vx_wspawn_pfn func_ptr) {
|
||||
asm volatile (".insn r %0, 1, 0, x0, %1, %2" :: "i"(RISCV_CUSTOM0), "r"(num_warps), "r"(func_ptr));
|
||||
}
|
||||
|
||||
// Spawn an explicit warp mask. The current warp bit is ignored by hardware.
|
||||
inline void vx_wspawn_mask(unsigned warp_mask, vx_wspawn_pfn func_ptr) {
|
||||
asm volatile (".insn r %0, 6, 0, x0, %1, %2" :: "i"(RISCV_CUSTOM0), "r"(warp_mask), "r"(func_ptr));
|
||||
}
|
||||
|
||||
inline void vx_spawn_scalar(unsigned warp_mask, vx_wspawn_pfn func_ptr) {
|
||||
vx_wspawn_mask(warp_mask & ((1u << NUM_SCALAR_WARPS) - 1u), func_ptr);
|
||||
}
|
||||
|
||||
inline void vx_spawn_tensor(unsigned warp_mask, vx_wspawn_pfn func_ptr) {
|
||||
vx_wspawn_mask(warp_mask & (((1u << NUM_TENSOR_WARPS) - 1u) << NUM_SCALAR_WARPS), func_ptr);
|
||||
}
|
||||
|
||||
// Split on a predicate
|
||||
inline unsigned vx_split(unsigned predicate) {
|
||||
unsigned ret;
|
||||
@@ -151,7 +164,34 @@ inline void vx_join(unsigned stack_ptr) {
|
||||
// Warp Barrier
|
||||
__attribute__((convergent))
|
||||
inline void vx_barrier(unsigned barried_id, unsigned num_warps) {
|
||||
asm volatile (".insn r %0, 4, 0, x0, %1, %2" :: "i"(RISCV_CUSTOM0), "r"(barried_id), "r"(num_warps));
|
||||
unsigned scalar_warps = (num_warps > NUM_SCALAR_WARPS) ? NUM_SCALAR_WARPS : num_warps;
|
||||
asm volatile (".insn r %0, 4, 0, x0, %1, %2" :: "i"(RISCV_CUSTOM0), "r"(barried_id), "r"(scalar_warps));
|
||||
}
|
||||
|
||||
#define VX_BARRIER_DOMAIN_SHIFT 28
|
||||
#define VX_BARRIER_DOMAIN_ALL 0u
|
||||
#define VX_BARRIER_DOMAIN_SCALAR 1u
|
||||
#define VX_BARRIER_DOMAIN_TENSOR 2u
|
||||
|
||||
__attribute__((convergent))
|
||||
inline void vx_barrier_domain(unsigned barrier_id, unsigned num_warps, unsigned domain) {
|
||||
unsigned encoded_id = barrier_id | (domain << VX_BARRIER_DOMAIN_SHIFT);
|
||||
asm volatile (".insn r %0, 4, 0, x0, %1, %2" :: "i"(RISCV_CUSTOM0), "r"(encoded_id), "r"(num_warps));
|
||||
}
|
||||
|
||||
__attribute__((convergent))
|
||||
inline void vx_barrier_scalar(unsigned barrier_id, unsigned num_warps) {
|
||||
vx_barrier_domain(barrier_id, num_warps, VX_BARRIER_DOMAIN_SCALAR);
|
||||
}
|
||||
|
||||
__attribute__((convergent))
|
||||
inline void vx_barrier_tensor(unsigned barrier_id, unsigned num_warps) {
|
||||
vx_barrier_domain(barrier_id, num_warps, VX_BARRIER_DOMAIN_TENSOR);
|
||||
}
|
||||
|
||||
__attribute__((convergent))
|
||||
inline void vx_barrier_mask(unsigned barrier_id, unsigned warp_mask) {
|
||||
asm volatile (".insn r %0, 7, 0, x0, %1, %2" :: "i"(RISCV_CUSTOM0), "r"(barrier_id), "r"(warp_mask));
|
||||
}
|
||||
|
||||
// Return current thread identifier
|
||||
@@ -203,6 +243,22 @@ inline int vx_num_warps() {
|
||||
return ret;
|
||||
}
|
||||
|
||||
inline int vx_num_scalar_warps() {
|
||||
return NUM_SCALAR_WARPS;
|
||||
}
|
||||
|
||||
inline int vx_num_tensor_warps() {
|
||||
return NUM_TENSOR_WARPS;
|
||||
}
|
||||
|
||||
inline unsigned vx_scalar_warp_mask() {
|
||||
return (1u << NUM_SCALAR_WARPS) - 1u;
|
||||
}
|
||||
|
||||
inline unsigned vx_tensor_warp_mask() {
|
||||
return ((1u << NUM_TENSOR_WARPS) - 1u) << NUM_SCALAR_WARPS;
|
||||
}
|
||||
|
||||
// Return the number of cores per cluster
|
||||
inline int vx_num_cores() {
|
||||
int ret;
|
||||
|
||||
@@ -76,7 +76,7 @@ static void __attribute__ ((noinline)) spawn_tasks_all_stub() {
|
||||
|
||||
static void __attribute__ ((noinline)) spawn_tasks_contiguous_all_stub() {
|
||||
int NT = vx_num_threads();
|
||||
int NW = vx_num_warps();
|
||||
int NW = NUM_SCALAR_WARPS;
|
||||
int cid = vx_core_id();
|
||||
int wid = vx_warp_id();
|
||||
int tid = vx_thread_id();
|
||||
@@ -96,7 +96,7 @@ static void __attribute__ ((noinline)) spawn_tasks_contiguous_all_stub() {
|
||||
|
||||
static void __attribute__ ((noinline)) spawn_tasks_cluster_all_stub() {
|
||||
int NT = vx_num_threads();
|
||||
int NW = vx_num_warps();
|
||||
int NW = NUM_SCALAR_WARPS;
|
||||
int cid = vx_core_id();
|
||||
int wid = vx_warp_id();
|
||||
int tid = vx_thread_id();
|
||||
@@ -187,7 +187,7 @@ static void __attribute__ ((noinline)) spawn_tasks_all_cb() {
|
||||
void vx_spawn_tasks_cluster(int num_tasks, vx_spawn_tasks_cb callback, void *arg) {
|
||||
// device specs
|
||||
const int NC = vx_num_cores();
|
||||
const int NW = vx_num_warps();
|
||||
const int NW = NUM_SCALAR_WARPS;
|
||||
const int NT = vx_num_threads();
|
||||
// NOTE: assumes divisible
|
||||
const int num_cluster = NC / CORES_PER_CLUSTER;
|
||||
@@ -243,7 +243,7 @@ void vx_spawn_tasks_cluster(int num_tasks, vx_spawn_tasks_cb callback, void *arg
|
||||
const int num_full_waves = num_warps_this_core / NW;
|
||||
const int rem_full_warps_in_last_wave = num_warps_this_core % NW;
|
||||
|
||||
const const int offset = cluster_id * num_tasks_this_cluster;
|
||||
const int offset = cluster_id * num_tasks_this_cluster;
|
||||
wspawn_tasks_args_t wspawn_args = {callback, arg, offset, num_full_waves,
|
||||
rem_full_warps_in_last_wave};
|
||||
g_wspawn_args[core_id] = &wspawn_args;
|
||||
@@ -289,7 +289,7 @@ void vx_spawn_tasks_cluster(int num_tasks, vx_spawn_tasks_cb callback, void *arg
|
||||
void vx_spawn_tasks_contiguous(int num_tasks, vx_spawn_tasks_cb callback , void * arg) {
|
||||
// device specs
|
||||
int NC = vx_num_cores();
|
||||
int NW = vx_num_warps();
|
||||
int NW = NUM_SCALAR_WARPS;
|
||||
int NT = vx_num_threads();
|
||||
|
||||
// current core id
|
||||
@@ -361,7 +361,7 @@ void vx_spawn_tasks_contiguous(int num_tasks, vx_spawn_tasks_cb callback , void
|
||||
void vx_spawn_tasks(int num_tasks, vx_spawn_tasks_cb callback , void * arg) {
|
||||
// device specs
|
||||
int NC = vx_num_cores();
|
||||
int NW = vx_num_warps();
|
||||
int NW = NUM_SCALAR_WARPS;
|
||||
int NT = vx_num_threads();
|
||||
|
||||
// current core id
|
||||
@@ -515,7 +515,7 @@ void vx_spawn_kernel(context_t * ctx, vx_spawn_kernel_cb callback, void * arg) {
|
||||
|
||||
// device specs
|
||||
int NC = vx_num_cores();
|
||||
int NW = vx_num_warps();
|
||||
int NW = NUM_SCALAR_WARPS;
|
||||
int NT = vx_num_threads();
|
||||
|
||||
// current core id
|
||||
|
||||
@@ -22,9 +22,9 @@
|
||||
_start:
|
||||
|
||||
# initialize per-thread registers
|
||||
csrr t0, VX_CSR_NUM_WARPS # get num warps
|
||||
li t0, ((1 << NUM_SCALAR_WARPS) - 1) # scalar warp mask
|
||||
la t1, init_regs_all
|
||||
.insn r RISCV_CUSTOM0, 1, 0, x0, t0, t1 # wspawn t0, t1
|
||||
.insn r RISCV_CUSTOM0, 6, 0, x0, t0, t1 # wspawn_mask t0, t1
|
||||
li t0, -1
|
||||
.insn r RISCV_CUSTOM0, 0, 0, x0, t0, x0 # tmc t0
|
||||
jal init_regs
|
||||
@@ -35,9 +35,9 @@ _start:
|
||||
jal vx_wspawn_wait
|
||||
|
||||
# initialize TLS for all warps
|
||||
csrr t0, VX_CSR_NUM_WARPS # get num warps
|
||||
li t0, ((1 << NUM_SCALAR_WARPS) - 1) # scalar warp mask
|
||||
la t1, init_tls_all
|
||||
.insn r RISCV_CUSTOM0, 1, 0, x0, t0, t1 # wspawn t0, t1
|
||||
.insn r RISCV_CUSTOM0, 6, 0, x0, t0, t1 # wspawn_mask t0, t1
|
||||
li t0, -1
|
||||
.insn r RISCV_CUSTOM0, 0, 0, x0, t0, x0 # tmc t0
|
||||
call __init_tls
|
||||
@@ -150,4 +150,3 @@ vx_wspawn_wait:
|
||||
.weak __dso_handle
|
||||
__dso_handle:
|
||||
.long 0
|
||||
|
||||
|
||||
Reference in New Issue
Block a user