From a967c262b144e98aed84f4de866b670f41e43f7b Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Sat, 7 Sep 2024 16:38:22 -0700 Subject: [PATCH] sgemm_impl: Add new block-row-major layout for DMA --- tests/regression/sgemm_tcore/sgemm_impl.hpp | 45 ++++++++++++--------- 1 file changed, 27 insertions(+), 18 deletions(-) diff --git a/tests/regression/sgemm_tcore/sgemm_impl.hpp b/tests/regression/sgemm_tcore/sgemm_impl.hpp index db8df789..e563f23c 100644 --- a/tests/regression/sgemm_tcore/sgemm_impl.hpp +++ b/tests/regression/sgemm_tcore/sgemm_impl.hpp @@ -70,10 +70,10 @@ static_assert(WMITER * WNITER * TCM * TCN * NUM_WARPS * CORES_PER_CLUSTER == // To model the case where the A matrix is already stored column-major in GMEM, // set both to 0. #define TRANSPOSE_AT_PRODUCE 0 -#define TRANSPOSE_AT_CONSUME 1 +#define TRANSPOSE_AT_CONSUME 0 #define GEMMINI_DMA 1 -#define GEMMINI_DMA_MN_MAJOR 0 +#define GEMMINI_DMA_FLEXIBLE_LAYOUT 0 #if SMEM_SIZE == 0x4000 #define SMEM_ADDR_Q0 ((float * const) 0xff000000) #define SMEM_ADDR_Q1 ((float * const) 0xff001000) @@ -101,6 +101,7 @@ static_assert(WMITER * WNITER * TCM * TCN * NUM_WARPS * CORES_PER_CLUSTER == enum class MemLayout { MN_major, K_major, + block_row_major, // Gemmini DMA }; inline constexpr void map_operand_32lanes(const int tid, int &row, int &col) { @@ -253,13 +254,14 @@ inline void wmma_load_a(volatile const T *smem_A, const int local_k, constexpr int packed_factor = (std::is_same_v ? 2 : 1); const int local_k_adjusted = local_k / packed_factor; - static_assert(!GEMMINI_DMA || (layout == MemLayout::K_major) || - GEMMINI_DMA_MN_MAJOR, - "GEMMINI_DMA only supported for K-major A tile"); + static_assert(!GEMMINI_DMA || (layout == MemLayout::block_row_major) || + GEMMINI_DMA_FLEXIBLE_LAYOUT, + "wrong memory layout selected for DMA"); static_assert((layout != MemLayout::K_major) || (FP_SIZE == 32), "fp16 is not really tested for K-major A layout"); - if constexpr (layout == MemLayout::K_major) { + if constexpr (layout == MemLayout::K_major || + layout == MemLayout::block_row_major) { constexpr int smem_A_cols = leading_dim; // f8-f15 stores a single row of A @@ -269,8 +271,9 @@ inline void wmma_load_a(volatile const T *smem_A, const int local_k, // if using Gemmini DMA, remap logical row/col to Gemmini's 2-level // block-row-major layout const auto [smem_row, smem_col] = - remap_to_gemmini_dma_layout(smem_logical_row, - smem_logical_col); + remap_to_gemmini_dma_layout(smem_logical_row, + smem_logical_col); const volatile uint8_t *smem_addr; smem_addr = reinterpret_cast( @@ -356,8 +359,9 @@ inline void wmma_load_b(const volatile T *smem_B, const int local_k, const int thread_in_warp) { asm volatile ("wmma_load_b_start_%=:" :: ); - static_assert(layout == MemLayout::MN_major, - "only N-major layout for the B tile is supported"); + static_assert( + layout == MemLayout::MN_major || layout == MemLayout::block_row_major, + "only N-major or block-row-major layout are supported for the B tile"); const int tid = thread_in_warp; const int tg = tid / 4; @@ -379,8 +383,9 @@ inline void wmma_load_b(const volatile T *smem_B, const int local_k, // if using Gemmini DMA, remap logical row/col to Gemmini's 2-level // block-row-major layout const auto [smem_row, smem_col] = - remap_to_gemmini_dma_layout(smem_logical_row, - smem_logical_col); + remap_to_gemmini_dma_layout(smem_logical_row, + smem_logical_col); const volatile uint8_t *smem_addr; smem_addr = reinterpret_cast( @@ -388,10 +393,10 @@ inline void wmma_load_b(const volatile T *smem_B, const int local_k, smem_B)[smem_B_cols * smem_row + smem_col]); // f8-f15 stores a single column of B // threads read from different columns; no bank conflicts - if constexpr (GEMMINI_DMA) { - // for GEMMINI_DMA, moving rows for the next 7 elements in the same column - // is the same as moving DIM elements forward in the memory because of the - // block-row-major layout + if constexpr (layout == MemLayout::block_row_major) { + // for the block-row-major layout, moving rows for the next 7 elements in + // the same column is the same as moving DIM elements forward in the memory + // because of the block-row-major layout asm volatile("flw f8, %0(%1)" :: "i"(DIM * 0 * sizeof(float)), "r"(smem_addr)); asm volatile("flw f9, %0(%1)" :: "i"(DIM * 1 * sizeof(float)), "r"(smem_addr)); asm volatile("flw f10, %0(%1)" :: "i"(DIM * 2 * sizeof(float)), "r"(smem_addr)); @@ -1064,8 +1069,12 @@ inline void thread_block_gemm(const T *A, const T *B, float *C, } constexpr MemLayout layout_a = - TRANSPOSE_AT_CONSUME ? MemLayout::K_major : MemLayout::MN_major; - thread_block_gemm_single_tile(