Use 16-byte fragments in wu_arch_hgemm

This commit is contained in:
Zhongdi LUO
2026-05-27 05:54:55 +00:00
parent 122a048ea6
commit ed16541c8e
2 changed files with 13 additions and 14 deletions

View File

@@ -1,8 +1,9 @@
# wu_arch_hgemm
Tensor-warp HGEMM smoke test for the Wu split scalar/tensor warp configuration.
Tensor-warp HGEMM smoke test for the Wu split scalar/tensor warp configuration
with the 4-lane Blackwell tensor-core path.
Scalar warp 0 initializes the shared-memory B operand, spawns only the tensor
warp mask, waits for tensor warps `NUM_SCALAR_WARPS..NUM_WARPS-1`, and reports
completion through `tohost`. Tensor warps execute the Blackwell custom HGEMM
instruction sequence and then stop themselves.
instruction sequence using 16-byte fragments and then stop themselves.

View File

@@ -5,20 +5,18 @@
#define BW_REP2(x) x, x
#define BW_REP4(x) BW_REP2(x), BW_REP2(x)
#define BW_REP8(x) BW_REP4(x), BW_REP4(x)
extern "C" {
volatile uint32_t g_hgemm_a_row[8] __attribute__((aligned(32))) = {
BW_REP8(0x3c003c00u)};
volatile uint32_t g_hgemm_b_row[8] __attribute__((aligned(32))) = {
BW_REP8(0x40004000u)};
volatile uint32_t g_hgemm_c_row[8] __attribute__((aligned(32))) = {
BW_REP8(0x3f800000u)};
volatile uint32_t g_hgemm_a_row[4] __attribute__((aligned(16))) = {
BW_REP4(0x3c003c00u)};
volatile uint32_t g_hgemm_b_row[4] __attribute__((aligned(16))) = {
BW_REP4(0x40004000u)};
volatile uint32_t g_hgemm_c_row[4] __attribute__((aligned(16))) = {
BW_REP4(0x3f800000u)};
}
#undef BW_REP2
#undef BW_REP4
#undef BW_REP8
extern "C" void __attribute__((naked, noinline, used)) tensor_hgemm_worker() {
asm volatile(
@@ -33,7 +31,7 @@ extern "C" void __attribute__((naked, noinline, used)) tensor_hgemm_worker() {
".insn r %[custom3], 2, 0, x0, x4, x6\n\t"
"add x4, x2, x7\n\t"
".insn r %[custom3], 2, 0, x0, x4, x3\n\t"
"addi x7, x7, 32\n\t"
"addi x7, x7, 16\n\t"
"li x4, 1024\n\t"
"blt x7, x4, 1b\n\t"
".insn r %[custom3], 3, 0, x0, x0, x0\n\t"
@@ -67,9 +65,9 @@ extern "C" int wu_main() {
volatile uint32_t *smem_b =
reinterpret_cast<volatile uint32_t *>(DEV_SMEM_START_ADDR);
for (uint32_t frag = 0; frag < 32u; ++frag) {
const uint32_t row = frag * 8u;
for (uint32_t i = 0; i < 8u; ++i) {
for (uint32_t frag = 0; frag < 64u; ++frag) {
const uint32_t row = frag * 4u;
for (uint32_t i = 0; i < 4u; ++i) {
smem_b[row + i] = g_hgemm_b_row[i];
}
}