sgemm_impl: Split out smem addr gen to functions
so that the addr gen code can also be used for wgmma.
This commit is contained in:
@@ -276,10 +276,11 @@ template <typename T, MemLayout layout,
|
||||
// becomes the stride between the 1st M-dim
|
||||
// vector and the 2nd M-dim vector.
|
||||
>
|
||||
inline void wmma_load_a(volatile const T *smem_A, const int local_k,
|
||||
const int warp_row, const int wm_iter,
|
||||
const int thread_in_warp) {
|
||||
asm volatile ("wmma_load_a_start_%=:" :: );
|
||||
inline volatile const uint8_t *
|
||||
generate_smem_addr_a(volatile const T *smem_A, const int local_k,
|
||||
const int warp_row, const int wm_iter,
|
||||
const int thread_in_warp) {
|
||||
asm volatile ("generate_smem_addr_a_start_%=:" :: );
|
||||
|
||||
const int tid = thread_in_warp;
|
||||
const int tg = tid / 4;
|
||||
@@ -316,10 +317,36 @@ inline void wmma_load_a(volatile const T *smem_A, const int local_k,
|
||||
smem_A_cols>(smem_logical_row,
|
||||
smem_logical_col);
|
||||
|
||||
const volatile uint8_t *smem_addr;
|
||||
smem_addr = reinterpret_cast<const volatile uint8_t *>(
|
||||
return reinterpret_cast<const volatile uint8_t *>(
|
||||
&reinterpret_cast<const volatile float *>(
|
||||
smem_A)[smem_A_cols * smem_row + smem_col]);
|
||||
} else if constexpr (layout == MemLayout::MN_major) {
|
||||
constexpr int smem_AS_cols = leading_dim;
|
||||
|
||||
return reinterpret_cast<const volatile uint8_t *>(
|
||||
&reinterpret_cast<const volatile float *>(
|
||||
smem_A)[((local_k_adjusted + 0) * smem_AS_cols) +
|
||||
(WM * warp_row + TCM * wm_iter) + row]);
|
||||
} else {
|
||||
static_assert(layout ==
|
||||
MemLayout::K_major /* fake cond that is always false */,
|
||||
"unsupported memory layout");
|
||||
}
|
||||
|
||||
asm volatile ("generate_smem_addr_a_finish_%=:" :: );
|
||||
}
|
||||
|
||||
template <typename T, MemLayout layout, uint32_t leading_dim>
|
||||
inline void wmma_load_a(volatile const T *smem_A, const int local_k,
|
||||
const int warp_row, const int wm_iter,
|
||||
const int thread_in_warp) {
|
||||
asm volatile ("wmma_load_a_start_%=:" :: );
|
||||
|
||||
if constexpr (layout == MemLayout::K_major ||
|
||||
layout == MemLayout::block_row_major) {
|
||||
const volatile uint8_t *smem_addr =
|
||||
generate_smem_addr_a<T, layout, leading_dim>(smem_A, local_k, warp_row,
|
||||
wm_iter, thread_in_warp);
|
||||
// step to the next column
|
||||
// @perf: bank conflicts; threads read from different rows
|
||||
// below is correct for GEMMINI_DMA; smem_col is always a multiple of 8,
|
||||
@@ -336,11 +363,9 @@ inline void wmma_load_a(volatile const T *smem_A, const int local_k,
|
||||
} else if constexpr (layout == MemLayout::MN_major) {
|
||||
constexpr int smem_AS_cols = leading_dim;
|
||||
|
||||
const volatile uint8_t *smem_addr;
|
||||
smem_addr = reinterpret_cast<const volatile uint8_t *>(
|
||||
&reinterpret_cast<const volatile float *>(
|
||||
smem_A)[((local_k_adjusted + 0) * smem_AS_cols) +
|
||||
(WM * warp_row + TCM * wm_iter) + row]);
|
||||
const volatile uint8_t *smem_addr =
|
||||
generate_smem_addr_a<T, layout, leading_dim>(smem_A, local_k, warp_row,
|
||||
wm_iter, thread_in_warp);
|
||||
// f8-f15 stores a single row of A
|
||||
// threads read from different columns; no bank conflicts
|
||||
// asm volatile("flw f0, %0(%1)" :: "i"(smem_AS_cols * 0 * sizeof(float)), "r"(smem_addr));
|
||||
@@ -393,12 +418,13 @@ inline void wmma_load_a(volatile const T *smem_A, const int local_k,
|
||||
}
|
||||
|
||||
// `local_k` is assumed to be multiple of TCK
|
||||
template <typename T, MemLayout layout, uint32_t tile_dim_m,
|
||||
uint32_t tile_dim_n, uint32_t tile_dim_k>
|
||||
inline void wmma_load_b(const volatile T *smem_B, const int local_k,
|
||||
const int warp_col, const int wn_iter,
|
||||
const int thread_in_warp) {
|
||||
asm volatile ("wmma_load_b_start_%=:" :: );
|
||||
template <typename T, MemLayout layout, uint32_t leading_dim,
|
||||
uint32_t tile_dim_k>
|
||||
inline volatile const uint8_t *
|
||||
generate_smem_addr_b(const volatile T *smem_B, const int local_k,
|
||||
const int warp_col, const int wn_iter,
|
||||
const int thread_in_warp) {
|
||||
asm volatile ("generate_smem_addr_b_start_%=:" :: );
|
||||
|
||||
static_assert(
|
||||
layout == MemLayout::MN_major || layout == MemLayout::block_row_major,
|
||||
@@ -417,7 +443,7 @@ inline void wmma_load_b(const volatile T *smem_B, const int local_k,
|
||||
const int local_k_adjusted = local_k / packed_factor;
|
||||
|
||||
// B is stored N-major in smem
|
||||
constexpr int smem_B_cols = tile_dim_n;
|
||||
constexpr int smem_B_cols = leading_dim;
|
||||
|
||||
const uint32_t smem_logical_row = local_k_adjusted + 0;
|
||||
const uint32_t smem_logical_col = (WN * warp_col + TCN * wn_iter) + col;
|
||||
@@ -428,10 +454,27 @@ inline void wmma_load_b(const volatile T *smem_B, const int local_k,
|
||||
smem_B_cols>(smem_logical_row,
|
||||
smem_logical_col);
|
||||
|
||||
const volatile uint8_t *smem_addr;
|
||||
smem_addr = reinterpret_cast<const volatile uint8_t *>(
|
||||
return reinterpret_cast<const volatile uint8_t *>(
|
||||
&reinterpret_cast<const volatile float *>(
|
||||
smem_B)[smem_B_cols * smem_row + smem_col]);
|
||||
|
||||
asm volatile ("generate_smem_addr_b_finish_%=:" :: );
|
||||
}
|
||||
|
||||
template <typename T, MemLayout layout, uint32_t leading_dim,
|
||||
uint32_t tile_dim_k>
|
||||
inline void wmma_load_b(const volatile T *smem_B, const int local_k,
|
||||
const int warp_col, const int wn_iter,
|
||||
const int thread_in_warp) {
|
||||
asm volatile ("wmma_load_b_start_%=:" :: );
|
||||
|
||||
// B is stored N-major in smem
|
||||
constexpr int smem_B_cols = leading_dim;
|
||||
|
||||
const volatile uint8_t *smem_addr =
|
||||
generate_smem_addr_b<T, layout, leading_dim, tile_dim_k>(
|
||||
smem_B, local_k, warp_col, wn_iter, thread_in_warp);
|
||||
|
||||
// f8-f15 stores a single column of B
|
||||
// threads read from different columns; no bank conflicts
|
||||
if constexpr (layout == MemLayout::block_row_major) {
|
||||
@@ -849,7 +892,7 @@ __attribute__((always_inline)) inline void thread_block_gemm_single_tile(
|
||||
// SMEM -> RF
|
||||
static_assert(leading_dim_b == 0,
|
||||
"leading_dim for wmma_load_b is not implemented yet");
|
||||
wmma_load_b<T, layout_b, tile_dim_m, tile_dim_n,
|
||||
wmma_load_b<T, layout_b, tile_dim_n,
|
||||
tile_dim_k /*leading_dim_b is TODO */>(
|
||||
local_b, local_k, warp_col, wn_iter, tid_in_warp);
|
||||
#pragma GCC unroll 2
|
||||
|
||||
Reference in New Issue
Block a user