sgemm_tcore: Experiment with high K; 48% util

This commit is contained in:
Hansung Kim
2024-06-06 18:38:43 -07:00
parent 062403066e
commit 7c4d850074

View File

@@ -363,7 +363,7 @@ inline void global_dmem_load(const uint32_t dim_n, const uint32_t dim_k,
const float *global_a = A + dim_k * global_a_row + (k + local_a_col);
volatile float *local_a_tmp = local_a + BK * local_a_row + local_a_col;
#pragma GCC unroll 2
#pragma GCC unroll 1
for (uint32_t local_row_offset = 0; local_row_offset < BM;
local_row_offset += row_stride_a) {
// const uint32_t global_a_offset =
@@ -392,7 +392,7 @@ inline void global_dmem_load(const uint32_t dim_n, const uint32_t dim_k,
(BK % (row_stride_as * 8)) == 0,
"manual loop unrolling condition not met; BK should be power-of-two");
#pragma GCC unroll 2
#pragma GCC unroll 1
for (uint32_t local_row_offset = 0; local_row_offset < BK;
local_row_offset += row_stride_as * 8) {
// @perf: bank conflicts here
@@ -446,7 +446,7 @@ inline void global_dmem_load(const uint32_t dim_n, const uint32_t dim_k,
(BM % (row_stride_a * 8)) == 0,
"manual loop unrolling condition not met; BM should be power-of-two");
#pragma GCC unroll 4
#pragma GCC unroll 1
for (uint32_t local_row_offset = 0; local_row_offset < BM;
local_row_offset += row_stride_a * 8) {
// const uint32_t global_a_offset =
@@ -498,7 +498,7 @@ inline void global_dmem_load(const uint32_t dim_n, const uint32_t dim_k,
(BK % (row_stride_b * 8)) == 0,
"manual loop unrolling condition not met; BK should be power-of-two");
#pragma GCC unroll 2
#pragma GCC unroll 1
for (uint32_t load_offset = 0; load_offset < BK;
load_offset += row_stride_b * 8) {
// const uint32_t global_b_offset =
@@ -609,7 +609,7 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
// right-shift
int32_t k_index = 0;
#pragma GCC unroll 1
for (uint32_t k = 0; k < (dim_k) - BK; k += BK) {
for (uint32_t k = 0; k < (8 * dim_k) - BK; k += BK) {
volatile float *local_a_produce;
volatile float *local_b_produce;
if constexpr (DOUBLE_BUFFER) {
@@ -656,7 +656,7 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
// right-shift
int32_t k_index = 0;
#pragma GCC unroll 1
for (uint32_t k = 0; k < (dim_k); k += BK) {
for (uint32_t k = 0; k < (8 * dim_k); k += BK) {
volatile float *local_a_consume;
volatile float *local_b_consume;
if constexpr (DOUBLE_BUFFER) {
@@ -682,19 +682,19 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
// vx_wmma_load
#pragma GCC unroll 1
for (int i = 0; i < BK_LOOP; i++) {
#pragma GCC unroll 1
#pragma GCC unroll 2
for (uint32_t local_k = 0; local_k < BK; local_k += TCK) {
// perform wmma
// vx_wmma_load(local_a_consume, local_b_consume, warp_x, warp_y,
// tid_in_warp);
// FIXME: this is wrong!! need separate accumulation register for
// WM/WN_ITERS
#pragma GCC unroll 1
#pragma GCC unroll 2
for (int wn_iter = 0; wn_iter < WNITER; wn_iter++) {
vx_wmma_load_b(local_b_consume, local_k, warp_col, wn_iter,
tid_in_warp);
// vx_wmma_load_b(local_b_consume, 0, 0, 0, tid_in_warp);
#pragma GCC unroll 1
#pragma GCC unroll 2
for (int wm_iter = 0; wm_iter < WMITER; wm_iter++) {
// if ((threadblock_id_in_cluster % 2) == 0) {
// asm volatile("addi a0, a0, 0");