sgemm_tcore: Increase RF data reuse for WMITER/WNITER
... by splitting vx_wmma_load to vx_wmma_load_{a,b} and pulling it
out of the innermost loop.
TODO: there's some duplicate address compute being done in the both
functions.
This commit is contained in:
@@ -142,12 +142,12 @@ inline void vx_wmma(const int dest_reg) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// `local_k` is assumed to be multiple of TCK
|
// `local_k` is assumed to be multiple of TCK
|
||||||
inline void vx_wmma_load(volatile float *smem_A, volatile float *smem_B, const int local_k,
|
inline void vx_wmma_load_a(volatile float *smem_A, const int local_k,
|
||||||
const int warp_col, const int warp_row, const int wn_iter,
|
const int warp_row, const int wm_iter, const int thread_in_warp) {
|
||||||
const int wm_iter, const int thread_in_warp) {
|
|
||||||
const int tid = thread_in_warp;
|
const int tid = thread_in_warp;
|
||||||
const int tg = tid / 4;
|
const int tg = tid / 4;
|
||||||
|
|
||||||
|
// TODO: this is duplicately computed between vx_wmma_load_a and vx_wmma_load_b
|
||||||
int row = 0;
|
int row = 0;
|
||||||
int col = 0;
|
int col = 0;
|
||||||
map_operand(tid, row, col);
|
map_operand(tid, row, col);
|
||||||
@@ -188,6 +188,25 @@ inline void vx_wmma_load(volatile float *smem_A, volatile float *smem_B, const i
|
|||||||
// asm volatile("flw f0, %0" ::"m"(smem_A[((local_k + i) * smem_A_rows) + (WM * warp_row + TCM * wm_iter) + row]));
|
// asm volatile("flw f0, %0" ::"m"(smem_A[((local_k + i) * smem_A_rows) + (WM * warp_row + TCM * wm_iter) + row]));
|
||||||
// }
|
// }
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// `local_k` is assumed to be multiple of TCK
|
||||||
|
inline void vx_wmma_load_b(volatile float *smem_B, const int local_k,
|
||||||
|
const int warp_col, const int wn_iter,
|
||||||
|
const int thread_in_warp) {
|
||||||
|
const int tid = thread_in_warp;
|
||||||
|
const int tg = tid / 4;
|
||||||
|
|
||||||
|
int row = 0;
|
||||||
|
int col = 0;
|
||||||
|
map_operand(tid, row, col);
|
||||||
|
|
||||||
|
constexpr int smem_A_rows = BM;
|
||||||
|
constexpr int smem_A_cols = BK;
|
||||||
|
constexpr int smem_AS_rows = BK;
|
||||||
|
constexpr int smem_AS_cols = BM;
|
||||||
|
constexpr int smem_B_rows = BK;
|
||||||
|
constexpr int smem_B_cols = BN;
|
||||||
|
|
||||||
// f8-f15 stores a single column of B
|
// f8-f15 stores a single column of B
|
||||||
asm volatile("flw f8, %0" ::"m"(smem_B[((local_k + 0) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col]));
|
asm volatile("flw f8, %0" ::"m"(smem_B[((local_k + 0) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col]));
|
||||||
@@ -401,9 +420,11 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
|
|||||||
// FIXME: this is wrong!! need separate accumulation register for
|
// FIXME: this is wrong!! need separate accumulation register for
|
||||||
// WM/WN_ITERS
|
// WM/WN_ITERS
|
||||||
#pragma GCC unroll 2
|
#pragma GCC unroll 2
|
||||||
for (int wm_iter = 0; wm_iter < WMITER; wm_iter++) {
|
for (int wn_iter = 0; wn_iter < WNITER; wn_iter++) {
|
||||||
|
vx_wmma_load_b(local_b, local_k, warp_col, wn_iter, tid_in_warp);
|
||||||
|
// vx_wmma_load_b(local_b, 0, 0, 0, tid_in_warp);
|
||||||
#pragma GCC unroll 2
|
#pragma GCC unroll 2
|
||||||
for (int wn_iter = 0; wn_iter < WNITER; wn_iter++) {
|
for (int wm_iter = 0; wm_iter < WMITER; wm_iter++) {
|
||||||
#if TC_SINGLE_WARP
|
#if TC_SINGLE_WARP
|
||||||
if (warp_in_threadblock == 0) {
|
if (warp_in_threadblock == 0) {
|
||||||
#endif
|
#endif
|
||||||
@@ -419,8 +440,8 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
|
|||||||
// asm volatile("addi a0, a0, 0");
|
// asm volatile("addi a0, a0, 0");
|
||||||
// }
|
// }
|
||||||
// SMEM -> RF
|
// SMEM -> RF
|
||||||
vx_wmma_load(local_a, local_b, local_k, warp_col, warp_row,
|
vx_wmma_load_a(local_a, local_k, warp_row, wm_iter, tid_in_warp);
|
||||||
wn_iter, wm_iter, tid_in_warp);
|
// vx_wmma_load_a(local_a, 0, 0, 0, tid_in_warp);
|
||||||
// compute
|
// compute
|
||||||
vx_wmma(wm_iter);
|
vx_wmma(wm_iter);
|
||||||
#if TC_SINGLE_WARP
|
#if TC_SINGLE_WARP
|
||||||
|
|||||||
Reference in New Issue
Block a user