diff --git a/tests/regression/sgemm_tcore/sgemm_impl.hpp b/tests/regression/sgemm_tcore/sgemm_impl.hpp index 13226a76..5bd694dd 100644 --- a/tests/regression/sgemm_tcore/sgemm_impl.hpp +++ b/tests/regression/sgemm_tcore/sgemm_impl.hpp @@ -104,7 +104,13 @@ static_assert(WMITER * WNITER * TCM * TCN * NUM_WARPS * CORES_PER_CLUSTER == #define TRANSPOSE_AT_PRODUCE 0 #define TRANSPOSE_AT_CONSUME 0 -#define GEMMINI_DMA 1 +// if 1, wmma_store() will not respect the register <-> matrix fragment mapping +// scheme and instead do a fast coalesced GMEM writes for move out. This +// doesn't necessarily mean breaking correctness; it means that the final +// result matrix will be stored in a swizzled form in the global memory. +#define WMMA_STORE_FAST 1 + +#define GEMMINI_DMA 0 #define GEMMINI_DMA_FAST 1 #if SMEM_SIZE == 0x4000 #define SMEM_ADDR_Q0 ((float * const) 0xff000000) @@ -213,10 +219,9 @@ inline constexpr void map_c_8lanes(const int tid, int &row, int &col) { col += ((tid % 4) / 2) * 2; } -inline constexpr void map_c_8lanes_hopper(const int tid, int &row, int &col) { +inline constexpr void map_c_8lanes_coalesced(const int tid, int &row, int &col) { const int tg = tid / 2; - // FIXME wrong!!! row = 0; col = tid; } @@ -225,8 +230,8 @@ inline constexpr void map_c(const int tid, int &row, int &col) { if constexpr (NUM_THREADS == 32) { map_c_32lanes(tid, row, col); } else if constexpr (NUM_THREADS == 8) { - if constexpr (TENSOR_HOPPER) { - map_c_8lanes_hopper(tid, row, col); + if constexpr (TENSOR_HOPPER || WMMA_STORE_FAST) { + map_c_8lanes_coalesced(tid, row, col); } else { map_c_8lanes(tid, row, col); } @@ -664,26 +669,48 @@ wmma_store(const int thread_in_warp, const int warp_col, const int warp_row, volatile uint8_t *addr = reinterpret_cast( &write_addr[dim_n * (local_row + 0) + (local_col + 0)]); volatile uint8_t *addr_tworow = addr + (2 * dim_n) * sizeof(float); - asm volatile("fsw f16, %0(%1)" ::"i"(0 * sizeof(float)), "r"(addr)); - asm volatile("fsw f17, %0(%1)" ::"i"(1 * sizeof(float)), "r"(addr)); - asm volatile("fsw f18, %0(%1)" ::"i"(0 * sizeof(float)), "r"(addr_tworow)); - asm volatile("fsw f19, %0(%1)" ::"i"(1 * sizeof(float)), "r"(addr_tworow)); - asm volatile("fsw f20, %0(%1)" ::"i"(4 * sizeof(float)), "r"(addr)); - asm volatile("fsw f21, %0(%1)" ::"i"(5 * sizeof(float)), "r"(addr)); - asm volatile("fsw f22, %0(%1)" ::"i"(4 * sizeof(float)), "r"(addr_tworow)); - asm volatile("fsw f23, %0(%1)" ::"i"(5 * sizeof(float)), "r"(addr_tworow)); + if constexpr (!WMMA_STORE_FAST) { + asm volatile("fsw f16, %0(%1)" ::"i"(0 * sizeof(float)), "r"(addr)); + asm volatile("fsw f17, %0(%1)" ::"i"(1 * sizeof(float)), "r"(addr)); + asm volatile("fsw f18, %0(%1)" ::"i"(0 * sizeof(float)), "r"(addr_tworow)); + asm volatile("fsw f19, %0(%1)" ::"i"(1 * sizeof(float)), "r"(addr_tworow)); + asm volatile("fsw f20, %0(%1)" ::"i"(4 * sizeof(float)), "r"(addr)); + asm volatile("fsw f21, %0(%1)" ::"i"(5 * sizeof(float)), "r"(addr)); + asm volatile("fsw f22, %0(%1)" ::"i"(4 * sizeof(float)), "r"(addr_tworow)); + asm volatile("fsw f23, %0(%1)" ::"i"(5 * sizeof(float)), "r"(addr_tworow)); + } else { + asm volatile("fsw f16, %0(%1)" ::"i"(0 * WN * sizeof(float)), "r"(addr)); + asm volatile("fsw f17, %0(%1)" ::"i"(1 * WN * sizeof(float)), "r"(addr)); + asm volatile("fsw f18, %0(%1)" ::"i"(2 * WN * sizeof(float)), "r"(addr)); + asm volatile("fsw f19, %0(%1)" ::"i"(3 * WN * sizeof(float)), "r"(addr)); + asm volatile("fsw f20, %0(%1)" ::"i"(4 * WN * sizeof(float)), "r"(addr)); + asm volatile("fsw f21, %0(%1)" ::"i"(5 * WN * sizeof(float)), "r"(addr)); + asm volatile("fsw f22, %0(%1)" ::"i"(6 * WN * sizeof(float)), "r"(addr)); + asm volatile("fsw f23, %0(%1)" ::"i"(7 * WN * sizeof(float)), "r"(addr)); + } } else { volatile uint8_t *addr = reinterpret_cast( &write_addr[dim_n * (local_row + 0) + (local_col + 0)]); volatile uint8_t *addr_tworow = addr + (2 * dim_n) * sizeof(float); - asm volatile("fsw f24, %0(%1)" ::"i"(0 * sizeof(float)), "r"(addr)); - asm volatile("fsw f25, %0(%1)" ::"i"(1 * sizeof(float)), "r"(addr)); - asm volatile("fsw f26, %0(%1)" ::"i"(0 * sizeof(float)), "r"(addr_tworow)); - asm volatile("fsw f27, %0(%1)" ::"i"(1 * sizeof(float)), "r"(addr_tworow)); - asm volatile("fsw f28, %0(%1)" ::"i"(4 * sizeof(float)), "r"(addr)); - asm volatile("fsw f29, %0(%1)" ::"i"(5 * sizeof(float)), "r"(addr)); - asm volatile("fsw f30, %0(%1)" ::"i"(4 * sizeof(float)), "r"(addr_tworow)); - asm volatile("fsw f31, %0(%1)" ::"i"(5 * sizeof(float)), "r"(addr_tworow)); + if constexpr (!WMMA_STORE_FAST) { + asm volatile("fsw f24, %0(%1)" ::"i"(0 * sizeof(float)), "r"(addr)); + asm volatile("fsw f25, %0(%1)" ::"i"(1 * sizeof(float)), "r"(addr)); + asm volatile("fsw f26, %0(%1)" ::"i"(2 * sizeof(float)), "r"(addr_tworow)); + asm volatile("fsw f27, %0(%1)" ::"i"(3 * sizeof(float)), "r"(addr_tworow)); + asm volatile("fsw f28, %0(%1)" ::"i"(4 * sizeof(float)), "r"(addr)); + asm volatile("fsw f29, %0(%1)" ::"i"(5 * sizeof(float)), "r"(addr)); + asm volatile("fsw f30, %0(%1)" ::"i"(6 * sizeof(float)), "r"(addr_tworow)); + asm volatile("fsw f31, %0(%1)" ::"i"(7 * sizeof(float)), "r"(addr_tworow)); + } else { + asm volatile("fsw f24, %0(%1)" ::"i"(0 * WN * sizeof(float)), "r"(addr)); + asm volatile("fsw f25, %0(%1)" ::"i"(1 * WN * sizeof(float)), "r"(addr)); + asm volatile("fsw f26, %0(%1)" ::"i"(2 * WN * sizeof(float)), "r"(addr)); + asm volatile("fsw f27, %0(%1)" ::"i"(3 * WN * sizeof(float)), "r"(addr)); + asm volatile("fsw f28, %0(%1)" ::"i"(4 * WN * sizeof(float)), "r"(addr)); + asm volatile("fsw f29, %0(%1)" ::"i"(5 * WN * sizeof(float)), "r"(addr)); + asm volatile("fsw f30, %0(%1)" ::"i"(6 * WN * sizeof(float)), "r"(addr)); + asm volatile("fsw f31, %0(%1)" ::"i"(7 * WN * sizeof(float)), "r"(addr)); + } } asm volatile ("wmma_store_finish_%=:" :: );