Implement WU architecture support

This commit is contained in:
2026-05-25 19:25:05 +08:00
parent 323ed7d7e9
commit 0ad87bde81
35 changed files with 3303 additions and 472 deletions

View File

@@ -136,6 +136,19 @@ inline void vx_wspawn(unsigned num_warps, vx_wspawn_pfn func_ptr) {
asm volatile (".insn r %0, 1, 0, x0, %1, %2" :: "i"(RISCV_CUSTOM0), "r"(num_warps), "r"(func_ptr));
}
// Spawn an explicit warp mask. The current warp bit is ignored by hardware.
inline void vx_wspawn_mask(unsigned warp_mask, vx_wspawn_pfn func_ptr) {
asm volatile (".insn r %0, 6, 0, x0, %1, %2" :: "i"(RISCV_CUSTOM0), "r"(warp_mask), "r"(func_ptr));
}
inline void vx_spawn_scalar(unsigned warp_mask, vx_wspawn_pfn func_ptr) {
vx_wspawn_mask(warp_mask & ((1u << NUM_SCALAR_WARPS) - 1u), func_ptr);
}
inline void vx_spawn_tensor(unsigned warp_mask, vx_wspawn_pfn func_ptr) {
vx_wspawn_mask(warp_mask & (((1u << NUM_TENSOR_WARPS) - 1u) << NUM_SCALAR_WARPS), func_ptr);
}
// Split on a predicate
inline unsigned vx_split(unsigned predicate) {
unsigned ret;
@@ -149,8 +162,36 @@ inline void vx_join(unsigned stack_ptr) {
}
// Warp Barrier
__attribute__((convergent))
inline void vx_barrier(unsigned barried_id, unsigned num_warps) {
asm volatile (".insn r %0, 4, 0, x0, %1, %2" :: "i"(RISCV_CUSTOM0), "r"(barried_id), "r"(num_warps));
unsigned scalar_warps = (num_warps > NUM_SCALAR_WARPS) ? NUM_SCALAR_WARPS : num_warps;
asm volatile (".insn r %0, 4, 0, x0, %1, %2" :: "i"(RISCV_CUSTOM0), "r"(barried_id), "r"(scalar_warps));
}
#define VX_BARRIER_DOMAIN_SHIFT 28
#define VX_BARRIER_DOMAIN_ALL 0u
#define VX_BARRIER_DOMAIN_SCALAR 1u
#define VX_BARRIER_DOMAIN_TENSOR 2u
__attribute__((convergent))
inline void vx_barrier_domain(unsigned barrier_id, unsigned num_warps, unsigned domain) {
unsigned encoded_id = barrier_id | (domain << VX_BARRIER_DOMAIN_SHIFT);
asm volatile (".insn r %0, 4, 0, x0, %1, %2" :: "i"(RISCV_CUSTOM0), "r"(encoded_id), "r"(num_warps));
}
__attribute__((convergent))
inline void vx_barrier_scalar(unsigned barrier_id, unsigned num_warps) {
vx_barrier_domain(barrier_id, num_warps, VX_BARRIER_DOMAIN_SCALAR);
}
__attribute__((convergent))
inline void vx_barrier_tensor(unsigned barrier_id, unsigned num_warps) {
vx_barrier_domain(barrier_id, num_warps, VX_BARRIER_DOMAIN_TENSOR);
}
__attribute__((convergent))
inline void vx_barrier_mask(unsigned barrier_id, unsigned warp_mask) {
asm volatile (".insn r %0, 7, 0, x0, %1, %2" :: "i"(RISCV_CUSTOM0), "r"(barrier_id), "r"(warp_mask));
}
// Return current thread identifier
@@ -202,6 +243,22 @@ inline int vx_num_warps() {
return ret;
}
inline int vx_num_scalar_warps() {
return NUM_SCALAR_WARPS;
}
inline int vx_num_tensor_warps() {
return NUM_TENSOR_WARPS;
}
inline unsigned vx_scalar_warp_mask() {
return (1u << NUM_SCALAR_WARPS) - 1u;
}
inline unsigned vx_tensor_warp_mask() {
return ((1u << NUM_TENSOR_WARPS) - 1u) << NUM_SCALAR_WARPS;
}
// Return the number of cores per cluster
inline int vx_num_cores() {
int ret;