From 526c2bd334a5f1fa10c64377fb1745952c77a6eb Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Tue, 20 Aug 2024 17:46:35 -0700 Subject: [PATCH] sgemm_impl: load_tile: accept k_index for consistency + fix gmem addr gen --- tests/regression/sgemm_tcore/sgemm_impl.hpp | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/tests/regression/sgemm_tcore/sgemm_impl.hpp b/tests/regression/sgemm_tcore/sgemm_impl.hpp index 1f461aa0..f26697e1 100644 --- a/tests/regression/sgemm_tcore/sgemm_impl.hpp +++ b/tests/regression/sgemm_tcore/sgemm_impl.hpp @@ -453,7 +453,7 @@ template __attribute__((always_inline)) inline void load_tile_to_smem(const uint32_t dim_major, const uint32_t mn_index, - const uint32_t k, const T *global_addr, + const uint32_t k_index, const T *global_addr, volatile T *local_addr, const uint32_t tid_in_threadblock) { asm volatile("global_dmem_load_start_new_%=:" ::); @@ -469,15 +469,11 @@ load_tile_to_smem(const uint32_t dim_major, const uint32_t mn_index, (gmem_layout == MemLayout::K_major) ? tile_dim_mn : tile_dim_k_packed; constexpr uint32_t gmem_dim_col = (gmem_layout == MemLayout::K_major) ? tile_dim_k_packed : tile_dim_mn; - constexpr uint32_t smem_dim_row = - (smem_layout == MemLayout::K_major) ? tile_dim_mn : tile_dim_k_packed; constexpr uint32_t smem_dim_col = (smem_layout == MemLayout::K_major) ? tile_dim_k_packed : tile_dim_mn; const uint32_t dim_major_ = (gmem_layout == MemLayout::K_major) ? dim_major / packed_factor : dim_major; - // FIXME: unsure about this - const uint32_t k_ = k / packed_factor; // threads in the threadblock always do contiguous accesses in the gmem const uint32_t local_row_gmem = tid_in_threadblock / gmem_dim_col; @@ -493,10 +489,10 @@ load_tile_to_smem(const uint32_t dim_major, const uint32_t mn_index, // FIXME: don't hardcode this here constexpr uint32_t threads_per_threadblock = (BM * BN) / ELEM_PER_THREAD; - const uint32_t global_row_mn_major = k_ + local_row_gmem; - const uint32_t global_col_mn_major = smem_dim_col * mn_index + local_col_gmem; + const uint32_t global_row_mn_major = tile_dim_k_packed * k_index + local_row_gmem; + const uint32_t global_col_mn_major = gmem_dim_col * mn_index + local_col_gmem; const uint32_t global_row_k_major = gmem_dim_row * mn_index + local_row_gmem; - const uint32_t global_col_k_major = k_ + local_col_gmem; + const uint32_t global_col_k_major = tile_dim_k_packed * k_index + local_col_gmem; const uint32_t global_row = (gmem_layout == MemLayout::K_major) ? global_row_k_major : global_row_mn_major; @@ -879,16 +875,16 @@ inline void thread_block_gemm(const T *A, const T *B, float *C, // move A if constexpr (!TRANSPOSE_AT_PRODUCE) { load_tile_to_smem(dim_m, block_m, block_k * BK, A, local_a, + BK>(dim_m, block_m, block_k, A, local_a, tid_in_threadblock); } else { load_tile_to_smem( - dim_k, block_m, block_k * BK, A, local_a, tid_in_threadblock); + dim_k, block_m, block_k, A, local_a, tid_in_threadblock); } // move B load_tile_to_smem( - dim_n, block_n, block_k * BK, B, local_b, tid_in_threadblock); + dim_n, block_n, block_k, B, local_b, tid_in_threadblock); threadblock_barrier(threadblock_id_in_cluster, warps_per_threadblock_per_core);