From ee0295cbefa91092b690940ad050462f96fc5378 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Thu, 29 Aug 2024 21:43:57 -0700 Subject: [PATCH] sgemm_impl: Accept threads_per_threadblock in load_tile_to_smem Needed for warp-specialized kernels. --- tests/regression/sgemm_tcore/sgemm_impl.hpp | 36 ++++++++++++--------- 1 file changed, 20 insertions(+), 16 deletions(-) diff --git a/tests/regression/sgemm_tcore/sgemm_impl.hpp b/tests/regression/sgemm_tcore/sgemm_impl.hpp index d744a8c1..e77fea35 100644 --- a/tests/regression/sgemm_tcore/sgemm_impl.hpp +++ b/tests/regression/sgemm_tcore/sgemm_impl.hpp @@ -446,16 +446,20 @@ inline void threadblock_barrier(const uint32_t barrier_id, const uint32_t count) // `dim_major`: major dimension of the matrix in GMEM, e.g. if K-major, K; or // MN-major, M/N. template __attribute__((always_inline)) inline void load_tile_to_smem(const uint32_t dim_major, const uint32_t mn_index, const uint32_t k_index, const T *global_addr, volatile T *local_addr, const uint32_t tid_in_threadblock) { - asm volatile("global_dmem_load_start_new_%=:" ::); + asm volatile("load_tile_to_smem_start_%=:" ::); // In fp16 mode, bit-pack two fp16 elements into each fp32 element, and do // data movement at the fp32 granularity. The tensor core hardware assumes @@ -486,9 +490,6 @@ load_tile_to_smem(const uint32_t dim_major, const uint32_t mn_index, const uint32_t local_col_smem = transposed_write ? local_row_gmem : local_col_gmem; - // FIXME: don't hardcode this here - constexpr uint32_t threads_per_threadblock = (BM * BN) / ELEM_PER_THREAD; - const uint32_t global_row_mn_major = tile_dim_k_packed * k_index + local_row_gmem; const uint32_t global_col_mn_major = gmem_dim_col * mn_index + local_col_gmem; const uint32_t global_row_k_major = gmem_dim_row * mn_index + local_row_gmem; @@ -506,6 +507,7 @@ load_tile_to_smem(const uint32_t dim_major, const uint32_t mn_index, smem_dim_col * local_row_smem + local_col_smem; constexpr uint32_t row_stride = threads_per_threadblock / gmem_dim_col; + static_assert(row_stride * 8 <= gmem_dim_row, "manual loop unrolling condition not met; tile row dimension " "is too shallow"); @@ -598,7 +600,7 @@ load_tile_to_smem(const uint32_t dim_major, const uint32_t mn_index, } } - asm volatile("global_dmem_load_finish_new_%=:" ::); + asm volatile("load_tile_to_smem_finish_new_%=:" ::); } // Do a single tile*tile matrix multiplication using the matrix data stored in @@ -677,7 +679,7 @@ __attribute__((always_inline)) inline void thread_block_gemm_single_tile( if constexpr (write_to_mem) { // need to protect smem reads in the earlier step from writes in below, - // especially when the destination smem address overlaps with the input + // especially when the destination address overlaps with the source address threadblock_barrier(threadblock_id_in_cluster, warps_per_threadblock_per_core); @@ -877,17 +879,19 @@ inline void thread_block_gemm(const T *A, const T *B, float *C, #else // move A if constexpr (!TRANSPOSE_AT_PRODUCE) { - load_tile_to_smem(dim_m, block_m, block_k, A, local_a, - tid_in_threadblock); + load_tile_to_smem( + dim_m, block_m, block_k, A, local_a, tid_in_threadblock); } else { - load_tile_to_smem( + load_tile_to_smem( dim_k, block_m, block_k, A, local_a, tid_in_threadblock); } // move B - load_tile_to_smem( - dim_n, block_n, block_k, B, local_b, tid_in_threadblock); + load_tile_to_smem(dim_n, block_n, block_k, B, + local_b, tid_in_threadblock); threadblock_barrier(threadblock_id_in_cluster, warps_per_threadblock_per_core);