Add Wu TMEM FlashAttention validation cases
This commit is contained in:
7
kernels/blackwell_multi_tc/Makefile
Normal file
7
kernels/blackwell_multi_tc/Makefile
Normal file
@@ -0,0 +1,7 @@
|
||||
PROJECT = blackwell_multi_tc
|
||||
|
||||
VX_SRCS = kernel.cpp
|
||||
|
||||
OPTS ?= -n1
|
||||
|
||||
include ../common.mk
|
||||
188
kernels/blackwell_multi_tc/kernel.cpp
Normal file
188
kernels/blackwell_multi_tc/kernel.cpp
Normal file
@@ -0,0 +1,188 @@
|
||||
#include <stdint.h>
|
||||
#include <vx_intrinsics.h>
|
||||
#include <vx_spawn.h>
|
||||
|
||||
#define RISCV_CUSTOM3 0x7B
|
||||
#define DEV_SMEM_START_ADDR 0xff000000u
|
||||
#define MAX_CORES 4
|
||||
#define MAX_WARPS 8
|
||||
|
||||
#define BW_REP2(x) x, x
|
||||
#define BW_REP4(x) BW_REP2(x), BW_REP2(x)
|
||||
#define BW_REP8(x) BW_REP4(x), BW_REP4(x)
|
||||
|
||||
extern "C" {
|
||||
volatile uint32_t g_a_row[8] __attribute__((aligned(32))) = {
|
||||
BW_REP8(0x3c003c00u)}; // two fp16 1.0 values per word
|
||||
volatile uint32_t g_b_row[8] __attribute__((aligned(32))) = {
|
||||
BW_REP8(0x40004000u)}; // two fp16 2.0 values per word
|
||||
volatile uint32_t g_c_row[8] __attribute__((aligned(32))) = {
|
||||
BW_REP8(0x3f800000u)}; // one fp32 1.0 value per word
|
||||
volatile uint32_t g_result[MAX_CORES * MAX_WARPS] __attribute__((aligned(32)));
|
||||
volatile uint32_t g_status[MAX_CORES * MAX_WARPS] __attribute__((aligned(32)));
|
||||
}
|
||||
|
||||
#undef BW_REP2
|
||||
#undef BW_REP4
|
||||
#undef BW_REP8
|
||||
|
||||
static inline void tcgen05_cp(uint32_t addr_tmem, uint32_t addr_gmem) {
|
||||
asm volatile(".insn r %0, 2, 0, x0, %1, %2"
|
||||
:
|
||||
: "i"(RISCV_CUSTOM3), "r"(addr_tmem), "r"(addr_gmem)
|
||||
: "memory");
|
||||
}
|
||||
|
||||
static inline void tcgen05_cp_wait() {
|
||||
asm volatile(".insn r %0, 3, 0, x0, x0, x0" :: "i"(RISCV_CUSTOM3)
|
||||
: "memory");
|
||||
}
|
||||
|
||||
static inline void bwgmma(uint32_t addr_tmem_c,
|
||||
uint32_t addr_tmem_a,
|
||||
uint32_t addr_smem_b) {
|
||||
asm volatile(".insn r %0, 0, 0, %1, %2, %3"
|
||||
:
|
||||
: "i"(RISCV_CUSTOM3), "r"(addr_tmem_c), "r"(addr_tmem_a),
|
||||
"r"(addr_smem_b)
|
||||
: "memory");
|
||||
}
|
||||
|
||||
static inline void bwgmma_wait() {
|
||||
asm volatile(".insn r %0, 1, 0, x0, x0, x0" :: "i"(RISCV_CUSTOM3)
|
||||
: "memory");
|
||||
}
|
||||
|
||||
static inline float tcgen05_ld_f32(uint32_t addr_tmem) {
|
||||
float value;
|
||||
asm volatile(".insn r %1, 4, 0, %0, %2, x0"
|
||||
: "=f"(value)
|
||||
: "i"(RISCV_CUSTOM3), "r"(addr_tmem)
|
||||
: "memory");
|
||||
return value;
|
||||
}
|
||||
|
||||
static inline uint32_t f32_bits(float value) {
|
||||
union {
|
||||
float f;
|
||||
uint32_t u;
|
||||
} bits = {value};
|
||||
return bits.u;
|
||||
}
|
||||
|
||||
extern "C" void vx_perf_dump() {}
|
||||
|
||||
extern "C" void __attribute__((naked, noinline, used)) tensor_kernel_entry() {
|
||||
asm volatile(
|
||||
"csrr x5, %[csr_wid]\n\t"
|
||||
"slli x1, x5, 11\n\t" // tmem_a = wid * 0x800
|
||||
"addi x2, x1, 1024\n\t" // tmem_c = tmem_a + 0x400
|
||||
"la x6, g_a_row\n\t"
|
||||
"la x3, g_c_row\n\t"
|
||||
"li x7, 0\n\t"
|
||||
"1:\n\t"
|
||||
"add x4, x1, x7\n\t"
|
||||
".insn r %[custom3], 2, 0, x0, x4, x6\n\t"
|
||||
"add x4, x2, x7\n\t"
|
||||
".insn r %[custom3], 2, 0, x0, x4, x3\n\t"
|
||||
"addi x7, x7, 32\n\t"
|
||||
"li x4, 1024\n\t"
|
||||
"blt x7, x4, 1b\n\t"
|
||||
".insn r %[custom3], 3, 0, x0, x0, x0\n\t"
|
||||
"li x4, %[smem_base]\n\t"
|
||||
".insn r %[custom3], 0, 0, x2, x1, x4\n\t"
|
||||
".insn r %[custom3], 1, 0, x0, x0, x0\n\t"
|
||||
"csrr x5, %[csr_wid]\n\t"
|
||||
"slli x5, x5, 2\n\t"
|
||||
"la x6, g_status\n\t"
|
||||
"add x6, x6, x5\n\t"
|
||||
"li x7, 0x600d\n\t"
|
||||
"sw x7, 0(x6)\n\t"
|
||||
".insn r %[custom0], 0, 0, x0, x0, x0\n\t"
|
||||
"2: j 2b\n\t"
|
||||
:
|
||||
: [csr_wid] "i"(VX_CSR_WARP_ID), [custom0] "i"(RISCV_CUSTOM0),
|
||||
[custom3] "i"(RISCV_CUSTOM3), [smem_base] "i"(DEV_SMEM_START_ADDR)
|
||||
: "memory");
|
||||
}
|
||||
|
||||
struct kernel_arg_t {
|
||||
volatile uint32_t *a_row;
|
||||
volatile uint32_t *b_row;
|
||||
volatile uint32_t *c_row;
|
||||
volatile uint32_t *result;
|
||||
volatile uint32_t *status;
|
||||
};
|
||||
|
||||
static void __attribute__((convergent)) kernel_body(int task_id,
|
||||
kernel_arg_t *__UNIFORM__ arg) {
|
||||
const int warp_id = task_id / vx_num_threads();
|
||||
const int lane_id = task_id % vx_num_threads();
|
||||
if (lane_id != 0)
|
||||
return;
|
||||
|
||||
const int num_warps = vx_num_warps();
|
||||
if (warp_id >= num_warps)
|
||||
return;
|
||||
|
||||
volatile uint32_t *a_row = arg->a_row;
|
||||
volatile uint32_t *c_row = arg->c_row;
|
||||
volatile uint32_t *result = arg->result;
|
||||
volatile uint32_t *status = arg->status;
|
||||
const int core_id = vx_core_id();
|
||||
const int status_idx = core_id * MAX_WARPS + warp_id;
|
||||
|
||||
const uint32_t smem_b = DEV_SMEM_START_ADDR;
|
||||
const uint32_t base = static_cast<uint32_t>(warp_id) * 0x800u;
|
||||
const uint32_t tmem_a = base + 0x000u;
|
||||
const uint32_t tmem_c = base + 0x400u;
|
||||
|
||||
status[status_idx] = 0x100u;
|
||||
|
||||
for (int frag = 0; frag < 32; ++frag) {
|
||||
const uint32_t offset = static_cast<uint32_t>(frag * 32);
|
||||
tcgen05_cp(tmem_a + offset, reinterpret_cast<uint32_t>(a_row));
|
||||
tcgen05_cp(tmem_c + offset, reinterpret_cast<uint32_t>(c_row));
|
||||
}
|
||||
tcgen05_cp_wait();
|
||||
|
||||
status[status_idx] = 0x200u;
|
||||
bwgmma(tmem_c, tmem_a, smem_b);
|
||||
bwgmma_wait();
|
||||
|
||||
result[status_idx] = f32_bits(tcgen05_ld_f32(tmem_c));
|
||||
status[status_idx] = (result[status_idx] == 0x42820000u) ? 0x600du : 0xe000u;
|
||||
}
|
||||
|
||||
int main() {
|
||||
const int core_id = vx_core_id();
|
||||
if (core_id != 0 || vx_warp_id() != 0 || vx_thread_id() != 0)
|
||||
return 0;
|
||||
|
||||
volatile uint32_t *smem_b_ptr =
|
||||
reinterpret_cast<volatile uint32_t *>(DEV_SMEM_START_ADDR);
|
||||
for (int frag = 0; frag < 32; ++frag) {
|
||||
const int row = frag * 8;
|
||||
for (int i = 0; i < 8; ++i)
|
||||
smem_b_ptr[row + i] = g_b_row[i];
|
||||
}
|
||||
|
||||
for (int i = 0; i < MAX_WARPS; ++i)
|
||||
g_status[core_id * MAX_WARPS + i] = 0;
|
||||
|
||||
vx_spawn_tensor(vx_tensor_warp_mask(), tensor_kernel_entry);
|
||||
|
||||
for (int spin = 0; spin < 100000; ++spin) {
|
||||
int done = 1;
|
||||
for (int i = NUM_SCALAR_WARPS; i < vx_num_warps(); ++i)
|
||||
done &= (g_status[core_id * MAX_WARPS + i] == 0x600du);
|
||||
if (done)
|
||||
break;
|
||||
}
|
||||
|
||||
for (int i = NUM_SCALAR_WARPS; i < vx_num_warps(); ++i) {
|
||||
if (g_status[core_id * MAX_WARPS + i] != 0x600du)
|
||||
return 1;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
@@ -1,3 +1,6 @@
|
||||
# Get the directory where this common.mk file is located
|
||||
COMMON_MK_DIR := $(dir $(lastword $(MAKEFILE_LIST)))
|
||||
|
||||
XLEN ?= 32
|
||||
|
||||
TOOLDIR ?= /opt
|
||||
@@ -7,7 +10,7 @@ RISCV_TOOLCHAIN_PATH ?= $(TOOLDIR)/riscv64-gnu-toolchain
|
||||
VX_CFLAGS += -march=rv64imafd -mabi=lp64d
|
||||
STARTUP_ADDR ?= 0x180000000
|
||||
else
|
||||
RISCV_TOOLCHAIN_PATH ?= $(TOOLDIR)/riscv-gnu-toolchain
|
||||
RISCV_TOOLCHAIN_PATH ?= $(realpath $(COMMON_MK_DIR)../../toolchain/vortex-toolchain-prebuilt/riscv-gnu-toolchain)
|
||||
VX_CFLAGS += -march=rv32imaf -mabi=ilp32f
|
||||
STARTUP_ADDR ?= 0x80000000
|
||||
endif
|
||||
@@ -18,7 +21,7 @@ RISCV_SYSROOT ?= $(RISCV_TOOLCHAIN_PATH)/$(RISCV_PREFIX)
|
||||
VORTEX_KN_PATH ?= $(realpath ../../lib)
|
||||
GEMMINI_SW_PATH ?= $(realpath ../../lib/gemmini)
|
||||
|
||||
LLVM_VORTEX ?= $(TOOLDIR)/llvm-vortex
|
||||
LLVM_VORTEX ?= $(realpath $(COMMON_MK_DIR)../../toolchain/llvm-r8)
|
||||
|
||||
LLVM_CFLAGS += --sysroot=$(RISCV_SYSROOT)
|
||||
LLVM_CFLAGS += --gcc-toolchain=$(RISCV_TOOLCHAIN_PATH)
|
||||
|
||||
9
kernels/hgemm_validation/Makefile
Normal file
9
kernels/hgemm_validation/Makefile
Normal file
@@ -0,0 +1,9 @@
|
||||
PROJECT = hgemm_validation
|
||||
|
||||
VX_SRCS = kernel.cpp
|
||||
OPTS ?= -n1
|
||||
|
||||
include ../common.mk
|
||||
|
||||
args.bin input.a.bin input.b.bin input.c.bin: ../wu_arch_cases/zero.bin
|
||||
cp $< $@
|
||||
14
kernels/hgemm_validation/README.md
Normal file
14
kernels/hgemm_validation/README.md
Normal file
@@ -0,0 +1,14 @@
|
||||
# hgemm_validation
|
||||
|
||||
Small HGEMM correctness validation for the 4-lane Blackwell tensor-core path.
|
||||
|
||||
The test runs one tensor warp on a single 16x16x32 HGEMM tile:
|
||||
|
||||
- A is fp16 1.0
|
||||
- B is fp16 2.0
|
||||
- C starts as fp32 1.0
|
||||
- expected output is fp32 65.0 for all 16x16 C elements
|
||||
|
||||
Scalar warp 0 initializes the shared-memory B tile, spawns only the first tensor
|
||||
warp, waits for completion, and checks all 256 output words copied back from
|
||||
TMEM. Success reports `WU_CASE_PASS` through `tohost`.
|
||||
119
kernels/hgemm_validation/kernel.cpp
Normal file
119
kernels/hgemm_validation/kernel.cpp
Normal file
@@ -0,0 +1,119 @@
|
||||
#include "../wu_arch_cases/common_wu_min.h"
|
||||
|
||||
#define DEV_SMEM_START_ADDR 0xff000000u
|
||||
#define HGEMM_VALIDATION_BASE 0x7700u
|
||||
|
||||
#define HGEMM_M 16u
|
||||
#define HGEMM_N 16u
|
||||
#define HGEMM_K 32u
|
||||
#define HGEMM_TILE_BYTES 1024u
|
||||
#define HGEMM_FRAGMENT_BYTES 16u
|
||||
#define HGEMM_FRAGMENTS (HGEMM_TILE_BYTES / HGEMM_FRAGMENT_BYTES)
|
||||
#define HGEMM_OUT_WORDS (HGEMM_M * HGEMM_N)
|
||||
#define HGEMM_EXPECTED 0x42820000u
|
||||
|
||||
#define BW_REP2(x) x, x
|
||||
#define BW_REP4(x) BW_REP2(x), BW_REP2(x)
|
||||
|
||||
extern "C" {
|
||||
volatile uint32_t g_hgemm_a_frag[4] __attribute__((aligned(16))) = {
|
||||
BW_REP4(0x3c003c00u)};
|
||||
volatile uint32_t g_hgemm_b_frag[4] __attribute__((aligned(16))) = {
|
||||
BW_REP4(0x40004000u)};
|
||||
volatile uint32_t g_hgemm_c_frag[4] __attribute__((aligned(16))) = {
|
||||
BW_REP4(0x3f800000u)};
|
||||
volatile uint32_t g_hgemm_out[HGEMM_OUT_WORDS] __attribute__((aligned(16)));
|
||||
}
|
||||
|
||||
#undef BW_REP2
|
||||
#undef BW_REP4
|
||||
|
||||
extern "C" void __attribute__((naked, noinline, used)) hgemm_validation_worker() {
|
||||
asm volatile(
|
||||
"li x1, 0\n\t"
|
||||
"li x2, %[tile_bytes]\n\t"
|
||||
"la x6, g_hgemm_a_frag\n\t"
|
||||
"la x3, g_hgemm_c_frag\n\t"
|
||||
"li x7, 0\n\t"
|
||||
"1:\n\t"
|
||||
"add x4, x1, x7\n\t"
|
||||
".insn r %[custom3], 2, 0, x0, x4, x6\n\t"
|
||||
"add x4, x2, x7\n\t"
|
||||
".insn r %[custom3], 2, 0, x0, x4, x3\n\t"
|
||||
"addi x7, x7, %[frag_bytes]\n\t"
|
||||
"li x4, %[tile_bytes]\n\t"
|
||||
"blt x7, x4, 1b\n\t"
|
||||
".insn r %[custom3], 3, 0, x0, x0, x0\n\t"
|
||||
"li x4, %[smem_base]\n\t"
|
||||
".insn r %[custom3], 0, 0, x2, x1, x4\n\t"
|
||||
".insn r %[custom3], 1, 0, x0, x0, x0\n\t"
|
||||
"la x3, g_hgemm_out\n\t"
|
||||
"li x7, 0\n\t"
|
||||
"2:\n\t"
|
||||
"add x4, x2, x7\n\t"
|
||||
"add x1, x3, x7\n\t"
|
||||
".insn r %[custom3], 6, 0, x0, x4, x1\n\t"
|
||||
"addi x7, x7, %[frag_bytes]\n\t"
|
||||
"li x4, %[tile_bytes]\n\t"
|
||||
"blt x7, x4, 2b\n\t"
|
||||
".insn r %[custom3], 3, 0, x0, x0, x0\n\t"
|
||||
"csrr x5, %[csr_wid]\n\t"
|
||||
"slli x6, x5, 2\n\t"
|
||||
"la x7, g_seen\n\t"
|
||||
"add x7, x7, x6\n\t"
|
||||
"li x6, %[done_base]\n\t"
|
||||
"or x6, x6, x5\n\t"
|
||||
"sw x6, 0(x7)\n\t"
|
||||
".insn r %[custom0], 0, 0, x0, x0, x0\n\t"
|
||||
"3: j 3b\n\t"
|
||||
:
|
||||
: [csr_wid] "i"(VX_CSR_WARP_ID),
|
||||
[custom0] "i"(RISCV_CUSTOM0),
|
||||
[custom3] "i"(RISCV_CUSTOM3),
|
||||
[smem_base] "i"(DEV_SMEM_START_ADDR),
|
||||
[done_base] "i"(HGEMM_VALIDATION_BASE),
|
||||
[tile_bytes] "i"(HGEMM_TILE_BYTES),
|
||||
[frag_bytes] "i"(HGEMM_FRAGMENT_BYTES)
|
||||
: "memory");
|
||||
}
|
||||
|
||||
extern "C" int wu_main() {
|
||||
if (!wu_is_leader()) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
wu_case_reset();
|
||||
|
||||
for (uint32_t i = 0; i < HGEMM_OUT_WORDS; ++i) {
|
||||
g_hgemm_out[i] = 0;
|
||||
}
|
||||
|
||||
volatile uint32_t *smem_b =
|
||||
reinterpret_cast<volatile uint32_t *>(DEV_SMEM_START_ADDR);
|
||||
for (uint32_t frag = 0; frag < HGEMM_FRAGMENTS; ++frag) {
|
||||
const uint32_t row = frag * 4u;
|
||||
for (uint32_t i = 0; i < 4u; ++i) {
|
||||
smem_b[row + i] = g_hgemm_b_frag[i];
|
||||
}
|
||||
}
|
||||
|
||||
const uint32_t tensor_wid = NUM_SCALAR_WARPS;
|
||||
vx_spawn_tensor(1u << tensor_wid, hgemm_validation_worker);
|
||||
|
||||
if (wu_wait_seen_mask(1u << tensor_wid, HGEMM_VALIDATION_BASE) != 0) {
|
||||
wu_case_fail(0x09u);
|
||||
return 1;
|
||||
}
|
||||
|
||||
for (uint32_t i = 0; i < HGEMM_OUT_WORDS; ++i) {
|
||||
if (g_hgemm_out[i] != HGEMM_EXPECTED) {
|
||||
g_aux[0] = i;
|
||||
g_aux[1] = g_hgemm_out[i];
|
||||
wu_case_fail(0x20u);
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
wu_case_pass();
|
||||
return 0;
|
||||
}
|
||||
@@ -7,7 +7,19 @@ CASES := \
|
||||
case05_tensor_barrier \
|
||||
case06_masked_barrier \
|
||||
case07_tensor_csr_tmc \
|
||||
case08_tensor_lsu_optional
|
||||
case08_tensor_lsu_optional \
|
||||
case09_scalar_tmem_ldst \
|
||||
case10_tensor_scalar_tmem_handoff \
|
||||
case11_scalar_tmem_softmax_stage \
|
||||
case12_flash_pv_accum \
|
||||
case12_1_scalar_tmem_cb_probe \
|
||||
case12_2_flash_pv_p_probe \
|
||||
case12_3_scalar_tmem_lane_store \
|
||||
case13_flash_pv_two_warps \
|
||||
case14_flash_pv_k64 \
|
||||
case15_flash_softmax_pv_stage \
|
||||
case16_flash_full_pipeline \
|
||||
case17_flash_exp_softmax_probe
|
||||
|
||||
SMOKE_CASES := \
|
||||
case00_boot_scalar \
|
||||
|
||||
@@ -13,9 +13,27 @@ This directory contains small bare-metal kernels for incremental Wu architecture
|
||||
- `case06_masked_barrier`: explicit mixed `BAR_MASK` with scalar warp 0 and tensor warps.
|
||||
- `case07_tensor_csr_tmc`: tensor CSR/TMC path without barrier behavior.
|
||||
- `case08_tensor_lsu_optional`: tensor LSU store/load marker path; keep last because memory interaction is broader and slower.
|
||||
- `case09_scalar_tmem_ldst`: scalar warp direct TMEM store/load path for the banked TMEM softmax mechanism.
|
||||
- `case10_tensor_scalar_tmem_handoff`: tensor BWGMMA result in TMEM C observed by scalar TMEM loads.
|
||||
- `case11_scalar_tmem_softmax_stage`: scalar TMEM transform written back for tensor-side copy-out.
|
||||
- `case12_flash_pv_accum`: one tensor warp consumes scalar-written `P` in TMEM A for `O = O + P @ V`.
|
||||
- `case12_1_scalar_tmem_cb_probe`: scalar-written TMEM A rows copied back by tensor `tcgen05_cb`.
|
||||
- `case12_2_flash_pv_p_probe`: case12 P-write diagnostic using the same scalar fill path.
|
||||
- `case12_3_scalar_tmem_lane_store`: scalar TMEM store lane-coalesced fragment write diagnostic.
|
||||
- `case13_flash_pv_two_warps`: both tensor warps consume scalar-written `P` tiles for two row blocks.
|
||||
- `case14_flash_pv_k64`: two consecutive `K=32` BWGMMA steps accumulate into one PV output.
|
||||
- `case15_flash_softmax_pv_stage`: scalar reads TMEM C, writes softmax-like `P`, and tensor consumes it in PV.
|
||||
- `case16_flash_full_pipeline`: compact `QK -> scalar softmax handoff -> PV` end-to-end FlashAttention-style pipeline.
|
||||
- `case17_flash_exp_softmax_probe`: scalar non-uniform `e^x` softmax probe for generalized FlashAttention.
|
||||
|
||||
Each case has its own `README.md` describing the test objective, RTL surface, and expected pass marker.
|
||||
|
||||
## Debug Notes
|
||||
|
||||
- `TMC_DEBUG_NOTES.md`: reusable notes for diagnosing lane-mask/TMC operand
|
||||
bugs, including the `case12_3_scalar_tmem_lane_store` failure where a
|
||||
lane0-only register value was consumed under an all-lane mask.
|
||||
|
||||
## Build
|
||||
|
||||
Use the suite Makefile from this directory:
|
||||
|
||||
174
kernels/wu_arch_cases/TMC_DEBUG_NOTES.md
Normal file
174
kernels/wu_arch_cases/TMC_DEBUG_NOTES.md
Normal file
@@ -0,0 +1,174 @@
|
||||
# TMC Operand Debug Notes
|
||||
|
||||
This note records a failure pattern found while debugging
|
||||
`case12_3_scalar_tmem_lane_store`.
|
||||
|
||||
## Symptom
|
||||
|
||||
The simulator keeps producing trace output, but the kernel makes no forward
|
||||
progress:
|
||||
|
||||
- a tensor worker repeatedly polls a ready flag such as `g_case_mem[0]`;
|
||||
- the poll load always returns the old value;
|
||||
- the scalar warp has already committed the work before the ready flag store;
|
||||
- the expected ready flag store never appears in the LSU trace.
|
||||
|
||||
For `case12_3_scalar_tmem_lane_store`, the tensor worker was looping at
|
||||
`PC=0x80000034..0x80000048`, repeatedly reading `g_case_mem[0]` at
|
||||
`0x20000430` as `0`. The scalar warp committed the scalar TMEM store and the
|
||||
following `TMC`, but never issued the next ready flag store at `PC=0x8000015c`.
|
||||
|
||||
## Root Cause Pattern
|
||||
|
||||
Vortex scalar registers are lane-local. A value computed while only lane 0 is
|
||||
active is only known to be valid in lane 0. If that register is later consumed
|
||||
by an instruction while multiple lanes are active, the inactive lanes may hold
|
||||
stale or zero values.
|
||||
|
||||
This is especially easy to miss around `TMC`:
|
||||
|
||||
```asm
|
||||
# Bad when xN was defined while only lane 0 was active.
|
||||
vx_tmc all_lanes
|
||||
...
|
||||
vx_tmc xN
|
||||
```
|
||||
|
||||
The failed `case12_3` sequence used a C operand for `1u` that the compiler
|
||||
materialized before switching to all lanes. Runtime trace showed the source
|
||||
register for the final `TMC(1)` as:
|
||||
|
||||
```text
|
||||
rs1_data={0x0, 0x0, 0x0, 0x1}
|
||||
```
|
||||
|
||||
After that `TMC`, the scalar warp did not fetch the ready flag store.
|
||||
|
||||
## Correct Pattern
|
||||
|
||||
When a value is consumed under an all-lane mask, define that value under the
|
||||
same all-lane mask unless the instruction semantics explicitly use only lane 0.
|
||||
|
||||
For switching back to lane 0 from an all-lane region, keep the immediate
|
||||
materialization and the `TMC` adjacent inside the all-lane region:
|
||||
|
||||
```asm
|
||||
vx_tmc all_lanes
|
||||
...
|
||||
fence rw, rw
|
||||
li t2, 1
|
||||
vx_tmc t2
|
||||
```
|
||||
|
||||
The expected trace at the final `TMC` is:
|
||||
|
||||
```text
|
||||
rs1_data={0x1, 0x1, 0x1, 0x1}
|
||||
```
|
||||
|
||||
The library helper `vx_tmc_one()` follows this pattern because it emits
|
||||
`li a0, 1` and `vx_tmc a0` in the same volatile asm block. It is safe when
|
||||
called while all lanes are active.
|
||||
|
||||
## Fast Log Checks
|
||||
|
||||
Use the simulation log to distinguish a simulator stall from a kernel-level
|
||||
wait loop:
|
||||
|
||||
```sh
|
||||
stat -c '%s %y' chipyard/sims/verilator/output/chipyard.harness.TestHarness.VirgoBlackwellConfig/kernel.radiance.log
|
||||
tail -n 80 chipyard/sims/verilator/output/chipyard.harness.TestHarness.VirgoBlackwellConfig/kernel.radiance.log
|
||||
```
|
||||
|
||||
If the file is still growing but the tail repeatedly shows the same PC range,
|
||||
the simulator is alive and the kernel is probably spinning.
|
||||
|
||||
For a ready flag handoff, check both the polling load and the producer store:
|
||||
|
||||
```sh
|
||||
rg -n "0x20000430|0x9900|wid=0, PC=0x8000014|wid=0, PC=0x8000015" \
|
||||
chipyard/sims/verilator/output/chipyard.harness.TestHarness.VirgoBlackwellConfig/kernel.radiance.log
|
||||
```
|
||||
|
||||
Interpretation:
|
||||
|
||||
- load repeatedly returns `0` and no later `0x9900` store exists: producer did
|
||||
not reach the ready flag store;
|
||||
- ready flag store exists but consumer still reads old data: investigate memory
|
||||
ordering or address aliasing;
|
||||
- producer stops immediately after a `TMC`: inspect that `TMC` source register
|
||||
across all lanes.
|
||||
|
||||
## Dump Checks
|
||||
|
||||
Check that the final lane-mask narrowing defines `1` immediately before `TMC`
|
||||
inside the all-lane region:
|
||||
|
||||
```sh
|
||||
rg -n "li\\s+.*1|vx_tmc" \
|
||||
kernels/wu_arch_cases/case12_3_scalar_tmem_lane_store/kernel.radiance.dump
|
||||
```
|
||||
|
||||
The fixed `case12_3` dump has:
|
||||
|
||||
```text
|
||||
li t2, 1
|
||||
vx_tmc t2
|
||||
auipc a3, 1
|
||||
sw a5, ...
|
||||
```
|
||||
|
||||
## Case12-Case15 Audit
|
||||
|
||||
The following cases were checked after fixing `case12_3`:
|
||||
|
||||
- `case12_flash_pv_accum`
|
||||
- `case12_2_flash_pv_p_probe`
|
||||
- `case13_flash_pv_two_warps`
|
||||
- `case14_flash_pv_k64`
|
||||
- `case15_flash_softmax_pv_stage`
|
||||
|
||||
They switch to all lanes with `vx_tmc(wu_bw_all_lanes_mask())` and switch back
|
||||
with `vx_tmc_one()`. Their dumps show the safe adjacent sequence:
|
||||
|
||||
```text
|
||||
li a0, 1
|
||||
vx_tmc a0
|
||||
```
|
||||
|
||||
No analogous source change is needed for those cases.
|
||||
|
||||
## Scalar TMEM Fill Lane Coverage
|
||||
|
||||
`wu_bw_fill_tmem_tile()` must run with all fragment lanes active. Scalar TMEM
|
||||
store writes the active lane's source word into the matching TMEM fragment word:
|
||||
|
||||
```text
|
||||
TMEM[addr].word[lane] = rs2_data[lane]
|
||||
```
|
||||
|
||||
Therefore a fill loop executed with `tmask=0001` only initializes word 0 of
|
||||
each 16-byte fragment. Word 1 and later can retain old TMEM data. In
|
||||
`case12_1_scalar_tmem_cb_probe`, this showed up as a normal completion path
|
||||
followed by verification failure `0x14`; `g_aux[0]` was `1`, and `g_aux[1]`
|
||||
held the stale copied-back word.
|
||||
|
||||
The correct pattern is:
|
||||
|
||||
```c++
|
||||
vx_tmc(wu_bw_all_lanes_mask());
|
||||
wu_bw_fill_tmem_tile(wu_bw_tmem_a_byte_base(0), WU_BW_FP16_ONE_PACKED);
|
||||
vx_tmc_one();
|
||||
```
|
||||
|
||||
The dump should show the fill loop bracketed by all-lane and lane-0 masks:
|
||||
|
||||
```text
|
||||
li a5, 15
|
||||
vx_tmc a5
|
||||
...
|
||||
vx_cmov zero, a0, a5, a2
|
||||
...
|
||||
li a0, 1
|
||||
vx_tmc a0
|
||||
```
|
||||
@@ -3,6 +3,8 @@ VX_SRCS = kernel.cpp
|
||||
VX_CFLAGS += -I..
|
||||
VORTEX_KN_PATH ?= $(realpath ../../../lib)
|
||||
GEMMINI_SW_PATH ?= $(realpath ../../../lib/gemmini)
|
||||
LLVM_VORTEX ?= $(realpath ../../../../toolchain/llvm-r8)
|
||||
RISCV_TOOLCHAIN_PATH ?= $(realpath ../../../../toolchain/vortex-toolchain-prebuilt/riscv-gnu-toolchain)
|
||||
OPTS ?= -n1
|
||||
|
||||
include ../../common.mk
|
||||
|
||||
3
kernels/wu_arch_cases/case09_scalar_tmem_ldst/Makefile
Normal file
3
kernels/wu_arch_cases/case09_scalar_tmem_ldst/Makefile
Normal file
@@ -0,0 +1,3 @@
|
||||
PROJECT = wu_arch_case09_scalar_tmem_ldst
|
||||
|
||||
include ../case.mk
|
||||
7
kernels/wu_arch_cases/case09_scalar_tmem_ldst/README.md
Normal file
7
kernels/wu_arch_cases/case09_scalar_tmem_ldst/README.md
Normal file
@@ -0,0 +1,7 @@
|
||||
# case09_scalar_tmem_ldst
|
||||
|
||||
Smoke test for scalar warp direct TMEM access.
|
||||
|
||||
The leader scalar lane writes a lane-distinct word vector to a TMEM row with
|
||||
`CUSTOM1 func7=0x30 func3=1`, reads it back with `CUSTOM1 func7=0x30 func3=0`,
|
||||
and reports pass only if the lane 0 value matches.
|
||||
39
kernels/wu_arch_cases/case09_scalar_tmem_ldst/kernel.cpp
Normal file
39
kernels/wu_arch_cases/case09_scalar_tmem_ldst/kernel.cpp
Normal file
@@ -0,0 +1,39 @@
|
||||
#include "common_wu_min.h"
|
||||
|
||||
static inline uint32_t wu_scalar_tmem_ld(uint32_t addr) {
|
||||
uint32_t value;
|
||||
asm volatile(".insn r %[custom1], 0, 0x30, %[value], %[addr], x0"
|
||||
: [value] "=r"(value)
|
||||
: [custom1] "i"(RISCV_CUSTOM1), [addr] "r"(addr)
|
||||
: "memory");
|
||||
return value;
|
||||
}
|
||||
|
||||
static inline void wu_scalar_tmem_st(uint32_t addr, uint32_t value) {
|
||||
asm volatile(".insn r %[custom1], 1, 0x30, x0, %[addr], %[value]"
|
||||
:
|
||||
: [custom1] "i"(RISCV_CUSTOM1), [addr] "r"(addr), [value] "r"(value)
|
||||
: "memory");
|
||||
}
|
||||
|
||||
extern "C" int wu_main() {
|
||||
if (!wu_is_leader()) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
wu_case_reset();
|
||||
|
||||
const uint32_t addr = 0x24u;
|
||||
const uint32_t expected = 0x5a090000u | wu_tid();
|
||||
wu_scalar_tmem_st(addr, expected);
|
||||
const uint32_t observed = wu_scalar_tmem_ld(addr);
|
||||
|
||||
if (observed != expected) {
|
||||
g_aux[0] = observed;
|
||||
wu_case_fail(0x09u);
|
||||
return 1;
|
||||
}
|
||||
|
||||
wu_case_pass();
|
||||
return 0;
|
||||
}
|
||||
@@ -0,0 +1,3 @@
|
||||
PROJECT = wu_arch_case10_tensor_scalar_tmem_handoff
|
||||
|
||||
include ../case.mk
|
||||
@@ -0,0 +1,8 @@
|
||||
# case10_tensor_scalar_tmem_handoff
|
||||
|
||||
Validates the tensor-to-scalar TMEM handoff needed by FlashAttention.
|
||||
|
||||
The tensor warp initializes TMEM A/C, runs one BWGMMA, and leaves the result in
|
||||
TMEM C. The scalar leader waits for the tensor warp and then reads several C
|
||||
fragments through the scalar TMEM load instruction. Each lane is expected to
|
||||
observe the HGEMM value `0x42820000`.
|
||||
@@ -0,0 +1,111 @@
|
||||
#include "../common_wu_min.h"
|
||||
|
||||
#define DEV_SMEM_START_ADDR 0xff000000u
|
||||
#define WU_CASE_TMEM_HANDOFF_BASE 0x7700u
|
||||
#define WU_TMEM_TILE_BYTES 1024u
|
||||
#define WU_TMEM_FRAGMENT_BYTES 16u
|
||||
#define WU_TMEM_C_BYTE_BASE 1024u
|
||||
#define WU_TMEM_EXPECTED_HGEMM 0x42820000u
|
||||
|
||||
static_assert(NUM_TENSOR_WARPS >= 1, "case10 requires at least one tensor warp");
|
||||
static_assert(NUM_THREADS == 4, "case10 expects the 4-lane Blackwell tensor core");
|
||||
|
||||
#define BW_REP2(x) x, x
|
||||
#define BW_REP4(x) BW_REP2(x), BW_REP2(x)
|
||||
|
||||
extern "C" {
|
||||
volatile uint32_t g_case10_a_row[4] __attribute__((aligned(16))) = {
|
||||
BW_REP4(0x3c003c00u)};
|
||||
volatile uint32_t g_case10_b_row[4] __attribute__((aligned(16))) = {
|
||||
BW_REP4(0x40004000u)};
|
||||
volatile uint32_t g_case10_c_row[4] __attribute__((aligned(16))) = {
|
||||
BW_REP4(0x3f800000u)};
|
||||
}
|
||||
|
||||
#undef BW_REP2
|
||||
#undef BW_REP4
|
||||
|
||||
static inline uint32_t wu_scalar_tmem_ld(uint32_t frag_addr) {
|
||||
uint32_t value;
|
||||
asm volatile(".insn r %[custom1], 0, 0x30, %[value], %[addr], x0"
|
||||
: [value] "=r"(value)
|
||||
: [custom1] "i"(RISCV_CUSTOM1), [addr] "r"(frag_addr)
|
||||
: "memory");
|
||||
return value;
|
||||
}
|
||||
|
||||
extern "C" void __attribute__((naked, noinline, used)) tensor_case10_worker() {
|
||||
asm volatile(
|
||||
"csrr x5, %[csr_wid]\n\t"
|
||||
"addi x1, x5, -%[num_scalar_warps]\n\t"
|
||||
"slli x1, x1, 11\n\t"
|
||||
"addi x2, x1, 1024\n\t"
|
||||
"la x6, g_case10_a_row\n\t"
|
||||
"la x3, g_case10_c_row\n\t"
|
||||
"li x7, 0\n\t"
|
||||
"1:\n\t"
|
||||
"add x4, x1, x7\n\t"
|
||||
".insn r %[custom3], 2, 0, x0, x4, x6\n\t"
|
||||
"add x4, x2, x7\n\t"
|
||||
".insn r %[custom3], 2, 0, x0, x4, x3\n\t"
|
||||
"addi x7, x7, 16\n\t"
|
||||
"li x4, %[tile_bytes]\n\t"
|
||||
"blt x7, x4, 1b\n\t"
|
||||
".insn r %[custom3], 3, 0, x0, x0, x0\n\t"
|
||||
"li x4, %[smem_base]\n\t"
|
||||
".insn r %[custom3], 0, 0, x2, x1, x4\n\t"
|
||||
".insn r %[custom3], 1, 0, x0, x0, x0\n\t"
|
||||
"slli x6, x5, 2\n\t"
|
||||
"la x7, g_seen\n\t"
|
||||
"add x7, x7, x6\n\t"
|
||||
"li x6, %[handoff_base]\n\t"
|
||||
"or x6, x6, x5\n\t"
|
||||
"sw x6, 0(x7)\n\t"
|
||||
".insn r %[custom0], 0, 0, x0, x0, x0\n\t"
|
||||
"2: j 2b\n\t"
|
||||
:
|
||||
: [csr_wid] "i"(VX_CSR_WARP_ID),
|
||||
[custom0] "i"(RISCV_CUSTOM0),
|
||||
[custom3] "i"(RISCV_CUSTOM3),
|
||||
[smem_base] "i"(DEV_SMEM_START_ADDR),
|
||||
[handoff_base] "i"(WU_CASE_TMEM_HANDOFF_BASE),
|
||||
[num_scalar_warps] "i"(NUM_SCALAR_WARPS),
|
||||
[tile_bytes] "i"(WU_TMEM_TILE_BYTES)
|
||||
: "memory");
|
||||
}
|
||||
|
||||
extern "C" int wu_main() {
|
||||
if (!wu_is_leader()) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
wu_case_reset();
|
||||
|
||||
volatile uint32_t *smem_b =
|
||||
reinterpret_cast<volatile uint32_t *>(DEV_SMEM_START_ADDR);
|
||||
for (uint32_t i = 0; i < WU_TMEM_TILE_BYTES / sizeof(uint32_t); ++i) {
|
||||
smem_b[i] = g_case10_b_row[i & 3u];
|
||||
}
|
||||
asm volatile("fence rw, rw" ::: "memory");
|
||||
|
||||
vx_spawn_tensor(1u << NUM_SCALAR_WARPS, tensor_case10_worker);
|
||||
|
||||
if (wu_wait_seen_mask(1u << NUM_SCALAR_WARPS, WU_CASE_TMEM_HANDOFF_BASE) != 0) {
|
||||
wu_case_fail(0x10u);
|
||||
return 1;
|
||||
}
|
||||
|
||||
const uint32_t c_frag_base = WU_TMEM_C_BYTE_BASE / WU_TMEM_FRAGMENT_BYTES;
|
||||
for (uint32_t i = 0; i < 8; ++i) {
|
||||
const uint32_t observed = wu_scalar_tmem_ld(c_frag_base + i);
|
||||
if (observed != WU_TMEM_EXPECTED_HGEMM) {
|
||||
g_aux[0] = i;
|
||||
g_aux[1] = observed;
|
||||
wu_case_fail(0x11u);
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
wu_case_pass();
|
||||
return 0;
|
||||
}
|
||||
@@ -0,0 +1,3 @@
|
||||
PROJECT = wu_arch_case11_scalar_tmem_softmax_stage
|
||||
|
||||
include ../case.mk
|
||||
@@ -0,0 +1,10 @@
|
||||
# case11_scalar_tmem_softmax_stage
|
||||
|
||||
Validates the FlashAttention softmax-stage TMEM path.
|
||||
|
||||
The tensor warp writes HGEMM results into TMEM C. The scalar warp reads TMEM C,
|
||||
performs a deterministic lane-wise transform, and writes all four lanes back
|
||||
into one TMEM A-region fragment. The tensor warp then uses TCGEN05_CB to copy
|
||||
that TMEM fragment to global memory so scalar code can verify that tensor-side
|
||||
TMEM reads observe the scalar writeback with the correct lane ordering and
|
||||
write mask.
|
||||
@@ -0,0 +1,200 @@
|
||||
#include "../common_wu_min.h"
|
||||
|
||||
#define DEV_SMEM_START_ADDR 0xff000000u
|
||||
#define WU_CASE_TMEM_STAGE_HGEMM_BASE 0x7800u
|
||||
#define WU_CASE_TMEM_STAGE_VERIFY_BASE 0x7900u
|
||||
#define WU_TMEM_TILE_BYTES 1024u
|
||||
#define WU_TMEM_FRAGMENT_BYTES 16u
|
||||
#define WU_TMEM_C_BYTE_BASE 1024u
|
||||
#define WU_TMEM_SCALAR_WRITE_BYTE_ADDR 512u
|
||||
#define WU_TMEM_EXPECTED_HGEMM 0x42820000u
|
||||
#define WU_TMEM_STAGE_VALUE_BASE 0x42820001u
|
||||
|
||||
static_assert(NUM_TENSOR_WARPS >= 1, "case11 requires at least one tensor warp");
|
||||
static_assert(NUM_THREADS == 4, "case11 expects the 4-lane Blackwell tensor core");
|
||||
|
||||
#define BW_REP2(x) x, x
|
||||
#define BW_REP4(x) BW_REP2(x), BW_REP2(x)
|
||||
|
||||
extern "C" {
|
||||
volatile uint32_t g_case11_a_row[4] __attribute__((aligned(16))) = {
|
||||
BW_REP4(0x3c003c00u)};
|
||||
volatile uint32_t g_case11_b_row[4] __attribute__((aligned(16))) = {
|
||||
BW_REP4(0x40004000u)};
|
||||
volatile uint32_t g_case11_c_row[4] __attribute__((aligned(16))) = {
|
||||
BW_REP4(0x3f800000u)};
|
||||
volatile uint32_t g_case11_out[4] __attribute__((aligned(16)));
|
||||
volatile uint32_t g_case11_scalar_seen[4] __attribute__((aligned(16)));
|
||||
}
|
||||
|
||||
#undef BW_REP2
|
||||
#undef BW_REP4
|
||||
|
||||
static inline uint32_t wu_scalar_tmem_ld(uint32_t frag_addr) {
|
||||
uint32_t value;
|
||||
asm volatile(".insn r %[custom1], 0, 0x30, %[value], %[addr], x0"
|
||||
: [value] "=r"(value)
|
||||
: [custom1] "i"(RISCV_CUSTOM1), [addr] "r"(frag_addr)
|
||||
: "memory");
|
||||
return value;
|
||||
}
|
||||
|
||||
static inline void wu_scalar_tmem_st(uint32_t frag_addr, uint32_t value) {
|
||||
asm volatile(".insn r %[custom1], 1, 0x30, x0, %[addr], %[value]"
|
||||
:
|
||||
: [custom1] "i"(RISCV_CUSTOM1), [addr] "r"(frag_addr),
|
||||
[value] "r"(value)
|
||||
: "memory");
|
||||
}
|
||||
|
||||
extern "C" void __attribute__((naked, noinline, used)) tensor_case11_worker() {
|
||||
asm volatile(
|
||||
"csrr x5, %[csr_wid]\n\t"
|
||||
"addi x1, x5, -%[num_scalar_warps]\n\t"
|
||||
"slli x1, x1, 11\n\t"
|
||||
"addi x2, x1, 1024\n\t"
|
||||
"la x6, g_case11_a_row\n\t"
|
||||
"la x3, g_case11_c_row\n\t"
|
||||
"li x7, 0\n\t"
|
||||
"1:\n\t"
|
||||
"add x4, x1, x7\n\t"
|
||||
".insn r %[custom3], 2, 0, x0, x4, x6\n\t"
|
||||
"add x4, x2, x7\n\t"
|
||||
".insn r %[custom3], 2, 0, x0, x4, x3\n\t"
|
||||
"addi x7, x7, 16\n\t"
|
||||
"li x4, %[tile_bytes]\n\t"
|
||||
"blt x7, x4, 1b\n\t"
|
||||
".insn r %[custom3], 3, 0, x0, x0, x0\n\t"
|
||||
"li x4, %[smem_base]\n\t"
|
||||
".insn r %[custom3], 0, 0, x2, x1, x4\n\t"
|
||||
".insn r %[custom3], 1, 0, x0, x0, x0\n\t"
|
||||
"slli x6, x5, 2\n\t"
|
||||
"la x7, g_seen\n\t"
|
||||
"add x7, x7, x6\n\t"
|
||||
"li x6, %[hgemm_base]\n\t"
|
||||
"or x6, x6, x5\n\t"
|
||||
"sw x6, 0(x7)\n\t"
|
||||
"3:\n\t"
|
||||
"la x6, g_case_mem\n\t"
|
||||
"lw x7, 0(x6)\n\t"
|
||||
"li x4, %[verify_base]\n\t"
|
||||
"bne x7, x4, 3b\n\t"
|
||||
"addi x4, x1, %[write_byte_addr]\n\t"
|
||||
"la x6, g_case11_out\n\t"
|
||||
".insn r %[custom3], 6, 0, x0, x4, x6\n\t"
|
||||
".insn r %[custom3], 3, 0, x0, x0, x0\n\t"
|
||||
"csrr x5, %[csr_wid]\n\t"
|
||||
"slli x6, x5, 2\n\t"
|
||||
"la x7, g_seen\n\t"
|
||||
"add x7, x7, x6\n\t"
|
||||
"li x6, %[verify_base]\n\t"
|
||||
"or x6, x6, x5\n\t"
|
||||
"sw x6, 0(x7)\n\t"
|
||||
".insn r %[custom0], 0, 0, x0, x0, x0\n\t"
|
||||
"2: j 2b\n\t"
|
||||
:
|
||||
: [csr_wid] "i"(VX_CSR_WARP_ID),
|
||||
[custom0] "i"(RISCV_CUSTOM0),
|
||||
[custom3] "i"(RISCV_CUSTOM3),
|
||||
[smem_base] "i"(DEV_SMEM_START_ADDR),
|
||||
[hgemm_base] "i"(WU_CASE_TMEM_STAGE_HGEMM_BASE),
|
||||
[verify_base] "i"(WU_CASE_TMEM_STAGE_VERIFY_BASE),
|
||||
[num_scalar_warps] "i"(NUM_SCALAR_WARPS),
|
||||
[tile_bytes] "i"(WU_TMEM_TILE_BYTES),
|
||||
[write_byte_addr] "i"(WU_TMEM_SCALAR_WRITE_BYTE_ADDR)
|
||||
: "memory");
|
||||
}
|
||||
|
||||
extern "C" int wu_main() {
|
||||
if (vx_core_id() != 0 || vx_warp_id() != 0) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
const uint32_t tid = wu_tid();
|
||||
|
||||
if (tid == 0) {
|
||||
wu_case_reset();
|
||||
for (uint32_t i = 0; i < 4; ++i) {
|
||||
g_case11_out[i] = 0;
|
||||
g_case11_scalar_seen[i] = 0;
|
||||
}
|
||||
|
||||
volatile uint32_t *smem_b =
|
||||
reinterpret_cast<volatile uint32_t *>(DEV_SMEM_START_ADDR);
|
||||
for (uint32_t i = 0; i < WU_TMEM_TILE_BYTES / sizeof(uint32_t); ++i) {
|
||||
smem_b[i] = g_case11_b_row[i & 3u];
|
||||
}
|
||||
}
|
||||
asm volatile("fence rw, rw" ::: "memory");
|
||||
|
||||
if (tid == 0) {
|
||||
vx_spawn_tensor(1u << NUM_SCALAR_WARPS, tensor_case11_worker);
|
||||
|
||||
if (wu_wait_seen_mask(1u << NUM_SCALAR_WARPS,
|
||||
WU_CASE_TMEM_STAGE_HGEMM_BASE) != 0) {
|
||||
g_case_mem[1] = 0x20u;
|
||||
}
|
||||
}
|
||||
asm volatile("fence rw, rw" ::: "memory");
|
||||
if (g_case_mem[1] != 0) {
|
||||
if (tid == 0) {
|
||||
wu_case_fail(g_case_mem[1]);
|
||||
}
|
||||
return 1;
|
||||
}
|
||||
|
||||
const uint32_t c_frag_base = WU_TMEM_C_BYTE_BASE / WU_TMEM_FRAGMENT_BYTES;
|
||||
const uint32_t observed = wu_scalar_tmem_ld(c_frag_base);
|
||||
g_case11_scalar_seen[tid] = observed;
|
||||
|
||||
const uint32_t write_frag =
|
||||
WU_TMEM_SCALAR_WRITE_BYTE_ADDR / WU_TMEM_FRAGMENT_BYTES;
|
||||
wu_scalar_tmem_st(write_frag, observed + 1u + tid);
|
||||
asm volatile("fence rw, rw" ::: "memory");
|
||||
|
||||
if (tid == 0) {
|
||||
if (g_case11_scalar_seen[0] != WU_TMEM_EXPECTED_HGEMM) {
|
||||
g_aux[0] = 0;
|
||||
g_aux[1] = g_case11_scalar_seen[0];
|
||||
g_case_mem[1] = 0x21u;
|
||||
}
|
||||
}
|
||||
asm volatile("fence rw, rw" ::: "memory");
|
||||
if (g_case_mem[1] != 0) {
|
||||
if (tid == 0) {
|
||||
wu_case_fail(g_case_mem[1]);
|
||||
}
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (tid == 0) {
|
||||
g_case_mem[0] = WU_CASE_TMEM_STAGE_VERIFY_BASE;
|
||||
}
|
||||
asm volatile("fence rw, rw" ::: "memory");
|
||||
|
||||
if (tid == 0) {
|
||||
if (wu_wait_seen_mask(1u << NUM_SCALAR_WARPS,
|
||||
WU_CASE_TMEM_STAGE_VERIFY_BASE) != 0) {
|
||||
g_case_mem[1] = 0x22u;
|
||||
}
|
||||
}
|
||||
asm volatile("fence rw, rw" ::: "memory");
|
||||
if (g_case_mem[1] != 0) {
|
||||
if (tid == 0) {
|
||||
wu_case_fail(g_case_mem[1]);
|
||||
}
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (tid == 0) {
|
||||
if (g_case11_out[0] != WU_TMEM_STAGE_VALUE_BASE) {
|
||||
g_aux[0] = 0;
|
||||
g_aux[1] = g_case11_out[0];
|
||||
wu_case_fail(0x23u);
|
||||
return 1;
|
||||
}
|
||||
|
||||
wu_case_pass();
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
@@ -0,0 +1,3 @@
|
||||
PROJECT = case12_flash_pv_accum
|
||||
|
||||
include ../case.mk
|
||||
@@ -0,0 +1,10 @@
|
||||
# case12.1 scalar TMEM CB probe
|
||||
|
||||
Validates the scalar-to-tensor TMEM handoff without BWGMMA. Scalar warp 0 fills
|
||||
tensor block 0 TMEM A rows `0..63` with packed fp16 `1.0`. Tensor warp
|
||||
`NUM_SCALAR_WARPS` waits for that fill, then uses `tcgen05_cb` to copy the same
|
||||
TMEM A rows back to global memory.
|
||||
|
||||
The scalar leader verifies all 256 copied words are `0x3c003c00`. A mismatch
|
||||
means scalar TMEM stores are not visible to the tensor TMEM read/copy path at
|
||||
the expected row addresses.
|
||||
108
kernels/wu_arch_cases/case12_1_scalar_tmem_cb_probe/kernel.cpp
Normal file
108
kernels/wu_arch_cases/case12_1_scalar_tmem_cb_probe/kernel.cpp
Normal file
@@ -0,0 +1,108 @@
|
||||
#include "../common_wu_blackwell_fa.h"
|
||||
|
||||
#define WU_CASE12_1_INIT_BASE 0x8d00u
|
||||
#define WU_CASE12_1_DONE_BASE 0x8e00u
|
||||
#define WU_CASE12_1_COPY_READY 0x8f00u
|
||||
|
||||
extern "C" {
|
||||
volatile uint32_t g_case12_1_out[WU_BW_OUT_WORDS] __attribute__((aligned(16)));
|
||||
}
|
||||
|
||||
extern "C" void __attribute__((naked, noinline, used)) tensor_case12_1_worker() {
|
||||
asm volatile(
|
||||
"csrr x5, %[csr_wid]\n\t"
|
||||
"addi x1, x5, -%[num_scalar_warps]\n\t"
|
||||
"slli x1, x1, 11\n\t"
|
||||
"slli x6, x5, 2\n\t"
|
||||
"la x7, g_seen\n\t"
|
||||
"add x7, x7, x6\n\t"
|
||||
"li x6, %[init_base]\n\t"
|
||||
"or x6, x6, x5\n\t"
|
||||
"sw x6, 0(x7)\n\t"
|
||||
"1:\n\t"
|
||||
"la x6, g_case_mem\n\t"
|
||||
"lw x7, 0(x6)\n\t"
|
||||
"li x4, %[copy_ready]\n\t"
|
||||
"bne x7, x4, 1b\n\t"
|
||||
"la x3, g_case12_1_out\n\t"
|
||||
"li x7, 0\n\t"
|
||||
"2:\n\t"
|
||||
"add x4, x1, x7\n\t"
|
||||
"add x6, x3, x7\n\t"
|
||||
".insn r %[custom3], 6, 0, x0, x4, x6\n\t"
|
||||
"addi x7, x7, 16\n\t"
|
||||
"li x4, %[tile_bytes]\n\t"
|
||||
"blt x7, x4, 2b\n\t"
|
||||
".insn r %[custom3], 3, 0, x0, x0, x0\n\t"
|
||||
"csrr x5, %[csr_wid]\n\t"
|
||||
"slli x6, x5, 2\n\t"
|
||||
"la x7, g_seen\n\t"
|
||||
"add x7, x7, x6\n\t"
|
||||
"li x6, %[done_base]\n\t"
|
||||
"or x6, x6, x5\n\t"
|
||||
"sw x6, 0(x7)\n\t"
|
||||
".insn r %[custom0], 0, 0, x0, x0, x0\n\t"
|
||||
"3: j 3b\n\t"
|
||||
:
|
||||
: [csr_wid] "i"(VX_CSR_WARP_ID), [custom0] "i"(RISCV_CUSTOM0),
|
||||
[custom3] "i"(RISCV_CUSTOM3),
|
||||
[num_scalar_warps] "i"(NUM_SCALAR_WARPS),
|
||||
[tile_bytes] "i"(WU_BW_TMEM_TILE_BYTES),
|
||||
[init_base] "i"(WU_CASE12_1_INIT_BASE),
|
||||
[done_base] "i"(WU_CASE12_1_DONE_BASE),
|
||||
[copy_ready] "i"(WU_CASE12_1_COPY_READY)
|
||||
: "memory");
|
||||
}
|
||||
|
||||
extern "C" int wu_main() {
|
||||
if (vx_core_id() != 0 || vx_warp_id() != 0) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
const uint32_t tid = wu_tid();
|
||||
if (tid == 0) {
|
||||
wu_case_reset();
|
||||
for (uint32_t i = 0; i < WU_BW_OUT_WORDS; ++i) {
|
||||
g_case12_1_out[i] = 0;
|
||||
}
|
||||
vx_spawn_tensor(1u << NUM_SCALAR_WARPS, tensor_case12_1_worker);
|
||||
if (wu_wait_seen_mask(1u << NUM_SCALAR_WARPS, WU_CASE12_1_INIT_BASE) !=
|
||||
0) {
|
||||
g_case_mem[1] = 0x12u;
|
||||
}
|
||||
}
|
||||
asm volatile("fence rw, rw" ::: "memory");
|
||||
|
||||
if (g_case_mem[1] == 0) {
|
||||
vx_tmc(wu_bw_all_lanes_mask());
|
||||
wu_bw_fill_tmem_tile(wu_bw_tmem_a_byte_base(0), WU_BW_FP16_ONE_PACKED);
|
||||
vx_tmc_one();
|
||||
}
|
||||
asm volatile("fence rw, rw" ::: "memory");
|
||||
|
||||
if (tid == 0) {
|
||||
g_case_mem[0] = WU_CASE12_1_COPY_READY;
|
||||
if (g_case_mem[1] == 0 &&
|
||||
wu_wait_seen_mask(1u << NUM_SCALAR_WARPS, WU_CASE12_1_DONE_BASE) !=
|
||||
0) {
|
||||
g_case_mem[1] = 0x13u;
|
||||
}
|
||||
if (g_case_mem[1] == 0) {
|
||||
volatile uint32_t bad_actual = 0;
|
||||
const uint32_t bad =
|
||||
wu_bw_verify_constant(g_case12_1_out, WU_BW_OUT_WORDS,
|
||||
WU_BW_FP16_ONE_PACKED, &bad_actual);
|
||||
if (bad != WU_BW_OUT_WORDS) {
|
||||
g_aux[0] = bad;
|
||||
g_aux[1] = bad_actual;
|
||||
g_case_mem[1] = 0x14u;
|
||||
}
|
||||
}
|
||||
if (g_case_mem[1] != 0) {
|
||||
wu_case_fail(g_case_mem[1]);
|
||||
return 1;
|
||||
}
|
||||
wu_case_pass();
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
3
kernels/wu_arch_cases/case12_2_flash_pv_p_probe/Makefile
Normal file
3
kernels/wu_arch_cases/case12_2_flash_pv_p_probe/Makefile
Normal file
@@ -0,0 +1,3 @@
|
||||
PROJECT = case12_2_flash_pv_p_probe
|
||||
|
||||
include ../case.mk
|
||||
11
kernels/wu_arch_cases/case12_2_flash_pv_p_probe/README.md
Normal file
11
kernels/wu_arch_cases/case12_2_flash_pv_p_probe/README.md
Normal file
@@ -0,0 +1,11 @@
|
||||
# case12_2_flash_pv_p_probe
|
||||
|
||||
Diagnostic for `case12_flash_pv_accum`.
|
||||
|
||||
Scalar warp 0 fills tensor block 0 TMEM A with packed fp16 `1.0`, using the
|
||||
same `wu_bw_fill_tmem_tile()` path as case12. Tensor warp `NUM_SCALAR_WARPS`
|
||||
then copies the full TMEM A tile back to global memory with `tcgen05_cb`.
|
||||
|
||||
The scalar leader verifies all copied words are `0x3c003c00`. A mismatch means
|
||||
case12's scalar-written P tile is not reaching the tensor TMEM read/copy path as
|
||||
expected.
|
||||
109
kernels/wu_arch_cases/case12_2_flash_pv_p_probe/kernel.cpp
Normal file
109
kernels/wu_arch_cases/case12_2_flash_pv_p_probe/kernel.cpp
Normal file
@@ -0,0 +1,109 @@
|
||||
#include "../common_wu_blackwell_fa.h"
|
||||
|
||||
#define WU_CASE12_2_INIT_BASE 0x9600u
|
||||
#define WU_CASE12_2_DONE_BASE 0x9700u
|
||||
#define WU_CASE12_2_COPY_READY 0x9800u
|
||||
|
||||
extern "C" {
|
||||
volatile uint32_t g_case12_2_out[WU_BW_OUT_WORDS]
|
||||
__attribute__((aligned(16)));
|
||||
}
|
||||
|
||||
extern "C" void __attribute__((naked, noinline, used))
|
||||
tensor_case12_2_worker() {
|
||||
asm volatile(
|
||||
"csrr x5, %[csr_wid]\n\t"
|
||||
"addi x1, x5, -%[num_scalar_warps]\n\t"
|
||||
"slli x1, x1, 11\n\t"
|
||||
"slli x6, x5, 2\n\t"
|
||||
"la x7, g_seen\n\t"
|
||||
"add x7, x7, x6\n\t"
|
||||
"li x6, %[init_base]\n\t"
|
||||
"or x6, x6, x5\n\t"
|
||||
"sw x6, 0(x7)\n\t"
|
||||
"1:\n\t"
|
||||
"la x6, g_case_mem\n\t"
|
||||
"lw x7, 0(x6)\n\t"
|
||||
"li x4, %[copy_ready]\n\t"
|
||||
"bne x7, x4, 1b\n\t"
|
||||
"la x3, g_case12_2_out\n\t"
|
||||
"li x7, 0\n\t"
|
||||
"2:\n\t"
|
||||
"add x4, x1, x7\n\t"
|
||||
"add x6, x3, x7\n\t"
|
||||
".insn r %[custom3], 6, 0, x0, x4, x6\n\t"
|
||||
"addi x7, x7, 16\n\t"
|
||||
"li x4, %[tile_bytes]\n\t"
|
||||
"blt x7, x4, 2b\n\t"
|
||||
".insn r %[custom3], 3, 0, x0, x0, x0\n\t"
|
||||
"csrr x5, %[csr_wid]\n\t"
|
||||
"slli x6, x5, 2\n\t"
|
||||
"la x7, g_seen\n\t"
|
||||
"add x7, x7, x6\n\t"
|
||||
"li x6, %[done_base]\n\t"
|
||||
"or x6, x6, x5\n\t"
|
||||
"sw x6, 0(x7)\n\t"
|
||||
".insn r %[custom0], 0, 0, x0, x0, x0\n\t"
|
||||
"3: j 3b\n\t"
|
||||
:
|
||||
: [csr_wid] "i"(VX_CSR_WARP_ID), [custom0] "i"(RISCV_CUSTOM0),
|
||||
[custom3] "i"(RISCV_CUSTOM3),
|
||||
[num_scalar_warps] "i"(NUM_SCALAR_WARPS),
|
||||
[tile_bytes] "i"(WU_BW_TMEM_TILE_BYTES),
|
||||
[init_base] "i"(WU_CASE12_2_INIT_BASE),
|
||||
[done_base] "i"(WU_CASE12_2_DONE_BASE),
|
||||
[copy_ready] "i"(WU_CASE12_2_COPY_READY)
|
||||
: "memory");
|
||||
}
|
||||
|
||||
extern "C" int wu_main() {
|
||||
if (vx_core_id() != 0 || vx_warp_id() != 0) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
const uint32_t tid = wu_tid();
|
||||
const uint32_t tensor_mask = 1u << NUM_SCALAR_WARPS;
|
||||
if (tid == 0) {
|
||||
wu_case_reset();
|
||||
for (uint32_t i = 0; i < WU_BW_OUT_WORDS; ++i) {
|
||||
g_case12_2_out[i] = 0;
|
||||
}
|
||||
vx_spawn_tensor(tensor_mask, tensor_case12_2_worker);
|
||||
if (wu_wait_seen_mask(tensor_mask, WU_CASE12_2_INIT_BASE) != 0) {
|
||||
g_case_mem[1] = 0x51u;
|
||||
}
|
||||
}
|
||||
asm volatile("fence rw, rw" ::: "memory");
|
||||
|
||||
if (g_case_mem[1] == 0) {
|
||||
vx_tmc(wu_bw_all_lanes_mask());
|
||||
wu_bw_fill_tmem_tile(wu_bw_tmem_a_byte_base(0), WU_BW_FP16_ONE_PACKED);
|
||||
vx_tmc_one();
|
||||
}
|
||||
asm volatile("fence rw, rw" ::: "memory");
|
||||
|
||||
if (tid == 0) {
|
||||
g_case_mem[0] = WU_CASE12_2_COPY_READY;
|
||||
if (g_case_mem[1] == 0 &&
|
||||
wu_wait_seen_mask(tensor_mask, WU_CASE12_2_DONE_BASE) != 0) {
|
||||
g_case_mem[1] = 0x52u;
|
||||
}
|
||||
if (g_case_mem[1] == 0) {
|
||||
volatile uint32_t bad_actual = 0;
|
||||
const uint32_t bad =
|
||||
wu_bw_verify_constant(g_case12_2_out, WU_BW_OUT_WORDS,
|
||||
WU_BW_FP16_ONE_PACKED, &bad_actual);
|
||||
if (bad != WU_BW_OUT_WORDS) {
|
||||
g_aux[0] = bad;
|
||||
g_aux[1] = bad_actual;
|
||||
g_case_mem[1] = 0x53u;
|
||||
}
|
||||
}
|
||||
if (g_case_mem[1] != 0) {
|
||||
wu_case_fail(g_case_mem[1]);
|
||||
return 1;
|
||||
}
|
||||
wu_case_pass();
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
@@ -0,0 +1,3 @@
|
||||
PROJECT = case12_3_scalar_tmem_lane_store
|
||||
|
||||
include ../case.mk
|
||||
@@ -0,0 +1,14 @@
|
||||
# case12_3_scalar_tmem_lane_store
|
||||
|
||||
Validates scalar TMEM store lane-coalesced semantics.
|
||||
|
||||
One scalar warp writes a single 16-byte TMEM fragment. Each active scalar lane supplies one 32-bit word:
|
||||
|
||||
```text
|
||||
word0 = lane0.rs2
|
||||
word1 = lane1.rs2
|
||||
word2 = lane2.rs2
|
||||
word3 = lane3.rs2
|
||||
```
|
||||
|
||||
The tensor copy-back path then copies that fragment to memory. This catches implementations that only write lane0 data or broadcast one scalar value.
|
||||
@@ -0,0 +1,91 @@
|
||||
#include "../common_wu_blackwell_fa.h"
|
||||
|
||||
#define WU_CASE12_3_COPY_READY 0x9900u
|
||||
#define WU_CASE12_3_DONE_BASE 0x9a00u
|
||||
|
||||
extern "C" {
|
||||
volatile uint32_t g_case12_3_out[4] __attribute__((aligned(16)));
|
||||
}
|
||||
|
||||
extern "C" void __attribute__((naked, noinline, used))
|
||||
tensor_case12_3_copy_worker() {
|
||||
asm volatile(
|
||||
"csrr x5, %[csr_wid]\n\t"
|
||||
"addi x1, x5, -%[num_scalar_warps]\n\t"
|
||||
"slli x1, x1, 11\n\t"
|
||||
"1:\n\t"
|
||||
"la x6, g_case_mem\n\t"
|
||||
"lw x7, 0(x6)\n\t"
|
||||
"li x4, %[copy_ready]\n\t"
|
||||
"bne x7, x4, 1b\n\t"
|
||||
"la x6, g_case12_3_out\n\t"
|
||||
".insn r %[custom3], 6, 0, x0, x1, x6\n\t"
|
||||
".insn r %[custom3], 3, 0, x0, x0, x0\n\t"
|
||||
"csrr x5, %[csr_wid]\n\t"
|
||||
"slli x6, x5, 2\n\t"
|
||||
"la x7, g_seen\n\t"
|
||||
"add x7, x7, x6\n\t"
|
||||
"li x6, %[done_base]\n\t"
|
||||
"or x6, x6, x5\n\t"
|
||||
"sw x6, 0(x7)\n\t"
|
||||
".insn r %[custom0], 0, 0, x0, x0, x0\n\t"
|
||||
"2: j 2b\n\t"
|
||||
:
|
||||
: [csr_wid] "i"(VX_CSR_WARP_ID), [custom0] "i"(RISCV_CUSTOM0),
|
||||
[custom3] "i"(RISCV_CUSTOM3),
|
||||
[num_scalar_warps] "i"(NUM_SCALAR_WARPS),
|
||||
[copy_ready] "i"(WU_CASE12_3_COPY_READY),
|
||||
[done_base] "i"(WU_CASE12_3_DONE_BASE)
|
||||
: "memory");
|
||||
}
|
||||
|
||||
extern "C" int wu_main() {
|
||||
if (vx_core_id() != 0 || vx_warp_id() != 0) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
const uint32_t tensor_mask = 1u << NUM_SCALAR_WARPS;
|
||||
|
||||
wu_case_reset();
|
||||
for (uint32_t i = 0; i < 4; ++i) {
|
||||
g_case12_3_out[i] = 0;
|
||||
}
|
||||
vx_spawn_tensor(tensor_mask, tensor_case12_3_copy_worker);
|
||||
asm volatile("fence rw, rw" ::: "memory");
|
||||
|
||||
const uint32_t frag_addr =
|
||||
wu_bw_tmem_a_byte_base(0) / WU_BW_TMEM_FRAGMENT_BYTES;
|
||||
asm volatile(
|
||||
".insn r %[custom0], 0, 0, x0, %[all_lanes], x0\n\t"
|
||||
"csrr t0, %[csr_tid]\n\t"
|
||||
"lui t1, 0x5a120\n\t"
|
||||
"or t1, t1, t0\n\t"
|
||||
".insn r %[custom1], 1, 0x30, x0, %[addr], t1\n\t"
|
||||
"fence rw, rw\n\t"
|
||||
"li t2, 1\n\t"
|
||||
".insn r %[custom0], 0, 0, x0, t2, x0"
|
||||
:
|
||||
: [custom0] "i"(RISCV_CUSTOM0), [custom1] "i"(RISCV_CUSTOM1),
|
||||
[csr_tid] "i"(VX_CSR_THREAD_ID), [addr] "r"(frag_addr),
|
||||
[all_lanes] "r"(wu_bw_all_lanes_mask())
|
||||
: "t0", "t1", "t2", "memory");
|
||||
|
||||
g_case_mem[0] = WU_CASE12_3_COPY_READY;
|
||||
if (wu_wait_seen_mask(tensor_mask, WU_CASE12_3_DONE_BASE) != 0) {
|
||||
wu_case_fail(0x51u);
|
||||
return 1;
|
||||
}
|
||||
|
||||
for (uint32_t lane = 0; lane < NUM_THREADS; ++lane) {
|
||||
const uint32_t expected = 0x5a120000u | lane;
|
||||
if (g_case12_3_out[lane] != expected) {
|
||||
g_aux[0] = lane;
|
||||
g_aux[1] = g_case12_3_out[lane];
|
||||
wu_case_fail(0x52u);
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
wu_case_pass();
|
||||
return 0;
|
||||
}
|
||||
3
kernels/wu_arch_cases/case12_flash_pv_accum/Makefile
Normal file
3
kernels/wu_arch_cases/case12_flash_pv_accum/Makefile
Normal file
@@ -0,0 +1,3 @@
|
||||
PROJECT = case12_flash_pv_accum
|
||||
|
||||
include ../case.mk
|
||||
7
kernels/wu_arch_cases/case12_flash_pv_accum/README.md
Normal file
7
kernels/wu_arch_cases/case12_flash_pv_accum/README.md
Normal file
@@ -0,0 +1,7 @@
|
||||
# case12_flash_pv_accum
|
||||
|
||||
Validates the first Wu Blackwell FlashAttention PV substage.
|
||||
|
||||
Scalar warp 0 writes a full `16x32` fp16 `P` tile into TMEM A. One tensor warp
|
||||
uses BWGMMA with a `32x16` fp16 `V` tile in SMEM and a preloaded fp32 TMEM C
|
||||
tile. The expected result is `O = 3.0 + P @ V = 67.0` for every output element.
|
||||
127
kernels/wu_arch_cases/case12_flash_pv_accum/kernel.cpp
Normal file
127
kernels/wu_arch_cases/case12_flash_pv_accum/kernel.cpp
Normal file
@@ -0,0 +1,127 @@
|
||||
#include "../common_wu_blackwell_fa.h"
|
||||
|
||||
#define WU_CASE12_INIT_BASE 0x8a00u
|
||||
#define WU_CASE12_DONE_BASE 0x8b00u
|
||||
#define WU_CASE12_P_READY 0x8c00u
|
||||
|
||||
extern "C" {
|
||||
volatile uint32_t g_case12_o_row[4] __attribute__((aligned(16))) = {
|
||||
WU_BW_FP32_THREE, WU_BW_FP32_THREE, WU_BW_FP32_THREE,
|
||||
WU_BW_FP32_THREE};
|
||||
volatile uint32_t g_case12_out[WU_BW_OUT_WORDS] __attribute__((aligned(16)));
|
||||
}
|
||||
|
||||
extern "C" void __attribute__((naked, noinline, used)) tensor_case12_worker() {
|
||||
asm volatile(
|
||||
"csrr x5, %[csr_wid]\n\t"
|
||||
"addi x1, x5, -%[num_scalar_warps]\n\t"
|
||||
"slli x1, x1, 11\n\t"
|
||||
"addi x2, x1, %[c_offset]\n\t"
|
||||
"la x3, g_case12_o_row\n\t"
|
||||
"li x7, 0\n\t"
|
||||
"1:\n\t"
|
||||
"add x4, x2, x7\n\t"
|
||||
".insn r %[custom3], 2, 0, x0, x4, x3\n\t"
|
||||
"addi x7, x7, 16\n\t"
|
||||
"li x4, %[tile_bytes]\n\t"
|
||||
"blt x7, x4, 1b\n\t"
|
||||
".insn r %[custom3], 3, 0, x0, x0, x0\n\t"
|
||||
"slli x6, x5, 2\n\t"
|
||||
"la x7, g_seen\n\t"
|
||||
"add x7, x7, x6\n\t"
|
||||
"li x6, %[init_base]\n\t"
|
||||
"or x6, x6, x5\n\t"
|
||||
"sw x6, 0(x7)\n\t"
|
||||
"2:\n\t"
|
||||
"la x6, g_case_mem\n\t"
|
||||
"lw x7, 0(x6)\n\t"
|
||||
"li x4, %[p_ready]\n\t"
|
||||
"bne x7, x4, 2b\n\t"
|
||||
"li x4, %[smem_base]\n\t"
|
||||
".insn r %[custom3], 0, 0, x2, x1, x4\n\t"
|
||||
".insn r %[custom3], 1, 0, x0, x0, x0\n\t"
|
||||
"la x3, g_case12_out\n\t"
|
||||
"li x7, 0\n\t"
|
||||
"3:\n\t"
|
||||
"add x4, x2, x7\n\t"
|
||||
"add x6, x3, x7\n\t"
|
||||
".insn r %[custom3], 6, 0, x0, x4, x6\n\t"
|
||||
"addi x7, x7, 16\n\t"
|
||||
"li x4, %[tile_bytes]\n\t"
|
||||
"blt x7, x4, 3b\n\t"
|
||||
".insn r %[custom3], 3, 0, x0, x0, x0\n\t"
|
||||
"csrr x5, %[csr_wid]\n\t"
|
||||
"slli x6, x5, 2\n\t"
|
||||
"la x7, g_seen\n\t"
|
||||
"add x7, x7, x6\n\t"
|
||||
"li x6, %[done_base]\n\t"
|
||||
"or x6, x6, x5\n\t"
|
||||
"sw x6, 0(x7)\n\t"
|
||||
".insn r %[custom0], 0, 0, x0, x0, x0\n\t"
|
||||
"4: j 4b\n\t"
|
||||
:
|
||||
: [csr_wid] "i"(VX_CSR_WARP_ID), [custom0] "i"(RISCV_CUSTOM0),
|
||||
[custom3] "i"(RISCV_CUSTOM3),
|
||||
[num_scalar_warps] "i"(NUM_SCALAR_WARPS),
|
||||
[c_offset] "i"(WU_BW_TMEM_C_BYTE_OFFSET),
|
||||
[tile_bytes] "i"(WU_BW_TMEM_TILE_BYTES),
|
||||
[smem_base] "i"(WU_BW_DEV_SMEM_START_ADDR),
|
||||
[init_base] "i"(WU_CASE12_INIT_BASE),
|
||||
[done_base] "i"(WU_CASE12_DONE_BASE),
|
||||
[p_ready] "i"(WU_CASE12_P_READY)
|
||||
: "memory");
|
||||
}
|
||||
|
||||
extern "C" int wu_main() {
|
||||
if (vx_core_id() != 0 || vx_warp_id() != 0) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
const uint32_t tid = wu_tid();
|
||||
if (tid == 0) {
|
||||
wu_case_reset();
|
||||
for (uint32_t i = 0; i < WU_BW_OUT_WORDS; ++i) {
|
||||
g_case12_out[i] = 0;
|
||||
}
|
||||
wu_bw_fill_smem_tile(
|
||||
reinterpret_cast<volatile uint32_t *>(WU_BW_DEV_SMEM_START_ADDR),
|
||||
WU_BW_FP16_TWO_PACKED);
|
||||
vx_spawn_tensor(1u << NUM_SCALAR_WARPS, tensor_case12_worker);
|
||||
if (wu_wait_seen_mask(1u << NUM_SCALAR_WARPS, WU_CASE12_INIT_BASE) != 0) {
|
||||
g_case_mem[1] = 0x12u;
|
||||
}
|
||||
}
|
||||
asm volatile("fence rw, rw" ::: "memory");
|
||||
|
||||
if (g_case_mem[1] == 0) {
|
||||
vx_tmc(wu_bw_all_lanes_mask());
|
||||
wu_bw_fill_tmem_tile(wu_bw_tmem_a_byte_base(0), WU_BW_FP16_ONE_PACKED);
|
||||
vx_tmc_one();
|
||||
}
|
||||
asm volatile("fence rw, rw" ::: "memory");
|
||||
|
||||
if (tid == 0) {
|
||||
g_case_mem[0] = WU_CASE12_P_READY;
|
||||
if (g_case_mem[1] == 0 &&
|
||||
wu_wait_seen_mask(1u << NUM_SCALAR_WARPS, WU_CASE12_DONE_BASE) != 0) {
|
||||
g_case_mem[1] = 0x13u;
|
||||
}
|
||||
if (g_case_mem[1] == 0) {
|
||||
volatile uint32_t bad_actual = 0;
|
||||
const uint32_t bad =
|
||||
wu_bw_verify_constant(g_case12_out, WU_BW_OUT_WORDS,
|
||||
WU_BW_FP32_SIXTY_SEVEN, &bad_actual);
|
||||
if (bad != WU_BW_OUT_WORDS) {
|
||||
g_aux[0] = bad;
|
||||
g_aux[1] = bad_actual;
|
||||
g_case_mem[1] = 0x14u;
|
||||
}
|
||||
}
|
||||
if (g_case_mem[1] != 0) {
|
||||
wu_case_fail(g_case_mem[1]);
|
||||
return 1;
|
||||
}
|
||||
wu_case_pass();
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
3
kernels/wu_arch_cases/case13_flash_pv_two_warps/Makefile
Normal file
3
kernels/wu_arch_cases/case13_flash_pv_two_warps/Makefile
Normal file
@@ -0,0 +1,3 @@
|
||||
PROJECT = case13_flash_pv_two_warps
|
||||
|
||||
include ../case.mk
|
||||
@@ -0,0 +1,8 @@
|
||||
# case13_flash_pv_two_warps
|
||||
|
||||
Validates the Wu Blackwell FlashAttention PV substage across both tensor warps.
|
||||
|
||||
Scalar warp 0 writes one `P` tile into each tensor warp's TMEM A partition.
|
||||
Both tensor warps consume the same SMEM `V` tile, accumulate into their own TMEM
|
||||
C tiles, copy out contiguous row blocks, and stop. Every fp32 output is expected
|
||||
to be `67.0`.
|
||||
135
kernels/wu_arch_cases/case13_flash_pv_two_warps/kernel.cpp
Normal file
135
kernels/wu_arch_cases/case13_flash_pv_two_warps/kernel.cpp
Normal file
@@ -0,0 +1,135 @@
|
||||
#include "../common_wu_blackwell_fa.h"
|
||||
|
||||
#define WU_CASE13_INIT_BASE 0x8d00u
|
||||
#define WU_CASE13_DONE_BASE 0x8e00u
|
||||
#define WU_CASE13_P_READY 0x8f00u
|
||||
#define WU_CASE13_OUT_WORDS (NUM_TENSOR_WARPS * WU_BW_OUT_WORDS)
|
||||
|
||||
extern "C" {
|
||||
volatile uint32_t g_case13_o_row[4] __attribute__((aligned(16))) = {
|
||||
WU_BW_FP32_THREE, WU_BW_FP32_THREE, WU_BW_FP32_THREE,
|
||||
WU_BW_FP32_THREE};
|
||||
volatile uint32_t g_case13_out[WU_CASE13_OUT_WORDS]
|
||||
__attribute__((aligned(16)));
|
||||
}
|
||||
|
||||
extern "C" void __attribute__((naked, noinline, used)) tensor_case13_worker() {
|
||||
asm volatile(
|
||||
"csrr x5, %[csr_wid]\n\t"
|
||||
"addi x1, x5, -%[num_scalar_warps]\n\t"
|
||||
"slli x1, x1, 11\n\t"
|
||||
"addi x2, x1, %[c_offset]\n\t"
|
||||
"la x3, g_case13_o_row\n\t"
|
||||
"li x7, 0\n\t"
|
||||
"1:\n\t"
|
||||
"add x4, x2, x7\n\t"
|
||||
".insn r %[custom3], 2, 0, x0, x4, x3\n\t"
|
||||
"addi x7, x7, 16\n\t"
|
||||
"li x4, %[tile_bytes]\n\t"
|
||||
"blt x7, x4, 1b\n\t"
|
||||
".insn r %[custom3], 3, 0, x0, x0, x0\n\t"
|
||||
"slli x6, x5, 2\n\t"
|
||||
"la x7, g_seen\n\t"
|
||||
"add x7, x7, x6\n\t"
|
||||
"li x6, %[init_base]\n\t"
|
||||
"or x6, x6, x5\n\t"
|
||||
"sw x6, 0(x7)\n\t"
|
||||
"2:\n\t"
|
||||
"la x6, g_case_mem\n\t"
|
||||
"lw x7, 0(x6)\n\t"
|
||||
"li x4, %[p_ready]\n\t"
|
||||
"bne x7, x4, 2b\n\t"
|
||||
"li x4, %[smem_base]\n\t"
|
||||
".insn r %[custom3], 0, 0, x2, x1, x4\n\t"
|
||||
".insn r %[custom3], 1, 0, x0, x0, x0\n\t"
|
||||
"csrr x5, %[csr_wid]\n\t"
|
||||
"addi x6, x5, -%[num_scalar_warps]\n\t"
|
||||
"slli x7, x6, 10\n\t"
|
||||
"la x3, g_case13_out\n\t"
|
||||
"add x3, x3, x7\n\t"
|
||||
"li x7, 0\n\t"
|
||||
"3:\n\t"
|
||||
"add x4, x2, x7\n\t"
|
||||
"add x6, x3, x7\n\t"
|
||||
".insn r %[custom3], 6, 0, x0, x4, x6\n\t"
|
||||
"addi x7, x7, 16\n\t"
|
||||
"li x4, %[tile_bytes]\n\t"
|
||||
"blt x7, x4, 3b\n\t"
|
||||
".insn r %[custom3], 3, 0, x0, x0, x0\n\t"
|
||||
"csrr x5, %[csr_wid]\n\t"
|
||||
"slli x6, x5, 2\n\t"
|
||||
"la x7, g_seen\n\t"
|
||||
"add x7, x7, x6\n\t"
|
||||
"li x6, %[done_base]\n\t"
|
||||
"or x6, x6, x5\n\t"
|
||||
"sw x6, 0(x7)\n\t"
|
||||
".insn r %[custom0], 0, 0, x0, x0, x0\n\t"
|
||||
"4: j 4b\n\t"
|
||||
:
|
||||
: [csr_wid] "i"(VX_CSR_WARP_ID), [custom0] "i"(RISCV_CUSTOM0),
|
||||
[custom3] "i"(RISCV_CUSTOM3),
|
||||
[num_scalar_warps] "i"(NUM_SCALAR_WARPS),
|
||||
[c_offset] "i"(WU_BW_TMEM_C_BYTE_OFFSET),
|
||||
[tile_bytes] "i"(WU_BW_TMEM_TILE_BYTES),
|
||||
[smem_base] "i"(WU_BW_DEV_SMEM_START_ADDR),
|
||||
[init_base] "i"(WU_CASE13_INIT_BASE),
|
||||
[done_base] "i"(WU_CASE13_DONE_BASE),
|
||||
[p_ready] "i"(WU_CASE13_P_READY)
|
||||
: "memory");
|
||||
}
|
||||
|
||||
extern "C" int wu_main() {
|
||||
if (vx_core_id() != 0 || vx_warp_id() != 0) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
const uint32_t tid = wu_tid();
|
||||
const uint32_t tensor_mask = vx_tensor_warp_mask();
|
||||
if (tid == 0) {
|
||||
wu_case_reset();
|
||||
for (uint32_t i = 0; i < WU_CASE13_OUT_WORDS; ++i) {
|
||||
g_case13_out[i] = 0;
|
||||
}
|
||||
wu_bw_fill_smem_tile(
|
||||
reinterpret_cast<volatile uint32_t *>(WU_BW_DEV_SMEM_START_ADDR),
|
||||
WU_BW_FP16_TWO_PACKED);
|
||||
vx_spawn_tensor(tensor_mask, tensor_case13_worker);
|
||||
if (wu_wait_seen_mask(tensor_mask, WU_CASE13_INIT_BASE) != 0) {
|
||||
g_case_mem[1] = 0x21u;
|
||||
}
|
||||
}
|
||||
asm volatile("fence rw, rw" ::: "memory");
|
||||
|
||||
if (g_case_mem[1] == 0) {
|
||||
vx_tmc(wu_bw_all_lanes_mask());
|
||||
wu_bw_fill_tmem_tile(wu_bw_tmem_a_byte_base(0), WU_BW_FP16_ONE_PACKED);
|
||||
wu_bw_fill_tmem_tile(wu_bw_tmem_a_byte_base(1), WU_BW_FP16_ONE_PACKED);
|
||||
vx_tmc_one();
|
||||
}
|
||||
asm volatile("fence rw, rw" ::: "memory");
|
||||
|
||||
if (tid == 0) {
|
||||
g_case_mem[0] = WU_CASE13_P_READY;
|
||||
if (g_case_mem[1] == 0 &&
|
||||
wu_wait_seen_mask(tensor_mask, WU_CASE13_DONE_BASE) != 0) {
|
||||
g_case_mem[1] = 0x22u;
|
||||
}
|
||||
if (g_case_mem[1] == 0) {
|
||||
volatile uint32_t bad_actual = 0;
|
||||
const uint32_t bad =
|
||||
wu_bw_verify_constant(g_case13_out, WU_CASE13_OUT_WORDS,
|
||||
WU_BW_FP32_SIXTY_SEVEN, &bad_actual);
|
||||
if (bad != WU_CASE13_OUT_WORDS) {
|
||||
g_aux[0] = bad;
|
||||
g_aux[1] = bad_actual;
|
||||
g_case_mem[1] = 0x23u;
|
||||
}
|
||||
}
|
||||
if (g_case_mem[1] != 0) {
|
||||
wu_case_fail(g_case_mem[1]);
|
||||
return 1;
|
||||
}
|
||||
wu_case_pass();
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
3
kernels/wu_arch_cases/case14_flash_pv_k64/Makefile
Normal file
3
kernels/wu_arch_cases/case14_flash_pv_k64/Makefile
Normal file
@@ -0,0 +1,3 @@
|
||||
PROJECT = case14_flash_pv_k64
|
||||
|
||||
include ../case.mk
|
||||
7
kernels/wu_arch_cases/case14_flash_pv_k64/README.md
Normal file
7
kernels/wu_arch_cases/case14_flash_pv_k64/README.md
Normal file
@@ -0,0 +1,7 @@
|
||||
# case14_flash_pv_k64
|
||||
|
||||
Validates PV accumulation over two Blackwell `K=32` BWGMMA steps.
|
||||
|
||||
Both tensor warps run two consecutive BWGMMA operations against the same TMEM C
|
||||
tile, modeling a `K=64` FlashAttention PV accumulation. With `P=1.0`,
|
||||
`V=2.0`, and `O_init=5.0`, every fp32 output is expected to be `133.0`.
|
||||
136
kernels/wu_arch_cases/case14_flash_pv_k64/kernel.cpp
Normal file
136
kernels/wu_arch_cases/case14_flash_pv_k64/kernel.cpp
Normal file
@@ -0,0 +1,136 @@
|
||||
#include "../common_wu_blackwell_fa.h"
|
||||
|
||||
#define WU_CASE14_INIT_BASE 0x9000u
|
||||
#define WU_CASE14_DONE_BASE 0x9100u
|
||||
#define WU_CASE14_P_READY 0x9200u
|
||||
#define WU_CASE14_OUT_WORDS (NUM_TENSOR_WARPS * WU_BW_OUT_WORDS)
|
||||
|
||||
extern "C" {
|
||||
volatile uint32_t g_case14_o_row[4] __attribute__((aligned(16))) = {
|
||||
WU_BW_FP32_FIVE, WU_BW_FP32_FIVE, WU_BW_FP32_FIVE, WU_BW_FP32_FIVE};
|
||||
volatile uint32_t g_case14_out[WU_CASE14_OUT_WORDS]
|
||||
__attribute__((aligned(16)));
|
||||
}
|
||||
|
||||
extern "C" void __attribute__((naked, noinline, used)) tensor_case14_worker() {
|
||||
asm volatile(
|
||||
"csrr x5, %[csr_wid]\n\t"
|
||||
"addi x1, x5, -%[num_scalar_warps]\n\t"
|
||||
"slli x1, x1, 11\n\t"
|
||||
"addi x2, x1, %[c_offset]\n\t"
|
||||
"la x3, g_case14_o_row\n\t"
|
||||
"li x7, 0\n\t"
|
||||
"1:\n\t"
|
||||
"add x4, x2, x7\n\t"
|
||||
".insn r %[custom3], 2, 0, x0, x4, x3\n\t"
|
||||
"addi x7, x7, 16\n\t"
|
||||
"li x4, %[tile_bytes]\n\t"
|
||||
"blt x7, x4, 1b\n\t"
|
||||
".insn r %[custom3], 3, 0, x0, x0, x0\n\t"
|
||||
"slli x6, x5, 2\n\t"
|
||||
"la x7, g_seen\n\t"
|
||||
"add x7, x7, x6\n\t"
|
||||
"li x6, %[init_base]\n\t"
|
||||
"or x6, x6, x5\n\t"
|
||||
"sw x6, 0(x7)\n\t"
|
||||
"2:\n\t"
|
||||
"la x6, g_case_mem\n\t"
|
||||
"lw x7, 0(x6)\n\t"
|
||||
"li x4, %[p_ready]\n\t"
|
||||
"bne x7, x4, 2b\n\t"
|
||||
"li x4, %[smem_base]\n\t"
|
||||
".insn r %[custom3], 0, 0, x2, x1, x4\n\t"
|
||||
".insn r %[custom3], 1, 0, x0, x0, x0\n\t"
|
||||
".insn r %[custom3], 0, 0, x2, x1, x4\n\t"
|
||||
".insn r %[custom3], 1, 0, x0, x0, x0\n\t"
|
||||
"csrr x5, %[csr_wid]\n\t"
|
||||
"addi x6, x5, -%[num_scalar_warps]\n\t"
|
||||
"slli x7, x6, 10\n\t"
|
||||
"la x3, g_case14_out\n\t"
|
||||
"add x3, x3, x7\n\t"
|
||||
"li x7, 0\n\t"
|
||||
"3:\n\t"
|
||||
"add x4, x2, x7\n\t"
|
||||
"add x6, x3, x7\n\t"
|
||||
".insn r %[custom3], 6, 0, x0, x4, x6\n\t"
|
||||
"addi x7, x7, 16\n\t"
|
||||
"li x4, %[tile_bytes]\n\t"
|
||||
"blt x7, x4, 3b\n\t"
|
||||
".insn r %[custom3], 3, 0, x0, x0, x0\n\t"
|
||||
"csrr x5, %[csr_wid]\n\t"
|
||||
"slli x6, x5, 2\n\t"
|
||||
"la x7, g_seen\n\t"
|
||||
"add x7, x7, x6\n\t"
|
||||
"li x6, %[done_base]\n\t"
|
||||
"or x6, x6, x5\n\t"
|
||||
"sw x6, 0(x7)\n\t"
|
||||
".insn r %[custom0], 0, 0, x0, x0, x0\n\t"
|
||||
"4: j 4b\n\t"
|
||||
:
|
||||
: [csr_wid] "i"(VX_CSR_WARP_ID), [custom0] "i"(RISCV_CUSTOM0),
|
||||
[custom3] "i"(RISCV_CUSTOM3),
|
||||
[num_scalar_warps] "i"(NUM_SCALAR_WARPS),
|
||||
[c_offset] "i"(WU_BW_TMEM_C_BYTE_OFFSET),
|
||||
[tile_bytes] "i"(WU_BW_TMEM_TILE_BYTES),
|
||||
[smem_base] "i"(WU_BW_DEV_SMEM_START_ADDR),
|
||||
[init_base] "i"(WU_CASE14_INIT_BASE),
|
||||
[done_base] "i"(WU_CASE14_DONE_BASE),
|
||||
[p_ready] "i"(WU_CASE14_P_READY)
|
||||
: "memory");
|
||||
}
|
||||
|
||||
extern "C" int wu_main() {
|
||||
if (vx_core_id() != 0 || vx_warp_id() != 0) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
const uint32_t tid = wu_tid();
|
||||
const uint32_t tensor_mask = vx_tensor_warp_mask();
|
||||
if (tid == 0) {
|
||||
wu_case_reset();
|
||||
for (uint32_t i = 0; i < WU_CASE14_OUT_WORDS; ++i) {
|
||||
g_case14_out[i] = 0;
|
||||
}
|
||||
wu_bw_fill_smem_tile(
|
||||
reinterpret_cast<volatile uint32_t *>(WU_BW_DEV_SMEM_START_ADDR),
|
||||
WU_BW_FP16_TWO_PACKED);
|
||||
vx_spawn_tensor(tensor_mask, tensor_case14_worker);
|
||||
if (wu_wait_seen_mask(tensor_mask, WU_CASE14_INIT_BASE) != 0) {
|
||||
g_case_mem[1] = 0x31u;
|
||||
}
|
||||
}
|
||||
asm volatile("fence rw, rw" ::: "memory");
|
||||
|
||||
if (g_case_mem[1] == 0) {
|
||||
vx_tmc(wu_bw_all_lanes_mask());
|
||||
wu_bw_fill_tmem_tile(wu_bw_tmem_a_byte_base(0), WU_BW_FP16_ONE_PACKED);
|
||||
wu_bw_fill_tmem_tile(wu_bw_tmem_a_byte_base(1), WU_BW_FP16_ONE_PACKED);
|
||||
vx_tmc_one();
|
||||
}
|
||||
asm volatile("fence rw, rw" ::: "memory");
|
||||
|
||||
if (tid == 0) {
|
||||
g_case_mem[0] = WU_CASE14_P_READY;
|
||||
if (g_case_mem[1] == 0 &&
|
||||
wu_wait_seen_mask(tensor_mask, WU_CASE14_DONE_BASE) != 0) {
|
||||
g_case_mem[1] = 0x32u;
|
||||
}
|
||||
if (g_case_mem[1] == 0) {
|
||||
volatile uint32_t bad_actual = 0;
|
||||
const uint32_t bad =
|
||||
wu_bw_verify_constant(g_case14_out, WU_CASE14_OUT_WORDS,
|
||||
WU_BW_FP32_ONE_THIRTY_THREE, &bad_actual);
|
||||
if (bad != WU_CASE14_OUT_WORDS) {
|
||||
g_aux[0] = bad;
|
||||
g_aux[1] = bad_actual;
|
||||
g_case_mem[1] = 0x33u;
|
||||
}
|
||||
}
|
||||
if (g_case_mem[1] != 0) {
|
||||
wu_case_fail(g_case_mem[1]);
|
||||
return 1;
|
||||
}
|
||||
wu_case_pass();
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
@@ -0,0 +1,3 @@
|
||||
PROJECT = case15_flash_softmax_pv_stage
|
||||
|
||||
include ../case.mk
|
||||
@@ -0,0 +1,8 @@
|
||||
# case15_flash_softmax_pv_stage
|
||||
|
||||
Validates the scalar softmax-stage handoff into the Blackwell PV GEMM.
|
||||
|
||||
The tensor warp initializes TMEM C with a score/O value of `1.0`. Scalar warp 0
|
||||
reads that value through the scalar TMEM load path, writes a full fp16 `P=1.0`
|
||||
tile into TMEM A, and releases the tensor warp. The tensor warp then runs PV
|
||||
against `V=2.0`; every fp32 output is expected to be `65.0`.
|
||||
145
kernels/wu_arch_cases/case15_flash_softmax_pv_stage/kernel.cpp
Normal file
145
kernels/wu_arch_cases/case15_flash_softmax_pv_stage/kernel.cpp
Normal file
@@ -0,0 +1,145 @@
|
||||
#include "../common_wu_blackwell_fa.h"
|
||||
|
||||
#define WU_CASE15_INIT_BASE 0x9300u
|
||||
#define WU_CASE15_DONE_BASE 0x9400u
|
||||
#define WU_CASE15_P_READY 0x9500u
|
||||
|
||||
extern "C" {
|
||||
volatile uint32_t g_case15_score_row[4] __attribute__((aligned(16))) = {
|
||||
WU_BW_FP32_ONE, WU_BW_FP32_ONE, WU_BW_FP32_ONE, WU_BW_FP32_ONE};
|
||||
volatile uint32_t g_case15_out[WU_BW_OUT_WORDS] __attribute__((aligned(16)));
|
||||
volatile uint32_t g_case15_scalar_seen[4] __attribute__((aligned(16)));
|
||||
}
|
||||
|
||||
extern "C" void __attribute__((naked, noinline, used)) tensor_case15_worker() {
|
||||
asm volatile(
|
||||
"csrr x5, %[csr_wid]\n\t"
|
||||
"addi x1, x5, -%[num_scalar_warps]\n\t"
|
||||
"slli x1, x1, 11\n\t"
|
||||
"addi x2, x1, %[c_offset]\n\t"
|
||||
"la x3, g_case15_score_row\n\t"
|
||||
"li x7, 0\n\t"
|
||||
"1:\n\t"
|
||||
"add x4, x2, x7\n\t"
|
||||
".insn r %[custom3], 2, 0, x0, x4, x3\n\t"
|
||||
"addi x7, x7, 16\n\t"
|
||||
"li x4, %[tile_bytes]\n\t"
|
||||
"blt x7, x4, 1b\n\t"
|
||||
".insn r %[custom3], 3, 0, x0, x0, x0\n\t"
|
||||
"slli x6, x5, 2\n\t"
|
||||
"la x7, g_seen\n\t"
|
||||
"add x7, x7, x6\n\t"
|
||||
"li x6, %[init_base]\n\t"
|
||||
"or x6, x6, x5\n\t"
|
||||
"sw x6, 0(x7)\n\t"
|
||||
"2:\n\t"
|
||||
"la x6, g_case_mem\n\t"
|
||||
"lw x7, 0(x6)\n\t"
|
||||
"li x4, %[p_ready]\n\t"
|
||||
"bne x7, x4, 2b\n\t"
|
||||
"li x4, %[smem_base]\n\t"
|
||||
".insn r %[custom3], 0, 0, x2, x1, x4\n\t"
|
||||
".insn r %[custom3], 1, 0, x0, x0, x0\n\t"
|
||||
"la x3, g_case15_out\n\t"
|
||||
"li x7, 0\n\t"
|
||||
"3:\n\t"
|
||||
"add x4, x2, x7\n\t"
|
||||
"add x6, x3, x7\n\t"
|
||||
".insn r %[custom3], 6, 0, x0, x4, x6\n\t"
|
||||
"addi x7, x7, 16\n\t"
|
||||
"li x4, %[tile_bytes]\n\t"
|
||||
"blt x7, x4, 3b\n\t"
|
||||
".insn r %[custom3], 3, 0, x0, x0, x0\n\t"
|
||||
"csrr x5, %[csr_wid]\n\t"
|
||||
"slli x6, x5, 2\n\t"
|
||||
"la x7, g_seen\n\t"
|
||||
"add x7, x7, x6\n\t"
|
||||
"li x6, %[done_base]\n\t"
|
||||
"or x6, x6, x5\n\t"
|
||||
"sw x6, 0(x7)\n\t"
|
||||
".insn r %[custom0], 0, 0, x0, x0, x0\n\t"
|
||||
"4: j 4b\n\t"
|
||||
:
|
||||
: [csr_wid] "i"(VX_CSR_WARP_ID), [custom0] "i"(RISCV_CUSTOM0),
|
||||
[custom3] "i"(RISCV_CUSTOM3),
|
||||
[num_scalar_warps] "i"(NUM_SCALAR_WARPS),
|
||||
[c_offset] "i"(WU_BW_TMEM_C_BYTE_OFFSET),
|
||||
[tile_bytes] "i"(WU_BW_TMEM_TILE_BYTES),
|
||||
[smem_base] "i"(WU_BW_DEV_SMEM_START_ADDR),
|
||||
[init_base] "i"(WU_CASE15_INIT_BASE),
|
||||
[done_base] "i"(WU_CASE15_DONE_BASE),
|
||||
[p_ready] "i"(WU_CASE15_P_READY)
|
||||
: "memory");
|
||||
}
|
||||
|
||||
extern "C" int wu_main() {
|
||||
if (vx_core_id() != 0 || vx_warp_id() != 0) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
const uint32_t tid = wu_tid();
|
||||
const uint32_t tensor_mask = 1u << NUM_SCALAR_WARPS;
|
||||
if (tid == 0) {
|
||||
wu_case_reset();
|
||||
for (uint32_t i = 0; i < WU_BW_OUT_WORDS; ++i) {
|
||||
g_case15_out[i] = 0;
|
||||
}
|
||||
for (uint32_t i = 0; i < 4; ++i) {
|
||||
g_case15_scalar_seen[i] = 0;
|
||||
}
|
||||
wu_bw_fill_smem_tile(
|
||||
reinterpret_cast<volatile uint32_t *>(WU_BW_DEV_SMEM_START_ADDR),
|
||||
WU_BW_FP16_TWO_PACKED);
|
||||
vx_spawn_tensor(tensor_mask, tensor_case15_worker);
|
||||
if (wu_wait_seen_mask(tensor_mask, WU_CASE15_INIT_BASE) != 0) {
|
||||
g_case_mem[1] = 0x41u;
|
||||
}
|
||||
}
|
||||
asm volatile("fence rw, rw" ::: "memory");
|
||||
|
||||
const uint32_t c_frag =
|
||||
wu_bw_tmem_c_byte_base(0) / WU_BW_TMEM_FRAGMENT_BYTES;
|
||||
const uint32_t observed = wu_bw_scalar_tmem_ld(c_frag);
|
||||
g_case15_scalar_seen[tid] = observed;
|
||||
asm volatile("fence rw, rw" ::: "memory");
|
||||
|
||||
if (tid == 0 && g_case_mem[1] == 0 &&
|
||||
g_case15_scalar_seen[0] != WU_BW_FP32_ONE) {
|
||||
g_aux[0] = 0;
|
||||
g_aux[1] = g_case15_scalar_seen[0];
|
||||
g_case_mem[1] = 0x42u;
|
||||
}
|
||||
asm volatile("fence rw, rw" ::: "memory");
|
||||
|
||||
if (g_case_mem[1] == 0) {
|
||||
vx_tmc(wu_bw_all_lanes_mask());
|
||||
wu_bw_fill_tmem_tile(wu_bw_tmem_a_byte_base(0), WU_BW_FP16_ONE_PACKED);
|
||||
vx_tmc_one();
|
||||
}
|
||||
asm volatile("fence rw, rw" ::: "memory");
|
||||
|
||||
if (tid == 0) {
|
||||
g_case_mem[0] = WU_CASE15_P_READY;
|
||||
if (g_case_mem[1] == 0 &&
|
||||
wu_wait_seen_mask(tensor_mask, WU_CASE15_DONE_BASE) != 0) {
|
||||
g_case_mem[1] = 0x43u;
|
||||
}
|
||||
if (g_case_mem[1] == 0) {
|
||||
volatile uint32_t bad_actual = 0;
|
||||
const uint32_t bad =
|
||||
wu_bw_verify_constant(g_case15_out, WU_BW_OUT_WORDS,
|
||||
WU_BW_FP32_SIXTY_FIVE, &bad_actual);
|
||||
if (bad != WU_BW_OUT_WORDS) {
|
||||
g_aux[0] = bad;
|
||||
g_aux[1] = bad_actual;
|
||||
g_case_mem[1] = 0x44u;
|
||||
}
|
||||
}
|
||||
if (g_case_mem[1] != 0) {
|
||||
wu_case_fail(g_case_mem[1]);
|
||||
return 1;
|
||||
}
|
||||
wu_case_pass();
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
@@ -0,0 +1,3 @@
|
||||
PROJECT = case16_flash_full_pipeline
|
||||
|
||||
include ../case.mk
|
||||
14
kernels/wu_arch_cases/case16_flash_full_pipeline/README.md
Normal file
14
kernels/wu_arch_cases/case16_flash_full_pipeline/README.md
Normal file
@@ -0,0 +1,14 @@
|
||||
# case16_flash_full_pipeline
|
||||
|
||||
Validates a compact end-to-end FlashAttention-style pipeline on the current Wu
|
||||
Blackwell path.
|
||||
|
||||
The tensor warp first computes `S = Q @ K` into TMEM C with `Q=1.0`, `K=1.0`,
|
||||
and `O_init=0.0`, producing a constant fp32 score of `32.0`. Scalar warp 0 reads
|
||||
the score through scalar TMEM load, records it, writes the uniform softmax result
|
||||
`P=1/32` into TMEM A, refills SMEM with `V=2.0`, and releases the tensor warp.
|
||||
The tensor warp reloads TMEM C with `O=0.0`, then computes `O = P @ V`.
|
||||
|
||||
Every output word is expected to be fp32 `2.0`. This case covers the staged
|
||||
`QK -> scalar softmax handoff -> PV` loop without using the legacy
|
||||
`kernels/flash_attention` implementation.
|
||||
180
kernels/wu_arch_cases/case16_flash_full_pipeline/kernel.cpp
Normal file
180
kernels/wu_arch_cases/case16_flash_full_pipeline/kernel.cpp
Normal file
@@ -0,0 +1,180 @@
|
||||
#include "../common_wu_blackwell_fa.h"
|
||||
|
||||
#define WU_CASE16_INIT_BASE 0x9600u
|
||||
#define WU_CASE16_P_READY 0x9700u
|
||||
#define WU_CASE16_DONE_BASE 0x9800u
|
||||
|
||||
#define WU_BW_FP16_ONE_OVER_32_PACKED 0x28002800u
|
||||
#define WU_BW_FP32_ZERO 0x00000000u
|
||||
#define WU_BW_FP32_TWO 0x40000000u
|
||||
#define WU_BW_FP32_THIRTY_TWO 0x42000000u
|
||||
|
||||
extern "C" {
|
||||
volatile uint32_t g_case16_q_row[4] __attribute__((aligned(16))) = {
|
||||
WU_BW_FP16_ONE_PACKED, WU_BW_FP16_ONE_PACKED, WU_BW_FP16_ONE_PACKED,
|
||||
WU_BW_FP16_ONE_PACKED};
|
||||
volatile uint32_t g_case16_zero_row[4] __attribute__((aligned(16))) = {
|
||||
WU_BW_FP32_ZERO, WU_BW_FP32_ZERO, WU_BW_FP32_ZERO, WU_BW_FP32_ZERO};
|
||||
volatile uint32_t g_case16_out[WU_BW_OUT_WORDS] __attribute__((aligned(16)));
|
||||
volatile uint32_t g_case16_scalar_seen[4] __attribute__((aligned(16)));
|
||||
}
|
||||
|
||||
extern "C" void __attribute__((naked, noinline, used)) tensor_case16_worker() {
|
||||
asm volatile(
|
||||
"csrr x5, %[csr_wid]\n\t"
|
||||
"addi x1, x5, -%[num_scalar_warps]\n\t"
|
||||
"slli x1, x1, 11\n\t"
|
||||
"addi x2, x1, %[c_offset]\n\t"
|
||||
"la x3, g_case16_q_row\n\t"
|
||||
"li x7, 0\n\t"
|
||||
"1:\n\t"
|
||||
"add x4, x1, x7\n\t"
|
||||
".insn r %[custom3], 2, 0, x0, x4, x3\n\t"
|
||||
"addi x7, x7, 16\n\t"
|
||||
"li x4, %[tile_bytes]\n\t"
|
||||
"blt x7, x4, 1b\n\t"
|
||||
"la x3, g_case16_zero_row\n\t"
|
||||
"li x7, 0\n\t"
|
||||
"2:\n\t"
|
||||
"add x4, x2, x7\n\t"
|
||||
".insn r %[custom3], 2, 0, x0, x4, x3\n\t"
|
||||
"addi x7, x7, 16\n\t"
|
||||
"li x4, %[tile_bytes]\n\t"
|
||||
"blt x7, x4, 2b\n\t"
|
||||
".insn r %[custom3], 3, 0, x0, x0, x0\n\t"
|
||||
"li x4, %[smem_base]\n\t"
|
||||
".insn r %[custom3], 0, 0, x2, x1, x4\n\t"
|
||||
".insn r %[custom3], 1, 0, x0, x0, x0\n\t"
|
||||
"slli x6, x5, 2\n\t"
|
||||
"la x7, g_seen\n\t"
|
||||
"add x7, x7, x6\n\t"
|
||||
"li x6, %[init_base]\n\t"
|
||||
"or x6, x6, x5\n\t"
|
||||
"sw x6, 0(x7)\n\t"
|
||||
"3:\n\t"
|
||||
"la x6, g_case_mem\n\t"
|
||||
"lw x7, 0(x6)\n\t"
|
||||
"li x4, %[p_ready]\n\t"
|
||||
"bne x7, x4, 3b\n\t"
|
||||
"la x3, g_case16_zero_row\n\t"
|
||||
"li x7, 0\n\t"
|
||||
"4:\n\t"
|
||||
"add x4, x2, x7\n\t"
|
||||
".insn r %[custom3], 2, 0, x0, x4, x3\n\t"
|
||||
"addi x7, x7, 16\n\t"
|
||||
"li x4, %[tile_bytes]\n\t"
|
||||
"blt x7, x4, 4b\n\t"
|
||||
".insn r %[custom3], 3, 0, x0, x0, x0\n\t"
|
||||
"li x4, %[smem_base]\n\t"
|
||||
".insn r %[custom3], 0, 0, x2, x1, x4\n\t"
|
||||
".insn r %[custom3], 1, 0, x0, x0, x0\n\t"
|
||||
"la x3, g_case16_out\n\t"
|
||||
"li x7, 0\n\t"
|
||||
"5:\n\t"
|
||||
"add x4, x2, x7\n\t"
|
||||
"add x6, x3, x7\n\t"
|
||||
".insn r %[custom3], 6, 0, x0, x4, x6\n\t"
|
||||
"addi x7, x7, 16\n\t"
|
||||
"li x4, %[tile_bytes]\n\t"
|
||||
"blt x7, x4, 5b\n\t"
|
||||
".insn r %[custom3], 3, 0, x0, x0, x0\n\t"
|
||||
"csrr x5, %[csr_wid]\n\t"
|
||||
"slli x6, x5, 2\n\t"
|
||||
"la x7, g_seen\n\t"
|
||||
"add x7, x7, x6\n\t"
|
||||
"li x6, %[done_base]\n\t"
|
||||
"or x6, x6, x5\n\t"
|
||||
"sw x6, 0(x7)\n\t"
|
||||
".insn r %[custom0], 0, 0, x0, x0, x0\n\t"
|
||||
"6: j 6b\n\t"
|
||||
:
|
||||
: [csr_wid] "i"(VX_CSR_WARP_ID), [custom0] "i"(RISCV_CUSTOM0),
|
||||
[custom3] "i"(RISCV_CUSTOM3),
|
||||
[num_scalar_warps] "i"(NUM_SCALAR_WARPS),
|
||||
[c_offset] "i"(WU_BW_TMEM_C_BYTE_OFFSET),
|
||||
[tile_bytes] "i"(WU_BW_TMEM_TILE_BYTES),
|
||||
[smem_base] "i"(WU_BW_DEV_SMEM_START_ADDR),
|
||||
[init_base] "i"(WU_CASE16_INIT_BASE),
|
||||
[p_ready] "i"(WU_CASE16_P_READY),
|
||||
[done_base] "i"(WU_CASE16_DONE_BASE)
|
||||
: "memory");
|
||||
}
|
||||
|
||||
extern "C" int wu_main() {
|
||||
if (vx_core_id() != 0 || vx_warp_id() != 0) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
const uint32_t tid = wu_tid();
|
||||
const uint32_t tensor_mask = 1u << NUM_SCALAR_WARPS;
|
||||
if (tid == 0) {
|
||||
wu_case_reset();
|
||||
for (uint32_t i = 0; i < WU_BW_OUT_WORDS; ++i) {
|
||||
g_case16_out[i] = 0;
|
||||
}
|
||||
for (uint32_t i = 0; i < 4; ++i) {
|
||||
g_case16_scalar_seen[i] = 0;
|
||||
}
|
||||
wu_bw_fill_smem_tile(
|
||||
reinterpret_cast<volatile uint32_t *>(WU_BW_DEV_SMEM_START_ADDR),
|
||||
WU_BW_FP16_ONE_PACKED);
|
||||
vx_spawn_tensor(tensor_mask, tensor_case16_worker);
|
||||
if (wu_wait_seen_mask(tensor_mask, WU_CASE16_INIT_BASE) != 0) {
|
||||
g_case_mem[1] = 0x61u;
|
||||
}
|
||||
}
|
||||
asm volatile("fence rw, rw" ::: "memory");
|
||||
|
||||
const uint32_t c_frag =
|
||||
wu_bw_tmem_c_byte_base(0) / WU_BW_TMEM_FRAGMENT_BYTES;
|
||||
const uint32_t observed = wu_bw_scalar_tmem_ld(c_frag);
|
||||
g_case16_scalar_seen[tid] = observed;
|
||||
asm volatile("fence rw, rw" ::: "memory");
|
||||
|
||||
if (tid == 0 && g_case_mem[1] == 0 &&
|
||||
g_case16_scalar_seen[0] != WU_BW_FP32_THIRTY_TWO) {
|
||||
g_aux[0] = 0;
|
||||
g_aux[1] = g_case16_scalar_seen[0];
|
||||
g_case_mem[1] = 0x62u;
|
||||
}
|
||||
asm volatile("fence rw, rw" ::: "memory");
|
||||
|
||||
if (g_case_mem[1] == 0) {
|
||||
vx_tmc(wu_bw_all_lanes_mask());
|
||||
wu_bw_fill_tmem_tile(wu_bw_tmem_a_byte_base(0),
|
||||
WU_BW_FP16_ONE_OVER_32_PACKED);
|
||||
vx_tmc_one();
|
||||
}
|
||||
asm volatile("fence rw, rw" ::: "memory");
|
||||
|
||||
if (tid == 0) {
|
||||
if (g_case_mem[1] == 0) {
|
||||
wu_bw_fill_smem_tile(
|
||||
reinterpret_cast<volatile uint32_t *>(WU_BW_DEV_SMEM_START_ADDR),
|
||||
WU_BW_FP16_TWO_PACKED);
|
||||
}
|
||||
asm volatile("fence rw, rw" ::: "memory");
|
||||
g_case_mem[0] = WU_CASE16_P_READY;
|
||||
if (g_case_mem[1] == 0 &&
|
||||
wu_wait_seen_mask(tensor_mask, WU_CASE16_DONE_BASE) != 0) {
|
||||
g_case_mem[1] = 0x63u;
|
||||
}
|
||||
if (g_case_mem[1] == 0) {
|
||||
volatile uint32_t bad_actual = 0;
|
||||
const uint32_t bad =
|
||||
wu_bw_verify_constant(g_case16_out, WU_BW_OUT_WORDS,
|
||||
WU_BW_FP32_TWO, &bad_actual);
|
||||
if (bad != WU_BW_OUT_WORDS) {
|
||||
g_aux[0] = bad;
|
||||
g_aux[1] = bad_actual;
|
||||
g_case_mem[1] = 0x64u;
|
||||
}
|
||||
}
|
||||
if (g_case_mem[1] != 0) {
|
||||
wu_case_fail(g_case_mem[1]);
|
||||
return 1;
|
||||
}
|
||||
wu_case_pass();
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
@@ -0,0 +1,3 @@
|
||||
PROJECT = case17_flash_exp_softmax_probe
|
||||
|
||||
include ../case.mk
|
||||
@@ -0,0 +1,16 @@
|
||||
# case17_flash_exp_softmax_probe
|
||||
|
||||
Validates that the current scalar Wu path can execute the `e^x` work needed by
|
||||
non-uniform FlashAttention softmax. The current ISA/RTL does not expose a
|
||||
dedicated exp or exp2 instruction, so this case uses scalar fp32 arithmetic to
|
||||
approximate exp.
|
||||
|
||||
The case evaluates a two-element row with scores `{0, ln(2)}`. A numerically
|
||||
stable softmax computes `exp(score - row_max)`, so the exp inputs are
|
||||
`{-ln(2), 0}` and the probabilities should be close to `{1/3, 2/3}`. The
|
||||
inputs are loaded from volatile memory so the compiler cannot fold the result
|
||||
into constants.
|
||||
|
||||
This is intentionally separate from the tensor PV path. If this case fails, the
|
||||
problem is in scalar fp32 arithmetic, exp approximation, or normalization rather
|
||||
than TMEM handoff or BWGMMA.
|
||||
@@ -0,0 +1,81 @@
|
||||
#include "../common_wu_min.h"
|
||||
|
||||
extern "C" {
|
||||
volatile uint32_t g_case17_scores_bits[2] __attribute__((aligned(16))) = {
|
||||
0x00000000u, 0x3f317218u}; // 0.0, ln(2)
|
||||
volatile uint32_t g_case17_out_bits[2] __attribute__((aligned(16)));
|
||||
}
|
||||
|
||||
static inline float wu_case17_bits_to_f32(uint32_t bits) {
|
||||
union {
|
||||
uint32_t u;
|
||||
float f;
|
||||
} v = {bits};
|
||||
return v.f;
|
||||
}
|
||||
|
||||
static inline uint32_t wu_case17_f32_to_bits(float value) {
|
||||
union {
|
||||
float f;
|
||||
uint32_t u;
|
||||
} v = {value};
|
||||
return v.u;
|
||||
}
|
||||
|
||||
static inline float wu_case17_absf(float value) {
|
||||
return value < 0.0f ? -value : value;
|
||||
}
|
||||
|
||||
static inline float wu_case17_exp_neg_ln2_to_0(float x) {
|
||||
const float x2 = x * x;
|
||||
const float x3 = x2 * x;
|
||||
const float x4 = x2 * x2;
|
||||
const float x5 = x4 * x;
|
||||
return 1.0f + x + (0.5f * x2) + (0.1666666716f * x3) +
|
||||
(0.0416666679f * x4) + (0.0083333338f * x5);
|
||||
}
|
||||
|
||||
extern "C" int wu_main() {
|
||||
if (vx_core_id() != 0 || vx_warp_id() != 0) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
const uint32_t tid = wu_tid();
|
||||
if (tid == 0) {
|
||||
wu_case_reset();
|
||||
g_case17_out_bits[0] = 0;
|
||||
g_case17_out_bits[1] = 0;
|
||||
}
|
||||
asm volatile("fence rw, rw" ::: "memory");
|
||||
|
||||
const float score0 = wu_case17_bits_to_f32(g_case17_scores_bits[0]);
|
||||
const float score1 = wu_case17_bits_to_f32(g_case17_scores_bits[1]);
|
||||
const float row_max = score0 > score1 ? score0 : score1;
|
||||
const float e0 = wu_case17_exp_neg_ln2_to_0(score0 - row_max);
|
||||
const float e1 = wu_case17_exp_neg_ln2_to_0(score1 - row_max);
|
||||
const float inv_sum = 1.0f / (e0 + e1);
|
||||
const float p0 = e0 * inv_sum;
|
||||
const float p1 = e1 * inv_sum;
|
||||
|
||||
if (tid == 0) {
|
||||
g_case17_out_bits[0] = wu_case17_f32_to_bits(p0);
|
||||
g_case17_out_bits[1] = wu_case17_f32_to_bits(p1);
|
||||
}
|
||||
asm volatile("fence rw, rw" ::: "memory");
|
||||
|
||||
if (tid == 0) {
|
||||
const float expected0 = 0.3333333433f;
|
||||
const float expected1 = 0.6666666865f;
|
||||
const float tolerance = 0.0015f;
|
||||
const float err0 = wu_case17_absf(p0 - expected0);
|
||||
const float err1 = wu_case17_absf(p1 - expected1);
|
||||
if (err0 > tolerance || err1 > tolerance) {
|
||||
g_aux[0] = g_case17_out_bits[0];
|
||||
g_aux[1] = g_case17_out_bits[1];
|
||||
wu_case_fail(0x71u);
|
||||
return 1;
|
||||
}
|
||||
wu_case_pass();
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
101
kernels/wu_arch_cases/common_wu_blackwell_fa.h
Normal file
101
kernels/wu_arch_cases/common_wu_blackwell_fa.h
Normal file
@@ -0,0 +1,101 @@
|
||||
#ifndef WU_ARCH_CASES_COMMON_WU_BLACKWELL_FA_H
|
||||
#define WU_ARCH_CASES_COMMON_WU_BLACKWELL_FA_H
|
||||
|
||||
#include "common_wu_min.h"
|
||||
|
||||
#define WU_BW_DEV_SMEM_START_ADDR 0xff000000u
|
||||
#define WU_BW_TMEM_TILE_BYTES 1024u
|
||||
#define WU_BW_TMEM_FRAGMENT_BYTES 16u
|
||||
#define WU_BW_TMEM_WARP_BYTES 2048u
|
||||
#define WU_BW_TMEM_C_BYTE_OFFSET 1024u
|
||||
#define WU_BW_TMEM_FRAGMENTS \
|
||||
(WU_BW_TMEM_TILE_BYTES / WU_BW_TMEM_FRAGMENT_BYTES)
|
||||
#define WU_BW_TMEM_WORDS (WU_BW_TMEM_TILE_BYTES / sizeof(uint32_t))
|
||||
#define WU_BW_OUT_WORDS WU_BW_TMEM_WORDS
|
||||
|
||||
#define WU_BW_FP16_ONE_PACKED 0x3c003c00u
|
||||
#define WU_BW_FP16_TWO_PACKED 0x40004000u
|
||||
#define WU_BW_FP32_ONE 0x3f800000u
|
||||
#define WU_BW_FP32_THREE 0x40400000u
|
||||
#define WU_BW_FP32_FIVE 0x40a00000u
|
||||
#define WU_BW_FP32_SIXTY_FIVE 0x42820000u
|
||||
#define WU_BW_FP32_SIXTY_SEVEN 0x42860000u
|
||||
#define WU_BW_FP32_ONE_THIRTY_THREE 0x43050000u
|
||||
|
||||
static_assert(NUM_SCALAR_WARPS == 2,
|
||||
"Wu Blackwell FA staged cases expect two scalar warps");
|
||||
static_assert(NUM_TENSOR_WARPS == 2,
|
||||
"Wu Blackwell FA staged cases expect two tensor warps");
|
||||
static_assert(NUM_THREADS == 4,
|
||||
"Wu Blackwell FA staged cases expect the 4-lane tensor core");
|
||||
|
||||
static inline uint32_t wu_bw_all_lanes_mask() {
|
||||
return (1u << NUM_THREADS) - 1u;
|
||||
}
|
||||
|
||||
static inline uint32_t wu_bw_scalar_tmem_ld(uint32_t frag_addr) {
|
||||
uint32_t value;
|
||||
asm volatile(".insn r %[custom1], 0, 0x30, %[value], %[addr], x0"
|
||||
: [value] "=r"(value)
|
||||
: [custom1] "i"(RISCV_CUSTOM1), [addr] "r"(frag_addr)
|
||||
: "memory");
|
||||
return value;
|
||||
}
|
||||
|
||||
static inline void wu_bw_scalar_tmem_st(uint32_t frag_addr, uint32_t value) {
|
||||
asm volatile(".insn r %[custom1], 1, 0x30, x0, %[addr], %[value]"
|
||||
:
|
||||
: [custom1] "i"(RISCV_CUSTOM1), [addr] "r"(frag_addr),
|
||||
[value] "r"(value)
|
||||
: "memory");
|
||||
}
|
||||
|
||||
static inline uint32_t wu_bw_tmem_a_byte_base(uint32_t tensor_block) {
|
||||
return tensor_block * WU_BW_TMEM_WARP_BYTES;
|
||||
}
|
||||
|
||||
static inline uint32_t wu_bw_tmem_c_byte_base(uint32_t tensor_block) {
|
||||
return wu_bw_tmem_a_byte_base(tensor_block) + WU_BW_TMEM_C_BYTE_OFFSET;
|
||||
}
|
||||
|
||||
static inline void wu_bw_fill_tmem_tile(uint32_t byte_base, uint32_t value) {
|
||||
const uint32_t first_frag = byte_base / WU_BW_TMEM_FRAGMENT_BYTES;
|
||||
for (uint32_t frag = 0; frag < WU_BW_TMEM_FRAGMENTS; ++frag) {
|
||||
// Caller must run this with all TMEM fragment lanes active.
|
||||
wu_bw_scalar_tmem_st(first_frag + frag, value);
|
||||
}
|
||||
asm volatile("fence rw, rw" ::: "memory");
|
||||
}
|
||||
|
||||
static inline void wu_bw_store_tmem_fragment_lane_values(uint32_t byte_base,
|
||||
uint32_t frag,
|
||||
uint32_t lane_value) {
|
||||
const uint32_t first_frag = byte_base / WU_BW_TMEM_FRAGMENT_BYTES;
|
||||
wu_bw_scalar_tmem_st(first_frag + frag, lane_value);
|
||||
asm volatile("fence rw, rw" ::: "memory");
|
||||
}
|
||||
|
||||
static inline void wu_bw_fill_smem_tile(volatile uint32_t *smem,
|
||||
uint32_t value) {
|
||||
for (uint32_t i = 0; i < WU_BW_TMEM_WORDS; ++i) {
|
||||
smem[i] = value;
|
||||
}
|
||||
asm volatile("fence rw, rw" ::: "memory");
|
||||
}
|
||||
|
||||
static inline uint32_t wu_bw_verify_constant(const volatile uint32_t *data,
|
||||
uint32_t words,
|
||||
uint32_t expected,
|
||||
volatile uint32_t *bad_actual) {
|
||||
for (uint32_t i = 0; i < words; ++i) {
|
||||
const uint32_t actual = data[i];
|
||||
if (actual != expected) {
|
||||
*bad_actual = actual;
|
||||
return i;
|
||||
}
|
||||
}
|
||||
*bad_actual = 0;
|
||||
return words;
|
||||
}
|
||||
|
||||
#endif
|
||||
@@ -1,9 +1,15 @@
|
||||
# wu_arch_hgemm
|
||||
|
||||
Tensor-warp HGEMM smoke test for the Wu split scalar/tensor warp configuration
|
||||
with the 4-lane Blackwell tensor-core path.
|
||||
Two-tensor-warp HGEMM smoke test for the Wu split scalar/tensor warp
|
||||
configuration with the 4-lane Blackwell tensor-core path.
|
||||
|
||||
Scalar warp 0 initializes the shared-memory B operand, spawns only the tensor
|
||||
warp mask, waits for tensor warps `NUM_SCALAR_WARPS..NUM_WARPS-1`, and reports
|
||||
completion through `tohost`. Tensor warps execute the Blackwell custom HGEMM
|
||||
instruction sequence using 16-byte fragments and then stop themselves.
|
||||
Scalar warp 0 initializes the shared-memory B operand for a 32x16x32 GEMM,
|
||||
spawns only the tensor warp mask, waits for tensor warps
|
||||
`NUM_SCALAR_WARPS..NUM_WARPS-1`, verifies the combined 32x16 fp32 output, and
|
||||
reports completion through `tohost`.
|
||||
|
||||
Each tensor warp computes one 16x16x32 BWGMMA M-block using 16-byte fragments:
|
||||
warp `NUM_SCALAR_WARPS` produces rows 0..15 and warp `NUM_SCALAR_WARPS + 1`
|
||||
produces rows 16..31. Both tensor warps share the same B tile, copy their C
|
||||
tiles back into one contiguous output matrix, mark their block IDs, and then
|
||||
stop themselves.
|
||||
|
||||
@@ -2,6 +2,24 @@
|
||||
|
||||
#define DEV_SMEM_START_ADDR 0xff000000u
|
||||
#define WU_CASE_TENSOR_HGEMM_BASE 0x7500u
|
||||
#define WU_CASE_TENSOR_HGEMM_BLOCK_BASE 0x7600u
|
||||
|
||||
#define WU_HGEMM_M_BLOCKS 2u
|
||||
#define WU_HGEMM_M_PER_BLOCK 16u
|
||||
#define WU_HGEMM_N 16u
|
||||
#define WU_HGEMM_K 32u
|
||||
#define WU_HGEMM_TILE_BYTES 1024u
|
||||
#define WU_HGEMM_FRAGMENT_BYTES 16u
|
||||
#define WU_HGEMM_B_WORDS (WU_HGEMM_TILE_BYTES / sizeof(uint32_t))
|
||||
#define WU_HGEMM_OUT_WORDS \
|
||||
(WU_HGEMM_M_BLOCKS * WU_HGEMM_M_PER_BLOCK * WU_HGEMM_N)
|
||||
#define WU_HGEMM_EXPECTED 0x42820000u
|
||||
#define WU_HGEMM_NO_FAIL WU_HGEMM_OUT_WORDS
|
||||
|
||||
static_assert(NUM_TENSOR_WARPS == WU_HGEMM_M_BLOCKS,
|
||||
"wu_arch_hgemm expects two tensor warps");
|
||||
static_assert((NUM_THREADS * 2u) <= WU_CASE_MAX_WARPS,
|
||||
"wu_arch_hgemm uses g_aux[2 * tid + {0,1}] for check scratch");
|
||||
|
||||
#define BW_REP2(x) x, x
|
||||
#define BW_REP4(x) BW_REP2(x), BW_REP2(x)
|
||||
@@ -13,6 +31,9 @@ volatile uint32_t g_hgemm_b_row[4] __attribute__((aligned(16))) = {
|
||||
BW_REP4(0x40004000u)};
|
||||
volatile uint32_t g_hgemm_c_row[4] __attribute__((aligned(16))) = {
|
||||
BW_REP4(0x3f800000u)};
|
||||
volatile uint32_t g_hgemm_control_fail __attribute__((aligned(16)));
|
||||
volatile uint32_t g_hgemm_return_code __attribute__((aligned(16)));
|
||||
volatile uint32_t g_hgemm_out[WU_HGEMM_OUT_WORDS] __attribute__((aligned(16)));
|
||||
}
|
||||
|
||||
#undef BW_REP2
|
||||
@@ -21,7 +42,8 @@ volatile uint32_t g_hgemm_c_row[4] __attribute__((aligned(16))) = {
|
||||
extern "C" void __attribute__((naked, noinline, used)) tensor_hgemm_worker() {
|
||||
asm volatile(
|
||||
"csrr x5, %[csr_wid]\n\t"
|
||||
"slli x1, x5, 11\n\t"
|
||||
"addi x1, x5, -%[num_scalar_warps]\n\t"
|
||||
"slli x1, x1, 11\n\t"
|
||||
"addi x2, x1, 1024\n\t"
|
||||
"la x6, g_hgemm_a_row\n\t"
|
||||
"la x3, g_hgemm_c_row\n\t"
|
||||
@@ -39,6 +61,25 @@ extern "C" void __attribute__((naked, noinline, used)) tensor_hgemm_worker() {
|
||||
".insn r %[custom3], 0, 0, x2, x1, x4\n\t"
|
||||
".insn r %[custom3], 1, 0, x0, x0, x0\n\t"
|
||||
"csrr x5, %[csr_wid]\n\t"
|
||||
"addi x6, x5, -%[num_scalar_warps]\n\t"
|
||||
"slli x7, x6, 10\n\t"
|
||||
"la x3, g_hgemm_out\n\t"
|
||||
"add x3, x3, x7\n\t"
|
||||
"li x7, 0\n\t"
|
||||
"3:\n\t"
|
||||
"add x4, x2, x7\n\t"
|
||||
"add x1, x3, x7\n\t"
|
||||
".insn r %[custom3], 6, 0, x0, x4, x1\n\t"
|
||||
"addi x7, x7, 16\n\t"
|
||||
"li x4, %[tile_bytes]\n\t"
|
||||
"blt x7, x4, 3b\n\t"
|
||||
".insn r %[custom3], 3, 0, x0, x0, x0\n\t"
|
||||
"slli x7, x6, 2\n\t"
|
||||
"la x3, g_case_mem\n\t"
|
||||
"add x3, x3, x7\n\t"
|
||||
"li x7, %[hgemm_block_base]\n\t"
|
||||
"or x7, x7, x6\n\t"
|
||||
"sw x7, 0(x3)\n\t"
|
||||
"slli x6, x5, 2\n\t"
|
||||
"la x7, g_seen\n\t"
|
||||
"add x7, x7, x6\n\t"
|
||||
@@ -52,34 +93,91 @@ extern "C" void __attribute__((naked, noinline, used)) tensor_hgemm_worker() {
|
||||
[custom0] "i"(RISCV_CUSTOM0),
|
||||
[custom3] "i"(RISCV_CUSTOM3),
|
||||
[smem_base] "i"(DEV_SMEM_START_ADDR),
|
||||
[hgemm_base] "i"(WU_CASE_TENSOR_HGEMM_BASE)
|
||||
[hgemm_base] "i"(WU_CASE_TENSOR_HGEMM_BASE),
|
||||
[hgemm_block_base] "i"(WU_CASE_TENSOR_HGEMM_BLOCK_BASE),
|
||||
[num_scalar_warps] "i"(NUM_SCALAR_WARPS),
|
||||
[tile_bytes] "i"(WU_HGEMM_TILE_BYTES)
|
||||
: "memory");
|
||||
}
|
||||
|
||||
extern "C" int wu_main() {
|
||||
if (!wu_is_leader()) {
|
||||
if (vx_core_id() != 0 || vx_warp_id() != 0) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
wu_case_reset();
|
||||
const uint32_t tid = wu_tid();
|
||||
|
||||
volatile uint32_t *smem_b =
|
||||
reinterpret_cast<volatile uint32_t *>(DEV_SMEM_START_ADDR);
|
||||
for (uint32_t frag = 0; frag < 64u; ++frag) {
|
||||
const uint32_t row = frag * 4u;
|
||||
for (uint32_t i = 0; i < 4u; ++i) {
|
||||
smem_b[row + i] = g_hgemm_b_row[i];
|
||||
if (tid == 0) {
|
||||
wu_case_reset();
|
||||
g_hgemm_control_fail = 0;
|
||||
g_hgemm_return_code = 0;
|
||||
}
|
||||
asm volatile("fence rw, rw" ::: "memory");
|
||||
|
||||
if (tid == 0) {
|
||||
for (uint32_t i = 0; i < WU_HGEMM_OUT_WORDS; ++i) {
|
||||
g_hgemm_out[i] = 0;
|
||||
}
|
||||
|
||||
volatile uint32_t *smem_b =
|
||||
reinterpret_cast<volatile uint32_t *>(DEV_SMEM_START_ADDR);
|
||||
for (uint32_t i = 0; i < WU_HGEMM_B_WORDS; ++i) {
|
||||
smem_b[i] = g_hgemm_b_row[i & 3u];
|
||||
}
|
||||
}
|
||||
asm volatile("fence rw, rw" ::: "memory");
|
||||
|
||||
vx_spawn_tensor(vx_tensor_warp_mask(), tensor_hgemm_worker);
|
||||
if (tid == 0) {
|
||||
vx_spawn_tensor(vx_tensor_warp_mask(), tensor_hgemm_worker);
|
||||
|
||||
if (wu_wait_seen_range(NUM_SCALAR_WARPS, NUM_WARPS,
|
||||
WU_CASE_TENSOR_HGEMM_BASE) != 0) {
|
||||
wu_case_fail(0x09u);
|
||||
return 1;
|
||||
if (wu_wait_seen_range(NUM_SCALAR_WARPS, NUM_WARPS,
|
||||
WU_CASE_TENSOR_HGEMM_BASE) != 0) {
|
||||
g_hgemm_control_fail = 0x09u;
|
||||
}
|
||||
|
||||
if (g_hgemm_control_fail == 0) {
|
||||
for (uint32_t block = 0; block < WU_HGEMM_M_BLOCKS; ++block) {
|
||||
if (g_case_mem[block] != (WU_CASE_TENSOR_HGEMM_BLOCK_BASE | block)) {
|
||||
g_hgemm_control_fail = 0x0au + block;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
asm volatile("fence rw, rw" ::: "memory");
|
||||
|
||||
if (g_hgemm_control_fail != 0) {
|
||||
if (tid == 0) {
|
||||
g_hgemm_return_code = 1;
|
||||
wu_case_fail(g_hgemm_control_fail);
|
||||
}
|
||||
asm volatile("fence rw, rw" ::: "memory");
|
||||
return static_cast<int>(g_hgemm_return_code);
|
||||
}
|
||||
|
||||
wu_case_pass();
|
||||
return 0;
|
||||
if (tid == 0) {
|
||||
uint32_t bad_index = WU_HGEMM_NO_FAIL;
|
||||
uint32_t bad_actual = 0;
|
||||
|
||||
for (uint32_t i = 0; i < WU_HGEMM_OUT_WORDS; ++i) {
|
||||
const uint32_t actual = g_hgemm_out[i];
|
||||
if (actual != WU_HGEMM_EXPECTED) {
|
||||
bad_index = i;
|
||||
bad_actual = actual;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (bad_index != WU_HGEMM_NO_FAIL) {
|
||||
g_aux[0] = bad_index;
|
||||
g_aux[1] = bad_actual;
|
||||
g_hgemm_return_code = 1;
|
||||
wu_case_fail(0x20u);
|
||||
} else {
|
||||
g_hgemm_return_code = 0;
|
||||
wu_case_pass();
|
||||
}
|
||||
}
|
||||
asm volatile("fence rw, rw" ::: "memory");
|
||||
return static_cast<int>(g_hgemm_return_code);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user