Use 16-byte fragments in wu_arch_hgemm
This commit is contained in:
@@ -1,8 +1,9 @@
|
||||
# wu_arch_hgemm
|
||||
|
||||
Tensor-warp HGEMM smoke test for the Wu split scalar/tensor warp configuration.
|
||||
Tensor-warp HGEMM smoke test for the Wu split scalar/tensor warp configuration
|
||||
with the 4-lane Blackwell tensor-core path.
|
||||
|
||||
Scalar warp 0 initializes the shared-memory B operand, spawns only the tensor
|
||||
warp mask, waits for tensor warps `NUM_SCALAR_WARPS..NUM_WARPS-1`, and reports
|
||||
completion through `tohost`. Tensor warps execute the Blackwell custom HGEMM
|
||||
instruction sequence and then stop themselves.
|
||||
instruction sequence using 16-byte fragments and then stop themselves.
|
||||
|
||||
@@ -5,20 +5,18 @@
|
||||
|
||||
#define BW_REP2(x) x, x
|
||||
#define BW_REP4(x) BW_REP2(x), BW_REP2(x)
|
||||
#define BW_REP8(x) BW_REP4(x), BW_REP4(x)
|
||||
|
||||
extern "C" {
|
||||
volatile uint32_t g_hgemm_a_row[8] __attribute__((aligned(32))) = {
|
||||
BW_REP8(0x3c003c00u)};
|
||||
volatile uint32_t g_hgemm_b_row[8] __attribute__((aligned(32))) = {
|
||||
BW_REP8(0x40004000u)};
|
||||
volatile uint32_t g_hgemm_c_row[8] __attribute__((aligned(32))) = {
|
||||
BW_REP8(0x3f800000u)};
|
||||
volatile uint32_t g_hgemm_a_row[4] __attribute__((aligned(16))) = {
|
||||
BW_REP4(0x3c003c00u)};
|
||||
volatile uint32_t g_hgemm_b_row[4] __attribute__((aligned(16))) = {
|
||||
BW_REP4(0x40004000u)};
|
||||
volatile uint32_t g_hgemm_c_row[4] __attribute__((aligned(16))) = {
|
||||
BW_REP4(0x3f800000u)};
|
||||
}
|
||||
|
||||
#undef BW_REP2
|
||||
#undef BW_REP4
|
||||
#undef BW_REP8
|
||||
|
||||
extern "C" void __attribute__((naked, noinline, used)) tensor_hgemm_worker() {
|
||||
asm volatile(
|
||||
@@ -33,7 +31,7 @@ extern "C" void __attribute__((naked, noinline, used)) tensor_hgemm_worker() {
|
||||
".insn r %[custom3], 2, 0, x0, x4, x6\n\t"
|
||||
"add x4, x2, x7\n\t"
|
||||
".insn r %[custom3], 2, 0, x0, x4, x3\n\t"
|
||||
"addi x7, x7, 32\n\t"
|
||||
"addi x7, x7, 16\n\t"
|
||||
"li x4, 1024\n\t"
|
||||
"blt x7, x4, 1b\n\t"
|
||||
".insn r %[custom3], 3, 0, x0, x0, x0\n\t"
|
||||
@@ -67,9 +65,9 @@ extern "C" int wu_main() {
|
||||
|
||||
volatile uint32_t *smem_b =
|
||||
reinterpret_cast<volatile uint32_t *>(DEV_SMEM_START_ADDR);
|
||||
for (uint32_t frag = 0; frag < 32u; ++frag) {
|
||||
const uint32_t row = frag * 8u;
|
||||
for (uint32_t i = 0; i < 8u; ++i) {
|
||||
for (uint32_t frag = 0; frag < 64u; ++frag) {
|
||||
const uint32_t row = frag * 4u;
|
||||
for (uint32_t i = 0; i < 4u; ++i) {
|
||||
smem_b[row + i] = g_hgemm_b_row[i];
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user