sgemm_impl: Add param to load accumulation tile in single_tile
This commit is contained in:
@@ -343,13 +343,61 @@ template <int accum_reg_set> inline void initialize_accum_regs() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// `C` is expected to be in N-major layout.
|
||||||
|
__attribute__((always_inline)) inline void
|
||||||
|
wmma_load_accum(const int thread_in_warp, const int warp_col,
|
||||||
|
const int warp_row, const int wn_iter, const int wm_iter,
|
||||||
|
const int dim_n, const float *C) {
|
||||||
|
asm volatile("wmma_load_accum_start_%=:" ::);
|
||||||
|
|
||||||
|
const int tid = thread_in_warp;
|
||||||
|
|
||||||
|
// these are [0, TCM/TCN)
|
||||||
|
int tid_row = 0;
|
||||||
|
int tid_col = 0;
|
||||||
|
map_c(tid, tid_row, tid_col);
|
||||||
|
|
||||||
|
int local_row = (WM * warp_row + TCM * wm_iter) + tid_row;
|
||||||
|
int local_col = (WN * warp_col + TCN * wn_iter) + tid_col;
|
||||||
|
|
||||||
|
// @copypaste from wmma_store
|
||||||
|
// @perf: this likely causes a lot of gmem bank conflicts
|
||||||
|
if (wm_iter == 0) {
|
||||||
|
const uint8_t *addr = reinterpret_cast<const uint8_t *>(
|
||||||
|
&C[dim_n * (local_row + 0) + (local_col + 0)]);
|
||||||
|
const uint8_t *addr_tworow = addr + (2 * dim_n) * sizeof(float);
|
||||||
|
asm volatile("flw f16, %0(%1)" ::"i"(0 * sizeof(float)), "r"(addr));
|
||||||
|
asm volatile("flw f17, %0(%1)" ::"i"(1 * sizeof(float)), "r"(addr));
|
||||||
|
asm volatile("flw f18, %0(%1)" ::"i"(0 * sizeof(float)), "r"(addr_tworow));
|
||||||
|
asm volatile("flw f19, %0(%1)" ::"i"(1 * sizeof(float)), "r"(addr_tworow));
|
||||||
|
asm volatile("flw f20, %0(%1)" ::"i"(4 * sizeof(float)), "r"(addr));
|
||||||
|
asm volatile("flw f21, %0(%1)" ::"i"(5 * sizeof(float)), "r"(addr));
|
||||||
|
asm volatile("flw f22, %0(%1)" ::"i"(4 * sizeof(float)), "r"(addr_tworow));
|
||||||
|
asm volatile("flw f23, %0(%1)" ::"i"(5 * sizeof(float)), "r"(addr_tworow));
|
||||||
|
} else {
|
||||||
|
const uint8_t *addr = reinterpret_cast<const uint8_t *>(
|
||||||
|
&C[dim_n * (local_row + 0) + (local_col + 0)]);
|
||||||
|
const uint8_t *addr_tworow = addr + (2 * dim_n) * sizeof(float);
|
||||||
|
asm volatile("flw f24, %0(%1)" ::"i"(0 * sizeof(float)), "r"(addr));
|
||||||
|
asm volatile("flw f25, %0(%1)" ::"i"(1 * sizeof(float)), "r"(addr));
|
||||||
|
asm volatile("flw f26, %0(%1)" ::"i"(0 * sizeof(float)), "r"(addr_tworow));
|
||||||
|
asm volatile("flw f27, %0(%1)" ::"i"(1 * sizeof(float)), "r"(addr_tworow));
|
||||||
|
asm volatile("flw f28, %0(%1)" ::"i"(4 * sizeof(float)), "r"(addr));
|
||||||
|
asm volatile("flw f29, %0(%1)" ::"i"(5 * sizeof(float)), "r"(addr));
|
||||||
|
asm volatile("flw f30, %0(%1)" ::"i"(4 * sizeof(float)), "r"(addr_tworow));
|
||||||
|
asm volatile("flw f31, %0(%1)" ::"i"(5 * sizeof(float)), "r"(addr_tworow));
|
||||||
|
}
|
||||||
|
|
||||||
|
asm volatile("wmma_load_accum_finish_%=:" ::);
|
||||||
|
}
|
||||||
|
|
||||||
__attribute__((always_inline)) inline void
|
__attribute__((always_inline)) inline void
|
||||||
wmma_store(const int thread_in_warp, const int warp_col, const int warp_row,
|
wmma_store(const int thread_in_warp, const int warp_col, const int warp_row,
|
||||||
const int wn_iter, const int wm_iter, const int dim_n,
|
const int wn_iter, const int wm_iter, const int dim_n,
|
||||||
float *write_addr) {
|
float *write_addr) {
|
||||||
asm volatile ("wmma_store_start_%=:" :: );
|
asm volatile ("wmma_store_start_%=:" :: );
|
||||||
|
|
||||||
int tid = thread_in_warp;
|
const int tid = thread_in_warp;
|
||||||
|
|
||||||
// these are [0, TCM/TCN)
|
// these are [0, TCM/TCN)
|
||||||
int tid_row = 0;
|
int tid_row = 0;
|
||||||
@@ -560,17 +608,19 @@ load_tile_to_smem(const uint32_t dim_major, const uint32_t mn_index,
|
|||||||
// Do a single tile*tile matrix multiplication using the matrix data stored in
|
// Do a single tile*tile matrix multiplication using the matrix data stored in
|
||||||
// SMEM. Useful in fused kernels where GEMMs are done at a per-tile scope.
|
// SMEM. Useful in fused kernels where GEMMs are done at a per-tile scope.
|
||||||
template <typename T,
|
template <typename T,
|
||||||
MemLayout layout_a, // memory layout of `local_a`
|
MemLayout layout_a, // memory layout of `local_a`
|
||||||
MemLayout layout_b, // memory layout of `local_b`
|
MemLayout layout_b, // memory layout of `local_b`
|
||||||
bool write_to_smem = false // if true, write result tile to SMEM at a
|
bool load_accum = false, // if true, load the accumulation registers
|
||||||
// given address
|
// with `local_c`. used for the (C + A*B)
|
||||||
|
// operation
|
||||||
|
bool write_to_mem = false // if true, write the single result tile to
|
||||||
|
// the memory at a given address
|
||||||
>
|
>
|
||||||
__attribute__((always_inline)) inline void
|
__attribute__((always_inline)) inline void thread_block_gemm_single_tile(
|
||||||
thread_block_gemm_single_tile(const T *local_a, const T *local_b, T *local_c,
|
const T *local_a, const T *local_b, const T *local_c, T *result_addr,
|
||||||
const uint32_t tid_in_threadblock,
|
const uint32_t tid_in_threadblock, const uint32_t threads_per_threadblock,
|
||||||
const uint32_t threads_per_threadblock,
|
const uint32_t threadblocks_per_cluster,
|
||||||
const uint32_t threadblocks_per_cluster,
|
const uint32_t threadblock_id_in_cluster) {
|
||||||
const uint32_t threadblock_id_in_cluster) {
|
|
||||||
// no double-buffering
|
// no double-buffering
|
||||||
// FIXME: duplicated from thread_block_gemm
|
// FIXME: duplicated from thread_block_gemm
|
||||||
const uint32_t threads_per_warpgroup = threads_per_threadblock;
|
const uint32_t threads_per_warpgroup = threads_per_threadblock;
|
||||||
@@ -581,6 +631,21 @@ thread_block_gemm_single_tile(const T *local_a, const T *local_b, T *local_c,
|
|||||||
const uint32_t warps_per_threadblock_per_core =
|
const uint32_t warps_per_threadblock_per_core =
|
||||||
NUM_WARPS / threadblocks_per_cluster;
|
NUM_WARPS / threadblocks_per_cluster;
|
||||||
|
|
||||||
|
// TODO: it would be useful if this bit is split out into a function, so that
|
||||||
|
// preloading accumulation tile can be used for full GEMMs at the start of
|
||||||
|
// the K-loop.
|
||||||
|
if constexpr (load_accum) {
|
||||||
|
#pragma GCC unroll
|
||||||
|
for (int wm_iter = 0; wm_iter < WMITER; wm_iter++) {
|
||||||
|
#pragma GCC unroll
|
||||||
|
for (int wn_iter = 0; wn_iter < WNITER; wn_iter++) {
|
||||||
|
// FIXME: template parameter-ize BM
|
||||||
|
wmma_load_accum(tid_in_warp, warp_col, warp_row, wn_iter, wm_iter, BN,
|
||||||
|
local_c);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#pragma GCC unroll 1
|
#pragma GCC unroll 1
|
||||||
for (int i = 0; i < BK_LOOP; i++) {
|
for (int i = 0; i < BK_LOOP; i++) {
|
||||||
#pragma GCC unroll 4
|
#pragma GCC unroll 4
|
||||||
@@ -611,7 +676,7 @@ thread_block_gemm_single_tile(const T *local_a, const T *local_b, T *local_c,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if constexpr (write_to_smem) {
|
if constexpr (write_to_mem) {
|
||||||
// need to protect smem reads in the earlier step from writes in below,
|
// need to protect smem reads in the earlier step from writes in below,
|
||||||
// especially when the destination smem address overlaps with the input
|
// especially when the destination smem address overlaps with the input
|
||||||
threadblock_barrier(threadblock_id_in_cluster,
|
threadblock_barrier(threadblock_id_in_cluster,
|
||||||
@@ -622,7 +687,7 @@ thread_block_gemm_single_tile(const T *local_a, const T *local_b, T *local_c,
|
|||||||
#pragma GCC unroll
|
#pragma GCC unroll
|
||||||
for (int wn_iter = 0; wn_iter < WNITER; wn_iter++) {
|
for (int wn_iter = 0; wn_iter < WNITER; wn_iter++) {
|
||||||
wmma_store(tid_in_warp, warp_col, warp_row, wn_iter, wm_iter, BN,
|
wmma_store(tid_in_warp, warp_col, warp_row, wn_iter, wm_iter, BN,
|
||||||
local_c);
|
result_addr);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -857,9 +922,11 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
|
|||||||
constexpr MemLayout layout_a =
|
constexpr MemLayout layout_a =
|
||||||
TRANSPOSE_AT_CONSUME ? MemLayout::K_major : MemLayout::MN_major;
|
TRANSPOSE_AT_CONSUME ? MemLayout::K_major : MemLayout::MN_major;
|
||||||
thread_block_gemm_single_tile<T, layout_a, MemLayout::MN_major,
|
thread_block_gemm_single_tile<T, layout_a, MemLayout::MN_major,
|
||||||
/*write_to_smem=*/false>(
|
/*load_accum=*/false,
|
||||||
|
/*write_to_mem=*/false>(
|
||||||
local_a_consume, local_b_consume,
|
local_a_consume, local_b_consume,
|
||||||
static_cast<T *>(nullptr) /*ignore*/, tid_in_threadblock,
|
static_cast<T *>(nullptr) /*ignore accum*/,
|
||||||
|
static_cast<T *>(nullptr) /*ignore result*/, tid_in_threadblock,
|
||||||
threads_per_threadblock, threadblocks_per_cluster,
|
threads_per_threadblock, threadblocks_per_cluster,
|
||||||
threadblock_id_in_cluster);
|
threadblock_id_in_cluster);
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user