diff --git a/kernels/wu_arch_hgemm/README.md b/kernels/wu_arch_hgemm/README.md index 8046f77d..968d53f3 100644 --- a/kernels/wu_arch_hgemm/README.md +++ b/kernels/wu_arch_hgemm/README.md @@ -1,8 +1,9 @@ # wu_arch_hgemm -Tensor-warp HGEMM smoke test for the Wu split scalar/tensor warp configuration. +Tensor-warp HGEMM smoke test for the Wu split scalar/tensor warp configuration +with the 4-lane Blackwell tensor-core path. Scalar warp 0 initializes the shared-memory B operand, spawns only the tensor warp mask, waits for tensor warps `NUM_SCALAR_WARPS..NUM_WARPS-1`, and reports completion through `tohost`. Tensor warps execute the Blackwell custom HGEMM -instruction sequence and then stop themselves. +instruction sequence using 16-byte fragments and then stop themselves. diff --git a/kernels/wu_arch_hgemm/kernel.cpp b/kernels/wu_arch_hgemm/kernel.cpp index c5925531..97b81026 100644 --- a/kernels/wu_arch_hgemm/kernel.cpp +++ b/kernels/wu_arch_hgemm/kernel.cpp @@ -5,20 +5,18 @@ #define BW_REP2(x) x, x #define BW_REP4(x) BW_REP2(x), BW_REP2(x) -#define BW_REP8(x) BW_REP4(x), BW_REP4(x) extern "C" { -volatile uint32_t g_hgemm_a_row[8] __attribute__((aligned(32))) = { - BW_REP8(0x3c003c00u)}; -volatile uint32_t g_hgemm_b_row[8] __attribute__((aligned(32))) = { - BW_REP8(0x40004000u)}; -volatile uint32_t g_hgemm_c_row[8] __attribute__((aligned(32))) = { - BW_REP8(0x3f800000u)}; +volatile uint32_t g_hgemm_a_row[4] __attribute__((aligned(16))) = { + BW_REP4(0x3c003c00u)}; +volatile uint32_t g_hgemm_b_row[4] __attribute__((aligned(16))) = { + BW_REP4(0x40004000u)}; +volatile uint32_t g_hgemm_c_row[4] __attribute__((aligned(16))) = { + BW_REP4(0x3f800000u)}; } #undef BW_REP2 #undef BW_REP4 -#undef BW_REP8 extern "C" void __attribute__((naked, noinline, used)) tensor_hgemm_worker() { asm volatile( @@ -33,7 +31,7 @@ extern "C" void __attribute__((naked, noinline, used)) tensor_hgemm_worker() { ".insn r %[custom3], 2, 0, x0, x4, x6\n\t" "add x4, x2, x7\n\t" ".insn r %[custom3], 2, 0, x0, x4, x3\n\t" - "addi x7, x7, 32\n\t" + "addi x7, x7, 16\n\t" "li x4, 1024\n\t" "blt x7, x4, 1b\n\t" ".insn r %[custom3], 3, 0, x0, x0, x0\n\t" @@ -67,9 +65,9 @@ extern "C" int wu_main() { volatile uint32_t *smem_b = reinterpret_cast(DEV_SMEM_START_ADDR); - for (uint32_t frag = 0; frag < 32u; ++frag) { - const uint32_t row = frag * 8u; - for (uint32_t i = 0; i < 8u; ++i) { + for (uint32_t frag = 0; frag < 64u; ++frag) { + const uint32_t row = frag * 4u; + for (uint32_t i = 0; i < 4u; ++i) { smem_b[row + i] = g_hgemm_b_row[i]; } }