Files
kernels/kernels/blackwell_fp8_e4m3/kernel.cpp
2026-07-03 08:40:25 +00:00

122 lines
3.6 KiB
C++

#include "fp8_common.hpp"
#include "../wu_arch_cases/common_wu_min.h"
#define DEV_SMEM_START_ADDR 0xff000000u
#define FP8_VALIDATION_BASE 0x7800u
#define FP8_M 16u
#define FP8_N 16u
#define FP8_K 32u
#define FP8_TILE_BYTES 1024u
#define FP8_FRAGMENT_BYTES 16u
#define FP8_FRAGMENT_WORDS (FP8_FRAGMENT_BYTES / sizeof(uint32_t))
#define FP8_FRAGMENTS (FP8_TILE_BYTES / FP8_FRAGMENT_BYTES)
#define FP8_OUT_WORDS (FP8_M * FP8_N)
#define FP8_EXPECTED 0x42820000u
extern "C" {
volatile uint32_t g_fp8_a_frag[FP8_FRAGMENT_WORDS] __attribute__((aligned(16))) = {
WU_FP8_REP4(WU_FP8_PACK4(WU_FP8_E4M3_ONE, WU_FP8_E4M3_ONE,
WU_FP8_E4M3_ONE, WU_FP8_E4M3_ONE))};
volatile uint32_t g_fp8_b_frag[FP8_FRAGMENT_WORDS] __attribute__((aligned(16))) = {
WU_FP8_REP4(WU_FP8_PACK4(WU_FP8_E4M3_TWO, WU_FP8_E4M3_TWO,
WU_FP8_E4M3_TWO, WU_FP8_E4M3_TWO))};
volatile uint32_t g_fp8_c_frag[FP8_FRAGMENT_WORDS] __attribute__((aligned(16))) = {
WU_FP8_REP4(0x3f800000u)};
volatile uint32_t g_fp8_out[FP8_OUT_WORDS] __attribute__((aligned(16)));
}
#undef WU_FP8_REP2
#undef WU_FP8_REP4
#undef WU_FP8_REP8
extern "C" void __attribute__((naked, noinline, used)) fp8_validation_worker() {
asm volatile(
"li x1, 0\n\t"
"li x2, %[tile_bytes]\n\t"
"la x6, g_fp8_a_frag\n\t"
"la x3, g_fp8_c_frag\n\t"
"li x7, 0\n\t"
"1:\n\t"
"add x4, x1, x7\n\t"
".insn r %[custom3], 2, 0, x0, x4, x6\n\t"
"add x4, x2, x7\n\t"
".insn r %[custom3], 2, 0, x0, x4, x3\n\t"
"addi x7, x7, %[frag_bytes]\n\t"
"li x4, %[tile_bytes]\n\t"
"blt x7, x4, 1b\n\t"
".insn r %[custom3], 3, 0, x0, x0, x0\n\t"
"li x4, %[smem_base]\n\t"
".insn r %[custom3], 0, 0, x2, x1, x4\n\t"
".insn r %[custom3], 1, 0, x0, x0, x0\n\t"
"la x3, g_fp8_out\n\t"
"li x7, 0\n\t"
"2:\n\t"
"add x4, x2, x7\n\t"
"add x1, x3, x7\n\t"
".insn r %[custom3], 6, 0, x0, x4, x1\n\t"
"addi x7, x7, %[frag_bytes]\n\t"
"li x4, %[tile_bytes]\n\t"
"blt x7, x4, 2b\n\t"
".insn r %[custom3], 3, 0, x0, x0, x0\n\t"
"csrr x5, %[csr_wid]\n\t"
"slli x6, x5, 2\n\t"
"la x7, g_seen\n\t"
"add x7, x7, x6\n\t"
"li x6, %[done_base]\n\t"
"or x6, x6, x5\n\t"
"sw x6, 0(x7)\n\t"
".insn r %[custom0], 0, 0, x0, x0, x0\n\t"
"3: j 3b\n\t"
:
: [csr_wid] "i"(VX_CSR_WARP_ID),
[custom0] "i"(RISCV_CUSTOM0),
[custom3] "i"(RISCV_CUSTOM3),
[smem_base] "i"(DEV_SMEM_START_ADDR),
[done_base] "i"(FP8_VALIDATION_BASE),
[tile_bytes] "i"(FP8_TILE_BYTES),
[frag_bytes] "i"(FP8_FRAGMENT_BYTES)
: "memory");
}
extern "C" int wu_main() {
if (!wu_is_leader()) {
return 0;
}
wu_case_reset();
for (uint32_t i = 0; i < FP8_OUT_WORDS; ++i) {
g_fp8_out[i] = 0;
}
volatile uint32_t *smem_b =
reinterpret_cast<volatile uint32_t *>(DEV_SMEM_START_ADDR);
for (uint32_t frag = 0; frag < FP8_FRAGMENTS; ++frag) {
const uint32_t row = frag * FP8_FRAGMENT_WORDS;
for (uint32_t i = 0; i < FP8_FRAGMENT_WORDS; ++i) {
smem_b[row + i] = g_fp8_b_frag[i];
}
}
const uint32_t tensor_wid = NUM_SCALAR_WARPS;
vx_spawn_tensor(1u << tensor_wid, fp8_validation_worker);
if (wu_wait_seen_mask(1u << tensor_wid, FP8_VALIDATION_BASE) != 0) {
wu_case_fail(0x09u);
return 1;
}
for (uint32_t i = 0; i < FP8_OUT_WORDS; ++i) {
if (g_fp8_out[i] != FP8_EXPECTED) {
g_aux[0] = i;
g_aux[1] = g_fp8_out[i];
wu_case_fail(0x20u);
return 1;
}
}
wu_case_pass();
return 0;
}