sgemm_impl: Accept threads_per_threadblock in load_tile_to_smem
Needed for warp-specialized kernels.
This commit is contained in:
@@ -446,16 +446,20 @@ inline void threadblock_barrier(const uint32_t barrier_id, const uint32_t count)
|
||||
// `dim_major`: major dimension of the matrix in GMEM, e.g. if K-major, K; or
|
||||
// MN-major, M/N.
|
||||
template <typename T,
|
||||
MemLayout gmem_layout, // memory layout of the GMEM tile
|
||||
MemLayout smem_layout, // memory layout of the GMEM tile
|
||||
uint32_t tile_dim_mn, // row dimension of the SMEM tile
|
||||
uint32_t tile_dim_k // column dimension of the SMEM tile
|
||||
MemLayout gmem_layout, // memory layout of the GMEM tile
|
||||
MemLayout smem_layout, // memory layout of the GMEM tile
|
||||
uint32_t tile_dim_mn, // row dimension of the SMEM tile
|
||||
uint32_t tile_dim_k, // column dimension of the SMEM tile
|
||||
uint32_t threads_per_threadblock // this needs to be
|
||||
// compile-time in order
|
||||
// to do inline assembly with
|
||||
// constant memory offsets
|
||||
>
|
||||
__attribute__((always_inline)) inline void
|
||||
load_tile_to_smem(const uint32_t dim_major, const uint32_t mn_index,
|
||||
const uint32_t k_index, const T *global_addr,
|
||||
volatile T *local_addr, const uint32_t tid_in_threadblock) {
|
||||
asm volatile("global_dmem_load_start_new_%=:" ::);
|
||||
asm volatile("load_tile_to_smem_start_%=:" ::);
|
||||
|
||||
// In fp16 mode, bit-pack two fp16 elements into each fp32 element, and do
|
||||
// data movement at the fp32 granularity. The tensor core hardware assumes
|
||||
@@ -486,9 +490,6 @@ load_tile_to_smem(const uint32_t dim_major, const uint32_t mn_index,
|
||||
const uint32_t local_col_smem =
|
||||
transposed_write ? local_row_gmem : local_col_gmem;
|
||||
|
||||
// FIXME: don't hardcode this here
|
||||
constexpr uint32_t threads_per_threadblock = (BM * BN) / ELEM_PER_THREAD;
|
||||
|
||||
const uint32_t global_row_mn_major = tile_dim_k_packed * k_index + local_row_gmem;
|
||||
const uint32_t global_col_mn_major = gmem_dim_col * mn_index + local_col_gmem;
|
||||
const uint32_t global_row_k_major = gmem_dim_row * mn_index + local_row_gmem;
|
||||
@@ -506,6 +507,7 @@ load_tile_to_smem(const uint32_t dim_major, const uint32_t mn_index,
|
||||
smem_dim_col * local_row_smem + local_col_smem;
|
||||
|
||||
constexpr uint32_t row_stride = threads_per_threadblock / gmem_dim_col;
|
||||
|
||||
static_assert(row_stride * 8 <= gmem_dim_row,
|
||||
"manual loop unrolling condition not met; tile row dimension "
|
||||
"is too shallow");
|
||||
@@ -598,7 +600,7 @@ load_tile_to_smem(const uint32_t dim_major, const uint32_t mn_index,
|
||||
}
|
||||
}
|
||||
|
||||
asm volatile("global_dmem_load_finish_new_%=:" ::);
|
||||
asm volatile("load_tile_to_smem_finish_new_%=:" ::);
|
||||
}
|
||||
|
||||
// Do a single tile*tile matrix multiplication using the matrix data stored in
|
||||
@@ -677,7 +679,7 @@ __attribute__((always_inline)) inline void thread_block_gemm_single_tile(
|
||||
|
||||
if constexpr (write_to_mem) {
|
||||
// need to protect smem reads in the earlier step from writes in below,
|
||||
// especially when the destination smem address overlaps with the input
|
||||
// especially when the destination address overlaps with the source address
|
||||
threadblock_barrier(threadblock_id_in_cluster,
|
||||
warps_per_threadblock_per_core);
|
||||
|
||||
@@ -877,17 +879,19 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
|
||||
#else
|
||||
// move A
|
||||
if constexpr (!TRANSPOSE_AT_PRODUCE) {
|
||||
load_tile_to_smem<T, MemLayout::MN_major, MemLayout::MN_major, BM,
|
||||
BK>(dim_m, block_m, block_k, A, local_a,
|
||||
tid_in_threadblock);
|
||||
load_tile_to_smem<T, MemLayout::MN_major, MemLayout::MN_major, BM, BK,
|
||||
threads_per_threadblock>(
|
||||
dim_m, block_m, block_k, A, local_a, tid_in_threadblock);
|
||||
} else {
|
||||
load_tile_to_smem<T, MemLayout::K_major, MemLayout::MN_major, BM, BK>(
|
||||
load_tile_to_smem<T, MemLayout::K_major, MemLayout::MN_major, BM, BK,
|
||||
threads_per_threadblock>(
|
||||
dim_k, block_m, block_k, A, local_a, tid_in_threadblock);
|
||||
}
|
||||
|
||||
// move B
|
||||
load_tile_to_smem<T, MemLayout::MN_major, MemLayout::MN_major, BN, BK>(
|
||||
dim_n, block_n, block_k, B, local_b, tid_in_threadblock);
|
||||
load_tile_to_smem<T, MemLayout::MN_major, MemLayout::MN_major, BN, BK,
|
||||
threads_per_threadblock>(dim_n, block_n, block_k, B,
|
||||
local_b, tid_in_threadblock);
|
||||
|
||||
threadblock_barrier(threadblock_id_in_cluster,
|
||||
warps_per_threadblock_per_core);
|
||||
|
||||
Reference in New Issue
Block a user