From d8944db36950d306583cd5074ebbd34990613d74 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Tue, 4 Jun 2024 18:23:27 -0700 Subject: [PATCH] sgemm_tcore: Double-buffer over K-dimension TODO: Not completely parameterized with DOUBLE_BUFFER yet. --- tests/regression/sgemm_tcore/kernel.cpp | 241 +++++++++++++++--------- tests/regression/sgemm_tcore/main.cpp | 6 +- 2 files changed, 155 insertions(+), 92 deletions(-) diff --git a/tests/regression/sgemm_tcore/kernel.cpp b/tests/regression/sgemm_tcore/kernel.cpp index 69451813..3e3bed78 100644 --- a/tests/regression/sgemm_tcore/kernel.cpp +++ b/tests/regression/sgemm_tcore/kernel.cpp @@ -15,6 +15,8 @@ #define BK_LOOP 1 #define TRANSPOSE_AS 1 +#define DOUBLE_BUFFER 1 + // Constraints on parameters: // * Memory: // (BM + BN) * BK * sizeof(float) <= sharedmem size. @@ -29,7 +31,7 @@ // BM <= BK*TM*TN #define BM 32 #define BN 32 -#define BK 32 +#define BK 8 #define WM 16 #define WN 8 #define TCM 8 @@ -44,7 +46,12 @@ #define TM 1 #define TN 1 #endif -#define ELEM_PER_THREAD (WMITER * WNITER * TM * TN) +#define ELEM_PER_THREAD (WMITER * WNITER * TM * TN / (DOUBLE_BUFFER ? 2 : 1)) + +// FIXME: NUM_THREADS and NUM_WARPS hardcoded +#if ((BM * BN / ELEM_PER_THREAD) > (CORES_PER_CLUSTER * 8 * 8)) +#error "threadblock size too big for cluster" +#endif inline constexpr void map_operand_32lanes(const int tid, int &row, int &col) { const int tg = tid / 4; @@ -156,8 +163,6 @@ inline void vx_wmma_load_a(volatile float *smem_A, const int local_k, constexpr int smem_A_cols = BK; constexpr int smem_AS_rows = BK; constexpr int smem_AS_cols = BM; - constexpr int smem_B_rows = BK; - constexpr int smem_B_cols = BN; if constexpr (!TRANSPOSE_AS) { int A_offset = (WM * warp_row + TCM * wm_iter + row) * smem_A_cols; @@ -201,10 +206,6 @@ inline void vx_wmma_load_b(volatile float *smem_B, const int local_k, int col = 0; map_operand(tid, row, col); - constexpr int smem_A_rows = BM; - constexpr int smem_A_cols = BK; - constexpr int smem_AS_rows = BK; - constexpr int smem_AS_cols = BM; constexpr int smem_B_rows = BK; constexpr int smem_B_cols = BN; @@ -294,11 +295,21 @@ inline void threadblock_barrier(unsigned int tid_in_threadblock, inline void global_dmem_load(const uint32_t dim_n, const uint32_t dim_k, const uint32_t k, const float *A, const float *B, volatile float *local_a, - volatile float *local_b, const uint32_t threadblock_id_x, - const uint32_t threadblock_id_y, const uint32_t local_a_row, - const uint32_t local_a_col, const uint32_t local_as_row, - const uint32_t local_as_col, const uint32_t local_b_row, - const uint32_t local_b_col) { + volatile float *local_b, const uint32_t tid_in_threadblock, + const uint32_t threadblock_id_x, + const uint32_t threadblock_id_y) { + constexpr uint32_t BM_d = BM; + constexpr uint32_t BN_d = BN; + + const uint32_t local_a_row = tid_in_threadblock / BK; + const uint32_t local_a_col = tid_in_threadblock % BK; + const uint32_t local_as_row = tid_in_threadblock / BM; + const uint32_t local_as_col = tid_in_threadblock % BM; + const uint32_t local_b_row = tid_in_threadblock / BN; + const uint32_t local_b_col = tid_in_threadblock % BN; + + constexpr uint32_t threads_in_warpgroup = + (BM * BN) / ELEM_PER_THREAD / (DOUBLE_BUFFER ? 2 : 1); // FIXME // Data move from GMEM to SMEM // @@ -307,24 +318,24 @@ global_dmem_load(const uint32_t dim_n, const uint32_t dim_k, const uint32_t k, // // TODO: Sharedmem swizzling is important here if constexpr (!TRANSPOSE_AS) { - const uint32_t global_a_row = BM * threadblock_id_y + local_a_row; + const uint32_t global_a_row = BM_d * threadblock_id_y + local_a_row; // number of rows a full TB can read at a time - constexpr uint32_t row_stride_a = (BM * BN) / ELEM_PER_THREAD / BK; + constexpr uint32_t row_stride_a = threads_in_warpgroup / BK; #pragma GCC unroll 1 - for (uint32_t local_row_offset = 0; local_row_offset < BM; + for (uint32_t local_row_offset = 0; local_row_offset < BM_d; local_row_offset += row_stride_a) { const uint32_t global_a_offset = dim_k * (global_a_row + local_row_offset) + (k + local_a_col); // NOTE: all threads in TB will do this load; make sure this is not - // out-of-bounds of BM*BK + // out-of-bounds of BM_d*BK local_a[BK * (local_a_row + local_row_offset) + local_a_col] = A[global_a_offset]; } } else { - const uint32_t global_a_row = BM * threadblock_id_y + local_as_col; - // const uint32_t global_a_row = BM * threadblock_id_y + local_as_row; - constexpr uint32_t row_stride_as = (BM * BN) / ELEM_PER_THREAD / BM; -#pragma GCC unroll 1 + const uint32_t global_a_row = BM_d * threadblock_id_y + local_as_col; + // const uint32_t global_a_row = BM_d * threadblock_id_y + local_as_row; + constexpr uint32_t row_stride_as = threads_in_warpgroup / BM_d; +#pragma GCC unroll 4 for (uint32_t local_row_offset = 0; local_row_offset < BK; local_row_offset += row_stride_as) { // @perf: bank conflicts here @@ -333,25 +344,26 @@ global_dmem_load(const uint32_t dim_n, const uint32_t dim_k, const uint32_t k, // FIXME experimenting with global coalescing // const uint32_t global_a_offset = // dim_k * (global_a_row + local_row_offset) + (k + local_as_col); - local_a[BM * (local_as_row + local_row_offset) + local_as_col] = + local_a[BM_d * (local_as_row + local_row_offset) + local_as_col] = A[global_a_offset]; } } - constexpr uint32_t row_stride_b = (BM * BN) / ELEM_PER_THREAD / BN; - const uint32_t global_b_col = BN * threadblock_id_x + local_b_col; -#pragma GCC unroll 1 + constexpr uint32_t row_stride_b = threads_in_warpgroup / BN_d; + const uint32_t global_b_col = BN_d * threadblock_id_x + local_b_col; +#pragma GCC unroll 2 for (uint32_t load_offset = 0; load_offset < BK; load_offset += row_stride_b) { const uint32_t global_b_offset = dim_n * (k + local_b_row + load_offset) + global_b_col; - local_b[BN * (local_b_row + load_offset) + local_b_col] = + local_b[BN_d * (local_b_row + load_offset) + local_b_col] = B[global_b_offset]; } } void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, const uint32_t tid_in_threadblock, + const uint32_t threads_per_threadblock, const uint32_t threadblock_dim_x, const uint32_t threadblock_dim_y, const uint32_t threadblock_id_x, @@ -376,14 +388,20 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, const uint32_t local_c_row = tid_in_threadblock / (BN / TN); const uint32_t local_c_col = tid_in_threadblock % (BN / TN); +#if !USE_TENSOR_CORE // each thread generates TM output element float reg_c[TM * TN] = { 0.0f }; float reg_a[TM] = { 0.0f }; float reg_b[TN] = { 0.0f }; +#endif - const uint32_t warp_in_threadblock = tid_in_threadblock / NUM_LANES; - const uint32_t warp_row = warp_in_threadblock / (BN / WN); - const uint32_t warp_col = warp_in_threadblock % (BN / WN); + const uint32_t threads_per_warpgroup = threads_per_threadblock / (DOUBLE_BUFFER ? 2 : 1); + const uint32_t warpgroup_id = tid_in_threadblock / threads_per_warpgroup; + const uint32_t tid_in_warpgroup = tid_in_threadblock % threads_per_warpgroup; // FIXME + const uint32_t warp_in_warpgroup = tid_in_warpgroup / NUM_LANES; + // FIXME: warp_row / BN should be warp-specialized? + const uint32_t warp_row = warp_in_warpgroup / (BN / WN); + const uint32_t warp_col = warp_in_warpgroup % (BN / WN); const uint32_t tid_in_warp = tid_in_threadblock % NUM_LANES; volatile float *local_a = sharedmem_per_threadblock; @@ -391,69 +409,109 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, const size_t local_a_elems = (BM * BK); volatile float *local_b = sharedmem_per_threadblock + local_a_elems; const size_t local_b_elems = (BK * BN); + + volatile float *local_a_buf = local_b + local_b_elems; + volatile float *local_b_buf = local_a_buf + local_a_elems; + volatile float *local_warp_results = - local_b + local_b_elems + (warp_in_threadblock * TCM * TCN); + local_b_buf + local_b_elems + (warp_in_warpgroup * TCM * TCN); // clear out C initialize_C(0); initialize_C(1); -#pragma GCC unroll 1 - for (uint32_t k = 0; k < dim_k; k += BK) { - global_dmem_load(dim_n, dim_k, k, A, B, local_a, local_b, - threadblock_id_x, threadblock_id_y, local_a_row, - local_a_col, local_as_row, local_as_col, local_b_row, - local_b_col); - - threadblock_barrier(tid_in_threadblock, threadblock_id_in_cluster, - threadblock_dim_y); - -#if USE_TENSOR_CORE - // @perf: this loop spills to stack a lot because of all the flws in - // vx_wmma_load -#pragma GCC unroll 1 - for (int i = 0; i < BK_LOOP; i++) { -#pragma GCC unroll 1 - for (uint32_t local_k = 0; local_k < BK; local_k += TCK) { - // perform wmma - // vx_wmma_load(local_a, local_b, warp_x, warp_y, tid_in_warp); - // FIXME: this is wrong!! need separate accumulation register for - // WM/WN_ITERS -#pragma GCC unroll 2 - for (int wn_iter = 0; wn_iter < WNITER; wn_iter++) { - vx_wmma_load_b(local_b, local_k, warp_col, wn_iter, tid_in_warp); - // vx_wmma_load_b(local_b, 0, 0, 0, tid_in_warp); -#pragma GCC unroll 2 - for (int wm_iter = 0; wm_iter < WMITER; wm_iter++) { -#if TC_SINGLE_WARP - if (warp_in_threadblock == 0) { -#endif - // if ((threadblock_id_in_cluster % 2) == 0) { - // asm volatile("addi a0, a0, 0"); - // asm volatile("addi a0, a0, 0"); - // asm volatile("addi a0, a0, 0"); - // asm volatile("addi a0, a0, 0"); - // asm volatile("addi a0, a0, 0"); - // asm volatile("addi a0, a0, 0"); - // asm volatile("addi a0, a0, 0"); - // asm volatile("addi a0, a0, 0"); - // asm volatile("addi a0, a0, 0"); - // } - // SMEM -> RF - vx_wmma_load_a(local_a, local_k, warp_row, wm_iter, tid_in_warp); - // vx_wmma_load_a(local_a, 0, 0, 0, tid_in_warp); - // compute - vx_wmma(wm_iter); -#if TC_SINGLE_WARP - } -#endif - } - } - } + if constexpr (DOUBLE_BUFFER) { + // initiate software pipeline + if (warpgroup_id == 0) { + global_dmem_load(dim_n, dim_k, 0 /*k*/, A, B, local_a, local_b, + tid_in_warpgroup, threadblock_id_x, threadblock_id_y); } threadblock_barrier(tid_in_threadblock, threadblock_id_in_cluster, threadblock_dim_y); + } + + uint32_t k_index = 0; + +#pragma GCC unroll 1 + for (uint32_t k = 0; k < dim_k; k += BK) { + volatile float *local_a_produce; + volatile float *local_b_produce; + volatile float *local_a_consume; + volatile float *local_b_consume; + if constexpr (DOUBLE_BUFFER) { + local_a_produce = (k_index % 2) ? local_a : local_a_buf; + local_b_produce = (k_index % 2) ? local_b : local_b_buf; + local_a_consume = (k_index % 2) ? local_a_buf : local_a; + local_b_consume = (k_index % 2) ? local_b_buf : local_b; + } else { + local_a_produce = local_a; + local_b_produce = local_b; + local_a_consume = local_a; + local_b_consume = local_b; + } + k_index++; + + if (warpgroup_id == 0) { + if (k != (dim_k - BK)) { + global_dmem_load(dim_n, dim_k, k + BK /*runahead*/, A, B, + local_a_produce, local_b_produce, tid_in_warpgroup, + threadblock_id_x, threadblock_id_y); + } + + threadblock_barrier(tid_in_threadblock, threadblock_id_in_cluster, + threadblock_dim_y); + } + + else { +#if USE_TENSOR_CORE + // @perf: this loop spills to stack a lot because of all the flws in + // vx_wmma_load +#pragma GCC unroll 1 + for (int i = 0; i < BK_LOOP; i++) { +#pragma GCC unroll 1 + for (uint32_t local_k = 0; local_k < BK; local_k += TCK) { + // perform wmma + // vx_wmma_load(local_a_consume, local_b_consume, warp_x, warp_y, tid_in_warp); + // FIXME: this is wrong!! need separate accumulation register for + // WM/WN_ITERS +#pragma GCC unroll 2 + for (int wn_iter = 0; wn_iter < WNITER; wn_iter++) { + vx_wmma_load_b(local_b_consume, local_k, warp_col, wn_iter, tid_in_warp); + // vx_wmma_load_b(local_b_consume, 0, 0, 0, tid_in_warp); +#pragma GCC unroll 1 + for (int wm_iter = 0; wm_iter < WMITER; wm_iter++) { +#if TC_SINGLE_WARP + if (warp_in_warpgroup == 0) { +#endif + // if ((threadblock_id_in_cluster % 2) == 0) { + // asm volatile("addi a0, a0, 0"); + // asm volatile("addi a0, a0, 0"); + // asm volatile("addi a0, a0, 0"); + // asm volatile("addi a0, a0, 0"); + // asm volatile("addi a0, a0, 0"); + // asm volatile("addi a0, a0, 0"); + // asm volatile("addi a0, a0, 0"); + // asm volatile("addi a0, a0, 0"); + // asm volatile("addi a0, a0, 0"); + // } + // SMEM -> RF + vx_wmma_load_a(local_a_consume, local_k, warp_row, wm_iter, + tid_in_warp); + // vx_wmma_load_a(local_a_consume, 0, 0, 0, tid_in_warp); + // compute + vx_wmma(wm_iter); +#if TC_SINGLE_WARP + } +#endif + } + } + } + } + + threadblock_barrier(tid_in_threadblock, threadblock_id_in_cluster, + threadblock_dim_y); + } #else @@ -498,11 +556,13 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, #pragma GCC unroll 1 for (int wn_iter = 0; wn_iter < WNITER; wn_iter++) { #if TC_SINGLE_WARP - if (warp_in_threadblock == 0) { + if (warp_in_warpgroup == 0) { #endif - write_results(local_warp_results, tid_in_warp, warp_col, warp_row, - wn_iter, wm_iter, dim_m, dim_n, C, threadblock_id_x, - threadblock_id_y); + if (warpgroup_id == 1) { + write_results(local_warp_results, tid_in_warp, warp_col, warp_row, + wn_iter, wm_iter, dim_m, dim_n, C, threadblock_id_x, + threadblock_id_y); + } #if TC_SINGLE_WARP } #endif @@ -554,9 +614,12 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { // occupancy of a single cluster float *sharedmem_per_threadblock = (float *)DEV_SMEM_START_ADDR + (2 * BM * BK) * threadblock_id_in_cluster; - thread_block_gemm(arg, tid_in_threadblock, threadblock_dim_x, - threadblock_dim_y, threadblock_id_x, threadblock_id_y, - threadblock_id_in_cluster, sharedmem_per_threadblock); + + const int warp_id = vx_warp_id(); + thread_block_gemm(arg, tid_in_threadblock, threads_per_threadblock, + threadblock_dim_x, threadblock_dim_y, threadblock_id_x, + threadblock_id_y, threadblock_id_in_cluster, + sharedmem_per_threadblock); } int main() { diff --git a/tests/regression/sgemm_tcore/main.cpp b/tests/regression/sgemm_tcore/main.cpp index 0fbd838b..e6f18317 100644 --- a/tests/regression/sgemm_tcore/main.cpp +++ b/tests/regression/sgemm_tcore/main.cpp @@ -155,9 +155,9 @@ int main(int argc, char *argv[]) { RT_CHECK(vx_dev_open(&device)); // FIXME: hardcoded - uint32_t dim_m = 16; - uint32_t dim_n = 16; - uint32_t dim_k = 16; + uint32_t dim_m = 32; + uint32_t dim_n = 32; + uint32_t dim_k = 32; generate_source_matrix(dim_m, dim_n, dim_k); generate_reference_matmul(dim_m, dim_n, dim_k);