From 443a37be6ca93f22ddafade18349ccee8bdd617d Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Sun, 8 Sep 2024 14:56:48 -0700 Subject: [PATCH] sgemm_impl: Add DMA_FAST option; fix dbuf offset for dma --- tests/regression/sgemm_tcore/kernel.cpp | 21 ++++++++++++----- tests/regression/sgemm_tcore/sgemm_impl.hpp | 26 +++++++++++++-------- 2 files changed, 31 insertions(+), 16 deletions(-) diff --git a/tests/regression/sgemm_tcore/kernel.cpp b/tests/regression/sgemm_tcore/kernel.cpp index bc77ac2a..bb904baf 100644 --- a/tests/regression/sgemm_tcore/kernel.cpp +++ b/tests/regression/sgemm_tcore/kernel.cpp @@ -90,13 +90,22 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { thread_block_gemm( - (const float_type *)arg->addr_a, (const float_type *)arg->addr_b, - (float *)arg->addr_c, arg->dim_m, arg->dim_n, arg->dim_k, - tid_in_threadblock, threadblocks_per_cluster, threadblock_id_in_cluster, - sharedmem_per_threadblock); + /*smem_b_dbuf_offset=*/(2 * BM * BK + BK * BN) * sizeof(float_type) +#endif + >((const float_type *)arg->addr_a, + (const float_type *)arg->addr_b, (float *)arg->addr_c, + arg->dim_m, arg->dim_n, arg->dim_k, tid_in_threadblock, + threadblocks_per_cluster, threadblock_id_in_cluster, + sharedmem_per_threadblock); float *gmem_tmp_d0 = reinterpret_cast(0xd0000000UL); float *gmem_tmp_d1 = reinterpret_cast(0xd1000000UL); diff --git a/tests/regression/sgemm_tcore/sgemm_impl.hpp b/tests/regression/sgemm_tcore/sgemm_impl.hpp index d2e88ace..7ba19992 100644 --- a/tests/regression/sgemm_tcore/sgemm_impl.hpp +++ b/tests/regression/sgemm_tcore/sgemm_impl.hpp @@ -72,8 +72,9 @@ static_assert(WMITER * WNITER * TCM * TCN * NUM_WARPS * CORES_PER_CLUSTER == #define TRANSPOSE_AT_PRODUCE 0 #define TRANSPOSE_AT_CONSUME 0 -#define GEMMINI_DMA 0 -#define GEMMINI_DMA_FLEXIBLE_LAYOUT 0 +#define GEMMINI_DMA 1 +#define GEMMINI_DMA_FAST 1 +#define GEMMINI_DMA_FLEXIBLE_LAYOUT 1 #if SMEM_SIZE == 0x4000 #define SMEM_ADDR_Q0 ((float * const) 0xff000000) #define SMEM_ADDR_Q1 ((float * const) 0xff001000) @@ -207,7 +208,7 @@ template inline constexpr std::pair remap_to_gemmini_dma_layout(const uint32_t logical_row, const uint32_t logical_col) { - static_assert(GEMMINI_DMA_FLEXIBLE_LAYOUT || DIM == 8, + static_assert(!use_dma || DIM == 8, "GEMMINI_DMA layout remapping code only written for DIM == 8"); if constexpr (use_dma) { @@ -915,7 +916,6 @@ inline void thread_block_gemm(const T *A, const T *B, float *C, // pipeline initiation if (tid_in_threadblock == 0) { // configure dma gmem address to load from - // FIXME: block_k is wrong ROCC_INSTRUCTION_RS1_RS2( XCUSTOM_ACC, (uint64_t)(A + block_m * BM * dim_k + /*block_k:*/0 * BK), @@ -963,7 +963,6 @@ inline void thread_block_gemm(const T *A, const T *B, float *C, #if (GEMMINI_DMA == 1) if ((tid_in_threadblock == 0) && ((block_k * BK) != (dim_k - BK))) { // configure dma gmem address to load from - // FIXME: block_k is wrong ROCC_INSTRUCTION_RS1_RS2( XCUSTOM_ACC, (uint64_t)(A + block_m * BM * dim_k + (block_k + 1/*runahead*/) * BK), @@ -976,7 +975,7 @@ inline void thread_block_gemm(const T *A, const T *B, float *C, // block_k is even: opcode 11 (write to local_a_buf) // block_k is odd: opcode 10 (write to local_a) const uint32_t opcode = 11 - (block_k & 1); - GEMMINI_CISC_CMD_R(opcode); + GEMMINI_CISC_CMD_I(opcode); // // TODO: branch is probably slow // if (block_k & 1) { // GEMMINI_CISC_CMD_I(12); @@ -1061,8 +1060,12 @@ inline void thread_block_gemm(const T *A, const T *B, float *C, // local_b_consume = reinterpret_cast( // (mask_odd & reinterpret_cast(local_b_buf)) | // (mask_even & reinterpret_cast(local_b))); - local_a_consume = local_a + (block_k & 1) * (BM * BK); - local_b_consume = local_b + (block_k & 1) * (BK * BN); + local_a_consume = local_a + (block_k & 1) * + (smem_a_dbuf_offset - smem_a_offset) / + sizeof(T); + local_b_consume = local_b + (block_k & 1) * + (smem_b_dbuf_offset - smem_b_offset) / + sizeof(T); } else { // no double-buffering without DMA local_a_consume = local_a; @@ -1071,11 +1074,14 @@ inline void thread_block_gemm(const T *A, const T *B, float *C, asm volatile("dbuf_sel_end_%=:" ::); constexpr MemLayout layout_a = - GEMMINI_DMA ? MemLayout::block_row_major + GEMMINI_DMA ? (GEMMINI_DMA_FAST ? MemLayout::MN_major + : MemLayout::block_row_major) : (TRANSPOSE_AT_CONSUME ? MemLayout::K_major : MemLayout::MN_major); constexpr MemLayout layout_b = - GEMMINI_DMA ? MemLayout::block_row_major : MemLayout::MN_major; + GEMMINI_DMA ? (GEMMINI_DMA_FAST ? MemLayout::MN_major + : MemLayout::block_row_major) + : MemLayout::MN_major; thread_block_gemm_single_tile