diff --git a/tests/regression/sgemm_tcore/sgemm_impl.hpp b/tests/regression/sgemm_tcore/sgemm_impl.hpp index 41cc5a17..e018fb25 100644 --- a/tests/regression/sgemm_tcore/sgemm_impl.hpp +++ b/tests/regression/sgemm_tcore/sgemm_impl.hpp @@ -276,10 +276,11 @@ template -inline void wmma_load_a(volatile const T *smem_A, const int local_k, - const int warp_row, const int wm_iter, - const int thread_in_warp) { - asm volatile ("wmma_load_a_start_%=:" :: ); +inline volatile const uint8_t * +generate_smem_addr_a(volatile const T *smem_A, const int local_k, + const int warp_row, const int wm_iter, + const int thread_in_warp) { + asm volatile ("generate_smem_addr_a_start_%=:" :: ); const int tid = thread_in_warp; const int tg = tid / 4; @@ -316,10 +317,36 @@ inline void wmma_load_a(volatile const T *smem_A, const int local_k, smem_A_cols>(smem_logical_row, smem_logical_col); - const volatile uint8_t *smem_addr; - smem_addr = reinterpret_cast( + return reinterpret_cast( &reinterpret_cast( smem_A)[smem_A_cols * smem_row + smem_col]); + } else if constexpr (layout == MemLayout::MN_major) { + constexpr int smem_AS_cols = leading_dim; + + return reinterpret_cast( + &reinterpret_cast( + smem_A)[((local_k_adjusted + 0) * smem_AS_cols) + + (WM * warp_row + TCM * wm_iter) + row]); + } else { + static_assert(layout == + MemLayout::K_major /* fake cond that is always false */, + "unsupported memory layout"); + } + + asm volatile ("generate_smem_addr_a_finish_%=:" :: ); +} + +template +inline void wmma_load_a(volatile const T *smem_A, const int local_k, + const int warp_row, const int wm_iter, + const int thread_in_warp) { + asm volatile ("wmma_load_a_start_%=:" :: ); + + if constexpr (layout == MemLayout::K_major || + layout == MemLayout::block_row_major) { + const volatile uint8_t *smem_addr = + generate_smem_addr_a(smem_A, local_k, warp_row, + wm_iter, thread_in_warp); // step to the next column // @perf: bank conflicts; threads read from different rows // below is correct for GEMMINI_DMA; smem_col is always a multiple of 8, @@ -336,11 +363,9 @@ inline void wmma_load_a(volatile const T *smem_A, const int local_k, } else if constexpr (layout == MemLayout::MN_major) { constexpr int smem_AS_cols = leading_dim; - const volatile uint8_t *smem_addr; - smem_addr = reinterpret_cast( - &reinterpret_cast( - smem_A)[((local_k_adjusted + 0) * smem_AS_cols) + - (WM * warp_row + TCM * wm_iter) + row]); + const volatile uint8_t *smem_addr = + generate_smem_addr_a(smem_A, local_k, warp_row, + wm_iter, thread_in_warp); // f8-f15 stores a single row of A // threads read from different columns; no bank conflicts // asm volatile("flw f0, %0(%1)" :: "i"(smem_AS_cols * 0 * sizeof(float)), "r"(smem_addr)); @@ -393,12 +418,13 @@ inline void wmma_load_a(volatile const T *smem_A, const int local_k, } // `local_k` is assumed to be multiple of TCK -template -inline void wmma_load_b(const volatile T *smem_B, const int local_k, - const int warp_col, const int wn_iter, - const int thread_in_warp) { - asm volatile ("wmma_load_b_start_%=:" :: ); +template +inline volatile const uint8_t * +generate_smem_addr_b(const volatile T *smem_B, const int local_k, + const int warp_col, const int wn_iter, + const int thread_in_warp) { + asm volatile ("generate_smem_addr_b_start_%=:" :: ); static_assert( layout == MemLayout::MN_major || layout == MemLayout::block_row_major, @@ -417,7 +443,7 @@ inline void wmma_load_b(const volatile T *smem_B, const int local_k, const int local_k_adjusted = local_k / packed_factor; // B is stored N-major in smem - constexpr int smem_B_cols = tile_dim_n; + constexpr int smem_B_cols = leading_dim; const uint32_t smem_logical_row = local_k_adjusted + 0; const uint32_t smem_logical_col = (WN * warp_col + TCN * wn_iter) + col; @@ -428,10 +454,27 @@ inline void wmma_load_b(const volatile T *smem_B, const int local_k, smem_B_cols>(smem_logical_row, smem_logical_col); - const volatile uint8_t *smem_addr; - smem_addr = reinterpret_cast( + return reinterpret_cast( &reinterpret_cast( smem_B)[smem_B_cols * smem_row + smem_col]); + + asm volatile ("generate_smem_addr_b_finish_%=:" :: ); +} + +template +inline void wmma_load_b(const volatile T *smem_B, const int local_k, + const int warp_col, const int wn_iter, + const int thread_in_warp) { + asm volatile ("wmma_load_b_start_%=:" :: ); + + // B is stored N-major in smem + constexpr int smem_B_cols = leading_dim; + + const volatile uint8_t *smem_addr = + generate_smem_addr_b( + smem_B, local_k, warp_col, wn_iter, thread_in_warp); + // f8-f15 stores a single column of B // threads read from different columns; no bank conflicts if constexpr (layout == MemLayout::block_row_major) { @@ -849,7 +892,7 @@ __attribute__((always_inline)) inline void thread_block_gemm_single_tile( // SMEM -> RF static_assert(leading_dim_b == 0, "leading_dim for wmma_load_b is not implemented yet"); - wmma_load_b( local_b, local_k, warp_col, wn_iter, tid_in_warp); #pragma GCC unroll 2