sgemm_impl: Rename dmem load function
This commit is contained in:
@@ -395,13 +395,12 @@ template <typename T,
|
|||||||
MemLayout gmem_layout, // memory layout of the GMEM tile
|
MemLayout gmem_layout, // memory layout of the GMEM tile
|
||||||
MemLayout smem_layout, // memory layout of the GMEM tile
|
MemLayout smem_layout, // memory layout of the GMEM tile
|
||||||
uint32_t tile_dim_mn, // row dimension of the SMEM tile
|
uint32_t tile_dim_mn, // row dimension of the SMEM tile
|
||||||
uint32_t tile_dim_k // column dimension of the SMEM tile
|
uint32_t tile_dim_k // column dimension of the SMEM tile
|
||||||
>
|
>
|
||||||
__attribute__((always_inline)) inline void
|
__attribute__((always_inline)) inline void
|
||||||
global_dmem_load_new(const uint32_t dim_col, const uint32_t mn_index,
|
load_tile_to_smem(const uint32_t dim_col, const uint32_t mn_index,
|
||||||
const uint32_t k, const T *global_addr,
|
const uint32_t k, const T *global_addr,
|
||||||
volatile T *local_addr,
|
volatile T *local_addr, const uint32_t tid_in_threadblock) {
|
||||||
const uint32_t tid_in_threadblock) {
|
|
||||||
asm volatile("global_dmem_load_start_new_%=:" ::);
|
asm volatile("global_dmem_load_start_new_%=:" ::);
|
||||||
|
|
||||||
// In fp16 mode, bit-pack two fp16 elements into each fp32 element, and do
|
// In fp16 mode, bit-pack two fp16 elements into each fp32 element, and do
|
||||||
@@ -805,19 +804,17 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
|
|||||||
#else
|
#else
|
||||||
// move A
|
// move A
|
||||||
if constexpr (!TRANSPOSE_AT_PRODUCE) {
|
if constexpr (!TRANSPOSE_AT_PRODUCE) {
|
||||||
global_dmem_load_new<T, MemLayout::MN_major, MemLayout::MN_major, BM,
|
load_tile_to_smem<T, MemLayout::MN_major, MemLayout::MN_major, BM,
|
||||||
BK>(dim_m, block_m, block_k * BK, A, local_a,
|
BK>(dim_m, block_m, block_k * BK, A, local_a,
|
||||||
tid_in_threadblock);
|
tid_in_threadblock);
|
||||||
} else {
|
} else {
|
||||||
global_dmem_load_new<T, MemLayout::K_major, MemLayout::MN_major, BM,
|
load_tile_to_smem<T, MemLayout::K_major, MemLayout::MN_major, BM, BK>(
|
||||||
BK>(dim_k, block_m, block_k * BK, A, local_a,
|
dim_k, block_m, block_k * BK, A, local_a, tid_in_threadblock);
|
||||||
tid_in_threadblock);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// move B
|
// move B
|
||||||
global_dmem_load_new<T, MemLayout::MN_major, MemLayout::MN_major, BN,
|
load_tile_to_smem<T, MemLayout::MN_major, MemLayout::MN_major, BN, BK>(
|
||||||
BK>(dim_n, block_n, block_k * BK, B, local_b,
|
dim_n, block_n, block_k * BK, B, local_b, tid_in_threadblock);
|
||||||
tid_in_threadblock);
|
|
||||||
|
|
||||||
threadblock_barrier(threadblock_id_in_cluster,
|
threadblock_barrier(threadblock_id_in_cluster,
|
||||||
warps_per_threadblock_per_core);
|
warps_per_threadblock_per_core);
|
||||||
|
|||||||
Reference in New Issue
Block a user