From 18e3653d31503598b435f98122e0613bca22f8c7 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Mon, 3 Jun 2024 21:10:42 -0700 Subject: [PATCH] sgemm_tcore: Increase RF data reuse for WMITER/WNITER ... by splitting vx_wmma_load to vx_wmma_load_{a,b} and pulling it out of the innermost loop. TODO: there's some duplicate address compute being done in the both functions. --- tests/regression/sgemm_tcore/kernel.cpp | 35 ++++++++++++++++++++----- 1 file changed, 28 insertions(+), 7 deletions(-) diff --git a/tests/regression/sgemm_tcore/kernel.cpp b/tests/regression/sgemm_tcore/kernel.cpp index 4ac80775..69451813 100644 --- a/tests/regression/sgemm_tcore/kernel.cpp +++ b/tests/regression/sgemm_tcore/kernel.cpp @@ -142,12 +142,12 @@ inline void vx_wmma(const int dest_reg) { } // `local_k` is assumed to be multiple of TCK -inline void vx_wmma_load(volatile float *smem_A, volatile float *smem_B, const int local_k, - const int warp_col, const int warp_row, const int wn_iter, - const int wm_iter, const int thread_in_warp) { +inline void vx_wmma_load_a(volatile float *smem_A, const int local_k, + const int warp_row, const int wm_iter, const int thread_in_warp) { const int tid = thread_in_warp; const int tg = tid / 4; + // TODO: this is duplicately computed between vx_wmma_load_a and vx_wmma_load_b int row = 0; int col = 0; map_operand(tid, row, col); @@ -188,6 +188,25 @@ inline void vx_wmma_load(volatile float *smem_A, volatile float *smem_B, const i // asm volatile("flw f0, %0" ::"m"(smem_A[((local_k + i) * smem_A_rows) + (WM * warp_row + TCM * wm_iter) + row])); // } } +} + +// `local_k` is assumed to be multiple of TCK +inline void vx_wmma_load_b(volatile float *smem_B, const int local_k, + const int warp_col, const int wn_iter, + const int thread_in_warp) { + const int tid = thread_in_warp; + const int tg = tid / 4; + + int row = 0; + int col = 0; + map_operand(tid, row, col); + + constexpr int smem_A_rows = BM; + constexpr int smem_A_cols = BK; + constexpr int smem_AS_rows = BK; + constexpr int smem_AS_cols = BM; + constexpr int smem_B_rows = BK; + constexpr int smem_B_cols = BN; // f8-f15 stores a single column of B asm volatile("flw f8, %0" ::"m"(smem_B[((local_k + 0) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col])); @@ -401,9 +420,11 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, // FIXME: this is wrong!! need separate accumulation register for // WM/WN_ITERS #pragma GCC unroll 2 - for (int wm_iter = 0; wm_iter < WMITER; wm_iter++) { + for (int wn_iter = 0; wn_iter < WNITER; wn_iter++) { + vx_wmma_load_b(local_b, local_k, warp_col, wn_iter, tid_in_warp); + // vx_wmma_load_b(local_b, 0, 0, 0, tid_in_warp); #pragma GCC unroll 2 - for (int wn_iter = 0; wn_iter < WNITER; wn_iter++) { + for (int wm_iter = 0; wm_iter < WMITER; wm_iter++) { #if TC_SINGLE_WARP if (warp_in_threadblock == 0) { #endif @@ -419,8 +440,8 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, // asm volatile("addi a0, a0, 0"); // } // SMEM -> RF - vx_wmma_load(local_a, local_b, local_k, warp_col, warp_row, - wn_iter, wm_iter, tid_in_warp); + vx_wmma_load_a(local_a, local_k, warp_row, wm_iter, tid_in_warp); + // vx_wmma_load_a(local_a, 0, 0, 0, tid_in_warp); // compute vx_wmma(wm_iter); #if TC_SINGLE_WARP