sgemm_impl: Split tile offset addr gen from wmma store
& add an option to write to smem in gemm_single_tile.
This commit is contained in:
@@ -342,12 +342,11 @@ inline void initialize_C(const int dest_reg) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void write_results(const int thread_in_warp, const int warp_col,
|
__attribute__((always_inline)) inline void
|
||||||
const int warp_row, const int wn_iter,
|
wmma_store(const int thread_in_warp, const int warp_col, const int warp_row,
|
||||||
const int wm_iter, const int dim_n,
|
const int wn_iter, const int wm_iter, const int dim_n,
|
||||||
float *C, const int threadblock_id_x,
|
float *write_addr) {
|
||||||
const int threadblock_id_y) {
|
asm volatile ("wmma_store_start_%=:" :: );
|
||||||
asm volatile ("write_results_start_%=:" :: );
|
|
||||||
|
|
||||||
int tid = thread_in_warp;
|
int tid = thread_in_warp;
|
||||||
|
|
||||||
@@ -359,45 +358,34 @@ inline void write_results(const int thread_in_warp, const int warp_col,
|
|||||||
int local_row = (WM * warp_row + TCM * wm_iter) + tid_row;
|
int local_row = (WM * warp_row + TCM * wm_iter) + tid_row;
|
||||||
int local_col = (WN * warp_col + TCN * wn_iter) + tid_col;
|
int local_col = (WN * warp_col + TCN * wn_iter) + tid_col;
|
||||||
|
|
||||||
float *global_offset_C =
|
|
||||||
C + (BM * threadblock_id_y) * dim_n + BN * threadblock_id_x;
|
|
||||||
|
|
||||||
// @perf: this likely causes a lot of gmem bank conflicts
|
// @perf: this likely causes a lot of gmem bank conflicts
|
||||||
if (wm_iter == 0) {
|
if (wm_iter == 0) {
|
||||||
volatile uint8_t *gmem_addr = reinterpret_cast<volatile uint8_t *>(
|
volatile uint8_t *addr = reinterpret_cast<volatile uint8_t *>(
|
||||||
&global_offset_C[dim_n * (local_row + 0) + (local_col + 0)]);
|
&write_addr[dim_n * (local_row + 0) + (local_col + 0)]);
|
||||||
volatile uint8_t *gmem_addr_tmp = gmem_addr + (2 * dim_n) * sizeof(float);
|
volatile uint8_t *addr_tworow = addr + (2 * dim_n) * sizeof(float);
|
||||||
asm volatile ("fsw f16, %0(%1)" :: "i"(0 * sizeof(float)), "r"(gmem_addr));
|
asm volatile("fsw f16, %0(%1)" ::"i"(0 * sizeof(float)), "r"(addr));
|
||||||
asm volatile ("fsw f17, %0(%1)" :: "i"(1 * sizeof(float)), "r"(gmem_addr));
|
asm volatile("fsw f17, %0(%1)" ::"i"(1 * sizeof(float)), "r"(addr));
|
||||||
asm volatile ("fsw f18, %0(%1)" :: "i"(0 * sizeof(float)), "r"(gmem_addr_tmp));
|
asm volatile("fsw f18, %0(%1)" ::"i"(0 * sizeof(float)), "r"(addr_tworow));
|
||||||
asm volatile ("fsw f19, %0(%1)" :: "i"(1 * sizeof(float)), "r"(gmem_addr_tmp));
|
asm volatile("fsw f19, %0(%1)" ::"i"(1 * sizeof(float)), "r"(addr_tworow));
|
||||||
asm volatile ("fsw f20, %0(%1)" :: "i"(4 * sizeof(float)), "r"(gmem_addr));
|
asm volatile("fsw f20, %0(%1)" ::"i"(4 * sizeof(float)), "r"(addr));
|
||||||
asm volatile ("fsw f21, %0(%1)" :: "i"(5 * sizeof(float)), "r"(gmem_addr));
|
asm volatile("fsw f21, %0(%1)" ::"i"(5 * sizeof(float)), "r"(addr));
|
||||||
asm volatile ("fsw f22, %0(%1)" :: "i"(4 * sizeof(float)), "r"(gmem_addr_tmp));
|
asm volatile("fsw f22, %0(%1)" ::"i"(4 * sizeof(float)), "r"(addr_tworow));
|
||||||
asm volatile ("fsw f23, %0(%1)" :: "i"(5 * sizeof(float)), "r"(gmem_addr_tmp));
|
asm volatile("fsw f23, %0(%1)" ::"i"(5 * sizeof(float)), "r"(addr_tworow));
|
||||||
// asm volatile ("fsw f16, %0" :: "m"(global_offset_C[dim_n * (local_row + 0) + (local_col + 0)]));
|
|
||||||
// asm volatile ("fsw f17, %0" :: "m"(global_offset_C[dim_n * (local_row + 0) + (local_col + 1)]));
|
|
||||||
// asm volatile ("fsw f18, %0" :: "m"(global_offset_C[dim_n * (local_row + 2) + (local_col + 0)]));
|
|
||||||
// asm volatile ("fsw f19, %0" :: "m"(global_offset_C[dim_n * (local_row + 2) + (local_col + 1)]));
|
|
||||||
// asm volatile ("fsw f20, %0" :: "m"(global_offset_C[dim_n * (local_row + 0) + (local_col + 4)]));
|
|
||||||
// asm volatile ("fsw f21, %0" :: "m"(global_offset_C[dim_n * (local_row + 0) + (local_col + 5)]));
|
|
||||||
// asm volatile ("fsw f22, %0" :: "m"(global_offset_C[dim_n * (local_row + 2) + (local_col + 4)]));
|
|
||||||
// asm volatile ("fsw f23, %0" :: "m"(global_offset_C[dim_n * (local_row + 2) + (local_col + 5)]));
|
|
||||||
} else {
|
} else {
|
||||||
volatile uint8_t *gmem_addr = reinterpret_cast<volatile uint8_t *>(
|
volatile uint8_t *addr = reinterpret_cast<volatile uint8_t *>(
|
||||||
&global_offset_C[dim_n * (local_row + 0) + (local_col + 0)]);
|
&write_addr[dim_n * (local_row + 0) + (local_col + 0)]);
|
||||||
volatile uint8_t *gmem_addr_tmp = gmem_addr + (2 * dim_n) * sizeof(float);
|
volatile uint8_t *addr_tworow = addr + (2 * dim_n) * sizeof(float);
|
||||||
asm volatile ("fsw f24, %0(%1)" :: "i"(0 * sizeof(float)), "r"(gmem_addr));
|
asm volatile("fsw f24, %0(%1)" ::"i"(0 * sizeof(float)), "r"(addr));
|
||||||
asm volatile ("fsw f25, %0(%1)" :: "i"(1 * sizeof(float)), "r"(gmem_addr));
|
asm volatile("fsw f25, %0(%1)" ::"i"(1 * sizeof(float)), "r"(addr));
|
||||||
asm volatile ("fsw f26, %0(%1)" :: "i"(0 * sizeof(float)), "r"(gmem_addr_tmp));
|
asm volatile("fsw f26, %0(%1)" ::"i"(0 * sizeof(float)), "r"(addr_tworow));
|
||||||
asm volatile ("fsw f27, %0(%1)" :: "i"(1 * sizeof(float)), "r"(gmem_addr_tmp));
|
asm volatile("fsw f27, %0(%1)" ::"i"(1 * sizeof(float)), "r"(addr_tworow));
|
||||||
asm volatile ("fsw f28, %0(%1)" :: "i"(4 * sizeof(float)), "r"(gmem_addr));
|
asm volatile("fsw f28, %0(%1)" ::"i"(4 * sizeof(float)), "r"(addr));
|
||||||
asm volatile ("fsw f29, %0(%1)" :: "i"(5 * sizeof(float)), "r"(gmem_addr));
|
asm volatile("fsw f29, %0(%1)" ::"i"(5 * sizeof(float)), "r"(addr));
|
||||||
asm volatile ("fsw f30, %0(%1)" :: "i"(4 * sizeof(float)), "r"(gmem_addr_tmp));
|
asm volatile("fsw f30, %0(%1)" ::"i"(4 * sizeof(float)), "r"(addr_tworow));
|
||||||
asm volatile ("fsw f31, %0(%1)" :: "i"(5 * sizeof(float)), "r"(gmem_addr_tmp));
|
asm volatile("fsw f31, %0(%1)" ::"i"(5 * sizeof(float)), "r"(addr_tworow));
|
||||||
}
|
}
|
||||||
|
|
||||||
asm volatile ("write_results_finish_%=:" :: );
|
asm volatile ("wmma_store_finish_%=:" :: );
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void threadblock_barrier(const uint32_t barrier_id, const uint32_t count) {
|
inline void threadblock_barrier(const uint32_t barrier_id, const uint32_t count) {
|
||||||
@@ -648,9 +636,12 @@ inline void global_dmem_load(const uint32_t dim_m, const uint32_t dim_n, const u
|
|||||||
|
|
||||||
// Do a single tile*tile matrix multiplication using the matrix data stored in
|
// Do a single tile*tile matrix multiplication using the matrix data stored in
|
||||||
// SMEM. Useful in fused kernels where GEMMs are done at a per-tile scope.
|
// SMEM. Useful in fused kernels where GEMMs are done at a per-tile scope.
|
||||||
template <typename T>
|
template <typename T,
|
||||||
|
bool write_to_smem = false // if true, write result tile to SMEM at a
|
||||||
|
// given address
|
||||||
|
>
|
||||||
__attribute__((always_inline)) inline void
|
__attribute__((always_inline)) inline void
|
||||||
thread_block_gemm_single_tile(const T *local_a, const T *local_b,
|
thread_block_gemm_single_tile(const T *local_a, const T *local_b, T *local_c,
|
||||||
const uint32_t tid_in_threadblock,
|
const uint32_t tid_in_threadblock,
|
||||||
const uint32_t threads_per_threadblock) {
|
const uint32_t threads_per_threadblock) {
|
||||||
// no double-buffering
|
// no double-buffering
|
||||||
@@ -688,6 +679,17 @@ thread_block_gemm_single_tile(const T *local_a, const T *local_b,
|
|||||||
gemmini_fence();
|
gemmini_fence();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if constexpr (write_to_smem) {
|
||||||
|
#pragma GCC unroll
|
||||||
|
for (int wm_iter = 0; wm_iter < WMITER; wm_iter++) {
|
||||||
|
#pragma GCC unroll
|
||||||
|
for (int wn_iter = 0; wn_iter < WNITER; wn_iter++) {
|
||||||
|
wmma_store(tid_in_warp, warp_col, warp_row, wn_iter, wm_iter, BN,
|
||||||
|
local_c);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, bool write_to_gmem = true,
|
template <typename T, bool write_to_gmem = true,
|
||||||
@@ -914,9 +916,10 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
|
|||||||
local_b_consume = local_b;
|
local_b_consume = local_b;
|
||||||
}
|
}
|
||||||
|
|
||||||
thread_block_gemm_single_tile(local_a_consume, local_b_consume,
|
thread_block_gemm_single_tile(
|
||||||
tid_in_threadblock,
|
local_a_consume, local_b_consume,
|
||||||
threads_per_threadblock);
|
static_cast<volatile T *>(nullptr) /*ignore*/, tid_in_threadblock,
|
||||||
|
threads_per_threadblock);
|
||||||
|
|
||||||
if constexpr (GEMMINI_DMA) {
|
if constexpr (GEMMINI_DMA) {
|
||||||
// Call gemmini fence at the end of the loop to overlap dma & wmma.
|
// Call gemmini fence at the end of the loop to overlap dma & wmma.
|
||||||
@@ -932,12 +935,13 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
|
|||||||
}
|
}
|
||||||
|
|
||||||
if constexpr (write_to_gmem) {
|
if constexpr (write_to_gmem) {
|
||||||
#pragma GCC unroll 2
|
#pragma GCC unroll
|
||||||
for (int wm_iter = 0; wm_iter < WMITER; wm_iter++) {
|
for (int wm_iter = 0; wm_iter < WMITER; wm_iter++) {
|
||||||
#pragma GCC unroll 2
|
#pragma GCC unroll
|
||||||
for (int wn_iter = 0; wn_iter < WNITER; wn_iter++) {
|
for (int wn_iter = 0; wn_iter < WNITER; wn_iter++) {
|
||||||
write_results(tid_in_warp, warp_col, warp_row, wn_iter, wm_iter,
|
float *global_offset_C = C + (BM * block_m) * dim_n + BN * block_n;
|
||||||
dim_n, C, block_n, block_m);
|
wmma_store(tid_in_warp, warp_col, warp_row, wn_iter, wm_iter, dim_n,
|
||||||
|
global_offset_C);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user