sgemm_impl: Split tile offset addr gen from wmma store

& add an option to write to smem in gemm_single_tile.
This commit is contained in:
Hansung Kim
2024-08-18 16:10:29 -07:00
parent 90f6effa97
commit b978bf8757

View File

@@ -342,12 +342,11 @@ inline void initialize_C(const int dest_reg) {
} }
} }
inline void write_results(const int thread_in_warp, const int warp_col, __attribute__((always_inline)) inline void
const int warp_row, const int wn_iter, wmma_store(const int thread_in_warp, const int warp_col, const int warp_row,
const int wm_iter, const int dim_n, const int wn_iter, const int wm_iter, const int dim_n,
float *C, const int threadblock_id_x, float *write_addr) {
const int threadblock_id_y) { asm volatile ("wmma_store_start_%=:" :: );
asm volatile ("write_results_start_%=:" :: );
int tid = thread_in_warp; int tid = thread_in_warp;
@@ -359,45 +358,34 @@ inline void write_results(const int thread_in_warp, const int warp_col,
int local_row = (WM * warp_row + TCM * wm_iter) + tid_row; int local_row = (WM * warp_row + TCM * wm_iter) + tid_row;
int local_col = (WN * warp_col + TCN * wn_iter) + tid_col; int local_col = (WN * warp_col + TCN * wn_iter) + tid_col;
float *global_offset_C =
C + (BM * threadblock_id_y) * dim_n + BN * threadblock_id_x;
// @perf: this likely causes a lot of gmem bank conflicts // @perf: this likely causes a lot of gmem bank conflicts
if (wm_iter == 0) { if (wm_iter == 0) {
volatile uint8_t *gmem_addr = reinterpret_cast<volatile uint8_t *>( volatile uint8_t *addr = reinterpret_cast<volatile uint8_t *>(
&global_offset_C[dim_n * (local_row + 0) + (local_col + 0)]); &write_addr[dim_n * (local_row + 0) + (local_col + 0)]);
volatile uint8_t *gmem_addr_tmp = gmem_addr + (2 * dim_n) * sizeof(float); volatile uint8_t *addr_tworow = addr + (2 * dim_n) * sizeof(float);
asm volatile ("fsw f16, %0(%1)" :: "i"(0 * sizeof(float)), "r"(gmem_addr)); asm volatile("fsw f16, %0(%1)" ::"i"(0 * sizeof(float)), "r"(addr));
asm volatile ("fsw f17, %0(%1)" :: "i"(1 * sizeof(float)), "r"(gmem_addr)); asm volatile("fsw f17, %0(%1)" ::"i"(1 * sizeof(float)), "r"(addr));
asm volatile ("fsw f18, %0(%1)" :: "i"(0 * sizeof(float)), "r"(gmem_addr_tmp)); asm volatile("fsw f18, %0(%1)" ::"i"(0 * sizeof(float)), "r"(addr_tworow));
asm volatile ("fsw f19, %0(%1)" :: "i"(1 * sizeof(float)), "r"(gmem_addr_tmp)); asm volatile("fsw f19, %0(%1)" ::"i"(1 * sizeof(float)), "r"(addr_tworow));
asm volatile ("fsw f20, %0(%1)" :: "i"(4 * sizeof(float)), "r"(gmem_addr)); asm volatile("fsw f20, %0(%1)" ::"i"(4 * sizeof(float)), "r"(addr));
asm volatile ("fsw f21, %0(%1)" :: "i"(5 * sizeof(float)), "r"(gmem_addr)); asm volatile("fsw f21, %0(%1)" ::"i"(5 * sizeof(float)), "r"(addr));
asm volatile ("fsw f22, %0(%1)" :: "i"(4 * sizeof(float)), "r"(gmem_addr_tmp)); asm volatile("fsw f22, %0(%1)" ::"i"(4 * sizeof(float)), "r"(addr_tworow));
asm volatile ("fsw f23, %0(%1)" :: "i"(5 * sizeof(float)), "r"(gmem_addr_tmp)); asm volatile("fsw f23, %0(%1)" ::"i"(5 * sizeof(float)), "r"(addr_tworow));
// asm volatile ("fsw f16, %0" :: "m"(global_offset_C[dim_n * (local_row + 0) + (local_col + 0)]));
// asm volatile ("fsw f17, %0" :: "m"(global_offset_C[dim_n * (local_row + 0) + (local_col + 1)]));
// asm volatile ("fsw f18, %0" :: "m"(global_offset_C[dim_n * (local_row + 2) + (local_col + 0)]));
// asm volatile ("fsw f19, %0" :: "m"(global_offset_C[dim_n * (local_row + 2) + (local_col + 1)]));
// asm volatile ("fsw f20, %0" :: "m"(global_offset_C[dim_n * (local_row + 0) + (local_col + 4)]));
// asm volatile ("fsw f21, %0" :: "m"(global_offset_C[dim_n * (local_row + 0) + (local_col + 5)]));
// asm volatile ("fsw f22, %0" :: "m"(global_offset_C[dim_n * (local_row + 2) + (local_col + 4)]));
// asm volatile ("fsw f23, %0" :: "m"(global_offset_C[dim_n * (local_row + 2) + (local_col + 5)]));
} else { } else {
volatile uint8_t *gmem_addr = reinterpret_cast<volatile uint8_t *>( volatile uint8_t *addr = reinterpret_cast<volatile uint8_t *>(
&global_offset_C[dim_n * (local_row + 0) + (local_col + 0)]); &write_addr[dim_n * (local_row + 0) + (local_col + 0)]);
volatile uint8_t *gmem_addr_tmp = gmem_addr + (2 * dim_n) * sizeof(float); volatile uint8_t *addr_tworow = addr + (2 * dim_n) * sizeof(float);
asm volatile ("fsw f24, %0(%1)" :: "i"(0 * sizeof(float)), "r"(gmem_addr)); asm volatile("fsw f24, %0(%1)" ::"i"(0 * sizeof(float)), "r"(addr));
asm volatile ("fsw f25, %0(%1)" :: "i"(1 * sizeof(float)), "r"(gmem_addr)); asm volatile("fsw f25, %0(%1)" ::"i"(1 * sizeof(float)), "r"(addr));
asm volatile ("fsw f26, %0(%1)" :: "i"(0 * sizeof(float)), "r"(gmem_addr_tmp)); asm volatile("fsw f26, %0(%1)" ::"i"(0 * sizeof(float)), "r"(addr_tworow));
asm volatile ("fsw f27, %0(%1)" :: "i"(1 * sizeof(float)), "r"(gmem_addr_tmp)); asm volatile("fsw f27, %0(%1)" ::"i"(1 * sizeof(float)), "r"(addr_tworow));
asm volatile ("fsw f28, %0(%1)" :: "i"(4 * sizeof(float)), "r"(gmem_addr)); asm volatile("fsw f28, %0(%1)" ::"i"(4 * sizeof(float)), "r"(addr));
asm volatile ("fsw f29, %0(%1)" :: "i"(5 * sizeof(float)), "r"(gmem_addr)); asm volatile("fsw f29, %0(%1)" ::"i"(5 * sizeof(float)), "r"(addr));
asm volatile ("fsw f30, %0(%1)" :: "i"(4 * sizeof(float)), "r"(gmem_addr_tmp)); asm volatile("fsw f30, %0(%1)" ::"i"(4 * sizeof(float)), "r"(addr_tworow));
asm volatile ("fsw f31, %0(%1)" :: "i"(5 * sizeof(float)), "r"(gmem_addr_tmp)); asm volatile("fsw f31, %0(%1)" ::"i"(5 * sizeof(float)), "r"(addr_tworow));
} }
asm volatile ("write_results_finish_%=:" :: ); asm volatile ("wmma_store_finish_%=:" :: );
} }
inline void threadblock_barrier(const uint32_t barrier_id, const uint32_t count) { inline void threadblock_barrier(const uint32_t barrier_id, const uint32_t count) {
@@ -648,9 +636,12 @@ inline void global_dmem_load(const uint32_t dim_m, const uint32_t dim_n, const u
// Do a single tile*tile matrix multiplication using the matrix data stored in // Do a single tile*tile matrix multiplication using the matrix data stored in
// SMEM. Useful in fused kernels where GEMMs are done at a per-tile scope. // SMEM. Useful in fused kernels where GEMMs are done at a per-tile scope.
template <typename T> template <typename T,
bool write_to_smem = false // if true, write result tile to SMEM at a
// given address
>
__attribute__((always_inline)) inline void __attribute__((always_inline)) inline void
thread_block_gemm_single_tile(const T *local_a, const T *local_b, thread_block_gemm_single_tile(const T *local_a, const T *local_b, T *local_c,
const uint32_t tid_in_threadblock, const uint32_t tid_in_threadblock,
const uint32_t threads_per_threadblock) { const uint32_t threads_per_threadblock) {
// no double-buffering // no double-buffering
@@ -688,6 +679,17 @@ thread_block_gemm_single_tile(const T *local_a, const T *local_b,
gemmini_fence(); gemmini_fence();
} }
} }
if constexpr (write_to_smem) {
#pragma GCC unroll
for (int wm_iter = 0; wm_iter < WMITER; wm_iter++) {
#pragma GCC unroll
for (int wn_iter = 0; wn_iter < WNITER; wn_iter++) {
wmma_store(tid_in_warp, warp_col, warp_row, wn_iter, wm_iter, BN,
local_c);
}
}
}
} }
template <typename T, bool write_to_gmem = true, template <typename T, bool write_to_gmem = true,
@@ -914,9 +916,10 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
local_b_consume = local_b; local_b_consume = local_b;
} }
thread_block_gemm_single_tile(local_a_consume, local_b_consume, thread_block_gemm_single_tile(
tid_in_threadblock, local_a_consume, local_b_consume,
threads_per_threadblock); static_cast<volatile T *>(nullptr) /*ignore*/, tid_in_threadblock,
threads_per_threadblock);
if constexpr (GEMMINI_DMA) { if constexpr (GEMMINI_DMA) {
// Call gemmini fence at the end of the loop to overlap dma & wmma. // Call gemmini fence at the end of the loop to overlap dma & wmma.
@@ -932,12 +935,13 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
} }
if constexpr (write_to_gmem) { if constexpr (write_to_gmem) {
#pragma GCC unroll 2 #pragma GCC unroll
for (int wm_iter = 0; wm_iter < WMITER; wm_iter++) { for (int wm_iter = 0; wm_iter < WMITER; wm_iter++) {
#pragma GCC unroll 2 #pragma GCC unroll
for (int wn_iter = 0; wn_iter < WNITER; wn_iter++) { for (int wn_iter = 0; wn_iter < WNITER; wn_iter++) {
write_results(tid_in_warp, warp_col, warp_row, wn_iter, wm_iter, float *global_offset_C = C + (BM * block_m) * dim_n + BN * block_n;
dim_n, C, block_n, block_m); wmma_store(tid_in_warp, warp_col, warp_row, wn_iter, wm_iter, dim_n,
global_offset_C);
} }
} }
} }