From ced98a6ff45ade1879ee9755ef3f7968e32ae6a8 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Tue, 3 Sep 2024 16:20:31 -0700 Subject: [PATCH] sgemm_impl: Refactor DMA layout remap logic into constexpr func --- tests/regression/sgemm_tcore/sgemm_impl.hpp | 69 +++++++++++---------- 1 file changed, 36 insertions(+), 33 deletions(-) diff --git a/tests/regression/sgemm_tcore/sgemm_impl.hpp b/tests/regression/sgemm_tcore/sgemm_impl.hpp index bcb4f13e..6674edd0 100644 --- a/tests/regression/sgemm_tcore/sgemm_impl.hpp +++ b/tests/regression/sgemm_tcore/sgemm_impl.hpp @@ -72,7 +72,7 @@ using float_type = float16_t; #define TRANSPOSE_AT_PRODUCE 0 #define TRANSPOSE_AT_CONSUME 0 -#define GEMMINI_DMA 0 +#define GEMMINI_DMA 1 #define GEMMINI_DMA_MN_MAJOR 1 #if SMEM_SIZE == 0x4000 #define SMEM_ADDR_Q0 ((float * const) 0xff000000) @@ -200,6 +200,28 @@ inline void vx_wmma(const int dest_reg) { } } +// Remap logical row/col coordinate of a matrix element to a memory index that +// follows the 2-level block-row-major layout that Gemmini DMA uses +template +inline constexpr std::pair +remap_to_gemmini_dma_layout(const uint32_t logical_row, + const uint32_t logical_col) { + static_assert(DIM == 8, + "GEMMINI_DMA layout remapping code only written for DIM == 8"); + + if constexpr (use_dma) { + constexpr int dim_blocks_in_row = (dim_col / DIM); + const uint32_t row = + (logical_row / dim_blocks_in_row) * DIM + (logical_col / DIM); + const uint32_t col = + (logical_row % dim_blocks_in_row) * DIM + (logical_col % DIM); + return {row, col}; + } else { + // pass-through + return {logical_row, logical_col}; + } +} + // `local_k` is assumed to be multiple of TCK template (smem_logical_row, + smem_logical_col); const volatile uint8_t *smem_addr; smem_addr = reinterpret_cast( @@ -356,20 +367,11 @@ inline void wmma_load_b(const volatile T *smem_B, const int local_k, const uint32_t smem_logical_row = local_k_adjusted + 0; const uint32_t smem_logical_col = (WN * warp_col + TCN * wn_iter) + col; - uint32_t smem_row; - uint32_t smem_col; - if constexpr (GEMMINI_DMA) { - // if using Gemmini DMA, remap logical row/col to Gemmini's 2-level - // block-row-major layout - constexpr int dim_blocks_in_row = (smem_B_cols / DIM); - smem_row = - (smem_logical_row / dim_blocks_in_row) * DIM + (smem_logical_col / DIM); - smem_col = - (smem_logical_row % dim_blocks_in_row) * DIM + (smem_logical_col % DIM); - } else { - smem_row = smem_logical_row; - smem_col = smem_logical_col; - } + // if using Gemmini DMA, remap logical row/col to Gemmini's 2-level + // block-row-major layout + const auto [smem_row, smem_col] = + remap_to_gemmini_dma_layout(smem_logical_row, + smem_logical_col); const volatile uint8_t *smem_addr; smem_addr = reinterpret_cast( @@ -475,6 +477,7 @@ wmma_load_accum(const int thread_in_warp, const int warp_col, asm volatile("wmma_load_accum_finish_%=:" ::); } +// Write out the matrix data stored in RF to memory __attribute__((always_inline)) inline void wmma_store(const int thread_in_warp, const int warp_col, const int warp_row, const int wn_iter, const int wm_iter, const int dim_n,