feat: add blackwell fp8 e4m3 kernel
This commit is contained in:
10
kernels/blackwell_fp8_e4m3/Makefile
Normal file
10
kernels/blackwell_fp8_e4m3/Makefile
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
PROJECT = blackwell_fp8_e4m3
|
||||||
|
|
||||||
|
VX_SRCS = kernel.cpp
|
||||||
|
VX_INCLUDES = fp8_common.hpp
|
||||||
|
OPTS ?= -n1
|
||||||
|
|
||||||
|
include ../common.mk
|
||||||
|
|
||||||
|
args.bin input.a.bin input.b.bin input.c.bin: ../wu_arch_cases/zero.bin
|
||||||
|
cp $< $@
|
||||||
19
kernels/blackwell_fp8_e4m3/README.md
Normal file
19
kernels/blackwell_fp8_e4m3/README.md
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
# blackwell_fp8_e4m3
|
||||||
|
|
||||||
|
Standalone FP8 E4M3 validation kernel for the Wu Blackwell BWGMMA branch.
|
||||||
|
|
||||||
|
This directory is the only kernel area used by the FP8 branch work. Existing
|
||||||
|
FP16 HGEMM, `wu_arch_cases`, and flash kernels are intentionally left unchanged.
|
||||||
|
|
||||||
|
The validation runs one tensor warp on a 16x16x32 tile:
|
||||||
|
|
||||||
|
- A is FP8 E4M3 1.0 (`0x38`)
|
||||||
|
- B is FP8 E4M3 2.0 (`0x40`)
|
||||||
|
- C is FP32 1.0 (`0x3f800000`)
|
||||||
|
- Expected output is FP32 65.0 (`0x42820000`)
|
||||||
|
|
||||||
|
Build:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
make -C /home/lzd/wu/wuarch/virgo-kernels/kernels/blackwell_fp8_e4m3
|
||||||
|
```
|
||||||
53
kernels/blackwell_fp8_e4m3/fp8_common.hpp
Normal file
53
kernels/blackwell_fp8_e4m3/fp8_common.hpp
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
#ifndef BLACKWELL_FP8_E4M3_COMMON_HPP
|
||||||
|
#define BLACKWELL_FP8_E4M3_COMMON_HPP
|
||||||
|
|
||||||
|
#include <stdint.h>
|
||||||
|
#include <vx_intrinsics.h>
|
||||||
|
|
||||||
|
#define WU_FP8_E4M3_ZERO 0x00u
|
||||||
|
#define WU_FP8_E4M3_HALF 0x30u
|
||||||
|
#define WU_FP8_E4M3_ONE 0x38u
|
||||||
|
#define WU_FP8_E4M3_TWO 0x40u
|
||||||
|
|
||||||
|
#define WU_FP8_PACK4(a, b, c, d) \
|
||||||
|
((((uint32_t)(a) & 0xffu) << 0) | (((uint32_t)(b) & 0xffu) << 8) | \
|
||||||
|
(((uint32_t)(c) & 0xffu) << 16) | (((uint32_t)(d) & 0xffu) << 24))
|
||||||
|
|
||||||
|
#define WU_FP8_REP2(x) x, x
|
||||||
|
#define WU_FP8_REP4(x) WU_FP8_REP2(x), WU_FP8_REP2(x)
|
||||||
|
#define WU_FP8_REP8(x) WU_FP8_REP4(x), WU_FP8_REP4(x)
|
||||||
|
|
||||||
|
static inline void wu_tcgen05_cp(uint32_t addr_tmem, uint32_t addr_gmem) {
|
||||||
|
asm volatile(".insn r %0, 2, 0, x0, %1, %2"
|
||||||
|
:
|
||||||
|
: "i"(RISCV_CUSTOM3), "r"(addr_tmem), "r"(addr_gmem)
|
||||||
|
: "memory");
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline void wu_tcgen05_cp_wait() {
|
||||||
|
asm volatile(".insn r %0, 3, 0, x0, x0, x0" :: "i"(RISCV_CUSTOM3)
|
||||||
|
: "memory");
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline void wu_tcgen05_cb(uint32_t addr_tmem, uint32_t addr_gmem) {
|
||||||
|
asm volatile(".insn r %0, 6, 0, x0, %1, %2"
|
||||||
|
:
|
||||||
|
: "i"(RISCV_CUSTOM3), "r"(addr_tmem), "r"(addr_gmem)
|
||||||
|
: "memory");
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline void wu_bwgmma_fp8(uint32_t addr_tmem_c, uint32_t addr_tmem_a,
|
||||||
|
uint32_t addr_smem_b) {
|
||||||
|
asm volatile(".insn r %0, 0, 0, %1, %2, %3"
|
||||||
|
:
|
||||||
|
: "i"(RISCV_CUSTOM3), "r"(addr_tmem_c), "r"(addr_tmem_a),
|
||||||
|
"r"(addr_smem_b)
|
||||||
|
: "memory");
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline void wu_bwgmma_wait() {
|
||||||
|
asm volatile(".insn r %0, 1, 0, x0, x0, x0" :: "i"(RISCV_CUSTOM3)
|
||||||
|
: "memory");
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
||||||
121
kernels/blackwell_fp8_e4m3/kernel.cpp
Normal file
121
kernels/blackwell_fp8_e4m3/kernel.cpp
Normal file
@@ -0,0 +1,121 @@
|
|||||||
|
#include "fp8_common.hpp"
|
||||||
|
#include "../wu_arch_cases/common_wu_min.h"
|
||||||
|
|
||||||
|
#define DEV_SMEM_START_ADDR 0xff000000u
|
||||||
|
#define FP8_VALIDATION_BASE 0x7800u
|
||||||
|
|
||||||
|
#define FP8_M 16u
|
||||||
|
#define FP8_N 16u
|
||||||
|
#define FP8_K 32u
|
||||||
|
#define FP8_TILE_BYTES 1024u
|
||||||
|
#define FP8_FRAGMENT_BYTES 32u
|
||||||
|
#define FP8_FRAGMENT_WORDS (FP8_FRAGMENT_BYTES / sizeof(uint32_t))
|
||||||
|
#define FP8_FRAGMENTS (FP8_TILE_BYTES / FP8_FRAGMENT_BYTES)
|
||||||
|
#define FP8_OUT_WORDS (FP8_M * FP8_N)
|
||||||
|
#define FP8_EXPECTED 0x42820000u
|
||||||
|
|
||||||
|
extern "C" {
|
||||||
|
volatile uint32_t g_fp8_a_frag[FP8_FRAGMENT_WORDS] __attribute__((aligned(32))) = {
|
||||||
|
WU_FP8_REP8(WU_FP8_PACK4(WU_FP8_E4M3_ONE, WU_FP8_E4M3_ONE,
|
||||||
|
WU_FP8_E4M3_ONE, WU_FP8_E4M3_ONE))};
|
||||||
|
volatile uint32_t g_fp8_b_frag[FP8_FRAGMENT_WORDS] __attribute__((aligned(32))) = {
|
||||||
|
WU_FP8_REP8(WU_FP8_PACK4(WU_FP8_E4M3_TWO, WU_FP8_E4M3_TWO,
|
||||||
|
WU_FP8_E4M3_TWO, WU_FP8_E4M3_TWO))};
|
||||||
|
volatile uint32_t g_fp8_c_frag[FP8_FRAGMENT_WORDS] __attribute__((aligned(32))) = {
|
||||||
|
WU_FP8_REP8(0x3f800000u)};
|
||||||
|
volatile uint32_t g_fp8_out[FP8_OUT_WORDS] __attribute__((aligned(32)));
|
||||||
|
}
|
||||||
|
|
||||||
|
#undef WU_FP8_REP2
|
||||||
|
#undef WU_FP8_REP4
|
||||||
|
#undef WU_FP8_REP8
|
||||||
|
|
||||||
|
extern "C" void __attribute__((naked, noinline, used)) fp8_validation_worker() {
|
||||||
|
asm volatile(
|
||||||
|
"li x1, 0\n\t"
|
||||||
|
"li x2, %[tile_bytes]\n\t"
|
||||||
|
"la x6, g_fp8_a_frag\n\t"
|
||||||
|
"la x3, g_fp8_c_frag\n\t"
|
||||||
|
"li x7, 0\n\t"
|
||||||
|
"1:\n\t"
|
||||||
|
"add x4, x1, x7\n\t"
|
||||||
|
".insn r %[custom3], 2, 0, x0, x4, x6\n\t"
|
||||||
|
"add x4, x2, x7\n\t"
|
||||||
|
".insn r %[custom3], 2, 0, x0, x4, x3\n\t"
|
||||||
|
"addi x7, x7, %[frag_bytes]\n\t"
|
||||||
|
"li x4, %[tile_bytes]\n\t"
|
||||||
|
"blt x7, x4, 1b\n\t"
|
||||||
|
".insn r %[custom3], 3, 0, x0, x0, x0\n\t"
|
||||||
|
"li x4, %[smem_base]\n\t"
|
||||||
|
".insn r %[custom3], 0, 0, x2, x1, x4\n\t"
|
||||||
|
".insn r %[custom3], 1, 0, x0, x0, x0\n\t"
|
||||||
|
"la x3, g_fp8_out\n\t"
|
||||||
|
"li x7, 0\n\t"
|
||||||
|
"2:\n\t"
|
||||||
|
"add x4, x2, x7\n\t"
|
||||||
|
"add x1, x3, x7\n\t"
|
||||||
|
".insn r %[custom3], 6, 0, x0, x4, x1\n\t"
|
||||||
|
"addi x7, x7, %[frag_bytes]\n\t"
|
||||||
|
"li x4, %[tile_bytes]\n\t"
|
||||||
|
"blt x7, x4, 2b\n\t"
|
||||||
|
".insn r %[custom3], 3, 0, x0, x0, x0\n\t"
|
||||||
|
"csrr x5, %[csr_wid]\n\t"
|
||||||
|
"slli x6, x5, 2\n\t"
|
||||||
|
"la x7, g_seen\n\t"
|
||||||
|
"add x7, x7, x6\n\t"
|
||||||
|
"li x6, %[done_base]\n\t"
|
||||||
|
"or x6, x6, x5\n\t"
|
||||||
|
"sw x6, 0(x7)\n\t"
|
||||||
|
".insn r %[custom0], 0, 0, x0, x0, x0\n\t"
|
||||||
|
"3: j 3b\n\t"
|
||||||
|
:
|
||||||
|
: [csr_wid] "i"(VX_CSR_WARP_ID),
|
||||||
|
[custom0] "i"(RISCV_CUSTOM0),
|
||||||
|
[custom3] "i"(RISCV_CUSTOM3),
|
||||||
|
[smem_base] "i"(DEV_SMEM_START_ADDR),
|
||||||
|
[done_base] "i"(FP8_VALIDATION_BASE),
|
||||||
|
[tile_bytes] "i"(FP8_TILE_BYTES),
|
||||||
|
[frag_bytes] "i"(FP8_FRAGMENT_BYTES)
|
||||||
|
: "memory");
|
||||||
|
}
|
||||||
|
|
||||||
|
extern "C" int wu_main() {
|
||||||
|
if (!wu_is_leader()) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
wu_case_reset();
|
||||||
|
|
||||||
|
for (uint32_t i = 0; i < FP8_OUT_WORDS; ++i) {
|
||||||
|
g_fp8_out[i] = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
volatile uint32_t *smem_b =
|
||||||
|
reinterpret_cast<volatile uint32_t *>(DEV_SMEM_START_ADDR);
|
||||||
|
for (uint32_t frag = 0; frag < FP8_FRAGMENTS; ++frag) {
|
||||||
|
const uint32_t row = frag * FP8_FRAGMENT_WORDS;
|
||||||
|
for (uint32_t i = 0; i < FP8_FRAGMENT_WORDS; ++i) {
|
||||||
|
smem_b[row + i] = g_fp8_b_frag[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const uint32_t tensor_wid = NUM_SCALAR_WARPS;
|
||||||
|
vx_spawn_tensor(1u << tensor_wid, fp8_validation_worker);
|
||||||
|
|
||||||
|
if (wu_wait_seen_mask(1u << tensor_wid, FP8_VALIDATION_BASE) != 0) {
|
||||||
|
wu_case_fail(0x09u);
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (uint32_t i = 0; i < FP8_OUT_WORDS; ++i) {
|
||||||
|
if (g_fp8_out[i] != FP8_EXPECTED) {
|
||||||
|
g_aux[0] = i;
|
||||||
|
g_aux[1] = g_fp8_out[i];
|
||||||
|
wu_case_fail(0x20u);
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
wu_case_pass();
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user