sgemm_impl: load_tile: accept k_index for consistency + fix gmem addr gen

This commit is contained in:
Hansung Kim
2024-08-20 17:46:35 -07:00
parent 60aec1de8d
commit 526c2bd334

View File

@@ -453,7 +453,7 @@ template <typename T,
>
__attribute__((always_inline)) inline void
load_tile_to_smem(const uint32_t dim_major, const uint32_t mn_index,
const uint32_t k, const T *global_addr,
const uint32_t k_index, const T *global_addr,
volatile T *local_addr, const uint32_t tid_in_threadblock) {
asm volatile("global_dmem_load_start_new_%=:" ::);
@@ -469,15 +469,11 @@ load_tile_to_smem(const uint32_t dim_major, const uint32_t mn_index,
(gmem_layout == MemLayout::K_major) ? tile_dim_mn : tile_dim_k_packed;
constexpr uint32_t gmem_dim_col =
(gmem_layout == MemLayout::K_major) ? tile_dim_k_packed : tile_dim_mn;
constexpr uint32_t smem_dim_row =
(smem_layout == MemLayout::K_major) ? tile_dim_mn : tile_dim_k_packed;
constexpr uint32_t smem_dim_col =
(smem_layout == MemLayout::K_major) ? tile_dim_k_packed : tile_dim_mn;
const uint32_t dim_major_ =
(gmem_layout == MemLayout::K_major) ? dim_major / packed_factor : dim_major;
// FIXME: unsure about this
const uint32_t k_ = k / packed_factor;
// threads in the threadblock always do contiguous accesses in the gmem
const uint32_t local_row_gmem = tid_in_threadblock / gmem_dim_col;
@@ -493,10 +489,10 @@ load_tile_to_smem(const uint32_t dim_major, const uint32_t mn_index,
// FIXME: don't hardcode this here
constexpr uint32_t threads_per_threadblock = (BM * BN) / ELEM_PER_THREAD;
const uint32_t global_row_mn_major = k_ + local_row_gmem;
const uint32_t global_col_mn_major = smem_dim_col * mn_index + local_col_gmem;
const uint32_t global_row_mn_major = tile_dim_k_packed * k_index + local_row_gmem;
const uint32_t global_col_mn_major = gmem_dim_col * mn_index + local_col_gmem;
const uint32_t global_row_k_major = gmem_dim_row * mn_index + local_row_gmem;
const uint32_t global_col_k_major = k_ + local_col_gmem;
const uint32_t global_col_k_major = tile_dim_k_packed * k_index + local_col_gmem;
const uint32_t global_row = (gmem_layout == MemLayout::K_major)
? global_row_k_major
: global_row_mn_major;
@@ -879,16 +875,16 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
// move A
if constexpr (!TRANSPOSE_AT_PRODUCE) {
load_tile_to_smem<T, MemLayout::MN_major, MemLayout::MN_major, BM,
BK>(dim_m, block_m, block_k * BK, A, local_a,
BK>(dim_m, block_m, block_k, A, local_a,
tid_in_threadblock);
} else {
load_tile_to_smem<T, MemLayout::K_major, MemLayout::MN_major, BM, BK>(
dim_k, block_m, block_k * BK, A, local_a, tid_in_threadblock);
dim_k, block_m, block_k, A, local_a, tid_in_threadblock);
}
// move B
load_tile_to_smem<T, MemLayout::MN_major, MemLayout::MN_major, BN, BK>(
dim_n, block_n, block_k * BK, B, local_b, tid_in_threadblock);
dim_n, block_n, block_k, B, local_b, tid_in_threadblock);
threadblock_barrier(threadblock_id_in_cluster,
warps_per_threadblock_per_core);