sgemm_impl: Refactor DMA layout remap logic into constexpr func
This commit is contained in:
@@ -72,7 +72,7 @@ using float_type = float16_t;
|
|||||||
#define TRANSPOSE_AT_PRODUCE 0
|
#define TRANSPOSE_AT_PRODUCE 0
|
||||||
#define TRANSPOSE_AT_CONSUME 0
|
#define TRANSPOSE_AT_CONSUME 0
|
||||||
|
|
||||||
#define GEMMINI_DMA 0
|
#define GEMMINI_DMA 1
|
||||||
#define GEMMINI_DMA_MN_MAJOR 1
|
#define GEMMINI_DMA_MN_MAJOR 1
|
||||||
#if SMEM_SIZE == 0x4000
|
#if SMEM_SIZE == 0x4000
|
||||||
#define SMEM_ADDR_Q0 ((float * const) 0xff000000)
|
#define SMEM_ADDR_Q0 ((float * const) 0xff000000)
|
||||||
@@ -200,6 +200,28 @@ inline void vx_wmma(const int dest_reg) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Remap logical row/col coordinate of a matrix element to a memory index that
|
||||||
|
// follows the 2-level block-row-major layout that Gemmini DMA uses
|
||||||
|
template <bool use_dma, uint32_t dim_col>
|
||||||
|
inline constexpr std::pair<uint32_t, uint32_t>
|
||||||
|
remap_to_gemmini_dma_layout(const uint32_t logical_row,
|
||||||
|
const uint32_t logical_col) {
|
||||||
|
static_assert(DIM == 8,
|
||||||
|
"GEMMINI_DMA layout remapping code only written for DIM == 8");
|
||||||
|
|
||||||
|
if constexpr (use_dma) {
|
||||||
|
constexpr int dim_blocks_in_row = (dim_col / DIM);
|
||||||
|
const uint32_t row =
|
||||||
|
(logical_row / dim_blocks_in_row) * DIM + (logical_col / DIM);
|
||||||
|
const uint32_t col =
|
||||||
|
(logical_row % dim_blocks_in_row) * DIM + (logical_col % DIM);
|
||||||
|
return {row, col};
|
||||||
|
} else {
|
||||||
|
// pass-through
|
||||||
|
return {logical_row, logical_col};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// `local_k` is assumed to be multiple of TCK
|
// `local_k` is assumed to be multiple of TCK
|
||||||
template <typename T, MemLayout layout,
|
template <typename T, MemLayout layout,
|
||||||
uint32_t leading_dim // stride in sizeof(T) between consecutive
|
uint32_t leading_dim // stride in sizeof(T) between consecutive
|
||||||
@@ -242,24 +264,13 @@ inline void wmma_load_a(volatile const T *smem_A, const int local_k,
|
|||||||
|
|
||||||
// f8-f15 stores a single row of A
|
// f8-f15 stores a single row of A
|
||||||
const uint32_t smem_logical_row = WM * warp_row + TCM * wm_iter + row;
|
const uint32_t smem_logical_row = WM * warp_row + TCM * wm_iter + row;
|
||||||
const uint32_t smem_logical_col = local_k_adjusted + 0; /* FIXME: fp16 adjust necessary? */
|
const uint32_t smem_logical_col =
|
||||||
uint32_t smem_row;
|
local_k_adjusted + 0; /* FIXME: fp16 adjust necessary? */
|
||||||
uint32_t smem_col;
|
// if using Gemmini DMA, remap logical row/col to Gemmini's 2-level
|
||||||
if constexpr (GEMMINI_DMA) {
|
// block-row-major layout
|
||||||
// if using Gemmini DMA, remap logical row/col to Gemmini's 2-level
|
const auto [smem_row, smem_col] =
|
||||||
// block-row-major layout
|
remap_to_gemmini_dma_layout<GEMMINI_DMA, smem_A_cols>(smem_logical_row,
|
||||||
static_assert(
|
smem_logical_col);
|
||||||
DIM == 8,
|
|
||||||
"GEMMINI_DMA layout remapping code only written for DIM == 8");
|
|
||||||
constexpr int dim_blocks_in_row = (smem_A_cols / DIM);
|
|
||||||
smem_row = (smem_logical_row / dim_blocks_in_row) * DIM +
|
|
||||||
(smem_logical_col / DIM);
|
|
||||||
smem_col = (smem_logical_row % dim_blocks_in_row) * DIM +
|
|
||||||
(smem_logical_col % DIM);
|
|
||||||
} else {
|
|
||||||
smem_row = smem_logical_row;
|
|
||||||
smem_col = smem_logical_col;
|
|
||||||
}
|
|
||||||
|
|
||||||
const volatile uint8_t *smem_addr;
|
const volatile uint8_t *smem_addr;
|
||||||
smem_addr = reinterpret_cast<const volatile uint8_t *>(
|
smem_addr = reinterpret_cast<const volatile uint8_t *>(
|
||||||
@@ -356,20 +367,11 @@ inline void wmma_load_b(const volatile T *smem_B, const int local_k,
|
|||||||
|
|
||||||
const uint32_t smem_logical_row = local_k_adjusted + 0;
|
const uint32_t smem_logical_row = local_k_adjusted + 0;
|
||||||
const uint32_t smem_logical_col = (WN * warp_col + TCN * wn_iter) + col;
|
const uint32_t smem_logical_col = (WN * warp_col + TCN * wn_iter) + col;
|
||||||
uint32_t smem_row;
|
// if using Gemmini DMA, remap logical row/col to Gemmini's 2-level
|
||||||
uint32_t smem_col;
|
// block-row-major layout
|
||||||
if constexpr (GEMMINI_DMA) {
|
const auto [smem_row, smem_col] =
|
||||||
// if using Gemmini DMA, remap logical row/col to Gemmini's 2-level
|
remap_to_gemmini_dma_layout<GEMMINI_DMA, smem_B_cols>(smem_logical_row,
|
||||||
// block-row-major layout
|
smem_logical_col);
|
||||||
constexpr int dim_blocks_in_row = (smem_B_cols / DIM);
|
|
||||||
smem_row =
|
|
||||||
(smem_logical_row / dim_blocks_in_row) * DIM + (smem_logical_col / DIM);
|
|
||||||
smem_col =
|
|
||||||
(smem_logical_row % dim_blocks_in_row) * DIM + (smem_logical_col % DIM);
|
|
||||||
} else {
|
|
||||||
smem_row = smem_logical_row;
|
|
||||||
smem_col = smem_logical_col;
|
|
||||||
}
|
|
||||||
|
|
||||||
const volatile uint8_t *smem_addr;
|
const volatile uint8_t *smem_addr;
|
||||||
smem_addr = reinterpret_cast<const volatile uint8_t *>(
|
smem_addr = reinterpret_cast<const volatile uint8_t *>(
|
||||||
@@ -475,6 +477,7 @@ wmma_load_accum(const int thread_in_warp, const int warp_col,
|
|||||||
asm volatile("wmma_load_accum_finish_%=:" ::);
|
asm volatile("wmma_load_accum_finish_%=:" ::);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Write out the matrix data stored in RF to memory
|
||||||
__attribute__((always_inline)) inline void
|
__attribute__((always_inline)) inline void
|
||||||
wmma_store(const int thread_in_warp, const int warp_col, const int warp_row,
|
wmma_store(const int thread_in_warp, const int warp_col, const int warp_row,
|
||||||
const int wn_iter, const int wm_iter, const int dim_n,
|
const int wn_iter, const int wm_iter, const int dim_n,
|
||||||
|
|||||||
Reference in New Issue
Block a user