From 7c4d850074ef5820c9bec91df5522ba01e989dbd Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Thu, 6 Jun 2024 18:38:43 -0700 Subject: [PATCH] sgemm_tcore: Experiment with high K; 48% util --- tests/regression/sgemm_tcore/kernel.cpp | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/tests/regression/sgemm_tcore/kernel.cpp b/tests/regression/sgemm_tcore/kernel.cpp index 11187644..6db7ae3d 100644 --- a/tests/regression/sgemm_tcore/kernel.cpp +++ b/tests/regression/sgemm_tcore/kernel.cpp @@ -363,7 +363,7 @@ inline void global_dmem_load(const uint32_t dim_n, const uint32_t dim_k, const float *global_a = A + dim_k * global_a_row + (k + local_a_col); volatile float *local_a_tmp = local_a + BK * local_a_row + local_a_col; -#pragma GCC unroll 2 +#pragma GCC unroll 1 for (uint32_t local_row_offset = 0; local_row_offset < BM; local_row_offset += row_stride_a) { // const uint32_t global_a_offset = @@ -392,7 +392,7 @@ inline void global_dmem_load(const uint32_t dim_n, const uint32_t dim_k, (BK % (row_stride_as * 8)) == 0, "manual loop unrolling condition not met; BK should be power-of-two"); -#pragma GCC unroll 2 +#pragma GCC unroll 1 for (uint32_t local_row_offset = 0; local_row_offset < BK; local_row_offset += row_stride_as * 8) { // @perf: bank conflicts here @@ -446,7 +446,7 @@ inline void global_dmem_load(const uint32_t dim_n, const uint32_t dim_k, (BM % (row_stride_a * 8)) == 0, "manual loop unrolling condition not met; BM should be power-of-two"); -#pragma GCC unroll 4 +#pragma GCC unroll 1 for (uint32_t local_row_offset = 0; local_row_offset < BM; local_row_offset += row_stride_a * 8) { // const uint32_t global_a_offset = @@ -498,7 +498,7 @@ inline void global_dmem_load(const uint32_t dim_n, const uint32_t dim_k, (BK % (row_stride_b * 8)) == 0, "manual loop unrolling condition not met; BK should be power-of-two"); -#pragma GCC unroll 2 +#pragma GCC unroll 1 for (uint32_t load_offset = 0; load_offset < BK; load_offset += row_stride_b * 8) { // const uint32_t global_b_offset = @@ -609,7 +609,7 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, // right-shift int32_t k_index = 0; #pragma GCC unroll 1 - for (uint32_t k = 0; k < (dim_k) - BK; k += BK) { + for (uint32_t k = 0; k < (8 * dim_k) - BK; k += BK) { volatile float *local_a_produce; volatile float *local_b_produce; if constexpr (DOUBLE_BUFFER) { @@ -656,7 +656,7 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, // right-shift int32_t k_index = 0; #pragma GCC unroll 1 - for (uint32_t k = 0; k < (dim_k); k += BK) { + for (uint32_t k = 0; k < (8 * dim_k); k += BK) { volatile float *local_a_consume; volatile float *local_b_consume; if constexpr (DOUBLE_BUFFER) { @@ -682,19 +682,19 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, // vx_wmma_load #pragma GCC unroll 1 for (int i = 0; i < BK_LOOP; i++) { -#pragma GCC unroll 1 +#pragma GCC unroll 2 for (uint32_t local_k = 0; local_k < BK; local_k += TCK) { // perform wmma // vx_wmma_load(local_a_consume, local_b_consume, warp_x, warp_y, // tid_in_warp); // FIXME: this is wrong!! need separate accumulation register for // WM/WN_ITERS -#pragma GCC unroll 1 +#pragma GCC unroll 2 for (int wn_iter = 0; wn_iter < WNITER; wn_iter++) { vx_wmma_load_b(local_b_consume, local_k, warp_col, wn_iter, tid_in_warp); // vx_wmma_load_b(local_b_consume, 0, 0, 0, tid_in_warp); -#pragma GCC unroll 1 +#pragma GCC unroll 2 for (int wm_iter = 0; wm_iter < WMITER; wm_iter++) { // if ((threadblock_id_in_cluster % 2) == 0) { // asm volatile("addi a0, a0, 0");