feat: add blackwell fp8 e4m3 kernel

This commit is contained in:
Zhongdi LUO
2026-07-02 10:40:29 +00:00
parent f1aa1303d2
commit 3f7ce1f1c9
4 changed files with 203 additions and 0 deletions

View File

@@ -0,0 +1,10 @@
PROJECT = blackwell_fp8_e4m3
VX_SRCS = kernel.cpp
VX_INCLUDES = fp8_common.hpp
OPTS ?= -n1
include ../common.mk
args.bin input.a.bin input.b.bin input.c.bin: ../wu_arch_cases/zero.bin
cp $< $@

View File

@@ -0,0 +1,19 @@
# blackwell_fp8_e4m3
Standalone FP8 E4M3 validation kernel for the Wu Blackwell BWGMMA branch.
This directory is the only kernel area used by the FP8 branch work. Existing
FP16 HGEMM, `wu_arch_cases`, and flash kernels are intentionally left unchanged.
The validation runs one tensor warp on a 16x16x32 tile:
- A is FP8 E4M3 1.0 (`0x38`)
- B is FP8 E4M3 2.0 (`0x40`)
- C is FP32 1.0 (`0x3f800000`)
- Expected output is FP32 65.0 (`0x42820000`)
Build:
```bash
make -C /home/lzd/wu/wuarch/virgo-kernels/kernels/blackwell_fp8_e4m3
```

View File

@@ -0,0 +1,53 @@
#ifndef BLACKWELL_FP8_E4M3_COMMON_HPP
#define BLACKWELL_FP8_E4M3_COMMON_HPP
#include <stdint.h>
#include <vx_intrinsics.h>
#define WU_FP8_E4M3_ZERO 0x00u
#define WU_FP8_E4M3_HALF 0x30u
#define WU_FP8_E4M3_ONE 0x38u
#define WU_FP8_E4M3_TWO 0x40u
#define WU_FP8_PACK4(a, b, c, d) \
((((uint32_t)(a) & 0xffu) << 0) | (((uint32_t)(b) & 0xffu) << 8) | \
(((uint32_t)(c) & 0xffu) << 16) | (((uint32_t)(d) & 0xffu) << 24))
#define WU_FP8_REP2(x) x, x
#define WU_FP8_REP4(x) WU_FP8_REP2(x), WU_FP8_REP2(x)
#define WU_FP8_REP8(x) WU_FP8_REP4(x), WU_FP8_REP4(x)
static inline void wu_tcgen05_cp(uint32_t addr_tmem, uint32_t addr_gmem) {
asm volatile(".insn r %0, 2, 0, x0, %1, %2"
:
: "i"(RISCV_CUSTOM3), "r"(addr_tmem), "r"(addr_gmem)
: "memory");
}
static inline void wu_tcgen05_cp_wait() {
asm volatile(".insn r %0, 3, 0, x0, x0, x0" :: "i"(RISCV_CUSTOM3)
: "memory");
}
static inline void wu_tcgen05_cb(uint32_t addr_tmem, uint32_t addr_gmem) {
asm volatile(".insn r %0, 6, 0, x0, %1, %2"
:
: "i"(RISCV_CUSTOM3), "r"(addr_tmem), "r"(addr_gmem)
: "memory");
}
static inline void wu_bwgmma_fp8(uint32_t addr_tmem_c, uint32_t addr_tmem_a,
uint32_t addr_smem_b) {
asm volatile(".insn r %0, 0, 0, %1, %2, %3"
:
: "i"(RISCV_CUSTOM3), "r"(addr_tmem_c), "r"(addr_tmem_a),
"r"(addr_smem_b)
: "memory");
}
static inline void wu_bwgmma_wait() {
asm volatile(".insn r %0, 1, 0, x0, x0, x0" :: "i"(RISCV_CUSTOM3)
: "memory");
}
#endif

View File

@@ -0,0 +1,121 @@
#include "fp8_common.hpp"
#include "../wu_arch_cases/common_wu_min.h"
#define DEV_SMEM_START_ADDR 0xff000000u
#define FP8_VALIDATION_BASE 0x7800u
#define FP8_M 16u
#define FP8_N 16u
#define FP8_K 32u
#define FP8_TILE_BYTES 1024u
#define FP8_FRAGMENT_BYTES 32u
#define FP8_FRAGMENT_WORDS (FP8_FRAGMENT_BYTES / sizeof(uint32_t))
#define FP8_FRAGMENTS (FP8_TILE_BYTES / FP8_FRAGMENT_BYTES)
#define FP8_OUT_WORDS (FP8_M * FP8_N)
#define FP8_EXPECTED 0x42820000u
extern "C" {
volatile uint32_t g_fp8_a_frag[FP8_FRAGMENT_WORDS] __attribute__((aligned(32))) = {
WU_FP8_REP8(WU_FP8_PACK4(WU_FP8_E4M3_ONE, WU_FP8_E4M3_ONE,
WU_FP8_E4M3_ONE, WU_FP8_E4M3_ONE))};
volatile uint32_t g_fp8_b_frag[FP8_FRAGMENT_WORDS] __attribute__((aligned(32))) = {
WU_FP8_REP8(WU_FP8_PACK4(WU_FP8_E4M3_TWO, WU_FP8_E4M3_TWO,
WU_FP8_E4M3_TWO, WU_FP8_E4M3_TWO))};
volatile uint32_t g_fp8_c_frag[FP8_FRAGMENT_WORDS] __attribute__((aligned(32))) = {
WU_FP8_REP8(0x3f800000u)};
volatile uint32_t g_fp8_out[FP8_OUT_WORDS] __attribute__((aligned(32)));
}
#undef WU_FP8_REP2
#undef WU_FP8_REP4
#undef WU_FP8_REP8
extern "C" void __attribute__((naked, noinline, used)) fp8_validation_worker() {
asm volatile(
"li x1, 0\n\t"
"li x2, %[tile_bytes]\n\t"
"la x6, g_fp8_a_frag\n\t"
"la x3, g_fp8_c_frag\n\t"
"li x7, 0\n\t"
"1:\n\t"
"add x4, x1, x7\n\t"
".insn r %[custom3], 2, 0, x0, x4, x6\n\t"
"add x4, x2, x7\n\t"
".insn r %[custom3], 2, 0, x0, x4, x3\n\t"
"addi x7, x7, %[frag_bytes]\n\t"
"li x4, %[tile_bytes]\n\t"
"blt x7, x4, 1b\n\t"
".insn r %[custom3], 3, 0, x0, x0, x0\n\t"
"li x4, %[smem_base]\n\t"
".insn r %[custom3], 0, 0, x2, x1, x4\n\t"
".insn r %[custom3], 1, 0, x0, x0, x0\n\t"
"la x3, g_fp8_out\n\t"
"li x7, 0\n\t"
"2:\n\t"
"add x4, x2, x7\n\t"
"add x1, x3, x7\n\t"
".insn r %[custom3], 6, 0, x0, x4, x1\n\t"
"addi x7, x7, %[frag_bytes]\n\t"
"li x4, %[tile_bytes]\n\t"
"blt x7, x4, 2b\n\t"
".insn r %[custom3], 3, 0, x0, x0, x0\n\t"
"csrr x5, %[csr_wid]\n\t"
"slli x6, x5, 2\n\t"
"la x7, g_seen\n\t"
"add x7, x7, x6\n\t"
"li x6, %[done_base]\n\t"
"or x6, x6, x5\n\t"
"sw x6, 0(x7)\n\t"
".insn r %[custom0], 0, 0, x0, x0, x0\n\t"
"3: j 3b\n\t"
:
: [csr_wid] "i"(VX_CSR_WARP_ID),
[custom0] "i"(RISCV_CUSTOM0),
[custom3] "i"(RISCV_CUSTOM3),
[smem_base] "i"(DEV_SMEM_START_ADDR),
[done_base] "i"(FP8_VALIDATION_BASE),
[tile_bytes] "i"(FP8_TILE_BYTES),
[frag_bytes] "i"(FP8_FRAGMENT_BYTES)
: "memory");
}
extern "C" int wu_main() {
if (!wu_is_leader()) {
return 0;
}
wu_case_reset();
for (uint32_t i = 0; i < FP8_OUT_WORDS; ++i) {
g_fp8_out[i] = 0;
}
volatile uint32_t *smem_b =
reinterpret_cast<volatile uint32_t *>(DEV_SMEM_START_ADDR);
for (uint32_t frag = 0; frag < FP8_FRAGMENTS; ++frag) {
const uint32_t row = frag * FP8_FRAGMENT_WORDS;
for (uint32_t i = 0; i < FP8_FRAGMENT_WORDS; ++i) {
smem_b[row + i] = g_fp8_b_frag[i];
}
}
const uint32_t tensor_wid = NUM_SCALAR_WARPS;
vx_spawn_tensor(1u << tensor_wid, fp8_validation_worker);
if (wu_wait_seen_mask(1u << tensor_wid, FP8_VALIDATION_BASE) != 0) {
wu_case_fail(0x09u);
return 1;
}
for (uint32_t i = 0; i < FP8_OUT_WORDS; ++i) {
if (g_fp8_out[i] != FP8_EXPECTED) {
g_aux[0] = i;
g_aux[1] = g_fp8_out[i];
wu_case_fail(0x20u);
return 1;
}
}
wu_case_pass();
return 0;
}