diff --git a/tests/regression/sgemm_tcore/sgemm_impl.hpp b/tests/regression/sgemm_tcore/sgemm_impl.hpp index 8f8be348..1f461aa0 100644 --- a/tests/regression/sgemm_tcore/sgemm_impl.hpp +++ b/tests/regression/sgemm_tcore/sgemm_impl.hpp @@ -343,13 +343,61 @@ template inline void initialize_accum_regs() { } } +// `C` is expected to be in N-major layout. +__attribute__((always_inline)) inline void +wmma_load_accum(const int thread_in_warp, const int warp_col, + const int warp_row, const int wn_iter, const int wm_iter, + const int dim_n, const float *C) { + asm volatile("wmma_load_accum_start_%=:" ::); + + const int tid = thread_in_warp; + + // these are [0, TCM/TCN) + int tid_row = 0; + int tid_col = 0; + map_c(tid, tid_row, tid_col); + + int local_row = (WM * warp_row + TCM * wm_iter) + tid_row; + int local_col = (WN * warp_col + TCN * wn_iter) + tid_col; + + // @copypaste from wmma_store + // @perf: this likely causes a lot of gmem bank conflicts + if (wm_iter == 0) { + const uint8_t *addr = reinterpret_cast( + &C[dim_n * (local_row + 0) + (local_col + 0)]); + const uint8_t *addr_tworow = addr + (2 * dim_n) * sizeof(float); + asm volatile("flw f16, %0(%1)" ::"i"(0 * sizeof(float)), "r"(addr)); + asm volatile("flw f17, %0(%1)" ::"i"(1 * sizeof(float)), "r"(addr)); + asm volatile("flw f18, %0(%1)" ::"i"(0 * sizeof(float)), "r"(addr_tworow)); + asm volatile("flw f19, %0(%1)" ::"i"(1 * sizeof(float)), "r"(addr_tworow)); + asm volatile("flw f20, %0(%1)" ::"i"(4 * sizeof(float)), "r"(addr)); + asm volatile("flw f21, %0(%1)" ::"i"(5 * sizeof(float)), "r"(addr)); + asm volatile("flw f22, %0(%1)" ::"i"(4 * sizeof(float)), "r"(addr_tworow)); + asm volatile("flw f23, %0(%1)" ::"i"(5 * sizeof(float)), "r"(addr_tworow)); + } else { + const uint8_t *addr = reinterpret_cast( + &C[dim_n * (local_row + 0) + (local_col + 0)]); + const uint8_t *addr_tworow = addr + (2 * dim_n) * sizeof(float); + asm volatile("flw f24, %0(%1)" ::"i"(0 * sizeof(float)), "r"(addr)); + asm volatile("flw f25, %0(%1)" ::"i"(1 * sizeof(float)), "r"(addr)); + asm volatile("flw f26, %0(%1)" ::"i"(0 * sizeof(float)), "r"(addr_tworow)); + asm volatile("flw f27, %0(%1)" ::"i"(1 * sizeof(float)), "r"(addr_tworow)); + asm volatile("flw f28, %0(%1)" ::"i"(4 * sizeof(float)), "r"(addr)); + asm volatile("flw f29, %0(%1)" ::"i"(5 * sizeof(float)), "r"(addr)); + asm volatile("flw f30, %0(%1)" ::"i"(4 * sizeof(float)), "r"(addr_tworow)); + asm volatile("flw f31, %0(%1)" ::"i"(5 * sizeof(float)), "r"(addr_tworow)); + } + + asm volatile("wmma_load_accum_finish_%=:" ::); +} + __attribute__((always_inline)) inline void wmma_store(const int thread_in_warp, const int warp_col, const int warp_row, const int wn_iter, const int wm_iter, const int dim_n, float *write_addr) { asm volatile ("wmma_store_start_%=:" :: ); - int tid = thread_in_warp; + const int tid = thread_in_warp; // these are [0, TCM/TCN) int tid_row = 0; @@ -560,17 +608,19 @@ load_tile_to_smem(const uint32_t dim_major, const uint32_t mn_index, // Do a single tile*tile matrix multiplication using the matrix data stored in // SMEM. Useful in fused kernels where GEMMs are done at a per-tile scope. template -__attribute__((always_inline)) inline void -thread_block_gemm_single_tile(const T *local_a, const T *local_b, T *local_c, - const uint32_t tid_in_threadblock, - const uint32_t threads_per_threadblock, - const uint32_t threadblocks_per_cluster, - const uint32_t threadblock_id_in_cluster) { +__attribute__((always_inline)) inline void thread_block_gemm_single_tile( + const T *local_a, const T *local_b, const T *local_c, T *result_addr, + const uint32_t tid_in_threadblock, const uint32_t threads_per_threadblock, + const uint32_t threadblocks_per_cluster, + const uint32_t threadblock_id_in_cluster) { // no double-buffering // FIXME: duplicated from thread_block_gemm const uint32_t threads_per_warpgroup = threads_per_threadblock; @@ -581,6 +631,21 @@ thread_block_gemm_single_tile(const T *local_a, const T *local_b, T *local_c, const uint32_t warps_per_threadblock_per_core = NUM_WARPS / threadblocks_per_cluster; + // TODO: it would be useful if this bit is split out into a function, so that + // preloading accumulation tile can be used for full GEMMs at the start of + // the K-loop. + if constexpr (load_accum) { +#pragma GCC unroll + for (int wm_iter = 0; wm_iter < WMITER; wm_iter++) { +#pragma GCC unroll + for (int wn_iter = 0; wn_iter < WNITER; wn_iter++) { + // FIXME: template parameter-ize BM + wmma_load_accum(tid_in_warp, warp_col, warp_row, wn_iter, wm_iter, BN, + local_c); + } + } + } + #pragma GCC unroll 1 for (int i = 0; i < BK_LOOP; i++) { #pragma GCC unroll 4 @@ -611,7 +676,7 @@ thread_block_gemm_single_tile(const T *local_a, const T *local_b, T *local_c, } } - if constexpr (write_to_smem) { + if constexpr (write_to_mem) { // need to protect smem reads in the earlier step from writes in below, // especially when the destination smem address overlaps with the input threadblock_barrier(threadblock_id_in_cluster, @@ -622,7 +687,7 @@ thread_block_gemm_single_tile(const T *local_a, const T *local_b, T *local_c, #pragma GCC unroll for (int wn_iter = 0; wn_iter < WNITER; wn_iter++) { wmma_store(tid_in_warp, warp_col, warp_row, wn_iter, wm_iter, BN, - local_c); + result_addr); } } } @@ -857,9 +922,11 @@ inline void thread_block_gemm(const T *A, const T *B, float *C, constexpr MemLayout layout_a = TRANSPOSE_AT_CONSUME ? MemLayout::K_major : MemLayout::MN_major; thread_block_gemm_single_tile( + /*load_accum=*/false, + /*write_to_mem=*/false>( local_a_consume, local_b_consume, - static_cast(nullptr) /*ignore*/, tid_in_threadblock, + static_cast(nullptr) /*ignore accum*/, + static_cast(nullptr) /*ignore result*/, tid_in_threadblock, threads_per_threadblock, threadblocks_per_cluster, threadblock_id_in_cluster);