From 8a15e5251e48121af4d5b878d30b7b004627b646 Mon Sep 17 00:00:00 2001 From: Zhongdi LUO Date: Fri, 3 Jul 2026 08:40:25 +0000 Subject: [PATCH] fix: match blackwell fp8 fragment width --- kernels/blackwell_fp8_e4m3/README.md | 2 ++ kernels/blackwell_fp8_e4m3/kernel.cpp | 16 ++++++++-------- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/kernels/blackwell_fp8_e4m3/README.md b/kernels/blackwell_fp8_e4m3/README.md index b34facf6..90b36f52 100644 --- a/kernels/blackwell_fp8_e4m3/README.md +++ b/kernels/blackwell_fp8_e4m3/README.md @@ -11,6 +11,8 @@ The validation runs one tensor warp on a 16x16x32 tile: - B is FP8 E4M3 2.0 (`0x40`) - C is FP32 1.0 (`0x3f800000`) - Expected output is FP32 65.0 (`0x42820000`) +- `VirgoBlackwellConfig` currently uses 4 core/memory lanes, so one + `tcgen05_cp/cb` fragment is 16 bytes. Build: diff --git a/kernels/blackwell_fp8_e4m3/kernel.cpp b/kernels/blackwell_fp8_e4m3/kernel.cpp index 9758277f..3718a777 100644 --- a/kernels/blackwell_fp8_e4m3/kernel.cpp +++ b/kernels/blackwell_fp8_e4m3/kernel.cpp @@ -8,22 +8,22 @@ #define FP8_N 16u #define FP8_K 32u #define FP8_TILE_BYTES 1024u -#define FP8_FRAGMENT_BYTES 32u +#define FP8_FRAGMENT_BYTES 16u #define FP8_FRAGMENT_WORDS (FP8_FRAGMENT_BYTES / sizeof(uint32_t)) #define FP8_FRAGMENTS (FP8_TILE_BYTES / FP8_FRAGMENT_BYTES) #define FP8_OUT_WORDS (FP8_M * FP8_N) #define FP8_EXPECTED 0x42820000u extern "C" { -volatile uint32_t g_fp8_a_frag[FP8_FRAGMENT_WORDS] __attribute__((aligned(32))) = { - WU_FP8_REP8(WU_FP8_PACK4(WU_FP8_E4M3_ONE, WU_FP8_E4M3_ONE, +volatile uint32_t g_fp8_a_frag[FP8_FRAGMENT_WORDS] __attribute__((aligned(16))) = { + WU_FP8_REP4(WU_FP8_PACK4(WU_FP8_E4M3_ONE, WU_FP8_E4M3_ONE, WU_FP8_E4M3_ONE, WU_FP8_E4M3_ONE))}; -volatile uint32_t g_fp8_b_frag[FP8_FRAGMENT_WORDS] __attribute__((aligned(32))) = { - WU_FP8_REP8(WU_FP8_PACK4(WU_FP8_E4M3_TWO, WU_FP8_E4M3_TWO, +volatile uint32_t g_fp8_b_frag[FP8_FRAGMENT_WORDS] __attribute__((aligned(16))) = { + WU_FP8_REP4(WU_FP8_PACK4(WU_FP8_E4M3_TWO, WU_FP8_E4M3_TWO, WU_FP8_E4M3_TWO, WU_FP8_E4M3_TWO))}; -volatile uint32_t g_fp8_c_frag[FP8_FRAGMENT_WORDS] __attribute__((aligned(32))) = { - WU_FP8_REP8(0x3f800000u)}; -volatile uint32_t g_fp8_out[FP8_OUT_WORDS] __attribute__((aligned(32))); +volatile uint32_t g_fp8_c_frag[FP8_FRAGMENT_WORDS] __attribute__((aligned(16))) = { + WU_FP8_REP4(0x3f800000u)}; +volatile uint32_t g_fp8_out[FP8_OUT_WORDS] __attribute__((aligned(16))); } #undef WU_FP8_REP2