sgemm_impl: Implement fast coalesced wmma_store

Enables a fairer comparison between core-coupled tensor core to Hopper
tensor core, where the latter benefits from coalesced full-throughput
moveout to GMEM because it does not use the 1x2 interleaved register
mapping.  This means the result matrix will be stored swizzled in the
GMEM, without breaking correctness.
This commit is contained in:
Hansung Kim
2024-10-29 22:34:22 -07:00
parent 6b39a6fe70
commit 21b6655c10

View File

@@ -104,7 +104,13 @@ static_assert(WMITER * WNITER * TCM * TCN * NUM_WARPS * CORES_PER_CLUSTER ==
#define TRANSPOSE_AT_PRODUCE 0
#define TRANSPOSE_AT_CONSUME 0
#define GEMMINI_DMA 1
// if 1, wmma_store() will not respect the register <-> matrix fragment mapping
// scheme and instead do a fast coalesced GMEM writes for move out. This
// doesn't necessarily mean breaking correctness; it means that the final
// result matrix will be stored in a swizzled form in the global memory.
#define WMMA_STORE_FAST 1
#define GEMMINI_DMA 0
#define GEMMINI_DMA_FAST 1
#if SMEM_SIZE == 0x4000
#define SMEM_ADDR_Q0 ((float * const) 0xff000000)
@@ -213,10 +219,9 @@ inline constexpr void map_c_8lanes(const int tid, int &row, int &col) {
col += ((tid % 4) / 2) * 2;
}
inline constexpr void map_c_8lanes_hopper(const int tid, int &row, int &col) {
inline constexpr void map_c_8lanes_coalesced(const int tid, int &row, int &col) {
const int tg = tid / 2;
// FIXME wrong!!!
row = 0;
col = tid;
}
@@ -225,8 +230,8 @@ inline constexpr void map_c(const int tid, int &row, int &col) {
if constexpr (NUM_THREADS == 32) {
map_c_32lanes(tid, row, col);
} else if constexpr (NUM_THREADS == 8) {
if constexpr (TENSOR_HOPPER) {
map_c_8lanes_hopper(tid, row, col);
if constexpr (TENSOR_HOPPER || WMMA_STORE_FAST) {
map_c_8lanes_coalesced(tid, row, col);
} else {
map_c_8lanes(tid, row, col);
}
@@ -664,26 +669,48 @@ wmma_store(const int thread_in_warp, const int warp_col, const int warp_row,
volatile uint8_t *addr = reinterpret_cast<volatile uint8_t *>(
&write_addr[dim_n * (local_row + 0) + (local_col + 0)]);
volatile uint8_t *addr_tworow = addr + (2 * dim_n) * sizeof(float);
asm volatile("fsw f16, %0(%1)" ::"i"(0 * sizeof(float)), "r"(addr));
asm volatile("fsw f17, %0(%1)" ::"i"(1 * sizeof(float)), "r"(addr));
asm volatile("fsw f18, %0(%1)" ::"i"(0 * sizeof(float)), "r"(addr_tworow));
asm volatile("fsw f19, %0(%1)" ::"i"(1 * sizeof(float)), "r"(addr_tworow));
asm volatile("fsw f20, %0(%1)" ::"i"(4 * sizeof(float)), "r"(addr));
asm volatile("fsw f21, %0(%1)" ::"i"(5 * sizeof(float)), "r"(addr));
asm volatile("fsw f22, %0(%1)" ::"i"(4 * sizeof(float)), "r"(addr_tworow));
asm volatile("fsw f23, %0(%1)" ::"i"(5 * sizeof(float)), "r"(addr_tworow));
if constexpr (!WMMA_STORE_FAST) {
asm volatile("fsw f16, %0(%1)" ::"i"(0 * sizeof(float)), "r"(addr));
asm volatile("fsw f17, %0(%1)" ::"i"(1 * sizeof(float)), "r"(addr));
asm volatile("fsw f18, %0(%1)" ::"i"(0 * sizeof(float)), "r"(addr_tworow));
asm volatile("fsw f19, %0(%1)" ::"i"(1 * sizeof(float)), "r"(addr_tworow));
asm volatile("fsw f20, %0(%1)" ::"i"(4 * sizeof(float)), "r"(addr));
asm volatile("fsw f21, %0(%1)" ::"i"(5 * sizeof(float)), "r"(addr));
asm volatile("fsw f22, %0(%1)" ::"i"(4 * sizeof(float)), "r"(addr_tworow));
asm volatile("fsw f23, %0(%1)" ::"i"(5 * sizeof(float)), "r"(addr_tworow));
} else {
asm volatile("fsw f16, %0(%1)" ::"i"(0 * WN * sizeof(float)), "r"(addr));
asm volatile("fsw f17, %0(%1)" ::"i"(1 * WN * sizeof(float)), "r"(addr));
asm volatile("fsw f18, %0(%1)" ::"i"(2 * WN * sizeof(float)), "r"(addr));
asm volatile("fsw f19, %0(%1)" ::"i"(3 * WN * sizeof(float)), "r"(addr));
asm volatile("fsw f20, %0(%1)" ::"i"(4 * WN * sizeof(float)), "r"(addr));
asm volatile("fsw f21, %0(%1)" ::"i"(5 * WN * sizeof(float)), "r"(addr));
asm volatile("fsw f22, %0(%1)" ::"i"(6 * WN * sizeof(float)), "r"(addr));
asm volatile("fsw f23, %0(%1)" ::"i"(7 * WN * sizeof(float)), "r"(addr));
}
} else {
volatile uint8_t *addr = reinterpret_cast<volatile uint8_t *>(
&write_addr[dim_n * (local_row + 0) + (local_col + 0)]);
volatile uint8_t *addr_tworow = addr + (2 * dim_n) * sizeof(float);
asm volatile("fsw f24, %0(%1)" ::"i"(0 * sizeof(float)), "r"(addr));
asm volatile("fsw f25, %0(%1)" ::"i"(1 * sizeof(float)), "r"(addr));
asm volatile("fsw f26, %0(%1)" ::"i"(0 * sizeof(float)), "r"(addr_tworow));
asm volatile("fsw f27, %0(%1)" ::"i"(1 * sizeof(float)), "r"(addr_tworow));
asm volatile("fsw f28, %0(%1)" ::"i"(4 * sizeof(float)), "r"(addr));
asm volatile("fsw f29, %0(%1)" ::"i"(5 * sizeof(float)), "r"(addr));
asm volatile("fsw f30, %0(%1)" ::"i"(4 * sizeof(float)), "r"(addr_tworow));
asm volatile("fsw f31, %0(%1)" ::"i"(5 * sizeof(float)), "r"(addr_tworow));
if constexpr (!WMMA_STORE_FAST) {
asm volatile("fsw f24, %0(%1)" ::"i"(0 * sizeof(float)), "r"(addr));
asm volatile("fsw f25, %0(%1)" ::"i"(1 * sizeof(float)), "r"(addr));
asm volatile("fsw f26, %0(%1)" ::"i"(2 * sizeof(float)), "r"(addr_tworow));
asm volatile("fsw f27, %0(%1)" ::"i"(3 * sizeof(float)), "r"(addr_tworow));
asm volatile("fsw f28, %0(%1)" ::"i"(4 * sizeof(float)), "r"(addr));
asm volatile("fsw f29, %0(%1)" ::"i"(5 * sizeof(float)), "r"(addr));
asm volatile("fsw f30, %0(%1)" ::"i"(6 * sizeof(float)), "r"(addr_tworow));
asm volatile("fsw f31, %0(%1)" ::"i"(7 * sizeof(float)), "r"(addr_tworow));
} else {
asm volatile("fsw f24, %0(%1)" ::"i"(0 * WN * sizeof(float)), "r"(addr));
asm volatile("fsw f25, %0(%1)" ::"i"(1 * WN * sizeof(float)), "r"(addr));
asm volatile("fsw f26, %0(%1)" ::"i"(2 * WN * sizeof(float)), "r"(addr));
asm volatile("fsw f27, %0(%1)" ::"i"(3 * WN * sizeof(float)), "r"(addr));
asm volatile("fsw f28, %0(%1)" ::"i"(4 * WN * sizeof(float)), "r"(addr));
asm volatile("fsw f29, %0(%1)" ::"i"(5 * WN * sizeof(float)), "r"(addr));
asm volatile("fsw f30, %0(%1)" ::"i"(6 * WN * sizeof(float)), "r"(addr));
asm volatile("fsw f31, %0(%1)" ::"i"(7 * WN * sizeof(float)), "r"(addr));
}
}
asm volatile ("wmma_store_finish_%=:" :: );