sgemm_impl: Parameterize BM/BN/BK in single_tile
This commit is contained in:
@@ -606,6 +606,9 @@ load_tile_to_smem(const uint32_t dim_major, const uint32_t mn_index,
|
|||||||
template <typename T,
|
template <typename T,
|
||||||
MemLayout layout_a, // memory layout of `local_a`
|
MemLayout layout_a, // memory layout of `local_a`
|
||||||
MemLayout layout_b, // memory layout of `local_b`
|
MemLayout layout_b, // memory layout of `local_b`
|
||||||
|
uint32_t tile_dim_m,
|
||||||
|
uint32_t tile_dim_n,
|
||||||
|
uint32_t tile_dim_k,
|
||||||
bool load_accum = false, // if true, load the accumulation registers
|
bool load_accum = false, // if true, load the accumulation registers
|
||||||
// with `local_c`. used for the (C + A*B)
|
// with `local_c`. used for the (C + A*B)
|
||||||
// operation
|
// operation
|
||||||
@@ -621,8 +624,8 @@ __attribute__((always_inline)) inline void thread_block_gemm_single_tile(
|
|||||||
// FIXME: duplicated from thread_block_gemm
|
// FIXME: duplicated from thread_block_gemm
|
||||||
const uint32_t threads_per_warpgroup = threads_per_threadblock;
|
const uint32_t threads_per_warpgroup = threads_per_threadblock;
|
||||||
const uint32_t warp_id_in_warpgroup = tid_in_threadblock / NUM_THREADS;
|
const uint32_t warp_id_in_warpgroup = tid_in_threadblock / NUM_THREADS;
|
||||||
const uint32_t warp_row = warp_id_in_warpgroup / (BN / WN);
|
const uint32_t warp_row = warp_id_in_warpgroup / (tile_dim_n / WN);
|
||||||
const uint32_t warp_col = warp_id_in_warpgroup % (BN / WN);
|
const uint32_t warp_col = warp_id_in_warpgroup % (tile_dim_n / WN);
|
||||||
const uint32_t tid_in_warp = tid_in_threadblock % NUM_THREADS;
|
const uint32_t tid_in_warp = tid_in_threadblock % NUM_THREADS;
|
||||||
const uint32_t warps_per_threadblock_per_core =
|
const uint32_t warps_per_threadblock_per_core =
|
||||||
NUM_WARPS / threadblocks_per_cluster;
|
NUM_WARPS / threadblocks_per_cluster;
|
||||||
@@ -636,8 +639,8 @@ __attribute__((always_inline)) inline void thread_block_gemm_single_tile(
|
|||||||
#pragma GCC unroll
|
#pragma GCC unroll
|
||||||
for (int wn_iter = 0; wn_iter < WNITER; wn_iter++) {
|
for (int wn_iter = 0; wn_iter < WNITER; wn_iter++) {
|
||||||
// FIXME: template parameter-ize BM
|
// FIXME: template parameter-ize BM
|
||||||
wmma_load_accum(tid_in_warp, warp_col, warp_row, wn_iter, wm_iter, BN,
|
wmma_load_accum(tid_in_warp, warp_col, warp_row, wn_iter, wm_iter,
|
||||||
local_c);
|
tile_dim_n, local_c);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -645,7 +648,7 @@ __attribute__((always_inline)) inline void thread_block_gemm_single_tile(
|
|||||||
#pragma GCC unroll 1
|
#pragma GCC unroll 1
|
||||||
for (int i = 0; i < BK_LOOP; i++) {
|
for (int i = 0; i < BK_LOOP; i++) {
|
||||||
#pragma GCC unroll 4
|
#pragma GCC unroll 4
|
||||||
for (uint32_t local_k = 0; local_k < BK; local_k += TCK) {
|
for (uint32_t local_k = 0; local_k < tile_dim_k; local_k += TCK) {
|
||||||
#pragma GCC unroll 2
|
#pragma GCC unroll 2
|
||||||
for (int wn_iter = 0; wn_iter < WNITER; wn_iter++) {
|
for (int wn_iter = 0; wn_iter < WNITER; wn_iter++) {
|
||||||
// SMEM -> RF
|
// SMEM -> RF
|
||||||
@@ -682,7 +685,7 @@ __attribute__((always_inline)) inline void thread_block_gemm_single_tile(
|
|||||||
for (int wm_iter = 0; wm_iter < WMITER; wm_iter++) {
|
for (int wm_iter = 0; wm_iter < WMITER; wm_iter++) {
|
||||||
#pragma GCC unroll
|
#pragma GCC unroll
|
||||||
for (int wn_iter = 0; wn_iter < WNITER; wn_iter++) {
|
for (int wn_iter = 0; wn_iter < WNITER; wn_iter++) {
|
||||||
wmma_store(tid_in_warp, warp_col, warp_row, wn_iter, wm_iter, BN,
|
wmma_store(tid_in_warp, warp_col, warp_row, wn_iter, wm_iter, tile_dim_n,
|
||||||
result_addr);
|
result_addr);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -918,6 +921,7 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
|
|||||||
constexpr MemLayout layout_a =
|
constexpr MemLayout layout_a =
|
||||||
TRANSPOSE_AT_CONSUME ? MemLayout::K_major : MemLayout::MN_major;
|
TRANSPOSE_AT_CONSUME ? MemLayout::K_major : MemLayout::MN_major;
|
||||||
thread_block_gemm_single_tile<T, layout_a, MemLayout::MN_major,
|
thread_block_gemm_single_tile<T, layout_a, MemLayout::MN_major,
|
||||||
|
BM, BN, BK,
|
||||||
/*load_accum=*/false,
|
/*load_accum=*/false,
|
||||||
/*write_to_mem=*/false>(
|
/*write_to_mem=*/false>(
|
||||||
local_a_consume, local_b_consume,
|
local_a_consume, local_b_consume,
|
||||||
|
|||||||
Reference in New Issue
Block a user