diff --git a/kernels/blackwell_fp8_e4m3/Makefile b/kernels/blackwell_fp8_e4m3/Makefile new file mode 100644 index 00000000..57cd1c04 --- /dev/null +++ b/kernels/blackwell_fp8_e4m3/Makefile @@ -0,0 +1,10 @@ +PROJECT = blackwell_fp8_e4m3 + +VX_SRCS = kernel.cpp +VX_INCLUDES = fp8_common.hpp +OPTS ?= -n1 + +include ../common.mk + +args.bin input.a.bin input.b.bin input.c.bin: ../wu_arch_cases/zero.bin + cp $< $@ diff --git a/kernels/blackwell_fp8_e4m3/README.md b/kernels/blackwell_fp8_e4m3/README.md new file mode 100644 index 00000000..b34facf6 --- /dev/null +++ b/kernels/blackwell_fp8_e4m3/README.md @@ -0,0 +1,19 @@ +# blackwell_fp8_e4m3 + +Standalone FP8 E4M3 validation kernel for the Wu Blackwell BWGMMA branch. + +This directory is the only kernel area used by the FP8 branch work. Existing +FP16 HGEMM, `wu_arch_cases`, and flash kernels are intentionally left unchanged. + +The validation runs one tensor warp on a 16x16x32 tile: + +- A is FP8 E4M3 1.0 (`0x38`) +- B is FP8 E4M3 2.0 (`0x40`) +- C is FP32 1.0 (`0x3f800000`) +- Expected output is FP32 65.0 (`0x42820000`) + +Build: + +```bash +make -C /home/lzd/wu/wuarch/virgo-kernels/kernels/blackwell_fp8_e4m3 +``` diff --git a/kernels/blackwell_fp8_e4m3/fp8_common.hpp b/kernels/blackwell_fp8_e4m3/fp8_common.hpp new file mode 100644 index 00000000..e32ab723 --- /dev/null +++ b/kernels/blackwell_fp8_e4m3/fp8_common.hpp @@ -0,0 +1,53 @@ +#ifndef BLACKWELL_FP8_E4M3_COMMON_HPP +#define BLACKWELL_FP8_E4M3_COMMON_HPP + +#include +#include + +#define WU_FP8_E4M3_ZERO 0x00u +#define WU_FP8_E4M3_HALF 0x30u +#define WU_FP8_E4M3_ONE 0x38u +#define WU_FP8_E4M3_TWO 0x40u + +#define WU_FP8_PACK4(a, b, c, d) \ + ((((uint32_t)(a) & 0xffu) << 0) | (((uint32_t)(b) & 0xffu) << 8) | \ + (((uint32_t)(c) & 0xffu) << 16) | (((uint32_t)(d) & 0xffu) << 24)) + +#define WU_FP8_REP2(x) x, x +#define WU_FP8_REP4(x) WU_FP8_REP2(x), WU_FP8_REP2(x) +#define WU_FP8_REP8(x) WU_FP8_REP4(x), WU_FP8_REP4(x) + +static inline void wu_tcgen05_cp(uint32_t addr_tmem, uint32_t addr_gmem) { + asm volatile(".insn r %0, 2, 0, x0, %1, %2" + : + : "i"(RISCV_CUSTOM3), "r"(addr_tmem), "r"(addr_gmem) + : "memory"); +} + +static inline void wu_tcgen05_cp_wait() { + asm volatile(".insn r %0, 3, 0, x0, x0, x0" :: "i"(RISCV_CUSTOM3) + : "memory"); +} + +static inline void wu_tcgen05_cb(uint32_t addr_tmem, uint32_t addr_gmem) { + asm volatile(".insn r %0, 6, 0, x0, %1, %2" + : + : "i"(RISCV_CUSTOM3), "r"(addr_tmem), "r"(addr_gmem) + : "memory"); +} + +static inline void wu_bwgmma_fp8(uint32_t addr_tmem_c, uint32_t addr_tmem_a, + uint32_t addr_smem_b) { + asm volatile(".insn r %0, 0, 0, %1, %2, %3" + : + : "i"(RISCV_CUSTOM3), "r"(addr_tmem_c), "r"(addr_tmem_a), + "r"(addr_smem_b) + : "memory"); +} + +static inline void wu_bwgmma_wait() { + asm volatile(".insn r %0, 1, 0, x0, x0, x0" :: "i"(RISCV_CUSTOM3) + : "memory"); +} + +#endif diff --git a/kernels/blackwell_fp8_e4m3/kernel.cpp b/kernels/blackwell_fp8_e4m3/kernel.cpp new file mode 100644 index 00000000..9758277f --- /dev/null +++ b/kernels/blackwell_fp8_e4m3/kernel.cpp @@ -0,0 +1,121 @@ +#include "fp8_common.hpp" +#include "../wu_arch_cases/common_wu_min.h" + +#define DEV_SMEM_START_ADDR 0xff000000u +#define FP8_VALIDATION_BASE 0x7800u + +#define FP8_M 16u +#define FP8_N 16u +#define FP8_K 32u +#define FP8_TILE_BYTES 1024u +#define FP8_FRAGMENT_BYTES 32u +#define FP8_FRAGMENT_WORDS (FP8_FRAGMENT_BYTES / sizeof(uint32_t)) +#define FP8_FRAGMENTS (FP8_TILE_BYTES / FP8_FRAGMENT_BYTES) +#define FP8_OUT_WORDS (FP8_M * FP8_N) +#define FP8_EXPECTED 0x42820000u + +extern "C" { +volatile uint32_t g_fp8_a_frag[FP8_FRAGMENT_WORDS] __attribute__((aligned(32))) = { + WU_FP8_REP8(WU_FP8_PACK4(WU_FP8_E4M3_ONE, WU_FP8_E4M3_ONE, + WU_FP8_E4M3_ONE, WU_FP8_E4M3_ONE))}; +volatile uint32_t g_fp8_b_frag[FP8_FRAGMENT_WORDS] __attribute__((aligned(32))) = { + WU_FP8_REP8(WU_FP8_PACK4(WU_FP8_E4M3_TWO, WU_FP8_E4M3_TWO, + WU_FP8_E4M3_TWO, WU_FP8_E4M3_TWO))}; +volatile uint32_t g_fp8_c_frag[FP8_FRAGMENT_WORDS] __attribute__((aligned(32))) = { + WU_FP8_REP8(0x3f800000u)}; +volatile uint32_t g_fp8_out[FP8_OUT_WORDS] __attribute__((aligned(32))); +} + +#undef WU_FP8_REP2 +#undef WU_FP8_REP4 +#undef WU_FP8_REP8 + +extern "C" void __attribute__((naked, noinline, used)) fp8_validation_worker() { + asm volatile( + "li x1, 0\n\t" + "li x2, %[tile_bytes]\n\t" + "la x6, g_fp8_a_frag\n\t" + "la x3, g_fp8_c_frag\n\t" + "li x7, 0\n\t" + "1:\n\t" + "add x4, x1, x7\n\t" + ".insn r %[custom3], 2, 0, x0, x4, x6\n\t" + "add x4, x2, x7\n\t" + ".insn r %[custom3], 2, 0, x0, x4, x3\n\t" + "addi x7, x7, %[frag_bytes]\n\t" + "li x4, %[tile_bytes]\n\t" + "blt x7, x4, 1b\n\t" + ".insn r %[custom3], 3, 0, x0, x0, x0\n\t" + "li x4, %[smem_base]\n\t" + ".insn r %[custom3], 0, 0, x2, x1, x4\n\t" + ".insn r %[custom3], 1, 0, x0, x0, x0\n\t" + "la x3, g_fp8_out\n\t" + "li x7, 0\n\t" + "2:\n\t" + "add x4, x2, x7\n\t" + "add x1, x3, x7\n\t" + ".insn r %[custom3], 6, 0, x0, x4, x1\n\t" + "addi x7, x7, %[frag_bytes]\n\t" + "li x4, %[tile_bytes]\n\t" + "blt x7, x4, 2b\n\t" + ".insn r %[custom3], 3, 0, x0, x0, x0\n\t" + "csrr x5, %[csr_wid]\n\t" + "slli x6, x5, 2\n\t" + "la x7, g_seen\n\t" + "add x7, x7, x6\n\t" + "li x6, %[done_base]\n\t" + "or x6, x6, x5\n\t" + "sw x6, 0(x7)\n\t" + ".insn r %[custom0], 0, 0, x0, x0, x0\n\t" + "3: j 3b\n\t" + : + : [csr_wid] "i"(VX_CSR_WARP_ID), + [custom0] "i"(RISCV_CUSTOM0), + [custom3] "i"(RISCV_CUSTOM3), + [smem_base] "i"(DEV_SMEM_START_ADDR), + [done_base] "i"(FP8_VALIDATION_BASE), + [tile_bytes] "i"(FP8_TILE_BYTES), + [frag_bytes] "i"(FP8_FRAGMENT_BYTES) + : "memory"); +} + +extern "C" int wu_main() { + if (!wu_is_leader()) { + return 0; + } + + wu_case_reset(); + + for (uint32_t i = 0; i < FP8_OUT_WORDS; ++i) { + g_fp8_out[i] = 0; + } + + volatile uint32_t *smem_b = + reinterpret_cast(DEV_SMEM_START_ADDR); + for (uint32_t frag = 0; frag < FP8_FRAGMENTS; ++frag) { + const uint32_t row = frag * FP8_FRAGMENT_WORDS; + for (uint32_t i = 0; i < FP8_FRAGMENT_WORDS; ++i) { + smem_b[row + i] = g_fp8_b_frag[i]; + } + } + + const uint32_t tensor_wid = NUM_SCALAR_WARPS; + vx_spawn_tensor(1u << tensor_wid, fp8_validation_worker); + + if (wu_wait_seen_mask(1u << tensor_wid, FP8_VALIDATION_BASE) != 0) { + wu_case_fail(0x09u); + return 1; + } + + for (uint32_t i = 0; i < FP8_OUT_WORDS; ++i) { + if (g_fp8_out[i] != FP8_EXPECTED) { + g_aux[0] = i; + g_aux[1] = g_fp8_out[i]; + wu_case_fail(0x20u); + return 1; + } + } + + wu_case_pass(); + return 0; +}