Add Wu TMEM FlashAttention validation cases

This commit is contained in:
Zhongdi LUO
2026-06-24 06:26:30 +00:00
parent ed16541c8e
commit d6fbd447c3
49 changed files with 2395 additions and 26 deletions

View File

@@ -0,0 +1,7 @@
PROJECT = blackwell_multi_tc
VX_SRCS = kernel.cpp
OPTS ?= -n1
include ../common.mk

View File

@@ -0,0 +1,188 @@
#include <stdint.h>
#include <vx_intrinsics.h>
#include <vx_spawn.h>
#define RISCV_CUSTOM3 0x7B
#define DEV_SMEM_START_ADDR 0xff000000u
#define MAX_CORES 4
#define MAX_WARPS 8
#define BW_REP2(x) x, x
#define BW_REP4(x) BW_REP2(x), BW_REP2(x)
#define BW_REP8(x) BW_REP4(x), BW_REP4(x)
extern "C" {
volatile uint32_t g_a_row[8] __attribute__((aligned(32))) = {
BW_REP8(0x3c003c00u)}; // two fp16 1.0 values per word
volatile uint32_t g_b_row[8] __attribute__((aligned(32))) = {
BW_REP8(0x40004000u)}; // two fp16 2.0 values per word
volatile uint32_t g_c_row[8] __attribute__((aligned(32))) = {
BW_REP8(0x3f800000u)}; // one fp32 1.0 value per word
volatile uint32_t g_result[MAX_CORES * MAX_WARPS] __attribute__((aligned(32)));
volatile uint32_t g_status[MAX_CORES * MAX_WARPS] __attribute__((aligned(32)));
}
#undef BW_REP2
#undef BW_REP4
#undef BW_REP8
static inline void tcgen05_cp(uint32_t addr_tmem, uint32_t addr_gmem) {
asm volatile(".insn r %0, 2, 0, x0, %1, %2"
:
: "i"(RISCV_CUSTOM3), "r"(addr_tmem), "r"(addr_gmem)
: "memory");
}
static inline void tcgen05_cp_wait() {
asm volatile(".insn r %0, 3, 0, x0, x0, x0" :: "i"(RISCV_CUSTOM3)
: "memory");
}
static inline void bwgmma(uint32_t addr_tmem_c,
uint32_t addr_tmem_a,
uint32_t addr_smem_b) {
asm volatile(".insn r %0, 0, 0, %1, %2, %3"
:
: "i"(RISCV_CUSTOM3), "r"(addr_tmem_c), "r"(addr_tmem_a),
"r"(addr_smem_b)
: "memory");
}
static inline void bwgmma_wait() {
asm volatile(".insn r %0, 1, 0, x0, x0, x0" :: "i"(RISCV_CUSTOM3)
: "memory");
}
static inline float tcgen05_ld_f32(uint32_t addr_tmem) {
float value;
asm volatile(".insn r %1, 4, 0, %0, %2, x0"
: "=f"(value)
: "i"(RISCV_CUSTOM3), "r"(addr_tmem)
: "memory");
return value;
}
static inline uint32_t f32_bits(float value) {
union {
float f;
uint32_t u;
} bits = {value};
return bits.u;
}
extern "C" void vx_perf_dump() {}
extern "C" void __attribute__((naked, noinline, used)) tensor_kernel_entry() {
asm volatile(
"csrr x5, %[csr_wid]\n\t"
"slli x1, x5, 11\n\t" // tmem_a = wid * 0x800
"addi x2, x1, 1024\n\t" // tmem_c = tmem_a + 0x400
"la x6, g_a_row\n\t"
"la x3, g_c_row\n\t"
"li x7, 0\n\t"
"1:\n\t"
"add x4, x1, x7\n\t"
".insn r %[custom3], 2, 0, x0, x4, x6\n\t"
"add x4, x2, x7\n\t"
".insn r %[custom3], 2, 0, x0, x4, x3\n\t"
"addi x7, x7, 32\n\t"
"li x4, 1024\n\t"
"blt x7, x4, 1b\n\t"
".insn r %[custom3], 3, 0, x0, x0, x0\n\t"
"li x4, %[smem_base]\n\t"
".insn r %[custom3], 0, 0, x2, x1, x4\n\t"
".insn r %[custom3], 1, 0, x0, x0, x0\n\t"
"csrr x5, %[csr_wid]\n\t"
"slli x5, x5, 2\n\t"
"la x6, g_status\n\t"
"add x6, x6, x5\n\t"
"li x7, 0x600d\n\t"
"sw x7, 0(x6)\n\t"
".insn r %[custom0], 0, 0, x0, x0, x0\n\t"
"2: j 2b\n\t"
:
: [csr_wid] "i"(VX_CSR_WARP_ID), [custom0] "i"(RISCV_CUSTOM0),
[custom3] "i"(RISCV_CUSTOM3), [smem_base] "i"(DEV_SMEM_START_ADDR)
: "memory");
}
struct kernel_arg_t {
volatile uint32_t *a_row;
volatile uint32_t *b_row;
volatile uint32_t *c_row;
volatile uint32_t *result;
volatile uint32_t *status;
};
static void __attribute__((convergent)) kernel_body(int task_id,
kernel_arg_t *__UNIFORM__ arg) {
const int warp_id = task_id / vx_num_threads();
const int lane_id = task_id % vx_num_threads();
if (lane_id != 0)
return;
const int num_warps = vx_num_warps();
if (warp_id >= num_warps)
return;
volatile uint32_t *a_row = arg->a_row;
volatile uint32_t *c_row = arg->c_row;
volatile uint32_t *result = arg->result;
volatile uint32_t *status = arg->status;
const int core_id = vx_core_id();
const int status_idx = core_id * MAX_WARPS + warp_id;
const uint32_t smem_b = DEV_SMEM_START_ADDR;
const uint32_t base = static_cast<uint32_t>(warp_id) * 0x800u;
const uint32_t tmem_a = base + 0x000u;
const uint32_t tmem_c = base + 0x400u;
status[status_idx] = 0x100u;
for (int frag = 0; frag < 32; ++frag) {
const uint32_t offset = static_cast<uint32_t>(frag * 32);
tcgen05_cp(tmem_a + offset, reinterpret_cast<uint32_t>(a_row));
tcgen05_cp(tmem_c + offset, reinterpret_cast<uint32_t>(c_row));
}
tcgen05_cp_wait();
status[status_idx] = 0x200u;
bwgmma(tmem_c, tmem_a, smem_b);
bwgmma_wait();
result[status_idx] = f32_bits(tcgen05_ld_f32(tmem_c));
status[status_idx] = (result[status_idx] == 0x42820000u) ? 0x600du : 0xe000u;
}
int main() {
const int core_id = vx_core_id();
if (core_id != 0 || vx_warp_id() != 0 || vx_thread_id() != 0)
return 0;
volatile uint32_t *smem_b_ptr =
reinterpret_cast<volatile uint32_t *>(DEV_SMEM_START_ADDR);
for (int frag = 0; frag < 32; ++frag) {
const int row = frag * 8;
for (int i = 0; i < 8; ++i)
smem_b_ptr[row + i] = g_b_row[i];
}
for (int i = 0; i < MAX_WARPS; ++i)
g_status[core_id * MAX_WARPS + i] = 0;
vx_spawn_tensor(vx_tensor_warp_mask(), tensor_kernel_entry);
for (int spin = 0; spin < 100000; ++spin) {
int done = 1;
for (int i = NUM_SCALAR_WARPS; i < vx_num_warps(); ++i)
done &= (g_status[core_id * MAX_WARPS + i] == 0x600du);
if (done)
break;
}
for (int i = NUM_SCALAR_WARPS; i < vx_num_warps(); ++i) {
if (g_status[core_id * MAX_WARPS + i] != 0x600du)
return 1;
}
return 0;
}

View File

@@ -1,3 +1,6 @@
# Get the directory where this common.mk file is located
COMMON_MK_DIR := $(dir $(lastword $(MAKEFILE_LIST)))
XLEN ?= 32 XLEN ?= 32
TOOLDIR ?= /opt TOOLDIR ?= /opt
@@ -7,7 +10,7 @@ RISCV_TOOLCHAIN_PATH ?= $(TOOLDIR)/riscv64-gnu-toolchain
VX_CFLAGS += -march=rv64imafd -mabi=lp64d VX_CFLAGS += -march=rv64imafd -mabi=lp64d
STARTUP_ADDR ?= 0x180000000 STARTUP_ADDR ?= 0x180000000
else else
RISCV_TOOLCHAIN_PATH ?= $(TOOLDIR)/riscv-gnu-toolchain RISCV_TOOLCHAIN_PATH ?= $(realpath $(COMMON_MK_DIR)../../toolchain/vortex-toolchain-prebuilt/riscv-gnu-toolchain)
VX_CFLAGS += -march=rv32imaf -mabi=ilp32f VX_CFLAGS += -march=rv32imaf -mabi=ilp32f
STARTUP_ADDR ?= 0x80000000 STARTUP_ADDR ?= 0x80000000
endif endif
@@ -18,7 +21,7 @@ RISCV_SYSROOT ?= $(RISCV_TOOLCHAIN_PATH)/$(RISCV_PREFIX)
VORTEX_KN_PATH ?= $(realpath ../../lib) VORTEX_KN_PATH ?= $(realpath ../../lib)
GEMMINI_SW_PATH ?= $(realpath ../../lib/gemmini) GEMMINI_SW_PATH ?= $(realpath ../../lib/gemmini)
LLVM_VORTEX ?= $(TOOLDIR)/llvm-vortex LLVM_VORTEX ?= $(realpath $(COMMON_MK_DIR)../../toolchain/llvm-r8)
LLVM_CFLAGS += --sysroot=$(RISCV_SYSROOT) LLVM_CFLAGS += --sysroot=$(RISCV_SYSROOT)
LLVM_CFLAGS += --gcc-toolchain=$(RISCV_TOOLCHAIN_PATH) LLVM_CFLAGS += --gcc-toolchain=$(RISCV_TOOLCHAIN_PATH)

View File

@@ -0,0 +1,9 @@
PROJECT = hgemm_validation
VX_SRCS = kernel.cpp
OPTS ?= -n1
include ../common.mk
args.bin input.a.bin input.b.bin input.c.bin: ../wu_arch_cases/zero.bin
cp $< $@

View File

@@ -0,0 +1,14 @@
# hgemm_validation
Small HGEMM correctness validation for the 4-lane Blackwell tensor-core path.
The test runs one tensor warp on a single 16x16x32 HGEMM tile:
- A is fp16 1.0
- B is fp16 2.0
- C starts as fp32 1.0
- expected output is fp32 65.0 for all 16x16 C elements
Scalar warp 0 initializes the shared-memory B tile, spawns only the first tensor
warp, waits for completion, and checks all 256 output words copied back from
TMEM. Success reports `WU_CASE_PASS` through `tohost`.

View File

@@ -0,0 +1,119 @@
#include "../wu_arch_cases/common_wu_min.h"
#define DEV_SMEM_START_ADDR 0xff000000u
#define HGEMM_VALIDATION_BASE 0x7700u
#define HGEMM_M 16u
#define HGEMM_N 16u
#define HGEMM_K 32u
#define HGEMM_TILE_BYTES 1024u
#define HGEMM_FRAGMENT_BYTES 16u
#define HGEMM_FRAGMENTS (HGEMM_TILE_BYTES / HGEMM_FRAGMENT_BYTES)
#define HGEMM_OUT_WORDS (HGEMM_M * HGEMM_N)
#define HGEMM_EXPECTED 0x42820000u
#define BW_REP2(x) x, x
#define BW_REP4(x) BW_REP2(x), BW_REP2(x)
extern "C" {
volatile uint32_t g_hgemm_a_frag[4] __attribute__((aligned(16))) = {
BW_REP4(0x3c003c00u)};
volatile uint32_t g_hgemm_b_frag[4] __attribute__((aligned(16))) = {
BW_REP4(0x40004000u)};
volatile uint32_t g_hgemm_c_frag[4] __attribute__((aligned(16))) = {
BW_REP4(0x3f800000u)};
volatile uint32_t g_hgemm_out[HGEMM_OUT_WORDS] __attribute__((aligned(16)));
}
#undef BW_REP2
#undef BW_REP4
extern "C" void __attribute__((naked, noinline, used)) hgemm_validation_worker() {
asm volatile(
"li x1, 0\n\t"
"li x2, %[tile_bytes]\n\t"
"la x6, g_hgemm_a_frag\n\t"
"la x3, g_hgemm_c_frag\n\t"
"li x7, 0\n\t"
"1:\n\t"
"add x4, x1, x7\n\t"
".insn r %[custom3], 2, 0, x0, x4, x6\n\t"
"add x4, x2, x7\n\t"
".insn r %[custom3], 2, 0, x0, x4, x3\n\t"
"addi x7, x7, %[frag_bytes]\n\t"
"li x4, %[tile_bytes]\n\t"
"blt x7, x4, 1b\n\t"
".insn r %[custom3], 3, 0, x0, x0, x0\n\t"
"li x4, %[smem_base]\n\t"
".insn r %[custom3], 0, 0, x2, x1, x4\n\t"
".insn r %[custom3], 1, 0, x0, x0, x0\n\t"
"la x3, g_hgemm_out\n\t"
"li x7, 0\n\t"
"2:\n\t"
"add x4, x2, x7\n\t"
"add x1, x3, x7\n\t"
".insn r %[custom3], 6, 0, x0, x4, x1\n\t"
"addi x7, x7, %[frag_bytes]\n\t"
"li x4, %[tile_bytes]\n\t"
"blt x7, x4, 2b\n\t"
".insn r %[custom3], 3, 0, x0, x0, x0\n\t"
"csrr x5, %[csr_wid]\n\t"
"slli x6, x5, 2\n\t"
"la x7, g_seen\n\t"
"add x7, x7, x6\n\t"
"li x6, %[done_base]\n\t"
"or x6, x6, x5\n\t"
"sw x6, 0(x7)\n\t"
".insn r %[custom0], 0, 0, x0, x0, x0\n\t"
"3: j 3b\n\t"
:
: [csr_wid] "i"(VX_CSR_WARP_ID),
[custom0] "i"(RISCV_CUSTOM0),
[custom3] "i"(RISCV_CUSTOM3),
[smem_base] "i"(DEV_SMEM_START_ADDR),
[done_base] "i"(HGEMM_VALIDATION_BASE),
[tile_bytes] "i"(HGEMM_TILE_BYTES),
[frag_bytes] "i"(HGEMM_FRAGMENT_BYTES)
: "memory");
}
extern "C" int wu_main() {
if (!wu_is_leader()) {
return 0;
}
wu_case_reset();
for (uint32_t i = 0; i < HGEMM_OUT_WORDS; ++i) {
g_hgemm_out[i] = 0;
}
volatile uint32_t *smem_b =
reinterpret_cast<volatile uint32_t *>(DEV_SMEM_START_ADDR);
for (uint32_t frag = 0; frag < HGEMM_FRAGMENTS; ++frag) {
const uint32_t row = frag * 4u;
for (uint32_t i = 0; i < 4u; ++i) {
smem_b[row + i] = g_hgemm_b_frag[i];
}
}
const uint32_t tensor_wid = NUM_SCALAR_WARPS;
vx_spawn_tensor(1u << tensor_wid, hgemm_validation_worker);
if (wu_wait_seen_mask(1u << tensor_wid, HGEMM_VALIDATION_BASE) != 0) {
wu_case_fail(0x09u);
return 1;
}
for (uint32_t i = 0; i < HGEMM_OUT_WORDS; ++i) {
if (g_hgemm_out[i] != HGEMM_EXPECTED) {
g_aux[0] = i;
g_aux[1] = g_hgemm_out[i];
wu_case_fail(0x20u);
return 1;
}
}
wu_case_pass();
return 0;
}

View File

@@ -7,7 +7,19 @@ CASES := \
case05_tensor_barrier \ case05_tensor_barrier \
case06_masked_barrier \ case06_masked_barrier \
case07_tensor_csr_tmc \ case07_tensor_csr_tmc \
case08_tensor_lsu_optional case08_tensor_lsu_optional \
case09_scalar_tmem_ldst \
case10_tensor_scalar_tmem_handoff \
case11_scalar_tmem_softmax_stage \
case12_flash_pv_accum \
case12_1_scalar_tmem_cb_probe \
case12_2_flash_pv_p_probe \
case12_3_scalar_tmem_lane_store \
case13_flash_pv_two_warps \
case14_flash_pv_k64 \
case15_flash_softmax_pv_stage \
case16_flash_full_pipeline \
case17_flash_exp_softmax_probe
SMOKE_CASES := \ SMOKE_CASES := \
case00_boot_scalar \ case00_boot_scalar \

View File

@@ -13,9 +13,27 @@ This directory contains small bare-metal kernels for incremental Wu architecture
- `case06_masked_barrier`: explicit mixed `BAR_MASK` with scalar warp 0 and tensor warps. - `case06_masked_barrier`: explicit mixed `BAR_MASK` with scalar warp 0 and tensor warps.
- `case07_tensor_csr_tmc`: tensor CSR/TMC path without barrier behavior. - `case07_tensor_csr_tmc`: tensor CSR/TMC path without barrier behavior.
- `case08_tensor_lsu_optional`: tensor LSU store/load marker path; keep last because memory interaction is broader and slower. - `case08_tensor_lsu_optional`: tensor LSU store/load marker path; keep last because memory interaction is broader and slower.
- `case09_scalar_tmem_ldst`: scalar warp direct TMEM store/load path for the banked TMEM softmax mechanism.
- `case10_tensor_scalar_tmem_handoff`: tensor BWGMMA result in TMEM C observed by scalar TMEM loads.
- `case11_scalar_tmem_softmax_stage`: scalar TMEM transform written back for tensor-side copy-out.
- `case12_flash_pv_accum`: one tensor warp consumes scalar-written `P` in TMEM A for `O = O + P @ V`.
- `case12_1_scalar_tmem_cb_probe`: scalar-written TMEM A rows copied back by tensor `tcgen05_cb`.
- `case12_2_flash_pv_p_probe`: case12 P-write diagnostic using the same scalar fill path.
- `case12_3_scalar_tmem_lane_store`: scalar TMEM store lane-coalesced fragment write diagnostic.
- `case13_flash_pv_two_warps`: both tensor warps consume scalar-written `P` tiles for two row blocks.
- `case14_flash_pv_k64`: two consecutive `K=32` BWGMMA steps accumulate into one PV output.
- `case15_flash_softmax_pv_stage`: scalar reads TMEM C, writes softmax-like `P`, and tensor consumes it in PV.
- `case16_flash_full_pipeline`: compact `QK -> scalar softmax handoff -> PV` end-to-end FlashAttention-style pipeline.
- `case17_flash_exp_softmax_probe`: scalar non-uniform `e^x` softmax probe for generalized FlashAttention.
Each case has its own `README.md` describing the test objective, RTL surface, and expected pass marker. Each case has its own `README.md` describing the test objective, RTL surface, and expected pass marker.
## Debug Notes
- `TMC_DEBUG_NOTES.md`: reusable notes for diagnosing lane-mask/TMC operand
bugs, including the `case12_3_scalar_tmem_lane_store` failure where a
lane0-only register value was consumed under an all-lane mask.
## Build ## Build
Use the suite Makefile from this directory: Use the suite Makefile from this directory:

View File

@@ -0,0 +1,174 @@
# TMC Operand Debug Notes
This note records a failure pattern found while debugging
`case12_3_scalar_tmem_lane_store`.
## Symptom
The simulator keeps producing trace output, but the kernel makes no forward
progress:
- a tensor worker repeatedly polls a ready flag such as `g_case_mem[0]`;
- the poll load always returns the old value;
- the scalar warp has already committed the work before the ready flag store;
- the expected ready flag store never appears in the LSU trace.
For `case12_3_scalar_tmem_lane_store`, the tensor worker was looping at
`PC=0x80000034..0x80000048`, repeatedly reading `g_case_mem[0]` at
`0x20000430` as `0`. The scalar warp committed the scalar TMEM store and the
following `TMC`, but never issued the next ready flag store at `PC=0x8000015c`.
## Root Cause Pattern
Vortex scalar registers are lane-local. A value computed while only lane 0 is
active is only known to be valid in lane 0. If that register is later consumed
by an instruction while multiple lanes are active, the inactive lanes may hold
stale or zero values.
This is especially easy to miss around `TMC`:
```asm
# Bad when xN was defined while only lane 0 was active.
vx_tmc all_lanes
...
vx_tmc xN
```
The failed `case12_3` sequence used a C operand for `1u` that the compiler
materialized before switching to all lanes. Runtime trace showed the source
register for the final `TMC(1)` as:
```text
rs1_data={0x0, 0x0, 0x0, 0x1}
```
After that `TMC`, the scalar warp did not fetch the ready flag store.
## Correct Pattern
When a value is consumed under an all-lane mask, define that value under the
same all-lane mask unless the instruction semantics explicitly use only lane 0.
For switching back to lane 0 from an all-lane region, keep the immediate
materialization and the `TMC` adjacent inside the all-lane region:
```asm
vx_tmc all_lanes
...
fence rw, rw
li t2, 1
vx_tmc t2
```
The expected trace at the final `TMC` is:
```text
rs1_data={0x1, 0x1, 0x1, 0x1}
```
The library helper `vx_tmc_one()` follows this pattern because it emits
`li a0, 1` and `vx_tmc a0` in the same volatile asm block. It is safe when
called while all lanes are active.
## Fast Log Checks
Use the simulation log to distinguish a simulator stall from a kernel-level
wait loop:
```sh
stat -c '%s %y' chipyard/sims/verilator/output/chipyard.harness.TestHarness.VirgoBlackwellConfig/kernel.radiance.log
tail -n 80 chipyard/sims/verilator/output/chipyard.harness.TestHarness.VirgoBlackwellConfig/kernel.radiance.log
```
If the file is still growing but the tail repeatedly shows the same PC range,
the simulator is alive and the kernel is probably spinning.
For a ready flag handoff, check both the polling load and the producer store:
```sh
rg -n "0x20000430|0x9900|wid=0, PC=0x8000014|wid=0, PC=0x8000015" \
chipyard/sims/verilator/output/chipyard.harness.TestHarness.VirgoBlackwellConfig/kernel.radiance.log
```
Interpretation:
- load repeatedly returns `0` and no later `0x9900` store exists: producer did
not reach the ready flag store;
- ready flag store exists but consumer still reads old data: investigate memory
ordering or address aliasing;
- producer stops immediately after a `TMC`: inspect that `TMC` source register
across all lanes.
## Dump Checks
Check that the final lane-mask narrowing defines `1` immediately before `TMC`
inside the all-lane region:
```sh
rg -n "li\\s+.*1|vx_tmc" \
kernels/wu_arch_cases/case12_3_scalar_tmem_lane_store/kernel.radiance.dump
```
The fixed `case12_3` dump has:
```text
li t2, 1
vx_tmc t2
auipc a3, 1
sw a5, ...
```
## Case12-Case15 Audit
The following cases were checked after fixing `case12_3`:
- `case12_flash_pv_accum`
- `case12_2_flash_pv_p_probe`
- `case13_flash_pv_two_warps`
- `case14_flash_pv_k64`
- `case15_flash_softmax_pv_stage`
They switch to all lanes with `vx_tmc(wu_bw_all_lanes_mask())` and switch back
with `vx_tmc_one()`. Their dumps show the safe adjacent sequence:
```text
li a0, 1
vx_tmc a0
```
No analogous source change is needed for those cases.
## Scalar TMEM Fill Lane Coverage
`wu_bw_fill_tmem_tile()` must run with all fragment lanes active. Scalar TMEM
store writes the active lane's source word into the matching TMEM fragment word:
```text
TMEM[addr].word[lane] = rs2_data[lane]
```
Therefore a fill loop executed with `tmask=0001` only initializes word 0 of
each 16-byte fragment. Word 1 and later can retain old TMEM data. In
`case12_1_scalar_tmem_cb_probe`, this showed up as a normal completion path
followed by verification failure `0x14`; `g_aux[0]` was `1`, and `g_aux[1]`
held the stale copied-back word.
The correct pattern is:
```c++
vx_tmc(wu_bw_all_lanes_mask());
wu_bw_fill_tmem_tile(wu_bw_tmem_a_byte_base(0), WU_BW_FP16_ONE_PACKED);
vx_tmc_one();
```
The dump should show the fill loop bracketed by all-lane and lane-0 masks:
```text
li a5, 15
vx_tmc a5
...
vx_cmov zero, a0, a5, a2
...
li a0, 1
vx_tmc a0
```

View File

@@ -3,6 +3,8 @@ VX_SRCS = kernel.cpp
VX_CFLAGS += -I.. VX_CFLAGS += -I..
VORTEX_KN_PATH ?= $(realpath ../../../lib) VORTEX_KN_PATH ?= $(realpath ../../../lib)
GEMMINI_SW_PATH ?= $(realpath ../../../lib/gemmini) GEMMINI_SW_PATH ?= $(realpath ../../../lib/gemmini)
LLVM_VORTEX ?= $(realpath ../../../../toolchain/llvm-r8)
RISCV_TOOLCHAIN_PATH ?= $(realpath ../../../../toolchain/vortex-toolchain-prebuilt/riscv-gnu-toolchain)
OPTS ?= -n1 OPTS ?= -n1
include ../../common.mk include ../../common.mk

View File

@@ -0,0 +1,3 @@
PROJECT = wu_arch_case09_scalar_tmem_ldst
include ../case.mk

View File

@@ -0,0 +1,7 @@
# case09_scalar_tmem_ldst
Smoke test for scalar warp direct TMEM access.
The leader scalar lane writes a lane-distinct word vector to a TMEM row with
`CUSTOM1 func7=0x30 func3=1`, reads it back with `CUSTOM1 func7=0x30 func3=0`,
and reports pass only if the lane 0 value matches.

View File

@@ -0,0 +1,39 @@
#include "common_wu_min.h"
static inline uint32_t wu_scalar_tmem_ld(uint32_t addr) {
uint32_t value;
asm volatile(".insn r %[custom1], 0, 0x30, %[value], %[addr], x0"
: [value] "=r"(value)
: [custom1] "i"(RISCV_CUSTOM1), [addr] "r"(addr)
: "memory");
return value;
}
static inline void wu_scalar_tmem_st(uint32_t addr, uint32_t value) {
asm volatile(".insn r %[custom1], 1, 0x30, x0, %[addr], %[value]"
:
: [custom1] "i"(RISCV_CUSTOM1), [addr] "r"(addr), [value] "r"(value)
: "memory");
}
extern "C" int wu_main() {
if (!wu_is_leader()) {
return 0;
}
wu_case_reset();
const uint32_t addr = 0x24u;
const uint32_t expected = 0x5a090000u | wu_tid();
wu_scalar_tmem_st(addr, expected);
const uint32_t observed = wu_scalar_tmem_ld(addr);
if (observed != expected) {
g_aux[0] = observed;
wu_case_fail(0x09u);
return 1;
}
wu_case_pass();
return 0;
}

View File

@@ -0,0 +1,3 @@
PROJECT = wu_arch_case10_tensor_scalar_tmem_handoff
include ../case.mk

View File

@@ -0,0 +1,8 @@
# case10_tensor_scalar_tmem_handoff
Validates the tensor-to-scalar TMEM handoff needed by FlashAttention.
The tensor warp initializes TMEM A/C, runs one BWGMMA, and leaves the result in
TMEM C. The scalar leader waits for the tensor warp and then reads several C
fragments through the scalar TMEM load instruction. Each lane is expected to
observe the HGEMM value `0x42820000`.

View File

@@ -0,0 +1,111 @@
#include "../common_wu_min.h"
#define DEV_SMEM_START_ADDR 0xff000000u
#define WU_CASE_TMEM_HANDOFF_BASE 0x7700u
#define WU_TMEM_TILE_BYTES 1024u
#define WU_TMEM_FRAGMENT_BYTES 16u
#define WU_TMEM_C_BYTE_BASE 1024u
#define WU_TMEM_EXPECTED_HGEMM 0x42820000u
static_assert(NUM_TENSOR_WARPS >= 1, "case10 requires at least one tensor warp");
static_assert(NUM_THREADS == 4, "case10 expects the 4-lane Blackwell tensor core");
#define BW_REP2(x) x, x
#define BW_REP4(x) BW_REP2(x), BW_REP2(x)
extern "C" {
volatile uint32_t g_case10_a_row[4] __attribute__((aligned(16))) = {
BW_REP4(0x3c003c00u)};
volatile uint32_t g_case10_b_row[4] __attribute__((aligned(16))) = {
BW_REP4(0x40004000u)};
volatile uint32_t g_case10_c_row[4] __attribute__((aligned(16))) = {
BW_REP4(0x3f800000u)};
}
#undef BW_REP2
#undef BW_REP4
static inline uint32_t wu_scalar_tmem_ld(uint32_t frag_addr) {
uint32_t value;
asm volatile(".insn r %[custom1], 0, 0x30, %[value], %[addr], x0"
: [value] "=r"(value)
: [custom1] "i"(RISCV_CUSTOM1), [addr] "r"(frag_addr)
: "memory");
return value;
}
extern "C" void __attribute__((naked, noinline, used)) tensor_case10_worker() {
asm volatile(
"csrr x5, %[csr_wid]\n\t"
"addi x1, x5, -%[num_scalar_warps]\n\t"
"slli x1, x1, 11\n\t"
"addi x2, x1, 1024\n\t"
"la x6, g_case10_a_row\n\t"
"la x3, g_case10_c_row\n\t"
"li x7, 0\n\t"
"1:\n\t"
"add x4, x1, x7\n\t"
".insn r %[custom3], 2, 0, x0, x4, x6\n\t"
"add x4, x2, x7\n\t"
".insn r %[custom3], 2, 0, x0, x4, x3\n\t"
"addi x7, x7, 16\n\t"
"li x4, %[tile_bytes]\n\t"
"blt x7, x4, 1b\n\t"
".insn r %[custom3], 3, 0, x0, x0, x0\n\t"
"li x4, %[smem_base]\n\t"
".insn r %[custom3], 0, 0, x2, x1, x4\n\t"
".insn r %[custom3], 1, 0, x0, x0, x0\n\t"
"slli x6, x5, 2\n\t"
"la x7, g_seen\n\t"
"add x7, x7, x6\n\t"
"li x6, %[handoff_base]\n\t"
"or x6, x6, x5\n\t"
"sw x6, 0(x7)\n\t"
".insn r %[custom0], 0, 0, x0, x0, x0\n\t"
"2: j 2b\n\t"
:
: [csr_wid] "i"(VX_CSR_WARP_ID),
[custom0] "i"(RISCV_CUSTOM0),
[custom3] "i"(RISCV_CUSTOM3),
[smem_base] "i"(DEV_SMEM_START_ADDR),
[handoff_base] "i"(WU_CASE_TMEM_HANDOFF_BASE),
[num_scalar_warps] "i"(NUM_SCALAR_WARPS),
[tile_bytes] "i"(WU_TMEM_TILE_BYTES)
: "memory");
}
extern "C" int wu_main() {
if (!wu_is_leader()) {
return 0;
}
wu_case_reset();
volatile uint32_t *smem_b =
reinterpret_cast<volatile uint32_t *>(DEV_SMEM_START_ADDR);
for (uint32_t i = 0; i < WU_TMEM_TILE_BYTES / sizeof(uint32_t); ++i) {
smem_b[i] = g_case10_b_row[i & 3u];
}
asm volatile("fence rw, rw" ::: "memory");
vx_spawn_tensor(1u << NUM_SCALAR_WARPS, tensor_case10_worker);
if (wu_wait_seen_mask(1u << NUM_SCALAR_WARPS, WU_CASE_TMEM_HANDOFF_BASE) != 0) {
wu_case_fail(0x10u);
return 1;
}
const uint32_t c_frag_base = WU_TMEM_C_BYTE_BASE / WU_TMEM_FRAGMENT_BYTES;
for (uint32_t i = 0; i < 8; ++i) {
const uint32_t observed = wu_scalar_tmem_ld(c_frag_base + i);
if (observed != WU_TMEM_EXPECTED_HGEMM) {
g_aux[0] = i;
g_aux[1] = observed;
wu_case_fail(0x11u);
return 1;
}
}
wu_case_pass();
return 0;
}

View File

@@ -0,0 +1,3 @@
PROJECT = wu_arch_case11_scalar_tmem_softmax_stage
include ../case.mk

View File

@@ -0,0 +1,10 @@
# case11_scalar_tmem_softmax_stage
Validates the FlashAttention softmax-stage TMEM path.
The tensor warp writes HGEMM results into TMEM C. The scalar warp reads TMEM C,
performs a deterministic lane-wise transform, and writes all four lanes back
into one TMEM A-region fragment. The tensor warp then uses TCGEN05_CB to copy
that TMEM fragment to global memory so scalar code can verify that tensor-side
TMEM reads observe the scalar writeback with the correct lane ordering and
write mask.

View File

@@ -0,0 +1,200 @@
#include "../common_wu_min.h"
#define DEV_SMEM_START_ADDR 0xff000000u
#define WU_CASE_TMEM_STAGE_HGEMM_BASE 0x7800u
#define WU_CASE_TMEM_STAGE_VERIFY_BASE 0x7900u
#define WU_TMEM_TILE_BYTES 1024u
#define WU_TMEM_FRAGMENT_BYTES 16u
#define WU_TMEM_C_BYTE_BASE 1024u
#define WU_TMEM_SCALAR_WRITE_BYTE_ADDR 512u
#define WU_TMEM_EXPECTED_HGEMM 0x42820000u
#define WU_TMEM_STAGE_VALUE_BASE 0x42820001u
static_assert(NUM_TENSOR_WARPS >= 1, "case11 requires at least one tensor warp");
static_assert(NUM_THREADS == 4, "case11 expects the 4-lane Blackwell tensor core");
#define BW_REP2(x) x, x
#define BW_REP4(x) BW_REP2(x), BW_REP2(x)
extern "C" {
volatile uint32_t g_case11_a_row[4] __attribute__((aligned(16))) = {
BW_REP4(0x3c003c00u)};
volatile uint32_t g_case11_b_row[4] __attribute__((aligned(16))) = {
BW_REP4(0x40004000u)};
volatile uint32_t g_case11_c_row[4] __attribute__((aligned(16))) = {
BW_REP4(0x3f800000u)};
volatile uint32_t g_case11_out[4] __attribute__((aligned(16)));
volatile uint32_t g_case11_scalar_seen[4] __attribute__((aligned(16)));
}
#undef BW_REP2
#undef BW_REP4
static inline uint32_t wu_scalar_tmem_ld(uint32_t frag_addr) {
uint32_t value;
asm volatile(".insn r %[custom1], 0, 0x30, %[value], %[addr], x0"
: [value] "=r"(value)
: [custom1] "i"(RISCV_CUSTOM1), [addr] "r"(frag_addr)
: "memory");
return value;
}
static inline void wu_scalar_tmem_st(uint32_t frag_addr, uint32_t value) {
asm volatile(".insn r %[custom1], 1, 0x30, x0, %[addr], %[value]"
:
: [custom1] "i"(RISCV_CUSTOM1), [addr] "r"(frag_addr),
[value] "r"(value)
: "memory");
}
extern "C" void __attribute__((naked, noinline, used)) tensor_case11_worker() {
asm volatile(
"csrr x5, %[csr_wid]\n\t"
"addi x1, x5, -%[num_scalar_warps]\n\t"
"slli x1, x1, 11\n\t"
"addi x2, x1, 1024\n\t"
"la x6, g_case11_a_row\n\t"
"la x3, g_case11_c_row\n\t"
"li x7, 0\n\t"
"1:\n\t"
"add x4, x1, x7\n\t"
".insn r %[custom3], 2, 0, x0, x4, x6\n\t"
"add x4, x2, x7\n\t"
".insn r %[custom3], 2, 0, x0, x4, x3\n\t"
"addi x7, x7, 16\n\t"
"li x4, %[tile_bytes]\n\t"
"blt x7, x4, 1b\n\t"
".insn r %[custom3], 3, 0, x0, x0, x0\n\t"
"li x4, %[smem_base]\n\t"
".insn r %[custom3], 0, 0, x2, x1, x4\n\t"
".insn r %[custom3], 1, 0, x0, x0, x0\n\t"
"slli x6, x5, 2\n\t"
"la x7, g_seen\n\t"
"add x7, x7, x6\n\t"
"li x6, %[hgemm_base]\n\t"
"or x6, x6, x5\n\t"
"sw x6, 0(x7)\n\t"
"3:\n\t"
"la x6, g_case_mem\n\t"
"lw x7, 0(x6)\n\t"
"li x4, %[verify_base]\n\t"
"bne x7, x4, 3b\n\t"
"addi x4, x1, %[write_byte_addr]\n\t"
"la x6, g_case11_out\n\t"
".insn r %[custom3], 6, 0, x0, x4, x6\n\t"
".insn r %[custom3], 3, 0, x0, x0, x0\n\t"
"csrr x5, %[csr_wid]\n\t"
"slli x6, x5, 2\n\t"
"la x7, g_seen\n\t"
"add x7, x7, x6\n\t"
"li x6, %[verify_base]\n\t"
"or x6, x6, x5\n\t"
"sw x6, 0(x7)\n\t"
".insn r %[custom0], 0, 0, x0, x0, x0\n\t"
"2: j 2b\n\t"
:
: [csr_wid] "i"(VX_CSR_WARP_ID),
[custom0] "i"(RISCV_CUSTOM0),
[custom3] "i"(RISCV_CUSTOM3),
[smem_base] "i"(DEV_SMEM_START_ADDR),
[hgemm_base] "i"(WU_CASE_TMEM_STAGE_HGEMM_BASE),
[verify_base] "i"(WU_CASE_TMEM_STAGE_VERIFY_BASE),
[num_scalar_warps] "i"(NUM_SCALAR_WARPS),
[tile_bytes] "i"(WU_TMEM_TILE_BYTES),
[write_byte_addr] "i"(WU_TMEM_SCALAR_WRITE_BYTE_ADDR)
: "memory");
}
extern "C" int wu_main() {
if (vx_core_id() != 0 || vx_warp_id() != 0) {
return 0;
}
const uint32_t tid = wu_tid();
if (tid == 0) {
wu_case_reset();
for (uint32_t i = 0; i < 4; ++i) {
g_case11_out[i] = 0;
g_case11_scalar_seen[i] = 0;
}
volatile uint32_t *smem_b =
reinterpret_cast<volatile uint32_t *>(DEV_SMEM_START_ADDR);
for (uint32_t i = 0; i < WU_TMEM_TILE_BYTES / sizeof(uint32_t); ++i) {
smem_b[i] = g_case11_b_row[i & 3u];
}
}
asm volatile("fence rw, rw" ::: "memory");
if (tid == 0) {
vx_spawn_tensor(1u << NUM_SCALAR_WARPS, tensor_case11_worker);
if (wu_wait_seen_mask(1u << NUM_SCALAR_WARPS,
WU_CASE_TMEM_STAGE_HGEMM_BASE) != 0) {
g_case_mem[1] = 0x20u;
}
}
asm volatile("fence rw, rw" ::: "memory");
if (g_case_mem[1] != 0) {
if (tid == 0) {
wu_case_fail(g_case_mem[1]);
}
return 1;
}
const uint32_t c_frag_base = WU_TMEM_C_BYTE_BASE / WU_TMEM_FRAGMENT_BYTES;
const uint32_t observed = wu_scalar_tmem_ld(c_frag_base);
g_case11_scalar_seen[tid] = observed;
const uint32_t write_frag =
WU_TMEM_SCALAR_WRITE_BYTE_ADDR / WU_TMEM_FRAGMENT_BYTES;
wu_scalar_tmem_st(write_frag, observed + 1u + tid);
asm volatile("fence rw, rw" ::: "memory");
if (tid == 0) {
if (g_case11_scalar_seen[0] != WU_TMEM_EXPECTED_HGEMM) {
g_aux[0] = 0;
g_aux[1] = g_case11_scalar_seen[0];
g_case_mem[1] = 0x21u;
}
}
asm volatile("fence rw, rw" ::: "memory");
if (g_case_mem[1] != 0) {
if (tid == 0) {
wu_case_fail(g_case_mem[1]);
}
return 1;
}
if (tid == 0) {
g_case_mem[0] = WU_CASE_TMEM_STAGE_VERIFY_BASE;
}
asm volatile("fence rw, rw" ::: "memory");
if (tid == 0) {
if (wu_wait_seen_mask(1u << NUM_SCALAR_WARPS,
WU_CASE_TMEM_STAGE_VERIFY_BASE) != 0) {
g_case_mem[1] = 0x22u;
}
}
asm volatile("fence rw, rw" ::: "memory");
if (g_case_mem[1] != 0) {
if (tid == 0) {
wu_case_fail(g_case_mem[1]);
}
return 1;
}
if (tid == 0) {
if (g_case11_out[0] != WU_TMEM_STAGE_VALUE_BASE) {
g_aux[0] = 0;
g_aux[1] = g_case11_out[0];
wu_case_fail(0x23u);
return 1;
}
wu_case_pass();
}
return 0;
}

View File

@@ -0,0 +1,3 @@
PROJECT = case12_flash_pv_accum
include ../case.mk

View File

@@ -0,0 +1,10 @@
# case12.1 scalar TMEM CB probe
Validates the scalar-to-tensor TMEM handoff without BWGMMA. Scalar warp 0 fills
tensor block 0 TMEM A rows `0..63` with packed fp16 `1.0`. Tensor warp
`NUM_SCALAR_WARPS` waits for that fill, then uses `tcgen05_cb` to copy the same
TMEM A rows back to global memory.
The scalar leader verifies all 256 copied words are `0x3c003c00`. A mismatch
means scalar TMEM stores are not visible to the tensor TMEM read/copy path at
the expected row addresses.

View File

@@ -0,0 +1,108 @@
#include "../common_wu_blackwell_fa.h"
#define WU_CASE12_1_INIT_BASE 0x8d00u
#define WU_CASE12_1_DONE_BASE 0x8e00u
#define WU_CASE12_1_COPY_READY 0x8f00u
extern "C" {
volatile uint32_t g_case12_1_out[WU_BW_OUT_WORDS] __attribute__((aligned(16)));
}
extern "C" void __attribute__((naked, noinline, used)) tensor_case12_1_worker() {
asm volatile(
"csrr x5, %[csr_wid]\n\t"
"addi x1, x5, -%[num_scalar_warps]\n\t"
"slli x1, x1, 11\n\t"
"slli x6, x5, 2\n\t"
"la x7, g_seen\n\t"
"add x7, x7, x6\n\t"
"li x6, %[init_base]\n\t"
"or x6, x6, x5\n\t"
"sw x6, 0(x7)\n\t"
"1:\n\t"
"la x6, g_case_mem\n\t"
"lw x7, 0(x6)\n\t"
"li x4, %[copy_ready]\n\t"
"bne x7, x4, 1b\n\t"
"la x3, g_case12_1_out\n\t"
"li x7, 0\n\t"
"2:\n\t"
"add x4, x1, x7\n\t"
"add x6, x3, x7\n\t"
".insn r %[custom3], 6, 0, x0, x4, x6\n\t"
"addi x7, x7, 16\n\t"
"li x4, %[tile_bytes]\n\t"
"blt x7, x4, 2b\n\t"
".insn r %[custom3], 3, 0, x0, x0, x0\n\t"
"csrr x5, %[csr_wid]\n\t"
"slli x6, x5, 2\n\t"
"la x7, g_seen\n\t"
"add x7, x7, x6\n\t"
"li x6, %[done_base]\n\t"
"or x6, x6, x5\n\t"
"sw x6, 0(x7)\n\t"
".insn r %[custom0], 0, 0, x0, x0, x0\n\t"
"3: j 3b\n\t"
:
: [csr_wid] "i"(VX_CSR_WARP_ID), [custom0] "i"(RISCV_CUSTOM0),
[custom3] "i"(RISCV_CUSTOM3),
[num_scalar_warps] "i"(NUM_SCALAR_WARPS),
[tile_bytes] "i"(WU_BW_TMEM_TILE_BYTES),
[init_base] "i"(WU_CASE12_1_INIT_BASE),
[done_base] "i"(WU_CASE12_1_DONE_BASE),
[copy_ready] "i"(WU_CASE12_1_COPY_READY)
: "memory");
}
extern "C" int wu_main() {
if (vx_core_id() != 0 || vx_warp_id() != 0) {
return 0;
}
const uint32_t tid = wu_tid();
if (tid == 0) {
wu_case_reset();
for (uint32_t i = 0; i < WU_BW_OUT_WORDS; ++i) {
g_case12_1_out[i] = 0;
}
vx_spawn_tensor(1u << NUM_SCALAR_WARPS, tensor_case12_1_worker);
if (wu_wait_seen_mask(1u << NUM_SCALAR_WARPS, WU_CASE12_1_INIT_BASE) !=
0) {
g_case_mem[1] = 0x12u;
}
}
asm volatile("fence rw, rw" ::: "memory");
if (g_case_mem[1] == 0) {
vx_tmc(wu_bw_all_lanes_mask());
wu_bw_fill_tmem_tile(wu_bw_tmem_a_byte_base(0), WU_BW_FP16_ONE_PACKED);
vx_tmc_one();
}
asm volatile("fence rw, rw" ::: "memory");
if (tid == 0) {
g_case_mem[0] = WU_CASE12_1_COPY_READY;
if (g_case_mem[1] == 0 &&
wu_wait_seen_mask(1u << NUM_SCALAR_WARPS, WU_CASE12_1_DONE_BASE) !=
0) {
g_case_mem[1] = 0x13u;
}
if (g_case_mem[1] == 0) {
volatile uint32_t bad_actual = 0;
const uint32_t bad =
wu_bw_verify_constant(g_case12_1_out, WU_BW_OUT_WORDS,
WU_BW_FP16_ONE_PACKED, &bad_actual);
if (bad != WU_BW_OUT_WORDS) {
g_aux[0] = bad;
g_aux[1] = bad_actual;
g_case_mem[1] = 0x14u;
}
}
if (g_case_mem[1] != 0) {
wu_case_fail(g_case_mem[1]);
return 1;
}
wu_case_pass();
}
return 0;
}

View File

@@ -0,0 +1,3 @@
PROJECT = case12_2_flash_pv_p_probe
include ../case.mk

View File

@@ -0,0 +1,11 @@
# case12_2_flash_pv_p_probe
Diagnostic for `case12_flash_pv_accum`.
Scalar warp 0 fills tensor block 0 TMEM A with packed fp16 `1.0`, using the
same `wu_bw_fill_tmem_tile()` path as case12. Tensor warp `NUM_SCALAR_WARPS`
then copies the full TMEM A tile back to global memory with `tcgen05_cb`.
The scalar leader verifies all copied words are `0x3c003c00`. A mismatch means
case12's scalar-written P tile is not reaching the tensor TMEM read/copy path as
expected.

View File

@@ -0,0 +1,109 @@
#include "../common_wu_blackwell_fa.h"
#define WU_CASE12_2_INIT_BASE 0x9600u
#define WU_CASE12_2_DONE_BASE 0x9700u
#define WU_CASE12_2_COPY_READY 0x9800u
extern "C" {
volatile uint32_t g_case12_2_out[WU_BW_OUT_WORDS]
__attribute__((aligned(16)));
}
extern "C" void __attribute__((naked, noinline, used))
tensor_case12_2_worker() {
asm volatile(
"csrr x5, %[csr_wid]\n\t"
"addi x1, x5, -%[num_scalar_warps]\n\t"
"slli x1, x1, 11\n\t"
"slli x6, x5, 2\n\t"
"la x7, g_seen\n\t"
"add x7, x7, x6\n\t"
"li x6, %[init_base]\n\t"
"or x6, x6, x5\n\t"
"sw x6, 0(x7)\n\t"
"1:\n\t"
"la x6, g_case_mem\n\t"
"lw x7, 0(x6)\n\t"
"li x4, %[copy_ready]\n\t"
"bne x7, x4, 1b\n\t"
"la x3, g_case12_2_out\n\t"
"li x7, 0\n\t"
"2:\n\t"
"add x4, x1, x7\n\t"
"add x6, x3, x7\n\t"
".insn r %[custom3], 6, 0, x0, x4, x6\n\t"
"addi x7, x7, 16\n\t"
"li x4, %[tile_bytes]\n\t"
"blt x7, x4, 2b\n\t"
".insn r %[custom3], 3, 0, x0, x0, x0\n\t"
"csrr x5, %[csr_wid]\n\t"
"slli x6, x5, 2\n\t"
"la x7, g_seen\n\t"
"add x7, x7, x6\n\t"
"li x6, %[done_base]\n\t"
"or x6, x6, x5\n\t"
"sw x6, 0(x7)\n\t"
".insn r %[custom0], 0, 0, x0, x0, x0\n\t"
"3: j 3b\n\t"
:
: [csr_wid] "i"(VX_CSR_WARP_ID), [custom0] "i"(RISCV_CUSTOM0),
[custom3] "i"(RISCV_CUSTOM3),
[num_scalar_warps] "i"(NUM_SCALAR_WARPS),
[tile_bytes] "i"(WU_BW_TMEM_TILE_BYTES),
[init_base] "i"(WU_CASE12_2_INIT_BASE),
[done_base] "i"(WU_CASE12_2_DONE_BASE),
[copy_ready] "i"(WU_CASE12_2_COPY_READY)
: "memory");
}
extern "C" int wu_main() {
if (vx_core_id() != 0 || vx_warp_id() != 0) {
return 0;
}
const uint32_t tid = wu_tid();
const uint32_t tensor_mask = 1u << NUM_SCALAR_WARPS;
if (tid == 0) {
wu_case_reset();
for (uint32_t i = 0; i < WU_BW_OUT_WORDS; ++i) {
g_case12_2_out[i] = 0;
}
vx_spawn_tensor(tensor_mask, tensor_case12_2_worker);
if (wu_wait_seen_mask(tensor_mask, WU_CASE12_2_INIT_BASE) != 0) {
g_case_mem[1] = 0x51u;
}
}
asm volatile("fence rw, rw" ::: "memory");
if (g_case_mem[1] == 0) {
vx_tmc(wu_bw_all_lanes_mask());
wu_bw_fill_tmem_tile(wu_bw_tmem_a_byte_base(0), WU_BW_FP16_ONE_PACKED);
vx_tmc_one();
}
asm volatile("fence rw, rw" ::: "memory");
if (tid == 0) {
g_case_mem[0] = WU_CASE12_2_COPY_READY;
if (g_case_mem[1] == 0 &&
wu_wait_seen_mask(tensor_mask, WU_CASE12_2_DONE_BASE) != 0) {
g_case_mem[1] = 0x52u;
}
if (g_case_mem[1] == 0) {
volatile uint32_t bad_actual = 0;
const uint32_t bad =
wu_bw_verify_constant(g_case12_2_out, WU_BW_OUT_WORDS,
WU_BW_FP16_ONE_PACKED, &bad_actual);
if (bad != WU_BW_OUT_WORDS) {
g_aux[0] = bad;
g_aux[1] = bad_actual;
g_case_mem[1] = 0x53u;
}
}
if (g_case_mem[1] != 0) {
wu_case_fail(g_case_mem[1]);
return 1;
}
wu_case_pass();
}
return 0;
}

View File

@@ -0,0 +1,3 @@
PROJECT = case12_3_scalar_tmem_lane_store
include ../case.mk

View File

@@ -0,0 +1,14 @@
# case12_3_scalar_tmem_lane_store
Validates scalar TMEM store lane-coalesced semantics.
One scalar warp writes a single 16-byte TMEM fragment. Each active scalar lane supplies one 32-bit word:
```text
word0 = lane0.rs2
word1 = lane1.rs2
word2 = lane2.rs2
word3 = lane3.rs2
```
The tensor copy-back path then copies that fragment to memory. This catches implementations that only write lane0 data or broadcast one scalar value.

View File

@@ -0,0 +1,91 @@
#include "../common_wu_blackwell_fa.h"
#define WU_CASE12_3_COPY_READY 0x9900u
#define WU_CASE12_3_DONE_BASE 0x9a00u
extern "C" {
volatile uint32_t g_case12_3_out[4] __attribute__((aligned(16)));
}
extern "C" void __attribute__((naked, noinline, used))
tensor_case12_3_copy_worker() {
asm volatile(
"csrr x5, %[csr_wid]\n\t"
"addi x1, x5, -%[num_scalar_warps]\n\t"
"slli x1, x1, 11\n\t"
"1:\n\t"
"la x6, g_case_mem\n\t"
"lw x7, 0(x6)\n\t"
"li x4, %[copy_ready]\n\t"
"bne x7, x4, 1b\n\t"
"la x6, g_case12_3_out\n\t"
".insn r %[custom3], 6, 0, x0, x1, x6\n\t"
".insn r %[custom3], 3, 0, x0, x0, x0\n\t"
"csrr x5, %[csr_wid]\n\t"
"slli x6, x5, 2\n\t"
"la x7, g_seen\n\t"
"add x7, x7, x6\n\t"
"li x6, %[done_base]\n\t"
"or x6, x6, x5\n\t"
"sw x6, 0(x7)\n\t"
".insn r %[custom0], 0, 0, x0, x0, x0\n\t"
"2: j 2b\n\t"
:
: [csr_wid] "i"(VX_CSR_WARP_ID), [custom0] "i"(RISCV_CUSTOM0),
[custom3] "i"(RISCV_CUSTOM3),
[num_scalar_warps] "i"(NUM_SCALAR_WARPS),
[copy_ready] "i"(WU_CASE12_3_COPY_READY),
[done_base] "i"(WU_CASE12_3_DONE_BASE)
: "memory");
}
extern "C" int wu_main() {
if (vx_core_id() != 0 || vx_warp_id() != 0) {
return 0;
}
const uint32_t tensor_mask = 1u << NUM_SCALAR_WARPS;
wu_case_reset();
for (uint32_t i = 0; i < 4; ++i) {
g_case12_3_out[i] = 0;
}
vx_spawn_tensor(tensor_mask, tensor_case12_3_copy_worker);
asm volatile("fence rw, rw" ::: "memory");
const uint32_t frag_addr =
wu_bw_tmem_a_byte_base(0) / WU_BW_TMEM_FRAGMENT_BYTES;
asm volatile(
".insn r %[custom0], 0, 0, x0, %[all_lanes], x0\n\t"
"csrr t0, %[csr_tid]\n\t"
"lui t1, 0x5a120\n\t"
"or t1, t1, t0\n\t"
".insn r %[custom1], 1, 0x30, x0, %[addr], t1\n\t"
"fence rw, rw\n\t"
"li t2, 1\n\t"
".insn r %[custom0], 0, 0, x0, t2, x0"
:
: [custom0] "i"(RISCV_CUSTOM0), [custom1] "i"(RISCV_CUSTOM1),
[csr_tid] "i"(VX_CSR_THREAD_ID), [addr] "r"(frag_addr),
[all_lanes] "r"(wu_bw_all_lanes_mask())
: "t0", "t1", "t2", "memory");
g_case_mem[0] = WU_CASE12_3_COPY_READY;
if (wu_wait_seen_mask(tensor_mask, WU_CASE12_3_DONE_BASE) != 0) {
wu_case_fail(0x51u);
return 1;
}
for (uint32_t lane = 0; lane < NUM_THREADS; ++lane) {
const uint32_t expected = 0x5a120000u | lane;
if (g_case12_3_out[lane] != expected) {
g_aux[0] = lane;
g_aux[1] = g_case12_3_out[lane];
wu_case_fail(0x52u);
return 1;
}
}
wu_case_pass();
return 0;
}

View File

@@ -0,0 +1,3 @@
PROJECT = case12_flash_pv_accum
include ../case.mk

View File

@@ -0,0 +1,7 @@
# case12_flash_pv_accum
Validates the first Wu Blackwell FlashAttention PV substage.
Scalar warp 0 writes a full `16x32` fp16 `P` tile into TMEM A. One tensor warp
uses BWGMMA with a `32x16` fp16 `V` tile in SMEM and a preloaded fp32 TMEM C
tile. The expected result is `O = 3.0 + P @ V = 67.0` for every output element.

View File

@@ -0,0 +1,127 @@
#include "../common_wu_blackwell_fa.h"
#define WU_CASE12_INIT_BASE 0x8a00u
#define WU_CASE12_DONE_BASE 0x8b00u
#define WU_CASE12_P_READY 0x8c00u
extern "C" {
volatile uint32_t g_case12_o_row[4] __attribute__((aligned(16))) = {
WU_BW_FP32_THREE, WU_BW_FP32_THREE, WU_BW_FP32_THREE,
WU_BW_FP32_THREE};
volatile uint32_t g_case12_out[WU_BW_OUT_WORDS] __attribute__((aligned(16)));
}
extern "C" void __attribute__((naked, noinline, used)) tensor_case12_worker() {
asm volatile(
"csrr x5, %[csr_wid]\n\t"
"addi x1, x5, -%[num_scalar_warps]\n\t"
"slli x1, x1, 11\n\t"
"addi x2, x1, %[c_offset]\n\t"
"la x3, g_case12_o_row\n\t"
"li x7, 0\n\t"
"1:\n\t"
"add x4, x2, x7\n\t"
".insn r %[custom3], 2, 0, x0, x4, x3\n\t"
"addi x7, x7, 16\n\t"
"li x4, %[tile_bytes]\n\t"
"blt x7, x4, 1b\n\t"
".insn r %[custom3], 3, 0, x0, x0, x0\n\t"
"slli x6, x5, 2\n\t"
"la x7, g_seen\n\t"
"add x7, x7, x6\n\t"
"li x6, %[init_base]\n\t"
"or x6, x6, x5\n\t"
"sw x6, 0(x7)\n\t"
"2:\n\t"
"la x6, g_case_mem\n\t"
"lw x7, 0(x6)\n\t"
"li x4, %[p_ready]\n\t"
"bne x7, x4, 2b\n\t"
"li x4, %[smem_base]\n\t"
".insn r %[custom3], 0, 0, x2, x1, x4\n\t"
".insn r %[custom3], 1, 0, x0, x0, x0\n\t"
"la x3, g_case12_out\n\t"
"li x7, 0\n\t"
"3:\n\t"
"add x4, x2, x7\n\t"
"add x6, x3, x7\n\t"
".insn r %[custom3], 6, 0, x0, x4, x6\n\t"
"addi x7, x7, 16\n\t"
"li x4, %[tile_bytes]\n\t"
"blt x7, x4, 3b\n\t"
".insn r %[custom3], 3, 0, x0, x0, x0\n\t"
"csrr x5, %[csr_wid]\n\t"
"slli x6, x5, 2\n\t"
"la x7, g_seen\n\t"
"add x7, x7, x6\n\t"
"li x6, %[done_base]\n\t"
"or x6, x6, x5\n\t"
"sw x6, 0(x7)\n\t"
".insn r %[custom0], 0, 0, x0, x0, x0\n\t"
"4: j 4b\n\t"
:
: [csr_wid] "i"(VX_CSR_WARP_ID), [custom0] "i"(RISCV_CUSTOM0),
[custom3] "i"(RISCV_CUSTOM3),
[num_scalar_warps] "i"(NUM_SCALAR_WARPS),
[c_offset] "i"(WU_BW_TMEM_C_BYTE_OFFSET),
[tile_bytes] "i"(WU_BW_TMEM_TILE_BYTES),
[smem_base] "i"(WU_BW_DEV_SMEM_START_ADDR),
[init_base] "i"(WU_CASE12_INIT_BASE),
[done_base] "i"(WU_CASE12_DONE_BASE),
[p_ready] "i"(WU_CASE12_P_READY)
: "memory");
}
extern "C" int wu_main() {
if (vx_core_id() != 0 || vx_warp_id() != 0) {
return 0;
}
const uint32_t tid = wu_tid();
if (tid == 0) {
wu_case_reset();
for (uint32_t i = 0; i < WU_BW_OUT_WORDS; ++i) {
g_case12_out[i] = 0;
}
wu_bw_fill_smem_tile(
reinterpret_cast<volatile uint32_t *>(WU_BW_DEV_SMEM_START_ADDR),
WU_BW_FP16_TWO_PACKED);
vx_spawn_tensor(1u << NUM_SCALAR_WARPS, tensor_case12_worker);
if (wu_wait_seen_mask(1u << NUM_SCALAR_WARPS, WU_CASE12_INIT_BASE) != 0) {
g_case_mem[1] = 0x12u;
}
}
asm volatile("fence rw, rw" ::: "memory");
if (g_case_mem[1] == 0) {
vx_tmc(wu_bw_all_lanes_mask());
wu_bw_fill_tmem_tile(wu_bw_tmem_a_byte_base(0), WU_BW_FP16_ONE_PACKED);
vx_tmc_one();
}
asm volatile("fence rw, rw" ::: "memory");
if (tid == 0) {
g_case_mem[0] = WU_CASE12_P_READY;
if (g_case_mem[1] == 0 &&
wu_wait_seen_mask(1u << NUM_SCALAR_WARPS, WU_CASE12_DONE_BASE) != 0) {
g_case_mem[1] = 0x13u;
}
if (g_case_mem[1] == 0) {
volatile uint32_t bad_actual = 0;
const uint32_t bad =
wu_bw_verify_constant(g_case12_out, WU_BW_OUT_WORDS,
WU_BW_FP32_SIXTY_SEVEN, &bad_actual);
if (bad != WU_BW_OUT_WORDS) {
g_aux[0] = bad;
g_aux[1] = bad_actual;
g_case_mem[1] = 0x14u;
}
}
if (g_case_mem[1] != 0) {
wu_case_fail(g_case_mem[1]);
return 1;
}
wu_case_pass();
}
return 0;
}

View File

@@ -0,0 +1,3 @@
PROJECT = case13_flash_pv_two_warps
include ../case.mk

View File

@@ -0,0 +1,8 @@
# case13_flash_pv_two_warps
Validates the Wu Blackwell FlashAttention PV substage across both tensor warps.
Scalar warp 0 writes one `P` tile into each tensor warp's TMEM A partition.
Both tensor warps consume the same SMEM `V` tile, accumulate into their own TMEM
C tiles, copy out contiguous row blocks, and stop. Every fp32 output is expected
to be `67.0`.

View File

@@ -0,0 +1,135 @@
#include "../common_wu_blackwell_fa.h"
#define WU_CASE13_INIT_BASE 0x8d00u
#define WU_CASE13_DONE_BASE 0x8e00u
#define WU_CASE13_P_READY 0x8f00u
#define WU_CASE13_OUT_WORDS (NUM_TENSOR_WARPS * WU_BW_OUT_WORDS)
extern "C" {
volatile uint32_t g_case13_o_row[4] __attribute__((aligned(16))) = {
WU_BW_FP32_THREE, WU_BW_FP32_THREE, WU_BW_FP32_THREE,
WU_BW_FP32_THREE};
volatile uint32_t g_case13_out[WU_CASE13_OUT_WORDS]
__attribute__((aligned(16)));
}
extern "C" void __attribute__((naked, noinline, used)) tensor_case13_worker() {
asm volatile(
"csrr x5, %[csr_wid]\n\t"
"addi x1, x5, -%[num_scalar_warps]\n\t"
"slli x1, x1, 11\n\t"
"addi x2, x1, %[c_offset]\n\t"
"la x3, g_case13_o_row\n\t"
"li x7, 0\n\t"
"1:\n\t"
"add x4, x2, x7\n\t"
".insn r %[custom3], 2, 0, x0, x4, x3\n\t"
"addi x7, x7, 16\n\t"
"li x4, %[tile_bytes]\n\t"
"blt x7, x4, 1b\n\t"
".insn r %[custom3], 3, 0, x0, x0, x0\n\t"
"slli x6, x5, 2\n\t"
"la x7, g_seen\n\t"
"add x7, x7, x6\n\t"
"li x6, %[init_base]\n\t"
"or x6, x6, x5\n\t"
"sw x6, 0(x7)\n\t"
"2:\n\t"
"la x6, g_case_mem\n\t"
"lw x7, 0(x6)\n\t"
"li x4, %[p_ready]\n\t"
"bne x7, x4, 2b\n\t"
"li x4, %[smem_base]\n\t"
".insn r %[custom3], 0, 0, x2, x1, x4\n\t"
".insn r %[custom3], 1, 0, x0, x0, x0\n\t"
"csrr x5, %[csr_wid]\n\t"
"addi x6, x5, -%[num_scalar_warps]\n\t"
"slli x7, x6, 10\n\t"
"la x3, g_case13_out\n\t"
"add x3, x3, x7\n\t"
"li x7, 0\n\t"
"3:\n\t"
"add x4, x2, x7\n\t"
"add x6, x3, x7\n\t"
".insn r %[custom3], 6, 0, x0, x4, x6\n\t"
"addi x7, x7, 16\n\t"
"li x4, %[tile_bytes]\n\t"
"blt x7, x4, 3b\n\t"
".insn r %[custom3], 3, 0, x0, x0, x0\n\t"
"csrr x5, %[csr_wid]\n\t"
"slli x6, x5, 2\n\t"
"la x7, g_seen\n\t"
"add x7, x7, x6\n\t"
"li x6, %[done_base]\n\t"
"or x6, x6, x5\n\t"
"sw x6, 0(x7)\n\t"
".insn r %[custom0], 0, 0, x0, x0, x0\n\t"
"4: j 4b\n\t"
:
: [csr_wid] "i"(VX_CSR_WARP_ID), [custom0] "i"(RISCV_CUSTOM0),
[custom3] "i"(RISCV_CUSTOM3),
[num_scalar_warps] "i"(NUM_SCALAR_WARPS),
[c_offset] "i"(WU_BW_TMEM_C_BYTE_OFFSET),
[tile_bytes] "i"(WU_BW_TMEM_TILE_BYTES),
[smem_base] "i"(WU_BW_DEV_SMEM_START_ADDR),
[init_base] "i"(WU_CASE13_INIT_BASE),
[done_base] "i"(WU_CASE13_DONE_BASE),
[p_ready] "i"(WU_CASE13_P_READY)
: "memory");
}
extern "C" int wu_main() {
if (vx_core_id() != 0 || vx_warp_id() != 0) {
return 0;
}
const uint32_t tid = wu_tid();
const uint32_t tensor_mask = vx_tensor_warp_mask();
if (tid == 0) {
wu_case_reset();
for (uint32_t i = 0; i < WU_CASE13_OUT_WORDS; ++i) {
g_case13_out[i] = 0;
}
wu_bw_fill_smem_tile(
reinterpret_cast<volatile uint32_t *>(WU_BW_DEV_SMEM_START_ADDR),
WU_BW_FP16_TWO_PACKED);
vx_spawn_tensor(tensor_mask, tensor_case13_worker);
if (wu_wait_seen_mask(tensor_mask, WU_CASE13_INIT_BASE) != 0) {
g_case_mem[1] = 0x21u;
}
}
asm volatile("fence rw, rw" ::: "memory");
if (g_case_mem[1] == 0) {
vx_tmc(wu_bw_all_lanes_mask());
wu_bw_fill_tmem_tile(wu_bw_tmem_a_byte_base(0), WU_BW_FP16_ONE_PACKED);
wu_bw_fill_tmem_tile(wu_bw_tmem_a_byte_base(1), WU_BW_FP16_ONE_PACKED);
vx_tmc_one();
}
asm volatile("fence rw, rw" ::: "memory");
if (tid == 0) {
g_case_mem[0] = WU_CASE13_P_READY;
if (g_case_mem[1] == 0 &&
wu_wait_seen_mask(tensor_mask, WU_CASE13_DONE_BASE) != 0) {
g_case_mem[1] = 0x22u;
}
if (g_case_mem[1] == 0) {
volatile uint32_t bad_actual = 0;
const uint32_t bad =
wu_bw_verify_constant(g_case13_out, WU_CASE13_OUT_WORDS,
WU_BW_FP32_SIXTY_SEVEN, &bad_actual);
if (bad != WU_CASE13_OUT_WORDS) {
g_aux[0] = bad;
g_aux[1] = bad_actual;
g_case_mem[1] = 0x23u;
}
}
if (g_case_mem[1] != 0) {
wu_case_fail(g_case_mem[1]);
return 1;
}
wu_case_pass();
}
return 0;
}

View File

@@ -0,0 +1,3 @@
PROJECT = case14_flash_pv_k64
include ../case.mk

View File

@@ -0,0 +1,7 @@
# case14_flash_pv_k64
Validates PV accumulation over two Blackwell `K=32` BWGMMA steps.
Both tensor warps run two consecutive BWGMMA operations against the same TMEM C
tile, modeling a `K=64` FlashAttention PV accumulation. With `P=1.0`,
`V=2.0`, and `O_init=5.0`, every fp32 output is expected to be `133.0`.

View File

@@ -0,0 +1,136 @@
#include "../common_wu_blackwell_fa.h"
#define WU_CASE14_INIT_BASE 0x9000u
#define WU_CASE14_DONE_BASE 0x9100u
#define WU_CASE14_P_READY 0x9200u
#define WU_CASE14_OUT_WORDS (NUM_TENSOR_WARPS * WU_BW_OUT_WORDS)
extern "C" {
volatile uint32_t g_case14_o_row[4] __attribute__((aligned(16))) = {
WU_BW_FP32_FIVE, WU_BW_FP32_FIVE, WU_BW_FP32_FIVE, WU_BW_FP32_FIVE};
volatile uint32_t g_case14_out[WU_CASE14_OUT_WORDS]
__attribute__((aligned(16)));
}
extern "C" void __attribute__((naked, noinline, used)) tensor_case14_worker() {
asm volatile(
"csrr x5, %[csr_wid]\n\t"
"addi x1, x5, -%[num_scalar_warps]\n\t"
"slli x1, x1, 11\n\t"
"addi x2, x1, %[c_offset]\n\t"
"la x3, g_case14_o_row\n\t"
"li x7, 0\n\t"
"1:\n\t"
"add x4, x2, x7\n\t"
".insn r %[custom3], 2, 0, x0, x4, x3\n\t"
"addi x7, x7, 16\n\t"
"li x4, %[tile_bytes]\n\t"
"blt x7, x4, 1b\n\t"
".insn r %[custom3], 3, 0, x0, x0, x0\n\t"
"slli x6, x5, 2\n\t"
"la x7, g_seen\n\t"
"add x7, x7, x6\n\t"
"li x6, %[init_base]\n\t"
"or x6, x6, x5\n\t"
"sw x6, 0(x7)\n\t"
"2:\n\t"
"la x6, g_case_mem\n\t"
"lw x7, 0(x6)\n\t"
"li x4, %[p_ready]\n\t"
"bne x7, x4, 2b\n\t"
"li x4, %[smem_base]\n\t"
".insn r %[custom3], 0, 0, x2, x1, x4\n\t"
".insn r %[custom3], 1, 0, x0, x0, x0\n\t"
".insn r %[custom3], 0, 0, x2, x1, x4\n\t"
".insn r %[custom3], 1, 0, x0, x0, x0\n\t"
"csrr x5, %[csr_wid]\n\t"
"addi x6, x5, -%[num_scalar_warps]\n\t"
"slli x7, x6, 10\n\t"
"la x3, g_case14_out\n\t"
"add x3, x3, x7\n\t"
"li x7, 0\n\t"
"3:\n\t"
"add x4, x2, x7\n\t"
"add x6, x3, x7\n\t"
".insn r %[custom3], 6, 0, x0, x4, x6\n\t"
"addi x7, x7, 16\n\t"
"li x4, %[tile_bytes]\n\t"
"blt x7, x4, 3b\n\t"
".insn r %[custom3], 3, 0, x0, x0, x0\n\t"
"csrr x5, %[csr_wid]\n\t"
"slli x6, x5, 2\n\t"
"la x7, g_seen\n\t"
"add x7, x7, x6\n\t"
"li x6, %[done_base]\n\t"
"or x6, x6, x5\n\t"
"sw x6, 0(x7)\n\t"
".insn r %[custom0], 0, 0, x0, x0, x0\n\t"
"4: j 4b\n\t"
:
: [csr_wid] "i"(VX_CSR_WARP_ID), [custom0] "i"(RISCV_CUSTOM0),
[custom3] "i"(RISCV_CUSTOM3),
[num_scalar_warps] "i"(NUM_SCALAR_WARPS),
[c_offset] "i"(WU_BW_TMEM_C_BYTE_OFFSET),
[tile_bytes] "i"(WU_BW_TMEM_TILE_BYTES),
[smem_base] "i"(WU_BW_DEV_SMEM_START_ADDR),
[init_base] "i"(WU_CASE14_INIT_BASE),
[done_base] "i"(WU_CASE14_DONE_BASE),
[p_ready] "i"(WU_CASE14_P_READY)
: "memory");
}
extern "C" int wu_main() {
if (vx_core_id() != 0 || vx_warp_id() != 0) {
return 0;
}
const uint32_t tid = wu_tid();
const uint32_t tensor_mask = vx_tensor_warp_mask();
if (tid == 0) {
wu_case_reset();
for (uint32_t i = 0; i < WU_CASE14_OUT_WORDS; ++i) {
g_case14_out[i] = 0;
}
wu_bw_fill_smem_tile(
reinterpret_cast<volatile uint32_t *>(WU_BW_DEV_SMEM_START_ADDR),
WU_BW_FP16_TWO_PACKED);
vx_spawn_tensor(tensor_mask, tensor_case14_worker);
if (wu_wait_seen_mask(tensor_mask, WU_CASE14_INIT_BASE) != 0) {
g_case_mem[1] = 0x31u;
}
}
asm volatile("fence rw, rw" ::: "memory");
if (g_case_mem[1] == 0) {
vx_tmc(wu_bw_all_lanes_mask());
wu_bw_fill_tmem_tile(wu_bw_tmem_a_byte_base(0), WU_BW_FP16_ONE_PACKED);
wu_bw_fill_tmem_tile(wu_bw_tmem_a_byte_base(1), WU_BW_FP16_ONE_PACKED);
vx_tmc_one();
}
asm volatile("fence rw, rw" ::: "memory");
if (tid == 0) {
g_case_mem[0] = WU_CASE14_P_READY;
if (g_case_mem[1] == 0 &&
wu_wait_seen_mask(tensor_mask, WU_CASE14_DONE_BASE) != 0) {
g_case_mem[1] = 0x32u;
}
if (g_case_mem[1] == 0) {
volatile uint32_t bad_actual = 0;
const uint32_t bad =
wu_bw_verify_constant(g_case14_out, WU_CASE14_OUT_WORDS,
WU_BW_FP32_ONE_THIRTY_THREE, &bad_actual);
if (bad != WU_CASE14_OUT_WORDS) {
g_aux[0] = bad;
g_aux[1] = bad_actual;
g_case_mem[1] = 0x33u;
}
}
if (g_case_mem[1] != 0) {
wu_case_fail(g_case_mem[1]);
return 1;
}
wu_case_pass();
}
return 0;
}

View File

@@ -0,0 +1,3 @@
PROJECT = case15_flash_softmax_pv_stage
include ../case.mk

View File

@@ -0,0 +1,8 @@
# case15_flash_softmax_pv_stage
Validates the scalar softmax-stage handoff into the Blackwell PV GEMM.
The tensor warp initializes TMEM C with a score/O value of `1.0`. Scalar warp 0
reads that value through the scalar TMEM load path, writes a full fp16 `P=1.0`
tile into TMEM A, and releases the tensor warp. The tensor warp then runs PV
against `V=2.0`; every fp32 output is expected to be `65.0`.

View File

@@ -0,0 +1,145 @@
#include "../common_wu_blackwell_fa.h"
#define WU_CASE15_INIT_BASE 0x9300u
#define WU_CASE15_DONE_BASE 0x9400u
#define WU_CASE15_P_READY 0x9500u
extern "C" {
volatile uint32_t g_case15_score_row[4] __attribute__((aligned(16))) = {
WU_BW_FP32_ONE, WU_BW_FP32_ONE, WU_BW_FP32_ONE, WU_BW_FP32_ONE};
volatile uint32_t g_case15_out[WU_BW_OUT_WORDS] __attribute__((aligned(16)));
volatile uint32_t g_case15_scalar_seen[4] __attribute__((aligned(16)));
}
extern "C" void __attribute__((naked, noinline, used)) tensor_case15_worker() {
asm volatile(
"csrr x5, %[csr_wid]\n\t"
"addi x1, x5, -%[num_scalar_warps]\n\t"
"slli x1, x1, 11\n\t"
"addi x2, x1, %[c_offset]\n\t"
"la x3, g_case15_score_row\n\t"
"li x7, 0\n\t"
"1:\n\t"
"add x4, x2, x7\n\t"
".insn r %[custom3], 2, 0, x0, x4, x3\n\t"
"addi x7, x7, 16\n\t"
"li x4, %[tile_bytes]\n\t"
"blt x7, x4, 1b\n\t"
".insn r %[custom3], 3, 0, x0, x0, x0\n\t"
"slli x6, x5, 2\n\t"
"la x7, g_seen\n\t"
"add x7, x7, x6\n\t"
"li x6, %[init_base]\n\t"
"or x6, x6, x5\n\t"
"sw x6, 0(x7)\n\t"
"2:\n\t"
"la x6, g_case_mem\n\t"
"lw x7, 0(x6)\n\t"
"li x4, %[p_ready]\n\t"
"bne x7, x4, 2b\n\t"
"li x4, %[smem_base]\n\t"
".insn r %[custom3], 0, 0, x2, x1, x4\n\t"
".insn r %[custom3], 1, 0, x0, x0, x0\n\t"
"la x3, g_case15_out\n\t"
"li x7, 0\n\t"
"3:\n\t"
"add x4, x2, x7\n\t"
"add x6, x3, x7\n\t"
".insn r %[custom3], 6, 0, x0, x4, x6\n\t"
"addi x7, x7, 16\n\t"
"li x4, %[tile_bytes]\n\t"
"blt x7, x4, 3b\n\t"
".insn r %[custom3], 3, 0, x0, x0, x0\n\t"
"csrr x5, %[csr_wid]\n\t"
"slli x6, x5, 2\n\t"
"la x7, g_seen\n\t"
"add x7, x7, x6\n\t"
"li x6, %[done_base]\n\t"
"or x6, x6, x5\n\t"
"sw x6, 0(x7)\n\t"
".insn r %[custom0], 0, 0, x0, x0, x0\n\t"
"4: j 4b\n\t"
:
: [csr_wid] "i"(VX_CSR_WARP_ID), [custom0] "i"(RISCV_CUSTOM0),
[custom3] "i"(RISCV_CUSTOM3),
[num_scalar_warps] "i"(NUM_SCALAR_WARPS),
[c_offset] "i"(WU_BW_TMEM_C_BYTE_OFFSET),
[tile_bytes] "i"(WU_BW_TMEM_TILE_BYTES),
[smem_base] "i"(WU_BW_DEV_SMEM_START_ADDR),
[init_base] "i"(WU_CASE15_INIT_BASE),
[done_base] "i"(WU_CASE15_DONE_BASE),
[p_ready] "i"(WU_CASE15_P_READY)
: "memory");
}
extern "C" int wu_main() {
if (vx_core_id() != 0 || vx_warp_id() != 0) {
return 0;
}
const uint32_t tid = wu_tid();
const uint32_t tensor_mask = 1u << NUM_SCALAR_WARPS;
if (tid == 0) {
wu_case_reset();
for (uint32_t i = 0; i < WU_BW_OUT_WORDS; ++i) {
g_case15_out[i] = 0;
}
for (uint32_t i = 0; i < 4; ++i) {
g_case15_scalar_seen[i] = 0;
}
wu_bw_fill_smem_tile(
reinterpret_cast<volatile uint32_t *>(WU_BW_DEV_SMEM_START_ADDR),
WU_BW_FP16_TWO_PACKED);
vx_spawn_tensor(tensor_mask, tensor_case15_worker);
if (wu_wait_seen_mask(tensor_mask, WU_CASE15_INIT_BASE) != 0) {
g_case_mem[1] = 0x41u;
}
}
asm volatile("fence rw, rw" ::: "memory");
const uint32_t c_frag =
wu_bw_tmem_c_byte_base(0) / WU_BW_TMEM_FRAGMENT_BYTES;
const uint32_t observed = wu_bw_scalar_tmem_ld(c_frag);
g_case15_scalar_seen[tid] = observed;
asm volatile("fence rw, rw" ::: "memory");
if (tid == 0 && g_case_mem[1] == 0 &&
g_case15_scalar_seen[0] != WU_BW_FP32_ONE) {
g_aux[0] = 0;
g_aux[1] = g_case15_scalar_seen[0];
g_case_mem[1] = 0x42u;
}
asm volatile("fence rw, rw" ::: "memory");
if (g_case_mem[1] == 0) {
vx_tmc(wu_bw_all_lanes_mask());
wu_bw_fill_tmem_tile(wu_bw_tmem_a_byte_base(0), WU_BW_FP16_ONE_PACKED);
vx_tmc_one();
}
asm volatile("fence rw, rw" ::: "memory");
if (tid == 0) {
g_case_mem[0] = WU_CASE15_P_READY;
if (g_case_mem[1] == 0 &&
wu_wait_seen_mask(tensor_mask, WU_CASE15_DONE_BASE) != 0) {
g_case_mem[1] = 0x43u;
}
if (g_case_mem[1] == 0) {
volatile uint32_t bad_actual = 0;
const uint32_t bad =
wu_bw_verify_constant(g_case15_out, WU_BW_OUT_WORDS,
WU_BW_FP32_SIXTY_FIVE, &bad_actual);
if (bad != WU_BW_OUT_WORDS) {
g_aux[0] = bad;
g_aux[1] = bad_actual;
g_case_mem[1] = 0x44u;
}
}
if (g_case_mem[1] != 0) {
wu_case_fail(g_case_mem[1]);
return 1;
}
wu_case_pass();
}
return 0;
}

View File

@@ -0,0 +1,3 @@
PROJECT = case16_flash_full_pipeline
include ../case.mk

View File

@@ -0,0 +1,14 @@
# case16_flash_full_pipeline
Validates a compact end-to-end FlashAttention-style pipeline on the current Wu
Blackwell path.
The tensor warp first computes `S = Q @ K` into TMEM C with `Q=1.0`, `K=1.0`,
and `O_init=0.0`, producing a constant fp32 score of `32.0`. Scalar warp 0 reads
the score through scalar TMEM load, records it, writes the uniform softmax result
`P=1/32` into TMEM A, refills SMEM with `V=2.0`, and releases the tensor warp.
The tensor warp reloads TMEM C with `O=0.0`, then computes `O = P @ V`.
Every output word is expected to be fp32 `2.0`. This case covers the staged
`QK -> scalar softmax handoff -> PV` loop without using the legacy
`kernels/flash_attention` implementation.

View File

@@ -0,0 +1,180 @@
#include "../common_wu_blackwell_fa.h"
#define WU_CASE16_INIT_BASE 0x9600u
#define WU_CASE16_P_READY 0x9700u
#define WU_CASE16_DONE_BASE 0x9800u
#define WU_BW_FP16_ONE_OVER_32_PACKED 0x28002800u
#define WU_BW_FP32_ZERO 0x00000000u
#define WU_BW_FP32_TWO 0x40000000u
#define WU_BW_FP32_THIRTY_TWO 0x42000000u
extern "C" {
volatile uint32_t g_case16_q_row[4] __attribute__((aligned(16))) = {
WU_BW_FP16_ONE_PACKED, WU_BW_FP16_ONE_PACKED, WU_BW_FP16_ONE_PACKED,
WU_BW_FP16_ONE_PACKED};
volatile uint32_t g_case16_zero_row[4] __attribute__((aligned(16))) = {
WU_BW_FP32_ZERO, WU_BW_FP32_ZERO, WU_BW_FP32_ZERO, WU_BW_FP32_ZERO};
volatile uint32_t g_case16_out[WU_BW_OUT_WORDS] __attribute__((aligned(16)));
volatile uint32_t g_case16_scalar_seen[4] __attribute__((aligned(16)));
}
extern "C" void __attribute__((naked, noinline, used)) tensor_case16_worker() {
asm volatile(
"csrr x5, %[csr_wid]\n\t"
"addi x1, x5, -%[num_scalar_warps]\n\t"
"slli x1, x1, 11\n\t"
"addi x2, x1, %[c_offset]\n\t"
"la x3, g_case16_q_row\n\t"
"li x7, 0\n\t"
"1:\n\t"
"add x4, x1, x7\n\t"
".insn r %[custom3], 2, 0, x0, x4, x3\n\t"
"addi x7, x7, 16\n\t"
"li x4, %[tile_bytes]\n\t"
"blt x7, x4, 1b\n\t"
"la x3, g_case16_zero_row\n\t"
"li x7, 0\n\t"
"2:\n\t"
"add x4, x2, x7\n\t"
".insn r %[custom3], 2, 0, x0, x4, x3\n\t"
"addi x7, x7, 16\n\t"
"li x4, %[tile_bytes]\n\t"
"blt x7, x4, 2b\n\t"
".insn r %[custom3], 3, 0, x0, x0, x0\n\t"
"li x4, %[smem_base]\n\t"
".insn r %[custom3], 0, 0, x2, x1, x4\n\t"
".insn r %[custom3], 1, 0, x0, x0, x0\n\t"
"slli x6, x5, 2\n\t"
"la x7, g_seen\n\t"
"add x7, x7, x6\n\t"
"li x6, %[init_base]\n\t"
"or x6, x6, x5\n\t"
"sw x6, 0(x7)\n\t"
"3:\n\t"
"la x6, g_case_mem\n\t"
"lw x7, 0(x6)\n\t"
"li x4, %[p_ready]\n\t"
"bne x7, x4, 3b\n\t"
"la x3, g_case16_zero_row\n\t"
"li x7, 0\n\t"
"4:\n\t"
"add x4, x2, x7\n\t"
".insn r %[custom3], 2, 0, x0, x4, x3\n\t"
"addi x7, x7, 16\n\t"
"li x4, %[tile_bytes]\n\t"
"blt x7, x4, 4b\n\t"
".insn r %[custom3], 3, 0, x0, x0, x0\n\t"
"li x4, %[smem_base]\n\t"
".insn r %[custom3], 0, 0, x2, x1, x4\n\t"
".insn r %[custom3], 1, 0, x0, x0, x0\n\t"
"la x3, g_case16_out\n\t"
"li x7, 0\n\t"
"5:\n\t"
"add x4, x2, x7\n\t"
"add x6, x3, x7\n\t"
".insn r %[custom3], 6, 0, x0, x4, x6\n\t"
"addi x7, x7, 16\n\t"
"li x4, %[tile_bytes]\n\t"
"blt x7, x4, 5b\n\t"
".insn r %[custom3], 3, 0, x0, x0, x0\n\t"
"csrr x5, %[csr_wid]\n\t"
"slli x6, x5, 2\n\t"
"la x7, g_seen\n\t"
"add x7, x7, x6\n\t"
"li x6, %[done_base]\n\t"
"or x6, x6, x5\n\t"
"sw x6, 0(x7)\n\t"
".insn r %[custom0], 0, 0, x0, x0, x0\n\t"
"6: j 6b\n\t"
:
: [csr_wid] "i"(VX_CSR_WARP_ID), [custom0] "i"(RISCV_CUSTOM0),
[custom3] "i"(RISCV_CUSTOM3),
[num_scalar_warps] "i"(NUM_SCALAR_WARPS),
[c_offset] "i"(WU_BW_TMEM_C_BYTE_OFFSET),
[tile_bytes] "i"(WU_BW_TMEM_TILE_BYTES),
[smem_base] "i"(WU_BW_DEV_SMEM_START_ADDR),
[init_base] "i"(WU_CASE16_INIT_BASE),
[p_ready] "i"(WU_CASE16_P_READY),
[done_base] "i"(WU_CASE16_DONE_BASE)
: "memory");
}
extern "C" int wu_main() {
if (vx_core_id() != 0 || vx_warp_id() != 0) {
return 0;
}
const uint32_t tid = wu_tid();
const uint32_t tensor_mask = 1u << NUM_SCALAR_WARPS;
if (tid == 0) {
wu_case_reset();
for (uint32_t i = 0; i < WU_BW_OUT_WORDS; ++i) {
g_case16_out[i] = 0;
}
for (uint32_t i = 0; i < 4; ++i) {
g_case16_scalar_seen[i] = 0;
}
wu_bw_fill_smem_tile(
reinterpret_cast<volatile uint32_t *>(WU_BW_DEV_SMEM_START_ADDR),
WU_BW_FP16_ONE_PACKED);
vx_spawn_tensor(tensor_mask, tensor_case16_worker);
if (wu_wait_seen_mask(tensor_mask, WU_CASE16_INIT_BASE) != 0) {
g_case_mem[1] = 0x61u;
}
}
asm volatile("fence rw, rw" ::: "memory");
const uint32_t c_frag =
wu_bw_tmem_c_byte_base(0) / WU_BW_TMEM_FRAGMENT_BYTES;
const uint32_t observed = wu_bw_scalar_tmem_ld(c_frag);
g_case16_scalar_seen[tid] = observed;
asm volatile("fence rw, rw" ::: "memory");
if (tid == 0 && g_case_mem[1] == 0 &&
g_case16_scalar_seen[0] != WU_BW_FP32_THIRTY_TWO) {
g_aux[0] = 0;
g_aux[1] = g_case16_scalar_seen[0];
g_case_mem[1] = 0x62u;
}
asm volatile("fence rw, rw" ::: "memory");
if (g_case_mem[1] == 0) {
vx_tmc(wu_bw_all_lanes_mask());
wu_bw_fill_tmem_tile(wu_bw_tmem_a_byte_base(0),
WU_BW_FP16_ONE_OVER_32_PACKED);
vx_tmc_one();
}
asm volatile("fence rw, rw" ::: "memory");
if (tid == 0) {
if (g_case_mem[1] == 0) {
wu_bw_fill_smem_tile(
reinterpret_cast<volatile uint32_t *>(WU_BW_DEV_SMEM_START_ADDR),
WU_BW_FP16_TWO_PACKED);
}
asm volatile("fence rw, rw" ::: "memory");
g_case_mem[0] = WU_CASE16_P_READY;
if (g_case_mem[1] == 0 &&
wu_wait_seen_mask(tensor_mask, WU_CASE16_DONE_BASE) != 0) {
g_case_mem[1] = 0x63u;
}
if (g_case_mem[1] == 0) {
volatile uint32_t bad_actual = 0;
const uint32_t bad =
wu_bw_verify_constant(g_case16_out, WU_BW_OUT_WORDS,
WU_BW_FP32_TWO, &bad_actual);
if (bad != WU_BW_OUT_WORDS) {
g_aux[0] = bad;
g_aux[1] = bad_actual;
g_case_mem[1] = 0x64u;
}
}
if (g_case_mem[1] != 0) {
wu_case_fail(g_case_mem[1]);
return 1;
}
wu_case_pass();
}
return 0;
}

View File

@@ -0,0 +1,3 @@
PROJECT = case17_flash_exp_softmax_probe
include ../case.mk

View File

@@ -0,0 +1,16 @@
# case17_flash_exp_softmax_probe
Validates that the current scalar Wu path can execute the `e^x` work needed by
non-uniform FlashAttention softmax. The current ISA/RTL does not expose a
dedicated exp or exp2 instruction, so this case uses scalar fp32 arithmetic to
approximate exp.
The case evaluates a two-element row with scores `{0, ln(2)}`. A numerically
stable softmax computes `exp(score - row_max)`, so the exp inputs are
`{-ln(2), 0}` and the probabilities should be close to `{1/3, 2/3}`. The
inputs are loaded from volatile memory so the compiler cannot fold the result
into constants.
This is intentionally separate from the tensor PV path. If this case fails, the
problem is in scalar fp32 arithmetic, exp approximation, or normalization rather
than TMEM handoff or BWGMMA.

View File

@@ -0,0 +1,81 @@
#include "../common_wu_min.h"
extern "C" {
volatile uint32_t g_case17_scores_bits[2] __attribute__((aligned(16))) = {
0x00000000u, 0x3f317218u}; // 0.0, ln(2)
volatile uint32_t g_case17_out_bits[2] __attribute__((aligned(16)));
}
static inline float wu_case17_bits_to_f32(uint32_t bits) {
union {
uint32_t u;
float f;
} v = {bits};
return v.f;
}
static inline uint32_t wu_case17_f32_to_bits(float value) {
union {
float f;
uint32_t u;
} v = {value};
return v.u;
}
static inline float wu_case17_absf(float value) {
return value < 0.0f ? -value : value;
}
static inline float wu_case17_exp_neg_ln2_to_0(float x) {
const float x2 = x * x;
const float x3 = x2 * x;
const float x4 = x2 * x2;
const float x5 = x4 * x;
return 1.0f + x + (0.5f * x2) + (0.1666666716f * x3) +
(0.0416666679f * x4) + (0.0083333338f * x5);
}
extern "C" int wu_main() {
if (vx_core_id() != 0 || vx_warp_id() != 0) {
return 0;
}
const uint32_t tid = wu_tid();
if (tid == 0) {
wu_case_reset();
g_case17_out_bits[0] = 0;
g_case17_out_bits[1] = 0;
}
asm volatile("fence rw, rw" ::: "memory");
const float score0 = wu_case17_bits_to_f32(g_case17_scores_bits[0]);
const float score1 = wu_case17_bits_to_f32(g_case17_scores_bits[1]);
const float row_max = score0 > score1 ? score0 : score1;
const float e0 = wu_case17_exp_neg_ln2_to_0(score0 - row_max);
const float e1 = wu_case17_exp_neg_ln2_to_0(score1 - row_max);
const float inv_sum = 1.0f / (e0 + e1);
const float p0 = e0 * inv_sum;
const float p1 = e1 * inv_sum;
if (tid == 0) {
g_case17_out_bits[0] = wu_case17_f32_to_bits(p0);
g_case17_out_bits[1] = wu_case17_f32_to_bits(p1);
}
asm volatile("fence rw, rw" ::: "memory");
if (tid == 0) {
const float expected0 = 0.3333333433f;
const float expected1 = 0.6666666865f;
const float tolerance = 0.0015f;
const float err0 = wu_case17_absf(p0 - expected0);
const float err1 = wu_case17_absf(p1 - expected1);
if (err0 > tolerance || err1 > tolerance) {
g_aux[0] = g_case17_out_bits[0];
g_aux[1] = g_case17_out_bits[1];
wu_case_fail(0x71u);
return 1;
}
wu_case_pass();
}
return 0;
}

View File

@@ -0,0 +1,101 @@
#ifndef WU_ARCH_CASES_COMMON_WU_BLACKWELL_FA_H
#define WU_ARCH_CASES_COMMON_WU_BLACKWELL_FA_H
#include "common_wu_min.h"
#define WU_BW_DEV_SMEM_START_ADDR 0xff000000u
#define WU_BW_TMEM_TILE_BYTES 1024u
#define WU_BW_TMEM_FRAGMENT_BYTES 16u
#define WU_BW_TMEM_WARP_BYTES 2048u
#define WU_BW_TMEM_C_BYTE_OFFSET 1024u
#define WU_BW_TMEM_FRAGMENTS \
(WU_BW_TMEM_TILE_BYTES / WU_BW_TMEM_FRAGMENT_BYTES)
#define WU_BW_TMEM_WORDS (WU_BW_TMEM_TILE_BYTES / sizeof(uint32_t))
#define WU_BW_OUT_WORDS WU_BW_TMEM_WORDS
#define WU_BW_FP16_ONE_PACKED 0x3c003c00u
#define WU_BW_FP16_TWO_PACKED 0x40004000u
#define WU_BW_FP32_ONE 0x3f800000u
#define WU_BW_FP32_THREE 0x40400000u
#define WU_BW_FP32_FIVE 0x40a00000u
#define WU_BW_FP32_SIXTY_FIVE 0x42820000u
#define WU_BW_FP32_SIXTY_SEVEN 0x42860000u
#define WU_BW_FP32_ONE_THIRTY_THREE 0x43050000u
static_assert(NUM_SCALAR_WARPS == 2,
"Wu Blackwell FA staged cases expect two scalar warps");
static_assert(NUM_TENSOR_WARPS == 2,
"Wu Blackwell FA staged cases expect two tensor warps");
static_assert(NUM_THREADS == 4,
"Wu Blackwell FA staged cases expect the 4-lane tensor core");
static inline uint32_t wu_bw_all_lanes_mask() {
return (1u << NUM_THREADS) - 1u;
}
static inline uint32_t wu_bw_scalar_tmem_ld(uint32_t frag_addr) {
uint32_t value;
asm volatile(".insn r %[custom1], 0, 0x30, %[value], %[addr], x0"
: [value] "=r"(value)
: [custom1] "i"(RISCV_CUSTOM1), [addr] "r"(frag_addr)
: "memory");
return value;
}
static inline void wu_bw_scalar_tmem_st(uint32_t frag_addr, uint32_t value) {
asm volatile(".insn r %[custom1], 1, 0x30, x0, %[addr], %[value]"
:
: [custom1] "i"(RISCV_CUSTOM1), [addr] "r"(frag_addr),
[value] "r"(value)
: "memory");
}
static inline uint32_t wu_bw_tmem_a_byte_base(uint32_t tensor_block) {
return tensor_block * WU_BW_TMEM_WARP_BYTES;
}
static inline uint32_t wu_bw_tmem_c_byte_base(uint32_t tensor_block) {
return wu_bw_tmem_a_byte_base(tensor_block) + WU_BW_TMEM_C_BYTE_OFFSET;
}
static inline void wu_bw_fill_tmem_tile(uint32_t byte_base, uint32_t value) {
const uint32_t first_frag = byte_base / WU_BW_TMEM_FRAGMENT_BYTES;
for (uint32_t frag = 0; frag < WU_BW_TMEM_FRAGMENTS; ++frag) {
// Caller must run this with all TMEM fragment lanes active.
wu_bw_scalar_tmem_st(first_frag + frag, value);
}
asm volatile("fence rw, rw" ::: "memory");
}
static inline void wu_bw_store_tmem_fragment_lane_values(uint32_t byte_base,
uint32_t frag,
uint32_t lane_value) {
const uint32_t first_frag = byte_base / WU_BW_TMEM_FRAGMENT_BYTES;
wu_bw_scalar_tmem_st(first_frag + frag, lane_value);
asm volatile("fence rw, rw" ::: "memory");
}
static inline void wu_bw_fill_smem_tile(volatile uint32_t *smem,
uint32_t value) {
for (uint32_t i = 0; i < WU_BW_TMEM_WORDS; ++i) {
smem[i] = value;
}
asm volatile("fence rw, rw" ::: "memory");
}
static inline uint32_t wu_bw_verify_constant(const volatile uint32_t *data,
uint32_t words,
uint32_t expected,
volatile uint32_t *bad_actual) {
for (uint32_t i = 0; i < words; ++i) {
const uint32_t actual = data[i];
if (actual != expected) {
*bad_actual = actual;
return i;
}
}
*bad_actual = 0;
return words;
}
#endif

View File

@@ -1,9 +1,15 @@
# wu_arch_hgemm # wu_arch_hgemm
Tensor-warp HGEMM smoke test for the Wu split scalar/tensor warp configuration Two-tensor-warp HGEMM smoke test for the Wu split scalar/tensor warp
with the 4-lane Blackwell tensor-core path. configuration with the 4-lane Blackwell tensor-core path.
Scalar warp 0 initializes the shared-memory B operand, spawns only the tensor Scalar warp 0 initializes the shared-memory B operand for a 32x16x32 GEMM,
warp mask, waits for tensor warps `NUM_SCALAR_WARPS..NUM_WARPS-1`, and reports spawns only the tensor warp mask, waits for tensor warps
completion through `tohost`. Tensor warps execute the Blackwell custom HGEMM `NUM_SCALAR_WARPS..NUM_WARPS-1`, verifies the combined 32x16 fp32 output, and
instruction sequence using 16-byte fragments and then stop themselves. reports completion through `tohost`.
Each tensor warp computes one 16x16x32 BWGMMA M-block using 16-byte fragments:
warp `NUM_SCALAR_WARPS` produces rows 0..15 and warp `NUM_SCALAR_WARPS + 1`
produces rows 16..31. Both tensor warps share the same B tile, copy their C
tiles back into one contiguous output matrix, mark their block IDs, and then
stop themselves.

View File

@@ -2,6 +2,24 @@
#define DEV_SMEM_START_ADDR 0xff000000u #define DEV_SMEM_START_ADDR 0xff000000u
#define WU_CASE_TENSOR_HGEMM_BASE 0x7500u #define WU_CASE_TENSOR_HGEMM_BASE 0x7500u
#define WU_CASE_TENSOR_HGEMM_BLOCK_BASE 0x7600u
#define WU_HGEMM_M_BLOCKS 2u
#define WU_HGEMM_M_PER_BLOCK 16u
#define WU_HGEMM_N 16u
#define WU_HGEMM_K 32u
#define WU_HGEMM_TILE_BYTES 1024u
#define WU_HGEMM_FRAGMENT_BYTES 16u
#define WU_HGEMM_B_WORDS (WU_HGEMM_TILE_BYTES / sizeof(uint32_t))
#define WU_HGEMM_OUT_WORDS \
(WU_HGEMM_M_BLOCKS * WU_HGEMM_M_PER_BLOCK * WU_HGEMM_N)
#define WU_HGEMM_EXPECTED 0x42820000u
#define WU_HGEMM_NO_FAIL WU_HGEMM_OUT_WORDS
static_assert(NUM_TENSOR_WARPS == WU_HGEMM_M_BLOCKS,
"wu_arch_hgemm expects two tensor warps");
static_assert((NUM_THREADS * 2u) <= WU_CASE_MAX_WARPS,
"wu_arch_hgemm uses g_aux[2 * tid + {0,1}] for check scratch");
#define BW_REP2(x) x, x #define BW_REP2(x) x, x
#define BW_REP4(x) BW_REP2(x), BW_REP2(x) #define BW_REP4(x) BW_REP2(x), BW_REP2(x)
@@ -13,6 +31,9 @@ volatile uint32_t g_hgemm_b_row[4] __attribute__((aligned(16))) = {
BW_REP4(0x40004000u)}; BW_REP4(0x40004000u)};
volatile uint32_t g_hgemm_c_row[4] __attribute__((aligned(16))) = { volatile uint32_t g_hgemm_c_row[4] __attribute__((aligned(16))) = {
BW_REP4(0x3f800000u)}; BW_REP4(0x3f800000u)};
volatile uint32_t g_hgemm_control_fail __attribute__((aligned(16)));
volatile uint32_t g_hgemm_return_code __attribute__((aligned(16)));
volatile uint32_t g_hgemm_out[WU_HGEMM_OUT_WORDS] __attribute__((aligned(16)));
} }
#undef BW_REP2 #undef BW_REP2
@@ -21,7 +42,8 @@ volatile uint32_t g_hgemm_c_row[4] __attribute__((aligned(16))) = {
extern "C" void __attribute__((naked, noinline, used)) tensor_hgemm_worker() { extern "C" void __attribute__((naked, noinline, used)) tensor_hgemm_worker() {
asm volatile( asm volatile(
"csrr x5, %[csr_wid]\n\t" "csrr x5, %[csr_wid]\n\t"
"slli x1, x5, 11\n\t" "addi x1, x5, -%[num_scalar_warps]\n\t"
"slli x1, x1, 11\n\t"
"addi x2, x1, 1024\n\t" "addi x2, x1, 1024\n\t"
"la x6, g_hgemm_a_row\n\t" "la x6, g_hgemm_a_row\n\t"
"la x3, g_hgemm_c_row\n\t" "la x3, g_hgemm_c_row\n\t"
@@ -39,6 +61,25 @@ extern "C" void __attribute__((naked, noinline, used)) tensor_hgemm_worker() {
".insn r %[custom3], 0, 0, x2, x1, x4\n\t" ".insn r %[custom3], 0, 0, x2, x1, x4\n\t"
".insn r %[custom3], 1, 0, x0, x0, x0\n\t" ".insn r %[custom3], 1, 0, x0, x0, x0\n\t"
"csrr x5, %[csr_wid]\n\t" "csrr x5, %[csr_wid]\n\t"
"addi x6, x5, -%[num_scalar_warps]\n\t"
"slli x7, x6, 10\n\t"
"la x3, g_hgemm_out\n\t"
"add x3, x3, x7\n\t"
"li x7, 0\n\t"
"3:\n\t"
"add x4, x2, x7\n\t"
"add x1, x3, x7\n\t"
".insn r %[custom3], 6, 0, x0, x4, x1\n\t"
"addi x7, x7, 16\n\t"
"li x4, %[tile_bytes]\n\t"
"blt x7, x4, 3b\n\t"
".insn r %[custom3], 3, 0, x0, x0, x0\n\t"
"slli x7, x6, 2\n\t"
"la x3, g_case_mem\n\t"
"add x3, x3, x7\n\t"
"li x7, %[hgemm_block_base]\n\t"
"or x7, x7, x6\n\t"
"sw x7, 0(x3)\n\t"
"slli x6, x5, 2\n\t" "slli x6, x5, 2\n\t"
"la x7, g_seen\n\t" "la x7, g_seen\n\t"
"add x7, x7, x6\n\t" "add x7, x7, x6\n\t"
@@ -52,34 +93,91 @@ extern "C" void __attribute__((naked, noinline, used)) tensor_hgemm_worker() {
[custom0] "i"(RISCV_CUSTOM0), [custom0] "i"(RISCV_CUSTOM0),
[custom3] "i"(RISCV_CUSTOM3), [custom3] "i"(RISCV_CUSTOM3),
[smem_base] "i"(DEV_SMEM_START_ADDR), [smem_base] "i"(DEV_SMEM_START_ADDR),
[hgemm_base] "i"(WU_CASE_TENSOR_HGEMM_BASE) [hgemm_base] "i"(WU_CASE_TENSOR_HGEMM_BASE),
[hgemm_block_base] "i"(WU_CASE_TENSOR_HGEMM_BLOCK_BASE),
[num_scalar_warps] "i"(NUM_SCALAR_WARPS),
[tile_bytes] "i"(WU_HGEMM_TILE_BYTES)
: "memory"); : "memory");
} }
extern "C" int wu_main() { extern "C" int wu_main() {
if (!wu_is_leader()) { if (vx_core_id() != 0 || vx_warp_id() != 0) {
return 0; return 0;
} }
wu_case_reset(); const uint32_t tid = wu_tid();
volatile uint32_t *smem_b = if (tid == 0) {
reinterpret_cast<volatile uint32_t *>(DEV_SMEM_START_ADDR); wu_case_reset();
for (uint32_t frag = 0; frag < 64u; ++frag) { g_hgemm_control_fail = 0;
const uint32_t row = frag * 4u; g_hgemm_return_code = 0;
for (uint32_t i = 0; i < 4u; ++i) { }
smem_b[row + i] = g_hgemm_b_row[i]; asm volatile("fence rw, rw" ::: "memory");
if (tid == 0) {
for (uint32_t i = 0; i < WU_HGEMM_OUT_WORDS; ++i) {
g_hgemm_out[i] = 0;
}
volatile uint32_t *smem_b =
reinterpret_cast<volatile uint32_t *>(DEV_SMEM_START_ADDR);
for (uint32_t i = 0; i < WU_HGEMM_B_WORDS; ++i) {
smem_b[i] = g_hgemm_b_row[i & 3u];
} }
} }
asm volatile("fence rw, rw" ::: "memory");
vx_spawn_tensor(vx_tensor_warp_mask(), tensor_hgemm_worker); if (tid == 0) {
vx_spawn_tensor(vx_tensor_warp_mask(), tensor_hgemm_worker);
if (wu_wait_seen_range(NUM_SCALAR_WARPS, NUM_WARPS, if (wu_wait_seen_range(NUM_SCALAR_WARPS, NUM_WARPS,
WU_CASE_TENSOR_HGEMM_BASE) != 0) { WU_CASE_TENSOR_HGEMM_BASE) != 0) {
wu_case_fail(0x09u); g_hgemm_control_fail = 0x09u;
return 1; }
if (g_hgemm_control_fail == 0) {
for (uint32_t block = 0; block < WU_HGEMM_M_BLOCKS; ++block) {
if (g_case_mem[block] != (WU_CASE_TENSOR_HGEMM_BLOCK_BASE | block)) {
g_hgemm_control_fail = 0x0au + block;
break;
}
}
}
}
asm volatile("fence rw, rw" ::: "memory");
if (g_hgemm_control_fail != 0) {
if (tid == 0) {
g_hgemm_return_code = 1;
wu_case_fail(g_hgemm_control_fail);
}
asm volatile("fence rw, rw" ::: "memory");
return static_cast<int>(g_hgemm_return_code);
} }
wu_case_pass(); if (tid == 0) {
return 0; uint32_t bad_index = WU_HGEMM_NO_FAIL;
uint32_t bad_actual = 0;
for (uint32_t i = 0; i < WU_HGEMM_OUT_WORDS; ++i) {
const uint32_t actual = g_hgemm_out[i];
if (actual != WU_HGEMM_EXPECTED) {
bad_index = i;
bad_actual = actual;
break;
}
}
if (bad_index != WU_HGEMM_NO_FAIL) {
g_aux[0] = bad_index;
g_aux[1] = bad_actual;
g_hgemm_return_code = 1;
wu_case_fail(0x20u);
} else {
g_hgemm_return_code = 0;
wu_case_pass();
}
}
asm volatile("fence rw, rw" ::: "memory");
return static_cast<int>(g_hgemm_return_code);
} }