From a06b2dd20ea702f4e3824cb519a0081d46312cfc Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Wed, 28 Feb 2024 21:17:42 -0800 Subject: [PATCH] sgemm_wg: Cleanup & proper unroll --- tests/regression/sgemm_wg/kernel.cpp | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/tests/regression/sgemm_wg/kernel.cpp b/tests/regression/sgemm_wg/kernel.cpp index 69ef9f14..9b767d35 100644 --- a/tests/regression/sgemm_wg/kernel.cpp +++ b/tests/regression/sgemm_wg/kernel.cpp @@ -40,30 +40,30 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, const uint32_t global_a_row = BM * threadblock_id_y + local_a_row; const uint32_t global_b_col = BN * threadblock_id_x + local_b_col; - // each thread generates one output element + // each thread generates TM output element float reg_c[TM] = { 0.0f }; - for (uint32_t k = 0; k < dim_k; k += BK) { - float *local_a = sharedmem_per_threadblock; - size_t local_a_elems = threadblock_dim_x * threadblock_dim_y; - float *local_b = sharedmem_per_threadblock + local_a_elems; + volatile float *local_a = sharedmem_per_threadblock; + const size_t local_a_elems = threadblock_dim_x * threadblock_dim_y; + volatile float *local_b = sharedmem_per_threadblock + local_a_elems; + for (uint32_t k = 0; k < dim_k; k += BK) { uint32_t global_a_offset = dim_k * global_a_row + (k + local_a_col); uint32_t global_b_offset = dim_n * (k + local_b_row) + global_b_col; - // NOTE: local_b is transposed to column-major to facilitate better memory - // access. local_a[BK * local_a_row + local_a_col] = A[global_a_offset]; local_b[BN * local_b_row + local_b_col] = B[global_b_offset]; vx_barrier(threadblock_id_in_core, threadblock_dim_y); vx_fence(); +#pragma GCC unroll TM for (uint32_t local_k = 0; local_k < BK; local_k++) { // Compute multiple result elements (TM) per thread const float local_b_tmp = local_b[BN * local_k + local_b_col]; -#pragma GCC unroll 4 +#pragma GCC unroll TM for (uint32_t result_idx = 0; result_idx < TM; result_idx++) { + // NOTE use of local_b_row reg_c[result_idx] += local_a[BK * (TM * local_b_row + result_idx) + local_k] * local_b_tmp; @@ -74,8 +74,9 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, vx_fence(); } -#pragma GCC unroll 4 +#pragma GCC unroll TM for (uint32_t result_idx = 0; result_idx < TM; result_idx++) { + // NOTE use of local_b_row and global_b_col here C[dim_n * (BM * threadblock_id_y + TM * local_b_row + result_idx) + global_b_col] = reg_c[result_idx]; }