From 793779aa6cd0fe0e316ff455e3a5dbee4635a7a7 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Wed, 24 Apr 2024 21:08:31 -0700 Subject: [PATCH] sgemm_wg: 128x128 config --- tests/regression/sgemm_wg/kernel.cpp | 14 +++++++------- tests/regression/sgemm_wg/main.cpp | 6 +++--- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/tests/regression/sgemm_wg/kernel.cpp b/tests/regression/sgemm_wg/kernel.cpp index e9f898a0..86b7309d 100644 --- a/tests/regression/sgemm_wg/kernel.cpp +++ b/tests/regression/sgemm_wg/kernel.cpp @@ -16,11 +16,11 @@ // (BM*BN) / (TM*TN) == threadblock size >= NT * CORES_PER_CLUSTER // * Combining BM * BK >= (BM*BN) / (TM*TN) == threadblock yields // BM <= BK*TM*TN -#define BM 8 +#define BM 32 #define BN BM -#define BK 2 -#define TM 2 -#define TN 2 +#define BK 8 +#define TM 4 +#define TN 4 void threadblock_barrier(unsigned int tid_in_threadblock, unsigned int barrier_id, unsigned int count) { vx_fence(); @@ -80,14 +80,14 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, // // Make sure global offset values for A and B are contiguous between // neighboring threads to ensure GMEM coalescing. -// #pragma GCC unroll 1 +#pragma GCC unroll 2 for (uint32_t load_offset = 0; load_offset < BM; load_offset += stride_a) { const uint32_t global_a_offset = dim_k * (global_a_row + load_offset) + (k + local_a_col); local_a[BK * (local_a_row + load_offset) + local_a_col] = A[global_a_offset]; } -// #pragma GCC unroll 1 +#pragma GCC unroll 2 for (uint32_t load_offset = 0; load_offset < BK; load_offset += stride_b) { const uint32_t global_b_offset = dim_n * (k + local_b_row + load_offset) + global_b_col; @@ -99,7 +99,7 @@ void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg, threadblock_dim_y); // Compute single tile*tile matmul -// #pragma GCC unroll 2 +#pragma GCC unroll 4 for (uint32_t local_k = 0; local_k < BK; local_k++) { // First, pump data from SMEM->RF #pragma GCC unroll TM diff --git a/tests/regression/sgemm_wg/main.cpp b/tests/regression/sgemm_wg/main.cpp index 709d804c..62625c44 100644 --- a/tests/regression/sgemm_wg/main.cpp +++ b/tests/regression/sgemm_wg/main.cpp @@ -166,9 +166,9 @@ int main(int argc, char *argv[]) { RT_CHECK(vx_dev_open(&device)); // FIXME: hardcoded - uint32_t dim_m = 32; - uint32_t dim_n = 32; - uint32_t dim_k = 32; + uint32_t dim_m = 128; + uint32_t dim_n = 128; + uint32_t dim_k = 128; generate_source_matrix(dim_m, dim_n, dim_k); generate_reference_matmul(dim_m, dim_n, dim_k);