sgemm_wg: Explicitly limit unroll to reduce stack spilling

This needs to be done case-by-case for different BK/TM/TN combinations and
examining the assembly.
This commit is contained in:
Hansung Kim
2024-03-29 02:48:29 -07:00
parent 537b97eb20
commit fa2b6e2ad0

View File

@@ -12,14 +12,15 @@
// but smaller case is not handled.
// * Compute:
// ( M* N) / (TM*TN) == grid size >= NC*NW*NT
// (BM*BN) / (TM*TN) == threadblock size < NT * NW * CORES_PER_CLUSTER
// (BM*BN) / (TM*TN) == threadblock size >= NT * CORES_PER_CLUSTER
// * Combining BM * BK >= (BM*BN) / (TM*TN) == threadblock yields
// BM <= BK*TM*TN.
#define BM 8
// BM <= BK*TM*TN
#define BM 16
#define BN BM
#define BK 2
#define TM 2
#define TN 2
#define BK 4
#define TM 4
#define TN 4
void threadblock_barrier(unsigned int tid_in_threadblock, unsigned int barrier_id, unsigned int count) {
vx_fence();
@@ -32,7 +33,7 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
const uint32_t threadblock_dim_y,
const uint32_t threadblock_id_x,
const uint32_t threadblock_id_y,
const uint32_t threadblock_id_in_core,
const uint32_t threadblock_id_in_cluster,
float *sharedmem_per_threadblock) {
const float *A = (const float *)arg->addr_a;
const float *B = (const float *)arg->addr_b;
@@ -75,12 +76,17 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
constexpr uint32_t stride_b = (BM * BN) / BN / (TM * TN);
for (uint32_t k = 0; k < dim_k; k += BK) {
// Data move from GMEM to SMEM
//
// Make sure global offset values for A and B are contiguous between
// neighboring threads to ensure GMEM coalescing.
for (uint32_t load_offset = 0; load_offset < BM; load_offset += stride_a) {
const uint32_t global_a_offset =
dim_k * (global_a_row + load_offset) + (k + local_a_col);
local_a[BK * (local_a_row + load_offset) + local_a_col] =
A[global_a_offset];
}
// #pragma GCC unroll 1
for (uint32_t load_offset = 0; load_offset < BK; load_offset += stride_b) {
const uint32_t global_b_offset =
dim_n * (k + local_b_row + load_offset) + global_b_col;
@@ -88,10 +94,11 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
B[global_b_offset];
}
threadblock_barrier(tid_in_threadblock, threadblock_id_in_core,
threadblock_barrier(tid_in_threadblock, threadblock_id_in_cluster,
threadblock_dim_y);
// Compute single tile*tile matmul
#pragma GCC unroll 2
for (uint32_t local_k = 0; local_k < BK; local_k++) {
// First, pump data from SMEM->RF
#pragma GCC unroll TM
@@ -120,7 +127,7 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
}
}
threadblock_barrier(tid_in_threadblock, threadblock_id_in_core,
threadblock_barrier(tid_in_threadblock, threadblock_id_in_cluster,
threadblock_dim_y);
}
@@ -137,14 +144,15 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
}
}
void kernel_body(int task_id, kernel_arg_t* __UNIFORM__ arg) {
void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
// @perf: All threads are running these compute whose result is mostly same
// across the threadblock
const uint32_t threads_per_threadblock = (BM * BN) / (TM * TN);
#ifdef RADIANCE
const uint32_t threadblocks_per_core =
vx_num_threads() * vx_num_warps() / threads_per_threadblock * CORES_PER_CLUSTER;
const uint32_t threadblocks_per_core = vx_num_threads() * vx_num_warps() /
threads_per_threadblock *
CORES_PER_CLUSTER;
#else
const uint32_t threadblocks_per_core =
vx_num_threads() * vx_num_warps() / threads_per_threadblock;
@@ -152,7 +160,7 @@ void kernel_body(int task_id, kernel_arg_t* __UNIFORM__ arg) {
const uint32_t threadblock_dim_x = vx_num_threads();
const uint32_t threadblock_dim_y = vx_num_warps() / threadblocks_per_core;
const int threadblock_id = task_id / threads_per_threadblock;
const int threadblock_id_in_core = threadblock_id % threadblocks_per_core;
const int threadblock_id_in_cluster = threadblock_id % threadblocks_per_core;
const int tid_in_threadblock = task_id % threads_per_threadblock;
const uint32_t dim_m = arg->dim_m;
@@ -164,10 +172,10 @@ void kernel_body(int task_id, kernel_arg_t* __UNIFORM__ arg) {
// "static" shared memory allocation. This would determine threadblock
// occupancy of a single cluster
float *sharedmem_per_threadblock =
(float *)DEV_SMEM_START_ADDR + (2 * BM * BK) * threadblock_id_in_core;
(float *)DEV_SMEM_START_ADDR + (2 * BM * BK) * threadblock_id_in_cluster;
thread_block_gemm(arg, tid_in_threadblock, threadblock_dim_x,
threadblock_dim_y, threadblock_id_x, threadblock_id_y,
threadblock_id_in_core, sharedmem_per_threadblock);
threadblock_id_in_cluster, sharedmem_per_threadblock);
}
int main() {
@@ -176,8 +184,8 @@ int main() {
#ifdef RADIANCE
vx_spawn_tasks_cluster(grid_size, (vx_spawn_tasks_cb)kernel_body, arg);
#else
// NOTE: This kernel assumes contiguous thread scheduling for threadblock
// allocation, and therefore does not work with original vx_spawn_tasks
// NOTE: This kernel assumes contiguous thread scheduling for efficient shared
// memory allocation, and therefore does not work with original vx_spawn_tasks
vx_spawn_tasks_contiguous(grid_size, (vx_spawn_tasks_cb)kernel_body, arg);
#endif
return 0;