Compare commits
7 Commits
e7229dae27
...
wu-blackwe
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8a15e5251e | ||
|
|
3f7ce1f1c9 | ||
|
|
f1aa1303d2 | ||
|
|
d6fbd447c3 | ||
|
|
ed16541c8e | ||
| 122a048ea6 | |||
|
|
9f4be1b8f7 |
10
kernels/blackwell_fp8_e4m3/Makefile
Normal file
10
kernels/blackwell_fp8_e4m3/Makefile
Normal file
@@ -0,0 +1,10 @@
|
||||
PROJECT = blackwell_fp8_e4m3
|
||||
|
||||
VX_SRCS = kernel.cpp
|
||||
VX_INCLUDES = fp8_common.hpp
|
||||
OPTS ?= -n1
|
||||
|
||||
include ../common.mk
|
||||
|
||||
args.bin input.a.bin input.b.bin input.c.bin: ../wu_arch_cases/zero.bin
|
||||
cp $< $@
|
||||
21
kernels/blackwell_fp8_e4m3/README.md
Normal file
21
kernels/blackwell_fp8_e4m3/README.md
Normal file
@@ -0,0 +1,21 @@
|
||||
# blackwell_fp8_e4m3
|
||||
|
||||
Standalone FP8 E4M3 validation kernel for the Wu Blackwell BWGMMA branch.
|
||||
|
||||
This directory is the only kernel area used by the FP8 branch work. Existing
|
||||
FP16 HGEMM, `wu_arch_cases`, and flash kernels are intentionally left unchanged.
|
||||
|
||||
The validation runs one tensor warp on a 16x16x32 tile:
|
||||
|
||||
- A is FP8 E4M3 1.0 (`0x38`)
|
||||
- B is FP8 E4M3 2.0 (`0x40`)
|
||||
- C is FP32 1.0 (`0x3f800000`)
|
||||
- Expected output is FP32 65.0 (`0x42820000`)
|
||||
- `VirgoBlackwellConfig` currently uses 4 core/memory lanes, so one
|
||||
`tcgen05_cp/cb` fragment is 16 bytes.
|
||||
|
||||
Build:
|
||||
|
||||
```bash
|
||||
make -C /home/lzd/wu/wuarch/virgo-kernels/kernels/blackwell_fp8_e4m3
|
||||
```
|
||||
53
kernels/blackwell_fp8_e4m3/fp8_common.hpp
Normal file
53
kernels/blackwell_fp8_e4m3/fp8_common.hpp
Normal file
@@ -0,0 +1,53 @@
|
||||
#ifndef BLACKWELL_FP8_E4M3_COMMON_HPP
|
||||
#define BLACKWELL_FP8_E4M3_COMMON_HPP
|
||||
|
||||
#include <stdint.h>
|
||||
#include <vx_intrinsics.h>
|
||||
|
||||
#define WU_FP8_E4M3_ZERO 0x00u
|
||||
#define WU_FP8_E4M3_HALF 0x30u
|
||||
#define WU_FP8_E4M3_ONE 0x38u
|
||||
#define WU_FP8_E4M3_TWO 0x40u
|
||||
|
||||
#define WU_FP8_PACK4(a, b, c, d) \
|
||||
((((uint32_t)(a) & 0xffu) << 0) | (((uint32_t)(b) & 0xffu) << 8) | \
|
||||
(((uint32_t)(c) & 0xffu) << 16) | (((uint32_t)(d) & 0xffu) << 24))
|
||||
|
||||
#define WU_FP8_REP2(x) x, x
|
||||
#define WU_FP8_REP4(x) WU_FP8_REP2(x), WU_FP8_REP2(x)
|
||||
#define WU_FP8_REP8(x) WU_FP8_REP4(x), WU_FP8_REP4(x)
|
||||
|
||||
static inline void wu_tcgen05_cp(uint32_t addr_tmem, uint32_t addr_gmem) {
|
||||
asm volatile(".insn r %0, 2, 0, x0, %1, %2"
|
||||
:
|
||||
: "i"(RISCV_CUSTOM3), "r"(addr_tmem), "r"(addr_gmem)
|
||||
: "memory");
|
||||
}
|
||||
|
||||
static inline void wu_tcgen05_cp_wait() {
|
||||
asm volatile(".insn r %0, 3, 0, x0, x0, x0" :: "i"(RISCV_CUSTOM3)
|
||||
: "memory");
|
||||
}
|
||||
|
||||
static inline void wu_tcgen05_cb(uint32_t addr_tmem, uint32_t addr_gmem) {
|
||||
asm volatile(".insn r %0, 6, 0, x0, %1, %2"
|
||||
:
|
||||
: "i"(RISCV_CUSTOM3), "r"(addr_tmem), "r"(addr_gmem)
|
||||
: "memory");
|
||||
}
|
||||
|
||||
static inline void wu_bwgmma_fp8(uint32_t addr_tmem_c, uint32_t addr_tmem_a,
|
||||
uint32_t addr_smem_b) {
|
||||
asm volatile(".insn r %0, 0, 0, %1, %2, %3"
|
||||
:
|
||||
: "i"(RISCV_CUSTOM3), "r"(addr_tmem_c), "r"(addr_tmem_a),
|
||||
"r"(addr_smem_b)
|
||||
: "memory");
|
||||
}
|
||||
|
||||
static inline void wu_bwgmma_wait() {
|
||||
asm volatile(".insn r %0, 1, 0, x0, x0, x0" :: "i"(RISCV_CUSTOM3)
|
||||
: "memory");
|
||||
}
|
||||
|
||||
#endif
|
||||
121
kernels/blackwell_fp8_e4m3/kernel.cpp
Normal file
121
kernels/blackwell_fp8_e4m3/kernel.cpp
Normal file
@@ -0,0 +1,121 @@
|
||||
#include "fp8_common.hpp"
|
||||
#include "../wu_arch_cases/common_wu_min.h"
|
||||
|
||||
#define DEV_SMEM_START_ADDR 0xff000000u
|
||||
#define FP8_VALIDATION_BASE 0x7800u
|
||||
|
||||
#define FP8_M 16u
|
||||
#define FP8_N 16u
|
||||
#define FP8_K 32u
|
||||
#define FP8_TILE_BYTES 1024u
|
||||
#define FP8_FRAGMENT_BYTES 16u
|
||||
#define FP8_FRAGMENT_WORDS (FP8_FRAGMENT_BYTES / sizeof(uint32_t))
|
||||
#define FP8_FRAGMENTS (FP8_TILE_BYTES / FP8_FRAGMENT_BYTES)
|
||||
#define FP8_OUT_WORDS (FP8_M * FP8_N)
|
||||
#define FP8_EXPECTED 0x42820000u
|
||||
|
||||
extern "C" {
|
||||
volatile uint32_t g_fp8_a_frag[FP8_FRAGMENT_WORDS] __attribute__((aligned(16))) = {
|
||||
WU_FP8_REP4(WU_FP8_PACK4(WU_FP8_E4M3_ONE, WU_FP8_E4M3_ONE,
|
||||
WU_FP8_E4M3_ONE, WU_FP8_E4M3_ONE))};
|
||||
volatile uint32_t g_fp8_b_frag[FP8_FRAGMENT_WORDS] __attribute__((aligned(16))) = {
|
||||
WU_FP8_REP4(WU_FP8_PACK4(WU_FP8_E4M3_TWO, WU_FP8_E4M3_TWO,
|
||||
WU_FP8_E4M3_TWO, WU_FP8_E4M3_TWO))};
|
||||
volatile uint32_t g_fp8_c_frag[FP8_FRAGMENT_WORDS] __attribute__((aligned(16))) = {
|
||||
WU_FP8_REP4(0x3f800000u)};
|
||||
volatile uint32_t g_fp8_out[FP8_OUT_WORDS] __attribute__((aligned(16)));
|
||||
}
|
||||
|
||||
#undef WU_FP8_REP2
|
||||
#undef WU_FP8_REP4
|
||||
#undef WU_FP8_REP8
|
||||
|
||||
extern "C" void __attribute__((naked, noinline, used)) fp8_validation_worker() {
|
||||
asm volatile(
|
||||
"li x1, 0\n\t"
|
||||
"li x2, %[tile_bytes]\n\t"
|
||||
"la x6, g_fp8_a_frag\n\t"
|
||||
"la x3, g_fp8_c_frag\n\t"
|
||||
"li x7, 0\n\t"
|
||||
"1:\n\t"
|
||||
"add x4, x1, x7\n\t"
|
||||
".insn r %[custom3], 2, 0, x0, x4, x6\n\t"
|
||||
"add x4, x2, x7\n\t"
|
||||
".insn r %[custom3], 2, 0, x0, x4, x3\n\t"
|
||||
"addi x7, x7, %[frag_bytes]\n\t"
|
||||
"li x4, %[tile_bytes]\n\t"
|
||||
"blt x7, x4, 1b\n\t"
|
||||
".insn r %[custom3], 3, 0, x0, x0, x0\n\t"
|
||||
"li x4, %[smem_base]\n\t"
|
||||
".insn r %[custom3], 0, 0, x2, x1, x4\n\t"
|
||||
".insn r %[custom3], 1, 0, x0, x0, x0\n\t"
|
||||
"la x3, g_fp8_out\n\t"
|
||||
"li x7, 0\n\t"
|
||||
"2:\n\t"
|
||||
"add x4, x2, x7\n\t"
|
||||
"add x1, x3, x7\n\t"
|
||||
".insn r %[custom3], 6, 0, x0, x4, x1\n\t"
|
||||
"addi x7, x7, %[frag_bytes]\n\t"
|
||||
"li x4, %[tile_bytes]\n\t"
|
||||
"blt x7, x4, 2b\n\t"
|
||||
".insn r %[custom3], 3, 0, x0, x0, x0\n\t"
|
||||
"csrr x5, %[csr_wid]\n\t"
|
||||
"slli x6, x5, 2\n\t"
|
||||
"la x7, g_seen\n\t"
|
||||
"add x7, x7, x6\n\t"
|
||||
"li x6, %[done_base]\n\t"
|
||||
"or x6, x6, x5\n\t"
|
||||
"sw x6, 0(x7)\n\t"
|
||||
".insn r %[custom0], 0, 0, x0, x0, x0\n\t"
|
||||
"3: j 3b\n\t"
|
||||
:
|
||||
: [csr_wid] "i"(VX_CSR_WARP_ID),
|
||||
[custom0] "i"(RISCV_CUSTOM0),
|
||||
[custom3] "i"(RISCV_CUSTOM3),
|
||||
[smem_base] "i"(DEV_SMEM_START_ADDR),
|
||||
[done_base] "i"(FP8_VALIDATION_BASE),
|
||||
[tile_bytes] "i"(FP8_TILE_BYTES),
|
||||
[frag_bytes] "i"(FP8_FRAGMENT_BYTES)
|
||||
: "memory");
|
||||
}
|
||||
|
||||
extern "C" int wu_main() {
|
||||
if (!wu_is_leader()) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
wu_case_reset();
|
||||
|
||||
for (uint32_t i = 0; i < FP8_OUT_WORDS; ++i) {
|
||||
g_fp8_out[i] = 0;
|
||||
}
|
||||
|
||||
volatile uint32_t *smem_b =
|
||||
reinterpret_cast<volatile uint32_t *>(DEV_SMEM_START_ADDR);
|
||||
for (uint32_t frag = 0; frag < FP8_FRAGMENTS; ++frag) {
|
||||
const uint32_t row = frag * FP8_FRAGMENT_WORDS;
|
||||
for (uint32_t i = 0; i < FP8_FRAGMENT_WORDS; ++i) {
|
||||
smem_b[row + i] = g_fp8_b_frag[i];
|
||||
}
|
||||
}
|
||||
|
||||
const uint32_t tensor_wid = NUM_SCALAR_WARPS;
|
||||
vx_spawn_tensor(1u << tensor_wid, fp8_validation_worker);
|
||||
|
||||
if (wu_wait_seen_mask(1u << tensor_wid, FP8_VALIDATION_BASE) != 0) {
|
||||
wu_case_fail(0x09u);
|
||||
return 1;
|
||||
}
|
||||
|
||||
for (uint32_t i = 0; i < FP8_OUT_WORDS; ++i) {
|
||||
if (g_fp8_out[i] != FP8_EXPECTED) {
|
||||
g_aux[0] = i;
|
||||
g_aux[1] = g_fp8_out[i];
|
||||
wu_case_fail(0x20u);
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
wu_case_pass();
|
||||
return 0;
|
||||
}
|
||||
7
kernels/blackwell_multi_tc/Makefile
Normal file
7
kernels/blackwell_multi_tc/Makefile
Normal file
@@ -0,0 +1,7 @@
|
||||
PROJECT = blackwell_multi_tc
|
||||
|
||||
VX_SRCS = kernel.cpp
|
||||
|
||||
OPTS ?= -n1
|
||||
|
||||
include ../common.mk
|
||||
188
kernels/blackwell_multi_tc/kernel.cpp
Normal file
188
kernels/blackwell_multi_tc/kernel.cpp
Normal file
@@ -0,0 +1,188 @@
|
||||
#include <stdint.h>
|
||||
#include <vx_intrinsics.h>
|
||||
#include <vx_spawn.h>
|
||||
|
||||
#define RISCV_CUSTOM3 0x7B
|
||||
#define DEV_SMEM_START_ADDR 0xff000000u
|
||||
#define MAX_CORES 4
|
||||
#define MAX_WARPS 8
|
||||
|
||||
#define BW_REP2(x) x, x
|
||||
#define BW_REP4(x) BW_REP2(x), BW_REP2(x)
|
||||
#define BW_REP8(x) BW_REP4(x), BW_REP4(x)
|
||||
|
||||
extern "C" {
|
||||
volatile uint32_t g_a_row[8] __attribute__((aligned(32))) = {
|
||||
BW_REP8(0x3c003c00u)}; // two fp16 1.0 values per word
|
||||
volatile uint32_t g_b_row[8] __attribute__((aligned(32))) = {
|
||||
BW_REP8(0x40004000u)}; // two fp16 2.0 values per word
|
||||
volatile uint32_t g_c_row[8] __attribute__((aligned(32))) = {
|
||||
BW_REP8(0x3f800000u)}; // one fp32 1.0 value per word
|
||||
volatile uint32_t g_result[MAX_CORES * MAX_WARPS] __attribute__((aligned(32)));
|
||||
volatile uint32_t g_status[MAX_CORES * MAX_WARPS] __attribute__((aligned(32)));
|
||||
}
|
||||
|
||||
#undef BW_REP2
|
||||
#undef BW_REP4
|
||||
#undef BW_REP8
|
||||
|
||||
static inline void tcgen05_cp(uint32_t addr_tmem, uint32_t addr_gmem) {
|
||||
asm volatile(".insn r %0, 2, 0, x0, %1, %2"
|
||||
:
|
||||
: "i"(RISCV_CUSTOM3), "r"(addr_tmem), "r"(addr_gmem)
|
||||
: "memory");
|
||||
}
|
||||
|
||||
static inline void tcgen05_cp_wait() {
|
||||
asm volatile(".insn r %0, 3, 0, x0, x0, x0" :: "i"(RISCV_CUSTOM3)
|
||||
: "memory");
|
||||
}
|
||||
|
||||
static inline void bwgmma(uint32_t addr_tmem_c,
|
||||
uint32_t addr_tmem_a,
|
||||
uint32_t addr_smem_b) {
|
||||
asm volatile(".insn r %0, 0, 0, %1, %2, %3"
|
||||
:
|
||||
: "i"(RISCV_CUSTOM3), "r"(addr_tmem_c), "r"(addr_tmem_a),
|
||||
"r"(addr_smem_b)
|
||||
: "memory");
|
||||
}
|
||||
|
||||
static inline void bwgmma_wait() {
|
||||
asm volatile(".insn r %0, 1, 0, x0, x0, x0" :: "i"(RISCV_CUSTOM3)
|
||||
: "memory");
|
||||
}
|
||||
|
||||
static inline float tcgen05_ld_f32(uint32_t addr_tmem) {
|
||||
float value;
|
||||
asm volatile(".insn r %1, 4, 0, %0, %2, x0"
|
||||
: "=f"(value)
|
||||
: "i"(RISCV_CUSTOM3), "r"(addr_tmem)
|
||||
: "memory");
|
||||
return value;
|
||||
}
|
||||
|
||||
static inline uint32_t f32_bits(float value) {
|
||||
union {
|
||||
float f;
|
||||
uint32_t u;
|
||||
} bits = {value};
|
||||
return bits.u;
|
||||
}
|
||||
|
||||
extern "C" void vx_perf_dump() {}
|
||||
|
||||
extern "C" void __attribute__((naked, noinline, used)) tensor_kernel_entry() {
|
||||
asm volatile(
|
||||
"csrr x5, %[csr_wid]\n\t"
|
||||
"slli x1, x5, 11\n\t" // tmem_a = wid * 0x800
|
||||
"addi x2, x1, 1024\n\t" // tmem_c = tmem_a + 0x400
|
||||
"la x6, g_a_row\n\t"
|
||||
"la x3, g_c_row\n\t"
|
||||
"li x7, 0\n\t"
|
||||
"1:\n\t"
|
||||
"add x4, x1, x7\n\t"
|
||||
".insn r %[custom3], 2, 0, x0, x4, x6\n\t"
|
||||
"add x4, x2, x7\n\t"
|
||||
".insn r %[custom3], 2, 0, x0, x4, x3\n\t"
|
||||
"addi x7, x7, 32\n\t"
|
||||
"li x4, 1024\n\t"
|
||||
"blt x7, x4, 1b\n\t"
|
||||
".insn r %[custom3], 3, 0, x0, x0, x0\n\t"
|
||||
"li x4, %[smem_base]\n\t"
|
||||
".insn r %[custom3], 0, 0, x2, x1, x4\n\t"
|
||||
".insn r %[custom3], 1, 0, x0, x0, x0\n\t"
|
||||
"csrr x5, %[csr_wid]\n\t"
|
||||
"slli x5, x5, 2\n\t"
|
||||
"la x6, g_status\n\t"
|
||||
"add x6, x6, x5\n\t"
|
||||
"li x7, 0x600d\n\t"
|
||||
"sw x7, 0(x6)\n\t"
|
||||
".insn r %[custom0], 0, 0, x0, x0, x0\n\t"
|
||||
"2: j 2b\n\t"
|
||||
:
|
||||
: [csr_wid] "i"(VX_CSR_WARP_ID), [custom0] "i"(RISCV_CUSTOM0),
|
||||
[custom3] "i"(RISCV_CUSTOM3), [smem_base] "i"(DEV_SMEM_START_ADDR)
|
||||
: "memory");
|
||||
}
|
||||
|
||||
struct kernel_arg_t {
|
||||
volatile uint32_t *a_row;
|
||||
volatile uint32_t *b_row;
|
||||
volatile uint32_t *c_row;
|
||||
volatile uint32_t *result;
|
||||
volatile uint32_t *status;
|
||||
};
|
||||
|
||||
static void __attribute__((convergent)) kernel_body(int task_id,
|
||||
kernel_arg_t *__UNIFORM__ arg) {
|
||||
const int warp_id = task_id / vx_num_threads();
|
||||
const int lane_id = task_id % vx_num_threads();
|
||||
if (lane_id != 0)
|
||||
return;
|
||||
|
||||
const int num_warps = vx_num_warps();
|
||||
if (warp_id >= num_warps)
|
||||
return;
|
||||
|
||||
volatile uint32_t *a_row = arg->a_row;
|
||||
volatile uint32_t *c_row = arg->c_row;
|
||||
volatile uint32_t *result = arg->result;
|
||||
volatile uint32_t *status = arg->status;
|
||||
const int core_id = vx_core_id();
|
||||
const int status_idx = core_id * MAX_WARPS + warp_id;
|
||||
|
||||
const uint32_t smem_b = DEV_SMEM_START_ADDR;
|
||||
const uint32_t base = static_cast<uint32_t>(warp_id) * 0x800u;
|
||||
const uint32_t tmem_a = base + 0x000u;
|
||||
const uint32_t tmem_c = base + 0x400u;
|
||||
|
||||
status[status_idx] = 0x100u;
|
||||
|
||||
for (int frag = 0; frag < 32; ++frag) {
|
||||
const uint32_t offset = static_cast<uint32_t>(frag * 32);
|
||||
tcgen05_cp(tmem_a + offset, reinterpret_cast<uint32_t>(a_row));
|
||||
tcgen05_cp(tmem_c + offset, reinterpret_cast<uint32_t>(c_row));
|
||||
}
|
||||
tcgen05_cp_wait();
|
||||
|
||||
status[status_idx] = 0x200u;
|
||||
bwgmma(tmem_c, tmem_a, smem_b);
|
||||
bwgmma_wait();
|
||||
|
||||
result[status_idx] = f32_bits(tcgen05_ld_f32(tmem_c));
|
||||
status[status_idx] = (result[status_idx] == 0x42820000u) ? 0x600du : 0xe000u;
|
||||
}
|
||||
|
||||
int main() {
|
||||
const int core_id = vx_core_id();
|
||||
if (core_id != 0 || vx_warp_id() != 0 || vx_thread_id() != 0)
|
||||
return 0;
|
||||
|
||||
volatile uint32_t *smem_b_ptr =
|
||||
reinterpret_cast<volatile uint32_t *>(DEV_SMEM_START_ADDR);
|
||||
for (int frag = 0; frag < 32; ++frag) {
|
||||
const int row = frag * 8;
|
||||
for (int i = 0; i < 8; ++i)
|
||||
smem_b_ptr[row + i] = g_b_row[i];
|
||||
}
|
||||
|
||||
for (int i = 0; i < MAX_WARPS; ++i)
|
||||
g_status[core_id * MAX_WARPS + i] = 0;
|
||||
|
||||
vx_spawn_tensor(vx_tensor_warp_mask(), tensor_kernel_entry);
|
||||
|
||||
for (int spin = 0; spin < 100000; ++spin) {
|
||||
int done = 1;
|
||||
for (int i = NUM_SCALAR_WARPS; i < vx_num_warps(); ++i)
|
||||
done &= (g_status[core_id * MAX_WARPS + i] == 0x600du);
|
||||
if (done)
|
||||
break;
|
||||
}
|
||||
|
||||
for (int i = NUM_SCALAR_WARPS; i < vx_num_warps(); ++i) {
|
||||
if (g_status[core_id * MAX_WARPS + i] != 0x600du)
|
||||
return 1;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
@@ -1,3 +1,6 @@
|
||||
# Get the directory where this common.mk file is located
|
||||
COMMON_MK_DIR := $(dir $(lastword $(MAKEFILE_LIST)))
|
||||
|
||||
XLEN ?= 32
|
||||
|
||||
TOOLDIR ?= /opt
|
||||
@@ -7,7 +10,7 @@ RISCV_TOOLCHAIN_PATH ?= $(TOOLDIR)/riscv64-gnu-toolchain
|
||||
VX_CFLAGS += -march=rv64imafd -mabi=lp64d
|
||||
STARTUP_ADDR ?= 0x180000000
|
||||
else
|
||||
RISCV_TOOLCHAIN_PATH ?= $(TOOLDIR)/riscv-gnu-toolchain
|
||||
RISCV_TOOLCHAIN_PATH ?= $(realpath $(COMMON_MK_DIR)../../toolchain/vortex-toolchain-prebuilt/riscv-gnu-toolchain)
|
||||
VX_CFLAGS += -march=rv32imaf -mabi=ilp32f
|
||||
STARTUP_ADDR ?= 0x80000000
|
||||
endif
|
||||
@@ -18,7 +21,7 @@ RISCV_SYSROOT ?= $(RISCV_TOOLCHAIN_PATH)/$(RISCV_PREFIX)
|
||||
VORTEX_KN_PATH ?= $(realpath ../../lib)
|
||||
GEMMINI_SW_PATH ?= $(realpath ../../lib/gemmini)
|
||||
|
||||
LLVM_VORTEX ?= $(TOOLDIR)/llvm-vortex
|
||||
LLVM_VORTEX ?= $(realpath $(COMMON_MK_DIR)../../toolchain/llvm-r8)
|
||||
|
||||
LLVM_CFLAGS += --sysroot=$(RISCV_SYSROOT)
|
||||
LLVM_CFLAGS += --gcc-toolchain=$(RISCV_TOOLCHAIN_PATH)
|
||||
|
||||
9
kernels/hgemm_validation/Makefile
Normal file
9
kernels/hgemm_validation/Makefile
Normal file
@@ -0,0 +1,9 @@
|
||||
PROJECT = hgemm_validation
|
||||
|
||||
VX_SRCS = kernel.cpp
|
||||
OPTS ?= -n1
|
||||
|
||||
include ../common.mk
|
||||
|
||||
args.bin input.a.bin input.b.bin input.c.bin: ../wu_arch_cases/zero.bin
|
||||
cp $< $@
|
||||
14
kernels/hgemm_validation/README.md
Normal file
14
kernels/hgemm_validation/README.md
Normal file
@@ -0,0 +1,14 @@
|
||||
# hgemm_validation
|
||||
|
||||
Small HGEMM correctness validation for the 4-lane Blackwell tensor-core path.
|
||||
|
||||
The test runs one tensor warp on a single 16x16x32 HGEMM tile:
|
||||
|
||||
- A is fp16 1.0
|
||||
- B is fp16 2.0
|
||||
- C starts as fp32 1.0
|
||||
- expected output is fp32 65.0 for all 16x16 C elements
|
||||
|
||||
Scalar warp 0 initializes the shared-memory B tile, spawns only the first tensor
|
||||
warp, waits for completion, and checks all 256 output words copied back from
|
||||
TMEM. Success reports `WU_CASE_PASS` through `tohost`.
|
||||
119
kernels/hgemm_validation/kernel.cpp
Normal file
119
kernels/hgemm_validation/kernel.cpp
Normal file
@@ -0,0 +1,119 @@
|
||||
#include "../wu_arch_cases/common_wu_min.h"
|
||||
|
||||
#define DEV_SMEM_START_ADDR 0xff000000u
|
||||
#define HGEMM_VALIDATION_BASE 0x7700u
|
||||
|
||||
#define HGEMM_M 16u
|
||||
#define HGEMM_N 16u
|
||||
#define HGEMM_K 32u
|
||||
#define HGEMM_TILE_BYTES 1024u
|
||||
#define HGEMM_FRAGMENT_BYTES 16u
|
||||
#define HGEMM_FRAGMENTS (HGEMM_TILE_BYTES / HGEMM_FRAGMENT_BYTES)
|
||||
#define HGEMM_OUT_WORDS (HGEMM_M * HGEMM_N)
|
||||
#define HGEMM_EXPECTED 0x42820000u
|
||||
|
||||
#define BW_REP2(x) x, x
|
||||
#define BW_REP4(x) BW_REP2(x), BW_REP2(x)
|
||||
|
||||
extern "C" {
|
||||
volatile uint32_t g_hgemm_a_frag[4] __attribute__((aligned(16))) = {
|
||||
BW_REP4(0x3c003c00u)};
|
||||
volatile uint32_t g_hgemm_b_frag[4] __attribute__((aligned(16))) = {
|
||||
BW_REP4(0x40004000u)};
|
||||
volatile uint32_t g_hgemm_c_frag[4] __attribute__((aligned(16))) = {
|
||||
BW_REP4(0x3f800000u)};
|
||||
volatile uint32_t g_hgemm_out[HGEMM_OUT_WORDS] __attribute__((aligned(16)));
|
||||
}
|
||||
|
||||
#undef BW_REP2
|
||||
#undef BW_REP4
|
||||
|
||||
extern "C" void __attribute__((naked, noinline, used)) hgemm_validation_worker() {
|
||||
asm volatile(
|
||||
"li x1, 0\n\t"
|
||||
"li x2, %[tile_bytes]\n\t"
|
||||
"la x6, g_hgemm_a_frag\n\t"
|
||||
"la x3, g_hgemm_c_frag\n\t"
|
||||
"li x7, 0\n\t"
|
||||
"1:\n\t"
|
||||
"add x4, x1, x7\n\t"
|
||||
".insn r %[custom3], 2, 0, x0, x4, x6\n\t"
|
||||
"add x4, x2, x7\n\t"
|
||||
".insn r %[custom3], 2, 0, x0, x4, x3\n\t"
|
||||
"addi x7, x7, %[frag_bytes]\n\t"
|
||||
"li x4, %[tile_bytes]\n\t"
|
||||
"blt x7, x4, 1b\n\t"
|
||||
".insn r %[custom3], 3, 0, x0, x0, x0\n\t"
|
||||
"li x4, %[smem_base]\n\t"
|
||||
".insn r %[custom3], 0, 0, x2, x1, x4\n\t"
|
||||
".insn r %[custom3], 1, 0, x0, x0, x0\n\t"
|
||||
"la x3, g_hgemm_out\n\t"
|
||||
"li x7, 0\n\t"
|
||||
"2:\n\t"
|
||||
"add x4, x2, x7\n\t"
|
||||
"add x1, x3, x7\n\t"
|
||||
".insn r %[custom3], 6, 0, x0, x4, x1\n\t"
|
||||
"addi x7, x7, %[frag_bytes]\n\t"
|
||||
"li x4, %[tile_bytes]\n\t"
|
||||
"blt x7, x4, 2b\n\t"
|
||||
".insn r %[custom3], 3, 0, x0, x0, x0\n\t"
|
||||
"csrr x5, %[csr_wid]\n\t"
|
||||
"slli x6, x5, 2\n\t"
|
||||
"la x7, g_seen\n\t"
|
||||
"add x7, x7, x6\n\t"
|
||||
"li x6, %[done_base]\n\t"
|
||||
"or x6, x6, x5\n\t"
|
||||
"sw x6, 0(x7)\n\t"
|
||||
".insn r %[custom0], 0, 0, x0, x0, x0\n\t"
|
||||
"3: j 3b\n\t"
|
||||
:
|
||||
: [csr_wid] "i"(VX_CSR_WARP_ID),
|
||||
[custom0] "i"(RISCV_CUSTOM0),
|
||||
[custom3] "i"(RISCV_CUSTOM3),
|
||||
[smem_base] "i"(DEV_SMEM_START_ADDR),
|
||||
[done_base] "i"(HGEMM_VALIDATION_BASE),
|
||||
[tile_bytes] "i"(HGEMM_TILE_BYTES),
|
||||
[frag_bytes] "i"(HGEMM_FRAGMENT_BYTES)
|
||||
: "memory");
|
||||
}
|
||||
|
||||
extern "C" int wu_main() {
|
||||
if (!wu_is_leader()) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
wu_case_reset();
|
||||
|
||||
for (uint32_t i = 0; i < HGEMM_OUT_WORDS; ++i) {
|
||||
g_hgemm_out[i] = 0;
|
||||
}
|
||||
|
||||
volatile uint32_t *smem_b =
|
||||
reinterpret_cast<volatile uint32_t *>(DEV_SMEM_START_ADDR);
|
||||
for (uint32_t frag = 0; frag < HGEMM_FRAGMENTS; ++frag) {
|
||||
const uint32_t row = frag * 4u;
|
||||
for (uint32_t i = 0; i < 4u; ++i) {
|
||||
smem_b[row + i] = g_hgemm_b_frag[i];
|
||||
}
|
||||
}
|
||||
|
||||
const uint32_t tensor_wid = NUM_SCALAR_WARPS;
|
||||
vx_spawn_tensor(1u << tensor_wid, hgemm_validation_worker);
|
||||
|
||||
if (wu_wait_seen_mask(1u << tensor_wid, HGEMM_VALIDATION_BASE) != 0) {
|
||||
wu_case_fail(0x09u);
|
||||
return 1;
|
||||
}
|
||||
|
||||
for (uint32_t i = 0; i < HGEMM_OUT_WORDS; ++i) {
|
||||
if (g_hgemm_out[i] != HGEMM_EXPECTED) {
|
||||
g_aux[0] = i;
|
||||
g_aux[1] = g_hgemm_out[i];
|
||||
wu_case_fail(0x20u);
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
wu_case_pass();
|
||||
return 0;
|
||||
}
|
||||
26
kernels/wu_arch/Makefile
Normal file
26
kernels/wu_arch/Makefile
Normal file
@@ -0,0 +1,26 @@
|
||||
PROJECT = wu_arch
|
||||
|
||||
VX_SRCS = kernel.cpp
|
||||
|
||||
OPTS ?= -n1
|
||||
|
||||
WU_VARIANT_DUMPS = \
|
||||
kernel.radiance.barriers.dump
|
||||
|
||||
all: kernel.radiance.dump $(WU_VARIANT_DUMPS)
|
||||
|
||||
include ../common.mk
|
||||
|
||||
kernel.radiance.barriers.dump: kernel.radiance.barriers.elf
|
||||
$(VX_DP) -D $< > $@
|
||||
|
||||
kernel.radiance.barriers.elf: $(VX_SRCS) $(VX_INCLUDES) $(BINFILES)
|
||||
$(VX_CXX) $(VX_CFLAGS) $(VX_SRCS) $(VX_LDFLAGS) -DRADIANCE -DWU_RUN_DOMAIN_BARRIERS -o $@
|
||||
$(OBJCOPY) --set-section-flags .operand.a=$(OBJCOPY_FLAGS) $@
|
||||
$(OBJCOPY) --set-section-flags .operand.b=$(OBJCOPY_FLAGS) $@
|
||||
$(OBJCOPY) --set-section-flags .operand.c=$(OBJCOPY_FLAGS) $@
|
||||
$(OBJCOPY) --set-section-flags .args=$(OBJCOPY_FLAGS) $@
|
||||
$(OBJCOPY) --update-section .operand.a=input.a.bin $@ || true
|
||||
$(OBJCOPY) --update-section .operand.b=input.b.bin $@ || true
|
||||
$(OBJCOPY) --update-section .operand.c=input.c.bin $@ || true
|
||||
$(OBJCOPY) --update-section .args=args.bin $@ || true
|
||||
1
kernels/wu_arch/args.bin
Normal file
1
kernels/wu_arch/args.bin
Normal file
@@ -0,0 +1 @@
|
||||
0
|
||||
1
kernels/wu_arch/input.a.bin
Normal file
1
kernels/wu_arch/input.a.bin
Normal file
@@ -0,0 +1 @@
|
||||
0
|
||||
1
kernels/wu_arch/input.b.bin
Normal file
1
kernels/wu_arch/input.b.bin
Normal file
@@ -0,0 +1 @@
|
||||
0
|
||||
1
kernels/wu_arch/input.c.bin
Normal file
1
kernels/wu_arch/input.c.bin
Normal file
@@ -0,0 +1 @@
|
||||
0
|
||||
173
kernels/wu_arch/kernel.cpp
Normal file
173
kernels/wu_arch/kernel.cpp
Normal file
@@ -0,0 +1,173 @@
|
||||
#include <stdint.h>
|
||||
#include <vx_intrinsics.h>
|
||||
|
||||
#define DEV_SMEM_START_ADDR 0xff000000u
|
||||
#define MAX_WARPS 8
|
||||
#define MINIMAL_INIT_WORDS 4
|
||||
#define WU_SCALAR_SPIN 32
|
||||
#define WU_TENSOR_SPIN 32
|
||||
#define WU_WAIT_SPIN 8192
|
||||
#define WU_STATUS_DONE 0x600du
|
||||
#define WU_STATUS_SCALAR_BASE 0x5100u
|
||||
#define WU_STATUS_TENSOR_BASE 0x7100u
|
||||
#define WU_BARRIER_SCALAR 0u
|
||||
#define WU_BARRIER_MASKED 1u
|
||||
#define WU_BARRIER_TENSOR 2u
|
||||
|
||||
extern "C" {
|
||||
volatile uint32_t g_status[MAX_WARPS] __attribute__((aligned(32)));
|
||||
volatile uint32_t g_scalar_seen[MAX_WARPS] __attribute__((aligned(32)));
|
||||
volatile uint32_t g_tensor_seen[MAX_WARPS] __attribute__((aligned(32)));
|
||||
volatile uint32_t g_spin_sink[MAX_WARPS] __attribute__((aligned(32)));
|
||||
extern volatile uint64_t tohost;
|
||||
}
|
||||
|
||||
extern "C" void vx_perf_dump() {}
|
||||
|
||||
static inline void wu_report_tohost(uint32_t exit_code) {
|
||||
asm volatile("fence rw, rw" ::: "memory");
|
||||
tohost = (static_cast<uint64_t>(exit_code) << 1) | 1u;
|
||||
asm volatile("fence rw, rw" ::: "memory");
|
||||
}
|
||||
|
||||
extern "C" int wu_main();
|
||||
|
||||
extern "C" void __attribute__((naked, section(".init"), used)) _start() {
|
||||
asm volatile(
|
||||
".option push\n\t"
|
||||
".option norelax\n\t"
|
||||
"la gp, __global_pointer\n\t"
|
||||
".option pop\n\t"
|
||||
"csrr t0, %[csr_core]\n\t"
|
||||
"bnez t0, 2f\n\t"
|
||||
"li sp, %[stack_base]\n\t"
|
||||
"call wu_main\n\t"
|
||||
"mv gp, a0\n\t"
|
||||
"2:\n\t"
|
||||
".insn r %[custom0], 0, 0, x0, x0, x0\n\t"
|
||||
"1: j 1b\n\t"
|
||||
:
|
||||
: [csr_core] "i"(VX_CSR_CORE_ID),
|
||||
[stack_base] "i"(STACK_BASE_ADDR),
|
||||
[custom0] "i"(RISCV_CUSTOM0)
|
||||
: "memory");
|
||||
}
|
||||
|
||||
extern "C" void scalar_worker() {
|
||||
const uint32_t wid = static_cast<uint32_t>(vx_warp_id());
|
||||
const uint32_t tid = static_cast<uint32_t>(vx_thread_id());
|
||||
volatile uint32_t mix = wid + 1u;
|
||||
|
||||
#ifdef WU_RUN_DOMAIN_BARRIERS
|
||||
vx_barrier_scalar(WU_BARRIER_SCALAR, NUM_SCALAR_WARPS);
|
||||
vx_barrier_mask(WU_BARRIER_MASKED, vx_scalar_warp_mask());
|
||||
#endif
|
||||
|
||||
for (uint32_t i = 0; i < WU_SCALAR_SPIN; ++i)
|
||||
mix = (mix << 1) ^ (i + wid);
|
||||
|
||||
if (tid == 0 && wid < MAX_WARPS) {
|
||||
g_spin_sink[wid] = mix;
|
||||
g_scalar_seen[wid] = WU_STATUS_SCALAR_BASE | wid;
|
||||
}
|
||||
|
||||
vx_tmc_zero();
|
||||
while (1) {}
|
||||
}
|
||||
|
||||
extern "C" void __attribute__((naked, noinline, used)) tensor_worker() {
|
||||
asm volatile(
|
||||
"csrr x5, %[csr_wid]\n\t"
|
||||
#ifdef WU_RUN_DOMAIN_BARRIERS
|
||||
"li x1, (%[bar_tensor] | (%[domain_tensor] << %[domain_shift]))\n\t"
|
||||
"li x2, %[num_tensor]\n\t"
|
||||
".insn r %[custom0], 4, 0, x0, x1, x2\n\t"
|
||||
#endif
|
||||
".insn r %[custom3], 3, 0, x0, x0, x0\n\t"
|
||||
"li x7, %[tensor_spin]\n\t"
|
||||
"1:\n\t"
|
||||
"addi x7, x7, -1\n\t"
|
||||
"bnez x7, 1b\n\t"
|
||||
"slli x5, x5, 2\n\t"
|
||||
"la x6, g_tensor_seen\n\t"
|
||||
"add x6, x6, x5\n\t"
|
||||
"li x7, %[tensor_base]\n\t"
|
||||
"srli x5, x5, 2\n\t"
|
||||
"or x7, x7, x5\n\t"
|
||||
"sw x7, 0(x6)\n\t"
|
||||
".insn r %[custom0], 0, 0, x0, x0, x0\n\t"
|
||||
"2: j 2b\n\t"
|
||||
:
|
||||
: [csr_wid] "i"(VX_CSR_WARP_ID),
|
||||
[custom0] "i"(RISCV_CUSTOM0),
|
||||
[custom3] "i"(RISCV_CUSTOM3),
|
||||
[bar_tensor] "i"(WU_BARRIER_TENSOR),
|
||||
[domain_tensor] "i"(VX_BARRIER_DOMAIN_TENSOR),
|
||||
[domain_shift] "i"(VX_BARRIER_DOMAIN_SHIFT),
|
||||
[num_tensor] "i"(NUM_TENSOR_WARPS),
|
||||
[tensor_spin] "i"(WU_TENSOR_SPIN),
|
||||
[tensor_base] "i"(WU_STATUS_TENSOR_BASE)
|
||||
: "memory");
|
||||
}
|
||||
|
||||
static void init_state() {
|
||||
g_status[0] = 0;
|
||||
for (uint32_t i = 0; i < MAX_WARPS; ++i) {
|
||||
g_scalar_seen[i] = 0;
|
||||
g_tensor_seen[i] = 0;
|
||||
}
|
||||
|
||||
volatile uint32_t *smem =
|
||||
reinterpret_cast<volatile uint32_t *>(DEV_SMEM_START_ADDR);
|
||||
for (uint32_t i = 0; i < MINIMAL_INIT_WORDS; ++i)
|
||||
smem[i] = 0x100u + i;
|
||||
}
|
||||
|
||||
static int wait_for_wu_completion() {
|
||||
for (uint32_t spin = 0; spin < WU_WAIT_SPIN; ++spin) {
|
||||
uint32_t done = 1;
|
||||
for (uint32_t wid = 0; wid < NUM_SCALAR_WARPS; ++wid)
|
||||
done &= (g_scalar_seen[wid] == (WU_STATUS_SCALAR_BASE | wid));
|
||||
for (uint32_t wid = NUM_SCALAR_WARPS; wid < NUM_WARPS; ++wid)
|
||||
done &= (g_tensor_seen[wid] == (WU_STATUS_TENSOR_BASE | wid));
|
||||
if (done)
|
||||
return 0;
|
||||
}
|
||||
return 1;
|
||||
}
|
||||
|
||||
extern "C" int wu_main() {
|
||||
if (vx_core_id() != 0 || vx_warp_id() != 0)
|
||||
return 0;
|
||||
if (vx_thread_id() != 0)
|
||||
return 0;
|
||||
|
||||
init_state();
|
||||
|
||||
const uint32_t other_scalar_warps = vx_scalar_warp_mask() & ~1u;
|
||||
if (other_scalar_warps != 0)
|
||||
vx_spawn_scalar(other_scalar_warps, scalar_worker);
|
||||
|
||||
vx_spawn_tensor(vx_tensor_warp_mask(), tensor_worker);
|
||||
|
||||
#ifdef WU_RUN_DOMAIN_BARRIERS
|
||||
vx_barrier_scalar(WU_BARRIER_SCALAR, NUM_SCALAR_WARPS);
|
||||
vx_barrier_mask(WU_BARRIER_MASKED, vx_scalar_warp_mask());
|
||||
#endif
|
||||
|
||||
volatile uint32_t mix = 1;
|
||||
for (uint32_t i = 0; i < WU_SCALAR_SPIN; ++i)
|
||||
mix = (mix << 1) ^ (i + 3u);
|
||||
g_spin_sink[0] = mix;
|
||||
g_scalar_seen[0] = WU_STATUS_SCALAR_BASE;
|
||||
|
||||
if (wait_for_wu_completion() != 0) {
|
||||
g_status[0] = 0xe001u;
|
||||
wu_report_tohost(1);
|
||||
return 1;
|
||||
}
|
||||
|
||||
g_status[0] = WU_STATUS_DONE;
|
||||
wu_report_tohost(0);
|
||||
return 0;
|
||||
}
|
||||
@@ -7,7 +7,25 @@ CASES := \
|
||||
case05_tensor_barrier \
|
||||
case06_masked_barrier \
|
||||
case07_tensor_csr_tmc \
|
||||
case08_tensor_lsu_optional
|
||||
case08_tensor_lsu_optional \
|
||||
case09_scalar_tmem_ldst \
|
||||
case10_tensor_scalar_tmem_handoff \
|
||||
case11_scalar_tmem_softmax_stage \
|
||||
case12_flash_pv_accum \
|
||||
case12_1_scalar_tmem_cb_probe \
|
||||
case12_2_flash_pv_p_probe \
|
||||
case12_3_scalar_tmem_lane_store \
|
||||
case13_flash_pv_two_warps \
|
||||
case14_flash_pv_k64 \
|
||||
case15_flash_softmax_pv_stage \
|
||||
case16_flash_full_pipeline \
|
||||
case17_flash_exp_softmax_probe \
|
||||
case18_scalar_fexp \
|
||||
case20_flash_bwd_fused \
|
||||
case21_moe_gating \
|
||||
case22_gemm_silu \
|
||||
case23_softmax_only \
|
||||
case24_flash_sw_pipeline
|
||||
|
||||
SMOKE_CASES := \
|
||||
case00_boot_scalar \
|
||||
|
||||
@@ -13,9 +13,33 @@ This directory contains small bare-metal kernels for incremental Wu architecture
|
||||
- `case06_masked_barrier`: explicit mixed `BAR_MASK` with scalar warp 0 and tensor warps.
|
||||
- `case07_tensor_csr_tmc`: tensor CSR/TMC path without barrier behavior.
|
||||
- `case08_tensor_lsu_optional`: tensor LSU store/load marker path; keep last because memory interaction is broader and slower.
|
||||
- `case09_scalar_tmem_ldst`: scalar warp direct TMEM store/load path for the banked TMEM softmax mechanism.
|
||||
- `case10_tensor_scalar_tmem_handoff`: tensor BWGMMA result in TMEM C observed by scalar TMEM loads.
|
||||
- `case11_scalar_tmem_softmax_stage`: scalar TMEM transform written back for tensor-side copy-out.
|
||||
- `case12_flash_pv_accum`: one tensor warp consumes scalar-written `P` in TMEM A for `O = O + P @ V`.
|
||||
- `case12_1_scalar_tmem_cb_probe`: scalar-written TMEM A rows copied back by tensor `tcgen05_cb`.
|
||||
- `case12_2_flash_pv_p_probe`: case12 P-write diagnostic using the same scalar fill path.
|
||||
- `case12_3_scalar_tmem_lane_store`: scalar TMEM store lane-coalesced fragment write diagnostic.
|
||||
- `case13_flash_pv_two_warps`: both tensor warps consume scalar-written `P` tiles for two row blocks.
|
||||
- `case14_flash_pv_k64`: two consecutive `K=32` BWGMMA steps accumulate into one PV output.
|
||||
- `case15_flash_softmax_pv_stage`: scalar reads TMEM C, writes softmax-like `P`, and tensor consumes it in PV.
|
||||
- `case16_flash_full_pipeline`: compact `QK -> scalar softmax handoff -> PV` end-to-end FlashAttention-style pipeline.
|
||||
- `case17_flash_exp_softmax_probe`: scalar non-uniform `e^x` softmax probe for generalized FlashAttention.
|
||||
- `case18_scalar_fexp`: scalar `FEXP.S` numerical probe.
|
||||
- `case20_flash_bwd_fused`: FlashAttention backward-style fused 5xMMA plus scalar softmax/dsoftmax handoff.
|
||||
- `case21_moe_gating`: scalar `softmax -> Top-K -> scatter` MoE gating pipeline.
|
||||
- `case22_gemm_silu`: tensor GEMM followed by scalar SiLU activation.
|
||||
- `case23_softmax_only`: scalar-only stable softmax probe.
|
||||
- `case24_flash_sw_pipeline`: four-iteration ping-pong FlashAttention-style software pipeline.
|
||||
|
||||
Each case has its own `README.md` describing the test objective, RTL surface, and expected pass marker.
|
||||
|
||||
## Debug Notes
|
||||
|
||||
- `TMC_DEBUG_NOTES.md`: reusable notes for diagnosing lane-mask/TMC operand
|
||||
bugs, including the `case12_3_scalar_tmem_lane_store` failure where a
|
||||
lane0-only register value was consumed under an all-lane mask.
|
||||
|
||||
## Build
|
||||
|
||||
Use the suite Makefile from this directory:
|
||||
|
||||
174
kernels/wu_arch_cases/TMC_DEBUG_NOTES.md
Normal file
174
kernels/wu_arch_cases/TMC_DEBUG_NOTES.md
Normal file
@@ -0,0 +1,174 @@
|
||||
# TMC Operand Debug Notes
|
||||
|
||||
This note records a failure pattern found while debugging
|
||||
`case12_3_scalar_tmem_lane_store`.
|
||||
|
||||
## Symptom
|
||||
|
||||
The simulator keeps producing trace output, but the kernel makes no forward
|
||||
progress:
|
||||
|
||||
- a tensor worker repeatedly polls a ready flag such as `g_case_mem[0]`;
|
||||
- the poll load always returns the old value;
|
||||
- the scalar warp has already committed the work before the ready flag store;
|
||||
- the expected ready flag store never appears in the LSU trace.
|
||||
|
||||
For `case12_3_scalar_tmem_lane_store`, the tensor worker was looping at
|
||||
`PC=0x80000034..0x80000048`, repeatedly reading `g_case_mem[0]` at
|
||||
`0x20000430` as `0`. The scalar warp committed the scalar TMEM store and the
|
||||
following `TMC`, but never issued the next ready flag store at `PC=0x8000015c`.
|
||||
|
||||
## Root Cause Pattern
|
||||
|
||||
Vortex scalar registers are lane-local. A value computed while only lane 0 is
|
||||
active is only known to be valid in lane 0. If that register is later consumed
|
||||
by an instruction while multiple lanes are active, the inactive lanes may hold
|
||||
stale or zero values.
|
||||
|
||||
This is especially easy to miss around `TMC`:
|
||||
|
||||
```asm
|
||||
# Bad when xN was defined while only lane 0 was active.
|
||||
vx_tmc all_lanes
|
||||
...
|
||||
vx_tmc xN
|
||||
```
|
||||
|
||||
The failed `case12_3` sequence used a C operand for `1u` that the compiler
|
||||
materialized before switching to all lanes. Runtime trace showed the source
|
||||
register for the final `TMC(1)` as:
|
||||
|
||||
```text
|
||||
rs1_data={0x0, 0x0, 0x0, 0x1}
|
||||
```
|
||||
|
||||
After that `TMC`, the scalar warp did not fetch the ready flag store.
|
||||
|
||||
## Correct Pattern
|
||||
|
||||
When a value is consumed under an all-lane mask, define that value under the
|
||||
same all-lane mask unless the instruction semantics explicitly use only lane 0.
|
||||
|
||||
For switching back to lane 0 from an all-lane region, keep the immediate
|
||||
materialization and the `TMC` adjacent inside the all-lane region:
|
||||
|
||||
```asm
|
||||
vx_tmc all_lanes
|
||||
...
|
||||
fence rw, rw
|
||||
li t2, 1
|
||||
vx_tmc t2
|
||||
```
|
||||
|
||||
The expected trace at the final `TMC` is:
|
||||
|
||||
```text
|
||||
rs1_data={0x1, 0x1, 0x1, 0x1}
|
||||
```
|
||||
|
||||
The library helper `vx_tmc_one()` follows this pattern because it emits
|
||||
`li a0, 1` and `vx_tmc a0` in the same volatile asm block. It is safe when
|
||||
called while all lanes are active.
|
||||
|
||||
## Fast Log Checks
|
||||
|
||||
Use the simulation log to distinguish a simulator stall from a kernel-level
|
||||
wait loop:
|
||||
|
||||
```sh
|
||||
stat -c '%s %y' chipyard/sims/verilator/output/chipyard.harness.TestHarness.VirgoBlackwellConfig/kernel.radiance.log
|
||||
tail -n 80 chipyard/sims/verilator/output/chipyard.harness.TestHarness.VirgoBlackwellConfig/kernel.radiance.log
|
||||
```
|
||||
|
||||
If the file is still growing but the tail repeatedly shows the same PC range,
|
||||
the simulator is alive and the kernel is probably spinning.
|
||||
|
||||
For a ready flag handoff, check both the polling load and the producer store:
|
||||
|
||||
```sh
|
||||
rg -n "0x20000430|0x9900|wid=0, PC=0x8000014|wid=0, PC=0x8000015" \
|
||||
chipyard/sims/verilator/output/chipyard.harness.TestHarness.VirgoBlackwellConfig/kernel.radiance.log
|
||||
```
|
||||
|
||||
Interpretation:
|
||||
|
||||
- load repeatedly returns `0` and no later `0x9900` store exists: producer did
|
||||
not reach the ready flag store;
|
||||
- ready flag store exists but consumer still reads old data: investigate memory
|
||||
ordering or address aliasing;
|
||||
- producer stops immediately after a `TMC`: inspect that `TMC` source register
|
||||
across all lanes.
|
||||
|
||||
## Dump Checks
|
||||
|
||||
Check that the final lane-mask narrowing defines `1` immediately before `TMC`
|
||||
inside the all-lane region:
|
||||
|
||||
```sh
|
||||
rg -n "li\\s+.*1|vx_tmc" \
|
||||
kernels/wu_arch_cases/case12_3_scalar_tmem_lane_store/kernel.radiance.dump
|
||||
```
|
||||
|
||||
The fixed `case12_3` dump has:
|
||||
|
||||
```text
|
||||
li t2, 1
|
||||
vx_tmc t2
|
||||
auipc a3, 1
|
||||
sw a5, ...
|
||||
```
|
||||
|
||||
## Case12-Case15 Audit
|
||||
|
||||
The following cases were checked after fixing `case12_3`:
|
||||
|
||||
- `case12_flash_pv_accum`
|
||||
- `case12_2_flash_pv_p_probe`
|
||||
- `case13_flash_pv_two_warps`
|
||||
- `case14_flash_pv_k64`
|
||||
- `case15_flash_softmax_pv_stage`
|
||||
|
||||
They switch to all lanes with `vx_tmc(wu_bw_all_lanes_mask())` and switch back
|
||||
with `vx_tmc_one()`. Their dumps show the safe adjacent sequence:
|
||||
|
||||
```text
|
||||
li a0, 1
|
||||
vx_tmc a0
|
||||
```
|
||||
|
||||
No analogous source change is needed for those cases.
|
||||
|
||||
## Scalar TMEM Fill Lane Coverage
|
||||
|
||||
`wu_bw_fill_tmem_tile()` must run with all fragment lanes active. Scalar TMEM
|
||||
store writes the active lane's source word into the matching TMEM fragment word:
|
||||
|
||||
```text
|
||||
TMEM[addr].word[lane] = rs2_data[lane]
|
||||
```
|
||||
|
||||
Therefore a fill loop executed with `tmask=0001` only initializes word 0 of
|
||||
each 16-byte fragment. Word 1 and later can retain old TMEM data. In
|
||||
`case12_1_scalar_tmem_cb_probe`, this showed up as a normal completion path
|
||||
followed by verification failure `0x14`; `g_aux[0]` was `1`, and `g_aux[1]`
|
||||
held the stale copied-back word.
|
||||
|
||||
The correct pattern is:
|
||||
|
||||
```c++
|
||||
vx_tmc(wu_bw_all_lanes_mask());
|
||||
wu_bw_fill_tmem_tile(wu_bw_tmem_a_byte_base(0), WU_BW_FP16_ONE_PACKED);
|
||||
vx_tmc_one();
|
||||
```
|
||||
|
||||
The dump should show the fill loop bracketed by all-lane and lane-0 masks:
|
||||
|
||||
```text
|
||||
li a5, 15
|
||||
vx_tmc a5
|
||||
...
|
||||
vx_cmov zero, a0, a5, a2
|
||||
...
|
||||
li a0, 1
|
||||
vx_tmc a0
|
||||
```
|
||||
@@ -3,6 +3,8 @@ VX_SRCS = kernel.cpp
|
||||
VX_CFLAGS += -I..
|
||||
VORTEX_KN_PATH ?= $(realpath ../../../lib)
|
||||
GEMMINI_SW_PATH ?= $(realpath ../../../lib/gemmini)
|
||||
LLVM_VORTEX ?= $(realpath ../../../../toolchain/llvm-r8)
|
||||
RISCV_TOOLCHAIN_PATH ?= $(realpath ../../../../toolchain/vortex-toolchain-prebuilt/riscv-gnu-toolchain)
|
||||
OPTS ?= -n1
|
||||
|
||||
include ../../common.mk
|
||||
|
||||
@@ -1,14 +1,11 @@
|
||||
#define WU_START_BRANCH_TO_MAIN 1
|
||||
#include "common_wu_min.h"
|
||||
|
||||
extern "C" void scalar_worker() {
|
||||
wu_short_delay(wu_wid());
|
||||
wu_mark_seen(WU_CASE_SCALAR_BASE);
|
||||
wu_stop_warp();
|
||||
}
|
||||
extern "C" void scalar_worker();
|
||||
|
||||
extern "C" int wu_main() {
|
||||
if (!wu_is_leader()) {
|
||||
return 0;
|
||||
wu_stop_warp();
|
||||
}
|
||||
|
||||
wu_case_reset();
|
||||
@@ -21,9 +18,37 @@ extern "C" int wu_main() {
|
||||
wu_mark_seen(WU_CASE_SCALAR_BASE);
|
||||
if (wu_wait_seen_range(0, NUM_SCALAR_WARPS, WU_CASE_SCALAR_BASE) != 0) {
|
||||
wu_case_fail(0x01u);
|
||||
return 1;
|
||||
wu_stop_warp();
|
||||
}
|
||||
|
||||
wu_case_pass();
|
||||
return 0;
|
||||
wu_stop_warp();
|
||||
}
|
||||
|
||||
extern "C" void scalar_worker_body();
|
||||
|
||||
extern "C" void __attribute__((naked, used)) scalar_worker() {
|
||||
asm volatile(
|
||||
".option push\n\t"
|
||||
".option norelax\n\t"
|
||||
"la gp, __global_pointer\n\t"
|
||||
".option pop\n\t"
|
||||
"li sp, %[stack_base]\n\t"
|
||||
"csrr t0, %[csr_hart]\n\t"
|
||||
"slli t1, t0, %[stack_log2]\n\t"
|
||||
"slli t2, t0, 4\n\t"
|
||||
"add t1, t1, t2\n\t"
|
||||
"sub sp, sp, t1\n\t"
|
||||
"j scalar_worker_body\n\t"
|
||||
:
|
||||
: [csr_hart] "i"(VX_CSR_MHARTID),
|
||||
[stack_base] "i"(STACK_BASE_ADDR),
|
||||
[stack_log2] "i"(STACK_LOG2_SIZE)
|
||||
: "memory");
|
||||
}
|
||||
|
||||
extern "C" void scalar_worker_body() {
|
||||
wu_short_delay(wu_wid());
|
||||
wu_mark_seen(WU_CASE_SCALAR_BASE);
|
||||
wu_stop_warp();
|
||||
}
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
#define WU_CASE_WAIT_SPIN 1024u
|
||||
#include "common_wu_min.h"
|
||||
|
||||
extern "C" void __attribute__((naked, noinline, used)) tensor_worker() {
|
||||
|
||||
@@ -6,13 +6,15 @@ extern "C" void __attribute__((naked, noinline, used)) tensor_worker() {
|
||||
"slli x6, x5, 2\n\t"
|
||||
"la x7, g_case_mem\n\t"
|
||||
"add x7, x7, x6\n\t"
|
||||
"li x8, %[tensor_lsu_base]\n\t"
|
||||
"or x8, x8, x5\n\t"
|
||||
"sw x8, 0(x7)\n\t"
|
||||
"lw x8, 0(x7)\n\t"
|
||||
"li x6, %[tensor_lsu_base]\n\t"
|
||||
"or x5, x6, x5\n\t"
|
||||
"sw x5, 0(x7)\n\t"
|
||||
"lw x5, 0(x7)\n\t"
|
||||
"sub x6, x5, x6\n\t"
|
||||
"slli x6, x6, 2\n\t"
|
||||
"la x7, g_seen\n\t"
|
||||
"add x7, x7, x6\n\t"
|
||||
"sw x8, 0(x7)\n\t"
|
||||
"sw x5, 0(x7)\n\t"
|
||||
".insn r %[custom0], 0, 0, x0, x0, x0\n\t"
|
||||
"1: j 1b\n\t"
|
||||
:
|
||||
|
||||
3
kernels/wu_arch_cases/case09_scalar_tmem_ldst/Makefile
Normal file
3
kernels/wu_arch_cases/case09_scalar_tmem_ldst/Makefile
Normal file
@@ -0,0 +1,3 @@
|
||||
PROJECT = wu_arch_case09_scalar_tmem_ldst
|
||||
|
||||
include ../case.mk
|
||||
7
kernels/wu_arch_cases/case09_scalar_tmem_ldst/README.md
Normal file
7
kernels/wu_arch_cases/case09_scalar_tmem_ldst/README.md
Normal file
@@ -0,0 +1,7 @@
|
||||
# case09_scalar_tmem_ldst
|
||||
|
||||
Smoke test for scalar warp direct TMEM access.
|
||||
|
||||
The leader scalar lane writes a lane-distinct word vector to a TMEM row with
|
||||
`CUSTOM1 func7=0x30 func3=1`, reads it back with `CUSTOM1 func7=0x30 func3=0`,
|
||||
and reports pass only if the lane 0 value matches.
|
||||
39
kernels/wu_arch_cases/case09_scalar_tmem_ldst/kernel.cpp
Normal file
39
kernels/wu_arch_cases/case09_scalar_tmem_ldst/kernel.cpp
Normal file
@@ -0,0 +1,39 @@
|
||||
#include "common_wu_min.h"
|
||||
|
||||
static inline uint32_t wu_scalar_tmem_ld(uint32_t addr) {
|
||||
uint32_t value;
|
||||
asm volatile(".insn r %[custom1], 0, 0x30, %[value], %[addr], x0"
|
||||
: [value] "=r"(value)
|
||||
: [custom1] "i"(RISCV_CUSTOM1), [addr] "r"(addr)
|
||||
: "memory");
|
||||
return value;
|
||||
}
|
||||
|
||||
static inline void wu_scalar_tmem_st(uint32_t addr, uint32_t value) {
|
||||
asm volatile(".insn r %[custom1], 1, 0x30, x0, %[addr], %[value]"
|
||||
:
|
||||
: [custom1] "i"(RISCV_CUSTOM1), [addr] "r"(addr), [value] "r"(value)
|
||||
: "memory");
|
||||
}
|
||||
|
||||
extern "C" int wu_main() {
|
||||
if (!wu_is_leader()) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
wu_case_reset();
|
||||
|
||||
const uint32_t addr = 0x24u;
|
||||
const uint32_t expected = 0x5a090000u | wu_tid();
|
||||
wu_scalar_tmem_st(addr, expected);
|
||||
const uint32_t observed = wu_scalar_tmem_ld(addr);
|
||||
|
||||
if (observed != expected) {
|
||||
g_aux[0] = observed;
|
||||
wu_case_fail(0x09u);
|
||||
return 1;
|
||||
}
|
||||
|
||||
wu_case_pass();
|
||||
return 0;
|
||||
}
|
||||
@@ -0,0 +1,3 @@
|
||||
PROJECT = wu_arch_case10_tensor_scalar_tmem_handoff
|
||||
|
||||
include ../case.mk
|
||||
@@ -0,0 +1,8 @@
|
||||
# case10_tensor_scalar_tmem_handoff
|
||||
|
||||
Validates the tensor-to-scalar TMEM handoff needed by FlashAttention.
|
||||
|
||||
The tensor warp initializes TMEM A/C, runs one BWGMMA, and leaves the result in
|
||||
TMEM C. The scalar leader waits for the tensor warp and then reads several C
|
||||
fragments through the scalar TMEM load instruction. Each lane is expected to
|
||||
observe the HGEMM value `0x42820000`.
|
||||
@@ -0,0 +1,111 @@
|
||||
#include "../common_wu_min.h"
|
||||
|
||||
#define DEV_SMEM_START_ADDR 0xff000000u
|
||||
#define WU_CASE_TMEM_HANDOFF_BASE 0x7700u
|
||||
#define WU_TMEM_TILE_BYTES 1024u
|
||||
#define WU_TMEM_FRAGMENT_BYTES 16u
|
||||
#define WU_TMEM_C_BYTE_BASE 1024u
|
||||
#define WU_TMEM_EXPECTED_HGEMM 0x42820000u
|
||||
|
||||
static_assert(NUM_TENSOR_WARPS >= 1, "case10 requires at least one tensor warp");
|
||||
static_assert(NUM_THREADS == 4, "case10 expects the 4-lane Blackwell tensor core");
|
||||
|
||||
#define BW_REP2(x) x, x
|
||||
#define BW_REP4(x) BW_REP2(x), BW_REP2(x)
|
||||
|
||||
extern "C" {
|
||||
volatile uint32_t g_case10_a_row[4] __attribute__((aligned(16))) = {
|
||||
BW_REP4(0x3c003c00u)};
|
||||
volatile uint32_t g_case10_b_row[4] __attribute__((aligned(16))) = {
|
||||
BW_REP4(0x40004000u)};
|
||||
volatile uint32_t g_case10_c_row[4] __attribute__((aligned(16))) = {
|
||||
BW_REP4(0x3f800000u)};
|
||||
}
|
||||
|
||||
#undef BW_REP2
|
||||
#undef BW_REP4
|
||||
|
||||
static inline uint32_t wu_scalar_tmem_ld(uint32_t frag_addr) {
|
||||
uint32_t value;
|
||||
asm volatile(".insn r %[custom1], 0, 0x30, %[value], %[addr], x0"
|
||||
: [value] "=r"(value)
|
||||
: [custom1] "i"(RISCV_CUSTOM1), [addr] "r"(frag_addr)
|
||||
: "memory");
|
||||
return value;
|
||||
}
|
||||
|
||||
extern "C" void __attribute__((naked, noinline, used)) tensor_case10_worker() {
|
||||
asm volatile(
|
||||
"csrr x5, %[csr_wid]\n\t"
|
||||
"addi x1, x5, -%[num_scalar_warps]\n\t"
|
||||
"slli x1, x1, 11\n\t"
|
||||
"addi x2, x1, 1024\n\t"
|
||||
"la x6, g_case10_a_row\n\t"
|
||||
"la x3, g_case10_c_row\n\t"
|
||||
"li x7, 0\n\t"
|
||||
"1:\n\t"
|
||||
"add x4, x1, x7\n\t"
|
||||
".insn r %[custom3], 2, 0, x0, x4, x6\n\t"
|
||||
"add x4, x2, x7\n\t"
|
||||
".insn r %[custom3], 2, 0, x0, x4, x3\n\t"
|
||||
"addi x7, x7, 16\n\t"
|
||||
"li x4, %[tile_bytes]\n\t"
|
||||
"blt x7, x4, 1b\n\t"
|
||||
".insn r %[custom3], 3, 0, x0, x0, x0\n\t"
|
||||
"li x4, %[smem_base]\n\t"
|
||||
".insn r %[custom3], 0, 0, x2, x1, x4\n\t"
|
||||
".insn r %[custom3], 1, 0, x0, x0, x0\n\t"
|
||||
"slli x6, x5, 2\n\t"
|
||||
"la x7, g_seen\n\t"
|
||||
"add x7, x7, x6\n\t"
|
||||
"li x6, %[handoff_base]\n\t"
|
||||
"or x6, x6, x5\n\t"
|
||||
"sw x6, 0(x7)\n\t"
|
||||
".insn r %[custom0], 0, 0, x0, x0, x0\n\t"
|
||||
"2: j 2b\n\t"
|
||||
:
|
||||
: [csr_wid] "i"(VX_CSR_WARP_ID),
|
||||
[custom0] "i"(RISCV_CUSTOM0),
|
||||
[custom3] "i"(RISCV_CUSTOM3),
|
||||
[smem_base] "i"(DEV_SMEM_START_ADDR),
|
||||
[handoff_base] "i"(WU_CASE_TMEM_HANDOFF_BASE),
|
||||
[num_scalar_warps] "i"(NUM_SCALAR_WARPS),
|
||||
[tile_bytes] "i"(WU_TMEM_TILE_BYTES)
|
||||
: "memory");
|
||||
}
|
||||
|
||||
extern "C" int wu_main() {
|
||||
if (!wu_is_leader()) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
wu_case_reset();
|
||||
|
||||
volatile uint32_t *smem_b =
|
||||
reinterpret_cast<volatile uint32_t *>(DEV_SMEM_START_ADDR);
|
||||
for (uint32_t i = 0; i < WU_TMEM_TILE_BYTES / sizeof(uint32_t); ++i) {
|
||||
smem_b[i] = g_case10_b_row[i & 3u];
|
||||
}
|
||||
asm volatile("fence rw, rw" ::: "memory");
|
||||
|
||||
vx_spawn_tensor(1u << NUM_SCALAR_WARPS, tensor_case10_worker);
|
||||
|
||||
if (wu_wait_seen_mask(1u << NUM_SCALAR_WARPS, WU_CASE_TMEM_HANDOFF_BASE) != 0) {
|
||||
wu_case_fail(0x10u);
|
||||
return 1;
|
||||
}
|
||||
|
||||
const uint32_t c_frag_base = WU_TMEM_C_BYTE_BASE / WU_TMEM_FRAGMENT_BYTES;
|
||||
for (uint32_t i = 0; i < 8; ++i) {
|
||||
const uint32_t observed = wu_scalar_tmem_ld(c_frag_base + i);
|
||||
if (observed != WU_TMEM_EXPECTED_HGEMM) {
|
||||
g_aux[0] = i;
|
||||
g_aux[1] = observed;
|
||||
wu_case_fail(0x11u);
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
wu_case_pass();
|
||||
return 0;
|
||||
}
|
||||
@@ -0,0 +1,3 @@
|
||||
PROJECT = wu_arch_case11_scalar_tmem_softmax_stage
|
||||
|
||||
include ../case.mk
|
||||
@@ -0,0 +1,10 @@
|
||||
# case11_scalar_tmem_softmax_stage
|
||||
|
||||
Validates the FlashAttention softmax-stage TMEM path.
|
||||
|
||||
The tensor warp writes HGEMM results into TMEM C. The scalar warp reads TMEM C,
|
||||
performs a deterministic lane-wise transform, and writes all four lanes back
|
||||
into one TMEM A-region fragment. The tensor warp then uses TCGEN05_CB to copy
|
||||
that TMEM fragment to global memory so scalar code can verify that tensor-side
|
||||
TMEM reads observe the scalar writeback with the correct lane ordering and
|
||||
write mask.
|
||||
@@ -0,0 +1,200 @@
|
||||
#include "../common_wu_min.h"
|
||||
|
||||
#define DEV_SMEM_START_ADDR 0xff000000u
|
||||
#define WU_CASE_TMEM_STAGE_HGEMM_BASE 0x7800u
|
||||
#define WU_CASE_TMEM_STAGE_VERIFY_BASE 0x7900u
|
||||
#define WU_TMEM_TILE_BYTES 1024u
|
||||
#define WU_TMEM_FRAGMENT_BYTES 16u
|
||||
#define WU_TMEM_C_BYTE_BASE 1024u
|
||||
#define WU_TMEM_SCALAR_WRITE_BYTE_ADDR 512u
|
||||
#define WU_TMEM_EXPECTED_HGEMM 0x42820000u
|
||||
#define WU_TMEM_STAGE_VALUE_BASE 0x42820001u
|
||||
|
||||
static_assert(NUM_TENSOR_WARPS >= 1, "case11 requires at least one tensor warp");
|
||||
static_assert(NUM_THREADS == 4, "case11 expects the 4-lane Blackwell tensor core");
|
||||
|
||||
#define BW_REP2(x) x, x
|
||||
#define BW_REP4(x) BW_REP2(x), BW_REP2(x)
|
||||
|
||||
extern "C" {
|
||||
volatile uint32_t g_case11_a_row[4] __attribute__((aligned(16))) = {
|
||||
BW_REP4(0x3c003c00u)};
|
||||
volatile uint32_t g_case11_b_row[4] __attribute__((aligned(16))) = {
|
||||
BW_REP4(0x40004000u)};
|
||||
volatile uint32_t g_case11_c_row[4] __attribute__((aligned(16))) = {
|
||||
BW_REP4(0x3f800000u)};
|
||||
volatile uint32_t g_case11_out[4] __attribute__((aligned(16)));
|
||||
volatile uint32_t g_case11_scalar_seen[4] __attribute__((aligned(16)));
|
||||
}
|
||||
|
||||
#undef BW_REP2
|
||||
#undef BW_REP4
|
||||
|
||||
static inline uint32_t wu_scalar_tmem_ld(uint32_t frag_addr) {
|
||||
uint32_t value;
|
||||
asm volatile(".insn r %[custom1], 0, 0x30, %[value], %[addr], x0"
|
||||
: [value] "=r"(value)
|
||||
: [custom1] "i"(RISCV_CUSTOM1), [addr] "r"(frag_addr)
|
||||
: "memory");
|
||||
return value;
|
||||
}
|
||||
|
||||
static inline void wu_scalar_tmem_st(uint32_t frag_addr, uint32_t value) {
|
||||
asm volatile(".insn r %[custom1], 1, 0x30, x0, %[addr], %[value]"
|
||||
:
|
||||
: [custom1] "i"(RISCV_CUSTOM1), [addr] "r"(frag_addr),
|
||||
[value] "r"(value)
|
||||
: "memory");
|
||||
}
|
||||
|
||||
extern "C" void __attribute__((naked, noinline, used)) tensor_case11_worker() {
|
||||
asm volatile(
|
||||
"csrr x5, %[csr_wid]\n\t"
|
||||
"addi x1, x5, -%[num_scalar_warps]\n\t"
|
||||
"slli x1, x1, 11\n\t"
|
||||
"addi x2, x1, 1024\n\t"
|
||||
"la x6, g_case11_a_row\n\t"
|
||||
"la x3, g_case11_c_row\n\t"
|
||||
"li x7, 0\n\t"
|
||||
"1:\n\t"
|
||||
"add x4, x1, x7\n\t"
|
||||
".insn r %[custom3], 2, 0, x0, x4, x6\n\t"
|
||||
"add x4, x2, x7\n\t"
|
||||
".insn r %[custom3], 2, 0, x0, x4, x3\n\t"
|
||||
"addi x7, x7, 16\n\t"
|
||||
"li x4, %[tile_bytes]\n\t"
|
||||
"blt x7, x4, 1b\n\t"
|
||||
".insn r %[custom3], 3, 0, x0, x0, x0\n\t"
|
||||
"li x4, %[smem_base]\n\t"
|
||||
".insn r %[custom3], 0, 0, x2, x1, x4\n\t"
|
||||
".insn r %[custom3], 1, 0, x0, x0, x0\n\t"
|
||||
"slli x6, x5, 2\n\t"
|
||||
"la x7, g_seen\n\t"
|
||||
"add x7, x7, x6\n\t"
|
||||
"li x6, %[hgemm_base]\n\t"
|
||||
"or x6, x6, x5\n\t"
|
||||
"sw x6, 0(x7)\n\t"
|
||||
"3:\n\t"
|
||||
"la x6, g_case_mem\n\t"
|
||||
"lw x7, 0(x6)\n\t"
|
||||
"li x4, %[verify_base]\n\t"
|
||||
"bne x7, x4, 3b\n\t"
|
||||
"addi x4, x1, %[write_byte_addr]\n\t"
|
||||
"la x6, g_case11_out\n\t"
|
||||
".insn r %[custom3], 6, 0, x0, x4, x6\n\t"
|
||||
".insn r %[custom3], 3, 0, x0, x0, x0\n\t"
|
||||
"csrr x5, %[csr_wid]\n\t"
|
||||
"slli x6, x5, 2\n\t"
|
||||
"la x7, g_seen\n\t"
|
||||
"add x7, x7, x6\n\t"
|
||||
"li x6, %[verify_base]\n\t"
|
||||
"or x6, x6, x5\n\t"
|
||||
"sw x6, 0(x7)\n\t"
|
||||
".insn r %[custom0], 0, 0, x0, x0, x0\n\t"
|
||||
"2: j 2b\n\t"
|
||||
:
|
||||
: [csr_wid] "i"(VX_CSR_WARP_ID),
|
||||
[custom0] "i"(RISCV_CUSTOM0),
|
||||
[custom3] "i"(RISCV_CUSTOM3),
|
||||
[smem_base] "i"(DEV_SMEM_START_ADDR),
|
||||
[hgemm_base] "i"(WU_CASE_TMEM_STAGE_HGEMM_BASE),
|
||||
[verify_base] "i"(WU_CASE_TMEM_STAGE_VERIFY_BASE),
|
||||
[num_scalar_warps] "i"(NUM_SCALAR_WARPS),
|
||||
[tile_bytes] "i"(WU_TMEM_TILE_BYTES),
|
||||
[write_byte_addr] "i"(WU_TMEM_SCALAR_WRITE_BYTE_ADDR)
|
||||
: "memory");
|
||||
}
|
||||
|
||||
extern "C" int wu_main() {
|
||||
if (vx_core_id() != 0 || vx_warp_id() != 0) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
const uint32_t tid = wu_tid();
|
||||
|
||||
if (tid == 0) {
|
||||
wu_case_reset();
|
||||
for (uint32_t i = 0; i < 4; ++i) {
|
||||
g_case11_out[i] = 0;
|
||||
g_case11_scalar_seen[i] = 0;
|
||||
}
|
||||
|
||||
volatile uint32_t *smem_b =
|
||||
reinterpret_cast<volatile uint32_t *>(DEV_SMEM_START_ADDR);
|
||||
for (uint32_t i = 0; i < WU_TMEM_TILE_BYTES / sizeof(uint32_t); ++i) {
|
||||
smem_b[i] = g_case11_b_row[i & 3u];
|
||||
}
|
||||
}
|
||||
asm volatile("fence rw, rw" ::: "memory");
|
||||
|
||||
if (tid == 0) {
|
||||
vx_spawn_tensor(1u << NUM_SCALAR_WARPS, tensor_case11_worker);
|
||||
|
||||
if (wu_wait_seen_mask(1u << NUM_SCALAR_WARPS,
|
||||
WU_CASE_TMEM_STAGE_HGEMM_BASE) != 0) {
|
||||
g_case_mem[1] = 0x20u;
|
||||
}
|
||||
}
|
||||
asm volatile("fence rw, rw" ::: "memory");
|
||||
if (g_case_mem[1] != 0) {
|
||||
if (tid == 0) {
|
||||
wu_case_fail(g_case_mem[1]);
|
||||
}
|
||||
return 1;
|
||||
}
|
||||
|
||||
const uint32_t c_frag_base = WU_TMEM_C_BYTE_BASE / WU_TMEM_FRAGMENT_BYTES;
|
||||
const uint32_t observed = wu_scalar_tmem_ld(c_frag_base);
|
||||
g_case11_scalar_seen[tid] = observed;
|
||||
|
||||
const uint32_t write_frag =
|
||||
WU_TMEM_SCALAR_WRITE_BYTE_ADDR / WU_TMEM_FRAGMENT_BYTES;
|
||||
wu_scalar_tmem_st(write_frag, observed + 1u + tid);
|
||||
asm volatile("fence rw, rw" ::: "memory");
|
||||
|
||||
if (tid == 0) {
|
||||
if (g_case11_scalar_seen[0] != WU_TMEM_EXPECTED_HGEMM) {
|
||||
g_aux[0] = 0;
|
||||
g_aux[1] = g_case11_scalar_seen[0];
|
||||
g_case_mem[1] = 0x21u;
|
||||
}
|
||||
}
|
||||
asm volatile("fence rw, rw" ::: "memory");
|
||||
if (g_case_mem[1] != 0) {
|
||||
if (tid == 0) {
|
||||
wu_case_fail(g_case_mem[1]);
|
||||
}
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (tid == 0) {
|
||||
g_case_mem[0] = WU_CASE_TMEM_STAGE_VERIFY_BASE;
|
||||
}
|
||||
asm volatile("fence rw, rw" ::: "memory");
|
||||
|
||||
if (tid == 0) {
|
||||
if (wu_wait_seen_mask(1u << NUM_SCALAR_WARPS,
|
||||
WU_CASE_TMEM_STAGE_VERIFY_BASE) != 0) {
|
||||
g_case_mem[1] = 0x22u;
|
||||
}
|
||||
}
|
||||
asm volatile("fence rw, rw" ::: "memory");
|
||||
if (g_case_mem[1] != 0) {
|
||||
if (tid == 0) {
|
||||
wu_case_fail(g_case_mem[1]);
|
||||
}
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (tid == 0) {
|
||||
if (g_case11_out[0] != WU_TMEM_STAGE_VALUE_BASE) {
|
||||
g_aux[0] = 0;
|
||||
g_aux[1] = g_case11_out[0];
|
||||
wu_case_fail(0x23u);
|
||||
return 1;
|
||||
}
|
||||
|
||||
wu_case_pass();
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
@@ -0,0 +1,3 @@
|
||||
PROJECT = case12_flash_pv_accum
|
||||
|
||||
include ../case.mk
|
||||
@@ -0,0 +1,10 @@
|
||||
# case12.1 scalar TMEM CB probe
|
||||
|
||||
Validates the scalar-to-tensor TMEM handoff without BWGMMA. Scalar warp 0 fills
|
||||
tensor block 0 TMEM A rows `0..63` with packed fp16 `1.0`. Tensor warp
|
||||
`NUM_SCALAR_WARPS` waits for that fill, then uses `tcgen05_cb` to copy the same
|
||||
TMEM A rows back to global memory.
|
||||
|
||||
The scalar leader verifies all 256 copied words are `0x3c003c00`. A mismatch
|
||||
means scalar TMEM stores are not visible to the tensor TMEM read/copy path at
|
||||
the expected row addresses.
|
||||
108
kernels/wu_arch_cases/case12_1_scalar_tmem_cb_probe/kernel.cpp
Normal file
108
kernels/wu_arch_cases/case12_1_scalar_tmem_cb_probe/kernel.cpp
Normal file
@@ -0,0 +1,108 @@
|
||||
#include "../common_wu_blackwell_fa.h"
|
||||
|
||||
#define WU_CASE12_1_INIT_BASE 0x8d00u
|
||||
#define WU_CASE12_1_DONE_BASE 0x8e00u
|
||||
#define WU_CASE12_1_COPY_READY 0x8f00u
|
||||
|
||||
extern "C" {
|
||||
volatile uint32_t g_case12_1_out[WU_BW_OUT_WORDS] __attribute__((aligned(16)));
|
||||
}
|
||||
|
||||
extern "C" void __attribute__((naked, noinline, used)) tensor_case12_1_worker() {
|
||||
asm volatile(
|
||||
"csrr x5, %[csr_wid]\n\t"
|
||||
"addi x1, x5, -%[num_scalar_warps]\n\t"
|
||||
"slli x1, x1, 11\n\t"
|
||||
"slli x6, x5, 2\n\t"
|
||||
"la x7, g_seen\n\t"
|
||||
"add x7, x7, x6\n\t"
|
||||
"li x6, %[init_base]\n\t"
|
||||
"or x6, x6, x5\n\t"
|
||||
"sw x6, 0(x7)\n\t"
|
||||
"1:\n\t"
|
||||
"la x6, g_case_mem\n\t"
|
||||
"lw x7, 0(x6)\n\t"
|
||||
"li x4, %[copy_ready]\n\t"
|
||||
"bne x7, x4, 1b\n\t"
|
||||
"la x3, g_case12_1_out\n\t"
|
||||
"li x7, 0\n\t"
|
||||
"2:\n\t"
|
||||
"add x4, x1, x7\n\t"
|
||||
"add x6, x3, x7\n\t"
|
||||
".insn r %[custom3], 6, 0, x0, x4, x6\n\t"
|
||||
"addi x7, x7, 16\n\t"
|
||||
"li x4, %[tile_bytes]\n\t"
|
||||
"blt x7, x4, 2b\n\t"
|
||||
".insn r %[custom3], 3, 0, x0, x0, x0\n\t"
|
||||
"csrr x5, %[csr_wid]\n\t"
|
||||
"slli x6, x5, 2\n\t"
|
||||
"la x7, g_seen\n\t"
|
||||
"add x7, x7, x6\n\t"
|
||||
"li x6, %[done_base]\n\t"
|
||||
"or x6, x6, x5\n\t"
|
||||
"sw x6, 0(x7)\n\t"
|
||||
".insn r %[custom0], 0, 0, x0, x0, x0\n\t"
|
||||
"3: j 3b\n\t"
|
||||
:
|
||||
: [csr_wid] "i"(VX_CSR_WARP_ID), [custom0] "i"(RISCV_CUSTOM0),
|
||||
[custom3] "i"(RISCV_CUSTOM3),
|
||||
[num_scalar_warps] "i"(NUM_SCALAR_WARPS),
|
||||
[tile_bytes] "i"(WU_BW_TMEM_TILE_BYTES),
|
||||
[init_base] "i"(WU_CASE12_1_INIT_BASE),
|
||||
[done_base] "i"(WU_CASE12_1_DONE_BASE),
|
||||
[copy_ready] "i"(WU_CASE12_1_COPY_READY)
|
||||
: "memory");
|
||||
}
|
||||
|
||||
extern "C" int wu_main() {
|
||||
if (vx_core_id() != 0 || vx_warp_id() != 0) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
const uint32_t tid = wu_tid();
|
||||
if (tid == 0) {
|
||||
wu_case_reset();
|
||||
for (uint32_t i = 0; i < WU_BW_OUT_WORDS; ++i) {
|
||||
g_case12_1_out[i] = 0;
|
||||
}
|
||||
vx_spawn_tensor(1u << NUM_SCALAR_WARPS, tensor_case12_1_worker);
|
||||
if (wu_wait_seen_mask(1u << NUM_SCALAR_WARPS, WU_CASE12_1_INIT_BASE) !=
|
||||
0) {
|
||||
g_case_mem[1] = 0x12u;
|
||||
}
|
||||
}
|
||||
asm volatile("fence rw, rw" ::: "memory");
|
||||
|
||||
if (g_case_mem[1] == 0) {
|
||||
vx_tmc(wu_bw_all_lanes_mask());
|
||||
wu_bw_fill_tmem_tile(wu_bw_tmem_a_byte_base(0), WU_BW_FP16_ONE_PACKED);
|
||||
vx_tmc_one();
|
||||
}
|
||||
asm volatile("fence rw, rw" ::: "memory");
|
||||
|
||||
if (tid == 0) {
|
||||
g_case_mem[0] = WU_CASE12_1_COPY_READY;
|
||||
if (g_case_mem[1] == 0 &&
|
||||
wu_wait_seen_mask(1u << NUM_SCALAR_WARPS, WU_CASE12_1_DONE_BASE) !=
|
||||
0) {
|
||||
g_case_mem[1] = 0x13u;
|
||||
}
|
||||
if (g_case_mem[1] == 0) {
|
||||
volatile uint32_t bad_actual = 0;
|
||||
const uint32_t bad =
|
||||
wu_bw_verify_constant(g_case12_1_out, WU_BW_OUT_WORDS,
|
||||
WU_BW_FP16_ONE_PACKED, &bad_actual);
|
||||
if (bad != WU_BW_OUT_WORDS) {
|
||||
g_aux[0] = bad;
|
||||
g_aux[1] = bad_actual;
|
||||
g_case_mem[1] = 0x14u;
|
||||
}
|
||||
}
|
||||
if (g_case_mem[1] != 0) {
|
||||
wu_case_fail(g_case_mem[1]);
|
||||
return 1;
|
||||
}
|
||||
wu_case_pass();
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
3
kernels/wu_arch_cases/case12_2_flash_pv_p_probe/Makefile
Normal file
3
kernels/wu_arch_cases/case12_2_flash_pv_p_probe/Makefile
Normal file
@@ -0,0 +1,3 @@
|
||||
PROJECT = case12_2_flash_pv_p_probe
|
||||
|
||||
include ../case.mk
|
||||
11
kernels/wu_arch_cases/case12_2_flash_pv_p_probe/README.md
Normal file
11
kernels/wu_arch_cases/case12_2_flash_pv_p_probe/README.md
Normal file
@@ -0,0 +1,11 @@
|
||||
# case12_2_flash_pv_p_probe
|
||||
|
||||
Diagnostic for `case12_flash_pv_accum`.
|
||||
|
||||
Scalar warp 0 fills tensor block 0 TMEM A with packed fp16 `1.0`, using the
|
||||
same `wu_bw_fill_tmem_tile()` path as case12. Tensor warp `NUM_SCALAR_WARPS`
|
||||
then copies the full TMEM A tile back to global memory with `tcgen05_cb`.
|
||||
|
||||
The scalar leader verifies all copied words are `0x3c003c00`. A mismatch means
|
||||
case12's scalar-written P tile is not reaching the tensor TMEM read/copy path as
|
||||
expected.
|
||||
109
kernels/wu_arch_cases/case12_2_flash_pv_p_probe/kernel.cpp
Normal file
109
kernels/wu_arch_cases/case12_2_flash_pv_p_probe/kernel.cpp
Normal file
@@ -0,0 +1,109 @@
|
||||
#include "../common_wu_blackwell_fa.h"
|
||||
|
||||
#define WU_CASE12_2_INIT_BASE 0x9600u
|
||||
#define WU_CASE12_2_DONE_BASE 0x9700u
|
||||
#define WU_CASE12_2_COPY_READY 0x9800u
|
||||
|
||||
extern "C" {
|
||||
volatile uint32_t g_case12_2_out[WU_BW_OUT_WORDS]
|
||||
__attribute__((aligned(16)));
|
||||
}
|
||||
|
||||
extern "C" void __attribute__((naked, noinline, used))
|
||||
tensor_case12_2_worker() {
|
||||
asm volatile(
|
||||
"csrr x5, %[csr_wid]\n\t"
|
||||
"addi x1, x5, -%[num_scalar_warps]\n\t"
|
||||
"slli x1, x1, 11\n\t"
|
||||
"slli x6, x5, 2\n\t"
|
||||
"la x7, g_seen\n\t"
|
||||
"add x7, x7, x6\n\t"
|
||||
"li x6, %[init_base]\n\t"
|
||||
"or x6, x6, x5\n\t"
|
||||
"sw x6, 0(x7)\n\t"
|
||||
"1:\n\t"
|
||||
"la x6, g_case_mem\n\t"
|
||||
"lw x7, 0(x6)\n\t"
|
||||
"li x4, %[copy_ready]\n\t"
|
||||
"bne x7, x4, 1b\n\t"
|
||||
"la x3, g_case12_2_out\n\t"
|
||||
"li x7, 0\n\t"
|
||||
"2:\n\t"
|
||||
"add x4, x1, x7\n\t"
|
||||
"add x6, x3, x7\n\t"
|
||||
".insn r %[custom3], 6, 0, x0, x4, x6\n\t"
|
||||
"addi x7, x7, 16\n\t"
|
||||
"li x4, %[tile_bytes]\n\t"
|
||||
"blt x7, x4, 2b\n\t"
|
||||
".insn r %[custom3], 3, 0, x0, x0, x0\n\t"
|
||||
"csrr x5, %[csr_wid]\n\t"
|
||||
"slli x6, x5, 2\n\t"
|
||||
"la x7, g_seen\n\t"
|
||||
"add x7, x7, x6\n\t"
|
||||
"li x6, %[done_base]\n\t"
|
||||
"or x6, x6, x5\n\t"
|
||||
"sw x6, 0(x7)\n\t"
|
||||
".insn r %[custom0], 0, 0, x0, x0, x0\n\t"
|
||||
"3: j 3b\n\t"
|
||||
:
|
||||
: [csr_wid] "i"(VX_CSR_WARP_ID), [custom0] "i"(RISCV_CUSTOM0),
|
||||
[custom3] "i"(RISCV_CUSTOM3),
|
||||
[num_scalar_warps] "i"(NUM_SCALAR_WARPS),
|
||||
[tile_bytes] "i"(WU_BW_TMEM_TILE_BYTES),
|
||||
[init_base] "i"(WU_CASE12_2_INIT_BASE),
|
||||
[done_base] "i"(WU_CASE12_2_DONE_BASE),
|
||||
[copy_ready] "i"(WU_CASE12_2_COPY_READY)
|
||||
: "memory");
|
||||
}
|
||||
|
||||
extern "C" int wu_main() {
|
||||
if (vx_core_id() != 0 || vx_warp_id() != 0) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
const uint32_t tid = wu_tid();
|
||||
const uint32_t tensor_mask = 1u << NUM_SCALAR_WARPS;
|
||||
if (tid == 0) {
|
||||
wu_case_reset();
|
||||
for (uint32_t i = 0; i < WU_BW_OUT_WORDS; ++i) {
|
||||
g_case12_2_out[i] = 0;
|
||||
}
|
||||
vx_spawn_tensor(tensor_mask, tensor_case12_2_worker);
|
||||
if (wu_wait_seen_mask(tensor_mask, WU_CASE12_2_INIT_BASE) != 0) {
|
||||
g_case_mem[1] = 0x51u;
|
||||
}
|
||||
}
|
||||
asm volatile("fence rw, rw" ::: "memory");
|
||||
|
||||
if (g_case_mem[1] == 0) {
|
||||
vx_tmc(wu_bw_all_lanes_mask());
|
||||
wu_bw_fill_tmem_tile(wu_bw_tmem_a_byte_base(0), WU_BW_FP16_ONE_PACKED);
|
||||
vx_tmc_one();
|
||||
}
|
||||
asm volatile("fence rw, rw" ::: "memory");
|
||||
|
||||
if (tid == 0) {
|
||||
g_case_mem[0] = WU_CASE12_2_COPY_READY;
|
||||
if (g_case_mem[1] == 0 &&
|
||||
wu_wait_seen_mask(tensor_mask, WU_CASE12_2_DONE_BASE) != 0) {
|
||||
g_case_mem[1] = 0x52u;
|
||||
}
|
||||
if (g_case_mem[1] == 0) {
|
||||
volatile uint32_t bad_actual = 0;
|
||||
const uint32_t bad =
|
||||
wu_bw_verify_constant(g_case12_2_out, WU_BW_OUT_WORDS,
|
||||
WU_BW_FP16_ONE_PACKED, &bad_actual);
|
||||
if (bad != WU_BW_OUT_WORDS) {
|
||||
g_aux[0] = bad;
|
||||
g_aux[1] = bad_actual;
|
||||
g_case_mem[1] = 0x53u;
|
||||
}
|
||||
}
|
||||
if (g_case_mem[1] != 0) {
|
||||
wu_case_fail(g_case_mem[1]);
|
||||
return 1;
|
||||
}
|
||||
wu_case_pass();
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
@@ -0,0 +1,3 @@
|
||||
PROJECT = case12_3_scalar_tmem_lane_store
|
||||
|
||||
include ../case.mk
|
||||
@@ -0,0 +1,14 @@
|
||||
# case12_3_scalar_tmem_lane_store
|
||||
|
||||
Validates scalar TMEM store lane-coalesced semantics.
|
||||
|
||||
One scalar warp writes a single 16-byte TMEM fragment. Each active scalar lane supplies one 32-bit word:
|
||||
|
||||
```text
|
||||
word0 = lane0.rs2
|
||||
word1 = lane1.rs2
|
||||
word2 = lane2.rs2
|
||||
word3 = lane3.rs2
|
||||
```
|
||||
|
||||
The tensor copy-back path then copies that fragment to memory. This catches implementations that only write lane0 data or broadcast one scalar value.
|
||||
@@ -0,0 +1,91 @@
|
||||
#include "../common_wu_blackwell_fa.h"
|
||||
|
||||
#define WU_CASE12_3_COPY_READY 0x9900u
|
||||
#define WU_CASE12_3_DONE_BASE 0x9a00u
|
||||
|
||||
extern "C" {
|
||||
volatile uint32_t g_case12_3_out[4] __attribute__((aligned(16)));
|
||||
}
|
||||
|
||||
extern "C" void __attribute__((naked, noinline, used))
|
||||
tensor_case12_3_copy_worker() {
|
||||
asm volatile(
|
||||
"csrr x5, %[csr_wid]\n\t"
|
||||
"addi x1, x5, -%[num_scalar_warps]\n\t"
|
||||
"slli x1, x1, 11\n\t"
|
||||
"1:\n\t"
|
||||
"la x6, g_case_mem\n\t"
|
||||
"lw x7, 0(x6)\n\t"
|
||||
"li x4, %[copy_ready]\n\t"
|
||||
"bne x7, x4, 1b\n\t"
|
||||
"la x6, g_case12_3_out\n\t"
|
||||
".insn r %[custom3], 6, 0, x0, x1, x6\n\t"
|
||||
".insn r %[custom3], 3, 0, x0, x0, x0\n\t"
|
||||
"csrr x5, %[csr_wid]\n\t"
|
||||
"slli x6, x5, 2\n\t"
|
||||
"la x7, g_seen\n\t"
|
||||
"add x7, x7, x6\n\t"
|
||||
"li x6, %[done_base]\n\t"
|
||||
"or x6, x6, x5\n\t"
|
||||
"sw x6, 0(x7)\n\t"
|
||||
".insn r %[custom0], 0, 0, x0, x0, x0\n\t"
|
||||
"2: j 2b\n\t"
|
||||
:
|
||||
: [csr_wid] "i"(VX_CSR_WARP_ID), [custom0] "i"(RISCV_CUSTOM0),
|
||||
[custom3] "i"(RISCV_CUSTOM3),
|
||||
[num_scalar_warps] "i"(NUM_SCALAR_WARPS),
|
||||
[copy_ready] "i"(WU_CASE12_3_COPY_READY),
|
||||
[done_base] "i"(WU_CASE12_3_DONE_BASE)
|
||||
: "memory");
|
||||
}
|
||||
|
||||
extern "C" int wu_main() {
|
||||
if (vx_core_id() != 0 || vx_warp_id() != 0) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
const uint32_t tensor_mask = 1u << NUM_SCALAR_WARPS;
|
||||
|
||||
wu_case_reset();
|
||||
for (uint32_t i = 0; i < 4; ++i) {
|
||||
g_case12_3_out[i] = 0;
|
||||
}
|
||||
vx_spawn_tensor(tensor_mask, tensor_case12_3_copy_worker);
|
||||
asm volatile("fence rw, rw" ::: "memory");
|
||||
|
||||
const uint32_t frag_addr =
|
||||
wu_bw_tmem_a_byte_base(0) / WU_BW_TMEM_FRAGMENT_BYTES;
|
||||
asm volatile(
|
||||
".insn r %[custom0], 0, 0, x0, %[all_lanes], x0\n\t"
|
||||
"csrr t0, %[csr_tid]\n\t"
|
||||
"lui t1, 0x5a120\n\t"
|
||||
"or t1, t1, t0\n\t"
|
||||
".insn r %[custom1], 1, 0x30, x0, %[addr], t1\n\t"
|
||||
"fence rw, rw\n\t"
|
||||
"li t2, 1\n\t"
|
||||
".insn r %[custom0], 0, 0, x0, t2, x0"
|
||||
:
|
||||
: [custom0] "i"(RISCV_CUSTOM0), [custom1] "i"(RISCV_CUSTOM1),
|
||||
[csr_tid] "i"(VX_CSR_THREAD_ID), [addr] "r"(frag_addr),
|
||||
[all_lanes] "r"(wu_bw_all_lanes_mask())
|
||||
: "t0", "t1", "t2", "memory");
|
||||
|
||||
g_case_mem[0] = WU_CASE12_3_COPY_READY;
|
||||
if (wu_wait_seen_mask(tensor_mask, WU_CASE12_3_DONE_BASE) != 0) {
|
||||
wu_case_fail(0x51u);
|
||||
return 1;
|
||||
}
|
||||
|
||||
for (uint32_t lane = 0; lane < NUM_THREADS; ++lane) {
|
||||
const uint32_t expected = 0x5a120000u | lane;
|
||||
if (g_case12_3_out[lane] != expected) {
|
||||
g_aux[0] = lane;
|
||||
g_aux[1] = g_case12_3_out[lane];
|
||||
wu_case_fail(0x52u);
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
wu_case_pass();
|
||||
return 0;
|
||||
}
|
||||
3
kernels/wu_arch_cases/case12_flash_pv_accum/Makefile
Normal file
3
kernels/wu_arch_cases/case12_flash_pv_accum/Makefile
Normal file
@@ -0,0 +1,3 @@
|
||||
PROJECT = case12_flash_pv_accum
|
||||
|
||||
include ../case.mk
|
||||
7
kernels/wu_arch_cases/case12_flash_pv_accum/README.md
Normal file
7
kernels/wu_arch_cases/case12_flash_pv_accum/README.md
Normal file
@@ -0,0 +1,7 @@
|
||||
# case12_flash_pv_accum
|
||||
|
||||
Validates the first Wu Blackwell FlashAttention PV substage.
|
||||
|
||||
Scalar warp 0 writes a full `16x32` fp16 `P` tile into TMEM A. One tensor warp
|
||||
uses BWGMMA with a `32x16` fp16 `V` tile in SMEM and a preloaded fp32 TMEM C
|
||||
tile. The expected result is `O = 3.0 + P @ V = 67.0` for every output element.
|
||||
127
kernels/wu_arch_cases/case12_flash_pv_accum/kernel.cpp
Normal file
127
kernels/wu_arch_cases/case12_flash_pv_accum/kernel.cpp
Normal file
@@ -0,0 +1,127 @@
|
||||
#include "../common_wu_blackwell_fa.h"
|
||||
|
||||
#define WU_CASE12_INIT_BASE 0x8a00u
|
||||
#define WU_CASE12_DONE_BASE 0x8b00u
|
||||
#define WU_CASE12_P_READY 0x8c00u
|
||||
|
||||
extern "C" {
|
||||
volatile uint32_t g_case12_o_row[4] __attribute__((aligned(16))) = {
|
||||
WU_BW_FP32_THREE, WU_BW_FP32_THREE, WU_BW_FP32_THREE,
|
||||
WU_BW_FP32_THREE};
|
||||
volatile uint32_t g_case12_out[WU_BW_OUT_WORDS] __attribute__((aligned(16)));
|
||||
}
|
||||
|
||||
extern "C" void __attribute__((naked, noinline, used)) tensor_case12_worker() {
|
||||
asm volatile(
|
||||
"csrr x5, %[csr_wid]\n\t"
|
||||
"addi x1, x5, -%[num_scalar_warps]\n\t"
|
||||
"slli x1, x1, 11\n\t"
|
||||
"addi x2, x1, %[c_offset]\n\t"
|
||||
"la x3, g_case12_o_row\n\t"
|
||||
"li x7, 0\n\t"
|
||||
"1:\n\t"
|
||||
"add x4, x2, x7\n\t"
|
||||
".insn r %[custom3], 2, 0, x0, x4, x3\n\t"
|
||||
"addi x7, x7, 16\n\t"
|
||||
"li x4, %[tile_bytes]\n\t"
|
||||
"blt x7, x4, 1b\n\t"
|
||||
".insn r %[custom3], 3, 0, x0, x0, x0\n\t"
|
||||
"slli x6, x5, 2\n\t"
|
||||
"la x7, g_seen\n\t"
|
||||
"add x7, x7, x6\n\t"
|
||||
"li x6, %[init_base]\n\t"
|
||||
"or x6, x6, x5\n\t"
|
||||
"sw x6, 0(x7)\n\t"
|
||||
"2:\n\t"
|
||||
"la x6, g_case_mem\n\t"
|
||||
"lw x7, 0(x6)\n\t"
|
||||
"li x4, %[p_ready]\n\t"
|
||||
"bne x7, x4, 2b\n\t"
|
||||
"li x4, %[smem_base]\n\t"
|
||||
".insn r %[custom3], 0, 0, x2, x1, x4\n\t"
|
||||
".insn r %[custom3], 1, 0, x0, x0, x0\n\t"
|
||||
"la x3, g_case12_out\n\t"
|
||||
"li x7, 0\n\t"
|
||||
"3:\n\t"
|
||||
"add x4, x2, x7\n\t"
|
||||
"add x6, x3, x7\n\t"
|
||||
".insn r %[custom3], 6, 0, x0, x4, x6\n\t"
|
||||
"addi x7, x7, 16\n\t"
|
||||
"li x4, %[tile_bytes]\n\t"
|
||||
"blt x7, x4, 3b\n\t"
|
||||
".insn r %[custom3], 3, 0, x0, x0, x0\n\t"
|
||||
"csrr x5, %[csr_wid]\n\t"
|
||||
"slli x6, x5, 2\n\t"
|
||||
"la x7, g_seen\n\t"
|
||||
"add x7, x7, x6\n\t"
|
||||
"li x6, %[done_base]\n\t"
|
||||
"or x6, x6, x5\n\t"
|
||||
"sw x6, 0(x7)\n\t"
|
||||
".insn r %[custom0], 0, 0, x0, x0, x0\n\t"
|
||||
"4: j 4b\n\t"
|
||||
:
|
||||
: [csr_wid] "i"(VX_CSR_WARP_ID), [custom0] "i"(RISCV_CUSTOM0),
|
||||
[custom3] "i"(RISCV_CUSTOM3),
|
||||
[num_scalar_warps] "i"(NUM_SCALAR_WARPS),
|
||||
[c_offset] "i"(WU_BW_TMEM_C_BYTE_OFFSET),
|
||||
[tile_bytes] "i"(WU_BW_TMEM_TILE_BYTES),
|
||||
[smem_base] "i"(WU_BW_DEV_SMEM_START_ADDR),
|
||||
[init_base] "i"(WU_CASE12_INIT_BASE),
|
||||
[done_base] "i"(WU_CASE12_DONE_BASE),
|
||||
[p_ready] "i"(WU_CASE12_P_READY)
|
||||
: "memory");
|
||||
}
|
||||
|
||||
extern "C" int wu_main() {
|
||||
if (vx_core_id() != 0 || vx_warp_id() != 0) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
const uint32_t tid = wu_tid();
|
||||
if (tid == 0) {
|
||||
wu_case_reset();
|
||||
for (uint32_t i = 0; i < WU_BW_OUT_WORDS; ++i) {
|
||||
g_case12_out[i] = 0;
|
||||
}
|
||||
wu_bw_fill_smem_tile(
|
||||
reinterpret_cast<volatile uint32_t *>(WU_BW_DEV_SMEM_START_ADDR),
|
||||
WU_BW_FP16_TWO_PACKED);
|
||||
vx_spawn_tensor(1u << NUM_SCALAR_WARPS, tensor_case12_worker);
|
||||
if (wu_wait_seen_mask(1u << NUM_SCALAR_WARPS, WU_CASE12_INIT_BASE) != 0) {
|
||||
g_case_mem[1] = 0x12u;
|
||||
}
|
||||
}
|
||||
asm volatile("fence rw, rw" ::: "memory");
|
||||
|
||||
if (g_case_mem[1] == 0) {
|
||||
vx_tmc(wu_bw_all_lanes_mask());
|
||||
wu_bw_fill_tmem_tile(wu_bw_tmem_a_byte_base(0), WU_BW_FP16_ONE_PACKED);
|
||||
vx_tmc_one();
|
||||
}
|
||||
asm volatile("fence rw, rw" ::: "memory");
|
||||
|
||||
if (tid == 0) {
|
||||
g_case_mem[0] = WU_CASE12_P_READY;
|
||||
if (g_case_mem[1] == 0 &&
|
||||
wu_wait_seen_mask(1u << NUM_SCALAR_WARPS, WU_CASE12_DONE_BASE) != 0) {
|
||||
g_case_mem[1] = 0x13u;
|
||||
}
|
||||
if (g_case_mem[1] == 0) {
|
||||
volatile uint32_t bad_actual = 0;
|
||||
const uint32_t bad =
|
||||
wu_bw_verify_constant(g_case12_out, WU_BW_OUT_WORDS,
|
||||
WU_BW_FP32_SIXTY_SEVEN, &bad_actual);
|
||||
if (bad != WU_BW_OUT_WORDS) {
|
||||
g_aux[0] = bad;
|
||||
g_aux[1] = bad_actual;
|
||||
g_case_mem[1] = 0x14u;
|
||||
}
|
||||
}
|
||||
if (g_case_mem[1] != 0) {
|
||||
wu_case_fail(g_case_mem[1]);
|
||||
return 1;
|
||||
}
|
||||
wu_case_pass();
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
3
kernels/wu_arch_cases/case13_flash_pv_two_warps/Makefile
Normal file
3
kernels/wu_arch_cases/case13_flash_pv_two_warps/Makefile
Normal file
@@ -0,0 +1,3 @@
|
||||
PROJECT = case13_flash_pv_two_warps
|
||||
|
||||
include ../case.mk
|
||||
@@ -0,0 +1,8 @@
|
||||
# case13_flash_pv_two_warps
|
||||
|
||||
Validates the Wu Blackwell FlashAttention PV substage across both tensor warps.
|
||||
|
||||
Scalar warp 0 writes one `P` tile into each tensor warp's TMEM A partition.
|
||||
Both tensor warps consume the same SMEM `V` tile, accumulate into their own TMEM
|
||||
C tiles, copy out contiguous row blocks, and stop. Every fp32 output is expected
|
||||
to be `67.0`.
|
||||
135
kernels/wu_arch_cases/case13_flash_pv_two_warps/kernel.cpp
Normal file
135
kernels/wu_arch_cases/case13_flash_pv_two_warps/kernel.cpp
Normal file
@@ -0,0 +1,135 @@
|
||||
#include "../common_wu_blackwell_fa.h"
|
||||
|
||||
#define WU_CASE13_INIT_BASE 0x8d00u
|
||||
#define WU_CASE13_DONE_BASE 0x8e00u
|
||||
#define WU_CASE13_P_READY 0x8f00u
|
||||
#define WU_CASE13_OUT_WORDS (NUM_TENSOR_WARPS * WU_BW_OUT_WORDS)
|
||||
|
||||
extern "C" {
|
||||
volatile uint32_t g_case13_o_row[4] __attribute__((aligned(16))) = {
|
||||
WU_BW_FP32_THREE, WU_BW_FP32_THREE, WU_BW_FP32_THREE,
|
||||
WU_BW_FP32_THREE};
|
||||
volatile uint32_t g_case13_out[WU_CASE13_OUT_WORDS]
|
||||
__attribute__((aligned(16)));
|
||||
}
|
||||
|
||||
extern "C" void __attribute__((naked, noinline, used)) tensor_case13_worker() {
|
||||
asm volatile(
|
||||
"csrr x5, %[csr_wid]\n\t"
|
||||
"addi x1, x5, -%[num_scalar_warps]\n\t"
|
||||
"slli x1, x1, 11\n\t"
|
||||
"addi x2, x1, %[c_offset]\n\t"
|
||||
"la x3, g_case13_o_row\n\t"
|
||||
"li x7, 0\n\t"
|
||||
"1:\n\t"
|
||||
"add x4, x2, x7\n\t"
|
||||
".insn r %[custom3], 2, 0, x0, x4, x3\n\t"
|
||||
"addi x7, x7, 16\n\t"
|
||||
"li x4, %[tile_bytes]\n\t"
|
||||
"blt x7, x4, 1b\n\t"
|
||||
".insn r %[custom3], 3, 0, x0, x0, x0\n\t"
|
||||
"slli x6, x5, 2\n\t"
|
||||
"la x7, g_seen\n\t"
|
||||
"add x7, x7, x6\n\t"
|
||||
"li x6, %[init_base]\n\t"
|
||||
"or x6, x6, x5\n\t"
|
||||
"sw x6, 0(x7)\n\t"
|
||||
"2:\n\t"
|
||||
"la x6, g_case_mem\n\t"
|
||||
"lw x7, 0(x6)\n\t"
|
||||
"li x4, %[p_ready]\n\t"
|
||||
"bne x7, x4, 2b\n\t"
|
||||
"li x4, %[smem_base]\n\t"
|
||||
".insn r %[custom3], 0, 0, x2, x1, x4\n\t"
|
||||
".insn r %[custom3], 1, 0, x0, x0, x0\n\t"
|
||||
"csrr x5, %[csr_wid]\n\t"
|
||||
"addi x6, x5, -%[num_scalar_warps]\n\t"
|
||||
"slli x7, x6, 10\n\t"
|
||||
"la x3, g_case13_out\n\t"
|
||||
"add x3, x3, x7\n\t"
|
||||
"li x7, 0\n\t"
|
||||
"3:\n\t"
|
||||
"add x4, x2, x7\n\t"
|
||||
"add x6, x3, x7\n\t"
|
||||
".insn r %[custom3], 6, 0, x0, x4, x6\n\t"
|
||||
"addi x7, x7, 16\n\t"
|
||||
"li x4, %[tile_bytes]\n\t"
|
||||
"blt x7, x4, 3b\n\t"
|
||||
".insn r %[custom3], 3, 0, x0, x0, x0\n\t"
|
||||
"csrr x5, %[csr_wid]\n\t"
|
||||
"slli x6, x5, 2\n\t"
|
||||
"la x7, g_seen\n\t"
|
||||
"add x7, x7, x6\n\t"
|
||||
"li x6, %[done_base]\n\t"
|
||||
"or x6, x6, x5\n\t"
|
||||
"sw x6, 0(x7)\n\t"
|
||||
".insn r %[custom0], 0, 0, x0, x0, x0\n\t"
|
||||
"4: j 4b\n\t"
|
||||
:
|
||||
: [csr_wid] "i"(VX_CSR_WARP_ID), [custom0] "i"(RISCV_CUSTOM0),
|
||||
[custom3] "i"(RISCV_CUSTOM3),
|
||||
[num_scalar_warps] "i"(NUM_SCALAR_WARPS),
|
||||
[c_offset] "i"(WU_BW_TMEM_C_BYTE_OFFSET),
|
||||
[tile_bytes] "i"(WU_BW_TMEM_TILE_BYTES),
|
||||
[smem_base] "i"(WU_BW_DEV_SMEM_START_ADDR),
|
||||
[init_base] "i"(WU_CASE13_INIT_BASE),
|
||||
[done_base] "i"(WU_CASE13_DONE_BASE),
|
||||
[p_ready] "i"(WU_CASE13_P_READY)
|
||||
: "memory");
|
||||
}
|
||||
|
||||
extern "C" int wu_main() {
|
||||
if (vx_core_id() != 0 || vx_warp_id() != 0) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
const uint32_t tid = wu_tid();
|
||||
const uint32_t tensor_mask = vx_tensor_warp_mask();
|
||||
if (tid == 0) {
|
||||
wu_case_reset();
|
||||
for (uint32_t i = 0; i < WU_CASE13_OUT_WORDS; ++i) {
|
||||
g_case13_out[i] = 0;
|
||||
}
|
||||
wu_bw_fill_smem_tile(
|
||||
reinterpret_cast<volatile uint32_t *>(WU_BW_DEV_SMEM_START_ADDR),
|
||||
WU_BW_FP16_TWO_PACKED);
|
||||
vx_spawn_tensor(tensor_mask, tensor_case13_worker);
|
||||
if (wu_wait_seen_mask(tensor_mask, WU_CASE13_INIT_BASE) != 0) {
|
||||
g_case_mem[1] = 0x21u;
|
||||
}
|
||||
}
|
||||
asm volatile("fence rw, rw" ::: "memory");
|
||||
|
||||
if (g_case_mem[1] == 0) {
|
||||
vx_tmc(wu_bw_all_lanes_mask());
|
||||
wu_bw_fill_tmem_tile(wu_bw_tmem_a_byte_base(0), WU_BW_FP16_ONE_PACKED);
|
||||
wu_bw_fill_tmem_tile(wu_bw_tmem_a_byte_base(1), WU_BW_FP16_ONE_PACKED);
|
||||
vx_tmc_one();
|
||||
}
|
||||
asm volatile("fence rw, rw" ::: "memory");
|
||||
|
||||
if (tid == 0) {
|
||||
g_case_mem[0] = WU_CASE13_P_READY;
|
||||
if (g_case_mem[1] == 0 &&
|
||||
wu_wait_seen_mask(tensor_mask, WU_CASE13_DONE_BASE) != 0) {
|
||||
g_case_mem[1] = 0x22u;
|
||||
}
|
||||
if (g_case_mem[1] == 0) {
|
||||
volatile uint32_t bad_actual = 0;
|
||||
const uint32_t bad =
|
||||
wu_bw_verify_constant(g_case13_out, WU_CASE13_OUT_WORDS,
|
||||
WU_BW_FP32_SIXTY_SEVEN, &bad_actual);
|
||||
if (bad != WU_CASE13_OUT_WORDS) {
|
||||
g_aux[0] = bad;
|
||||
g_aux[1] = bad_actual;
|
||||
g_case_mem[1] = 0x23u;
|
||||
}
|
||||
}
|
||||
if (g_case_mem[1] != 0) {
|
||||
wu_case_fail(g_case_mem[1]);
|
||||
return 1;
|
||||
}
|
||||
wu_case_pass();
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
3
kernels/wu_arch_cases/case14_flash_pv_k64/Makefile
Normal file
3
kernels/wu_arch_cases/case14_flash_pv_k64/Makefile
Normal file
@@ -0,0 +1,3 @@
|
||||
PROJECT = case14_flash_pv_k64
|
||||
|
||||
include ../case.mk
|
||||
7
kernels/wu_arch_cases/case14_flash_pv_k64/README.md
Normal file
7
kernels/wu_arch_cases/case14_flash_pv_k64/README.md
Normal file
@@ -0,0 +1,7 @@
|
||||
# case14_flash_pv_k64
|
||||
|
||||
Validates PV accumulation over two Blackwell `K=32` BWGMMA steps.
|
||||
|
||||
Both tensor warps run two consecutive BWGMMA operations against the same TMEM C
|
||||
tile, modeling a `K=64` FlashAttention PV accumulation. With `P=1.0`,
|
||||
`V=2.0`, and `O_init=5.0`, every fp32 output is expected to be `133.0`.
|
||||
136
kernels/wu_arch_cases/case14_flash_pv_k64/kernel.cpp
Normal file
136
kernels/wu_arch_cases/case14_flash_pv_k64/kernel.cpp
Normal file
@@ -0,0 +1,136 @@
|
||||
#include "../common_wu_blackwell_fa.h"
|
||||
|
||||
#define WU_CASE14_INIT_BASE 0x9000u
|
||||
#define WU_CASE14_DONE_BASE 0x9100u
|
||||
#define WU_CASE14_P_READY 0x9200u
|
||||
#define WU_CASE14_OUT_WORDS (NUM_TENSOR_WARPS * WU_BW_OUT_WORDS)
|
||||
|
||||
extern "C" {
|
||||
volatile uint32_t g_case14_o_row[4] __attribute__((aligned(16))) = {
|
||||
WU_BW_FP32_FIVE, WU_BW_FP32_FIVE, WU_BW_FP32_FIVE, WU_BW_FP32_FIVE};
|
||||
volatile uint32_t g_case14_out[WU_CASE14_OUT_WORDS]
|
||||
__attribute__((aligned(16)));
|
||||
}
|
||||
|
||||
extern "C" void __attribute__((naked, noinline, used)) tensor_case14_worker() {
|
||||
asm volatile(
|
||||
"csrr x5, %[csr_wid]\n\t"
|
||||
"addi x1, x5, -%[num_scalar_warps]\n\t"
|
||||
"slli x1, x1, 11\n\t"
|
||||
"addi x2, x1, %[c_offset]\n\t"
|
||||
"la x3, g_case14_o_row\n\t"
|
||||
"li x7, 0\n\t"
|
||||
"1:\n\t"
|
||||
"add x4, x2, x7\n\t"
|
||||
".insn r %[custom3], 2, 0, x0, x4, x3\n\t"
|
||||
"addi x7, x7, 16\n\t"
|
||||
"li x4, %[tile_bytes]\n\t"
|
||||
"blt x7, x4, 1b\n\t"
|
||||
".insn r %[custom3], 3, 0, x0, x0, x0\n\t"
|
||||
"slli x6, x5, 2\n\t"
|
||||
"la x7, g_seen\n\t"
|
||||
"add x7, x7, x6\n\t"
|
||||
"li x6, %[init_base]\n\t"
|
||||
"or x6, x6, x5\n\t"
|
||||
"sw x6, 0(x7)\n\t"
|
||||
"2:\n\t"
|
||||
"la x6, g_case_mem\n\t"
|
||||
"lw x7, 0(x6)\n\t"
|
||||
"li x4, %[p_ready]\n\t"
|
||||
"bne x7, x4, 2b\n\t"
|
||||
"li x4, %[smem_base]\n\t"
|
||||
".insn r %[custom3], 0, 0, x2, x1, x4\n\t"
|
||||
".insn r %[custom3], 1, 0, x0, x0, x0\n\t"
|
||||
".insn r %[custom3], 0, 0, x2, x1, x4\n\t"
|
||||
".insn r %[custom3], 1, 0, x0, x0, x0\n\t"
|
||||
"csrr x5, %[csr_wid]\n\t"
|
||||
"addi x6, x5, -%[num_scalar_warps]\n\t"
|
||||
"slli x7, x6, 10\n\t"
|
||||
"la x3, g_case14_out\n\t"
|
||||
"add x3, x3, x7\n\t"
|
||||
"li x7, 0\n\t"
|
||||
"3:\n\t"
|
||||
"add x4, x2, x7\n\t"
|
||||
"add x6, x3, x7\n\t"
|
||||
".insn r %[custom3], 6, 0, x0, x4, x6\n\t"
|
||||
"addi x7, x7, 16\n\t"
|
||||
"li x4, %[tile_bytes]\n\t"
|
||||
"blt x7, x4, 3b\n\t"
|
||||
".insn r %[custom3], 3, 0, x0, x0, x0\n\t"
|
||||
"csrr x5, %[csr_wid]\n\t"
|
||||
"slli x6, x5, 2\n\t"
|
||||
"la x7, g_seen\n\t"
|
||||
"add x7, x7, x6\n\t"
|
||||
"li x6, %[done_base]\n\t"
|
||||
"or x6, x6, x5\n\t"
|
||||
"sw x6, 0(x7)\n\t"
|
||||
".insn r %[custom0], 0, 0, x0, x0, x0\n\t"
|
||||
"4: j 4b\n\t"
|
||||
:
|
||||
: [csr_wid] "i"(VX_CSR_WARP_ID), [custom0] "i"(RISCV_CUSTOM0),
|
||||
[custom3] "i"(RISCV_CUSTOM3),
|
||||
[num_scalar_warps] "i"(NUM_SCALAR_WARPS),
|
||||
[c_offset] "i"(WU_BW_TMEM_C_BYTE_OFFSET),
|
||||
[tile_bytes] "i"(WU_BW_TMEM_TILE_BYTES),
|
||||
[smem_base] "i"(WU_BW_DEV_SMEM_START_ADDR),
|
||||
[init_base] "i"(WU_CASE14_INIT_BASE),
|
||||
[done_base] "i"(WU_CASE14_DONE_BASE),
|
||||
[p_ready] "i"(WU_CASE14_P_READY)
|
||||
: "memory");
|
||||
}
|
||||
|
||||
extern "C" int wu_main() {
|
||||
if (vx_core_id() != 0 || vx_warp_id() != 0) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
const uint32_t tid = wu_tid();
|
||||
const uint32_t tensor_mask = vx_tensor_warp_mask();
|
||||
if (tid == 0) {
|
||||
wu_case_reset();
|
||||
for (uint32_t i = 0; i < WU_CASE14_OUT_WORDS; ++i) {
|
||||
g_case14_out[i] = 0;
|
||||
}
|
||||
wu_bw_fill_smem_tile(
|
||||
reinterpret_cast<volatile uint32_t *>(WU_BW_DEV_SMEM_START_ADDR),
|
||||
WU_BW_FP16_TWO_PACKED);
|
||||
vx_spawn_tensor(tensor_mask, tensor_case14_worker);
|
||||
if (wu_wait_seen_mask(tensor_mask, WU_CASE14_INIT_BASE) != 0) {
|
||||
g_case_mem[1] = 0x31u;
|
||||
}
|
||||
}
|
||||
asm volatile("fence rw, rw" ::: "memory");
|
||||
|
||||
if (g_case_mem[1] == 0) {
|
||||
vx_tmc(wu_bw_all_lanes_mask());
|
||||
wu_bw_fill_tmem_tile(wu_bw_tmem_a_byte_base(0), WU_BW_FP16_ONE_PACKED);
|
||||
wu_bw_fill_tmem_tile(wu_bw_tmem_a_byte_base(1), WU_BW_FP16_ONE_PACKED);
|
||||
vx_tmc_one();
|
||||
}
|
||||
asm volatile("fence rw, rw" ::: "memory");
|
||||
|
||||
if (tid == 0) {
|
||||
g_case_mem[0] = WU_CASE14_P_READY;
|
||||
if (g_case_mem[1] == 0 &&
|
||||
wu_wait_seen_mask(tensor_mask, WU_CASE14_DONE_BASE) != 0) {
|
||||
g_case_mem[1] = 0x32u;
|
||||
}
|
||||
if (g_case_mem[1] == 0) {
|
||||
volatile uint32_t bad_actual = 0;
|
||||
const uint32_t bad =
|
||||
wu_bw_verify_constant(g_case14_out, WU_CASE14_OUT_WORDS,
|
||||
WU_BW_FP32_ONE_THIRTY_THREE, &bad_actual);
|
||||
if (bad != WU_CASE14_OUT_WORDS) {
|
||||
g_aux[0] = bad;
|
||||
g_aux[1] = bad_actual;
|
||||
g_case_mem[1] = 0x33u;
|
||||
}
|
||||
}
|
||||
if (g_case_mem[1] != 0) {
|
||||
wu_case_fail(g_case_mem[1]);
|
||||
return 1;
|
||||
}
|
||||
wu_case_pass();
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
@@ -0,0 +1,3 @@
|
||||
PROJECT = case15_flash_softmax_pv_stage
|
||||
|
||||
include ../case.mk
|
||||
@@ -0,0 +1,8 @@
|
||||
# case15_flash_softmax_pv_stage
|
||||
|
||||
Validates the scalar softmax-stage handoff into the Blackwell PV GEMM.
|
||||
|
||||
The tensor warp initializes TMEM C with a score/O value of `1.0`. Scalar warp 0
|
||||
reads that value through the scalar TMEM load path, writes a full fp16 `P=1.0`
|
||||
tile into TMEM A, and releases the tensor warp. The tensor warp then runs PV
|
||||
against `V=2.0`; every fp32 output is expected to be `65.0`.
|
||||
145
kernels/wu_arch_cases/case15_flash_softmax_pv_stage/kernel.cpp
Normal file
145
kernels/wu_arch_cases/case15_flash_softmax_pv_stage/kernel.cpp
Normal file
@@ -0,0 +1,145 @@
|
||||
#include "../common_wu_blackwell_fa.h"
|
||||
|
||||
#define WU_CASE15_INIT_BASE 0x9300u
|
||||
#define WU_CASE15_DONE_BASE 0x9400u
|
||||
#define WU_CASE15_P_READY 0x9500u
|
||||
|
||||
extern "C" {
|
||||
volatile uint32_t g_case15_score_row[4] __attribute__((aligned(16))) = {
|
||||
WU_BW_FP32_ONE, WU_BW_FP32_ONE, WU_BW_FP32_ONE, WU_BW_FP32_ONE};
|
||||
volatile uint32_t g_case15_out[WU_BW_OUT_WORDS] __attribute__((aligned(16)));
|
||||
volatile uint32_t g_case15_scalar_seen[4] __attribute__((aligned(16)));
|
||||
}
|
||||
|
||||
extern "C" void __attribute__((naked, noinline, used)) tensor_case15_worker() {
|
||||
asm volatile(
|
||||
"csrr x5, %[csr_wid]\n\t"
|
||||
"addi x1, x5, -%[num_scalar_warps]\n\t"
|
||||
"slli x1, x1, 11\n\t"
|
||||
"addi x2, x1, %[c_offset]\n\t"
|
||||
"la x3, g_case15_score_row\n\t"
|
||||
"li x7, 0\n\t"
|
||||
"1:\n\t"
|
||||
"add x4, x2, x7\n\t"
|
||||
".insn r %[custom3], 2, 0, x0, x4, x3\n\t"
|
||||
"addi x7, x7, 16\n\t"
|
||||
"li x4, %[tile_bytes]\n\t"
|
||||
"blt x7, x4, 1b\n\t"
|
||||
".insn r %[custom3], 3, 0, x0, x0, x0\n\t"
|
||||
"slli x6, x5, 2\n\t"
|
||||
"la x7, g_seen\n\t"
|
||||
"add x7, x7, x6\n\t"
|
||||
"li x6, %[init_base]\n\t"
|
||||
"or x6, x6, x5\n\t"
|
||||
"sw x6, 0(x7)\n\t"
|
||||
"2:\n\t"
|
||||
"la x6, g_case_mem\n\t"
|
||||
"lw x7, 0(x6)\n\t"
|
||||
"li x4, %[p_ready]\n\t"
|
||||
"bne x7, x4, 2b\n\t"
|
||||
"li x4, %[smem_base]\n\t"
|
||||
".insn r %[custom3], 0, 0, x2, x1, x4\n\t"
|
||||
".insn r %[custom3], 1, 0, x0, x0, x0\n\t"
|
||||
"la x3, g_case15_out\n\t"
|
||||
"li x7, 0\n\t"
|
||||
"3:\n\t"
|
||||
"add x4, x2, x7\n\t"
|
||||
"add x6, x3, x7\n\t"
|
||||
".insn r %[custom3], 6, 0, x0, x4, x6\n\t"
|
||||
"addi x7, x7, 16\n\t"
|
||||
"li x4, %[tile_bytes]\n\t"
|
||||
"blt x7, x4, 3b\n\t"
|
||||
".insn r %[custom3], 3, 0, x0, x0, x0\n\t"
|
||||
"csrr x5, %[csr_wid]\n\t"
|
||||
"slli x6, x5, 2\n\t"
|
||||
"la x7, g_seen\n\t"
|
||||
"add x7, x7, x6\n\t"
|
||||
"li x6, %[done_base]\n\t"
|
||||
"or x6, x6, x5\n\t"
|
||||
"sw x6, 0(x7)\n\t"
|
||||
".insn r %[custom0], 0, 0, x0, x0, x0\n\t"
|
||||
"4: j 4b\n\t"
|
||||
:
|
||||
: [csr_wid] "i"(VX_CSR_WARP_ID), [custom0] "i"(RISCV_CUSTOM0),
|
||||
[custom3] "i"(RISCV_CUSTOM3),
|
||||
[num_scalar_warps] "i"(NUM_SCALAR_WARPS),
|
||||
[c_offset] "i"(WU_BW_TMEM_C_BYTE_OFFSET),
|
||||
[tile_bytes] "i"(WU_BW_TMEM_TILE_BYTES),
|
||||
[smem_base] "i"(WU_BW_DEV_SMEM_START_ADDR),
|
||||
[init_base] "i"(WU_CASE15_INIT_BASE),
|
||||
[done_base] "i"(WU_CASE15_DONE_BASE),
|
||||
[p_ready] "i"(WU_CASE15_P_READY)
|
||||
: "memory");
|
||||
}
|
||||
|
||||
extern "C" int wu_main() {
|
||||
if (vx_core_id() != 0 || vx_warp_id() != 0) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
const uint32_t tid = wu_tid();
|
||||
const uint32_t tensor_mask = 1u << NUM_SCALAR_WARPS;
|
||||
if (tid == 0) {
|
||||
wu_case_reset();
|
||||
for (uint32_t i = 0; i < WU_BW_OUT_WORDS; ++i) {
|
||||
g_case15_out[i] = 0;
|
||||
}
|
||||
for (uint32_t i = 0; i < 4; ++i) {
|
||||
g_case15_scalar_seen[i] = 0;
|
||||
}
|
||||
wu_bw_fill_smem_tile(
|
||||
reinterpret_cast<volatile uint32_t *>(WU_BW_DEV_SMEM_START_ADDR),
|
||||
WU_BW_FP16_TWO_PACKED);
|
||||
vx_spawn_tensor(tensor_mask, tensor_case15_worker);
|
||||
if (wu_wait_seen_mask(tensor_mask, WU_CASE15_INIT_BASE) != 0) {
|
||||
g_case_mem[1] = 0x41u;
|
||||
}
|
||||
}
|
||||
asm volatile("fence rw, rw" ::: "memory");
|
||||
|
||||
const uint32_t c_frag =
|
||||
wu_bw_tmem_c_byte_base(0) / WU_BW_TMEM_FRAGMENT_BYTES;
|
||||
const uint32_t observed = wu_bw_scalar_tmem_ld(c_frag);
|
||||
g_case15_scalar_seen[tid] = observed;
|
||||
asm volatile("fence rw, rw" ::: "memory");
|
||||
|
||||
if (tid == 0 && g_case_mem[1] == 0 &&
|
||||
g_case15_scalar_seen[0] != WU_BW_FP32_ONE) {
|
||||
g_aux[0] = 0;
|
||||
g_aux[1] = g_case15_scalar_seen[0];
|
||||
g_case_mem[1] = 0x42u;
|
||||
}
|
||||
asm volatile("fence rw, rw" ::: "memory");
|
||||
|
||||
if (g_case_mem[1] == 0) {
|
||||
vx_tmc(wu_bw_all_lanes_mask());
|
||||
wu_bw_fill_tmem_tile(wu_bw_tmem_a_byte_base(0), WU_BW_FP16_ONE_PACKED);
|
||||
vx_tmc_one();
|
||||
}
|
||||
asm volatile("fence rw, rw" ::: "memory");
|
||||
|
||||
if (tid == 0) {
|
||||
g_case_mem[0] = WU_CASE15_P_READY;
|
||||
if (g_case_mem[1] == 0 &&
|
||||
wu_wait_seen_mask(tensor_mask, WU_CASE15_DONE_BASE) != 0) {
|
||||
g_case_mem[1] = 0x43u;
|
||||
}
|
||||
if (g_case_mem[1] == 0) {
|
||||
volatile uint32_t bad_actual = 0;
|
||||
const uint32_t bad =
|
||||
wu_bw_verify_constant(g_case15_out, WU_BW_OUT_WORDS,
|
||||
WU_BW_FP32_SIXTY_FIVE, &bad_actual);
|
||||
if (bad != WU_BW_OUT_WORDS) {
|
||||
g_aux[0] = bad;
|
||||
g_aux[1] = bad_actual;
|
||||
g_case_mem[1] = 0x44u;
|
||||
}
|
||||
}
|
||||
if (g_case_mem[1] != 0) {
|
||||
wu_case_fail(g_case_mem[1]);
|
||||
return 1;
|
||||
}
|
||||
wu_case_pass();
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
@@ -0,0 +1,3 @@
|
||||
PROJECT = case16_flash_full_pipeline
|
||||
|
||||
include ../case.mk
|
||||
16
kernels/wu_arch_cases/case16_flash_full_pipeline/README.md
Normal file
16
kernels/wu_arch_cases/case16_flash_full_pipeline/README.md
Normal file
@@ -0,0 +1,16 @@
|
||||
# case16_flash_full_pipeline
|
||||
|
||||
Validates a compact end-to-end FlashAttention-style pipeline on the current Wu
|
||||
Blackwell path.
|
||||
|
||||
The tensor warp first computes `S = Q @ K` into TMEM C with `Q=1.0`, `K=1.0`,
|
||||
and `O_init=0.0`, producing a constant fp32 score of `32.0`. The scalar warp
|
||||
reads the score row through scalar TMEM loads, scans the row maximum and
|
||||
normalization denominator with scalar-only `FEXP.S`, converts each probability
|
||||
to packed fp16, writes `P` into TMEM A, refills SMEM with `V=2.0`, and releases
|
||||
the tensor warp. The tensor warp reloads TMEM C with `O=0.0`, then computes
|
||||
`O = P @ V`.
|
||||
|
||||
Every output word is expected to be fp32 `2.0`. This case covers the staged
|
||||
`QK -> scalar softmax handoff -> PV` loop without using the legacy
|
||||
`kernels/flash_attention` implementation.
|
||||
259
kernels/wu_arch_cases/case16_flash_full_pipeline/kernel.cpp
Normal file
259
kernels/wu_arch_cases/case16_flash_full_pipeline/kernel.cpp
Normal file
@@ -0,0 +1,259 @@
|
||||
#include "../common_wu_blackwell_fa.h"
|
||||
|
||||
#define WU_CASE16_INIT_BASE 0x9600u
|
||||
#define WU_CASE16_P_READY 0x9700u
|
||||
#define WU_CASE16_DONE_BASE 0x9800u
|
||||
|
||||
#define WU_BW_FP32_ZERO 0x00000000u
|
||||
#define WU_BW_FP32_TWO 0x40000000u
|
||||
#define WU_BW_FP32_THIRTY_TWO 0x42000000u
|
||||
#define WU_CASE16_ROW_N 32u
|
||||
|
||||
extern "C" {
|
||||
volatile uint32_t g_case16_q_row[4] __attribute__((aligned(16))) = {
|
||||
WU_BW_FP16_ONE_PACKED, WU_BW_FP16_ONE_PACKED, WU_BW_FP16_ONE_PACKED,
|
||||
WU_BW_FP16_ONE_PACKED};
|
||||
volatile uint32_t g_case16_zero_row[4] __attribute__((aligned(16))) = {
|
||||
WU_BW_FP32_ZERO, WU_BW_FP32_ZERO, WU_BW_FP32_ZERO, WU_BW_FP32_ZERO};
|
||||
volatile uint32_t g_case16_out[WU_BW_OUT_WORDS] __attribute__((aligned(16)));
|
||||
volatile uint32_t g_case16_scalar_seen[4] __attribute__((aligned(16)));
|
||||
volatile uint32_t g_case16_p_bits[4] __attribute__((aligned(16)));
|
||||
}
|
||||
|
||||
static inline float wu_case16_bits_to_f32(uint32_t bits) {
|
||||
union {
|
||||
uint32_t u;
|
||||
float f;
|
||||
} v = {bits};
|
||||
return v.f;
|
||||
}
|
||||
|
||||
static inline uint32_t wu_case16_f32_to_bits(float value) {
|
||||
union {
|
||||
float f;
|
||||
uint32_t u;
|
||||
} v = {value};
|
||||
return v.u;
|
||||
}
|
||||
|
||||
static inline uint16_t wu_case16_f32_to_f16_positive(float value) {
|
||||
const uint32_t bits = wu_case16_f32_to_bits(value);
|
||||
const uint32_t exp = (bits >> 23) & 0xffu;
|
||||
uint32_t mant = bits & 0x7fffffu;
|
||||
|
||||
if (exp == 0 || value <= 0.0f) {
|
||||
return 0;
|
||||
}
|
||||
if (exp >= 143u) {
|
||||
return 0x7c00u;
|
||||
}
|
||||
if (exp <= 112u) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
uint32_t half_exp = exp - 112u;
|
||||
mant += 0x1000u;
|
||||
if (mant & 0x800000u) {
|
||||
mant = 0;
|
||||
++half_exp;
|
||||
}
|
||||
if (half_exp >= 31u) {
|
||||
return 0x7c00u;
|
||||
}
|
||||
return static_cast<uint16_t>((half_exp << 10) | (mant >> 13));
|
||||
}
|
||||
|
||||
static inline uint32_t wu_case16_pack_f16x2(float value) {
|
||||
const uint32_t h = wu_case16_f32_to_f16_positive(value);
|
||||
return h | (h << 16);
|
||||
}
|
||||
|
||||
static inline void wu_case16_softmax_tmem_row_to_p(uint32_t score_frag_base,
|
||||
uint32_t p_byte_base) {
|
||||
float row_max = wu_case16_bits_to_f32(wu_bw_scalar_tmem_ld(score_frag_base));
|
||||
for (uint32_t i = 1; i < WU_CASE16_ROW_N; ++i) {
|
||||
const float score =
|
||||
wu_case16_bits_to_f32(wu_bw_scalar_tmem_ld(score_frag_base + i));
|
||||
row_max = score > row_max ? score : row_max;
|
||||
}
|
||||
|
||||
float denom = 0.0f;
|
||||
for (uint32_t i = 0; i < WU_CASE16_ROW_N; ++i) {
|
||||
const float score =
|
||||
wu_case16_bits_to_f32(wu_bw_scalar_tmem_ld(score_frag_base + i));
|
||||
denom += wu_fexp_s(score - row_max);
|
||||
}
|
||||
|
||||
const uint32_t p_frag_base = p_byte_base / WU_BW_TMEM_FRAGMENT_BYTES;
|
||||
for (uint32_t frag = 0; frag < WU_BW_TMEM_FRAGMENTS; ++frag) {
|
||||
const uint32_t row_idx = frag % WU_CASE16_ROW_N;
|
||||
const float score =
|
||||
wu_case16_bits_to_f32(wu_bw_scalar_tmem_ld(score_frag_base + row_idx));
|
||||
const float p = wu_fexp_s(score - row_max) / denom;
|
||||
if (frag == 0) {
|
||||
g_case16_p_bits[wu_tid()] = wu_case16_f32_to_bits(p);
|
||||
}
|
||||
wu_bw_scalar_tmem_st(p_frag_base + frag, wu_case16_pack_f16x2(p));
|
||||
}
|
||||
asm volatile("fence rw, rw" ::: "memory");
|
||||
}
|
||||
|
||||
extern "C" void __attribute__((naked, noinline, used)) tensor_case16_worker() {
|
||||
asm volatile(
|
||||
"csrr x5, %[csr_wid]\n\t"
|
||||
"addi x1, x5, -%[num_scalar_warps]\n\t"
|
||||
"slli x1, x1, 11\n\t"
|
||||
"addi x2, x1, %[c_offset]\n\t"
|
||||
"la x3, g_case16_q_row\n\t"
|
||||
"li x7, 0\n\t"
|
||||
"1:\n\t"
|
||||
"add x4, x1, x7\n\t"
|
||||
".insn r %[custom3], 2, 0, x0, x4, x3\n\t"
|
||||
"addi x7, x7, 16\n\t"
|
||||
"li x4, %[tile_bytes]\n\t"
|
||||
"blt x7, x4, 1b\n\t"
|
||||
"la x3, g_case16_zero_row\n\t"
|
||||
"li x7, 0\n\t"
|
||||
"2:\n\t"
|
||||
"add x4, x2, x7\n\t"
|
||||
".insn r %[custom3], 2, 0, x0, x4, x3\n\t"
|
||||
"addi x7, x7, 16\n\t"
|
||||
"li x4, %[tile_bytes]\n\t"
|
||||
"blt x7, x4, 2b\n\t"
|
||||
".insn r %[custom3], 3, 0, x0, x0, x0\n\t"
|
||||
"li x4, %[smem_base]\n\t"
|
||||
".insn r %[custom3], 0, 0, x2, x1, x4\n\t"
|
||||
".insn r %[custom3], 1, 0, x0, x0, x0\n\t"
|
||||
"slli x6, x5, 2\n\t"
|
||||
"la x7, g_seen\n\t"
|
||||
"add x7, x7, x6\n\t"
|
||||
"li x6, %[init_base]\n\t"
|
||||
"or x6, x6, x5\n\t"
|
||||
"sw x6, 0(x7)\n\t"
|
||||
"3:\n\t"
|
||||
"la x6, g_case_mem\n\t"
|
||||
"lw x7, 0(x6)\n\t"
|
||||
"li x4, %[p_ready]\n\t"
|
||||
"bne x7, x4, 3b\n\t"
|
||||
"la x3, g_case16_zero_row\n\t"
|
||||
"li x7, 0\n\t"
|
||||
"4:\n\t"
|
||||
"add x4, x2, x7\n\t"
|
||||
".insn r %[custom3], 2, 0, x0, x4, x3\n\t"
|
||||
"addi x7, x7, 16\n\t"
|
||||
"li x4, %[tile_bytes]\n\t"
|
||||
"blt x7, x4, 4b\n\t"
|
||||
".insn r %[custom3], 3, 0, x0, x0, x0\n\t"
|
||||
"li x4, %[smem_base]\n\t"
|
||||
".insn r %[custom3], 0, 0, x2, x1, x4\n\t"
|
||||
".insn r %[custom3], 1, 0, x0, x0, x0\n\t"
|
||||
"la x3, g_case16_out\n\t"
|
||||
"li x7, 0\n\t"
|
||||
"5:\n\t"
|
||||
"add x4, x2, x7\n\t"
|
||||
"add x6, x3, x7\n\t"
|
||||
".insn r %[custom3], 6, 0, x0, x4, x6\n\t"
|
||||
"addi x7, x7, 16\n\t"
|
||||
"li x4, %[tile_bytes]\n\t"
|
||||
"blt x7, x4, 5b\n\t"
|
||||
".insn r %[custom3], 3, 0, x0, x0, x0\n\t"
|
||||
"csrr x5, %[csr_wid]\n\t"
|
||||
"slli x6, x5, 2\n\t"
|
||||
"la x7, g_seen\n\t"
|
||||
"add x7, x7, x6\n\t"
|
||||
"li x6, %[done_base]\n\t"
|
||||
"or x6, x6, x5\n\t"
|
||||
"sw x6, 0(x7)\n\t"
|
||||
".insn r %[custom0], 0, 0, x0, x0, x0\n\t"
|
||||
"6: j 6b\n\t"
|
||||
:
|
||||
: [csr_wid] "i"(VX_CSR_WARP_ID), [custom0] "i"(RISCV_CUSTOM0),
|
||||
[custom3] "i"(RISCV_CUSTOM3),
|
||||
[num_scalar_warps] "i"(NUM_SCALAR_WARPS),
|
||||
[c_offset] "i"(WU_BW_TMEM_C_BYTE_OFFSET),
|
||||
[tile_bytes] "i"(WU_BW_TMEM_TILE_BYTES),
|
||||
[smem_base] "i"(WU_BW_DEV_SMEM_START_ADDR),
|
||||
[init_base] "i"(WU_CASE16_INIT_BASE),
|
||||
[p_ready] "i"(WU_CASE16_P_READY),
|
||||
[done_base] "i"(WU_CASE16_DONE_BASE)
|
||||
: "memory");
|
||||
}
|
||||
|
||||
extern "C" int wu_main() {
|
||||
if (vx_core_id() != 0 || vx_warp_id() != 0) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
const uint32_t tid = wu_tid();
|
||||
const uint32_t tensor_mask = 1u << NUM_SCALAR_WARPS;
|
||||
if (tid == 0) {
|
||||
wu_case_reset();
|
||||
for (uint32_t i = 0; i < WU_BW_OUT_WORDS; ++i) {
|
||||
g_case16_out[i] = 0;
|
||||
}
|
||||
for (uint32_t i = 0; i < 4; ++i) {
|
||||
g_case16_scalar_seen[i] = 0;
|
||||
g_case16_p_bits[i] = 0;
|
||||
}
|
||||
wu_bw_fill_smem_tile(
|
||||
reinterpret_cast<volatile uint32_t *>(WU_BW_DEV_SMEM_START_ADDR),
|
||||
WU_BW_FP16_ONE_PACKED);
|
||||
vx_spawn_tensor(tensor_mask, tensor_case16_worker);
|
||||
if (wu_wait_seen_mask(tensor_mask, WU_CASE16_INIT_BASE) != 0) {
|
||||
g_case_mem[1] = 0x61u;
|
||||
}
|
||||
}
|
||||
asm volatile("fence rw, rw" ::: "memory");
|
||||
|
||||
const uint32_t c_frag =
|
||||
wu_bw_tmem_c_byte_base(0) / WU_BW_TMEM_FRAGMENT_BYTES;
|
||||
const uint32_t observed = wu_bw_scalar_tmem_ld(c_frag);
|
||||
g_case16_scalar_seen[tid] = observed;
|
||||
asm volatile("fence rw, rw" ::: "memory");
|
||||
|
||||
if (tid == 0 && g_case_mem[1] == 0 &&
|
||||
g_case16_scalar_seen[0] != WU_BW_FP32_THIRTY_TWO) {
|
||||
g_aux[0] = 0;
|
||||
g_aux[1] = g_case16_scalar_seen[0];
|
||||
g_case_mem[1] = 0x62u;
|
||||
}
|
||||
asm volatile("fence rw, rw" ::: "memory");
|
||||
|
||||
if (g_case_mem[1] == 0) {
|
||||
vx_tmc(wu_bw_all_lanes_mask());
|
||||
wu_case16_softmax_tmem_row_to_p(c_frag, wu_bw_tmem_a_byte_base(0));
|
||||
vx_tmc_one();
|
||||
}
|
||||
asm volatile("fence rw, rw" ::: "memory");
|
||||
|
||||
if (tid == 0) {
|
||||
if (g_case_mem[1] == 0) {
|
||||
wu_bw_fill_smem_tile(
|
||||
reinterpret_cast<volatile uint32_t *>(WU_BW_DEV_SMEM_START_ADDR),
|
||||
WU_BW_FP16_TWO_PACKED);
|
||||
}
|
||||
asm volatile("fence rw, rw" ::: "memory");
|
||||
g_case_mem[0] = WU_CASE16_P_READY;
|
||||
if (g_case_mem[1] == 0 &&
|
||||
wu_wait_seen_mask(tensor_mask, WU_CASE16_DONE_BASE) != 0) {
|
||||
g_case_mem[1] = 0x63u;
|
||||
}
|
||||
if (g_case_mem[1] == 0) {
|
||||
volatile uint32_t bad_actual = 0;
|
||||
const uint32_t bad =
|
||||
wu_bw_verify_constant(g_case16_out, WU_BW_OUT_WORDS,
|
||||
WU_BW_FP32_TWO, &bad_actual);
|
||||
if (bad != WU_BW_OUT_WORDS) {
|
||||
g_aux[0] = bad;
|
||||
g_aux[1] = bad_actual;
|
||||
g_case_mem[1] = 0x64u;
|
||||
}
|
||||
}
|
||||
if (g_case_mem[1] != 0) {
|
||||
wu_case_fail(g_case_mem[1]);
|
||||
return 1;
|
||||
}
|
||||
wu_case_pass();
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
@@ -0,0 +1,3 @@
|
||||
PROJECT = case17_flash_exp_softmax_probe
|
||||
|
||||
include ../case.mk
|
||||
@@ -0,0 +1,15 @@
|
||||
# case17_flash_exp_softmax_probe
|
||||
|
||||
Validates that the scalar Wu path can execute the `e^x` work needed by
|
||||
non-uniform FlashAttention softmax through the custom scalar-only `FEXP.S`
|
||||
instruction.
|
||||
|
||||
The case evaluates a two-element row with scores `{0, ln(2)}`. A numerically
|
||||
stable softmax computes `exp(score - row_max)`, so the exp inputs are
|
||||
`{-ln(2), 0}` and the probabilities should be close to `{1/3, 2/3}`. The
|
||||
inputs are loaded from volatile memory so the compiler cannot fold the result
|
||||
into constants.
|
||||
|
||||
This is intentionally separate from the tensor PV path. If this case fails, the
|
||||
problem is in scalar fp32 `FEXP.S` execution or normalization rather than TMEM
|
||||
handoff or BWGMMA.
|
||||
@@ -0,0 +1,72 @@
|
||||
#include "../common_wu_min.h"
|
||||
|
||||
extern "C" {
|
||||
volatile uint32_t g_case17_scores_bits[2] __attribute__((aligned(16))) = {
|
||||
0x00000000u, 0x3f317218u}; // 0.0, ln(2)
|
||||
volatile uint32_t g_case17_out_bits[2] __attribute__((aligned(16)));
|
||||
}
|
||||
|
||||
static inline float wu_case17_bits_to_f32(uint32_t bits) {
|
||||
union {
|
||||
uint32_t u;
|
||||
float f;
|
||||
} v = {bits};
|
||||
return v.f;
|
||||
}
|
||||
|
||||
static inline uint32_t wu_case17_f32_to_bits(float value) {
|
||||
union {
|
||||
float f;
|
||||
uint32_t u;
|
||||
} v = {value};
|
||||
return v.u;
|
||||
}
|
||||
|
||||
static inline float wu_case17_absf(float value) {
|
||||
return value < 0.0f ? -value : value;
|
||||
}
|
||||
|
||||
extern "C" int wu_main() {
|
||||
if (vx_core_id() != 0 || vx_warp_id() != 0) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
const uint32_t tid = wu_tid();
|
||||
if (tid == 0) {
|
||||
wu_case_reset();
|
||||
g_case17_out_bits[0] = 0;
|
||||
g_case17_out_bits[1] = 0;
|
||||
}
|
||||
asm volatile("fence rw, rw" ::: "memory");
|
||||
|
||||
const float score0 = wu_case17_bits_to_f32(g_case17_scores_bits[0]);
|
||||
const float score1 = wu_case17_bits_to_f32(g_case17_scores_bits[1]);
|
||||
const float row_max = score0 > score1 ? score0 : score1;
|
||||
const float e0 = wu_fexp_s(score0 - row_max);
|
||||
const float e1 = wu_fexp_s(score1 - row_max);
|
||||
const float inv_sum = 1.0f / (e0 + e1);
|
||||
const float p0 = e0 * inv_sum;
|
||||
const float p1 = e1 * inv_sum;
|
||||
|
||||
if (tid == 0) {
|
||||
g_case17_out_bits[0] = wu_case17_f32_to_bits(p0);
|
||||
g_case17_out_bits[1] = wu_case17_f32_to_bits(p1);
|
||||
}
|
||||
asm volatile("fence rw, rw" ::: "memory");
|
||||
|
||||
if (tid == 0) {
|
||||
const float expected0 = 0.3333333433f;
|
||||
const float expected1 = 0.6666666865f;
|
||||
const float tolerance = 0.0015f;
|
||||
const float err0 = wu_case17_absf(p0 - expected0);
|
||||
const float err1 = wu_case17_absf(p1 - expected1);
|
||||
if (err0 > tolerance || err1 > tolerance) {
|
||||
g_aux[0] = g_case17_out_bits[0];
|
||||
g_aux[1] = g_case17_out_bits[1];
|
||||
wu_case_fail(0x71u);
|
||||
return 1;
|
||||
}
|
||||
wu_case_pass();
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
3
kernels/wu_arch_cases/case18_scalar_fexp/Makefile
Normal file
3
kernels/wu_arch_cases/case18_scalar_fexp/Makefile
Normal file
@@ -0,0 +1,3 @@
|
||||
PROJECT = case18_scalar_fexp
|
||||
|
||||
include ../case.mk
|
||||
5
kernels/wu_arch_cases/case18_scalar_fexp/README.md
Normal file
5
kernels/wu_arch_cases/case18_scalar_fexp/README.md
Normal file
@@ -0,0 +1,5 @@
|
||||
# case18_scalar_fexp
|
||||
|
||||
Verifies scalar-warp execution of the custom `FEXP.S` instruction. The test
|
||||
checks representative fp32 inputs used by softmax-style code paths and confirms
|
||||
the result is close to `expf`.
|
||||
68
kernels/wu_arch_cases/case18_scalar_fexp/kernel.cpp
Normal file
68
kernels/wu_arch_cases/case18_scalar_fexp/kernel.cpp
Normal file
@@ -0,0 +1,68 @@
|
||||
#include "../common_wu_min.h"
|
||||
|
||||
extern "C" {
|
||||
volatile uint32_t g_case18_out_bits[4] __attribute__((aligned(16)));
|
||||
}
|
||||
|
||||
static inline uint32_t f32_to_bits(float value) {
|
||||
union {
|
||||
float f;
|
||||
uint32_t u;
|
||||
} v = {value};
|
||||
return v.u;
|
||||
}
|
||||
|
||||
static inline float absf_local(float value) {
|
||||
return value < 0.0f ? -value : value;
|
||||
}
|
||||
|
||||
extern "C" int wu_main() {
|
||||
if (vx_core_id() != 0 || vx_warp_id() != 0) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
const uint32_t tid = wu_tid();
|
||||
if (tid == 0) {
|
||||
wu_case_reset();
|
||||
for (uint32_t i = 0; i < 4; ++i) {
|
||||
g_case18_out_bits[i] = 0;
|
||||
}
|
||||
}
|
||||
asm volatile("fence rw, rw" ::: "memory");
|
||||
|
||||
const float input0 = 0.0f;
|
||||
const float input1 = 1.0f;
|
||||
const float input2 = -0.6931471805599453f;
|
||||
const float input3 = -10.0f;
|
||||
|
||||
const float out0 = wu_fexp_s(input0);
|
||||
const float out1 = wu_fexp_s(input1);
|
||||
const float out2 = wu_fexp_s(input2);
|
||||
const float out3 = wu_fexp_s(input3);
|
||||
|
||||
if (tid == 0) {
|
||||
g_case18_out_bits[0] = f32_to_bits(out0);
|
||||
g_case18_out_bits[1] = f32_to_bits(out1);
|
||||
g_case18_out_bits[2] = f32_to_bits(out2);
|
||||
g_case18_out_bits[3] = f32_to_bits(out3);
|
||||
}
|
||||
asm volatile("fence rw, rw" ::: "memory");
|
||||
|
||||
if (tid == 0) {
|
||||
const float err0 = absf_local(out0 - 1.0f);
|
||||
const float err1 = absf_local(out1 - 2.7182817459f);
|
||||
const float err2 = absf_local(out2 - 0.5f);
|
||||
const float err3 = absf_local(out3 - 0.00004539993f);
|
||||
if (err0 > 0.00001f || err1 > 0.0002f || err2 > 0.00001f ||
|
||||
err3 > 0.000001f) {
|
||||
g_aux[0] = g_case18_out_bits[0];
|
||||
g_aux[1] = g_case18_out_bits[1];
|
||||
g_aux[2] = g_case18_out_bits[2];
|
||||
g_aux[3] = g_case18_out_bits[3];
|
||||
wu_case_fail(0x18u);
|
||||
return 1;
|
||||
}
|
||||
wu_case_pass();
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
@@ -0,0 +1,3 @@
|
||||
PROJECT = case19_tensor_fexp_illegal
|
||||
|
||||
include ../case.mk
|
||||
@@ -0,0 +1,5 @@
|
||||
# case19_tensor_fexp_illegal
|
||||
|
||||
Negative test for `FEXP.S`: tensor warps must not execute this scalar FPU
|
||||
instruction. Running this case is expected to trip the existing tensor-FPU
|
||||
illegal-instruction path in decode/dispatch rather than complete normally.
|
||||
25
kernels/wu_arch_cases/case19_tensor_fexp_illegal/kernel.cpp
Normal file
25
kernels/wu_arch_cases/case19_tensor_fexp_illegal/kernel.cpp
Normal file
@@ -0,0 +1,25 @@
|
||||
#include "../common_wu_min.h"
|
||||
|
||||
extern "C" void __attribute__((naked, noinline, used)) tensor_worker() {
|
||||
asm volatile(
|
||||
"fmv.w.x f1, x0\n\t"
|
||||
".insn r %[custom1], 2, 0x30, f2, f1, x0\n\t"
|
||||
".insn r %[custom0], 0, 0, x0, x0, x0\n\t"
|
||||
"1: j 1b\n\t"
|
||||
:
|
||||
: [custom0] "i"(RISCV_CUSTOM0),
|
||||
[custom1] "i"(RISCV_CUSTOM1)
|
||||
: "memory");
|
||||
}
|
||||
|
||||
extern "C" int wu_main() {
|
||||
if (!wu_is_leader()) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
wu_case_reset();
|
||||
vx_spawn_tensor(vx_tensor_warp_mask(), tensor_worker);
|
||||
|
||||
wu_case_fail(0x19u);
|
||||
return 1;
|
||||
}
|
||||
3
kernels/wu_arch_cases/case20_flash_bwd_fused/Makefile
Normal file
3
kernels/wu_arch_cases/case20_flash_bwd_fused/Makefile
Normal file
@@ -0,0 +1,3 @@
|
||||
PROJECT = case20_flash_bwd_fused
|
||||
|
||||
include ../case.mk
|
||||
19
kernels/wu_arch_cases/case20_flash_bwd_fused/README.md
Normal file
19
kernels/wu_arch_cases/case20_flash_bwd_fused/README.md
Normal file
@@ -0,0 +1,19 @@
|
||||
# case20_flash_bwd_fused
|
||||
|
||||
FlashAttention backward-style fused pipeline smoke test.
|
||||
|
||||
The tensor warp performs one score MMA, then waits for the scalar warp to run
|
||||
softmax plus dsoftmax on the TMEM C row. The scalar warp writes the dS row back
|
||||
to TMEM A using signed fp16 values. The tensor warp then performs four more
|
||||
MMA steps, for five MMA operations total in this case.
|
||||
|
||||
This case verifies:
|
||||
|
||||
- tensor warp MMA sequencing around a scalar TMEM handoff;
|
||||
- scalar-only `FEXP.S` use for stable softmax;
|
||||
- dsoftmax shape `dS = P * (dP - sum(P * dP))`;
|
||||
- signed scalar TMEM stores feeding later tensor MMA operations.
|
||||
|
||||
The input score row is uniform, so `P = 1/32`. The synthetic upstream gradient
|
||||
uses two buckets, producing exact dS values `-1/32` for row entries 0..15 and
|
||||
`+1/32` for row entries 16..31.
|
||||
289
kernels/wu_arch_cases/case20_flash_bwd_fused/kernel.cpp
Normal file
289
kernels/wu_arch_cases/case20_flash_bwd_fused/kernel.cpp
Normal file
@@ -0,0 +1,289 @@
|
||||
#include "../common_wu_blackwell_fa.h"
|
||||
|
||||
#define WU_CASE20_SCORE_READY 0xa000u
|
||||
#define WU_CASE20_DSOFTMAX_READY 0xa100u
|
||||
#define WU_CASE20_DONE_BASE 0xa200u
|
||||
#define WU_CASE20_ROW_N 32u
|
||||
#define WU_CASE20_FP32_ZERO 0x00000000u
|
||||
#define WU_CASE20_FP32_THIRTY_TWO 0x42000000u
|
||||
|
||||
extern "C" {
|
||||
volatile uint32_t g_case20_q_row[4] __attribute__((aligned(16))) = {
|
||||
WU_BW_FP16_ONE_PACKED, WU_BW_FP16_ONE_PACKED, WU_BW_FP16_ONE_PACKED,
|
||||
WU_BW_FP16_ONE_PACKED};
|
||||
volatile uint32_t g_case20_zero_row[4] __attribute__((aligned(16))) = {
|
||||
WU_CASE20_FP32_ZERO, WU_CASE20_FP32_ZERO, WU_CASE20_FP32_ZERO,
|
||||
WU_CASE20_FP32_ZERO};
|
||||
volatile uint32_t g_case20_out[WU_BW_OUT_WORDS] __attribute__((aligned(16)));
|
||||
volatile uint32_t g_case20_score_bits[4] __attribute__((aligned(16)));
|
||||
volatile uint32_t g_case20_dsoftmax_bits[4] __attribute__((aligned(16)));
|
||||
}
|
||||
|
||||
static inline float wu_case20_bits_to_f32(uint32_t bits) {
|
||||
union {
|
||||
uint32_t u;
|
||||
float f;
|
||||
} v = {bits};
|
||||
return v.f;
|
||||
}
|
||||
|
||||
static inline uint32_t wu_case20_f32_to_bits(float value) {
|
||||
union {
|
||||
float f;
|
||||
uint32_t u;
|
||||
} v = {value};
|
||||
return v.u;
|
||||
}
|
||||
|
||||
static inline float wu_case20_absf(float value) {
|
||||
return value < 0.0f ? -value : value;
|
||||
}
|
||||
|
||||
static inline uint16_t wu_case20_f32_to_f16(float value) {
|
||||
const uint32_t bits = wu_case20_f32_to_bits(value);
|
||||
const uint32_t sign = (bits >> 16) & 0x8000u;
|
||||
const uint32_t exp = (bits >> 23) & 0xffu;
|
||||
uint32_t mant = bits & 0x7fffffu;
|
||||
|
||||
if ((bits & 0x7fffffffu) == 0 || exp == 0) {
|
||||
return static_cast<uint16_t>(sign);
|
||||
}
|
||||
if (exp >= 143u) {
|
||||
return static_cast<uint16_t>(sign | 0x7c00u);
|
||||
}
|
||||
if (exp <= 112u) {
|
||||
return static_cast<uint16_t>(sign);
|
||||
}
|
||||
|
||||
uint32_t half_exp = exp - 112u;
|
||||
mant += 0x1000u;
|
||||
if (mant & 0x800000u) {
|
||||
mant = 0;
|
||||
++half_exp;
|
||||
}
|
||||
if (half_exp >= 31u) {
|
||||
return static_cast<uint16_t>(sign | 0x7c00u);
|
||||
}
|
||||
return static_cast<uint16_t>(sign | (half_exp << 10) | (mant >> 13));
|
||||
}
|
||||
|
||||
static inline uint32_t wu_case20_pack_f16x2(float value) {
|
||||
const uint32_t h = wu_case20_f32_to_f16(value);
|
||||
return h | (h << 16);
|
||||
}
|
||||
|
||||
static inline float wu_case20_dp(uint32_t row_idx) {
|
||||
return row_idx < 16u ? 0.0f : 2.0f;
|
||||
}
|
||||
|
||||
static inline void wu_case20_dsoftmax_tmem_row(uint32_t score_frag_base,
|
||||
uint32_t ds_byte_base) {
|
||||
float row_max = wu_case20_bits_to_f32(wu_bw_scalar_tmem_ld(score_frag_base));
|
||||
for (uint32_t i = 1; i < WU_CASE20_ROW_N; ++i) {
|
||||
const float score =
|
||||
wu_case20_bits_to_f32(wu_bw_scalar_tmem_ld(score_frag_base + i));
|
||||
row_max = score > row_max ? score : row_max;
|
||||
}
|
||||
|
||||
float denom = 0.0f;
|
||||
for (uint32_t i = 0; i < WU_CASE20_ROW_N; ++i) {
|
||||
const float score =
|
||||
wu_case20_bits_to_f32(wu_bw_scalar_tmem_ld(score_frag_base + i));
|
||||
denom += wu_fexp_s(score - row_max);
|
||||
}
|
||||
|
||||
float dot = 0.0f;
|
||||
for (uint32_t i = 0; i < WU_CASE20_ROW_N; ++i) {
|
||||
const float score =
|
||||
wu_case20_bits_to_f32(wu_bw_scalar_tmem_ld(score_frag_base + i));
|
||||
const float p = wu_fexp_s(score - row_max) / denom;
|
||||
dot += p * wu_case20_dp(i);
|
||||
}
|
||||
|
||||
const uint32_t ds_frag_base = ds_byte_base / WU_BW_TMEM_FRAGMENT_BYTES;
|
||||
for (uint32_t frag = 0; frag < WU_BW_TMEM_FRAGMENTS; ++frag) {
|
||||
const uint32_t row_idx = frag % WU_CASE20_ROW_N;
|
||||
const float score =
|
||||
wu_case20_bits_to_f32(wu_bw_scalar_tmem_ld(score_frag_base + row_idx));
|
||||
const float p = wu_fexp_s(score - row_max) / denom;
|
||||
const float ds = p * (wu_case20_dp(row_idx) - dot);
|
||||
if (wu_tid() == 0 && row_idx == 0) {
|
||||
g_case20_dsoftmax_bits[0] = wu_case20_f32_to_bits(ds);
|
||||
}
|
||||
if (wu_tid() == 0 && row_idx == 16u) {
|
||||
g_case20_dsoftmax_bits[1] = wu_case20_f32_to_bits(ds);
|
||||
}
|
||||
wu_bw_scalar_tmem_st(ds_frag_base + frag, wu_case20_pack_f16x2(ds));
|
||||
}
|
||||
asm volatile("fence rw, rw" ::: "memory");
|
||||
}
|
||||
|
||||
extern "C" void __attribute__((naked, noinline, used)) tensor_case20_worker() {
|
||||
asm volatile(
|
||||
"csrr x5, %[csr_wid]\n\t"
|
||||
"addi x1, x5, -%[num_scalar_warps]\n\t"
|
||||
"slli x1, x1, 11\n\t"
|
||||
"addi x2, x1, %[c_offset]\n\t"
|
||||
"la x3, g_case20_q_row\n\t"
|
||||
"li x7, 0\n\t"
|
||||
"1:\n\t"
|
||||
"add x4, x1, x7\n\t"
|
||||
".insn r %[custom3], 2, 0, x0, x4, x3\n\t"
|
||||
"addi x7, x7, 16\n\t"
|
||||
"li x4, %[tile_bytes]\n\t"
|
||||
"blt x7, x4, 1b\n\t"
|
||||
"la x3, g_case20_zero_row\n\t"
|
||||
"li x7, 0\n\t"
|
||||
"2:\n\t"
|
||||
"add x4, x2, x7\n\t"
|
||||
".insn r %[custom3], 2, 0, x0, x4, x3\n\t"
|
||||
"addi x7, x7, 16\n\t"
|
||||
"li x4, %[tile_bytes]\n\t"
|
||||
"blt x7, x4, 2b\n\t"
|
||||
".insn r %[custom3], 3, 0, x0, x0, x0\n\t"
|
||||
"li x4, %[smem_base]\n\t"
|
||||
".insn r %[custom3], 0, 0, x2, x1, x4\n\t"
|
||||
".insn r %[custom3], 1, 0, x0, x0, x0\n\t"
|
||||
"csrr x5, %[csr_wid]\n\t"
|
||||
"slli x6, x5, 2\n\t"
|
||||
"la x7, g_seen\n\t"
|
||||
"add x7, x7, x6\n\t"
|
||||
"li x6, %[score_ready]\n\t"
|
||||
"or x6, x6, x5\n\t"
|
||||
"sw x6, 0(x7)\n\t"
|
||||
"3:\n\t"
|
||||
"la x6, g_case_mem\n\t"
|
||||
"lw x7, 0(x6)\n\t"
|
||||
"li x4, %[dsoftmax_ready]\n\t"
|
||||
"bne x7, x4, 3b\n\t"
|
||||
"la x3, g_case20_zero_row\n\t"
|
||||
"li x7, 0\n\t"
|
||||
"4:\n\t"
|
||||
"add x4, x2, x7\n\t"
|
||||
".insn r %[custom3], 2, 0, x0, x4, x3\n\t"
|
||||
"addi x7, x7, 16\n\t"
|
||||
"li x4, %[tile_bytes]\n\t"
|
||||
"blt x7, x4, 4b\n\t"
|
||||
".insn r %[custom3], 3, 0, x0, x0, x0\n\t"
|
||||
"li x4, %[smem_base]\n\t"
|
||||
".insn r %[custom3], 0, 0, x2, x1, x4\n\t"
|
||||
".insn r %[custom3], 1, 0, x0, x0, x0\n\t"
|
||||
".insn r %[custom3], 0, 0, x2, x1, x4\n\t"
|
||||
".insn r %[custom3], 1, 0, x0, x0, x0\n\t"
|
||||
".insn r %[custom3], 0, 0, x2, x1, x4\n\t"
|
||||
".insn r %[custom3], 1, 0, x0, x0, x0\n\t"
|
||||
".insn r %[custom3], 0, 0, x2, x1, x4\n\t"
|
||||
".insn r %[custom3], 1, 0, x0, x0, x0\n\t"
|
||||
"la x3, g_case20_out\n\t"
|
||||
"li x7, 0\n\t"
|
||||
"5:\n\t"
|
||||
"add x4, x2, x7\n\t"
|
||||
"add x6, x3, x7\n\t"
|
||||
".insn r %[custom3], 6, 0, x0, x4, x6\n\t"
|
||||
"addi x7, x7, 16\n\t"
|
||||
"li x4, %[tile_bytes]\n\t"
|
||||
"blt x7, x4, 5b\n\t"
|
||||
".insn r %[custom3], 3, 0, x0, x0, x0\n\t"
|
||||
"csrr x5, %[csr_wid]\n\t"
|
||||
"slli x6, x5, 2\n\t"
|
||||
"la x7, g_seen\n\t"
|
||||
"add x7, x7, x6\n\t"
|
||||
"li x6, %[done_base]\n\t"
|
||||
"or x6, x6, x5\n\t"
|
||||
"sw x6, 0(x7)\n\t"
|
||||
".insn r %[custom0], 0, 0, x0, x0, x0\n\t"
|
||||
"6: j 6b\n\t"
|
||||
:
|
||||
: [csr_wid] "i"(VX_CSR_WARP_ID), [custom0] "i"(RISCV_CUSTOM0),
|
||||
[custom3] "i"(RISCV_CUSTOM3),
|
||||
[num_scalar_warps] "i"(NUM_SCALAR_WARPS),
|
||||
[c_offset] "i"(WU_BW_TMEM_C_BYTE_OFFSET),
|
||||
[tile_bytes] "i"(WU_BW_TMEM_TILE_BYTES),
|
||||
[smem_base] "i"(WU_BW_DEV_SMEM_START_ADDR),
|
||||
[score_ready] "i"(WU_CASE20_SCORE_READY),
|
||||
[dsoftmax_ready] "i"(WU_CASE20_DSOFTMAX_READY),
|
||||
[done_base] "i"(WU_CASE20_DONE_BASE)
|
||||
: "memory");
|
||||
}
|
||||
|
||||
extern "C" int wu_main() {
|
||||
if (vx_core_id() != 0 || vx_warp_id() != 0) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
const uint32_t tid = wu_tid();
|
||||
const uint32_t tensor_mask = 1u << NUM_SCALAR_WARPS;
|
||||
if (tid == 0) {
|
||||
wu_case_reset();
|
||||
for (uint32_t i = 0; i < WU_BW_OUT_WORDS; ++i) {
|
||||
g_case20_out[i] = 0xffffffffu;
|
||||
}
|
||||
for (uint32_t i = 0; i < 4; ++i) {
|
||||
g_case20_score_bits[i] = 0;
|
||||
g_case20_dsoftmax_bits[i] = 0;
|
||||
}
|
||||
wu_bw_fill_smem_tile(
|
||||
reinterpret_cast<volatile uint32_t *>(WU_BW_DEV_SMEM_START_ADDR),
|
||||
WU_BW_FP16_ONE_PACKED);
|
||||
vx_spawn_tensor(tensor_mask, tensor_case20_worker);
|
||||
if (wu_wait_seen_mask(tensor_mask, WU_CASE20_SCORE_READY) != 0) {
|
||||
g_case_mem[1] = 0x41u;
|
||||
}
|
||||
}
|
||||
asm volatile("fence rw, rw" ::: "memory");
|
||||
|
||||
const uint32_t c_frag =
|
||||
wu_bw_tmem_c_byte_base(0) / WU_BW_TMEM_FRAGMENT_BYTES;
|
||||
const uint32_t observed = wu_bw_scalar_tmem_ld(c_frag);
|
||||
if (tid == 0) {
|
||||
g_case20_score_bits[0] = observed;
|
||||
if (g_case_mem[1] == 0 && observed != WU_CASE20_FP32_THIRTY_TWO) {
|
||||
g_aux[0] = observed;
|
||||
g_case_mem[1] = 0x42u;
|
||||
}
|
||||
}
|
||||
asm volatile("fence rw, rw" ::: "memory");
|
||||
|
||||
if (g_case_mem[1] == 0) {
|
||||
vx_tmc(wu_bw_all_lanes_mask());
|
||||
wu_case20_dsoftmax_tmem_row(c_frag, wu_bw_tmem_a_byte_base(0));
|
||||
vx_tmc_one();
|
||||
}
|
||||
asm volatile("fence rw, rw" ::: "memory");
|
||||
|
||||
if (tid == 0) {
|
||||
if (g_case_mem[1] == 0) {
|
||||
const float neg = wu_case20_bits_to_f32(g_case20_dsoftmax_bits[0]);
|
||||
const float pos = wu_case20_bits_to_f32(g_case20_dsoftmax_bits[1]);
|
||||
if (wu_case20_absf(neg + 0.03125f) > 0.0002f ||
|
||||
wu_case20_absf(pos - 0.03125f) > 0.0002f) {
|
||||
g_aux[0] = g_case20_dsoftmax_bits[0];
|
||||
g_aux[1] = g_case20_dsoftmax_bits[1];
|
||||
g_case_mem[1] = 0x43u;
|
||||
}
|
||||
}
|
||||
asm volatile("fence rw, rw" ::: "memory");
|
||||
g_case_mem[0] = WU_CASE20_DSOFTMAX_READY;
|
||||
if (g_case_mem[1] == 0 &&
|
||||
wu_wait_seen_mask(tensor_mask, WU_CASE20_DONE_BASE) != 0) {
|
||||
g_case_mem[1] = 0x44u;
|
||||
}
|
||||
if (g_case_mem[1] == 0) {
|
||||
volatile uint32_t bad_actual = 0;
|
||||
const uint32_t bad =
|
||||
wu_bw_verify_constant(g_case20_out, WU_BW_OUT_WORDS,
|
||||
WU_CASE20_FP32_ZERO, &bad_actual);
|
||||
if (bad != WU_BW_OUT_WORDS) {
|
||||
g_aux[0] = bad;
|
||||
g_aux[1] = bad_actual;
|
||||
g_case_mem[1] = 0x45u;
|
||||
}
|
||||
}
|
||||
if (g_case_mem[1] != 0) {
|
||||
wu_case_fail(g_case_mem[1]);
|
||||
return 1;
|
||||
}
|
||||
wu_case_pass();
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
3
kernels/wu_arch_cases/case21_moe_gating/Makefile
Normal file
3
kernels/wu_arch_cases/case21_moe_gating/Makefile
Normal file
@@ -0,0 +1,3 @@
|
||||
PROJECT = case21_moe_gating
|
||||
|
||||
include ../case.mk
|
||||
10
kernels/wu_arch_cases/case21_moe_gating/README.md
Normal file
10
kernels/wu_arch_cases/case21_moe_gating/README.md
Normal file
@@ -0,0 +1,10 @@
|
||||
# case21_moe_gating
|
||||
|
||||
MoE gating scalar pipeline test.
|
||||
|
||||
This case runs `softmax -> Top-K -> scatter` on scalar warp 0 using `FEXP.S`.
|
||||
The logits are `log(1), log(2), log(4), log(8)`, so the expected probabilities
|
||||
are `1/15, 2/15, 4/15, 8/15`. Top-2 should select experts 3 and 2, then scatter
|
||||
the token id and weight into the selected expert slots.
|
||||
|
||||
No tensor warp is spawned in this case.
|
||||
128
kernels/wu_arch_cases/case21_moe_gating/kernel.cpp
Normal file
128
kernels/wu_arch_cases/case21_moe_gating/kernel.cpp
Normal file
@@ -0,0 +1,128 @@
|
||||
#include "../common_wu_min.h"
|
||||
|
||||
#define WU_CASE21_TOKEN_ID 0x21u
|
||||
#define WU_CASE21_EMPTY 0xffffffffu
|
||||
|
||||
extern "C" {
|
||||
volatile uint32_t g_case21_logits_bits[4] __attribute__((aligned(16))) = {
|
||||
0x00000000u, 0x3f317218u, 0x3fb17218u, 0x40051d8fu};
|
||||
volatile uint32_t g_case21_prob_bits[4] __attribute__((aligned(16)));
|
||||
volatile uint32_t g_case21_top_idx[2] __attribute__((aligned(16)));
|
||||
volatile uint32_t g_case21_expert_token[4] __attribute__((aligned(16)));
|
||||
volatile uint32_t g_case21_expert_weight_bits[4] __attribute__((aligned(16)));
|
||||
}
|
||||
|
||||
static inline float wu_case21_bits_to_f32(uint32_t bits) {
|
||||
union {
|
||||
uint32_t u;
|
||||
float f;
|
||||
} v = {bits};
|
||||
return v.f;
|
||||
}
|
||||
|
||||
static inline uint32_t wu_case21_f32_to_bits(float value) {
|
||||
union {
|
||||
float f;
|
||||
uint32_t u;
|
||||
} v = {value};
|
||||
return v.u;
|
||||
}
|
||||
|
||||
static inline float wu_case21_absf(float value) {
|
||||
return value < 0.0f ? -value : value;
|
||||
}
|
||||
|
||||
extern "C" int wu_main() {
|
||||
if (vx_core_id() != 0 || vx_warp_id() != 0) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
const uint32_t tid = wu_tid();
|
||||
if (tid == 0) {
|
||||
wu_case_reset();
|
||||
for (uint32_t i = 0; i < 4; ++i) {
|
||||
g_case21_prob_bits[i] = 0;
|
||||
g_case21_expert_token[i] = WU_CASE21_EMPTY;
|
||||
g_case21_expert_weight_bits[i] = 0;
|
||||
}
|
||||
g_case21_top_idx[0] = WU_CASE21_EMPTY;
|
||||
g_case21_top_idx[1] = WU_CASE21_EMPTY;
|
||||
}
|
||||
asm volatile("fence rw, rw" ::: "memory");
|
||||
|
||||
float logits[4];
|
||||
float row_max = wu_case21_bits_to_f32(g_case21_logits_bits[0]);
|
||||
for (uint32_t i = 0; i < 4; ++i) {
|
||||
logits[i] = wu_case21_bits_to_f32(g_case21_logits_bits[i]);
|
||||
row_max = logits[i] > row_max ? logits[i] : row_max;
|
||||
}
|
||||
|
||||
float exp_values[4];
|
||||
float denom = 0.0f;
|
||||
for (uint32_t i = 0; i < 4; ++i) {
|
||||
exp_values[i] = wu_fexp_s(logits[i] - row_max);
|
||||
denom += exp_values[i];
|
||||
}
|
||||
|
||||
float probs[4];
|
||||
for (uint32_t i = 0; i < 4; ++i) {
|
||||
probs[i] = exp_values[i] / denom;
|
||||
}
|
||||
|
||||
uint32_t top0 = 0;
|
||||
uint32_t top1 = 1;
|
||||
if (probs[top1] > probs[top0]) {
|
||||
const uint32_t tmp = top0;
|
||||
top0 = top1;
|
||||
top1 = tmp;
|
||||
}
|
||||
for (uint32_t i = 2; i < 4; ++i) {
|
||||
if (probs[i] > probs[top0]) {
|
||||
top1 = top0;
|
||||
top0 = i;
|
||||
} else if (probs[i] > probs[top1]) {
|
||||
top1 = i;
|
||||
}
|
||||
}
|
||||
|
||||
if (tid == 0) {
|
||||
for (uint32_t i = 0; i < 4; ++i) {
|
||||
g_case21_prob_bits[i] = wu_case21_f32_to_bits(probs[i]);
|
||||
}
|
||||
g_case21_top_idx[0] = top0;
|
||||
g_case21_top_idx[1] = top1;
|
||||
g_case21_expert_token[top0] = WU_CASE21_TOKEN_ID;
|
||||
g_case21_expert_token[top1] = WU_CASE21_TOKEN_ID;
|
||||
g_case21_expert_weight_bits[top0] = wu_case21_f32_to_bits(probs[top0]);
|
||||
g_case21_expert_weight_bits[top1] = wu_case21_f32_to_bits(probs[top1]);
|
||||
}
|
||||
asm volatile("fence rw, rw" ::: "memory");
|
||||
|
||||
if (tid == 0) {
|
||||
const float expected[4] = {0.0666666701f, 0.1333333403f,
|
||||
0.2666666806f, 0.5333333611f};
|
||||
const float tolerance = 0.0015f;
|
||||
for (uint32_t i = 0; i < 4; ++i) {
|
||||
if (wu_case21_absf(probs[i] - expected[i]) > tolerance) {
|
||||
g_aux[0] = i;
|
||||
g_aux[1] = g_case21_prob_bits[i];
|
||||
wu_case_fail(0x21u);
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
if (top0 != 3u || top1 != 2u ||
|
||||
g_case21_expert_token[3] != WU_CASE21_TOKEN_ID ||
|
||||
g_case21_expert_token[2] != WU_CASE21_TOKEN_ID ||
|
||||
g_case21_expert_token[0] != WU_CASE21_EMPTY ||
|
||||
g_case21_expert_token[1] != WU_CASE21_EMPTY) {
|
||||
g_aux[0] = top0;
|
||||
g_aux[1] = top1;
|
||||
g_aux[2] = g_case21_expert_token[3];
|
||||
g_aux[3] = g_case21_expert_token[2];
|
||||
wu_case_fail(0x22u);
|
||||
return 1;
|
||||
}
|
||||
wu_case_pass();
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
3
kernels/wu_arch_cases/case22_gemm_silu/Makefile
Normal file
3
kernels/wu_arch_cases/case22_gemm_silu/Makefile
Normal file
@@ -0,0 +1,3 @@
|
||||
PROJECT = case22_gemm_silu
|
||||
|
||||
include ../case.mk
|
||||
10
kernels/wu_arch_cases/case22_gemm_silu/README.md
Normal file
10
kernels/wu_arch_cases/case22_gemm_silu/README.md
Normal file
@@ -0,0 +1,10 @@
|
||||
# case22_gemm_silu
|
||||
|
||||
GEMM plus SiLU fusion smoke test.
|
||||
|
||||
The tensor warp computes a compact GEMM with fp16 `A = 0.125` and fp16 `B = 1`,
|
||||
producing fp32 `C = 4`. Scalar warp 0 reads TMEM C and applies
|
||||
`SiLU(x) = x / (1 + exp(-x))` using scalar-only `FEXP.S`.
|
||||
|
||||
This case verifies the common `matmul -> nonlinear activation` fusion path
|
||||
without allowing tensor warp FPU execution.
|
||||
160
kernels/wu_arch_cases/case22_gemm_silu/kernel.cpp
Normal file
160
kernels/wu_arch_cases/case22_gemm_silu/kernel.cpp
Normal file
@@ -0,0 +1,160 @@
|
||||
#include "../common_wu_blackwell_fa.h"
|
||||
|
||||
#define WU_CASE22_INIT_BASE 0xa400u
|
||||
#define WU_CASE22_DONE_BASE 0xa500u
|
||||
#define WU_CASE22_FP16_ONE_EIGHTH_PACKED 0x30003000u
|
||||
#define WU_CASE22_FP32_FOUR 0x40800000u
|
||||
|
||||
extern "C" {
|
||||
volatile uint32_t g_case22_a_row[4] __attribute__((aligned(16))) = {
|
||||
WU_CASE22_FP16_ONE_EIGHTH_PACKED, WU_CASE22_FP16_ONE_EIGHTH_PACKED,
|
||||
WU_CASE22_FP16_ONE_EIGHTH_PACKED, WU_CASE22_FP16_ONE_EIGHTH_PACKED};
|
||||
volatile uint32_t g_case22_zero_row[4] __attribute__((aligned(16))) = {
|
||||
0x00000000u, 0x00000000u, 0x00000000u, 0x00000000u};
|
||||
volatile uint32_t g_case22_out[WU_BW_OUT_WORDS] __attribute__((aligned(16)));
|
||||
volatile uint32_t g_case22_silu_bits[4] __attribute__((aligned(16)));
|
||||
}
|
||||
|
||||
static inline float wu_case22_bits_to_f32(uint32_t bits) {
|
||||
union {
|
||||
uint32_t u;
|
||||
float f;
|
||||
} v = {bits};
|
||||
return v.f;
|
||||
}
|
||||
|
||||
static inline uint32_t wu_case22_f32_to_bits(float value) {
|
||||
union {
|
||||
float f;
|
||||
uint32_t u;
|
||||
} v = {value};
|
||||
return v.u;
|
||||
}
|
||||
|
||||
static inline float wu_case22_absf(float value) {
|
||||
return value < 0.0f ? -value : value;
|
||||
}
|
||||
|
||||
extern "C" void __attribute__((naked, noinline, used)) tensor_case22_worker() {
|
||||
asm volatile(
|
||||
"csrr x5, %[csr_wid]\n\t"
|
||||
"addi x1, x5, -%[num_scalar_warps]\n\t"
|
||||
"slli x1, x1, 11\n\t"
|
||||
"addi x2, x1, %[c_offset]\n\t"
|
||||
"la x3, g_case22_a_row\n\t"
|
||||
"li x7, 0\n\t"
|
||||
"1:\n\t"
|
||||
"add x4, x1, x7\n\t"
|
||||
".insn r %[custom3], 2, 0, x0, x4, x3\n\t"
|
||||
"addi x7, x7, 16\n\t"
|
||||
"li x4, %[tile_bytes]\n\t"
|
||||
"blt x7, x4, 1b\n\t"
|
||||
"la x3, g_case22_zero_row\n\t"
|
||||
"li x7, 0\n\t"
|
||||
"2:\n\t"
|
||||
"add x4, x2, x7\n\t"
|
||||
".insn r %[custom3], 2, 0, x0, x4, x3\n\t"
|
||||
"addi x7, x7, 16\n\t"
|
||||
"li x4, %[tile_bytes]\n\t"
|
||||
"blt x7, x4, 2b\n\t"
|
||||
".insn r %[custom3], 3, 0, x0, x0, x0\n\t"
|
||||
"li x4, %[smem_base]\n\t"
|
||||
".insn r %[custom3], 0, 0, x2, x1, x4\n\t"
|
||||
".insn r %[custom3], 1, 0, x0, x0, x0\n\t"
|
||||
"la x3, g_case22_out\n\t"
|
||||
"li x7, 0\n\t"
|
||||
"3:\n\t"
|
||||
"add x4, x2, x7\n\t"
|
||||
"add x6, x3, x7\n\t"
|
||||
".insn r %[custom3], 6, 0, x0, x4, x6\n\t"
|
||||
"addi x7, x7, 16\n\t"
|
||||
"li x4, %[tile_bytes]\n\t"
|
||||
"blt x7, x4, 3b\n\t"
|
||||
".insn r %[custom3], 3, 0, x0, x0, x0\n\t"
|
||||
"csrr x5, %[csr_wid]\n\t"
|
||||
"slli x6, x5, 2\n\t"
|
||||
"la x7, g_seen\n\t"
|
||||
"add x7, x7, x6\n\t"
|
||||
"li x6, %[done_base]\n\t"
|
||||
"or x6, x6, x5\n\t"
|
||||
"sw x6, 0(x7)\n\t"
|
||||
".insn r %[custom0], 0, 0, x0, x0, x0\n\t"
|
||||
"4: j 4b\n\t"
|
||||
:
|
||||
: [csr_wid] "i"(VX_CSR_WARP_ID), [custom0] "i"(RISCV_CUSTOM0),
|
||||
[custom3] "i"(RISCV_CUSTOM3),
|
||||
[num_scalar_warps] "i"(NUM_SCALAR_WARPS),
|
||||
[c_offset] "i"(WU_BW_TMEM_C_BYTE_OFFSET),
|
||||
[tile_bytes] "i"(WU_BW_TMEM_TILE_BYTES),
|
||||
[smem_base] "i"(WU_BW_DEV_SMEM_START_ADDR),
|
||||
[done_base] "i"(WU_CASE22_DONE_BASE)
|
||||
: "memory");
|
||||
}
|
||||
|
||||
extern "C" int wu_main() {
|
||||
if (vx_core_id() != 0 || vx_warp_id() != 0) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
const uint32_t tid = wu_tid();
|
||||
const uint32_t tensor_mask = 1u << NUM_SCALAR_WARPS;
|
||||
if (tid == 0) {
|
||||
wu_case_reset();
|
||||
for (uint32_t i = 0; i < WU_BW_OUT_WORDS; ++i) {
|
||||
g_case22_out[i] = 0;
|
||||
}
|
||||
for (uint32_t i = 0; i < 4; ++i) {
|
||||
g_case22_silu_bits[i] = 0;
|
||||
}
|
||||
wu_bw_fill_smem_tile(
|
||||
reinterpret_cast<volatile uint32_t *>(WU_BW_DEV_SMEM_START_ADDR),
|
||||
WU_BW_FP16_ONE_PACKED);
|
||||
vx_spawn_tensor(tensor_mask, tensor_case22_worker);
|
||||
if (wu_wait_seen_mask(tensor_mask, WU_CASE22_DONE_BASE) != 0) {
|
||||
g_case_mem[1] = 0x51u;
|
||||
}
|
||||
}
|
||||
asm volatile("fence rw, rw" ::: "memory");
|
||||
|
||||
const uint32_t c_frag =
|
||||
wu_bw_tmem_c_byte_base(0) / WU_BW_TMEM_FRAGMENT_BYTES;
|
||||
const uint32_t observed_bits = wu_bw_scalar_tmem_ld(c_frag);
|
||||
const float observed = wu_case22_bits_to_f32(observed_bits);
|
||||
const float silu = observed / (1.0f + wu_fexp_s(-observed));
|
||||
|
||||
if (tid == 0) {
|
||||
g_case22_silu_bits[0] = wu_case22_f32_to_bits(silu);
|
||||
}
|
||||
asm volatile("fence rw, rw" ::: "memory");
|
||||
|
||||
if (tid == 0) {
|
||||
if (g_case_mem[1] == 0 && observed_bits != WU_CASE22_FP32_FOUR) {
|
||||
g_aux[0] = observed_bits;
|
||||
g_case_mem[1] = 0x52u;
|
||||
}
|
||||
if (g_case_mem[1] == 0) {
|
||||
const float expected = 3.9280550480f;
|
||||
if (wu_case22_absf(silu - expected) > 0.004f) {
|
||||
g_aux[0] = g_case22_silu_bits[0];
|
||||
g_case_mem[1] = 0x53u;
|
||||
}
|
||||
}
|
||||
if (g_case_mem[1] == 0) {
|
||||
volatile uint32_t bad_actual = 0;
|
||||
const uint32_t bad =
|
||||
wu_bw_verify_constant(g_case22_out, WU_BW_OUT_WORDS,
|
||||
WU_CASE22_FP32_FOUR, &bad_actual);
|
||||
if (bad != WU_BW_OUT_WORDS) {
|
||||
g_aux[0] = bad;
|
||||
g_aux[1] = bad_actual;
|
||||
g_case_mem[1] = 0x54u;
|
||||
}
|
||||
}
|
||||
if (g_case_mem[1] != 0) {
|
||||
wu_case_fail(g_case_mem[1]);
|
||||
return 1;
|
||||
}
|
||||
wu_case_pass();
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
3
kernels/wu_arch_cases/case23_softmax_only/Makefile
Normal file
3
kernels/wu_arch_cases/case23_softmax_only/Makefile
Normal file
@@ -0,0 +1,3 @@
|
||||
PROJECT = case23_softmax_only
|
||||
|
||||
include ../case.mk
|
||||
9
kernels/wu_arch_cases/case23_softmax_only/README.md
Normal file
9
kernels/wu_arch_cases/case23_softmax_only/README.md
Normal file
@@ -0,0 +1,9 @@
|
||||
# case23_softmax_only
|
||||
|
||||
Scalar softmax-only test.
|
||||
|
||||
This case runs a stable 4-way softmax on scalar warp 0 using `FEXP.S`. The logits
|
||||
are `log(1), log(3), log(5), log(7)`, giving expected probabilities
|
||||
`1/16, 3/16, 5/16, 7/16`.
|
||||
|
||||
No tensor warp is spawned in this case.
|
||||
83
kernels/wu_arch_cases/case23_softmax_only/kernel.cpp
Normal file
83
kernels/wu_arch_cases/case23_softmax_only/kernel.cpp
Normal file
@@ -0,0 +1,83 @@
|
||||
#include "../common_wu_min.h"
|
||||
|
||||
extern "C" {
|
||||
volatile uint32_t g_case23_scores_bits[4] __attribute__((aligned(16))) = {
|
||||
0x00000000u, 0x3f8c9f54u, 0x3fcdf854u, 0x3ff91395u};
|
||||
volatile uint32_t g_case23_out_bits[4] __attribute__((aligned(16)));
|
||||
}
|
||||
|
||||
static inline float wu_case23_bits_to_f32(uint32_t bits) {
|
||||
union {
|
||||
uint32_t u;
|
||||
float f;
|
||||
} v = {bits};
|
||||
return v.f;
|
||||
}
|
||||
|
||||
static inline uint32_t wu_case23_f32_to_bits(float value) {
|
||||
union {
|
||||
float f;
|
||||
uint32_t u;
|
||||
} v = {value};
|
||||
return v.u;
|
||||
}
|
||||
|
||||
static inline float wu_case23_absf(float value) {
|
||||
return value < 0.0f ? -value : value;
|
||||
}
|
||||
|
||||
extern "C" int wu_main() {
|
||||
if (vx_core_id() != 0 || vx_warp_id() != 0) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
const uint32_t tid = wu_tid();
|
||||
if (tid == 0) {
|
||||
wu_case_reset();
|
||||
for (uint32_t i = 0; i < 4; ++i) {
|
||||
g_case23_out_bits[i] = 0;
|
||||
}
|
||||
}
|
||||
asm volatile("fence rw, rw" ::: "memory");
|
||||
|
||||
float scores[4];
|
||||
float row_max = wu_case23_bits_to_f32(g_case23_scores_bits[0]);
|
||||
for (uint32_t i = 0; i < 4; ++i) {
|
||||
scores[i] = wu_case23_bits_to_f32(g_case23_scores_bits[i]);
|
||||
row_max = scores[i] > row_max ? scores[i] : row_max;
|
||||
}
|
||||
|
||||
float exp_values[4];
|
||||
float denom = 0.0f;
|
||||
for (uint32_t i = 0; i < 4; ++i) {
|
||||
exp_values[i] = wu_fexp_s(scores[i] - row_max);
|
||||
denom += exp_values[i];
|
||||
}
|
||||
|
||||
float probs[4];
|
||||
for (uint32_t i = 0; i < 4; ++i) {
|
||||
probs[i] = exp_values[i] / denom;
|
||||
}
|
||||
|
||||
if (tid == 0) {
|
||||
for (uint32_t i = 0; i < 4; ++i) {
|
||||
g_case23_out_bits[i] = wu_case23_f32_to_bits(probs[i]);
|
||||
}
|
||||
}
|
||||
asm volatile("fence rw, rw" ::: "memory");
|
||||
|
||||
if (tid == 0) {
|
||||
const float expected[4] = {0.0625f, 0.1875f, 0.3125f, 0.4375f};
|
||||
const float tolerance = 0.0015f;
|
||||
for (uint32_t i = 0; i < 4; ++i) {
|
||||
if (wu_case23_absf(probs[i] - expected[i]) > tolerance) {
|
||||
g_aux[0] = i;
|
||||
g_aux[1] = g_case23_out_bits[i];
|
||||
wu_case_fail(0x23u);
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
wu_case_pass();
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
3
kernels/wu_arch_cases/case24_flash_sw_pipeline/Makefile
Normal file
3
kernels/wu_arch_cases/case24_flash_sw_pipeline/Makefile
Normal file
@@ -0,0 +1,3 @@
|
||||
PROJECT = case24_flash_sw_pipeline
|
||||
|
||||
include ../case.mk
|
||||
24
kernels/wu_arch_cases/case24_flash_sw_pipeline/README.md
Normal file
24
kernels/wu_arch_cases/case24_flash_sw_pipeline/README.md
Normal file
@@ -0,0 +1,24 @@
|
||||
# case24_flash_sw_pipeline
|
||||
|
||||
Software-pipelined FlashAttention-style multi-iteration case.
|
||||
|
||||
This case keeps `case16_flash_full_pipeline` as the single-tile
|
||||
producer/consumer baseline and adds a four-iteration ping-pong pipeline:
|
||||
|
||||
```text
|
||||
tensor warp 2 / slot 0: QK(0) -> wait P(0) -> PV(0) -> QK(2) -> wait P(2) -> PV(2)
|
||||
tensor warp 3 / slot 1: QK(1) -> wait P(1) -> PV(1) -> QK(3) -> wait P(3) -> PV(3)
|
||||
scalar warp 0: softmax(0) -> softmax(1) -> softmax(2) -> softmax(3)
|
||||
```
|
||||
|
||||
Each tensor warp owns one TMEM slot. The tensor warp writes `S = Q @ K` into
|
||||
TMEM C for its slot, marks `score_ready[iter]`, waits for scalar-generated
|
||||
`P`, then computes `O = P @ V`. Scalar warp 0 waits on each score in order,
|
||||
uses scalar-only `FEXP.S` for stable softmax, writes packed fp16 probabilities
|
||||
back to the same slot's TMEM A, and marks `p_ready[iter]`.
|
||||
|
||||
The first version intentionally uses constant `Q`, `K`, and `V` so the expected
|
||||
numeric result is simple: every score is fp32 `32.0`, every softmax row is
|
||||
uniform `1/32`, and every output word is fp32 `1.0`. The test objective is the
|
||||
multi-iteration overlap structure and per-slot handoff, not non-uniform
|
||||
FlashAttention numerics.
|
||||
318
kernels/wu_arch_cases/case24_flash_sw_pipeline/kernel.cpp
Normal file
318
kernels/wu_arch_cases/case24_flash_sw_pipeline/kernel.cpp
Normal file
@@ -0,0 +1,318 @@
|
||||
#define WU_CASE_WAIT_SPIN 16384u
|
||||
|
||||
#include "../common_wu_blackwell_fa.h"
|
||||
|
||||
#define WU_CASE24_ITER_N 4u
|
||||
#define WU_CASE24_ROW_N 32u
|
||||
#define WU_CASE24_SCORE_READY_BASE 0xb000u
|
||||
#define WU_CASE24_P_READY_BASE 0xb100u
|
||||
#define WU_CASE24_DONE_BASE 0xb200u
|
||||
#define WU_CASE24_FP32_ZERO 0x00000000u
|
||||
#define WU_CASE24_FP32_THIRTY_TWO 0x42000000u
|
||||
#define WU_CASE24_OUT_WORDS (WU_CASE24_ITER_N * WU_BW_OUT_WORDS)
|
||||
|
||||
extern "C" {
|
||||
volatile uint32_t g_case24_q_row[4] __attribute__((aligned(16))) = {
|
||||
WU_BW_FP16_ONE_PACKED, WU_BW_FP16_ONE_PACKED, WU_BW_FP16_ONE_PACKED,
|
||||
WU_BW_FP16_ONE_PACKED};
|
||||
volatile uint32_t g_case24_zero_row[4] __attribute__((aligned(16))) = {
|
||||
WU_CASE24_FP32_ZERO, WU_CASE24_FP32_ZERO, WU_CASE24_FP32_ZERO,
|
||||
WU_CASE24_FP32_ZERO};
|
||||
volatile uint32_t g_case24_score_ready[WU_CASE24_ITER_N]
|
||||
__attribute__((aligned(16)));
|
||||
volatile uint32_t g_case24_p_ready[WU_CASE24_ITER_N]
|
||||
__attribute__((aligned(16)));
|
||||
volatile uint32_t g_case24_done[WU_CASE24_ITER_N] __attribute__((aligned(16)));
|
||||
volatile uint32_t g_case24_score_bits[WU_CASE24_ITER_N * NUM_THREADS]
|
||||
__attribute__((aligned(16)));
|
||||
volatile uint32_t g_case24_p_bits[WU_CASE24_ITER_N * NUM_THREADS]
|
||||
__attribute__((aligned(16)));
|
||||
volatile uint32_t g_case24_overlap_hint __attribute__((aligned(16)));
|
||||
volatile uint32_t g_case24_out[WU_CASE24_OUT_WORDS]
|
||||
__attribute__((aligned(16)));
|
||||
}
|
||||
|
||||
static inline float wu_case24_bits_to_f32(uint32_t bits) {
|
||||
union {
|
||||
uint32_t u;
|
||||
float f;
|
||||
} v = {bits};
|
||||
return v.f;
|
||||
}
|
||||
|
||||
static inline uint32_t wu_case24_f32_to_bits(float value) {
|
||||
union {
|
||||
float f;
|
||||
uint32_t u;
|
||||
} v = {value};
|
||||
return v.u;
|
||||
}
|
||||
|
||||
static inline uint16_t wu_case24_f32_to_f16_positive(float value) {
|
||||
const uint32_t bits = wu_case24_f32_to_bits(value);
|
||||
const uint32_t exp = (bits >> 23) & 0xffu;
|
||||
uint32_t mant = bits & 0x7fffffu;
|
||||
|
||||
if (exp == 0 || value <= 0.0f) {
|
||||
return 0;
|
||||
}
|
||||
if (exp >= 143u) {
|
||||
return 0x7c00u;
|
||||
}
|
||||
if (exp <= 112u) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
uint32_t half_exp = exp - 112u;
|
||||
mant += 0x1000u;
|
||||
if (mant & 0x800000u) {
|
||||
mant = 0;
|
||||
++half_exp;
|
||||
}
|
||||
if (half_exp >= 31u) {
|
||||
return 0x7c00u;
|
||||
}
|
||||
return static_cast<uint16_t>((half_exp << 10) | (mant >> 13));
|
||||
}
|
||||
|
||||
static inline uint32_t wu_case24_pack_f16x2(float value) {
|
||||
const uint32_t h = wu_case24_f32_to_f16_positive(value);
|
||||
return h | (h << 16);
|
||||
}
|
||||
|
||||
static inline int wu_case24_wait_status(volatile uint32_t *status,
|
||||
uint32_t iter, uint32_t base) {
|
||||
const uint32_t expected = base | iter;
|
||||
for (uint32_t spin = 0; spin < WU_CASE_WAIT_SPIN; ++spin) {
|
||||
if (status[iter] == expected) {
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
return 1;
|
||||
}
|
||||
|
||||
static inline void wu_case24_softmax_tmem_row_to_p(uint32_t iter,
|
||||
uint32_t score_frag_base,
|
||||
uint32_t p_byte_base) {
|
||||
float row_max = wu_case24_bits_to_f32(wu_bw_scalar_tmem_ld(score_frag_base));
|
||||
for (uint32_t i = 1; i < WU_CASE24_ROW_N; ++i) {
|
||||
const float score =
|
||||
wu_case24_bits_to_f32(wu_bw_scalar_tmem_ld(score_frag_base + i));
|
||||
row_max = score > row_max ? score : row_max;
|
||||
}
|
||||
|
||||
float denom = 0.0f;
|
||||
for (uint32_t i = 0; i < WU_CASE24_ROW_N; ++i) {
|
||||
const float score =
|
||||
wu_case24_bits_to_f32(wu_bw_scalar_tmem_ld(score_frag_base + i));
|
||||
denom += wu_fexp_s(score - row_max);
|
||||
}
|
||||
|
||||
const uint32_t p_frag_base = p_byte_base / WU_BW_TMEM_FRAGMENT_BYTES;
|
||||
for (uint32_t frag = 0; frag < WU_BW_TMEM_FRAGMENTS; ++frag) {
|
||||
const uint32_t row_idx = frag % WU_CASE24_ROW_N;
|
||||
const float score =
|
||||
wu_case24_bits_to_f32(wu_bw_scalar_tmem_ld(score_frag_base + row_idx));
|
||||
const float p = wu_fexp_s(score - row_max) / denom;
|
||||
if (frag == 0) {
|
||||
g_case24_p_bits[iter * NUM_THREADS + wu_tid()] =
|
||||
wu_case24_f32_to_bits(p);
|
||||
}
|
||||
wu_bw_scalar_tmem_st(p_frag_base + frag, wu_case24_pack_f16x2(p));
|
||||
}
|
||||
asm volatile("fence rw, rw" ::: "memory");
|
||||
}
|
||||
|
||||
extern "C" void __attribute__((naked, noinline, used)) tensor_case24_worker() {
|
||||
asm volatile(
|
||||
"csrr x5, %[csr_wid]\n\t"
|
||||
"addi x8, x5, -%[num_scalar_warps]\n\t"
|
||||
"slli x1, x8, 11\n\t"
|
||||
"addi x2, x1, %[c_offset]\n\t"
|
||||
"mv x9, x8\n\t"
|
||||
"1:\n\t"
|
||||
"li x10, %[iter_n]\n\t"
|
||||
"bge x9, x10, 9f\n\t"
|
||||
"la x3, g_case24_q_row\n\t"
|
||||
"li x7, 0\n\t"
|
||||
"2:\n\t"
|
||||
"add x4, x1, x7\n\t"
|
||||
".insn r %[custom3], 2, 0, x0, x4, x3\n\t"
|
||||
"addi x7, x7, 16\n\t"
|
||||
"li x4, %[tile_bytes]\n\t"
|
||||
"blt x7, x4, 2b\n\t"
|
||||
"la x3, g_case24_zero_row\n\t"
|
||||
"li x7, 0\n\t"
|
||||
"3:\n\t"
|
||||
"add x4, x2, x7\n\t"
|
||||
".insn r %[custom3], 2, 0, x0, x4, x3\n\t"
|
||||
"addi x7, x7, 16\n\t"
|
||||
"li x4, %[tile_bytes]\n\t"
|
||||
"blt x7, x4, 3b\n\t"
|
||||
".insn r %[custom3], 3, 0, x0, x0, x0\n\t"
|
||||
"li x4, %[smem_base]\n\t"
|
||||
".insn r %[custom3], 0, 0, x2, x1, x4\n\t"
|
||||
".insn r %[custom3], 1, 0, x0, x0, x0\n\t"
|
||||
"slli x6, x9, 2\n\t"
|
||||
"la x7, g_case24_score_ready\n\t"
|
||||
"add x7, x7, x6\n\t"
|
||||
"li x6, %[score_ready_base]\n\t"
|
||||
"or x6, x6, x9\n\t"
|
||||
"sw x6, 0(x7)\n\t"
|
||||
"slli x6, x9, 2\n\t"
|
||||
"la x7, g_case24_p_ready\n\t"
|
||||
"add x7, x7, x6\n\t"
|
||||
"li x4, %[p_ready_base]\n\t"
|
||||
"or x4, x4, x9\n\t"
|
||||
"4:\n\t"
|
||||
"lw x6, 0(x7)\n\t"
|
||||
"bne x6, x4, 4b\n\t"
|
||||
"la x3, g_case24_zero_row\n\t"
|
||||
"li x7, 0\n\t"
|
||||
"5:\n\t"
|
||||
"add x4, x2, x7\n\t"
|
||||
".insn r %[custom3], 2, 0, x0, x4, x3\n\t"
|
||||
"addi x7, x7, 16\n\t"
|
||||
"li x4, %[tile_bytes]\n\t"
|
||||
"blt x7, x4, 5b\n\t"
|
||||
".insn r %[custom3], 3, 0, x0, x0, x0\n\t"
|
||||
"li x4, %[smem_base]\n\t"
|
||||
".insn r %[custom3], 0, 0, x2, x1, x4\n\t"
|
||||
".insn r %[custom3], 1, 0, x0, x0, x0\n\t"
|
||||
"slli x6, x9, 10\n\t"
|
||||
"la x3, g_case24_out\n\t"
|
||||
"add x3, x3, x6\n\t"
|
||||
"li x7, 0\n\t"
|
||||
"6:\n\t"
|
||||
"add x4, x2, x7\n\t"
|
||||
"add x6, x3, x7\n\t"
|
||||
".insn r %[custom3], 6, 0, x0, x4, x6\n\t"
|
||||
"addi x7, x7, 16\n\t"
|
||||
"li x4, %[tile_bytes]\n\t"
|
||||
"blt x7, x4, 6b\n\t"
|
||||
".insn r %[custom3], 3, 0, x0, x0, x0\n\t"
|
||||
"slli x6, x9, 2\n\t"
|
||||
"la x7, g_case24_done\n\t"
|
||||
"add x7, x7, x6\n\t"
|
||||
"li x6, %[done_base]\n\t"
|
||||
"or x6, x6, x9\n\t"
|
||||
"sw x6, 0(x7)\n\t"
|
||||
"addi x9, x9, 2\n\t"
|
||||
"j 1b\n\t"
|
||||
"9:\n\t"
|
||||
".insn r %[custom0], 0, 0, x0, x0, x0\n\t"
|
||||
"10: j 10b\n\t"
|
||||
:
|
||||
: [csr_wid] "i"(VX_CSR_WARP_ID), [custom0] "i"(RISCV_CUSTOM0),
|
||||
[custom3] "i"(RISCV_CUSTOM3),
|
||||
[num_scalar_warps] "i"(NUM_SCALAR_WARPS),
|
||||
[c_offset] "i"(WU_BW_TMEM_C_BYTE_OFFSET),
|
||||
[tile_bytes] "i"(WU_BW_TMEM_TILE_BYTES),
|
||||
[smem_base] "i"(WU_BW_DEV_SMEM_START_ADDR),
|
||||
[iter_n] "i"(WU_CASE24_ITER_N),
|
||||
[score_ready_base] "i"(WU_CASE24_SCORE_READY_BASE),
|
||||
[p_ready_base] "i"(WU_CASE24_P_READY_BASE),
|
||||
[done_base] "i"(WU_CASE24_DONE_BASE)
|
||||
: "memory");
|
||||
}
|
||||
|
||||
extern "C" int wu_main() {
|
||||
if (vx_core_id() != 0 || vx_warp_id() != 0) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
const uint32_t tid = wu_tid();
|
||||
const uint32_t tensor_mask = vx_tensor_warp_mask();
|
||||
if (tid == 0) {
|
||||
wu_case_reset();
|
||||
for (uint32_t i = 0; i < WU_CASE24_ITER_N; ++i) {
|
||||
g_case24_score_ready[i] = 0;
|
||||
g_case24_p_ready[i] = 0;
|
||||
g_case24_done[i] = 0;
|
||||
}
|
||||
for (uint32_t i = 0; i < WU_CASE24_ITER_N * NUM_THREADS; ++i) {
|
||||
g_case24_score_bits[i] = 0;
|
||||
g_case24_p_bits[i] = 0;
|
||||
}
|
||||
for (uint32_t i = 0; i < WU_CASE24_OUT_WORDS; ++i) {
|
||||
g_case24_out[i] = 0;
|
||||
}
|
||||
g_case24_overlap_hint = 0;
|
||||
wu_bw_fill_smem_tile(
|
||||
reinterpret_cast<volatile uint32_t *>(WU_BW_DEV_SMEM_START_ADDR),
|
||||
WU_BW_FP16_ONE_PACKED);
|
||||
vx_spawn_tensor(tensor_mask, tensor_case24_worker);
|
||||
}
|
||||
asm volatile("fence rw, rw" ::: "memory");
|
||||
|
||||
for (uint32_t iter = 0; iter < WU_CASE24_ITER_N; ++iter) {
|
||||
if (wu_case24_wait_status(g_case24_score_ready, iter,
|
||||
WU_CASE24_SCORE_READY_BASE) != 0) {
|
||||
if (tid == 0) {
|
||||
g_case_mem[1] = 0x81u;
|
||||
g_aux[0] = iter;
|
||||
}
|
||||
}
|
||||
asm volatile("fence rw, rw" ::: "memory");
|
||||
|
||||
const uint32_t slot = iter & 1u;
|
||||
const uint32_t c_frag =
|
||||
wu_bw_tmem_c_byte_base(slot) / WU_BW_TMEM_FRAGMENT_BYTES;
|
||||
const uint32_t observed = wu_bw_scalar_tmem_ld(c_frag);
|
||||
g_case24_score_bits[iter * NUM_THREADS + tid] = observed;
|
||||
asm volatile("fence rw, rw" ::: "memory");
|
||||
|
||||
if (tid == 0 && g_case_mem[1] == 0 &&
|
||||
observed != WU_CASE24_FP32_THIRTY_TWO) {
|
||||
g_aux[0] = iter;
|
||||
g_aux[1] = observed;
|
||||
g_case_mem[1] = 0x82u;
|
||||
}
|
||||
asm volatile("fence rw, rw" ::: "memory");
|
||||
|
||||
if (g_case_mem[1] == 0) {
|
||||
vx_tmc(wu_bw_all_lanes_mask());
|
||||
wu_case24_softmax_tmem_row_to_p(iter, c_frag,
|
||||
wu_bw_tmem_a_byte_base(slot));
|
||||
vx_tmc_one();
|
||||
}
|
||||
asm volatile("fence rw, rw" ::: "memory");
|
||||
|
||||
if (tid == 0 && g_case_mem[1] == 0) {
|
||||
if (iter == 0 &&
|
||||
g_case24_score_ready[1] == (WU_CASE24_SCORE_READY_BASE | 1u)) {
|
||||
g_case24_overlap_hint = 1;
|
||||
}
|
||||
g_case24_p_ready[iter] = WU_CASE24_P_READY_BASE | iter;
|
||||
}
|
||||
asm volatile("fence rw, rw" ::: "memory");
|
||||
}
|
||||
|
||||
if (tid == 0) {
|
||||
for (uint32_t iter = 0; iter < WU_CASE24_ITER_N; ++iter) {
|
||||
if (g_case_mem[1] == 0 &&
|
||||
wu_case24_wait_status(g_case24_done, iter, WU_CASE24_DONE_BASE) !=
|
||||
0) {
|
||||
g_aux[0] = iter;
|
||||
g_case_mem[1] = 0x83u;
|
||||
}
|
||||
}
|
||||
if (g_case_mem[1] == 0) {
|
||||
volatile uint32_t bad_actual = 0;
|
||||
const uint32_t bad =
|
||||
wu_bw_verify_constant(g_case24_out, WU_CASE24_OUT_WORDS,
|
||||
WU_BW_FP32_ONE, &bad_actual);
|
||||
if (bad != WU_CASE24_OUT_WORDS) {
|
||||
g_aux[0] = bad;
|
||||
g_aux[1] = bad_actual;
|
||||
g_case_mem[1] = 0x84u;
|
||||
}
|
||||
}
|
||||
if (g_case_mem[1] != 0) {
|
||||
wu_case_fail(g_case_mem[1]);
|
||||
return 1;
|
||||
}
|
||||
wu_case_pass();
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
101
kernels/wu_arch_cases/common_wu_blackwell_fa.h
Normal file
101
kernels/wu_arch_cases/common_wu_blackwell_fa.h
Normal file
@@ -0,0 +1,101 @@
|
||||
#ifndef WU_ARCH_CASES_COMMON_WU_BLACKWELL_FA_H
|
||||
#define WU_ARCH_CASES_COMMON_WU_BLACKWELL_FA_H
|
||||
|
||||
#include "common_wu_min.h"
|
||||
|
||||
#define WU_BW_DEV_SMEM_START_ADDR 0xff000000u
|
||||
#define WU_BW_TMEM_TILE_BYTES 1024u
|
||||
#define WU_BW_TMEM_FRAGMENT_BYTES 16u
|
||||
#define WU_BW_TMEM_WARP_BYTES 2048u
|
||||
#define WU_BW_TMEM_C_BYTE_OFFSET 1024u
|
||||
#define WU_BW_TMEM_FRAGMENTS \
|
||||
(WU_BW_TMEM_TILE_BYTES / WU_BW_TMEM_FRAGMENT_BYTES)
|
||||
#define WU_BW_TMEM_WORDS (WU_BW_TMEM_TILE_BYTES / sizeof(uint32_t))
|
||||
#define WU_BW_OUT_WORDS WU_BW_TMEM_WORDS
|
||||
|
||||
#define WU_BW_FP16_ONE_PACKED 0x3c003c00u
|
||||
#define WU_BW_FP16_TWO_PACKED 0x40004000u
|
||||
#define WU_BW_FP32_ONE 0x3f800000u
|
||||
#define WU_BW_FP32_THREE 0x40400000u
|
||||
#define WU_BW_FP32_FIVE 0x40a00000u
|
||||
#define WU_BW_FP32_SIXTY_FIVE 0x42820000u
|
||||
#define WU_BW_FP32_SIXTY_SEVEN 0x42860000u
|
||||
#define WU_BW_FP32_ONE_THIRTY_THREE 0x43050000u
|
||||
|
||||
static_assert(NUM_SCALAR_WARPS == 2,
|
||||
"Wu Blackwell FA staged cases expect two scalar warps");
|
||||
static_assert(NUM_TENSOR_WARPS == 2,
|
||||
"Wu Blackwell FA staged cases expect two tensor warps");
|
||||
static_assert(NUM_THREADS == 4,
|
||||
"Wu Blackwell FA staged cases expect the 4-lane tensor core");
|
||||
|
||||
static inline uint32_t wu_bw_all_lanes_mask() {
|
||||
return (1u << NUM_THREADS) - 1u;
|
||||
}
|
||||
|
||||
static inline uint32_t wu_bw_scalar_tmem_ld(uint32_t frag_addr) {
|
||||
uint32_t value;
|
||||
asm volatile(".insn r %[custom1], 0, 0x30, %[value], %[addr], x0"
|
||||
: [value] "=r"(value)
|
||||
: [custom1] "i"(RISCV_CUSTOM1), [addr] "r"(frag_addr)
|
||||
: "memory");
|
||||
return value;
|
||||
}
|
||||
|
||||
static inline void wu_bw_scalar_tmem_st(uint32_t frag_addr, uint32_t value) {
|
||||
asm volatile(".insn r %[custom1], 1, 0x30, x0, %[addr], %[value]"
|
||||
:
|
||||
: [custom1] "i"(RISCV_CUSTOM1), [addr] "r"(frag_addr),
|
||||
[value] "r"(value)
|
||||
: "memory");
|
||||
}
|
||||
|
||||
static inline uint32_t wu_bw_tmem_a_byte_base(uint32_t tensor_block) {
|
||||
return tensor_block * WU_BW_TMEM_WARP_BYTES;
|
||||
}
|
||||
|
||||
static inline uint32_t wu_bw_tmem_c_byte_base(uint32_t tensor_block) {
|
||||
return wu_bw_tmem_a_byte_base(tensor_block) + WU_BW_TMEM_C_BYTE_OFFSET;
|
||||
}
|
||||
|
||||
static inline void wu_bw_fill_tmem_tile(uint32_t byte_base, uint32_t value) {
|
||||
const uint32_t first_frag = byte_base / WU_BW_TMEM_FRAGMENT_BYTES;
|
||||
for (uint32_t frag = 0; frag < WU_BW_TMEM_FRAGMENTS; ++frag) {
|
||||
// Caller must run this with all TMEM fragment lanes active.
|
||||
wu_bw_scalar_tmem_st(first_frag + frag, value);
|
||||
}
|
||||
asm volatile("fence rw, rw" ::: "memory");
|
||||
}
|
||||
|
||||
static inline void wu_bw_store_tmem_fragment_lane_values(uint32_t byte_base,
|
||||
uint32_t frag,
|
||||
uint32_t lane_value) {
|
||||
const uint32_t first_frag = byte_base / WU_BW_TMEM_FRAGMENT_BYTES;
|
||||
wu_bw_scalar_tmem_st(first_frag + frag, lane_value);
|
||||
asm volatile("fence rw, rw" ::: "memory");
|
||||
}
|
||||
|
||||
static inline void wu_bw_fill_smem_tile(volatile uint32_t *smem,
|
||||
uint32_t value) {
|
||||
for (uint32_t i = 0; i < WU_BW_TMEM_WORDS; ++i) {
|
||||
smem[i] = value;
|
||||
}
|
||||
asm volatile("fence rw, rw" ::: "memory");
|
||||
}
|
||||
|
||||
static inline uint32_t wu_bw_verify_constant(const volatile uint32_t *data,
|
||||
uint32_t words,
|
||||
uint32_t expected,
|
||||
volatile uint32_t *bad_actual) {
|
||||
for (uint32_t i = 0; i < words; ++i) {
|
||||
const uint32_t actual = data[i];
|
||||
if (actual != expected) {
|
||||
*bad_actual = actual;
|
||||
return i;
|
||||
}
|
||||
}
|
||||
*bad_actual = 0;
|
||||
return words;
|
||||
}
|
||||
|
||||
#endif
|
||||
@@ -5,8 +5,12 @@
|
||||
#include <vx_intrinsics.h>
|
||||
|
||||
#define WU_CASE_MAX_WARPS 8u
|
||||
#ifndef WU_CASE_WAIT_SPIN
|
||||
#define WU_CASE_WAIT_SPIN 1024u
|
||||
#endif
|
||||
#ifndef WU_CASE_SHORT_SPIN
|
||||
#define WU_CASE_SHORT_SPIN 8u
|
||||
#endif
|
||||
|
||||
#define WU_CASE_PASS 0x600du
|
||||
#define WU_CASE_FAIL_BASE 0xe000u
|
||||
@@ -15,6 +19,10 @@
|
||||
#define WU_CASE_TENSOR_CSR_BASE 0x7300u
|
||||
#define WU_CASE_TENSOR_LSU_BASE 0x7400u
|
||||
|
||||
#ifndef WU_START_BRANCH_TO_MAIN
|
||||
#define WU_START_BRANCH_TO_MAIN 0
|
||||
#endif
|
||||
|
||||
extern "C" {
|
||||
volatile uint32_t g_status[WU_CASE_MAX_WARPS] __attribute__((aligned(32)));
|
||||
volatile uint32_t g_seen[WU_CASE_MAX_WARPS] __attribute__((aligned(32)));
|
||||
@@ -35,8 +43,12 @@ extern "C" void __attribute__((naked, section(".init"), used)) _start() {
|
||||
"csrr t0, %[csr_core]\n\t"
|
||||
"bnez t0, 2f\n\t"
|
||||
"li sp, %[stack_base]\n\t"
|
||||
#if WU_START_BRANCH_TO_MAIN
|
||||
"beq zero, zero, wu_main\n\t"
|
||||
#else
|
||||
"call wu_main\n\t"
|
||||
"mv gp, a0\n\t"
|
||||
#endif
|
||||
"2:\n\t"
|
||||
".insn r %[custom0], 0, 0, x0, x0, x0\n\t"
|
||||
"1: j 1b\n\t"
|
||||
@@ -113,7 +125,7 @@ static inline void wu_mark_seen(uint32_t base) {
|
||||
}
|
||||
}
|
||||
|
||||
static inline void wu_stop_warp() {
|
||||
static inline void __attribute__((noreturn)) wu_stop_warp() {
|
||||
vx_tmc_zero();
|
||||
while (1) {}
|
||||
}
|
||||
@@ -122,6 +134,14 @@ static inline int wu_is_leader() {
|
||||
return vx_core_id() == 0 && vx_warp_id() == 0 && vx_thread_id() == 0;
|
||||
}
|
||||
|
||||
static inline float wu_fexp_s(float value) {
|
||||
float result;
|
||||
asm volatile(".insn r %[custom1], 2, 0x30, %[rd], %[rs1], x0"
|
||||
: [rd] "=f"(result)
|
||||
: [rs1] "f"(value), [custom1] "i"(RISCV_CUSTOM1));
|
||||
return result;
|
||||
}
|
||||
|
||||
static inline void wu_report_tohost(uint32_t exit_code) {
|
||||
asm volatile("fence rw, rw" ::: "memory");
|
||||
tohost = (static_cast<uint64_t>(exit_code) << 1) | 1u;
|
||||
|
||||
9
kernels/wu_arch_hgemm/Makefile
Normal file
9
kernels/wu_arch_hgemm/Makefile
Normal file
@@ -0,0 +1,9 @@
|
||||
PROJECT = wu_arch_hgemm
|
||||
|
||||
VX_SRCS = kernel.cpp
|
||||
OPTS ?= -n1
|
||||
|
||||
include ../common.mk
|
||||
|
||||
args.bin input.a.bin input.b.bin input.c.bin: ../wu_arch_cases/zero.bin
|
||||
cp $< $@
|
||||
15
kernels/wu_arch_hgemm/README.md
Normal file
15
kernels/wu_arch_hgemm/README.md
Normal file
@@ -0,0 +1,15 @@
|
||||
# wu_arch_hgemm
|
||||
|
||||
Two-tensor-warp HGEMM smoke test for the Wu split scalar/tensor warp
|
||||
configuration with the 4-lane Blackwell tensor-core path.
|
||||
|
||||
Scalar warp 0 initializes the shared-memory B operand for a 32x16x32 GEMM,
|
||||
spawns only the tensor warp mask, waits for tensor warps
|
||||
`NUM_SCALAR_WARPS..NUM_WARPS-1`, verifies the combined 32x16 fp32 output, and
|
||||
reports completion through `tohost`.
|
||||
|
||||
Each tensor warp computes one 16x16x32 BWGMMA M-block using 16-byte fragments:
|
||||
warp `NUM_SCALAR_WARPS` produces rows 0..15 and warp `NUM_SCALAR_WARPS + 1`
|
||||
produces rows 16..31. Both tensor warps share the same B tile, copy their C
|
||||
tiles back into one contiguous output matrix, mark their block IDs, and then
|
||||
stop themselves.
|
||||
1
kernels/wu_arch_hgemm/args.bin
Normal file
1
kernels/wu_arch_hgemm/args.bin
Normal file
@@ -0,0 +1 @@
|
||||
0
|
||||
1
kernels/wu_arch_hgemm/input.a.bin
Normal file
1
kernels/wu_arch_hgemm/input.a.bin
Normal file
@@ -0,0 +1 @@
|
||||
0
|
||||
1
kernels/wu_arch_hgemm/input.b.bin
Normal file
1
kernels/wu_arch_hgemm/input.b.bin
Normal file
@@ -0,0 +1 @@
|
||||
0
|
||||
1
kernels/wu_arch_hgemm/input.c.bin
Normal file
1
kernels/wu_arch_hgemm/input.c.bin
Normal file
@@ -0,0 +1 @@
|
||||
0
|
||||
183
kernels/wu_arch_hgemm/kernel.cpp
Normal file
183
kernels/wu_arch_hgemm/kernel.cpp
Normal file
@@ -0,0 +1,183 @@
|
||||
#include "../wu_arch_cases/common_wu_min.h"
|
||||
|
||||
#define DEV_SMEM_START_ADDR 0xff000000u
|
||||
#define WU_CASE_TENSOR_HGEMM_BASE 0x7500u
|
||||
#define WU_CASE_TENSOR_HGEMM_BLOCK_BASE 0x7600u
|
||||
|
||||
#define WU_HGEMM_M_BLOCKS 2u
|
||||
#define WU_HGEMM_M_PER_BLOCK 16u
|
||||
#define WU_HGEMM_N 16u
|
||||
#define WU_HGEMM_K 32u
|
||||
#define WU_HGEMM_TILE_BYTES 1024u
|
||||
#define WU_HGEMM_FRAGMENT_BYTES 16u
|
||||
#define WU_HGEMM_B_WORDS (WU_HGEMM_TILE_BYTES / sizeof(uint32_t))
|
||||
#define WU_HGEMM_OUT_WORDS \
|
||||
(WU_HGEMM_M_BLOCKS * WU_HGEMM_M_PER_BLOCK * WU_HGEMM_N)
|
||||
#define WU_HGEMM_EXPECTED 0x42820000u
|
||||
#define WU_HGEMM_NO_FAIL WU_HGEMM_OUT_WORDS
|
||||
|
||||
static_assert(NUM_TENSOR_WARPS == WU_HGEMM_M_BLOCKS,
|
||||
"wu_arch_hgemm expects two tensor warps");
|
||||
static_assert((NUM_THREADS * 2u) <= WU_CASE_MAX_WARPS,
|
||||
"wu_arch_hgemm uses g_aux[2 * tid + {0,1}] for check scratch");
|
||||
|
||||
#define BW_REP2(x) x, x
|
||||
#define BW_REP4(x) BW_REP2(x), BW_REP2(x)
|
||||
|
||||
extern "C" {
|
||||
volatile uint32_t g_hgemm_a_row[4] __attribute__((aligned(16))) = {
|
||||
BW_REP4(0x3c003c00u)};
|
||||
volatile uint32_t g_hgemm_b_row[4] __attribute__((aligned(16))) = {
|
||||
BW_REP4(0x40004000u)};
|
||||
volatile uint32_t g_hgemm_c_row[4] __attribute__((aligned(16))) = {
|
||||
BW_REP4(0x3f800000u)};
|
||||
volatile uint32_t g_hgemm_control_fail __attribute__((aligned(16)));
|
||||
volatile uint32_t g_hgemm_return_code __attribute__((aligned(16)));
|
||||
volatile uint32_t g_hgemm_out[WU_HGEMM_OUT_WORDS] __attribute__((aligned(16)));
|
||||
}
|
||||
|
||||
#undef BW_REP2
|
||||
#undef BW_REP4
|
||||
|
||||
extern "C" void __attribute__((naked, noinline, used)) tensor_hgemm_worker() {
|
||||
asm volatile(
|
||||
"csrr x5, %[csr_wid]\n\t"
|
||||
"addi x1, x5, -%[num_scalar_warps]\n\t"
|
||||
"slli x1, x1, 11\n\t"
|
||||
"addi x2, x1, 1024\n\t"
|
||||
"la x6, g_hgemm_a_row\n\t"
|
||||
"la x3, g_hgemm_c_row\n\t"
|
||||
"li x7, 0\n\t"
|
||||
"1:\n\t"
|
||||
"add x4, x1, x7\n\t"
|
||||
".insn r %[custom3], 2, 0, x0, x4, x6\n\t"
|
||||
"add x4, x2, x7\n\t"
|
||||
".insn r %[custom3], 2, 0, x0, x4, x3\n\t"
|
||||
"addi x7, x7, 16\n\t"
|
||||
"li x4, 1024\n\t"
|
||||
"blt x7, x4, 1b\n\t"
|
||||
".insn r %[custom3], 3, 0, x0, x0, x0\n\t"
|
||||
"li x4, %[smem_base]\n\t"
|
||||
".insn r %[custom3], 0, 0, x2, x1, x4\n\t"
|
||||
".insn r %[custom3], 1, 0, x0, x0, x0\n\t"
|
||||
"csrr x5, %[csr_wid]\n\t"
|
||||
"addi x6, x5, -%[num_scalar_warps]\n\t"
|
||||
"slli x7, x6, 10\n\t"
|
||||
"la x3, g_hgemm_out\n\t"
|
||||
"add x3, x3, x7\n\t"
|
||||
"li x7, 0\n\t"
|
||||
"3:\n\t"
|
||||
"add x4, x2, x7\n\t"
|
||||
"add x1, x3, x7\n\t"
|
||||
".insn r %[custom3], 6, 0, x0, x4, x1\n\t"
|
||||
"addi x7, x7, 16\n\t"
|
||||
"li x4, %[tile_bytes]\n\t"
|
||||
"blt x7, x4, 3b\n\t"
|
||||
".insn r %[custom3], 3, 0, x0, x0, x0\n\t"
|
||||
"slli x7, x6, 2\n\t"
|
||||
"la x3, g_case_mem\n\t"
|
||||
"add x3, x3, x7\n\t"
|
||||
"li x7, %[hgemm_block_base]\n\t"
|
||||
"or x7, x7, x6\n\t"
|
||||
"sw x7, 0(x3)\n\t"
|
||||
"slli x6, x5, 2\n\t"
|
||||
"la x7, g_seen\n\t"
|
||||
"add x7, x7, x6\n\t"
|
||||
"li x6, %[hgemm_base]\n\t"
|
||||
"or x6, x6, x5\n\t"
|
||||
"sw x6, 0(x7)\n\t"
|
||||
".insn r %[custom0], 0, 0, x0, x0, x0\n\t"
|
||||
"2: j 2b\n\t"
|
||||
:
|
||||
: [csr_wid] "i"(VX_CSR_WARP_ID),
|
||||
[custom0] "i"(RISCV_CUSTOM0),
|
||||
[custom3] "i"(RISCV_CUSTOM3),
|
||||
[smem_base] "i"(DEV_SMEM_START_ADDR),
|
||||
[hgemm_base] "i"(WU_CASE_TENSOR_HGEMM_BASE),
|
||||
[hgemm_block_base] "i"(WU_CASE_TENSOR_HGEMM_BLOCK_BASE),
|
||||
[num_scalar_warps] "i"(NUM_SCALAR_WARPS),
|
||||
[tile_bytes] "i"(WU_HGEMM_TILE_BYTES)
|
||||
: "memory");
|
||||
}
|
||||
|
||||
extern "C" int wu_main() {
|
||||
if (vx_core_id() != 0 || vx_warp_id() != 0) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
const uint32_t tid = wu_tid();
|
||||
|
||||
if (tid == 0) {
|
||||
wu_case_reset();
|
||||
g_hgemm_control_fail = 0;
|
||||
g_hgemm_return_code = 0;
|
||||
}
|
||||
asm volatile("fence rw, rw" ::: "memory");
|
||||
|
||||
if (tid == 0) {
|
||||
for (uint32_t i = 0; i < WU_HGEMM_OUT_WORDS; ++i) {
|
||||
g_hgemm_out[i] = 0;
|
||||
}
|
||||
|
||||
volatile uint32_t *smem_b =
|
||||
reinterpret_cast<volatile uint32_t *>(DEV_SMEM_START_ADDR);
|
||||
for (uint32_t i = 0; i < WU_HGEMM_B_WORDS; ++i) {
|
||||
smem_b[i] = g_hgemm_b_row[i & 3u];
|
||||
}
|
||||
}
|
||||
asm volatile("fence rw, rw" ::: "memory");
|
||||
|
||||
if (tid == 0) {
|
||||
vx_spawn_tensor(vx_tensor_warp_mask(), tensor_hgemm_worker);
|
||||
|
||||
if (wu_wait_seen_range(NUM_SCALAR_WARPS, NUM_WARPS,
|
||||
WU_CASE_TENSOR_HGEMM_BASE) != 0) {
|
||||
g_hgemm_control_fail = 0x09u;
|
||||
}
|
||||
|
||||
if (g_hgemm_control_fail == 0) {
|
||||
for (uint32_t block = 0; block < WU_HGEMM_M_BLOCKS; ++block) {
|
||||
if (g_case_mem[block] != (WU_CASE_TENSOR_HGEMM_BLOCK_BASE | block)) {
|
||||
g_hgemm_control_fail = 0x0au + block;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
asm volatile("fence rw, rw" ::: "memory");
|
||||
|
||||
if (g_hgemm_control_fail != 0) {
|
||||
if (tid == 0) {
|
||||
g_hgemm_return_code = 1;
|
||||
wu_case_fail(g_hgemm_control_fail);
|
||||
}
|
||||
asm volatile("fence rw, rw" ::: "memory");
|
||||
return static_cast<int>(g_hgemm_return_code);
|
||||
}
|
||||
|
||||
if (tid == 0) {
|
||||
uint32_t bad_index = WU_HGEMM_NO_FAIL;
|
||||
uint32_t bad_actual = 0;
|
||||
|
||||
for (uint32_t i = 0; i < WU_HGEMM_OUT_WORDS; ++i) {
|
||||
const uint32_t actual = g_hgemm_out[i];
|
||||
if (actual != WU_HGEMM_EXPECTED) {
|
||||
bad_index = i;
|
||||
bad_actual = actual;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (bad_index != WU_HGEMM_NO_FAIL) {
|
||||
g_aux[0] = bad_index;
|
||||
g_aux[1] = bad_actual;
|
||||
g_hgemm_return_code = 1;
|
||||
wu_case_fail(0x20u);
|
||||
} else {
|
||||
g_hgemm_return_code = 0;
|
||||
wu_case_pass();
|
||||
}
|
||||
}
|
||||
asm volatile("fence rw, rw" ::: "memory");
|
||||
return static_cast<int>(g_hgemm_return_code);
|
||||
}
|
||||
@@ -84,15 +84,32 @@
|
||||
#endif
|
||||
|
||||
#ifndef NUM_CORES
|
||||
#define NUM_CORES 8
|
||||
#define NUM_CORES 1
|
||||
#endif
|
||||
|
||||
#ifndef NUM_WARPS
|
||||
#define NUM_WARPS 8
|
||||
#define NUM_WARPS 4
|
||||
#endif
|
||||
|
||||
#ifndef NUM_TENSOR_WARPS
|
||||
#define NUM_TENSOR_WARPS 2
|
||||
#endif
|
||||
|
||||
#define NUM_SCALAR_WARPS (NUM_WARPS - NUM_TENSOR_WARPS)
|
||||
|
||||
#define IS_SCALAR_WARP(wid) ((wid) < NUM_SCALAR_WARPS)
|
||||
#define IS_TENSOR_WARP(wid) ((wid) >= NUM_SCALAR_WARPS)
|
||||
|
||||
#ifndef TENSOR_NUM_GPRS
|
||||
#define TENSOR_NUM_GPRS 8
|
||||
#endif
|
||||
|
||||
#ifndef TENSOR_NUM_FPRS
|
||||
#define TENSOR_NUM_FPRS 8
|
||||
#endif
|
||||
|
||||
#ifndef NUM_THREADS
|
||||
#define NUM_THREADS 8
|
||||
#define NUM_THREADS 4
|
||||
#endif
|
||||
|
||||
#ifndef NUM_BARRIERS
|
||||
@@ -682,4 +699,3 @@
|
||||
#define IMPLEMENTATION_ID 0
|
||||
|
||||
#endif // VX_CONFIG_VH
|
||||
|
||||
|
||||
@@ -136,6 +136,19 @@ inline void vx_wspawn(unsigned num_warps, vx_wspawn_pfn func_ptr) {
|
||||
asm volatile (".insn r %0, 1, 0, x0, %1, %2" :: "i"(RISCV_CUSTOM0), "r"(num_warps), "r"(func_ptr));
|
||||
}
|
||||
|
||||
// Spawn an explicit warp mask. The current warp bit is ignored by hardware.
|
||||
inline void vx_wspawn_mask(unsigned warp_mask, vx_wspawn_pfn func_ptr) {
|
||||
asm volatile (".insn r %0, 6, 0, x0, %1, %2" :: "i"(RISCV_CUSTOM0), "r"(warp_mask), "r"(func_ptr));
|
||||
}
|
||||
|
||||
inline void vx_spawn_scalar(unsigned warp_mask, vx_wspawn_pfn func_ptr) {
|
||||
vx_wspawn_mask(warp_mask & ((1u << NUM_SCALAR_WARPS) - 1u), func_ptr);
|
||||
}
|
||||
|
||||
inline void vx_spawn_tensor(unsigned warp_mask, vx_wspawn_pfn func_ptr) {
|
||||
vx_wspawn_mask(warp_mask & (((1u << NUM_TENSOR_WARPS) - 1u) << NUM_SCALAR_WARPS), func_ptr);
|
||||
}
|
||||
|
||||
// Split on a predicate
|
||||
inline unsigned vx_split(unsigned predicate) {
|
||||
unsigned ret;
|
||||
@@ -151,7 +164,34 @@ inline void vx_join(unsigned stack_ptr) {
|
||||
// Warp Barrier
|
||||
__attribute__((convergent))
|
||||
inline void vx_barrier(unsigned barried_id, unsigned num_warps) {
|
||||
asm volatile (".insn r %0, 4, 0, x0, %1, %2" :: "i"(RISCV_CUSTOM0), "r"(barried_id), "r"(num_warps));
|
||||
unsigned scalar_warps = (num_warps > NUM_SCALAR_WARPS) ? NUM_SCALAR_WARPS : num_warps;
|
||||
asm volatile (".insn r %0, 4, 0, x0, %1, %2" :: "i"(RISCV_CUSTOM0), "r"(barried_id), "r"(scalar_warps));
|
||||
}
|
||||
|
||||
#define VX_BARRIER_DOMAIN_SHIFT 28
|
||||
#define VX_BARRIER_DOMAIN_ALL 0u
|
||||
#define VX_BARRIER_DOMAIN_SCALAR 1u
|
||||
#define VX_BARRIER_DOMAIN_TENSOR 2u
|
||||
|
||||
__attribute__((convergent))
|
||||
inline void vx_barrier_domain(unsigned barrier_id, unsigned num_warps, unsigned domain) {
|
||||
unsigned encoded_id = barrier_id | (domain << VX_BARRIER_DOMAIN_SHIFT);
|
||||
asm volatile (".insn r %0, 4, 0, x0, %1, %2" :: "i"(RISCV_CUSTOM0), "r"(encoded_id), "r"(num_warps));
|
||||
}
|
||||
|
||||
__attribute__((convergent))
|
||||
inline void vx_barrier_scalar(unsigned barrier_id, unsigned num_warps) {
|
||||
vx_barrier_domain(barrier_id, num_warps, VX_BARRIER_DOMAIN_SCALAR);
|
||||
}
|
||||
|
||||
__attribute__((convergent))
|
||||
inline void vx_barrier_tensor(unsigned barrier_id, unsigned num_warps) {
|
||||
vx_barrier_domain(barrier_id, num_warps, VX_BARRIER_DOMAIN_TENSOR);
|
||||
}
|
||||
|
||||
__attribute__((convergent))
|
||||
inline void vx_barrier_mask(unsigned barrier_id, unsigned warp_mask) {
|
||||
asm volatile (".insn r %0, 7, 0, x0, %1, %2" :: "i"(RISCV_CUSTOM0), "r"(barrier_id), "r"(warp_mask));
|
||||
}
|
||||
|
||||
// Return current thread identifier
|
||||
@@ -203,6 +243,22 @@ inline int vx_num_warps() {
|
||||
return ret;
|
||||
}
|
||||
|
||||
inline int vx_num_scalar_warps() {
|
||||
return NUM_SCALAR_WARPS;
|
||||
}
|
||||
|
||||
inline int vx_num_tensor_warps() {
|
||||
return NUM_TENSOR_WARPS;
|
||||
}
|
||||
|
||||
inline unsigned vx_scalar_warp_mask() {
|
||||
return (1u << NUM_SCALAR_WARPS) - 1u;
|
||||
}
|
||||
|
||||
inline unsigned vx_tensor_warp_mask() {
|
||||
return ((1u << NUM_TENSOR_WARPS) - 1u) << NUM_SCALAR_WARPS;
|
||||
}
|
||||
|
||||
// Return the number of cores per cluster
|
||||
inline int vx_num_cores() {
|
||||
int ret;
|
||||
|
||||
@@ -76,7 +76,7 @@ static void __attribute__ ((noinline)) spawn_tasks_all_stub() {
|
||||
|
||||
static void __attribute__ ((noinline)) spawn_tasks_contiguous_all_stub() {
|
||||
int NT = vx_num_threads();
|
||||
int NW = vx_num_warps();
|
||||
int NW = NUM_SCALAR_WARPS;
|
||||
int cid = vx_core_id();
|
||||
int wid = vx_warp_id();
|
||||
int tid = vx_thread_id();
|
||||
@@ -96,7 +96,7 @@ static void __attribute__ ((noinline)) spawn_tasks_contiguous_all_stub() {
|
||||
|
||||
static void __attribute__ ((noinline)) spawn_tasks_cluster_all_stub() {
|
||||
int NT = vx_num_threads();
|
||||
int NW = vx_num_warps();
|
||||
int NW = NUM_SCALAR_WARPS;
|
||||
int cid = vx_core_id();
|
||||
int wid = vx_warp_id();
|
||||
int tid = vx_thread_id();
|
||||
@@ -187,7 +187,7 @@ static void __attribute__ ((noinline)) spawn_tasks_all_cb() {
|
||||
void vx_spawn_tasks_cluster(int num_tasks, vx_spawn_tasks_cb callback, void *arg) {
|
||||
// device specs
|
||||
const int NC = vx_num_cores();
|
||||
const int NW = vx_num_warps();
|
||||
const int NW = NUM_SCALAR_WARPS;
|
||||
const int NT = vx_num_threads();
|
||||
// NOTE: assumes divisible
|
||||
const int num_cluster = NC / CORES_PER_CLUSTER;
|
||||
@@ -243,7 +243,7 @@ void vx_spawn_tasks_cluster(int num_tasks, vx_spawn_tasks_cb callback, void *arg
|
||||
const int num_full_waves = num_warps_this_core / NW;
|
||||
const int rem_full_warps_in_last_wave = num_warps_this_core % NW;
|
||||
|
||||
const const int offset = cluster_id * num_tasks_this_cluster;
|
||||
const int offset = cluster_id * num_tasks_this_cluster;
|
||||
wspawn_tasks_args_t wspawn_args = {callback, arg, offset, num_full_waves,
|
||||
rem_full_warps_in_last_wave};
|
||||
g_wspawn_args[core_id] = &wspawn_args;
|
||||
@@ -289,7 +289,7 @@ void vx_spawn_tasks_cluster(int num_tasks, vx_spawn_tasks_cb callback, void *arg
|
||||
void vx_spawn_tasks_contiguous(int num_tasks, vx_spawn_tasks_cb callback , void * arg) {
|
||||
// device specs
|
||||
int NC = vx_num_cores();
|
||||
int NW = vx_num_warps();
|
||||
int NW = NUM_SCALAR_WARPS;
|
||||
int NT = vx_num_threads();
|
||||
|
||||
// current core id
|
||||
@@ -361,7 +361,7 @@ void vx_spawn_tasks_contiguous(int num_tasks, vx_spawn_tasks_cb callback , void
|
||||
void vx_spawn_tasks(int num_tasks, vx_spawn_tasks_cb callback , void * arg) {
|
||||
// device specs
|
||||
int NC = vx_num_cores();
|
||||
int NW = vx_num_warps();
|
||||
int NW = NUM_SCALAR_WARPS;
|
||||
int NT = vx_num_threads();
|
||||
|
||||
// current core id
|
||||
@@ -515,7 +515,7 @@ void vx_spawn_kernel(context_t * ctx, vx_spawn_kernel_cb callback, void * arg) {
|
||||
|
||||
// device specs
|
||||
int NC = vx_num_cores();
|
||||
int NW = vx_num_warps();
|
||||
int NW = NUM_SCALAR_WARPS;
|
||||
int NT = vx_num_threads();
|
||||
|
||||
// current core id
|
||||
|
||||
@@ -22,9 +22,9 @@
|
||||
_start:
|
||||
|
||||
# initialize per-thread registers
|
||||
csrr t0, VX_CSR_NUM_WARPS # get num warps
|
||||
li t0, ((1 << NUM_SCALAR_WARPS) - 1) # scalar warp mask
|
||||
la t1, init_regs_all
|
||||
.insn r RISCV_CUSTOM0, 1, 0, x0, t0, t1 # wspawn t0, t1
|
||||
.insn r RISCV_CUSTOM0, 6, 0, x0, t0, t1 # wspawn_mask t0, t1
|
||||
li t0, -1
|
||||
.insn r RISCV_CUSTOM0, 0, 0, x0, t0, x0 # tmc t0
|
||||
jal init_regs
|
||||
@@ -35,9 +35,9 @@ _start:
|
||||
jal vx_wspawn_wait
|
||||
|
||||
# initialize TLS for all warps
|
||||
csrr t0, VX_CSR_NUM_WARPS # get num warps
|
||||
li t0, ((1 << NUM_SCALAR_WARPS) - 1) # scalar warp mask
|
||||
la t1, init_tls_all
|
||||
.insn r RISCV_CUSTOM0, 1, 0, x0, t0, t1 # wspawn t0, t1
|
||||
.insn r RISCV_CUSTOM0, 6, 0, x0, t0, t1 # wspawn_mask t0, t1
|
||||
li t0, -1
|
||||
.insn r RISCV_CUSTOM0, 0, 0, x0, t0, x0 # tmc t0
|
||||
call __init_tls
|
||||
@@ -150,4 +150,3 @@ vx_wspawn_wait:
|
||||
.weak __dso_handle
|
||||
__dso_handle:
|
||||
.long 0
|
||||
|
||||
|
||||
Reference in New Issue
Block a user