From d6fbd447c30a91174f16de8db4bc873e4f1d68d6 Mon Sep 17 00:00:00 2001 From: Zhongdi LUO Date: Wed, 24 Jun 2026 06:26:30 +0000 Subject: [PATCH] Add Wu TMEM FlashAttention validation cases --- kernels/blackwell_multi_tc/Makefile | 7 + kernels/blackwell_multi_tc/kernel.cpp | 188 ++++++++++++++++ kernels/common.mk | 7 +- kernels/hgemm_validation/Makefile | 9 + kernels/hgemm_validation/README.md | 14 ++ kernels/hgemm_validation/kernel.cpp | 119 +++++++++++ kernels/wu_arch_cases/Makefile | 14 +- kernels/wu_arch_cases/README.md | 18 ++ kernels/wu_arch_cases/TMC_DEBUG_NOTES.md | 174 +++++++++++++++ kernels/wu_arch_cases/case.mk | 2 + .../case09_scalar_tmem_ldst/Makefile | 3 + .../case09_scalar_tmem_ldst/README.md | 7 + .../case09_scalar_tmem_ldst/kernel.cpp | 39 ++++ .../Makefile | 3 + .../README.md | 8 + .../kernel.cpp | 111 ++++++++++ .../case11_scalar_tmem_softmax_stage/Makefile | 3 + .../README.md | 10 + .../kernel.cpp | 200 ++++++++++++++++++ .../case12_1_scalar_tmem_cb_probe/Makefile | 3 + .../case12_1_scalar_tmem_cb_probe/README.md | 10 + .../case12_1_scalar_tmem_cb_probe/kernel.cpp | 108 ++++++++++ .../case12_2_flash_pv_p_probe/Makefile | 3 + .../case12_2_flash_pv_p_probe/README.md | 11 + .../case12_2_flash_pv_p_probe/kernel.cpp | 109 ++++++++++ .../case12_3_scalar_tmem_lane_store/Makefile | 3 + .../case12_3_scalar_tmem_lane_store/README.md | 14 ++ .../kernel.cpp | 91 ++++++++ .../case12_flash_pv_accum/Makefile | 3 + .../case12_flash_pv_accum/README.md | 7 + .../case12_flash_pv_accum/kernel.cpp | 127 +++++++++++ .../case13_flash_pv_two_warps/Makefile | 3 + .../case13_flash_pv_two_warps/README.md | 8 + .../case13_flash_pv_two_warps/kernel.cpp | 135 ++++++++++++ .../case14_flash_pv_k64/Makefile | 3 + .../case14_flash_pv_k64/README.md | 7 + .../case14_flash_pv_k64/kernel.cpp | 136 ++++++++++++ .../case15_flash_softmax_pv_stage/Makefile | 3 + .../case15_flash_softmax_pv_stage/README.md | 8 + .../case15_flash_softmax_pv_stage/kernel.cpp | 145 +++++++++++++ .../case16_flash_full_pipeline/Makefile | 3 + .../case16_flash_full_pipeline/README.md | 14 ++ .../case16_flash_full_pipeline/kernel.cpp | 180 ++++++++++++++++ .../case17_flash_exp_softmax_probe/Makefile | 3 + .../case17_flash_exp_softmax_probe/README.md | 16 ++ .../case17_flash_exp_softmax_probe/kernel.cpp | 81 +++++++ .../wu_arch_cases/common_wu_blackwell_fa.h | 101 +++++++++ kernels/wu_arch_hgemm/README.md | 18 +- kernels/wu_arch_hgemm/kernel.cpp | 132 ++++++++++-- 49 files changed, 2395 insertions(+), 26 deletions(-) create mode 100644 kernels/blackwell_multi_tc/Makefile create mode 100644 kernels/blackwell_multi_tc/kernel.cpp create mode 100644 kernels/hgemm_validation/Makefile create mode 100644 kernels/hgemm_validation/README.md create mode 100644 kernels/hgemm_validation/kernel.cpp create mode 100644 kernels/wu_arch_cases/TMC_DEBUG_NOTES.md create mode 100644 kernels/wu_arch_cases/case09_scalar_tmem_ldst/Makefile create mode 100644 kernels/wu_arch_cases/case09_scalar_tmem_ldst/README.md create mode 100644 kernels/wu_arch_cases/case09_scalar_tmem_ldst/kernel.cpp create mode 100644 kernels/wu_arch_cases/case10_tensor_scalar_tmem_handoff/Makefile create mode 100644 kernels/wu_arch_cases/case10_tensor_scalar_tmem_handoff/README.md create mode 100644 kernels/wu_arch_cases/case10_tensor_scalar_tmem_handoff/kernel.cpp create mode 100644 kernels/wu_arch_cases/case11_scalar_tmem_softmax_stage/Makefile create mode 100644 kernels/wu_arch_cases/case11_scalar_tmem_softmax_stage/README.md create mode 100644 kernels/wu_arch_cases/case11_scalar_tmem_softmax_stage/kernel.cpp create mode 100644 kernels/wu_arch_cases/case12_1_scalar_tmem_cb_probe/Makefile create mode 100644 kernels/wu_arch_cases/case12_1_scalar_tmem_cb_probe/README.md create mode 100644 kernels/wu_arch_cases/case12_1_scalar_tmem_cb_probe/kernel.cpp create mode 100644 kernels/wu_arch_cases/case12_2_flash_pv_p_probe/Makefile create mode 100644 kernels/wu_arch_cases/case12_2_flash_pv_p_probe/README.md create mode 100644 kernels/wu_arch_cases/case12_2_flash_pv_p_probe/kernel.cpp create mode 100644 kernels/wu_arch_cases/case12_3_scalar_tmem_lane_store/Makefile create mode 100644 kernels/wu_arch_cases/case12_3_scalar_tmem_lane_store/README.md create mode 100644 kernels/wu_arch_cases/case12_3_scalar_tmem_lane_store/kernel.cpp create mode 100644 kernels/wu_arch_cases/case12_flash_pv_accum/Makefile create mode 100644 kernels/wu_arch_cases/case12_flash_pv_accum/README.md create mode 100644 kernels/wu_arch_cases/case12_flash_pv_accum/kernel.cpp create mode 100644 kernels/wu_arch_cases/case13_flash_pv_two_warps/Makefile create mode 100644 kernels/wu_arch_cases/case13_flash_pv_two_warps/README.md create mode 100644 kernels/wu_arch_cases/case13_flash_pv_two_warps/kernel.cpp create mode 100644 kernels/wu_arch_cases/case14_flash_pv_k64/Makefile create mode 100644 kernels/wu_arch_cases/case14_flash_pv_k64/README.md create mode 100644 kernels/wu_arch_cases/case14_flash_pv_k64/kernel.cpp create mode 100644 kernels/wu_arch_cases/case15_flash_softmax_pv_stage/Makefile create mode 100644 kernels/wu_arch_cases/case15_flash_softmax_pv_stage/README.md create mode 100644 kernels/wu_arch_cases/case15_flash_softmax_pv_stage/kernel.cpp create mode 100644 kernels/wu_arch_cases/case16_flash_full_pipeline/Makefile create mode 100644 kernels/wu_arch_cases/case16_flash_full_pipeline/README.md create mode 100644 kernels/wu_arch_cases/case16_flash_full_pipeline/kernel.cpp create mode 100644 kernels/wu_arch_cases/case17_flash_exp_softmax_probe/Makefile create mode 100644 kernels/wu_arch_cases/case17_flash_exp_softmax_probe/README.md create mode 100644 kernels/wu_arch_cases/case17_flash_exp_softmax_probe/kernel.cpp create mode 100644 kernels/wu_arch_cases/common_wu_blackwell_fa.h diff --git a/kernels/blackwell_multi_tc/Makefile b/kernels/blackwell_multi_tc/Makefile new file mode 100644 index 00000000..f37ec73f --- /dev/null +++ b/kernels/blackwell_multi_tc/Makefile @@ -0,0 +1,7 @@ +PROJECT = blackwell_multi_tc + +VX_SRCS = kernel.cpp + +OPTS ?= -n1 + +include ../common.mk diff --git a/kernels/blackwell_multi_tc/kernel.cpp b/kernels/blackwell_multi_tc/kernel.cpp new file mode 100644 index 00000000..59bdb7bc --- /dev/null +++ b/kernels/blackwell_multi_tc/kernel.cpp @@ -0,0 +1,188 @@ +#include +#include +#include + +#define RISCV_CUSTOM3 0x7B +#define DEV_SMEM_START_ADDR 0xff000000u +#define MAX_CORES 4 +#define MAX_WARPS 8 + +#define BW_REP2(x) x, x +#define BW_REP4(x) BW_REP2(x), BW_REP2(x) +#define BW_REP8(x) BW_REP4(x), BW_REP4(x) + +extern "C" { +volatile uint32_t g_a_row[8] __attribute__((aligned(32))) = { + BW_REP8(0x3c003c00u)}; // two fp16 1.0 values per word +volatile uint32_t g_b_row[8] __attribute__((aligned(32))) = { + BW_REP8(0x40004000u)}; // two fp16 2.0 values per word +volatile uint32_t g_c_row[8] __attribute__((aligned(32))) = { + BW_REP8(0x3f800000u)}; // one fp32 1.0 value per word +volatile uint32_t g_result[MAX_CORES * MAX_WARPS] __attribute__((aligned(32))); +volatile uint32_t g_status[MAX_CORES * MAX_WARPS] __attribute__((aligned(32))); +} + +#undef BW_REP2 +#undef BW_REP4 +#undef BW_REP8 + +static inline void tcgen05_cp(uint32_t addr_tmem, uint32_t addr_gmem) { + asm volatile(".insn r %0, 2, 0, x0, %1, %2" + : + : "i"(RISCV_CUSTOM3), "r"(addr_tmem), "r"(addr_gmem) + : "memory"); +} + +static inline void tcgen05_cp_wait() { + asm volatile(".insn r %0, 3, 0, x0, x0, x0" :: "i"(RISCV_CUSTOM3) + : "memory"); +} + +static inline void bwgmma(uint32_t addr_tmem_c, + uint32_t addr_tmem_a, + uint32_t addr_smem_b) { + asm volatile(".insn r %0, 0, 0, %1, %2, %3" + : + : "i"(RISCV_CUSTOM3), "r"(addr_tmem_c), "r"(addr_tmem_a), + "r"(addr_smem_b) + : "memory"); +} + +static inline void bwgmma_wait() { + asm volatile(".insn r %0, 1, 0, x0, x0, x0" :: "i"(RISCV_CUSTOM3) + : "memory"); +} + +static inline float tcgen05_ld_f32(uint32_t addr_tmem) { + float value; + asm volatile(".insn r %1, 4, 0, %0, %2, x0" + : "=f"(value) + : "i"(RISCV_CUSTOM3), "r"(addr_tmem) + : "memory"); + return value; +} + +static inline uint32_t f32_bits(float value) { + union { + float f; + uint32_t u; + } bits = {value}; + return bits.u; +} + +extern "C" void vx_perf_dump() {} + +extern "C" void __attribute__((naked, noinline, used)) tensor_kernel_entry() { + asm volatile( + "csrr x5, %[csr_wid]\n\t" + "slli x1, x5, 11\n\t" // tmem_a = wid * 0x800 + "addi x2, x1, 1024\n\t" // tmem_c = tmem_a + 0x400 + "la x6, g_a_row\n\t" + "la x3, g_c_row\n\t" + "li x7, 0\n\t" + "1:\n\t" + "add x4, x1, x7\n\t" + ".insn r %[custom3], 2, 0, x0, x4, x6\n\t" + "add x4, x2, x7\n\t" + ".insn r %[custom3], 2, 0, x0, x4, x3\n\t" + "addi x7, x7, 32\n\t" + "li x4, 1024\n\t" + "blt x7, x4, 1b\n\t" + ".insn r %[custom3], 3, 0, x0, x0, x0\n\t" + "li x4, %[smem_base]\n\t" + ".insn r %[custom3], 0, 0, x2, x1, x4\n\t" + ".insn r %[custom3], 1, 0, x0, x0, x0\n\t" + "csrr x5, %[csr_wid]\n\t" + "slli x5, x5, 2\n\t" + "la x6, g_status\n\t" + "add x6, x6, x5\n\t" + "li x7, 0x600d\n\t" + "sw x7, 0(x6)\n\t" + ".insn r %[custom0], 0, 0, x0, x0, x0\n\t" + "2: j 2b\n\t" + : + : [csr_wid] "i"(VX_CSR_WARP_ID), [custom0] "i"(RISCV_CUSTOM0), + [custom3] "i"(RISCV_CUSTOM3), [smem_base] "i"(DEV_SMEM_START_ADDR) + : "memory"); +} + +struct kernel_arg_t { + volatile uint32_t *a_row; + volatile uint32_t *b_row; + volatile uint32_t *c_row; + volatile uint32_t *result; + volatile uint32_t *status; +}; + +static void __attribute__((convergent)) kernel_body(int task_id, + kernel_arg_t *__UNIFORM__ arg) { + const int warp_id = task_id / vx_num_threads(); + const int lane_id = task_id % vx_num_threads(); + if (lane_id != 0) + return; + + const int num_warps = vx_num_warps(); + if (warp_id >= num_warps) + return; + + volatile uint32_t *a_row = arg->a_row; + volatile uint32_t *c_row = arg->c_row; + volatile uint32_t *result = arg->result; + volatile uint32_t *status = arg->status; + const int core_id = vx_core_id(); + const int status_idx = core_id * MAX_WARPS + warp_id; + + const uint32_t smem_b = DEV_SMEM_START_ADDR; + const uint32_t base = static_cast(warp_id) * 0x800u; + const uint32_t tmem_a = base + 0x000u; + const uint32_t tmem_c = base + 0x400u; + + status[status_idx] = 0x100u; + + for (int frag = 0; frag < 32; ++frag) { + const uint32_t offset = static_cast(frag * 32); + tcgen05_cp(tmem_a + offset, reinterpret_cast(a_row)); + tcgen05_cp(tmem_c + offset, reinterpret_cast(c_row)); + } + tcgen05_cp_wait(); + + status[status_idx] = 0x200u; + bwgmma(tmem_c, tmem_a, smem_b); + bwgmma_wait(); + + result[status_idx] = f32_bits(tcgen05_ld_f32(tmem_c)); + status[status_idx] = (result[status_idx] == 0x42820000u) ? 0x600du : 0xe000u; +} + +int main() { + const int core_id = vx_core_id(); + if (core_id != 0 || vx_warp_id() != 0 || vx_thread_id() != 0) + return 0; + + volatile uint32_t *smem_b_ptr = + reinterpret_cast(DEV_SMEM_START_ADDR); + for (int frag = 0; frag < 32; ++frag) { + const int row = frag * 8; + for (int i = 0; i < 8; ++i) + smem_b_ptr[row + i] = g_b_row[i]; + } + + for (int i = 0; i < MAX_WARPS; ++i) + g_status[core_id * MAX_WARPS + i] = 0; + + vx_spawn_tensor(vx_tensor_warp_mask(), tensor_kernel_entry); + + for (int spin = 0; spin < 100000; ++spin) { + int done = 1; + for (int i = NUM_SCALAR_WARPS; i < vx_num_warps(); ++i) + done &= (g_status[core_id * MAX_WARPS + i] == 0x600du); + if (done) + break; + } + + for (int i = NUM_SCALAR_WARPS; i < vx_num_warps(); ++i) { + if (g_status[core_id * MAX_WARPS + i] != 0x600du) + return 1; + } + return 0; +} diff --git a/kernels/common.mk b/kernels/common.mk index 63b9b8f6..54cabf01 100644 --- a/kernels/common.mk +++ b/kernels/common.mk @@ -1,3 +1,6 @@ +# Get the directory where this common.mk file is located +COMMON_MK_DIR := $(dir $(lastword $(MAKEFILE_LIST))) + XLEN ?= 32 TOOLDIR ?= /opt @@ -7,7 +10,7 @@ RISCV_TOOLCHAIN_PATH ?= $(TOOLDIR)/riscv64-gnu-toolchain VX_CFLAGS += -march=rv64imafd -mabi=lp64d STARTUP_ADDR ?= 0x180000000 else -RISCV_TOOLCHAIN_PATH ?= $(TOOLDIR)/riscv-gnu-toolchain +RISCV_TOOLCHAIN_PATH ?= $(realpath $(COMMON_MK_DIR)../../toolchain/vortex-toolchain-prebuilt/riscv-gnu-toolchain) VX_CFLAGS += -march=rv32imaf -mabi=ilp32f STARTUP_ADDR ?= 0x80000000 endif @@ -18,7 +21,7 @@ RISCV_SYSROOT ?= $(RISCV_TOOLCHAIN_PATH)/$(RISCV_PREFIX) VORTEX_KN_PATH ?= $(realpath ../../lib) GEMMINI_SW_PATH ?= $(realpath ../../lib/gemmini) -LLVM_VORTEX ?= $(TOOLDIR)/llvm-vortex +LLVM_VORTEX ?= $(realpath $(COMMON_MK_DIR)../../toolchain/llvm-r8) LLVM_CFLAGS += --sysroot=$(RISCV_SYSROOT) LLVM_CFLAGS += --gcc-toolchain=$(RISCV_TOOLCHAIN_PATH) diff --git a/kernels/hgemm_validation/Makefile b/kernels/hgemm_validation/Makefile new file mode 100644 index 00000000..f1aff7e2 --- /dev/null +++ b/kernels/hgemm_validation/Makefile @@ -0,0 +1,9 @@ +PROJECT = hgemm_validation + +VX_SRCS = kernel.cpp +OPTS ?= -n1 + +include ../common.mk + +args.bin input.a.bin input.b.bin input.c.bin: ../wu_arch_cases/zero.bin + cp $< $@ diff --git a/kernels/hgemm_validation/README.md b/kernels/hgemm_validation/README.md new file mode 100644 index 00000000..2a902da0 --- /dev/null +++ b/kernels/hgemm_validation/README.md @@ -0,0 +1,14 @@ +# hgemm_validation + +Small HGEMM correctness validation for the 4-lane Blackwell tensor-core path. + +The test runs one tensor warp on a single 16x16x32 HGEMM tile: + +- A is fp16 1.0 +- B is fp16 2.0 +- C starts as fp32 1.0 +- expected output is fp32 65.0 for all 16x16 C elements + +Scalar warp 0 initializes the shared-memory B tile, spawns only the first tensor +warp, waits for completion, and checks all 256 output words copied back from +TMEM. Success reports `WU_CASE_PASS` through `tohost`. diff --git a/kernels/hgemm_validation/kernel.cpp b/kernels/hgemm_validation/kernel.cpp new file mode 100644 index 00000000..b77466aa --- /dev/null +++ b/kernels/hgemm_validation/kernel.cpp @@ -0,0 +1,119 @@ +#include "../wu_arch_cases/common_wu_min.h" + +#define DEV_SMEM_START_ADDR 0xff000000u +#define HGEMM_VALIDATION_BASE 0x7700u + +#define HGEMM_M 16u +#define HGEMM_N 16u +#define HGEMM_K 32u +#define HGEMM_TILE_BYTES 1024u +#define HGEMM_FRAGMENT_BYTES 16u +#define HGEMM_FRAGMENTS (HGEMM_TILE_BYTES / HGEMM_FRAGMENT_BYTES) +#define HGEMM_OUT_WORDS (HGEMM_M * HGEMM_N) +#define HGEMM_EXPECTED 0x42820000u + +#define BW_REP2(x) x, x +#define BW_REP4(x) BW_REP2(x), BW_REP2(x) + +extern "C" { +volatile uint32_t g_hgemm_a_frag[4] __attribute__((aligned(16))) = { + BW_REP4(0x3c003c00u)}; +volatile uint32_t g_hgemm_b_frag[4] __attribute__((aligned(16))) = { + BW_REP4(0x40004000u)}; +volatile uint32_t g_hgemm_c_frag[4] __attribute__((aligned(16))) = { + BW_REP4(0x3f800000u)}; +volatile uint32_t g_hgemm_out[HGEMM_OUT_WORDS] __attribute__((aligned(16))); +} + +#undef BW_REP2 +#undef BW_REP4 + +extern "C" void __attribute__((naked, noinline, used)) hgemm_validation_worker() { + asm volatile( + "li x1, 0\n\t" + "li x2, %[tile_bytes]\n\t" + "la x6, g_hgemm_a_frag\n\t" + "la x3, g_hgemm_c_frag\n\t" + "li x7, 0\n\t" + "1:\n\t" + "add x4, x1, x7\n\t" + ".insn r %[custom3], 2, 0, x0, x4, x6\n\t" + "add x4, x2, x7\n\t" + ".insn r %[custom3], 2, 0, x0, x4, x3\n\t" + "addi x7, x7, %[frag_bytes]\n\t" + "li x4, %[tile_bytes]\n\t" + "blt x7, x4, 1b\n\t" + ".insn r %[custom3], 3, 0, x0, x0, x0\n\t" + "li x4, %[smem_base]\n\t" + ".insn r %[custom3], 0, 0, x2, x1, x4\n\t" + ".insn r %[custom3], 1, 0, x0, x0, x0\n\t" + "la x3, g_hgemm_out\n\t" + "li x7, 0\n\t" + "2:\n\t" + "add x4, x2, x7\n\t" + "add x1, x3, x7\n\t" + ".insn r %[custom3], 6, 0, x0, x4, x1\n\t" + "addi x7, x7, %[frag_bytes]\n\t" + "li x4, %[tile_bytes]\n\t" + "blt x7, x4, 2b\n\t" + ".insn r %[custom3], 3, 0, x0, x0, x0\n\t" + "csrr x5, %[csr_wid]\n\t" + "slli x6, x5, 2\n\t" + "la x7, g_seen\n\t" + "add x7, x7, x6\n\t" + "li x6, %[done_base]\n\t" + "or x6, x6, x5\n\t" + "sw x6, 0(x7)\n\t" + ".insn r %[custom0], 0, 0, x0, x0, x0\n\t" + "3: j 3b\n\t" + : + : [csr_wid] "i"(VX_CSR_WARP_ID), + [custom0] "i"(RISCV_CUSTOM0), + [custom3] "i"(RISCV_CUSTOM3), + [smem_base] "i"(DEV_SMEM_START_ADDR), + [done_base] "i"(HGEMM_VALIDATION_BASE), + [tile_bytes] "i"(HGEMM_TILE_BYTES), + [frag_bytes] "i"(HGEMM_FRAGMENT_BYTES) + : "memory"); +} + +extern "C" int wu_main() { + if (!wu_is_leader()) { + return 0; + } + + wu_case_reset(); + + for (uint32_t i = 0; i < HGEMM_OUT_WORDS; ++i) { + g_hgemm_out[i] = 0; + } + + volatile uint32_t *smem_b = + reinterpret_cast(DEV_SMEM_START_ADDR); + for (uint32_t frag = 0; frag < HGEMM_FRAGMENTS; ++frag) { + const uint32_t row = frag * 4u; + for (uint32_t i = 0; i < 4u; ++i) { + smem_b[row + i] = g_hgemm_b_frag[i]; + } + } + + const uint32_t tensor_wid = NUM_SCALAR_WARPS; + vx_spawn_tensor(1u << tensor_wid, hgemm_validation_worker); + + if (wu_wait_seen_mask(1u << tensor_wid, HGEMM_VALIDATION_BASE) != 0) { + wu_case_fail(0x09u); + return 1; + } + + for (uint32_t i = 0; i < HGEMM_OUT_WORDS; ++i) { + if (g_hgemm_out[i] != HGEMM_EXPECTED) { + g_aux[0] = i; + g_aux[1] = g_hgemm_out[i]; + wu_case_fail(0x20u); + return 1; + } + } + + wu_case_pass(); + return 0; +} diff --git a/kernels/wu_arch_cases/Makefile b/kernels/wu_arch_cases/Makefile index df297fd0..946c33b4 100644 --- a/kernels/wu_arch_cases/Makefile +++ b/kernels/wu_arch_cases/Makefile @@ -7,7 +7,19 @@ CASES := \ case05_tensor_barrier \ case06_masked_barrier \ case07_tensor_csr_tmc \ - case08_tensor_lsu_optional + case08_tensor_lsu_optional \ + case09_scalar_tmem_ldst \ + case10_tensor_scalar_tmem_handoff \ + case11_scalar_tmem_softmax_stage \ + case12_flash_pv_accum \ + case12_1_scalar_tmem_cb_probe \ + case12_2_flash_pv_p_probe \ + case12_3_scalar_tmem_lane_store \ + case13_flash_pv_two_warps \ + case14_flash_pv_k64 \ + case15_flash_softmax_pv_stage \ + case16_flash_full_pipeline \ + case17_flash_exp_softmax_probe SMOKE_CASES := \ case00_boot_scalar \ diff --git a/kernels/wu_arch_cases/README.md b/kernels/wu_arch_cases/README.md index bfc106ca..3b182bd7 100644 --- a/kernels/wu_arch_cases/README.md +++ b/kernels/wu_arch_cases/README.md @@ -13,9 +13,27 @@ This directory contains small bare-metal kernels for incremental Wu architecture - `case06_masked_barrier`: explicit mixed `BAR_MASK` with scalar warp 0 and tensor warps. - `case07_tensor_csr_tmc`: tensor CSR/TMC path without barrier behavior. - `case08_tensor_lsu_optional`: tensor LSU store/load marker path; keep last because memory interaction is broader and slower. +- `case09_scalar_tmem_ldst`: scalar warp direct TMEM store/load path for the banked TMEM softmax mechanism. +- `case10_tensor_scalar_tmem_handoff`: tensor BWGMMA result in TMEM C observed by scalar TMEM loads. +- `case11_scalar_tmem_softmax_stage`: scalar TMEM transform written back for tensor-side copy-out. +- `case12_flash_pv_accum`: one tensor warp consumes scalar-written `P` in TMEM A for `O = O + P @ V`. +- `case12_1_scalar_tmem_cb_probe`: scalar-written TMEM A rows copied back by tensor `tcgen05_cb`. +- `case12_2_flash_pv_p_probe`: case12 P-write diagnostic using the same scalar fill path. +- `case12_3_scalar_tmem_lane_store`: scalar TMEM store lane-coalesced fragment write diagnostic. +- `case13_flash_pv_two_warps`: both tensor warps consume scalar-written `P` tiles for two row blocks. +- `case14_flash_pv_k64`: two consecutive `K=32` BWGMMA steps accumulate into one PV output. +- `case15_flash_softmax_pv_stage`: scalar reads TMEM C, writes softmax-like `P`, and tensor consumes it in PV. +- `case16_flash_full_pipeline`: compact `QK -> scalar softmax handoff -> PV` end-to-end FlashAttention-style pipeline. +- `case17_flash_exp_softmax_probe`: scalar non-uniform `e^x` softmax probe for generalized FlashAttention. Each case has its own `README.md` describing the test objective, RTL surface, and expected pass marker. +## Debug Notes + +- `TMC_DEBUG_NOTES.md`: reusable notes for diagnosing lane-mask/TMC operand + bugs, including the `case12_3_scalar_tmem_lane_store` failure where a + lane0-only register value was consumed under an all-lane mask. + ## Build Use the suite Makefile from this directory: diff --git a/kernels/wu_arch_cases/TMC_DEBUG_NOTES.md b/kernels/wu_arch_cases/TMC_DEBUG_NOTES.md new file mode 100644 index 00000000..e654ba65 --- /dev/null +++ b/kernels/wu_arch_cases/TMC_DEBUG_NOTES.md @@ -0,0 +1,174 @@ +# TMC Operand Debug Notes + +This note records a failure pattern found while debugging +`case12_3_scalar_tmem_lane_store`. + +## Symptom + +The simulator keeps producing trace output, but the kernel makes no forward +progress: + +- a tensor worker repeatedly polls a ready flag such as `g_case_mem[0]`; +- the poll load always returns the old value; +- the scalar warp has already committed the work before the ready flag store; +- the expected ready flag store never appears in the LSU trace. + +For `case12_3_scalar_tmem_lane_store`, the tensor worker was looping at +`PC=0x80000034..0x80000048`, repeatedly reading `g_case_mem[0]` at +`0x20000430` as `0`. The scalar warp committed the scalar TMEM store and the +following `TMC`, but never issued the next ready flag store at `PC=0x8000015c`. + +## Root Cause Pattern + +Vortex scalar registers are lane-local. A value computed while only lane 0 is +active is only known to be valid in lane 0. If that register is later consumed +by an instruction while multiple lanes are active, the inactive lanes may hold +stale or zero values. + +This is especially easy to miss around `TMC`: + +```asm +# Bad when xN was defined while only lane 0 was active. +vx_tmc all_lanes +... +vx_tmc xN +``` + +The failed `case12_3` sequence used a C operand for `1u` that the compiler +materialized before switching to all lanes. Runtime trace showed the source +register for the final `TMC(1)` as: + +```text +rs1_data={0x0, 0x0, 0x0, 0x1} +``` + +After that `TMC`, the scalar warp did not fetch the ready flag store. + +## Correct Pattern + +When a value is consumed under an all-lane mask, define that value under the +same all-lane mask unless the instruction semantics explicitly use only lane 0. + +For switching back to lane 0 from an all-lane region, keep the immediate +materialization and the `TMC` adjacent inside the all-lane region: + +```asm +vx_tmc all_lanes +... +fence rw, rw +li t2, 1 +vx_tmc t2 +``` + +The expected trace at the final `TMC` is: + +```text +rs1_data={0x1, 0x1, 0x1, 0x1} +``` + +The library helper `vx_tmc_one()` follows this pattern because it emits +`li a0, 1` and `vx_tmc a0` in the same volatile asm block. It is safe when +called while all lanes are active. + +## Fast Log Checks + +Use the simulation log to distinguish a simulator stall from a kernel-level +wait loop: + +```sh +stat -c '%s %y' chipyard/sims/verilator/output/chipyard.harness.TestHarness.VirgoBlackwellConfig/kernel.radiance.log +tail -n 80 chipyard/sims/verilator/output/chipyard.harness.TestHarness.VirgoBlackwellConfig/kernel.radiance.log +``` + +If the file is still growing but the tail repeatedly shows the same PC range, +the simulator is alive and the kernel is probably spinning. + +For a ready flag handoff, check both the polling load and the producer store: + +```sh +rg -n "0x20000430|0x9900|wid=0, PC=0x8000014|wid=0, PC=0x8000015" \ + chipyard/sims/verilator/output/chipyard.harness.TestHarness.VirgoBlackwellConfig/kernel.radiance.log +``` + +Interpretation: + +- load repeatedly returns `0` and no later `0x9900` store exists: producer did + not reach the ready flag store; +- ready flag store exists but consumer still reads old data: investigate memory + ordering or address aliasing; +- producer stops immediately after a `TMC`: inspect that `TMC` source register + across all lanes. + +## Dump Checks + +Check that the final lane-mask narrowing defines `1` immediately before `TMC` +inside the all-lane region: + +```sh +rg -n "li\\s+.*1|vx_tmc" \ + kernels/wu_arch_cases/case12_3_scalar_tmem_lane_store/kernel.radiance.dump +``` + +The fixed `case12_3` dump has: + +```text +li t2, 1 +vx_tmc t2 +auipc a3, 1 +sw a5, ... +``` + +## Case12-Case15 Audit + +The following cases were checked after fixing `case12_3`: + +- `case12_flash_pv_accum` +- `case12_2_flash_pv_p_probe` +- `case13_flash_pv_two_warps` +- `case14_flash_pv_k64` +- `case15_flash_softmax_pv_stage` + +They switch to all lanes with `vx_tmc(wu_bw_all_lanes_mask())` and switch back +with `vx_tmc_one()`. Their dumps show the safe adjacent sequence: + +```text +li a0, 1 +vx_tmc a0 +``` + +No analogous source change is needed for those cases. + +## Scalar TMEM Fill Lane Coverage + +`wu_bw_fill_tmem_tile()` must run with all fragment lanes active. Scalar TMEM +store writes the active lane's source word into the matching TMEM fragment word: + +```text +TMEM[addr].word[lane] = rs2_data[lane] +``` + +Therefore a fill loop executed with `tmask=0001` only initializes word 0 of +each 16-byte fragment. Word 1 and later can retain old TMEM data. In +`case12_1_scalar_tmem_cb_probe`, this showed up as a normal completion path +followed by verification failure `0x14`; `g_aux[0]` was `1`, and `g_aux[1]` +held the stale copied-back word. + +The correct pattern is: + +```c++ +vx_tmc(wu_bw_all_lanes_mask()); +wu_bw_fill_tmem_tile(wu_bw_tmem_a_byte_base(0), WU_BW_FP16_ONE_PACKED); +vx_tmc_one(); +``` + +The dump should show the fill loop bracketed by all-lane and lane-0 masks: + +```text +li a5, 15 +vx_tmc a5 +... +vx_cmov zero, a0, a5, a2 +... +li a0, 1 +vx_tmc a0 +``` diff --git a/kernels/wu_arch_cases/case.mk b/kernels/wu_arch_cases/case.mk index 7035ee50..ba897d25 100644 --- a/kernels/wu_arch_cases/case.mk +++ b/kernels/wu_arch_cases/case.mk @@ -3,6 +3,8 @@ VX_SRCS = kernel.cpp VX_CFLAGS += -I.. VORTEX_KN_PATH ?= $(realpath ../../../lib) GEMMINI_SW_PATH ?= $(realpath ../../../lib/gemmini) +LLVM_VORTEX ?= $(realpath ../../../../toolchain/llvm-r8) +RISCV_TOOLCHAIN_PATH ?= $(realpath ../../../../toolchain/vortex-toolchain-prebuilt/riscv-gnu-toolchain) OPTS ?= -n1 include ../../common.mk diff --git a/kernels/wu_arch_cases/case09_scalar_tmem_ldst/Makefile b/kernels/wu_arch_cases/case09_scalar_tmem_ldst/Makefile new file mode 100644 index 00000000..c7e9b3d6 --- /dev/null +++ b/kernels/wu_arch_cases/case09_scalar_tmem_ldst/Makefile @@ -0,0 +1,3 @@ +PROJECT = wu_arch_case09_scalar_tmem_ldst + +include ../case.mk diff --git a/kernels/wu_arch_cases/case09_scalar_tmem_ldst/README.md b/kernels/wu_arch_cases/case09_scalar_tmem_ldst/README.md new file mode 100644 index 00000000..7f806a7e --- /dev/null +++ b/kernels/wu_arch_cases/case09_scalar_tmem_ldst/README.md @@ -0,0 +1,7 @@ +# case09_scalar_tmem_ldst + +Smoke test for scalar warp direct TMEM access. + +The leader scalar lane writes a lane-distinct word vector to a TMEM row with +`CUSTOM1 func7=0x30 func3=1`, reads it back with `CUSTOM1 func7=0x30 func3=0`, +and reports pass only if the lane 0 value matches. diff --git a/kernels/wu_arch_cases/case09_scalar_tmem_ldst/kernel.cpp b/kernels/wu_arch_cases/case09_scalar_tmem_ldst/kernel.cpp new file mode 100644 index 00000000..5265a99e --- /dev/null +++ b/kernels/wu_arch_cases/case09_scalar_tmem_ldst/kernel.cpp @@ -0,0 +1,39 @@ +#include "common_wu_min.h" + +static inline uint32_t wu_scalar_tmem_ld(uint32_t addr) { + uint32_t value; + asm volatile(".insn r %[custom1], 0, 0x30, %[value], %[addr], x0" + : [value] "=r"(value) + : [custom1] "i"(RISCV_CUSTOM1), [addr] "r"(addr) + : "memory"); + return value; +} + +static inline void wu_scalar_tmem_st(uint32_t addr, uint32_t value) { + asm volatile(".insn r %[custom1], 1, 0x30, x0, %[addr], %[value]" + : + : [custom1] "i"(RISCV_CUSTOM1), [addr] "r"(addr), [value] "r"(value) + : "memory"); +} + +extern "C" int wu_main() { + if (!wu_is_leader()) { + return 0; + } + + wu_case_reset(); + + const uint32_t addr = 0x24u; + const uint32_t expected = 0x5a090000u | wu_tid(); + wu_scalar_tmem_st(addr, expected); + const uint32_t observed = wu_scalar_tmem_ld(addr); + + if (observed != expected) { + g_aux[0] = observed; + wu_case_fail(0x09u); + return 1; + } + + wu_case_pass(); + return 0; +} diff --git a/kernels/wu_arch_cases/case10_tensor_scalar_tmem_handoff/Makefile b/kernels/wu_arch_cases/case10_tensor_scalar_tmem_handoff/Makefile new file mode 100644 index 00000000..1c8918ae --- /dev/null +++ b/kernels/wu_arch_cases/case10_tensor_scalar_tmem_handoff/Makefile @@ -0,0 +1,3 @@ +PROJECT = wu_arch_case10_tensor_scalar_tmem_handoff + +include ../case.mk diff --git a/kernels/wu_arch_cases/case10_tensor_scalar_tmem_handoff/README.md b/kernels/wu_arch_cases/case10_tensor_scalar_tmem_handoff/README.md new file mode 100644 index 00000000..6a0815a9 --- /dev/null +++ b/kernels/wu_arch_cases/case10_tensor_scalar_tmem_handoff/README.md @@ -0,0 +1,8 @@ +# case10_tensor_scalar_tmem_handoff + +Validates the tensor-to-scalar TMEM handoff needed by FlashAttention. + +The tensor warp initializes TMEM A/C, runs one BWGMMA, and leaves the result in +TMEM C. The scalar leader waits for the tensor warp and then reads several C +fragments through the scalar TMEM load instruction. Each lane is expected to +observe the HGEMM value `0x42820000`. diff --git a/kernels/wu_arch_cases/case10_tensor_scalar_tmem_handoff/kernel.cpp b/kernels/wu_arch_cases/case10_tensor_scalar_tmem_handoff/kernel.cpp new file mode 100644 index 00000000..d5393807 --- /dev/null +++ b/kernels/wu_arch_cases/case10_tensor_scalar_tmem_handoff/kernel.cpp @@ -0,0 +1,111 @@ +#include "../common_wu_min.h" + +#define DEV_SMEM_START_ADDR 0xff000000u +#define WU_CASE_TMEM_HANDOFF_BASE 0x7700u +#define WU_TMEM_TILE_BYTES 1024u +#define WU_TMEM_FRAGMENT_BYTES 16u +#define WU_TMEM_C_BYTE_BASE 1024u +#define WU_TMEM_EXPECTED_HGEMM 0x42820000u + +static_assert(NUM_TENSOR_WARPS >= 1, "case10 requires at least one tensor warp"); +static_assert(NUM_THREADS == 4, "case10 expects the 4-lane Blackwell tensor core"); + +#define BW_REP2(x) x, x +#define BW_REP4(x) BW_REP2(x), BW_REP2(x) + +extern "C" { +volatile uint32_t g_case10_a_row[4] __attribute__((aligned(16))) = { + BW_REP4(0x3c003c00u)}; +volatile uint32_t g_case10_b_row[4] __attribute__((aligned(16))) = { + BW_REP4(0x40004000u)}; +volatile uint32_t g_case10_c_row[4] __attribute__((aligned(16))) = { + BW_REP4(0x3f800000u)}; +} + +#undef BW_REP2 +#undef BW_REP4 + +static inline uint32_t wu_scalar_tmem_ld(uint32_t frag_addr) { + uint32_t value; + asm volatile(".insn r %[custom1], 0, 0x30, %[value], %[addr], x0" + : [value] "=r"(value) + : [custom1] "i"(RISCV_CUSTOM1), [addr] "r"(frag_addr) + : "memory"); + return value; +} + +extern "C" void __attribute__((naked, noinline, used)) tensor_case10_worker() { + asm volatile( + "csrr x5, %[csr_wid]\n\t" + "addi x1, x5, -%[num_scalar_warps]\n\t" + "slli x1, x1, 11\n\t" + "addi x2, x1, 1024\n\t" + "la x6, g_case10_a_row\n\t" + "la x3, g_case10_c_row\n\t" + "li x7, 0\n\t" + "1:\n\t" + "add x4, x1, x7\n\t" + ".insn r %[custom3], 2, 0, x0, x4, x6\n\t" + "add x4, x2, x7\n\t" + ".insn r %[custom3], 2, 0, x0, x4, x3\n\t" + "addi x7, x7, 16\n\t" + "li x4, %[tile_bytes]\n\t" + "blt x7, x4, 1b\n\t" + ".insn r %[custom3], 3, 0, x0, x0, x0\n\t" + "li x4, %[smem_base]\n\t" + ".insn r %[custom3], 0, 0, x2, x1, x4\n\t" + ".insn r %[custom3], 1, 0, x0, x0, x0\n\t" + "slli x6, x5, 2\n\t" + "la x7, g_seen\n\t" + "add x7, x7, x6\n\t" + "li x6, %[handoff_base]\n\t" + "or x6, x6, x5\n\t" + "sw x6, 0(x7)\n\t" + ".insn r %[custom0], 0, 0, x0, x0, x0\n\t" + "2: j 2b\n\t" + : + : [csr_wid] "i"(VX_CSR_WARP_ID), + [custom0] "i"(RISCV_CUSTOM0), + [custom3] "i"(RISCV_CUSTOM3), + [smem_base] "i"(DEV_SMEM_START_ADDR), + [handoff_base] "i"(WU_CASE_TMEM_HANDOFF_BASE), + [num_scalar_warps] "i"(NUM_SCALAR_WARPS), + [tile_bytes] "i"(WU_TMEM_TILE_BYTES) + : "memory"); +} + +extern "C" int wu_main() { + if (!wu_is_leader()) { + return 0; + } + + wu_case_reset(); + + volatile uint32_t *smem_b = + reinterpret_cast(DEV_SMEM_START_ADDR); + for (uint32_t i = 0; i < WU_TMEM_TILE_BYTES / sizeof(uint32_t); ++i) { + smem_b[i] = g_case10_b_row[i & 3u]; + } + asm volatile("fence rw, rw" ::: "memory"); + + vx_spawn_tensor(1u << NUM_SCALAR_WARPS, tensor_case10_worker); + + if (wu_wait_seen_mask(1u << NUM_SCALAR_WARPS, WU_CASE_TMEM_HANDOFF_BASE) != 0) { + wu_case_fail(0x10u); + return 1; + } + + const uint32_t c_frag_base = WU_TMEM_C_BYTE_BASE / WU_TMEM_FRAGMENT_BYTES; + for (uint32_t i = 0; i < 8; ++i) { + const uint32_t observed = wu_scalar_tmem_ld(c_frag_base + i); + if (observed != WU_TMEM_EXPECTED_HGEMM) { + g_aux[0] = i; + g_aux[1] = observed; + wu_case_fail(0x11u); + return 1; + } + } + + wu_case_pass(); + return 0; +} diff --git a/kernels/wu_arch_cases/case11_scalar_tmem_softmax_stage/Makefile b/kernels/wu_arch_cases/case11_scalar_tmem_softmax_stage/Makefile new file mode 100644 index 00000000..c973cc15 --- /dev/null +++ b/kernels/wu_arch_cases/case11_scalar_tmem_softmax_stage/Makefile @@ -0,0 +1,3 @@ +PROJECT = wu_arch_case11_scalar_tmem_softmax_stage + +include ../case.mk diff --git a/kernels/wu_arch_cases/case11_scalar_tmem_softmax_stage/README.md b/kernels/wu_arch_cases/case11_scalar_tmem_softmax_stage/README.md new file mode 100644 index 00000000..bceece07 --- /dev/null +++ b/kernels/wu_arch_cases/case11_scalar_tmem_softmax_stage/README.md @@ -0,0 +1,10 @@ +# case11_scalar_tmem_softmax_stage + +Validates the FlashAttention softmax-stage TMEM path. + +The tensor warp writes HGEMM results into TMEM C. The scalar warp reads TMEM C, +performs a deterministic lane-wise transform, and writes all four lanes back +into one TMEM A-region fragment. The tensor warp then uses TCGEN05_CB to copy +that TMEM fragment to global memory so scalar code can verify that tensor-side +TMEM reads observe the scalar writeback with the correct lane ordering and +write mask. diff --git a/kernels/wu_arch_cases/case11_scalar_tmem_softmax_stage/kernel.cpp b/kernels/wu_arch_cases/case11_scalar_tmem_softmax_stage/kernel.cpp new file mode 100644 index 00000000..bdc594d2 --- /dev/null +++ b/kernels/wu_arch_cases/case11_scalar_tmem_softmax_stage/kernel.cpp @@ -0,0 +1,200 @@ +#include "../common_wu_min.h" + +#define DEV_SMEM_START_ADDR 0xff000000u +#define WU_CASE_TMEM_STAGE_HGEMM_BASE 0x7800u +#define WU_CASE_TMEM_STAGE_VERIFY_BASE 0x7900u +#define WU_TMEM_TILE_BYTES 1024u +#define WU_TMEM_FRAGMENT_BYTES 16u +#define WU_TMEM_C_BYTE_BASE 1024u +#define WU_TMEM_SCALAR_WRITE_BYTE_ADDR 512u +#define WU_TMEM_EXPECTED_HGEMM 0x42820000u +#define WU_TMEM_STAGE_VALUE_BASE 0x42820001u + +static_assert(NUM_TENSOR_WARPS >= 1, "case11 requires at least one tensor warp"); +static_assert(NUM_THREADS == 4, "case11 expects the 4-lane Blackwell tensor core"); + +#define BW_REP2(x) x, x +#define BW_REP4(x) BW_REP2(x), BW_REP2(x) + +extern "C" { +volatile uint32_t g_case11_a_row[4] __attribute__((aligned(16))) = { + BW_REP4(0x3c003c00u)}; +volatile uint32_t g_case11_b_row[4] __attribute__((aligned(16))) = { + BW_REP4(0x40004000u)}; +volatile uint32_t g_case11_c_row[4] __attribute__((aligned(16))) = { + BW_REP4(0x3f800000u)}; +volatile uint32_t g_case11_out[4] __attribute__((aligned(16))); +volatile uint32_t g_case11_scalar_seen[4] __attribute__((aligned(16))); +} + +#undef BW_REP2 +#undef BW_REP4 + +static inline uint32_t wu_scalar_tmem_ld(uint32_t frag_addr) { + uint32_t value; + asm volatile(".insn r %[custom1], 0, 0x30, %[value], %[addr], x0" + : [value] "=r"(value) + : [custom1] "i"(RISCV_CUSTOM1), [addr] "r"(frag_addr) + : "memory"); + return value; +} + +static inline void wu_scalar_tmem_st(uint32_t frag_addr, uint32_t value) { + asm volatile(".insn r %[custom1], 1, 0x30, x0, %[addr], %[value]" + : + : [custom1] "i"(RISCV_CUSTOM1), [addr] "r"(frag_addr), + [value] "r"(value) + : "memory"); +} + +extern "C" void __attribute__((naked, noinline, used)) tensor_case11_worker() { + asm volatile( + "csrr x5, %[csr_wid]\n\t" + "addi x1, x5, -%[num_scalar_warps]\n\t" + "slli x1, x1, 11\n\t" + "addi x2, x1, 1024\n\t" + "la x6, g_case11_a_row\n\t" + "la x3, g_case11_c_row\n\t" + "li x7, 0\n\t" + "1:\n\t" + "add x4, x1, x7\n\t" + ".insn r %[custom3], 2, 0, x0, x4, x6\n\t" + "add x4, x2, x7\n\t" + ".insn r %[custom3], 2, 0, x0, x4, x3\n\t" + "addi x7, x7, 16\n\t" + "li x4, %[tile_bytes]\n\t" + "blt x7, x4, 1b\n\t" + ".insn r %[custom3], 3, 0, x0, x0, x0\n\t" + "li x4, %[smem_base]\n\t" + ".insn r %[custom3], 0, 0, x2, x1, x4\n\t" + ".insn r %[custom3], 1, 0, x0, x0, x0\n\t" + "slli x6, x5, 2\n\t" + "la x7, g_seen\n\t" + "add x7, x7, x6\n\t" + "li x6, %[hgemm_base]\n\t" + "or x6, x6, x5\n\t" + "sw x6, 0(x7)\n\t" + "3:\n\t" + "la x6, g_case_mem\n\t" + "lw x7, 0(x6)\n\t" + "li x4, %[verify_base]\n\t" + "bne x7, x4, 3b\n\t" + "addi x4, x1, %[write_byte_addr]\n\t" + "la x6, g_case11_out\n\t" + ".insn r %[custom3], 6, 0, x0, x4, x6\n\t" + ".insn r %[custom3], 3, 0, x0, x0, x0\n\t" + "csrr x5, %[csr_wid]\n\t" + "slli x6, x5, 2\n\t" + "la x7, g_seen\n\t" + "add x7, x7, x6\n\t" + "li x6, %[verify_base]\n\t" + "or x6, x6, x5\n\t" + "sw x6, 0(x7)\n\t" + ".insn r %[custom0], 0, 0, x0, x0, x0\n\t" + "2: j 2b\n\t" + : + : [csr_wid] "i"(VX_CSR_WARP_ID), + [custom0] "i"(RISCV_CUSTOM0), + [custom3] "i"(RISCV_CUSTOM3), + [smem_base] "i"(DEV_SMEM_START_ADDR), + [hgemm_base] "i"(WU_CASE_TMEM_STAGE_HGEMM_BASE), + [verify_base] "i"(WU_CASE_TMEM_STAGE_VERIFY_BASE), + [num_scalar_warps] "i"(NUM_SCALAR_WARPS), + [tile_bytes] "i"(WU_TMEM_TILE_BYTES), + [write_byte_addr] "i"(WU_TMEM_SCALAR_WRITE_BYTE_ADDR) + : "memory"); +} + +extern "C" int wu_main() { + if (vx_core_id() != 0 || vx_warp_id() != 0) { + return 0; + } + + const uint32_t tid = wu_tid(); + + if (tid == 0) { + wu_case_reset(); + for (uint32_t i = 0; i < 4; ++i) { + g_case11_out[i] = 0; + g_case11_scalar_seen[i] = 0; + } + + volatile uint32_t *smem_b = + reinterpret_cast(DEV_SMEM_START_ADDR); + for (uint32_t i = 0; i < WU_TMEM_TILE_BYTES / sizeof(uint32_t); ++i) { + smem_b[i] = g_case11_b_row[i & 3u]; + } + } + asm volatile("fence rw, rw" ::: "memory"); + + if (tid == 0) { + vx_spawn_tensor(1u << NUM_SCALAR_WARPS, tensor_case11_worker); + + if (wu_wait_seen_mask(1u << NUM_SCALAR_WARPS, + WU_CASE_TMEM_STAGE_HGEMM_BASE) != 0) { + g_case_mem[1] = 0x20u; + } + } + asm volatile("fence rw, rw" ::: "memory"); + if (g_case_mem[1] != 0) { + if (tid == 0) { + wu_case_fail(g_case_mem[1]); + } + return 1; + } + + const uint32_t c_frag_base = WU_TMEM_C_BYTE_BASE / WU_TMEM_FRAGMENT_BYTES; + const uint32_t observed = wu_scalar_tmem_ld(c_frag_base); + g_case11_scalar_seen[tid] = observed; + + const uint32_t write_frag = + WU_TMEM_SCALAR_WRITE_BYTE_ADDR / WU_TMEM_FRAGMENT_BYTES; + wu_scalar_tmem_st(write_frag, observed + 1u + tid); + asm volatile("fence rw, rw" ::: "memory"); + + if (tid == 0) { + if (g_case11_scalar_seen[0] != WU_TMEM_EXPECTED_HGEMM) { + g_aux[0] = 0; + g_aux[1] = g_case11_scalar_seen[0]; + g_case_mem[1] = 0x21u; + } + } + asm volatile("fence rw, rw" ::: "memory"); + if (g_case_mem[1] != 0) { + if (tid == 0) { + wu_case_fail(g_case_mem[1]); + } + return 1; + } + + if (tid == 0) { + g_case_mem[0] = WU_CASE_TMEM_STAGE_VERIFY_BASE; + } + asm volatile("fence rw, rw" ::: "memory"); + + if (tid == 0) { + if (wu_wait_seen_mask(1u << NUM_SCALAR_WARPS, + WU_CASE_TMEM_STAGE_VERIFY_BASE) != 0) { + g_case_mem[1] = 0x22u; + } + } + asm volatile("fence rw, rw" ::: "memory"); + if (g_case_mem[1] != 0) { + if (tid == 0) { + wu_case_fail(g_case_mem[1]); + } + return 1; + } + + if (tid == 0) { + if (g_case11_out[0] != WU_TMEM_STAGE_VALUE_BASE) { + g_aux[0] = 0; + g_aux[1] = g_case11_out[0]; + wu_case_fail(0x23u); + return 1; + } + + wu_case_pass(); + } + return 0; +} diff --git a/kernels/wu_arch_cases/case12_1_scalar_tmem_cb_probe/Makefile b/kernels/wu_arch_cases/case12_1_scalar_tmem_cb_probe/Makefile new file mode 100644 index 00000000..86d50507 --- /dev/null +++ b/kernels/wu_arch_cases/case12_1_scalar_tmem_cb_probe/Makefile @@ -0,0 +1,3 @@ +PROJECT = case12_flash_pv_accum + +include ../case.mk diff --git a/kernels/wu_arch_cases/case12_1_scalar_tmem_cb_probe/README.md b/kernels/wu_arch_cases/case12_1_scalar_tmem_cb_probe/README.md new file mode 100644 index 00000000..e65b60cc --- /dev/null +++ b/kernels/wu_arch_cases/case12_1_scalar_tmem_cb_probe/README.md @@ -0,0 +1,10 @@ +# case12.1 scalar TMEM CB probe + +Validates the scalar-to-tensor TMEM handoff without BWGMMA. Scalar warp 0 fills +tensor block 0 TMEM A rows `0..63` with packed fp16 `1.0`. Tensor warp +`NUM_SCALAR_WARPS` waits for that fill, then uses `tcgen05_cb` to copy the same +TMEM A rows back to global memory. + +The scalar leader verifies all 256 copied words are `0x3c003c00`. A mismatch +means scalar TMEM stores are not visible to the tensor TMEM read/copy path at +the expected row addresses. diff --git a/kernels/wu_arch_cases/case12_1_scalar_tmem_cb_probe/kernel.cpp b/kernels/wu_arch_cases/case12_1_scalar_tmem_cb_probe/kernel.cpp new file mode 100644 index 00000000..aa02a54d --- /dev/null +++ b/kernels/wu_arch_cases/case12_1_scalar_tmem_cb_probe/kernel.cpp @@ -0,0 +1,108 @@ +#include "../common_wu_blackwell_fa.h" + +#define WU_CASE12_1_INIT_BASE 0x8d00u +#define WU_CASE12_1_DONE_BASE 0x8e00u +#define WU_CASE12_1_COPY_READY 0x8f00u + +extern "C" { +volatile uint32_t g_case12_1_out[WU_BW_OUT_WORDS] __attribute__((aligned(16))); +} + +extern "C" void __attribute__((naked, noinline, used)) tensor_case12_1_worker() { + asm volatile( + "csrr x5, %[csr_wid]\n\t" + "addi x1, x5, -%[num_scalar_warps]\n\t" + "slli x1, x1, 11\n\t" + "slli x6, x5, 2\n\t" + "la x7, g_seen\n\t" + "add x7, x7, x6\n\t" + "li x6, %[init_base]\n\t" + "or x6, x6, x5\n\t" + "sw x6, 0(x7)\n\t" + "1:\n\t" + "la x6, g_case_mem\n\t" + "lw x7, 0(x6)\n\t" + "li x4, %[copy_ready]\n\t" + "bne x7, x4, 1b\n\t" + "la x3, g_case12_1_out\n\t" + "li x7, 0\n\t" + "2:\n\t" + "add x4, x1, x7\n\t" + "add x6, x3, x7\n\t" + ".insn r %[custom3], 6, 0, x0, x4, x6\n\t" + "addi x7, x7, 16\n\t" + "li x4, %[tile_bytes]\n\t" + "blt x7, x4, 2b\n\t" + ".insn r %[custom3], 3, 0, x0, x0, x0\n\t" + "csrr x5, %[csr_wid]\n\t" + "slli x6, x5, 2\n\t" + "la x7, g_seen\n\t" + "add x7, x7, x6\n\t" + "li x6, %[done_base]\n\t" + "or x6, x6, x5\n\t" + "sw x6, 0(x7)\n\t" + ".insn r %[custom0], 0, 0, x0, x0, x0\n\t" + "3: j 3b\n\t" + : + : [csr_wid] "i"(VX_CSR_WARP_ID), [custom0] "i"(RISCV_CUSTOM0), + [custom3] "i"(RISCV_CUSTOM3), + [num_scalar_warps] "i"(NUM_SCALAR_WARPS), + [tile_bytes] "i"(WU_BW_TMEM_TILE_BYTES), + [init_base] "i"(WU_CASE12_1_INIT_BASE), + [done_base] "i"(WU_CASE12_1_DONE_BASE), + [copy_ready] "i"(WU_CASE12_1_COPY_READY) + : "memory"); +} + +extern "C" int wu_main() { + if (vx_core_id() != 0 || vx_warp_id() != 0) { + return 0; + } + + const uint32_t tid = wu_tid(); + if (tid == 0) { + wu_case_reset(); + for (uint32_t i = 0; i < WU_BW_OUT_WORDS; ++i) { + g_case12_1_out[i] = 0; + } + vx_spawn_tensor(1u << NUM_SCALAR_WARPS, tensor_case12_1_worker); + if (wu_wait_seen_mask(1u << NUM_SCALAR_WARPS, WU_CASE12_1_INIT_BASE) != + 0) { + g_case_mem[1] = 0x12u; + } + } + asm volatile("fence rw, rw" ::: "memory"); + + if (g_case_mem[1] == 0) { + vx_tmc(wu_bw_all_lanes_mask()); + wu_bw_fill_tmem_tile(wu_bw_tmem_a_byte_base(0), WU_BW_FP16_ONE_PACKED); + vx_tmc_one(); + } + asm volatile("fence rw, rw" ::: "memory"); + + if (tid == 0) { + g_case_mem[0] = WU_CASE12_1_COPY_READY; + if (g_case_mem[1] == 0 && + wu_wait_seen_mask(1u << NUM_SCALAR_WARPS, WU_CASE12_1_DONE_BASE) != + 0) { + g_case_mem[1] = 0x13u; + } + if (g_case_mem[1] == 0) { + volatile uint32_t bad_actual = 0; + const uint32_t bad = + wu_bw_verify_constant(g_case12_1_out, WU_BW_OUT_WORDS, + WU_BW_FP16_ONE_PACKED, &bad_actual); + if (bad != WU_BW_OUT_WORDS) { + g_aux[0] = bad; + g_aux[1] = bad_actual; + g_case_mem[1] = 0x14u; + } + } + if (g_case_mem[1] != 0) { + wu_case_fail(g_case_mem[1]); + return 1; + } + wu_case_pass(); + } + return 0; +} diff --git a/kernels/wu_arch_cases/case12_2_flash_pv_p_probe/Makefile b/kernels/wu_arch_cases/case12_2_flash_pv_p_probe/Makefile new file mode 100644 index 00000000..0c4ccda8 --- /dev/null +++ b/kernels/wu_arch_cases/case12_2_flash_pv_p_probe/Makefile @@ -0,0 +1,3 @@ +PROJECT = case12_2_flash_pv_p_probe + +include ../case.mk diff --git a/kernels/wu_arch_cases/case12_2_flash_pv_p_probe/README.md b/kernels/wu_arch_cases/case12_2_flash_pv_p_probe/README.md new file mode 100644 index 00000000..70b361ae --- /dev/null +++ b/kernels/wu_arch_cases/case12_2_flash_pv_p_probe/README.md @@ -0,0 +1,11 @@ +# case12_2_flash_pv_p_probe + +Diagnostic for `case12_flash_pv_accum`. + +Scalar warp 0 fills tensor block 0 TMEM A with packed fp16 `1.0`, using the +same `wu_bw_fill_tmem_tile()` path as case12. Tensor warp `NUM_SCALAR_WARPS` +then copies the full TMEM A tile back to global memory with `tcgen05_cb`. + +The scalar leader verifies all copied words are `0x3c003c00`. A mismatch means +case12's scalar-written P tile is not reaching the tensor TMEM read/copy path as +expected. diff --git a/kernels/wu_arch_cases/case12_2_flash_pv_p_probe/kernel.cpp b/kernels/wu_arch_cases/case12_2_flash_pv_p_probe/kernel.cpp new file mode 100644 index 00000000..6eefdf1f --- /dev/null +++ b/kernels/wu_arch_cases/case12_2_flash_pv_p_probe/kernel.cpp @@ -0,0 +1,109 @@ +#include "../common_wu_blackwell_fa.h" + +#define WU_CASE12_2_INIT_BASE 0x9600u +#define WU_CASE12_2_DONE_BASE 0x9700u +#define WU_CASE12_2_COPY_READY 0x9800u + +extern "C" { +volatile uint32_t g_case12_2_out[WU_BW_OUT_WORDS] + __attribute__((aligned(16))); +} + +extern "C" void __attribute__((naked, noinline, used)) +tensor_case12_2_worker() { + asm volatile( + "csrr x5, %[csr_wid]\n\t" + "addi x1, x5, -%[num_scalar_warps]\n\t" + "slli x1, x1, 11\n\t" + "slli x6, x5, 2\n\t" + "la x7, g_seen\n\t" + "add x7, x7, x6\n\t" + "li x6, %[init_base]\n\t" + "or x6, x6, x5\n\t" + "sw x6, 0(x7)\n\t" + "1:\n\t" + "la x6, g_case_mem\n\t" + "lw x7, 0(x6)\n\t" + "li x4, %[copy_ready]\n\t" + "bne x7, x4, 1b\n\t" + "la x3, g_case12_2_out\n\t" + "li x7, 0\n\t" + "2:\n\t" + "add x4, x1, x7\n\t" + "add x6, x3, x7\n\t" + ".insn r %[custom3], 6, 0, x0, x4, x6\n\t" + "addi x7, x7, 16\n\t" + "li x4, %[tile_bytes]\n\t" + "blt x7, x4, 2b\n\t" + ".insn r %[custom3], 3, 0, x0, x0, x0\n\t" + "csrr x5, %[csr_wid]\n\t" + "slli x6, x5, 2\n\t" + "la x7, g_seen\n\t" + "add x7, x7, x6\n\t" + "li x6, %[done_base]\n\t" + "or x6, x6, x5\n\t" + "sw x6, 0(x7)\n\t" + ".insn r %[custom0], 0, 0, x0, x0, x0\n\t" + "3: j 3b\n\t" + : + : [csr_wid] "i"(VX_CSR_WARP_ID), [custom0] "i"(RISCV_CUSTOM0), + [custom3] "i"(RISCV_CUSTOM3), + [num_scalar_warps] "i"(NUM_SCALAR_WARPS), + [tile_bytes] "i"(WU_BW_TMEM_TILE_BYTES), + [init_base] "i"(WU_CASE12_2_INIT_BASE), + [done_base] "i"(WU_CASE12_2_DONE_BASE), + [copy_ready] "i"(WU_CASE12_2_COPY_READY) + : "memory"); +} + +extern "C" int wu_main() { + if (vx_core_id() != 0 || vx_warp_id() != 0) { + return 0; + } + + const uint32_t tid = wu_tid(); + const uint32_t tensor_mask = 1u << NUM_SCALAR_WARPS; + if (tid == 0) { + wu_case_reset(); + for (uint32_t i = 0; i < WU_BW_OUT_WORDS; ++i) { + g_case12_2_out[i] = 0; + } + vx_spawn_tensor(tensor_mask, tensor_case12_2_worker); + if (wu_wait_seen_mask(tensor_mask, WU_CASE12_2_INIT_BASE) != 0) { + g_case_mem[1] = 0x51u; + } + } + asm volatile("fence rw, rw" ::: "memory"); + + if (g_case_mem[1] == 0) { + vx_tmc(wu_bw_all_lanes_mask()); + wu_bw_fill_tmem_tile(wu_bw_tmem_a_byte_base(0), WU_BW_FP16_ONE_PACKED); + vx_tmc_one(); + } + asm volatile("fence rw, rw" ::: "memory"); + + if (tid == 0) { + g_case_mem[0] = WU_CASE12_2_COPY_READY; + if (g_case_mem[1] == 0 && + wu_wait_seen_mask(tensor_mask, WU_CASE12_2_DONE_BASE) != 0) { + g_case_mem[1] = 0x52u; + } + if (g_case_mem[1] == 0) { + volatile uint32_t bad_actual = 0; + const uint32_t bad = + wu_bw_verify_constant(g_case12_2_out, WU_BW_OUT_WORDS, + WU_BW_FP16_ONE_PACKED, &bad_actual); + if (bad != WU_BW_OUT_WORDS) { + g_aux[0] = bad; + g_aux[1] = bad_actual; + g_case_mem[1] = 0x53u; + } + } + if (g_case_mem[1] != 0) { + wu_case_fail(g_case_mem[1]); + return 1; + } + wu_case_pass(); + } + return 0; +} diff --git a/kernels/wu_arch_cases/case12_3_scalar_tmem_lane_store/Makefile b/kernels/wu_arch_cases/case12_3_scalar_tmem_lane_store/Makefile new file mode 100644 index 00000000..13faf673 --- /dev/null +++ b/kernels/wu_arch_cases/case12_3_scalar_tmem_lane_store/Makefile @@ -0,0 +1,3 @@ +PROJECT = case12_3_scalar_tmem_lane_store + +include ../case.mk diff --git a/kernels/wu_arch_cases/case12_3_scalar_tmem_lane_store/README.md b/kernels/wu_arch_cases/case12_3_scalar_tmem_lane_store/README.md new file mode 100644 index 00000000..6fbb24a4 --- /dev/null +++ b/kernels/wu_arch_cases/case12_3_scalar_tmem_lane_store/README.md @@ -0,0 +1,14 @@ +# case12_3_scalar_tmem_lane_store + +Validates scalar TMEM store lane-coalesced semantics. + +One scalar warp writes a single 16-byte TMEM fragment. Each active scalar lane supplies one 32-bit word: + +```text +word0 = lane0.rs2 +word1 = lane1.rs2 +word2 = lane2.rs2 +word3 = lane3.rs2 +``` + +The tensor copy-back path then copies that fragment to memory. This catches implementations that only write lane0 data or broadcast one scalar value. diff --git a/kernels/wu_arch_cases/case12_3_scalar_tmem_lane_store/kernel.cpp b/kernels/wu_arch_cases/case12_3_scalar_tmem_lane_store/kernel.cpp new file mode 100644 index 00000000..f20bbcbe --- /dev/null +++ b/kernels/wu_arch_cases/case12_3_scalar_tmem_lane_store/kernel.cpp @@ -0,0 +1,91 @@ +#include "../common_wu_blackwell_fa.h" + +#define WU_CASE12_3_COPY_READY 0x9900u +#define WU_CASE12_3_DONE_BASE 0x9a00u + +extern "C" { +volatile uint32_t g_case12_3_out[4] __attribute__((aligned(16))); +} + +extern "C" void __attribute__((naked, noinline, used)) +tensor_case12_3_copy_worker() { + asm volatile( + "csrr x5, %[csr_wid]\n\t" + "addi x1, x5, -%[num_scalar_warps]\n\t" + "slli x1, x1, 11\n\t" + "1:\n\t" + "la x6, g_case_mem\n\t" + "lw x7, 0(x6)\n\t" + "li x4, %[copy_ready]\n\t" + "bne x7, x4, 1b\n\t" + "la x6, g_case12_3_out\n\t" + ".insn r %[custom3], 6, 0, x0, x1, x6\n\t" + ".insn r %[custom3], 3, 0, x0, x0, x0\n\t" + "csrr x5, %[csr_wid]\n\t" + "slli x6, x5, 2\n\t" + "la x7, g_seen\n\t" + "add x7, x7, x6\n\t" + "li x6, %[done_base]\n\t" + "or x6, x6, x5\n\t" + "sw x6, 0(x7)\n\t" + ".insn r %[custom0], 0, 0, x0, x0, x0\n\t" + "2: j 2b\n\t" + : + : [csr_wid] "i"(VX_CSR_WARP_ID), [custom0] "i"(RISCV_CUSTOM0), + [custom3] "i"(RISCV_CUSTOM3), + [num_scalar_warps] "i"(NUM_SCALAR_WARPS), + [copy_ready] "i"(WU_CASE12_3_COPY_READY), + [done_base] "i"(WU_CASE12_3_DONE_BASE) + : "memory"); +} + +extern "C" int wu_main() { + if (vx_core_id() != 0 || vx_warp_id() != 0) { + return 0; + } + + const uint32_t tensor_mask = 1u << NUM_SCALAR_WARPS; + + wu_case_reset(); + for (uint32_t i = 0; i < 4; ++i) { + g_case12_3_out[i] = 0; + } + vx_spawn_tensor(tensor_mask, tensor_case12_3_copy_worker); + asm volatile("fence rw, rw" ::: "memory"); + + const uint32_t frag_addr = + wu_bw_tmem_a_byte_base(0) / WU_BW_TMEM_FRAGMENT_BYTES; + asm volatile( + ".insn r %[custom0], 0, 0, x0, %[all_lanes], x0\n\t" + "csrr t0, %[csr_tid]\n\t" + "lui t1, 0x5a120\n\t" + "or t1, t1, t0\n\t" + ".insn r %[custom1], 1, 0x30, x0, %[addr], t1\n\t" + "fence rw, rw\n\t" + "li t2, 1\n\t" + ".insn r %[custom0], 0, 0, x0, t2, x0" + : + : [custom0] "i"(RISCV_CUSTOM0), [custom1] "i"(RISCV_CUSTOM1), + [csr_tid] "i"(VX_CSR_THREAD_ID), [addr] "r"(frag_addr), + [all_lanes] "r"(wu_bw_all_lanes_mask()) + : "t0", "t1", "t2", "memory"); + + g_case_mem[0] = WU_CASE12_3_COPY_READY; + if (wu_wait_seen_mask(tensor_mask, WU_CASE12_3_DONE_BASE) != 0) { + wu_case_fail(0x51u); + return 1; + } + + for (uint32_t lane = 0; lane < NUM_THREADS; ++lane) { + const uint32_t expected = 0x5a120000u | lane; + if (g_case12_3_out[lane] != expected) { + g_aux[0] = lane; + g_aux[1] = g_case12_3_out[lane]; + wu_case_fail(0x52u); + return 1; + } + } + + wu_case_pass(); + return 0; +} diff --git a/kernels/wu_arch_cases/case12_flash_pv_accum/Makefile b/kernels/wu_arch_cases/case12_flash_pv_accum/Makefile new file mode 100644 index 00000000..86d50507 --- /dev/null +++ b/kernels/wu_arch_cases/case12_flash_pv_accum/Makefile @@ -0,0 +1,3 @@ +PROJECT = case12_flash_pv_accum + +include ../case.mk diff --git a/kernels/wu_arch_cases/case12_flash_pv_accum/README.md b/kernels/wu_arch_cases/case12_flash_pv_accum/README.md new file mode 100644 index 00000000..1b17257b --- /dev/null +++ b/kernels/wu_arch_cases/case12_flash_pv_accum/README.md @@ -0,0 +1,7 @@ +# case12_flash_pv_accum + +Validates the first Wu Blackwell FlashAttention PV substage. + +Scalar warp 0 writes a full `16x32` fp16 `P` tile into TMEM A. One tensor warp +uses BWGMMA with a `32x16` fp16 `V` tile in SMEM and a preloaded fp32 TMEM C +tile. The expected result is `O = 3.0 + P @ V = 67.0` for every output element. diff --git a/kernels/wu_arch_cases/case12_flash_pv_accum/kernel.cpp b/kernels/wu_arch_cases/case12_flash_pv_accum/kernel.cpp new file mode 100644 index 00000000..b6489ade --- /dev/null +++ b/kernels/wu_arch_cases/case12_flash_pv_accum/kernel.cpp @@ -0,0 +1,127 @@ +#include "../common_wu_blackwell_fa.h" + +#define WU_CASE12_INIT_BASE 0x8a00u +#define WU_CASE12_DONE_BASE 0x8b00u +#define WU_CASE12_P_READY 0x8c00u + +extern "C" { +volatile uint32_t g_case12_o_row[4] __attribute__((aligned(16))) = { + WU_BW_FP32_THREE, WU_BW_FP32_THREE, WU_BW_FP32_THREE, + WU_BW_FP32_THREE}; +volatile uint32_t g_case12_out[WU_BW_OUT_WORDS] __attribute__((aligned(16))); +} + +extern "C" void __attribute__((naked, noinline, used)) tensor_case12_worker() { + asm volatile( + "csrr x5, %[csr_wid]\n\t" + "addi x1, x5, -%[num_scalar_warps]\n\t" + "slli x1, x1, 11\n\t" + "addi x2, x1, %[c_offset]\n\t" + "la x3, g_case12_o_row\n\t" + "li x7, 0\n\t" + "1:\n\t" + "add x4, x2, x7\n\t" + ".insn r %[custom3], 2, 0, x0, x4, x3\n\t" + "addi x7, x7, 16\n\t" + "li x4, %[tile_bytes]\n\t" + "blt x7, x4, 1b\n\t" + ".insn r %[custom3], 3, 0, x0, x0, x0\n\t" + "slli x6, x5, 2\n\t" + "la x7, g_seen\n\t" + "add x7, x7, x6\n\t" + "li x6, %[init_base]\n\t" + "or x6, x6, x5\n\t" + "sw x6, 0(x7)\n\t" + "2:\n\t" + "la x6, g_case_mem\n\t" + "lw x7, 0(x6)\n\t" + "li x4, %[p_ready]\n\t" + "bne x7, x4, 2b\n\t" + "li x4, %[smem_base]\n\t" + ".insn r %[custom3], 0, 0, x2, x1, x4\n\t" + ".insn r %[custom3], 1, 0, x0, x0, x0\n\t" + "la x3, g_case12_out\n\t" + "li x7, 0\n\t" + "3:\n\t" + "add x4, x2, x7\n\t" + "add x6, x3, x7\n\t" + ".insn r %[custom3], 6, 0, x0, x4, x6\n\t" + "addi x7, x7, 16\n\t" + "li x4, %[tile_bytes]\n\t" + "blt x7, x4, 3b\n\t" + ".insn r %[custom3], 3, 0, x0, x0, x0\n\t" + "csrr x5, %[csr_wid]\n\t" + "slli x6, x5, 2\n\t" + "la x7, g_seen\n\t" + "add x7, x7, x6\n\t" + "li x6, %[done_base]\n\t" + "or x6, x6, x5\n\t" + "sw x6, 0(x7)\n\t" + ".insn r %[custom0], 0, 0, x0, x0, x0\n\t" + "4: j 4b\n\t" + : + : [csr_wid] "i"(VX_CSR_WARP_ID), [custom0] "i"(RISCV_CUSTOM0), + [custom3] "i"(RISCV_CUSTOM3), + [num_scalar_warps] "i"(NUM_SCALAR_WARPS), + [c_offset] "i"(WU_BW_TMEM_C_BYTE_OFFSET), + [tile_bytes] "i"(WU_BW_TMEM_TILE_BYTES), + [smem_base] "i"(WU_BW_DEV_SMEM_START_ADDR), + [init_base] "i"(WU_CASE12_INIT_BASE), + [done_base] "i"(WU_CASE12_DONE_BASE), + [p_ready] "i"(WU_CASE12_P_READY) + : "memory"); +} + +extern "C" int wu_main() { + if (vx_core_id() != 0 || vx_warp_id() != 0) { + return 0; + } + + const uint32_t tid = wu_tid(); + if (tid == 0) { + wu_case_reset(); + for (uint32_t i = 0; i < WU_BW_OUT_WORDS; ++i) { + g_case12_out[i] = 0; + } + wu_bw_fill_smem_tile( + reinterpret_cast(WU_BW_DEV_SMEM_START_ADDR), + WU_BW_FP16_TWO_PACKED); + vx_spawn_tensor(1u << NUM_SCALAR_WARPS, tensor_case12_worker); + if (wu_wait_seen_mask(1u << NUM_SCALAR_WARPS, WU_CASE12_INIT_BASE) != 0) { + g_case_mem[1] = 0x12u; + } + } + asm volatile("fence rw, rw" ::: "memory"); + + if (g_case_mem[1] == 0) { + vx_tmc(wu_bw_all_lanes_mask()); + wu_bw_fill_tmem_tile(wu_bw_tmem_a_byte_base(0), WU_BW_FP16_ONE_PACKED); + vx_tmc_one(); + } + asm volatile("fence rw, rw" ::: "memory"); + + if (tid == 0) { + g_case_mem[0] = WU_CASE12_P_READY; + if (g_case_mem[1] == 0 && + wu_wait_seen_mask(1u << NUM_SCALAR_WARPS, WU_CASE12_DONE_BASE) != 0) { + g_case_mem[1] = 0x13u; + } + if (g_case_mem[1] == 0) { + volatile uint32_t bad_actual = 0; + const uint32_t bad = + wu_bw_verify_constant(g_case12_out, WU_BW_OUT_WORDS, + WU_BW_FP32_SIXTY_SEVEN, &bad_actual); + if (bad != WU_BW_OUT_WORDS) { + g_aux[0] = bad; + g_aux[1] = bad_actual; + g_case_mem[1] = 0x14u; + } + } + if (g_case_mem[1] != 0) { + wu_case_fail(g_case_mem[1]); + return 1; + } + wu_case_pass(); + } + return 0; +} diff --git a/kernels/wu_arch_cases/case13_flash_pv_two_warps/Makefile b/kernels/wu_arch_cases/case13_flash_pv_two_warps/Makefile new file mode 100644 index 00000000..fd75f1d4 --- /dev/null +++ b/kernels/wu_arch_cases/case13_flash_pv_two_warps/Makefile @@ -0,0 +1,3 @@ +PROJECT = case13_flash_pv_two_warps + +include ../case.mk diff --git a/kernels/wu_arch_cases/case13_flash_pv_two_warps/README.md b/kernels/wu_arch_cases/case13_flash_pv_two_warps/README.md new file mode 100644 index 00000000..179c5f0a --- /dev/null +++ b/kernels/wu_arch_cases/case13_flash_pv_two_warps/README.md @@ -0,0 +1,8 @@ +# case13_flash_pv_two_warps + +Validates the Wu Blackwell FlashAttention PV substage across both tensor warps. + +Scalar warp 0 writes one `P` tile into each tensor warp's TMEM A partition. +Both tensor warps consume the same SMEM `V` tile, accumulate into their own TMEM +C tiles, copy out contiguous row blocks, and stop. Every fp32 output is expected +to be `67.0`. diff --git a/kernels/wu_arch_cases/case13_flash_pv_two_warps/kernel.cpp b/kernels/wu_arch_cases/case13_flash_pv_two_warps/kernel.cpp new file mode 100644 index 00000000..1b61a089 --- /dev/null +++ b/kernels/wu_arch_cases/case13_flash_pv_two_warps/kernel.cpp @@ -0,0 +1,135 @@ +#include "../common_wu_blackwell_fa.h" + +#define WU_CASE13_INIT_BASE 0x8d00u +#define WU_CASE13_DONE_BASE 0x8e00u +#define WU_CASE13_P_READY 0x8f00u +#define WU_CASE13_OUT_WORDS (NUM_TENSOR_WARPS * WU_BW_OUT_WORDS) + +extern "C" { +volatile uint32_t g_case13_o_row[4] __attribute__((aligned(16))) = { + WU_BW_FP32_THREE, WU_BW_FP32_THREE, WU_BW_FP32_THREE, + WU_BW_FP32_THREE}; +volatile uint32_t g_case13_out[WU_CASE13_OUT_WORDS] + __attribute__((aligned(16))); +} + +extern "C" void __attribute__((naked, noinline, used)) tensor_case13_worker() { + asm volatile( + "csrr x5, %[csr_wid]\n\t" + "addi x1, x5, -%[num_scalar_warps]\n\t" + "slli x1, x1, 11\n\t" + "addi x2, x1, %[c_offset]\n\t" + "la x3, g_case13_o_row\n\t" + "li x7, 0\n\t" + "1:\n\t" + "add x4, x2, x7\n\t" + ".insn r %[custom3], 2, 0, x0, x4, x3\n\t" + "addi x7, x7, 16\n\t" + "li x4, %[tile_bytes]\n\t" + "blt x7, x4, 1b\n\t" + ".insn r %[custom3], 3, 0, x0, x0, x0\n\t" + "slli x6, x5, 2\n\t" + "la x7, g_seen\n\t" + "add x7, x7, x6\n\t" + "li x6, %[init_base]\n\t" + "or x6, x6, x5\n\t" + "sw x6, 0(x7)\n\t" + "2:\n\t" + "la x6, g_case_mem\n\t" + "lw x7, 0(x6)\n\t" + "li x4, %[p_ready]\n\t" + "bne x7, x4, 2b\n\t" + "li x4, %[smem_base]\n\t" + ".insn r %[custom3], 0, 0, x2, x1, x4\n\t" + ".insn r %[custom3], 1, 0, x0, x0, x0\n\t" + "csrr x5, %[csr_wid]\n\t" + "addi x6, x5, -%[num_scalar_warps]\n\t" + "slli x7, x6, 10\n\t" + "la x3, g_case13_out\n\t" + "add x3, x3, x7\n\t" + "li x7, 0\n\t" + "3:\n\t" + "add x4, x2, x7\n\t" + "add x6, x3, x7\n\t" + ".insn r %[custom3], 6, 0, x0, x4, x6\n\t" + "addi x7, x7, 16\n\t" + "li x4, %[tile_bytes]\n\t" + "blt x7, x4, 3b\n\t" + ".insn r %[custom3], 3, 0, x0, x0, x0\n\t" + "csrr x5, %[csr_wid]\n\t" + "slli x6, x5, 2\n\t" + "la x7, g_seen\n\t" + "add x7, x7, x6\n\t" + "li x6, %[done_base]\n\t" + "or x6, x6, x5\n\t" + "sw x6, 0(x7)\n\t" + ".insn r %[custom0], 0, 0, x0, x0, x0\n\t" + "4: j 4b\n\t" + : + : [csr_wid] "i"(VX_CSR_WARP_ID), [custom0] "i"(RISCV_CUSTOM0), + [custom3] "i"(RISCV_CUSTOM3), + [num_scalar_warps] "i"(NUM_SCALAR_WARPS), + [c_offset] "i"(WU_BW_TMEM_C_BYTE_OFFSET), + [tile_bytes] "i"(WU_BW_TMEM_TILE_BYTES), + [smem_base] "i"(WU_BW_DEV_SMEM_START_ADDR), + [init_base] "i"(WU_CASE13_INIT_BASE), + [done_base] "i"(WU_CASE13_DONE_BASE), + [p_ready] "i"(WU_CASE13_P_READY) + : "memory"); +} + +extern "C" int wu_main() { + if (vx_core_id() != 0 || vx_warp_id() != 0) { + return 0; + } + + const uint32_t tid = wu_tid(); + const uint32_t tensor_mask = vx_tensor_warp_mask(); + if (tid == 0) { + wu_case_reset(); + for (uint32_t i = 0; i < WU_CASE13_OUT_WORDS; ++i) { + g_case13_out[i] = 0; + } + wu_bw_fill_smem_tile( + reinterpret_cast(WU_BW_DEV_SMEM_START_ADDR), + WU_BW_FP16_TWO_PACKED); + vx_spawn_tensor(tensor_mask, tensor_case13_worker); + if (wu_wait_seen_mask(tensor_mask, WU_CASE13_INIT_BASE) != 0) { + g_case_mem[1] = 0x21u; + } + } + asm volatile("fence rw, rw" ::: "memory"); + + if (g_case_mem[1] == 0) { + vx_tmc(wu_bw_all_lanes_mask()); + wu_bw_fill_tmem_tile(wu_bw_tmem_a_byte_base(0), WU_BW_FP16_ONE_PACKED); + wu_bw_fill_tmem_tile(wu_bw_tmem_a_byte_base(1), WU_BW_FP16_ONE_PACKED); + vx_tmc_one(); + } + asm volatile("fence rw, rw" ::: "memory"); + + if (tid == 0) { + g_case_mem[0] = WU_CASE13_P_READY; + if (g_case_mem[1] == 0 && + wu_wait_seen_mask(tensor_mask, WU_CASE13_DONE_BASE) != 0) { + g_case_mem[1] = 0x22u; + } + if (g_case_mem[1] == 0) { + volatile uint32_t bad_actual = 0; + const uint32_t bad = + wu_bw_verify_constant(g_case13_out, WU_CASE13_OUT_WORDS, + WU_BW_FP32_SIXTY_SEVEN, &bad_actual); + if (bad != WU_CASE13_OUT_WORDS) { + g_aux[0] = bad; + g_aux[1] = bad_actual; + g_case_mem[1] = 0x23u; + } + } + if (g_case_mem[1] != 0) { + wu_case_fail(g_case_mem[1]); + return 1; + } + wu_case_pass(); + } + return 0; +} diff --git a/kernels/wu_arch_cases/case14_flash_pv_k64/Makefile b/kernels/wu_arch_cases/case14_flash_pv_k64/Makefile new file mode 100644 index 00000000..f72cda32 --- /dev/null +++ b/kernels/wu_arch_cases/case14_flash_pv_k64/Makefile @@ -0,0 +1,3 @@ +PROJECT = case14_flash_pv_k64 + +include ../case.mk diff --git a/kernels/wu_arch_cases/case14_flash_pv_k64/README.md b/kernels/wu_arch_cases/case14_flash_pv_k64/README.md new file mode 100644 index 00000000..06fbaf58 --- /dev/null +++ b/kernels/wu_arch_cases/case14_flash_pv_k64/README.md @@ -0,0 +1,7 @@ +# case14_flash_pv_k64 + +Validates PV accumulation over two Blackwell `K=32` BWGMMA steps. + +Both tensor warps run two consecutive BWGMMA operations against the same TMEM C +tile, modeling a `K=64` FlashAttention PV accumulation. With `P=1.0`, +`V=2.0`, and `O_init=5.0`, every fp32 output is expected to be `133.0`. diff --git a/kernels/wu_arch_cases/case14_flash_pv_k64/kernel.cpp b/kernels/wu_arch_cases/case14_flash_pv_k64/kernel.cpp new file mode 100644 index 00000000..1f225bce --- /dev/null +++ b/kernels/wu_arch_cases/case14_flash_pv_k64/kernel.cpp @@ -0,0 +1,136 @@ +#include "../common_wu_blackwell_fa.h" + +#define WU_CASE14_INIT_BASE 0x9000u +#define WU_CASE14_DONE_BASE 0x9100u +#define WU_CASE14_P_READY 0x9200u +#define WU_CASE14_OUT_WORDS (NUM_TENSOR_WARPS * WU_BW_OUT_WORDS) + +extern "C" { +volatile uint32_t g_case14_o_row[4] __attribute__((aligned(16))) = { + WU_BW_FP32_FIVE, WU_BW_FP32_FIVE, WU_BW_FP32_FIVE, WU_BW_FP32_FIVE}; +volatile uint32_t g_case14_out[WU_CASE14_OUT_WORDS] + __attribute__((aligned(16))); +} + +extern "C" void __attribute__((naked, noinline, used)) tensor_case14_worker() { + asm volatile( + "csrr x5, %[csr_wid]\n\t" + "addi x1, x5, -%[num_scalar_warps]\n\t" + "slli x1, x1, 11\n\t" + "addi x2, x1, %[c_offset]\n\t" + "la x3, g_case14_o_row\n\t" + "li x7, 0\n\t" + "1:\n\t" + "add x4, x2, x7\n\t" + ".insn r %[custom3], 2, 0, x0, x4, x3\n\t" + "addi x7, x7, 16\n\t" + "li x4, %[tile_bytes]\n\t" + "blt x7, x4, 1b\n\t" + ".insn r %[custom3], 3, 0, x0, x0, x0\n\t" + "slli x6, x5, 2\n\t" + "la x7, g_seen\n\t" + "add x7, x7, x6\n\t" + "li x6, %[init_base]\n\t" + "or x6, x6, x5\n\t" + "sw x6, 0(x7)\n\t" + "2:\n\t" + "la x6, g_case_mem\n\t" + "lw x7, 0(x6)\n\t" + "li x4, %[p_ready]\n\t" + "bne x7, x4, 2b\n\t" + "li x4, %[smem_base]\n\t" + ".insn r %[custom3], 0, 0, x2, x1, x4\n\t" + ".insn r %[custom3], 1, 0, x0, x0, x0\n\t" + ".insn r %[custom3], 0, 0, x2, x1, x4\n\t" + ".insn r %[custom3], 1, 0, x0, x0, x0\n\t" + "csrr x5, %[csr_wid]\n\t" + "addi x6, x5, -%[num_scalar_warps]\n\t" + "slli x7, x6, 10\n\t" + "la x3, g_case14_out\n\t" + "add x3, x3, x7\n\t" + "li x7, 0\n\t" + "3:\n\t" + "add x4, x2, x7\n\t" + "add x6, x3, x7\n\t" + ".insn r %[custom3], 6, 0, x0, x4, x6\n\t" + "addi x7, x7, 16\n\t" + "li x4, %[tile_bytes]\n\t" + "blt x7, x4, 3b\n\t" + ".insn r %[custom3], 3, 0, x0, x0, x0\n\t" + "csrr x5, %[csr_wid]\n\t" + "slli x6, x5, 2\n\t" + "la x7, g_seen\n\t" + "add x7, x7, x6\n\t" + "li x6, %[done_base]\n\t" + "or x6, x6, x5\n\t" + "sw x6, 0(x7)\n\t" + ".insn r %[custom0], 0, 0, x0, x0, x0\n\t" + "4: j 4b\n\t" + : + : [csr_wid] "i"(VX_CSR_WARP_ID), [custom0] "i"(RISCV_CUSTOM0), + [custom3] "i"(RISCV_CUSTOM3), + [num_scalar_warps] "i"(NUM_SCALAR_WARPS), + [c_offset] "i"(WU_BW_TMEM_C_BYTE_OFFSET), + [tile_bytes] "i"(WU_BW_TMEM_TILE_BYTES), + [smem_base] "i"(WU_BW_DEV_SMEM_START_ADDR), + [init_base] "i"(WU_CASE14_INIT_BASE), + [done_base] "i"(WU_CASE14_DONE_BASE), + [p_ready] "i"(WU_CASE14_P_READY) + : "memory"); +} + +extern "C" int wu_main() { + if (vx_core_id() != 0 || vx_warp_id() != 0) { + return 0; + } + + const uint32_t tid = wu_tid(); + const uint32_t tensor_mask = vx_tensor_warp_mask(); + if (tid == 0) { + wu_case_reset(); + for (uint32_t i = 0; i < WU_CASE14_OUT_WORDS; ++i) { + g_case14_out[i] = 0; + } + wu_bw_fill_smem_tile( + reinterpret_cast(WU_BW_DEV_SMEM_START_ADDR), + WU_BW_FP16_TWO_PACKED); + vx_spawn_tensor(tensor_mask, tensor_case14_worker); + if (wu_wait_seen_mask(tensor_mask, WU_CASE14_INIT_BASE) != 0) { + g_case_mem[1] = 0x31u; + } + } + asm volatile("fence rw, rw" ::: "memory"); + + if (g_case_mem[1] == 0) { + vx_tmc(wu_bw_all_lanes_mask()); + wu_bw_fill_tmem_tile(wu_bw_tmem_a_byte_base(0), WU_BW_FP16_ONE_PACKED); + wu_bw_fill_tmem_tile(wu_bw_tmem_a_byte_base(1), WU_BW_FP16_ONE_PACKED); + vx_tmc_one(); + } + asm volatile("fence rw, rw" ::: "memory"); + + if (tid == 0) { + g_case_mem[0] = WU_CASE14_P_READY; + if (g_case_mem[1] == 0 && + wu_wait_seen_mask(tensor_mask, WU_CASE14_DONE_BASE) != 0) { + g_case_mem[1] = 0x32u; + } + if (g_case_mem[1] == 0) { + volatile uint32_t bad_actual = 0; + const uint32_t bad = + wu_bw_verify_constant(g_case14_out, WU_CASE14_OUT_WORDS, + WU_BW_FP32_ONE_THIRTY_THREE, &bad_actual); + if (bad != WU_CASE14_OUT_WORDS) { + g_aux[0] = bad; + g_aux[1] = bad_actual; + g_case_mem[1] = 0x33u; + } + } + if (g_case_mem[1] != 0) { + wu_case_fail(g_case_mem[1]); + return 1; + } + wu_case_pass(); + } + return 0; +} diff --git a/kernels/wu_arch_cases/case15_flash_softmax_pv_stage/Makefile b/kernels/wu_arch_cases/case15_flash_softmax_pv_stage/Makefile new file mode 100644 index 00000000..437be381 --- /dev/null +++ b/kernels/wu_arch_cases/case15_flash_softmax_pv_stage/Makefile @@ -0,0 +1,3 @@ +PROJECT = case15_flash_softmax_pv_stage + +include ../case.mk diff --git a/kernels/wu_arch_cases/case15_flash_softmax_pv_stage/README.md b/kernels/wu_arch_cases/case15_flash_softmax_pv_stage/README.md new file mode 100644 index 00000000..eb2c2752 --- /dev/null +++ b/kernels/wu_arch_cases/case15_flash_softmax_pv_stage/README.md @@ -0,0 +1,8 @@ +# case15_flash_softmax_pv_stage + +Validates the scalar softmax-stage handoff into the Blackwell PV GEMM. + +The tensor warp initializes TMEM C with a score/O value of `1.0`. Scalar warp 0 +reads that value through the scalar TMEM load path, writes a full fp16 `P=1.0` +tile into TMEM A, and releases the tensor warp. The tensor warp then runs PV +against `V=2.0`; every fp32 output is expected to be `65.0`. diff --git a/kernels/wu_arch_cases/case15_flash_softmax_pv_stage/kernel.cpp b/kernels/wu_arch_cases/case15_flash_softmax_pv_stage/kernel.cpp new file mode 100644 index 00000000..5bf0add2 --- /dev/null +++ b/kernels/wu_arch_cases/case15_flash_softmax_pv_stage/kernel.cpp @@ -0,0 +1,145 @@ +#include "../common_wu_blackwell_fa.h" + +#define WU_CASE15_INIT_BASE 0x9300u +#define WU_CASE15_DONE_BASE 0x9400u +#define WU_CASE15_P_READY 0x9500u + +extern "C" { +volatile uint32_t g_case15_score_row[4] __attribute__((aligned(16))) = { + WU_BW_FP32_ONE, WU_BW_FP32_ONE, WU_BW_FP32_ONE, WU_BW_FP32_ONE}; +volatile uint32_t g_case15_out[WU_BW_OUT_WORDS] __attribute__((aligned(16))); +volatile uint32_t g_case15_scalar_seen[4] __attribute__((aligned(16))); +} + +extern "C" void __attribute__((naked, noinline, used)) tensor_case15_worker() { + asm volatile( + "csrr x5, %[csr_wid]\n\t" + "addi x1, x5, -%[num_scalar_warps]\n\t" + "slli x1, x1, 11\n\t" + "addi x2, x1, %[c_offset]\n\t" + "la x3, g_case15_score_row\n\t" + "li x7, 0\n\t" + "1:\n\t" + "add x4, x2, x7\n\t" + ".insn r %[custom3], 2, 0, x0, x4, x3\n\t" + "addi x7, x7, 16\n\t" + "li x4, %[tile_bytes]\n\t" + "blt x7, x4, 1b\n\t" + ".insn r %[custom3], 3, 0, x0, x0, x0\n\t" + "slli x6, x5, 2\n\t" + "la x7, g_seen\n\t" + "add x7, x7, x6\n\t" + "li x6, %[init_base]\n\t" + "or x6, x6, x5\n\t" + "sw x6, 0(x7)\n\t" + "2:\n\t" + "la x6, g_case_mem\n\t" + "lw x7, 0(x6)\n\t" + "li x4, %[p_ready]\n\t" + "bne x7, x4, 2b\n\t" + "li x4, %[smem_base]\n\t" + ".insn r %[custom3], 0, 0, x2, x1, x4\n\t" + ".insn r %[custom3], 1, 0, x0, x0, x0\n\t" + "la x3, g_case15_out\n\t" + "li x7, 0\n\t" + "3:\n\t" + "add x4, x2, x7\n\t" + "add x6, x3, x7\n\t" + ".insn r %[custom3], 6, 0, x0, x4, x6\n\t" + "addi x7, x7, 16\n\t" + "li x4, %[tile_bytes]\n\t" + "blt x7, x4, 3b\n\t" + ".insn r %[custom3], 3, 0, x0, x0, x0\n\t" + "csrr x5, %[csr_wid]\n\t" + "slli x6, x5, 2\n\t" + "la x7, g_seen\n\t" + "add x7, x7, x6\n\t" + "li x6, %[done_base]\n\t" + "or x6, x6, x5\n\t" + "sw x6, 0(x7)\n\t" + ".insn r %[custom0], 0, 0, x0, x0, x0\n\t" + "4: j 4b\n\t" + : + : [csr_wid] "i"(VX_CSR_WARP_ID), [custom0] "i"(RISCV_CUSTOM0), + [custom3] "i"(RISCV_CUSTOM3), + [num_scalar_warps] "i"(NUM_SCALAR_WARPS), + [c_offset] "i"(WU_BW_TMEM_C_BYTE_OFFSET), + [tile_bytes] "i"(WU_BW_TMEM_TILE_BYTES), + [smem_base] "i"(WU_BW_DEV_SMEM_START_ADDR), + [init_base] "i"(WU_CASE15_INIT_BASE), + [done_base] "i"(WU_CASE15_DONE_BASE), + [p_ready] "i"(WU_CASE15_P_READY) + : "memory"); +} + +extern "C" int wu_main() { + if (vx_core_id() != 0 || vx_warp_id() != 0) { + return 0; + } + + const uint32_t tid = wu_tid(); + const uint32_t tensor_mask = 1u << NUM_SCALAR_WARPS; + if (tid == 0) { + wu_case_reset(); + for (uint32_t i = 0; i < WU_BW_OUT_WORDS; ++i) { + g_case15_out[i] = 0; + } + for (uint32_t i = 0; i < 4; ++i) { + g_case15_scalar_seen[i] = 0; + } + wu_bw_fill_smem_tile( + reinterpret_cast(WU_BW_DEV_SMEM_START_ADDR), + WU_BW_FP16_TWO_PACKED); + vx_spawn_tensor(tensor_mask, tensor_case15_worker); + if (wu_wait_seen_mask(tensor_mask, WU_CASE15_INIT_BASE) != 0) { + g_case_mem[1] = 0x41u; + } + } + asm volatile("fence rw, rw" ::: "memory"); + + const uint32_t c_frag = + wu_bw_tmem_c_byte_base(0) / WU_BW_TMEM_FRAGMENT_BYTES; + const uint32_t observed = wu_bw_scalar_tmem_ld(c_frag); + g_case15_scalar_seen[tid] = observed; + asm volatile("fence rw, rw" ::: "memory"); + + if (tid == 0 && g_case_mem[1] == 0 && + g_case15_scalar_seen[0] != WU_BW_FP32_ONE) { + g_aux[0] = 0; + g_aux[1] = g_case15_scalar_seen[0]; + g_case_mem[1] = 0x42u; + } + asm volatile("fence rw, rw" ::: "memory"); + + if (g_case_mem[1] == 0) { + vx_tmc(wu_bw_all_lanes_mask()); + wu_bw_fill_tmem_tile(wu_bw_tmem_a_byte_base(0), WU_BW_FP16_ONE_PACKED); + vx_tmc_one(); + } + asm volatile("fence rw, rw" ::: "memory"); + + if (tid == 0) { + g_case_mem[0] = WU_CASE15_P_READY; + if (g_case_mem[1] == 0 && + wu_wait_seen_mask(tensor_mask, WU_CASE15_DONE_BASE) != 0) { + g_case_mem[1] = 0x43u; + } + if (g_case_mem[1] == 0) { + volatile uint32_t bad_actual = 0; + const uint32_t bad = + wu_bw_verify_constant(g_case15_out, WU_BW_OUT_WORDS, + WU_BW_FP32_SIXTY_FIVE, &bad_actual); + if (bad != WU_BW_OUT_WORDS) { + g_aux[0] = bad; + g_aux[1] = bad_actual; + g_case_mem[1] = 0x44u; + } + } + if (g_case_mem[1] != 0) { + wu_case_fail(g_case_mem[1]); + return 1; + } + wu_case_pass(); + } + return 0; +} diff --git a/kernels/wu_arch_cases/case16_flash_full_pipeline/Makefile b/kernels/wu_arch_cases/case16_flash_full_pipeline/Makefile new file mode 100644 index 00000000..a027580e --- /dev/null +++ b/kernels/wu_arch_cases/case16_flash_full_pipeline/Makefile @@ -0,0 +1,3 @@ +PROJECT = case16_flash_full_pipeline + +include ../case.mk diff --git a/kernels/wu_arch_cases/case16_flash_full_pipeline/README.md b/kernels/wu_arch_cases/case16_flash_full_pipeline/README.md new file mode 100644 index 00000000..58577a7e --- /dev/null +++ b/kernels/wu_arch_cases/case16_flash_full_pipeline/README.md @@ -0,0 +1,14 @@ +# case16_flash_full_pipeline + +Validates a compact end-to-end FlashAttention-style pipeline on the current Wu +Blackwell path. + +The tensor warp first computes `S = Q @ K` into TMEM C with `Q=1.0`, `K=1.0`, +and `O_init=0.0`, producing a constant fp32 score of `32.0`. Scalar warp 0 reads +the score through scalar TMEM load, records it, writes the uniform softmax result +`P=1/32` into TMEM A, refills SMEM with `V=2.0`, and releases the tensor warp. +The tensor warp reloads TMEM C with `O=0.0`, then computes `O = P @ V`. + +Every output word is expected to be fp32 `2.0`. This case covers the staged +`QK -> scalar softmax handoff -> PV` loop without using the legacy +`kernels/flash_attention` implementation. diff --git a/kernels/wu_arch_cases/case16_flash_full_pipeline/kernel.cpp b/kernels/wu_arch_cases/case16_flash_full_pipeline/kernel.cpp new file mode 100644 index 00000000..d81deb5a --- /dev/null +++ b/kernels/wu_arch_cases/case16_flash_full_pipeline/kernel.cpp @@ -0,0 +1,180 @@ +#include "../common_wu_blackwell_fa.h" + +#define WU_CASE16_INIT_BASE 0x9600u +#define WU_CASE16_P_READY 0x9700u +#define WU_CASE16_DONE_BASE 0x9800u + +#define WU_BW_FP16_ONE_OVER_32_PACKED 0x28002800u +#define WU_BW_FP32_ZERO 0x00000000u +#define WU_BW_FP32_TWO 0x40000000u +#define WU_BW_FP32_THIRTY_TWO 0x42000000u + +extern "C" { +volatile uint32_t g_case16_q_row[4] __attribute__((aligned(16))) = { + WU_BW_FP16_ONE_PACKED, WU_BW_FP16_ONE_PACKED, WU_BW_FP16_ONE_PACKED, + WU_BW_FP16_ONE_PACKED}; +volatile uint32_t g_case16_zero_row[4] __attribute__((aligned(16))) = { + WU_BW_FP32_ZERO, WU_BW_FP32_ZERO, WU_BW_FP32_ZERO, WU_BW_FP32_ZERO}; +volatile uint32_t g_case16_out[WU_BW_OUT_WORDS] __attribute__((aligned(16))); +volatile uint32_t g_case16_scalar_seen[4] __attribute__((aligned(16))); +} + +extern "C" void __attribute__((naked, noinline, used)) tensor_case16_worker() { + asm volatile( + "csrr x5, %[csr_wid]\n\t" + "addi x1, x5, -%[num_scalar_warps]\n\t" + "slli x1, x1, 11\n\t" + "addi x2, x1, %[c_offset]\n\t" + "la x3, g_case16_q_row\n\t" + "li x7, 0\n\t" + "1:\n\t" + "add x4, x1, x7\n\t" + ".insn r %[custom3], 2, 0, x0, x4, x3\n\t" + "addi x7, x7, 16\n\t" + "li x4, %[tile_bytes]\n\t" + "blt x7, x4, 1b\n\t" + "la x3, g_case16_zero_row\n\t" + "li x7, 0\n\t" + "2:\n\t" + "add x4, x2, x7\n\t" + ".insn r %[custom3], 2, 0, x0, x4, x3\n\t" + "addi x7, x7, 16\n\t" + "li x4, %[tile_bytes]\n\t" + "blt x7, x4, 2b\n\t" + ".insn r %[custom3], 3, 0, x0, x0, x0\n\t" + "li x4, %[smem_base]\n\t" + ".insn r %[custom3], 0, 0, x2, x1, x4\n\t" + ".insn r %[custom3], 1, 0, x0, x0, x0\n\t" + "slli x6, x5, 2\n\t" + "la x7, g_seen\n\t" + "add x7, x7, x6\n\t" + "li x6, %[init_base]\n\t" + "or x6, x6, x5\n\t" + "sw x6, 0(x7)\n\t" + "3:\n\t" + "la x6, g_case_mem\n\t" + "lw x7, 0(x6)\n\t" + "li x4, %[p_ready]\n\t" + "bne x7, x4, 3b\n\t" + "la x3, g_case16_zero_row\n\t" + "li x7, 0\n\t" + "4:\n\t" + "add x4, x2, x7\n\t" + ".insn r %[custom3], 2, 0, x0, x4, x3\n\t" + "addi x7, x7, 16\n\t" + "li x4, %[tile_bytes]\n\t" + "blt x7, x4, 4b\n\t" + ".insn r %[custom3], 3, 0, x0, x0, x0\n\t" + "li x4, %[smem_base]\n\t" + ".insn r %[custom3], 0, 0, x2, x1, x4\n\t" + ".insn r %[custom3], 1, 0, x0, x0, x0\n\t" + "la x3, g_case16_out\n\t" + "li x7, 0\n\t" + "5:\n\t" + "add x4, x2, x7\n\t" + "add x6, x3, x7\n\t" + ".insn r %[custom3], 6, 0, x0, x4, x6\n\t" + "addi x7, x7, 16\n\t" + "li x4, %[tile_bytes]\n\t" + "blt x7, x4, 5b\n\t" + ".insn r %[custom3], 3, 0, x0, x0, x0\n\t" + "csrr x5, %[csr_wid]\n\t" + "slli x6, x5, 2\n\t" + "la x7, g_seen\n\t" + "add x7, x7, x6\n\t" + "li x6, %[done_base]\n\t" + "or x6, x6, x5\n\t" + "sw x6, 0(x7)\n\t" + ".insn r %[custom0], 0, 0, x0, x0, x0\n\t" + "6: j 6b\n\t" + : + : [csr_wid] "i"(VX_CSR_WARP_ID), [custom0] "i"(RISCV_CUSTOM0), + [custom3] "i"(RISCV_CUSTOM3), + [num_scalar_warps] "i"(NUM_SCALAR_WARPS), + [c_offset] "i"(WU_BW_TMEM_C_BYTE_OFFSET), + [tile_bytes] "i"(WU_BW_TMEM_TILE_BYTES), + [smem_base] "i"(WU_BW_DEV_SMEM_START_ADDR), + [init_base] "i"(WU_CASE16_INIT_BASE), + [p_ready] "i"(WU_CASE16_P_READY), + [done_base] "i"(WU_CASE16_DONE_BASE) + : "memory"); +} + +extern "C" int wu_main() { + if (vx_core_id() != 0 || vx_warp_id() != 0) { + return 0; + } + + const uint32_t tid = wu_tid(); + const uint32_t tensor_mask = 1u << NUM_SCALAR_WARPS; + if (tid == 0) { + wu_case_reset(); + for (uint32_t i = 0; i < WU_BW_OUT_WORDS; ++i) { + g_case16_out[i] = 0; + } + for (uint32_t i = 0; i < 4; ++i) { + g_case16_scalar_seen[i] = 0; + } + wu_bw_fill_smem_tile( + reinterpret_cast(WU_BW_DEV_SMEM_START_ADDR), + WU_BW_FP16_ONE_PACKED); + vx_spawn_tensor(tensor_mask, tensor_case16_worker); + if (wu_wait_seen_mask(tensor_mask, WU_CASE16_INIT_BASE) != 0) { + g_case_mem[1] = 0x61u; + } + } + asm volatile("fence rw, rw" ::: "memory"); + + const uint32_t c_frag = + wu_bw_tmem_c_byte_base(0) / WU_BW_TMEM_FRAGMENT_BYTES; + const uint32_t observed = wu_bw_scalar_tmem_ld(c_frag); + g_case16_scalar_seen[tid] = observed; + asm volatile("fence rw, rw" ::: "memory"); + + if (tid == 0 && g_case_mem[1] == 0 && + g_case16_scalar_seen[0] != WU_BW_FP32_THIRTY_TWO) { + g_aux[0] = 0; + g_aux[1] = g_case16_scalar_seen[0]; + g_case_mem[1] = 0x62u; + } + asm volatile("fence rw, rw" ::: "memory"); + + if (g_case_mem[1] == 0) { + vx_tmc(wu_bw_all_lanes_mask()); + wu_bw_fill_tmem_tile(wu_bw_tmem_a_byte_base(0), + WU_BW_FP16_ONE_OVER_32_PACKED); + vx_tmc_one(); + } + asm volatile("fence rw, rw" ::: "memory"); + + if (tid == 0) { + if (g_case_mem[1] == 0) { + wu_bw_fill_smem_tile( + reinterpret_cast(WU_BW_DEV_SMEM_START_ADDR), + WU_BW_FP16_TWO_PACKED); + } + asm volatile("fence rw, rw" ::: "memory"); + g_case_mem[0] = WU_CASE16_P_READY; + if (g_case_mem[1] == 0 && + wu_wait_seen_mask(tensor_mask, WU_CASE16_DONE_BASE) != 0) { + g_case_mem[1] = 0x63u; + } + if (g_case_mem[1] == 0) { + volatile uint32_t bad_actual = 0; + const uint32_t bad = + wu_bw_verify_constant(g_case16_out, WU_BW_OUT_WORDS, + WU_BW_FP32_TWO, &bad_actual); + if (bad != WU_BW_OUT_WORDS) { + g_aux[0] = bad; + g_aux[1] = bad_actual; + g_case_mem[1] = 0x64u; + } + } + if (g_case_mem[1] != 0) { + wu_case_fail(g_case_mem[1]); + return 1; + } + wu_case_pass(); + } + return 0; +} diff --git a/kernels/wu_arch_cases/case17_flash_exp_softmax_probe/Makefile b/kernels/wu_arch_cases/case17_flash_exp_softmax_probe/Makefile new file mode 100644 index 00000000..f35fc2e8 --- /dev/null +++ b/kernels/wu_arch_cases/case17_flash_exp_softmax_probe/Makefile @@ -0,0 +1,3 @@ +PROJECT = case17_flash_exp_softmax_probe + +include ../case.mk diff --git a/kernels/wu_arch_cases/case17_flash_exp_softmax_probe/README.md b/kernels/wu_arch_cases/case17_flash_exp_softmax_probe/README.md new file mode 100644 index 00000000..719c91d4 --- /dev/null +++ b/kernels/wu_arch_cases/case17_flash_exp_softmax_probe/README.md @@ -0,0 +1,16 @@ +# case17_flash_exp_softmax_probe + +Validates that the current scalar Wu path can execute the `e^x` work needed by +non-uniform FlashAttention softmax. The current ISA/RTL does not expose a +dedicated exp or exp2 instruction, so this case uses scalar fp32 arithmetic to +approximate exp. + +The case evaluates a two-element row with scores `{0, ln(2)}`. A numerically +stable softmax computes `exp(score - row_max)`, so the exp inputs are +`{-ln(2), 0}` and the probabilities should be close to `{1/3, 2/3}`. The +inputs are loaded from volatile memory so the compiler cannot fold the result +into constants. + +This is intentionally separate from the tensor PV path. If this case fails, the +problem is in scalar fp32 arithmetic, exp approximation, or normalization rather +than TMEM handoff or BWGMMA. diff --git a/kernels/wu_arch_cases/case17_flash_exp_softmax_probe/kernel.cpp b/kernels/wu_arch_cases/case17_flash_exp_softmax_probe/kernel.cpp new file mode 100644 index 00000000..c5cb8111 --- /dev/null +++ b/kernels/wu_arch_cases/case17_flash_exp_softmax_probe/kernel.cpp @@ -0,0 +1,81 @@ +#include "../common_wu_min.h" + +extern "C" { +volatile uint32_t g_case17_scores_bits[2] __attribute__((aligned(16))) = { + 0x00000000u, 0x3f317218u}; // 0.0, ln(2) +volatile uint32_t g_case17_out_bits[2] __attribute__((aligned(16))); +} + +static inline float wu_case17_bits_to_f32(uint32_t bits) { + union { + uint32_t u; + float f; + } v = {bits}; + return v.f; +} + +static inline uint32_t wu_case17_f32_to_bits(float value) { + union { + float f; + uint32_t u; + } v = {value}; + return v.u; +} + +static inline float wu_case17_absf(float value) { + return value < 0.0f ? -value : value; +} + +static inline float wu_case17_exp_neg_ln2_to_0(float x) { + const float x2 = x * x; + const float x3 = x2 * x; + const float x4 = x2 * x2; + const float x5 = x4 * x; + return 1.0f + x + (0.5f * x2) + (0.1666666716f * x3) + + (0.0416666679f * x4) + (0.0083333338f * x5); +} + +extern "C" int wu_main() { + if (vx_core_id() != 0 || vx_warp_id() != 0) { + return 0; + } + + const uint32_t tid = wu_tid(); + if (tid == 0) { + wu_case_reset(); + g_case17_out_bits[0] = 0; + g_case17_out_bits[1] = 0; + } + asm volatile("fence rw, rw" ::: "memory"); + + const float score0 = wu_case17_bits_to_f32(g_case17_scores_bits[0]); + const float score1 = wu_case17_bits_to_f32(g_case17_scores_bits[1]); + const float row_max = score0 > score1 ? score0 : score1; + const float e0 = wu_case17_exp_neg_ln2_to_0(score0 - row_max); + const float e1 = wu_case17_exp_neg_ln2_to_0(score1 - row_max); + const float inv_sum = 1.0f / (e0 + e1); + const float p0 = e0 * inv_sum; + const float p1 = e1 * inv_sum; + + if (tid == 0) { + g_case17_out_bits[0] = wu_case17_f32_to_bits(p0); + g_case17_out_bits[1] = wu_case17_f32_to_bits(p1); + } + asm volatile("fence rw, rw" ::: "memory"); + + if (tid == 0) { + const float expected0 = 0.3333333433f; + const float expected1 = 0.6666666865f; + const float tolerance = 0.0015f; + const float err0 = wu_case17_absf(p0 - expected0); + const float err1 = wu_case17_absf(p1 - expected1); + if (err0 > tolerance || err1 > tolerance) { + g_aux[0] = g_case17_out_bits[0]; + g_aux[1] = g_case17_out_bits[1]; + wu_case_fail(0x71u); + return 1; + } + wu_case_pass(); + } + return 0; +} diff --git a/kernels/wu_arch_cases/common_wu_blackwell_fa.h b/kernels/wu_arch_cases/common_wu_blackwell_fa.h new file mode 100644 index 00000000..68f4c029 --- /dev/null +++ b/kernels/wu_arch_cases/common_wu_blackwell_fa.h @@ -0,0 +1,101 @@ +#ifndef WU_ARCH_CASES_COMMON_WU_BLACKWELL_FA_H +#define WU_ARCH_CASES_COMMON_WU_BLACKWELL_FA_H + +#include "common_wu_min.h" + +#define WU_BW_DEV_SMEM_START_ADDR 0xff000000u +#define WU_BW_TMEM_TILE_BYTES 1024u +#define WU_BW_TMEM_FRAGMENT_BYTES 16u +#define WU_BW_TMEM_WARP_BYTES 2048u +#define WU_BW_TMEM_C_BYTE_OFFSET 1024u +#define WU_BW_TMEM_FRAGMENTS \ + (WU_BW_TMEM_TILE_BYTES / WU_BW_TMEM_FRAGMENT_BYTES) +#define WU_BW_TMEM_WORDS (WU_BW_TMEM_TILE_BYTES / sizeof(uint32_t)) +#define WU_BW_OUT_WORDS WU_BW_TMEM_WORDS + +#define WU_BW_FP16_ONE_PACKED 0x3c003c00u +#define WU_BW_FP16_TWO_PACKED 0x40004000u +#define WU_BW_FP32_ONE 0x3f800000u +#define WU_BW_FP32_THREE 0x40400000u +#define WU_BW_FP32_FIVE 0x40a00000u +#define WU_BW_FP32_SIXTY_FIVE 0x42820000u +#define WU_BW_FP32_SIXTY_SEVEN 0x42860000u +#define WU_BW_FP32_ONE_THIRTY_THREE 0x43050000u + +static_assert(NUM_SCALAR_WARPS == 2, + "Wu Blackwell FA staged cases expect two scalar warps"); +static_assert(NUM_TENSOR_WARPS == 2, + "Wu Blackwell FA staged cases expect two tensor warps"); +static_assert(NUM_THREADS == 4, + "Wu Blackwell FA staged cases expect the 4-lane tensor core"); + +static inline uint32_t wu_bw_all_lanes_mask() { + return (1u << NUM_THREADS) - 1u; +} + +static inline uint32_t wu_bw_scalar_tmem_ld(uint32_t frag_addr) { + uint32_t value; + asm volatile(".insn r %[custom1], 0, 0x30, %[value], %[addr], x0" + : [value] "=r"(value) + : [custom1] "i"(RISCV_CUSTOM1), [addr] "r"(frag_addr) + : "memory"); + return value; +} + +static inline void wu_bw_scalar_tmem_st(uint32_t frag_addr, uint32_t value) { + asm volatile(".insn r %[custom1], 1, 0x30, x0, %[addr], %[value]" + : + : [custom1] "i"(RISCV_CUSTOM1), [addr] "r"(frag_addr), + [value] "r"(value) + : "memory"); +} + +static inline uint32_t wu_bw_tmem_a_byte_base(uint32_t tensor_block) { + return tensor_block * WU_BW_TMEM_WARP_BYTES; +} + +static inline uint32_t wu_bw_tmem_c_byte_base(uint32_t tensor_block) { + return wu_bw_tmem_a_byte_base(tensor_block) + WU_BW_TMEM_C_BYTE_OFFSET; +} + +static inline void wu_bw_fill_tmem_tile(uint32_t byte_base, uint32_t value) { + const uint32_t first_frag = byte_base / WU_BW_TMEM_FRAGMENT_BYTES; + for (uint32_t frag = 0; frag < WU_BW_TMEM_FRAGMENTS; ++frag) { + // Caller must run this with all TMEM fragment lanes active. + wu_bw_scalar_tmem_st(first_frag + frag, value); + } + asm volatile("fence rw, rw" ::: "memory"); +} + +static inline void wu_bw_store_tmem_fragment_lane_values(uint32_t byte_base, + uint32_t frag, + uint32_t lane_value) { + const uint32_t first_frag = byte_base / WU_BW_TMEM_FRAGMENT_BYTES; + wu_bw_scalar_tmem_st(first_frag + frag, lane_value); + asm volatile("fence rw, rw" ::: "memory"); +} + +static inline void wu_bw_fill_smem_tile(volatile uint32_t *smem, + uint32_t value) { + for (uint32_t i = 0; i < WU_BW_TMEM_WORDS; ++i) { + smem[i] = value; + } + asm volatile("fence rw, rw" ::: "memory"); +} + +static inline uint32_t wu_bw_verify_constant(const volatile uint32_t *data, + uint32_t words, + uint32_t expected, + volatile uint32_t *bad_actual) { + for (uint32_t i = 0; i < words; ++i) { + const uint32_t actual = data[i]; + if (actual != expected) { + *bad_actual = actual; + return i; + } + } + *bad_actual = 0; + return words; +} + +#endif diff --git a/kernels/wu_arch_hgemm/README.md b/kernels/wu_arch_hgemm/README.md index 968d53f3..e79a954f 100644 --- a/kernels/wu_arch_hgemm/README.md +++ b/kernels/wu_arch_hgemm/README.md @@ -1,9 +1,15 @@ # wu_arch_hgemm -Tensor-warp HGEMM smoke test for the Wu split scalar/tensor warp configuration -with the 4-lane Blackwell tensor-core path. +Two-tensor-warp HGEMM smoke test for the Wu split scalar/tensor warp +configuration with the 4-lane Blackwell tensor-core path. -Scalar warp 0 initializes the shared-memory B operand, spawns only the tensor -warp mask, waits for tensor warps `NUM_SCALAR_WARPS..NUM_WARPS-1`, and reports -completion through `tohost`. Tensor warps execute the Blackwell custom HGEMM -instruction sequence using 16-byte fragments and then stop themselves. +Scalar warp 0 initializes the shared-memory B operand for a 32x16x32 GEMM, +spawns only the tensor warp mask, waits for tensor warps +`NUM_SCALAR_WARPS..NUM_WARPS-1`, verifies the combined 32x16 fp32 output, and +reports completion through `tohost`. + +Each tensor warp computes one 16x16x32 BWGMMA M-block using 16-byte fragments: +warp `NUM_SCALAR_WARPS` produces rows 0..15 and warp `NUM_SCALAR_WARPS + 1` +produces rows 16..31. Both tensor warps share the same B tile, copy their C +tiles back into one contiguous output matrix, mark their block IDs, and then +stop themselves. diff --git a/kernels/wu_arch_hgemm/kernel.cpp b/kernels/wu_arch_hgemm/kernel.cpp index 97b81026..7c3bed45 100644 --- a/kernels/wu_arch_hgemm/kernel.cpp +++ b/kernels/wu_arch_hgemm/kernel.cpp @@ -2,6 +2,24 @@ #define DEV_SMEM_START_ADDR 0xff000000u #define WU_CASE_TENSOR_HGEMM_BASE 0x7500u +#define WU_CASE_TENSOR_HGEMM_BLOCK_BASE 0x7600u + +#define WU_HGEMM_M_BLOCKS 2u +#define WU_HGEMM_M_PER_BLOCK 16u +#define WU_HGEMM_N 16u +#define WU_HGEMM_K 32u +#define WU_HGEMM_TILE_BYTES 1024u +#define WU_HGEMM_FRAGMENT_BYTES 16u +#define WU_HGEMM_B_WORDS (WU_HGEMM_TILE_BYTES / sizeof(uint32_t)) +#define WU_HGEMM_OUT_WORDS \ + (WU_HGEMM_M_BLOCKS * WU_HGEMM_M_PER_BLOCK * WU_HGEMM_N) +#define WU_HGEMM_EXPECTED 0x42820000u +#define WU_HGEMM_NO_FAIL WU_HGEMM_OUT_WORDS + +static_assert(NUM_TENSOR_WARPS == WU_HGEMM_M_BLOCKS, + "wu_arch_hgemm expects two tensor warps"); +static_assert((NUM_THREADS * 2u) <= WU_CASE_MAX_WARPS, + "wu_arch_hgemm uses g_aux[2 * tid + {0,1}] for check scratch"); #define BW_REP2(x) x, x #define BW_REP4(x) BW_REP2(x), BW_REP2(x) @@ -13,6 +31,9 @@ volatile uint32_t g_hgemm_b_row[4] __attribute__((aligned(16))) = { BW_REP4(0x40004000u)}; volatile uint32_t g_hgemm_c_row[4] __attribute__((aligned(16))) = { BW_REP4(0x3f800000u)}; +volatile uint32_t g_hgemm_control_fail __attribute__((aligned(16))); +volatile uint32_t g_hgemm_return_code __attribute__((aligned(16))); +volatile uint32_t g_hgemm_out[WU_HGEMM_OUT_WORDS] __attribute__((aligned(16))); } #undef BW_REP2 @@ -21,7 +42,8 @@ volatile uint32_t g_hgemm_c_row[4] __attribute__((aligned(16))) = { extern "C" void __attribute__((naked, noinline, used)) tensor_hgemm_worker() { asm volatile( "csrr x5, %[csr_wid]\n\t" - "slli x1, x5, 11\n\t" + "addi x1, x5, -%[num_scalar_warps]\n\t" + "slli x1, x1, 11\n\t" "addi x2, x1, 1024\n\t" "la x6, g_hgemm_a_row\n\t" "la x3, g_hgemm_c_row\n\t" @@ -39,6 +61,25 @@ extern "C" void __attribute__((naked, noinline, used)) tensor_hgemm_worker() { ".insn r %[custom3], 0, 0, x2, x1, x4\n\t" ".insn r %[custom3], 1, 0, x0, x0, x0\n\t" "csrr x5, %[csr_wid]\n\t" + "addi x6, x5, -%[num_scalar_warps]\n\t" + "slli x7, x6, 10\n\t" + "la x3, g_hgemm_out\n\t" + "add x3, x3, x7\n\t" + "li x7, 0\n\t" + "3:\n\t" + "add x4, x2, x7\n\t" + "add x1, x3, x7\n\t" + ".insn r %[custom3], 6, 0, x0, x4, x1\n\t" + "addi x7, x7, 16\n\t" + "li x4, %[tile_bytes]\n\t" + "blt x7, x4, 3b\n\t" + ".insn r %[custom3], 3, 0, x0, x0, x0\n\t" + "slli x7, x6, 2\n\t" + "la x3, g_case_mem\n\t" + "add x3, x3, x7\n\t" + "li x7, %[hgemm_block_base]\n\t" + "or x7, x7, x6\n\t" + "sw x7, 0(x3)\n\t" "slli x6, x5, 2\n\t" "la x7, g_seen\n\t" "add x7, x7, x6\n\t" @@ -52,34 +93,91 @@ extern "C" void __attribute__((naked, noinline, used)) tensor_hgemm_worker() { [custom0] "i"(RISCV_CUSTOM0), [custom3] "i"(RISCV_CUSTOM3), [smem_base] "i"(DEV_SMEM_START_ADDR), - [hgemm_base] "i"(WU_CASE_TENSOR_HGEMM_BASE) + [hgemm_base] "i"(WU_CASE_TENSOR_HGEMM_BASE), + [hgemm_block_base] "i"(WU_CASE_TENSOR_HGEMM_BLOCK_BASE), + [num_scalar_warps] "i"(NUM_SCALAR_WARPS), + [tile_bytes] "i"(WU_HGEMM_TILE_BYTES) : "memory"); } extern "C" int wu_main() { - if (!wu_is_leader()) { + if (vx_core_id() != 0 || vx_warp_id() != 0) { return 0; } - wu_case_reset(); + const uint32_t tid = wu_tid(); - volatile uint32_t *smem_b = - reinterpret_cast(DEV_SMEM_START_ADDR); - for (uint32_t frag = 0; frag < 64u; ++frag) { - const uint32_t row = frag * 4u; - for (uint32_t i = 0; i < 4u; ++i) { - smem_b[row + i] = g_hgemm_b_row[i]; + if (tid == 0) { + wu_case_reset(); + g_hgemm_control_fail = 0; + g_hgemm_return_code = 0; + } + asm volatile("fence rw, rw" ::: "memory"); + + if (tid == 0) { + for (uint32_t i = 0; i < WU_HGEMM_OUT_WORDS; ++i) { + g_hgemm_out[i] = 0; + } + + volatile uint32_t *smem_b = + reinterpret_cast(DEV_SMEM_START_ADDR); + for (uint32_t i = 0; i < WU_HGEMM_B_WORDS; ++i) { + smem_b[i] = g_hgemm_b_row[i & 3u]; } } + asm volatile("fence rw, rw" ::: "memory"); - vx_spawn_tensor(vx_tensor_warp_mask(), tensor_hgemm_worker); + if (tid == 0) { + vx_spawn_tensor(vx_tensor_warp_mask(), tensor_hgemm_worker); - if (wu_wait_seen_range(NUM_SCALAR_WARPS, NUM_WARPS, - WU_CASE_TENSOR_HGEMM_BASE) != 0) { - wu_case_fail(0x09u); - return 1; + if (wu_wait_seen_range(NUM_SCALAR_WARPS, NUM_WARPS, + WU_CASE_TENSOR_HGEMM_BASE) != 0) { + g_hgemm_control_fail = 0x09u; + } + + if (g_hgemm_control_fail == 0) { + for (uint32_t block = 0; block < WU_HGEMM_M_BLOCKS; ++block) { + if (g_case_mem[block] != (WU_CASE_TENSOR_HGEMM_BLOCK_BASE | block)) { + g_hgemm_control_fail = 0x0au + block; + break; + } + } + } + } + asm volatile("fence rw, rw" ::: "memory"); + + if (g_hgemm_control_fail != 0) { + if (tid == 0) { + g_hgemm_return_code = 1; + wu_case_fail(g_hgemm_control_fail); + } + asm volatile("fence rw, rw" ::: "memory"); + return static_cast(g_hgemm_return_code); } - wu_case_pass(); - return 0; + if (tid == 0) { + uint32_t bad_index = WU_HGEMM_NO_FAIL; + uint32_t bad_actual = 0; + + for (uint32_t i = 0; i < WU_HGEMM_OUT_WORDS; ++i) { + const uint32_t actual = g_hgemm_out[i]; + if (actual != WU_HGEMM_EXPECTED) { + bad_index = i; + bad_actual = actual; + break; + } + } + + if (bad_index != WU_HGEMM_NO_FAIL) { + g_aux[0] = bad_index; + g_aux[1] = bad_actual; + g_hgemm_return_code = 1; + wu_case_fail(0x20u); + } else { + g_hgemm_return_code = 0; + wu_case_pass(); + } + } + asm volatile("fence rw, rw" ::: "memory"); + return static_cast(g_hgemm_return_code); }