2 Commits

Author SHA1 Message Date
122a048ea6 Add WU architecture kernel cases 2026-05-27 09:08:30 +08:00
Zhongdi LUO
9f4be1b8f7 Update Wu architecture kernel implementations and runtime library 2026-05-26 12:59:35 +00:00
21 changed files with 453 additions and 31 deletions

26
kernels/wu_arch/Makefile Normal file
View File

@@ -0,0 +1,26 @@
PROJECT = wu_arch
VX_SRCS = kernel.cpp
OPTS ?= -n1
WU_VARIANT_DUMPS = \
kernel.radiance.barriers.dump
all: kernel.radiance.dump $(WU_VARIANT_DUMPS)
include ../common.mk
kernel.radiance.barriers.dump: kernel.radiance.barriers.elf
$(VX_DP) -D $< > $@
kernel.radiance.barriers.elf: $(VX_SRCS) $(VX_INCLUDES) $(BINFILES)
$(VX_CXX) $(VX_CFLAGS) $(VX_SRCS) $(VX_LDFLAGS) -DRADIANCE -DWU_RUN_DOMAIN_BARRIERS -o $@
$(OBJCOPY) --set-section-flags .operand.a=$(OBJCOPY_FLAGS) $@
$(OBJCOPY) --set-section-flags .operand.b=$(OBJCOPY_FLAGS) $@
$(OBJCOPY) --set-section-flags .operand.c=$(OBJCOPY_FLAGS) $@
$(OBJCOPY) --set-section-flags .args=$(OBJCOPY_FLAGS) $@
$(OBJCOPY) --update-section .operand.a=input.a.bin $@ || true
$(OBJCOPY) --update-section .operand.b=input.b.bin $@ || true
$(OBJCOPY) --update-section .operand.c=input.c.bin $@ || true
$(OBJCOPY) --update-section .args=args.bin $@ || true

1
kernels/wu_arch/args.bin Normal file
View File

@@ -0,0 +1 @@
0

View File

@@ -0,0 +1 @@
0

View File

@@ -0,0 +1 @@
0

View File

@@ -0,0 +1 @@
0

173
kernels/wu_arch/kernel.cpp Normal file
View File

@@ -0,0 +1,173 @@
#include <stdint.h>
#include <vx_intrinsics.h>
#define DEV_SMEM_START_ADDR 0xff000000u
#define MAX_WARPS 8
#define MINIMAL_INIT_WORDS 4
#define WU_SCALAR_SPIN 32
#define WU_TENSOR_SPIN 32
#define WU_WAIT_SPIN 8192
#define WU_STATUS_DONE 0x600du
#define WU_STATUS_SCALAR_BASE 0x5100u
#define WU_STATUS_TENSOR_BASE 0x7100u
#define WU_BARRIER_SCALAR 0u
#define WU_BARRIER_MASKED 1u
#define WU_BARRIER_TENSOR 2u
extern "C" {
volatile uint32_t g_status[MAX_WARPS] __attribute__((aligned(32)));
volatile uint32_t g_scalar_seen[MAX_WARPS] __attribute__((aligned(32)));
volatile uint32_t g_tensor_seen[MAX_WARPS] __attribute__((aligned(32)));
volatile uint32_t g_spin_sink[MAX_WARPS] __attribute__((aligned(32)));
extern volatile uint64_t tohost;
}
extern "C" void vx_perf_dump() {}
static inline void wu_report_tohost(uint32_t exit_code) {
asm volatile("fence rw, rw" ::: "memory");
tohost = (static_cast<uint64_t>(exit_code) << 1) | 1u;
asm volatile("fence rw, rw" ::: "memory");
}
extern "C" int wu_main();
extern "C" void __attribute__((naked, section(".init"), used)) _start() {
asm volatile(
".option push\n\t"
".option norelax\n\t"
"la gp, __global_pointer\n\t"
".option pop\n\t"
"csrr t0, %[csr_core]\n\t"
"bnez t0, 2f\n\t"
"li sp, %[stack_base]\n\t"
"call wu_main\n\t"
"mv gp, a0\n\t"
"2:\n\t"
".insn r %[custom0], 0, 0, x0, x0, x0\n\t"
"1: j 1b\n\t"
:
: [csr_core] "i"(VX_CSR_CORE_ID),
[stack_base] "i"(STACK_BASE_ADDR),
[custom0] "i"(RISCV_CUSTOM0)
: "memory");
}
extern "C" void scalar_worker() {
const uint32_t wid = static_cast<uint32_t>(vx_warp_id());
const uint32_t tid = static_cast<uint32_t>(vx_thread_id());
volatile uint32_t mix = wid + 1u;
#ifdef WU_RUN_DOMAIN_BARRIERS
vx_barrier_scalar(WU_BARRIER_SCALAR, NUM_SCALAR_WARPS);
vx_barrier_mask(WU_BARRIER_MASKED, vx_scalar_warp_mask());
#endif
for (uint32_t i = 0; i < WU_SCALAR_SPIN; ++i)
mix = (mix << 1) ^ (i + wid);
if (tid == 0 && wid < MAX_WARPS) {
g_spin_sink[wid] = mix;
g_scalar_seen[wid] = WU_STATUS_SCALAR_BASE | wid;
}
vx_tmc_zero();
while (1) {}
}
extern "C" void __attribute__((naked, noinline, used)) tensor_worker() {
asm volatile(
"csrr x5, %[csr_wid]\n\t"
#ifdef WU_RUN_DOMAIN_BARRIERS
"li x1, (%[bar_tensor] | (%[domain_tensor] << %[domain_shift]))\n\t"
"li x2, %[num_tensor]\n\t"
".insn r %[custom0], 4, 0, x0, x1, x2\n\t"
#endif
".insn r %[custom3], 3, 0, x0, x0, x0\n\t"
"li x7, %[tensor_spin]\n\t"
"1:\n\t"
"addi x7, x7, -1\n\t"
"bnez x7, 1b\n\t"
"slli x5, x5, 2\n\t"
"la x6, g_tensor_seen\n\t"
"add x6, x6, x5\n\t"
"li x7, %[tensor_base]\n\t"
"srli x5, x5, 2\n\t"
"or x7, x7, x5\n\t"
"sw x7, 0(x6)\n\t"
".insn r %[custom0], 0, 0, x0, x0, x0\n\t"
"2: j 2b\n\t"
:
: [csr_wid] "i"(VX_CSR_WARP_ID),
[custom0] "i"(RISCV_CUSTOM0),
[custom3] "i"(RISCV_CUSTOM3),
[bar_tensor] "i"(WU_BARRIER_TENSOR),
[domain_tensor] "i"(VX_BARRIER_DOMAIN_TENSOR),
[domain_shift] "i"(VX_BARRIER_DOMAIN_SHIFT),
[num_tensor] "i"(NUM_TENSOR_WARPS),
[tensor_spin] "i"(WU_TENSOR_SPIN),
[tensor_base] "i"(WU_STATUS_TENSOR_BASE)
: "memory");
}
static void init_state() {
g_status[0] = 0;
for (uint32_t i = 0; i < MAX_WARPS; ++i) {
g_scalar_seen[i] = 0;
g_tensor_seen[i] = 0;
}
volatile uint32_t *smem =
reinterpret_cast<volatile uint32_t *>(DEV_SMEM_START_ADDR);
for (uint32_t i = 0; i < MINIMAL_INIT_WORDS; ++i)
smem[i] = 0x100u + i;
}
static int wait_for_wu_completion() {
for (uint32_t spin = 0; spin < WU_WAIT_SPIN; ++spin) {
uint32_t done = 1;
for (uint32_t wid = 0; wid < NUM_SCALAR_WARPS; ++wid)
done &= (g_scalar_seen[wid] == (WU_STATUS_SCALAR_BASE | wid));
for (uint32_t wid = NUM_SCALAR_WARPS; wid < NUM_WARPS; ++wid)
done &= (g_tensor_seen[wid] == (WU_STATUS_TENSOR_BASE | wid));
if (done)
return 0;
}
return 1;
}
extern "C" int wu_main() {
if (vx_core_id() != 0 || vx_warp_id() != 0)
return 0;
if (vx_thread_id() != 0)
return 0;
init_state();
const uint32_t other_scalar_warps = vx_scalar_warp_mask() & ~1u;
if (other_scalar_warps != 0)
vx_spawn_scalar(other_scalar_warps, scalar_worker);
vx_spawn_tensor(vx_tensor_warp_mask(), tensor_worker);
#ifdef WU_RUN_DOMAIN_BARRIERS
vx_barrier_scalar(WU_BARRIER_SCALAR, NUM_SCALAR_WARPS);
vx_barrier_mask(WU_BARRIER_MASKED, vx_scalar_warp_mask());
#endif
volatile uint32_t mix = 1;
for (uint32_t i = 0; i < WU_SCALAR_SPIN; ++i)
mix = (mix << 1) ^ (i + 3u);
g_spin_sink[0] = mix;
g_scalar_seen[0] = WU_STATUS_SCALAR_BASE;
if (wait_for_wu_completion() != 0) {
g_status[0] = 0xe001u;
wu_report_tohost(1);
return 1;
}
g_status[0] = WU_STATUS_DONE;
wu_report_tohost(0);
return 0;
}

View File

@@ -1,14 +1,11 @@
#define WU_START_BRANCH_TO_MAIN 1
#include "common_wu_min.h"
extern "C" void scalar_worker() {
wu_short_delay(wu_wid());
wu_mark_seen(WU_CASE_SCALAR_BASE);
wu_stop_warp();
}
extern "C" void scalar_worker();
extern "C" int wu_main() {
if (!wu_is_leader()) {
return 0;
wu_stop_warp();
}
wu_case_reset();
@@ -21,9 +18,37 @@ extern "C" int wu_main() {
wu_mark_seen(WU_CASE_SCALAR_BASE);
if (wu_wait_seen_range(0, NUM_SCALAR_WARPS, WU_CASE_SCALAR_BASE) != 0) {
wu_case_fail(0x01u);
return 1;
wu_stop_warp();
}
wu_case_pass();
return 0;
wu_stop_warp();
}
extern "C" void scalar_worker_body();
extern "C" void __attribute__((naked, used)) scalar_worker() {
asm volatile(
".option push\n\t"
".option norelax\n\t"
"la gp, __global_pointer\n\t"
".option pop\n\t"
"li sp, %[stack_base]\n\t"
"csrr t0, %[csr_hart]\n\t"
"slli t1, t0, %[stack_log2]\n\t"
"slli t2, t0, 4\n\t"
"add t1, t1, t2\n\t"
"sub sp, sp, t1\n\t"
"j scalar_worker_body\n\t"
:
: [csr_hart] "i"(VX_CSR_MHARTID),
[stack_base] "i"(STACK_BASE_ADDR),
[stack_log2] "i"(STACK_LOG2_SIZE)
: "memory");
}
extern "C" void scalar_worker_body() {
wu_short_delay(wu_wid());
wu_mark_seen(WU_CASE_SCALAR_BASE);
wu_stop_warp();
}

View File

@@ -1,3 +1,4 @@
#define WU_CASE_WAIT_SPIN 1024u
#include "common_wu_min.h"
extern "C" void __attribute__((naked, noinline, used)) tensor_worker() {

View File

@@ -6,13 +6,15 @@ extern "C" void __attribute__((naked, noinline, used)) tensor_worker() {
"slli x6, x5, 2\n\t"
"la x7, g_case_mem\n\t"
"add x7, x7, x6\n\t"
"li x8, %[tensor_lsu_base]\n\t"
"or x8, x8, x5\n\t"
"sw x8, 0(x7)\n\t"
"lw x8, 0(x7)\n\t"
"li x6, %[tensor_lsu_base]\n\t"
"or x5, x6, x5\n\t"
"sw x5, 0(x7)\n\t"
"lw x5, 0(x7)\n\t"
"sub x6, x5, x6\n\t"
"slli x6, x6, 2\n\t"
"la x7, g_seen\n\t"
"add x7, x7, x6\n\t"
"sw x8, 0(x7)\n\t"
"sw x5, 0(x7)\n\t"
".insn r %[custom0], 0, 0, x0, x0, x0\n\t"
"1: j 1b\n\t"
:

View File

@@ -5,8 +5,12 @@
#include <vx_intrinsics.h>
#define WU_CASE_MAX_WARPS 8u
#ifndef WU_CASE_WAIT_SPIN
#define WU_CASE_WAIT_SPIN 1024u
#endif
#ifndef WU_CASE_SHORT_SPIN
#define WU_CASE_SHORT_SPIN 8u
#endif
#define WU_CASE_PASS 0x600du
#define WU_CASE_FAIL_BASE 0xe000u
@@ -15,6 +19,10 @@
#define WU_CASE_TENSOR_CSR_BASE 0x7300u
#define WU_CASE_TENSOR_LSU_BASE 0x7400u
#ifndef WU_START_BRANCH_TO_MAIN
#define WU_START_BRANCH_TO_MAIN 0
#endif
extern "C" {
volatile uint32_t g_status[WU_CASE_MAX_WARPS] __attribute__((aligned(32)));
volatile uint32_t g_seen[WU_CASE_MAX_WARPS] __attribute__((aligned(32)));
@@ -35,8 +43,12 @@ extern "C" void __attribute__((naked, section(".init"), used)) _start() {
"csrr t0, %[csr_core]\n\t"
"bnez t0, 2f\n\t"
"li sp, %[stack_base]\n\t"
#if WU_START_BRANCH_TO_MAIN
"beq zero, zero, wu_main\n\t"
#else
"call wu_main\n\t"
"mv gp, a0\n\t"
#endif
"2:\n\t"
".insn r %[custom0], 0, 0, x0, x0, x0\n\t"
"1: j 1b\n\t"
@@ -113,7 +125,7 @@ static inline void wu_mark_seen(uint32_t base) {
}
}
static inline void wu_stop_warp() {
static inline void __attribute__((noreturn)) wu_stop_warp() {
vx_tmc_zero();
while (1) {}
}

View File

@@ -0,0 +1,9 @@
PROJECT = wu_arch_hgemm
VX_SRCS = kernel.cpp
OPTS ?= -n1
include ../common.mk
args.bin input.a.bin input.b.bin input.c.bin: ../wu_arch_cases/zero.bin
cp $< $@

View File

@@ -0,0 +1,8 @@
# wu_arch_hgemm
Tensor-warp HGEMM smoke test for the Wu split scalar/tensor warp configuration.
Scalar warp 0 initializes the shared-memory B operand, spawns only the tensor
warp mask, waits for tensor warps `NUM_SCALAR_WARPS..NUM_WARPS-1`, and reports
completion through `tohost`. Tensor warps execute the Blackwell custom HGEMM
instruction sequence and then stop themselves.

View File

@@ -0,0 +1 @@
0

View File

@@ -0,0 +1 @@
0

View File

@@ -0,0 +1 @@
0

View File

@@ -0,0 +1 @@
0

View File

@@ -0,0 +1,87 @@
#include "../wu_arch_cases/common_wu_min.h"
#define DEV_SMEM_START_ADDR 0xff000000u
#define WU_CASE_TENSOR_HGEMM_BASE 0x7500u
#define BW_REP2(x) x, x
#define BW_REP4(x) BW_REP2(x), BW_REP2(x)
#define BW_REP8(x) BW_REP4(x), BW_REP4(x)
extern "C" {
volatile uint32_t g_hgemm_a_row[8] __attribute__((aligned(32))) = {
BW_REP8(0x3c003c00u)};
volatile uint32_t g_hgemm_b_row[8] __attribute__((aligned(32))) = {
BW_REP8(0x40004000u)};
volatile uint32_t g_hgemm_c_row[8] __attribute__((aligned(32))) = {
BW_REP8(0x3f800000u)};
}
#undef BW_REP2
#undef BW_REP4
#undef BW_REP8
extern "C" void __attribute__((naked, noinline, used)) tensor_hgemm_worker() {
asm volatile(
"csrr x5, %[csr_wid]\n\t"
"slli x1, x5, 11\n\t"
"addi x2, x1, 1024\n\t"
"la x6, g_hgemm_a_row\n\t"
"la x3, g_hgemm_c_row\n\t"
"li x7, 0\n\t"
"1:\n\t"
"add x4, x1, x7\n\t"
".insn r %[custom3], 2, 0, x0, x4, x6\n\t"
"add x4, x2, x7\n\t"
".insn r %[custom3], 2, 0, x0, x4, x3\n\t"
"addi x7, x7, 32\n\t"
"li x4, 1024\n\t"
"blt x7, x4, 1b\n\t"
".insn r %[custom3], 3, 0, x0, x0, x0\n\t"
"li x4, %[smem_base]\n\t"
".insn r %[custom3], 0, 0, x2, x1, x4\n\t"
".insn r %[custom3], 1, 0, x0, x0, x0\n\t"
"csrr x5, %[csr_wid]\n\t"
"slli x6, x5, 2\n\t"
"la x7, g_seen\n\t"
"add x7, x7, x6\n\t"
"li x6, %[hgemm_base]\n\t"
"or x6, x6, x5\n\t"
"sw x6, 0(x7)\n\t"
".insn r %[custom0], 0, 0, x0, x0, x0\n\t"
"2: j 2b\n\t"
:
: [csr_wid] "i"(VX_CSR_WARP_ID),
[custom0] "i"(RISCV_CUSTOM0),
[custom3] "i"(RISCV_CUSTOM3),
[smem_base] "i"(DEV_SMEM_START_ADDR),
[hgemm_base] "i"(WU_CASE_TENSOR_HGEMM_BASE)
: "memory");
}
extern "C" int wu_main() {
if (!wu_is_leader()) {
return 0;
}
wu_case_reset();
volatile uint32_t *smem_b =
reinterpret_cast<volatile uint32_t *>(DEV_SMEM_START_ADDR);
for (uint32_t frag = 0; frag < 32u; ++frag) {
const uint32_t row = frag * 8u;
for (uint32_t i = 0; i < 8u; ++i) {
smem_b[row + i] = g_hgemm_b_row[i];
}
}
vx_spawn_tensor(vx_tensor_warp_mask(), tensor_hgemm_worker);
if (wu_wait_seen_range(NUM_SCALAR_WARPS, NUM_WARPS,
WU_CASE_TENSOR_HGEMM_BASE) != 0) {
wu_case_fail(0x09u);
return 1;
}
wu_case_pass();
return 0;
}

View File

@@ -84,15 +84,32 @@
#endif
#ifndef NUM_CORES
#define NUM_CORES 8
#define NUM_CORES 1
#endif
#ifndef NUM_WARPS
#define NUM_WARPS 8
#define NUM_WARPS 4
#endif
#ifndef NUM_TENSOR_WARPS
#define NUM_TENSOR_WARPS 2
#endif
#define NUM_SCALAR_WARPS (NUM_WARPS - NUM_TENSOR_WARPS)
#define IS_SCALAR_WARP(wid) ((wid) < NUM_SCALAR_WARPS)
#define IS_TENSOR_WARP(wid) ((wid) >= NUM_SCALAR_WARPS)
#ifndef TENSOR_NUM_GPRS
#define TENSOR_NUM_GPRS 8
#endif
#ifndef TENSOR_NUM_FPRS
#define TENSOR_NUM_FPRS 8
#endif
#ifndef NUM_THREADS
#define NUM_THREADS 8
#define NUM_THREADS 4
#endif
#ifndef NUM_BARRIERS
@@ -682,4 +699,3 @@
#define IMPLEMENTATION_ID 0
#endif // VX_CONFIG_VH

View File

@@ -136,6 +136,19 @@ inline void vx_wspawn(unsigned num_warps, vx_wspawn_pfn func_ptr) {
asm volatile (".insn r %0, 1, 0, x0, %1, %2" :: "i"(RISCV_CUSTOM0), "r"(num_warps), "r"(func_ptr));
}
// Spawn an explicit warp mask. The current warp bit is ignored by hardware.
inline void vx_wspawn_mask(unsigned warp_mask, vx_wspawn_pfn func_ptr) {
asm volatile (".insn r %0, 6, 0, x0, %1, %2" :: "i"(RISCV_CUSTOM0), "r"(warp_mask), "r"(func_ptr));
}
inline void vx_spawn_scalar(unsigned warp_mask, vx_wspawn_pfn func_ptr) {
vx_wspawn_mask(warp_mask & ((1u << NUM_SCALAR_WARPS) - 1u), func_ptr);
}
inline void vx_spawn_tensor(unsigned warp_mask, vx_wspawn_pfn func_ptr) {
vx_wspawn_mask(warp_mask & (((1u << NUM_TENSOR_WARPS) - 1u) << NUM_SCALAR_WARPS), func_ptr);
}
// Split on a predicate
inline unsigned vx_split(unsigned predicate) {
unsigned ret;
@@ -151,7 +164,34 @@ inline void vx_join(unsigned stack_ptr) {
// Warp Barrier
__attribute__((convergent))
inline void vx_barrier(unsigned barried_id, unsigned num_warps) {
asm volatile (".insn r %0, 4, 0, x0, %1, %2" :: "i"(RISCV_CUSTOM0), "r"(barried_id), "r"(num_warps));
unsigned scalar_warps = (num_warps > NUM_SCALAR_WARPS) ? NUM_SCALAR_WARPS : num_warps;
asm volatile (".insn r %0, 4, 0, x0, %1, %2" :: "i"(RISCV_CUSTOM0), "r"(barried_id), "r"(scalar_warps));
}
#define VX_BARRIER_DOMAIN_SHIFT 28
#define VX_BARRIER_DOMAIN_ALL 0u
#define VX_BARRIER_DOMAIN_SCALAR 1u
#define VX_BARRIER_DOMAIN_TENSOR 2u
__attribute__((convergent))
inline void vx_barrier_domain(unsigned barrier_id, unsigned num_warps, unsigned domain) {
unsigned encoded_id = barrier_id | (domain << VX_BARRIER_DOMAIN_SHIFT);
asm volatile (".insn r %0, 4, 0, x0, %1, %2" :: "i"(RISCV_CUSTOM0), "r"(encoded_id), "r"(num_warps));
}
__attribute__((convergent))
inline void vx_barrier_scalar(unsigned barrier_id, unsigned num_warps) {
vx_barrier_domain(barrier_id, num_warps, VX_BARRIER_DOMAIN_SCALAR);
}
__attribute__((convergent))
inline void vx_barrier_tensor(unsigned barrier_id, unsigned num_warps) {
vx_barrier_domain(barrier_id, num_warps, VX_BARRIER_DOMAIN_TENSOR);
}
__attribute__((convergent))
inline void vx_barrier_mask(unsigned barrier_id, unsigned warp_mask) {
asm volatile (".insn r %0, 7, 0, x0, %1, %2" :: "i"(RISCV_CUSTOM0), "r"(barrier_id), "r"(warp_mask));
}
// Return current thread identifier
@@ -203,6 +243,22 @@ inline int vx_num_warps() {
return ret;
}
inline int vx_num_scalar_warps() {
return NUM_SCALAR_WARPS;
}
inline int vx_num_tensor_warps() {
return NUM_TENSOR_WARPS;
}
inline unsigned vx_scalar_warp_mask() {
return (1u << NUM_SCALAR_WARPS) - 1u;
}
inline unsigned vx_tensor_warp_mask() {
return ((1u << NUM_TENSOR_WARPS) - 1u) << NUM_SCALAR_WARPS;
}
// Return the number of cores per cluster
inline int vx_num_cores() {
int ret;

View File

@@ -76,7 +76,7 @@ static void __attribute__ ((noinline)) spawn_tasks_all_stub() {
static void __attribute__ ((noinline)) spawn_tasks_contiguous_all_stub() {
int NT = vx_num_threads();
int NW = vx_num_warps();
int NW = NUM_SCALAR_WARPS;
int cid = vx_core_id();
int wid = vx_warp_id();
int tid = vx_thread_id();
@@ -96,7 +96,7 @@ static void __attribute__ ((noinline)) spawn_tasks_contiguous_all_stub() {
static void __attribute__ ((noinline)) spawn_tasks_cluster_all_stub() {
int NT = vx_num_threads();
int NW = vx_num_warps();
int NW = NUM_SCALAR_WARPS;
int cid = vx_core_id();
int wid = vx_warp_id();
int tid = vx_thread_id();
@@ -187,7 +187,7 @@ static void __attribute__ ((noinline)) spawn_tasks_all_cb() {
void vx_spawn_tasks_cluster(int num_tasks, vx_spawn_tasks_cb callback, void *arg) {
// device specs
const int NC = vx_num_cores();
const int NW = vx_num_warps();
const int NW = NUM_SCALAR_WARPS;
const int NT = vx_num_threads();
// NOTE: assumes divisible
const int num_cluster = NC / CORES_PER_CLUSTER;
@@ -243,7 +243,7 @@ void vx_spawn_tasks_cluster(int num_tasks, vx_spawn_tasks_cb callback, void *arg
const int num_full_waves = num_warps_this_core / NW;
const int rem_full_warps_in_last_wave = num_warps_this_core % NW;
const const int offset = cluster_id * num_tasks_this_cluster;
const int offset = cluster_id * num_tasks_this_cluster;
wspawn_tasks_args_t wspawn_args = {callback, arg, offset, num_full_waves,
rem_full_warps_in_last_wave};
g_wspawn_args[core_id] = &wspawn_args;
@@ -289,7 +289,7 @@ void vx_spawn_tasks_cluster(int num_tasks, vx_spawn_tasks_cb callback, void *arg
void vx_spawn_tasks_contiguous(int num_tasks, vx_spawn_tasks_cb callback , void * arg) {
// device specs
int NC = vx_num_cores();
int NW = vx_num_warps();
int NW = NUM_SCALAR_WARPS;
int NT = vx_num_threads();
// current core id
@@ -361,7 +361,7 @@ void vx_spawn_tasks_contiguous(int num_tasks, vx_spawn_tasks_cb callback , void
void vx_spawn_tasks(int num_tasks, vx_spawn_tasks_cb callback , void * arg) {
// device specs
int NC = vx_num_cores();
int NW = vx_num_warps();
int NW = NUM_SCALAR_WARPS;
int NT = vx_num_threads();
// current core id
@@ -515,7 +515,7 @@ void vx_spawn_kernel(context_t * ctx, vx_spawn_kernel_cb callback, void * arg) {
// device specs
int NC = vx_num_cores();
int NW = vx_num_warps();
int NW = NUM_SCALAR_WARPS;
int NT = vx_num_threads();
// current core id

View File

@@ -22,9 +22,9 @@
_start:
# initialize per-thread registers
csrr t0, VX_CSR_NUM_WARPS # get num warps
li t0, ((1 << NUM_SCALAR_WARPS) - 1) # scalar warp mask
la t1, init_regs_all
.insn r RISCV_CUSTOM0, 1, 0, x0, t0, t1 # wspawn t0, t1
.insn r RISCV_CUSTOM0, 6, 0, x0, t0, t1 # wspawn_mask t0, t1
li t0, -1
.insn r RISCV_CUSTOM0, 0, 0, x0, t0, x0 # tmc t0
jal init_regs
@@ -35,9 +35,9 @@ _start:
jal vx_wspawn_wait
# initialize TLS for all warps
csrr t0, VX_CSR_NUM_WARPS # get num warps
li t0, ((1 << NUM_SCALAR_WARPS) - 1) # scalar warp mask
la t1, init_tls_all
.insn r RISCV_CUSTOM0, 1, 0, x0, t0, t1 # wspawn t0, t1
.insn r RISCV_CUSTOM0, 6, 0, x0, t0, t1 # wspawn_mask t0, t1
li t0, -1
.insn r RISCV_CUSTOM0, 0, 0, x0, t0, x0 # tmc t0
call __init_tls
@@ -150,4 +150,3 @@ vx_wspawn_wait:
.weak __dso_handle
__dso_handle:
.long 0