From fa2b6e2ad0d27da4dfae778626a714281c2ad505 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Fri, 29 Mar 2024 02:48:29 -0700 Subject: [PATCH] sgemm_wg: Explicitly limit unroll to reduce stack spilling This needs to be done case-by-case for different BK/TM/TN combinations and examining the assembly. --- tests/regression/sgemm_wg/kernel.cpp | 40 +++++++++++++++++----------- 1 file changed, 24 insertions(+), 16 deletions(-) diff --git a/tests/regression/sgemm_wg/kernel.cpp b/tests/regression/sgemm_wg/kernel.cpp index 11612db1..4833154c 100644 --- a/tests/regression/sgemm_wg/kernel.cpp +++ b/tests/regression/sgemm_wg/kernel.cpp @@ -12,14 +12,15 @@ // but smaller case is not handled. // * Compute: // ( M* N) / (TM*TN) == grid size >= NC*NW*NT +// (BM*BN) / (TM*TN) == threadblock size < NT * NW * CORES_PER_CLUSTER // (BM*BN) / (TM*TN) == threadblock size >= NT * CORES_PER_CLUSTER // * Combining BM * BK >= (BM*BN) / (TM*TN) == threadblock yields -// BM <= BK*TM*TN. -#define BM 8 +// BM <= BK*TM*TN +#define BM 16 #define BN BM -#define BK 2 -#define TM 2 -#define TN 2 +#define BK 4 +#define TM 4 +#define TN 4 void threadblock_barrier(unsigned int tid_in_threadblock, unsigned int barrier_id, unsigned int count) { vx_fence(); @@ -32,7 +33,7 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, const uint32_t threadblock_dim_y, const uint32_t threadblock_id_x, const uint32_t threadblock_id_y, - const uint32_t threadblock_id_in_core, + const uint32_t threadblock_id_in_cluster, float *sharedmem_per_threadblock) { const float *A = (const float *)arg->addr_a; const float *B = (const float *)arg->addr_b; @@ -75,12 +76,17 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, constexpr uint32_t stride_b = (BM * BN) / BN / (TM * TN); for (uint32_t k = 0; k < dim_k; k += BK) { + // Data move from GMEM to SMEM + // + // Make sure global offset values for A and B are contiguous between + // neighboring threads to ensure GMEM coalescing. for (uint32_t load_offset = 0; load_offset < BM; load_offset += stride_a) { const uint32_t global_a_offset = dim_k * (global_a_row + load_offset) + (k + local_a_col); local_a[BK * (local_a_row + load_offset) + local_a_col] = A[global_a_offset]; } +// #pragma GCC unroll 1 for (uint32_t load_offset = 0; load_offset < BK; load_offset += stride_b) { const uint32_t global_b_offset = dim_n * (k + local_b_row + load_offset) + global_b_col; @@ -88,10 +94,11 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, B[global_b_offset]; } - threadblock_barrier(tid_in_threadblock, threadblock_id_in_core, + threadblock_barrier(tid_in_threadblock, threadblock_id_in_cluster, threadblock_dim_y); // Compute single tile*tile matmul +#pragma GCC unroll 2 for (uint32_t local_k = 0; local_k < BK; local_k++) { // First, pump data from SMEM->RF #pragma GCC unroll TM @@ -120,7 +127,7 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, } } - threadblock_barrier(tid_in_threadblock, threadblock_id_in_core, + threadblock_barrier(tid_in_threadblock, threadblock_id_in_cluster, threadblock_dim_y); } @@ -137,14 +144,15 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, } } -void kernel_body(int task_id, kernel_arg_t* __UNIFORM__ arg) { +void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { // @perf: All threads are running these compute whose result is mostly same // across the threadblock const uint32_t threads_per_threadblock = (BM * BN) / (TM * TN); #ifdef RADIANCE - const uint32_t threadblocks_per_core = - vx_num_threads() * vx_num_warps() / threads_per_threadblock * CORES_PER_CLUSTER; + const uint32_t threadblocks_per_core = vx_num_threads() * vx_num_warps() / + threads_per_threadblock * + CORES_PER_CLUSTER; #else const uint32_t threadblocks_per_core = vx_num_threads() * vx_num_warps() / threads_per_threadblock; @@ -152,7 +160,7 @@ void kernel_body(int task_id, kernel_arg_t* __UNIFORM__ arg) { const uint32_t threadblock_dim_x = vx_num_threads(); const uint32_t threadblock_dim_y = vx_num_warps() / threadblocks_per_core; const int threadblock_id = task_id / threads_per_threadblock; - const int threadblock_id_in_core = threadblock_id % threadblocks_per_core; + const int threadblock_id_in_cluster = threadblock_id % threadblocks_per_core; const int tid_in_threadblock = task_id % threads_per_threadblock; const uint32_t dim_m = arg->dim_m; @@ -164,10 +172,10 @@ void kernel_body(int task_id, kernel_arg_t* __UNIFORM__ arg) { // "static" shared memory allocation. This would determine threadblock // occupancy of a single cluster float *sharedmem_per_threadblock = - (float *)DEV_SMEM_START_ADDR + (2 * BM * BK) * threadblock_id_in_core; + (float *)DEV_SMEM_START_ADDR + (2 * BM * BK) * threadblock_id_in_cluster; thread_block_gemm(arg, tid_in_threadblock, threadblock_dim_x, threadblock_dim_y, threadblock_id_x, threadblock_id_y, - threadblock_id_in_core, sharedmem_per_threadblock); + threadblock_id_in_cluster, sharedmem_per_threadblock); } int main() { @@ -176,8 +184,8 @@ int main() { #ifdef RADIANCE vx_spawn_tasks_cluster(grid_size, (vx_spawn_tasks_cb)kernel_body, arg); #else - // NOTE: This kernel assumes contiguous thread scheduling for threadblock - // allocation, and therefore does not work with original vx_spawn_tasks + // NOTE: This kernel assumes contiguous thread scheduling for efficient shared + // memory allocation, and therefore does not work with original vx_spawn_tasks vx_spawn_tasks_contiguous(grid_size, (vx_spawn_tasks_cb)kernel_body, arg); #endif return 0;