diff --git a/tests/regression/sgemm_tcore/sgemm_impl.hpp b/tests/regression/sgemm_tcore/sgemm_impl.hpp index 14ba1760..0d18fcee 100644 --- a/tests/regression/sgemm_tcore/sgemm_impl.hpp +++ b/tests/regression/sgemm_tcore/sgemm_impl.hpp @@ -235,7 +235,6 @@ inline void vx_wmma_load_a(volatile const T *smem_A, const int local_k, // int A_offset = (WM * warp_row + TCM * wm_iter + row) * smem_A_cols; - // @perf: bank conflicts // f8-f15 stores a single row of A const volatile uint8_t *smem_addr; smem_addr = reinterpret_cast( @@ -243,7 +242,7 @@ inline void vx_wmma_load_a(volatile const T *smem_A, const int local_k, smem_A)[(WM * warp_row + TCM * wm_iter + row) * smem_A_cols + local_k /* FIXME: adjust for fp16? */]); // step to the next column - // threads read from different rows; bank conflicts + // @perf: bank conflicts; threads read from different rows asm volatile("flw f0, %0(%1)" ::"i"(0 * sizeof(float)), "r"(smem_addr)); asm volatile("flw f1, %0(%1)" ::"i"(1 * sizeof(float)), "r"(smem_addr)); asm volatile("flw f2, %0(%1)" ::"i"(2 * sizeof(float)), "r"(smem_addr)); @@ -408,6 +407,7 @@ inline void threadblock_barrier(const uint32_t barrier_id, const uint32_t count) // TODO: reduce args by passing leading A/B dimensions template +__attribute__((always_inline)) inline void global_dmem_load(const uint32_t dim_m, const uint32_t dim_n, const uint32_t dim_k, const uint32_t k, const T *A, const T *B, volatile T *local_a, volatile T *local_b, @@ -646,10 +646,66 @@ inline void global_dmem_load(const uint32_t dim_m, const uint32_t dim_n, const u asm volatile ("global_dmem_load_finish_%=:" :: ); } -template +// Do a single tile*tile matrix multiplication using the matrix data stored in +// SMEM. Useful in fused kernels where GEMMs are done at a per-tile scope. +template +__attribute__((always_inline)) inline void +thread_block_gemm_single_tile(const T *local_a, const T *local_b, + const uint32_t tid_in_threadblock, + const uint32_t threads_per_threadblock) { + // no double-buffering + // FIXME: duplicated from thread_block_gemm + const uint32_t threads_per_warpgroup = threads_per_threadblock; + const uint32_t warp_id_in_warpgroup = tid_in_threadblock / NUM_THREADS; + const uint32_t warp_row = warp_id_in_warpgroup / (BN / WN); + const uint32_t warp_col = warp_id_in_warpgroup % (BN / WN); + const uint32_t tid_in_warp = tid_in_threadblock % NUM_THREADS; + +#pragma GCC unroll 1 + for (int i = 0; i < BK_LOOP; i++) { +#pragma GCC unroll 4 + for (uint32_t local_k = 0; local_k < BK; local_k += TCK) { +#pragma GCC unroll 2 + for (int wn_iter = 0; wn_iter < WNITER; wn_iter++) { + // SMEM -> RF + vx_wmma_load_b(local_b, local_k, warp_col, wn_iter, tid_in_warp); +#pragma GCC unroll 2 + for (int wm_iter = 0; wm_iter < WMITER; wm_iter++) { + // SMEM -> RF + vx_wmma_load_a(local_a, local_k, warp_row, wm_iter, tid_in_warp); + // perform mma + vx_wmma(wm_iter); + } + } + } + } + + if constexpr (GEMMINI_DMA) { + // Call gemmini fence at the end of the loop to overlap dma & wmma. + // Usually, by this time, dma has finished the copy so that this + // becomes a no-op. + if (tid_in_threadblock == 0) { + gemmini_fence(); + } + } +} + +template inline void thread_block_gemm(const T *A, const T *B, float *C, - const uint32_t dim_m, - const uint32_t dim_n, + const uint32_t dim_m, const uint32_t dim_n, const uint32_t dim_k, const uint32_t tid_in_threadblock, const uint32_t threads_per_threadblock, @@ -672,13 +728,14 @@ inline void thread_block_gemm(const T *A, const T *B, float *C, const uint32_t warps_per_threadblock_per_core = NUM_WARPS / threads_per_threadblock; - volatile T *local_a = reinterpret_cast(sharedmem_per_threadblock); - constexpr size_t local_a_elems = (BM * BK); - volatile T *local_a_buf = local_a + local_a_elems; - - volatile T *local_b = local_a_buf + local_a_elems; - constexpr size_t local_b_elems = (BK * BN); - volatile T *local_b_buf = local_a_buf + local_b_elems; + volatile T *local_a = + reinterpret_cast(sharedmem_per_threadblock + smem_a_offset); + volatile T *local_a_buf = + reinterpret_cast(sharedmem_per_threadblock + smem_a_dbuf_offset); + volatile T *local_b = + reinterpret_cast(sharedmem_per_threadblock + smem_b_offset); + volatile T *local_b_buf = + reinterpret_cast(sharedmem_per_threadblock + smem_b_dbuf_offset); constexpr uint32_t skips = loop_matmul_skips(/*skip_lda=*/0, /*skip_ldb=*/0, /*skip_ldd=*/1, @@ -849,34 +906,17 @@ inline void thread_block_gemm(const T *A, const T *B, float *C, // local_b_consume = reinterpret_cast( // (mask_odd & reinterpret_cast(local_b_buf)) | // (mask_even & reinterpret_cast(local_b))); - local_a_consume = local_a + (block_k & 1) * (local_a_elems); - local_b_consume = local_b + (block_k & 1) * (local_b_elems); + local_a_consume = local_a + (block_k & 1) * (BM * BK); + local_b_consume = local_b + (block_k & 1) * (BK * BN); } else { // no double-buffering without DMA local_a_consume = local_a; local_b_consume = local_b; } -#pragma GCC unroll 1 - for (int i = 0; i < BK_LOOP; i++) { -#pragma GCC unroll 4 - for (uint32_t local_k = 0; local_k < BK; local_k += TCK) { -#pragma GCC unroll 2 - for (int wn_iter = 0; wn_iter < WNITER; wn_iter++) { - // SMEM -> RF - vx_wmma_load_b(local_b_consume, local_k, warp_col, wn_iter, - tid_in_warp); -#pragma GCC unroll 2 - for (int wm_iter = 0; wm_iter < WMITER; wm_iter++) { - // SMEM -> RF - vx_wmma_load_a(local_a_consume, local_k, warp_row, wm_iter, - tid_in_warp); - // perform mma - vx_wmma(wm_iter); - } - } - } - } + thread_block_gemm_single_tile(local_a_consume, local_b_consume, + tid_in_threadblock, + threads_per_threadblock); if constexpr (GEMMINI_DMA) { // Call gemmini fence at the end of the loop to overlap dma & wmma.