sgemm_impl: Split tile offset addr gen from wmma store

& add an option to write to smem in gemm_single_tile.
This commit is contained in:
Hansung Kim
2024-08-18 16:10:29 -07:00
parent 90f6effa97
commit b978bf8757

View File

@@ -342,12 +342,11 @@ inline void initialize_C(const int dest_reg) {
}
}
inline void write_results(const int thread_in_warp, const int warp_col,
const int warp_row, const int wn_iter,
const int wm_iter, const int dim_n,
float *C, const int threadblock_id_x,
const int threadblock_id_y) {
asm volatile ("write_results_start_%=:" :: );
__attribute__((always_inline)) inline void
wmma_store(const int thread_in_warp, const int warp_col, const int warp_row,
const int wn_iter, const int wm_iter, const int dim_n,
float *write_addr) {
asm volatile ("wmma_store_start_%=:" :: );
int tid = thread_in_warp;
@@ -359,45 +358,34 @@ inline void write_results(const int thread_in_warp, const int warp_col,
int local_row = (WM * warp_row + TCM * wm_iter) + tid_row;
int local_col = (WN * warp_col + TCN * wn_iter) + tid_col;
float *global_offset_C =
C + (BM * threadblock_id_y) * dim_n + BN * threadblock_id_x;
// @perf: this likely causes a lot of gmem bank conflicts
if (wm_iter == 0) {
volatile uint8_t *gmem_addr = reinterpret_cast<volatile uint8_t *>(
&global_offset_C[dim_n * (local_row + 0) + (local_col + 0)]);
volatile uint8_t *gmem_addr_tmp = gmem_addr + (2 * dim_n) * sizeof(float);
asm volatile ("fsw f16, %0(%1)" :: "i"(0 * sizeof(float)), "r"(gmem_addr));
asm volatile ("fsw f17, %0(%1)" :: "i"(1 * sizeof(float)), "r"(gmem_addr));
asm volatile ("fsw f18, %0(%1)" :: "i"(0 * sizeof(float)), "r"(gmem_addr_tmp));
asm volatile ("fsw f19, %0(%1)" :: "i"(1 * sizeof(float)), "r"(gmem_addr_tmp));
asm volatile ("fsw f20, %0(%1)" :: "i"(4 * sizeof(float)), "r"(gmem_addr));
asm volatile ("fsw f21, %0(%1)" :: "i"(5 * sizeof(float)), "r"(gmem_addr));
asm volatile ("fsw f22, %0(%1)" :: "i"(4 * sizeof(float)), "r"(gmem_addr_tmp));
asm volatile ("fsw f23, %0(%1)" :: "i"(5 * sizeof(float)), "r"(gmem_addr_tmp));
// asm volatile ("fsw f16, %0" :: "m"(global_offset_C[dim_n * (local_row + 0) + (local_col + 0)]));
// asm volatile ("fsw f17, %0" :: "m"(global_offset_C[dim_n * (local_row + 0) + (local_col + 1)]));
// asm volatile ("fsw f18, %0" :: "m"(global_offset_C[dim_n * (local_row + 2) + (local_col + 0)]));
// asm volatile ("fsw f19, %0" :: "m"(global_offset_C[dim_n * (local_row + 2) + (local_col + 1)]));
// asm volatile ("fsw f20, %0" :: "m"(global_offset_C[dim_n * (local_row + 0) + (local_col + 4)]));
// asm volatile ("fsw f21, %0" :: "m"(global_offset_C[dim_n * (local_row + 0) + (local_col + 5)]));
// asm volatile ("fsw f22, %0" :: "m"(global_offset_C[dim_n * (local_row + 2) + (local_col + 4)]));
// asm volatile ("fsw f23, %0" :: "m"(global_offset_C[dim_n * (local_row + 2) + (local_col + 5)]));
volatile uint8_t *addr = reinterpret_cast<volatile uint8_t *>(
&write_addr[dim_n * (local_row + 0) + (local_col + 0)]);
volatile uint8_t *addr_tworow = addr + (2 * dim_n) * sizeof(float);
asm volatile("fsw f16, %0(%1)" ::"i"(0 * sizeof(float)), "r"(addr));
asm volatile("fsw f17, %0(%1)" ::"i"(1 * sizeof(float)), "r"(addr));
asm volatile("fsw f18, %0(%1)" ::"i"(0 * sizeof(float)), "r"(addr_tworow));
asm volatile("fsw f19, %0(%1)" ::"i"(1 * sizeof(float)), "r"(addr_tworow));
asm volatile("fsw f20, %0(%1)" ::"i"(4 * sizeof(float)), "r"(addr));
asm volatile("fsw f21, %0(%1)" ::"i"(5 * sizeof(float)), "r"(addr));
asm volatile("fsw f22, %0(%1)" ::"i"(4 * sizeof(float)), "r"(addr_tworow));
asm volatile("fsw f23, %0(%1)" ::"i"(5 * sizeof(float)), "r"(addr_tworow));
} else {
volatile uint8_t *gmem_addr = reinterpret_cast<volatile uint8_t *>(
&global_offset_C[dim_n * (local_row + 0) + (local_col + 0)]);
volatile uint8_t *gmem_addr_tmp = gmem_addr + (2 * dim_n) * sizeof(float);
asm volatile ("fsw f24, %0(%1)" :: "i"(0 * sizeof(float)), "r"(gmem_addr));
asm volatile ("fsw f25, %0(%1)" :: "i"(1 * sizeof(float)), "r"(gmem_addr));
asm volatile ("fsw f26, %0(%1)" :: "i"(0 * sizeof(float)), "r"(gmem_addr_tmp));
asm volatile ("fsw f27, %0(%1)" :: "i"(1 * sizeof(float)), "r"(gmem_addr_tmp));
asm volatile ("fsw f28, %0(%1)" :: "i"(4 * sizeof(float)), "r"(gmem_addr));
asm volatile ("fsw f29, %0(%1)" :: "i"(5 * sizeof(float)), "r"(gmem_addr));
asm volatile ("fsw f30, %0(%1)" :: "i"(4 * sizeof(float)), "r"(gmem_addr_tmp));
asm volatile ("fsw f31, %0(%1)" :: "i"(5 * sizeof(float)), "r"(gmem_addr_tmp));
volatile uint8_t *addr = reinterpret_cast<volatile uint8_t *>(
&write_addr[dim_n * (local_row + 0) + (local_col + 0)]);
volatile uint8_t *addr_tworow = addr + (2 * dim_n) * sizeof(float);
asm volatile("fsw f24, %0(%1)" ::"i"(0 * sizeof(float)), "r"(addr));
asm volatile("fsw f25, %0(%1)" ::"i"(1 * sizeof(float)), "r"(addr));
asm volatile("fsw f26, %0(%1)" ::"i"(0 * sizeof(float)), "r"(addr_tworow));
asm volatile("fsw f27, %0(%1)" ::"i"(1 * sizeof(float)), "r"(addr_tworow));
asm volatile("fsw f28, %0(%1)" ::"i"(4 * sizeof(float)), "r"(addr));
asm volatile("fsw f29, %0(%1)" ::"i"(5 * sizeof(float)), "r"(addr));
asm volatile("fsw f30, %0(%1)" ::"i"(4 * sizeof(float)), "r"(addr_tworow));
asm volatile("fsw f31, %0(%1)" ::"i"(5 * sizeof(float)), "r"(addr_tworow));
}
asm volatile ("write_results_finish_%=:" :: );
asm volatile ("wmma_store_finish_%=:" :: );
}
inline void threadblock_barrier(const uint32_t barrier_id, const uint32_t count) {
@@ -648,9 +636,12 @@ inline void global_dmem_load(const uint32_t dim_m, const uint32_t dim_n, const u
// Do a single tile*tile matrix multiplication using the matrix data stored in
// SMEM. Useful in fused kernels where GEMMs are done at a per-tile scope.
template <typename T>
template <typename T,
bool write_to_smem = false // if true, write result tile to SMEM at a
// given address
>
__attribute__((always_inline)) inline void
thread_block_gemm_single_tile(const T *local_a, const T *local_b,
thread_block_gemm_single_tile(const T *local_a, const T *local_b, T *local_c,
const uint32_t tid_in_threadblock,
const uint32_t threads_per_threadblock) {
// no double-buffering
@@ -688,6 +679,17 @@ thread_block_gemm_single_tile(const T *local_a, const T *local_b,
gemmini_fence();
}
}
if constexpr (write_to_smem) {
#pragma GCC unroll
for (int wm_iter = 0; wm_iter < WMITER; wm_iter++) {
#pragma GCC unroll
for (int wn_iter = 0; wn_iter < WNITER; wn_iter++) {
wmma_store(tid_in_warp, warp_col, warp_row, wn_iter, wm_iter, BN,
local_c);
}
}
}
}
template <typename T, bool write_to_gmem = true,
@@ -914,9 +916,10 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
local_b_consume = local_b;
}
thread_block_gemm_single_tile(local_a_consume, local_b_consume,
tid_in_threadblock,
threads_per_threadblock);
thread_block_gemm_single_tile(
local_a_consume, local_b_consume,
static_cast<volatile T *>(nullptr) /*ignore*/, tid_in_threadblock,
threads_per_threadblock);
if constexpr (GEMMINI_DMA) {
// Call gemmini fence at the end of the loop to overlap dma & wmma.
@@ -932,12 +935,13 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
}
if constexpr (write_to_gmem) {
#pragma GCC unroll 2
#pragma GCC unroll
for (int wm_iter = 0; wm_iter < WMITER; wm_iter++) {
#pragma GCC unroll 2
#pragma GCC unroll
for (int wn_iter = 0; wn_iter < WNITER; wn_iter++) {
write_results(tid_in_warp, warp_col, warp_row, wn_iter, wm_iter,
dim_n, C, block_n, block_m);
float *global_offset_C = C + (BM * block_m) * dim_n + BN * block_n;
wmma_store(tid_in_warp, warp_col, warp_row, wn_iter, wm_iter, dim_n,
global_offset_C);
}
}
}