From 09b23ffe87a8398f9accba2a63e90aacc55ae4ee Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Mon, 13 May 2024 14:52:33 -0700 Subject: [PATCH] sgemm_tg: 1-octet 8-lane kernel --- tests/regression/sgemm_tcore/kernel.cpp | 78 +++++++++---------------- 1 file changed, 29 insertions(+), 49 deletions(-) diff --git a/tests/regression/sgemm_tcore/kernel.cpp b/tests/regression/sgemm_tcore/kernel.cpp index 8913d95a..1484d555 100644 --- a/tests/regression/sgemm_tcore/kernel.cpp +++ b/tests/regression/sgemm_tcore/kernel.cpp @@ -18,14 +18,16 @@ // (BM*BN) / (TM*TN) == threadblock size >= NT * CORES_PER_CLUSTER // * Combining BM * BK >= (BM*BN) / (TM*TN) == threadblock yields // BM <= BK*TM*TN -#define BM 16 +#define BM 8 #define BN BM #define BK 8 -#define TCM 16 -#define TCN 16 +#define TCM 8 +#define TCN 8 #define TM 1 #define TN 1 +#define NUM_LANES 8 + inline constexpr void map_operand_32lanes(const int tid, int &row, int &col) { const int tg = tid / 4; @@ -67,6 +69,16 @@ inline constexpr void map_operand_8lanes(const int tid, int &row, int &col) { col += tg * 4; } +inline constexpr void map_operand(const int tid, int &row, int &col) { + if constexpr (NUM_LANES == 32) { + map_operand_32lanes(tid, row, col); + } else if constexpr (NUM_LANES == 8) { + map_operand_8lanes(tid, row, col); + } else { + // FIXME: not allowed + } +} + inline constexpr void map_c_32lanes(const int tid, int &row, int &col) { const int tg = tid / 4; @@ -93,6 +105,16 @@ inline constexpr void map_c_8lanes(const int tid, int &row, int &col) { col += ((tid % 4) / 2) * 2; } +inline constexpr void map_c(const int tid, int &row, int &col) { + if constexpr (NUM_LANES == 32) { + map_c_32lanes(tid, row, col); + } else if constexpr (NUM_LANES == 8) { + map_c_8lanes(tid, row, col); + } else { + // FIXME: not allowed + } +} + inline void vx_wmma() { asm volatile (".insn r %0, 0, 0, x0, x0, x0" :: "i"(RISCV_CUSTOM3)); } @@ -104,7 +126,7 @@ void vx_wmma_load(volatile float *smem_A, volatile float *smem_B, int warp_x, int row = 0; int col = 0; - map_operand_32lanes(tid, row, col); + map_operand(tid, row, col); int smem_A_rows = BM; int smem_A_cols = BK; @@ -154,8 +176,7 @@ inline void write_results(volatile float *local_warp_results, // these are [0, TCM/TCN) int local_row = 0; int local_col = 0; - - map_c_32lanes(tid, local_row, local_col); + map_c(tid, local_row, local_col); float *global_offset_C = C + (BM * threadblock_id_y) * dim_n + @@ -189,19 +210,10 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, const float *B = (const float *)arg->addr_b; float *C = (float *)arg->addr_c; - // assumes NT == NW == matrix_dim const uint32_t dim_m = arg->dim_m; const uint32_t dim_n = arg->dim_n; const uint32_t dim_k = arg->dim_k; - // FIXME: Output block size is assumed to be square, i.e. BM == BN - // const uint32_t BM = threadblock_dim_y; - // const uint32_t BN = threadblock_dim_y; - // const uint32_t BK = threadblock_dim_x; - // constexpr uint32_t BM = 8; - // constexpr uint32_t BN = 8; - // constexpr uint32_t BK = 2; - const uint32_t local_a_row = tid_in_threadblock / BK; const uint32_t local_a_col = tid_in_threadblock % BK; const uint32_t local_b_row = tid_in_threadblock / BN; @@ -217,8 +229,8 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, float reg_a[TM] = { 0.0f }; float reg_b[TN] = { 0.0f }; - const uint32_t warp_in_threadblock = tid_in_threadblock / 32; - const uint32_t tid_in_warp = tid_in_threadblock % 32; + const uint32_t warp_in_threadblock = tid_in_threadblock / NUM_LANES; + const uint32_t tid_in_warp = tid_in_threadblock % NUM_LANES; const uint32_t warp_x = warp_in_threadblock % 2; const uint32_t warp_y = warp_in_threadblock / 2; @@ -272,38 +284,6 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, vx_wmma(); } -#if 0 - // Compute single tile*tile matmul -#pragma GCC unroll 4 - for (uint32_t local_k = 0; local_k < BK; local_k++) { - // First, pump data from SMEM->RF -#pragma GCC unroll TM - for (uint32_t res_idx_m = 0; res_idx_m < TM; res_idx_m++) { - reg_a[res_idx_m] = - local_a[BK * (TM * local_c_row + res_idx_m) + local_k]; - } -#pragma GCC unroll TN - for (uint32_t res_idx_n = 0; res_idx_n < TN; res_idx_n++) { - reg_b[res_idx_n] = - local_b[BN * local_k + (TN * local_c_col + res_idx_n)]; - } - - // Next, compute multiple result elements (TM*TN) by reusing data in RF -#pragma GCC unroll TM - for (uint32_t res_idx_m = 0; res_idx_m < TM; res_idx_m++) { -#pragma GCC unroll TN - for (uint32_t res_idx_n = 0; res_idx_n < TN; res_idx_n++) { - // NOTE use of local_b_row - reg_c[TN * res_idx_m + res_idx_n] += - reg_a[res_idx_m] * reg_b[res_idx_n]; - // reg_c[TN * res_idx_m + res_idx_n] += - // local_a[BK * (TM * local_c_row + res_idx_m) + local_k] * - // local_b[BN * local_k + (TN * local_c_col + res_idx_n)]; - } - } - } -#endif - threadblock_barrier(tid_in_threadblock, threadblock_id_in_cluster, threadblock_dim_y); }