sgemm_tcore: Support two accumulation reg tiles
This commit is contained in:
@@ -27,10 +27,10 @@
|
|||||||
// (BM*BN) / (TM*TN) == threadblock size >= NT * CORES_PER_CLUSTER
|
// (BM*BN) / (TM*TN) == threadblock size >= NT * CORES_PER_CLUSTER
|
||||||
// * Combining BM * BK >= (BM*BN) / (TM*TN) == threadblock yields
|
// * Combining BM * BK >= (BM*BN) / (TM*TN) == threadblock yields
|
||||||
// BM <= BK*TM*TN
|
// BM <= BK*TM*TN
|
||||||
#define BM 8
|
#define BM 32
|
||||||
#define BN 8
|
#define BN 32
|
||||||
#define BK 8
|
#define BK 32
|
||||||
#define WM 8
|
#define WM 16
|
||||||
#define WN 8
|
#define WN 8
|
||||||
#define TCM 8
|
#define TCM 8
|
||||||
#define TCN 8
|
#define TCN 8
|
||||||
@@ -133,8 +133,12 @@ inline constexpr void map_c(const int tid, int &row, int &col) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void vx_wmma() {
|
inline void vx_wmma(const int dest_reg) {
|
||||||
asm volatile (".insn r %0, 0, 0, x0, x0, x0" :: "i"(RISCV_CUSTOM3));
|
if (dest_reg == 0) {
|
||||||
|
asm volatile (".insn r %0, 0, 0, x0, x0, x0" :: "i"(RISCV_CUSTOM3));
|
||||||
|
} else {
|
||||||
|
asm volatile (".insn r %0, 0, 0, x1, x0, x0" :: "i"(RISCV_CUSTOM3));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// `local_k` is assumed to be multiple of TCK
|
// `local_k` is assumed to be multiple of TCK
|
||||||
@@ -196,23 +200,35 @@ inline void vx_wmma_load(volatile float *smem_A, volatile float *smem_B, const i
|
|||||||
asm volatile("flw f15, %0" ::"m"(smem_B[((local_k + 7) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col]));
|
asm volatile("flw f15, %0" ::"m"(smem_B[((local_k + 7) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col]));
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void initialize_C() {
|
inline void initialize_C(const int dest_reg) {
|
||||||
// initialize C to zeros
|
// initialize C to zeros
|
||||||
asm volatile("fmv.w.x f16, x0");
|
if (dest_reg == 0) {
|
||||||
asm volatile("fmv.w.x f17, x0");
|
asm volatile("fmv.w.x f16, x0");
|
||||||
asm volatile("fmv.w.x f18, x0");
|
asm volatile("fmv.w.x f17, x0");
|
||||||
asm volatile("fmv.w.x f19, x0");
|
asm volatile("fmv.w.x f18, x0");
|
||||||
asm volatile("fmv.w.x f20, x0");
|
asm volatile("fmv.w.x f19, x0");
|
||||||
asm volatile("fmv.w.x f21, x0");
|
asm volatile("fmv.w.x f20, x0");
|
||||||
asm volatile("fmv.w.x f22, x0");
|
asm volatile("fmv.w.x f21, x0");
|
||||||
asm volatile("fmv.w.x f23, x0");
|
asm volatile("fmv.w.x f22, x0");
|
||||||
|
asm volatile("fmv.w.x f23, x0");
|
||||||
|
} else {
|
||||||
|
asm volatile("fmv.w.x f24, x0");
|
||||||
|
asm volatile("fmv.w.x f25, x0");
|
||||||
|
asm volatile("fmv.w.x f26, x0");
|
||||||
|
asm volatile("fmv.w.x f27, x0");
|
||||||
|
asm volatile("fmv.w.x f28, x0");
|
||||||
|
asm volatile("fmv.w.x f29, x0");
|
||||||
|
asm volatile("fmv.w.x f30, x0");
|
||||||
|
asm volatile("fmv.w.x f31, x0");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void write_results(volatile float *local_warp_results,
|
inline void write_results(volatile float *local_warp_results,
|
||||||
int thread_in_warp, int warp_col, int warp_row,
|
const int thread_in_warp, const int warp_col,
|
||||||
int wn_iter, int wm_iter, int dim_m, int dim_n,
|
const int warp_row, const int wn_iter,
|
||||||
float *C, int threadblock_id_x,
|
const int wm_iter, const int dim_m, const int dim_n,
|
||||||
int threadblock_id_y) {
|
float *C, const int threadblock_id_x,
|
||||||
|
const int threadblock_id_y) {
|
||||||
int tid = thread_in_warp;
|
int tid = thread_in_warp;
|
||||||
int tg = tid / 4;
|
int tg = tid / 4;
|
||||||
|
|
||||||
@@ -229,14 +245,25 @@ inline void write_results(volatile float *local_warp_results,
|
|||||||
BN * threadblock_id_x;
|
BN * threadblock_id_x;
|
||||||
|
|
||||||
// @perf: this likely causes a lot of gmem bank conflicts
|
// @perf: this likely causes a lot of gmem bank conflicts
|
||||||
asm volatile ("fsw f16, %0" :: "m"(global_offset_C[dim_n * (local_row + 0) + (local_col + 0)]));
|
if (wm_iter == 0) {
|
||||||
asm volatile ("fsw f17, %0" :: "m"(global_offset_C[dim_n * (local_row + 0) + (local_col + 1)]));
|
asm volatile ("fsw f16, %0" :: "m"(global_offset_C[dim_n * (local_row + 0) + (local_col + 0)]));
|
||||||
asm volatile ("fsw f18, %0" :: "m"(global_offset_C[dim_n * (local_row + 2) + (local_col + 0)]));
|
asm volatile ("fsw f17, %0" :: "m"(global_offset_C[dim_n * (local_row + 0) + (local_col + 1)]));
|
||||||
asm volatile ("fsw f19, %0" :: "m"(global_offset_C[dim_n * (local_row + 2) + (local_col + 1)]));
|
asm volatile ("fsw f18, %0" :: "m"(global_offset_C[dim_n * (local_row + 2) + (local_col + 0)]));
|
||||||
asm volatile ("fsw f20, %0" :: "m"(global_offset_C[dim_n * (local_row + 0) + (local_col + 4)]));
|
asm volatile ("fsw f19, %0" :: "m"(global_offset_C[dim_n * (local_row + 2) + (local_col + 1)]));
|
||||||
asm volatile ("fsw f21, %0" :: "m"(global_offset_C[dim_n * (local_row + 0) + (local_col + 5)]));
|
asm volatile ("fsw f20, %0" :: "m"(global_offset_C[dim_n * (local_row + 0) + (local_col + 4)]));
|
||||||
asm volatile ("fsw f22, %0" :: "m"(global_offset_C[dim_n * (local_row + 2) + (local_col + 4)]));
|
asm volatile ("fsw f21, %0" :: "m"(global_offset_C[dim_n * (local_row + 0) + (local_col + 5)]));
|
||||||
asm volatile ("fsw f23, %0" :: "m"(global_offset_C[dim_n * (local_row + 2) + (local_col + 5)]));
|
asm volatile ("fsw f22, %0" :: "m"(global_offset_C[dim_n * (local_row + 2) + (local_col + 4)]));
|
||||||
|
asm volatile ("fsw f23, %0" :: "m"(global_offset_C[dim_n * (local_row + 2) + (local_col + 5)]));
|
||||||
|
} else {
|
||||||
|
asm volatile ("fsw f24, %0" :: "m"(global_offset_C[dim_n * (local_row + 0) + (local_col + 0)]));
|
||||||
|
asm volatile ("fsw f25, %0" :: "m"(global_offset_C[dim_n * (local_row + 0) + (local_col + 1)]));
|
||||||
|
asm volatile ("fsw f26, %0" :: "m"(global_offset_C[dim_n * (local_row + 2) + (local_col + 0)]));
|
||||||
|
asm volatile ("fsw f27, %0" :: "m"(global_offset_C[dim_n * (local_row + 2) + (local_col + 1)]));
|
||||||
|
asm volatile ("fsw f28, %0" :: "m"(global_offset_C[dim_n * (local_row + 0) + (local_col + 4)]));
|
||||||
|
asm volatile ("fsw f29, %0" :: "m"(global_offset_C[dim_n * (local_row + 0) + (local_col + 5)]));
|
||||||
|
asm volatile ("fsw f30, %0" :: "m"(global_offset_C[dim_n * (local_row + 2) + (local_col + 4)]));
|
||||||
|
asm volatile ("fsw f31, %0" :: "m"(global_offset_C[dim_n * (local_row + 2) + (local_col + 5)]));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void threadblock_barrier(unsigned int tid_in_threadblock,
|
inline void threadblock_barrier(unsigned int tid_in_threadblock,
|
||||||
@@ -349,7 +376,8 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
|
|||||||
local_b + local_b_elems + (warp_in_threadblock * TCM * TCN);
|
local_b + local_b_elems + (warp_in_threadblock * TCM * TCN);
|
||||||
|
|
||||||
// clear out C
|
// clear out C
|
||||||
initialize_C();
|
initialize_C(0);
|
||||||
|
initialize_C(1);
|
||||||
|
|
||||||
#pragma GCC unroll 1
|
#pragma GCC unroll 1
|
||||||
for (uint32_t k = 0; k < dim_k; k += BK) {
|
for (uint32_t k = 0; k < dim_k; k += BK) {
|
||||||
@@ -394,7 +422,7 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
|
|||||||
vx_wmma_load(local_a, local_b, local_k, warp_col, warp_row,
|
vx_wmma_load(local_a, local_b, local_k, warp_col, warp_row,
|
||||||
wn_iter, wm_iter, tid_in_warp);
|
wn_iter, wm_iter, tid_in_warp);
|
||||||
// compute
|
// compute
|
||||||
vx_wmma();
|
vx_wmma(wm_iter);
|
||||||
#if TC_SINGLE_WARP
|
#if TC_SINGLE_WARP
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|||||||
Reference in New Issue
Block a user