sgemm_impl: Rename to wmma

This commit is contained in:
Hansung Kim
2024-08-18 16:21:22 -07:00
parent b978bf8757
commit b44b202a21
2 changed files with 31 additions and 32 deletions

View File

@@ -205,15 +205,15 @@ inline void vx_wmma(const int dest_reg) {
// `local_k` is assumed to be multiple of TCK
template <typename T>
inline void vx_wmma_load_a(volatile const T *smem_A, const int local_k,
inline void wmma_load_a(volatile const T *smem_A, const int local_k,
const int warp_row, const int wm_iter,
const int thread_in_warp) {
asm volatile ("vx_wmma_load_a_start_%=:" :: );
asm volatile ("wmma_load_a_start_%=:" :: );
const int tid = thread_in_warp;
const int tg = tid / 4;
// @perf: this is duplicately computed in vx_wmma_load_a and vx_wmma_load_b
// @perf: this is duplicately computed in wmma_load_a and wmma_load_b
int row = 0;
int col = 0;
map_operand(tid, row, col);
@@ -273,15 +273,15 @@ inline void vx_wmma_load_a(volatile const T *smem_A, const int local_k,
asm volatile("flw f7, %0(%1)" :: "i"(smem_AS_cols * 7 * sizeof(float)), "r"(smem_addr));
}
asm volatile ("vx_wmma_load_a_finish_%=:" :: );
asm volatile ("wmma_load_a_finish_%=:" :: );
}
// `local_k` is assumed to be multiple of TCK
template <typename T>
inline void vx_wmma_load_b(const volatile T *smem_B, const int local_k,
inline void wmma_load_b(const volatile T *smem_B, const int local_k,
const int warp_col, const int wn_iter,
const int thread_in_warp) {
asm volatile ("vx_wmma_load_b_start_%=:" :: );
asm volatile ("wmma_load_b_start_%=:" :: );
const int tid = thread_in_warp;
const int tg = tid / 4;
@@ -290,7 +290,7 @@ inline void vx_wmma_load_b(const volatile T *smem_B, const int local_k,
int col = 0;
map_operand(tid, row, col);
// see comment in vx_wmma_load_a
// see comment in wmma_load_a
constexpr int packed_factor = (std::is_same_v<T, float16_t> ? 2 : 1);
constexpr int BK_adjusted = BN / packed_factor;
constexpr int BN_adjusted = BN / packed_factor;
@@ -316,7 +316,7 @@ inline void vx_wmma_load_b(const volatile T *smem_B, const int local_k,
asm volatile("flw f14, %0(%1)" :: "i"(smem_B_cols * 6 * sizeof(float)), "r"(smem_addr));
asm volatile("flw f15, %0(%1)" :: "i"(smem_B_cols * 7 * sizeof(float)), "r"(smem_addr));
asm volatile ("vx_wmma_load_b_finish_%=:" :: );
asm volatile ("wmma_load_b_finish_%=:" :: );
}
inline void initialize_C(const int dest_reg) {
@@ -659,11 +659,11 @@ thread_block_gemm_single_tile(const T *local_a, const T *local_b, T *local_c,
#pragma GCC unroll 2
for (int wn_iter = 0; wn_iter < WNITER; wn_iter++) {
// SMEM -> RF
vx_wmma_load_b<T>(local_b, local_k, warp_col, wn_iter, tid_in_warp);
wmma_load_b<T>(local_b, local_k, warp_col, wn_iter, tid_in_warp);
#pragma GCC unroll 2
for (int wm_iter = 0; wm_iter < WMITER; wm_iter++) {
// SMEM -> RF
vx_wmma_load_a<T>(local_a, local_k, warp_row, wm_iter, tid_in_warp);
wmma_load_a<T>(local_a, local_k, warp_row, wm_iter, tid_in_warp);
// perform mma
vx_wmma(wm_iter);
}