From 36eb50060f1e5c14be769b74e6d3f8c4a36bbdfb Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Mon, 28 Oct 2024 12:47:20 -0700 Subject: [PATCH] sgemm_impl: Add skeleton wgmma routine for single_tile --- tests/regression/sgemm_tcore/sgemm_impl.hpp | 101 ++++++++++++-------- 1 file changed, 62 insertions(+), 39 deletions(-) diff --git a/tests/regression/sgemm_tcore/sgemm_impl.hpp b/tests/regression/sgemm_tcore/sgemm_impl.hpp index b3aacf48..626b2f52 100644 --- a/tests/regression/sgemm_tcore/sgemm_impl.hpp +++ b/tests/regression/sgemm_tcore/sgemm_impl.hpp @@ -233,6 +233,18 @@ inline void vx_wmma(const int dest_reg) { } } +inline void vx_wgmma() { + // .insn r opcode6, func3, func7, rd, rs1, rs2 + // https://www.rowleydownload.co.uk/arm/documentation/gnu/as/RISC_002dV_002dFormats.html#RISC_002dV_002dFormats + asm volatile (".insn r %0, 0, 0, x0, x0, x0" :: "i"(RISCV_CUSTOM3)); +} + +inline void vx_wgmma_wait() { + // .insn r opcode6, func3, func7, rd, rs1, rs2 + // func3 == 1 encodes wait + asm volatile (".insn r %0, 1, 0, x0, x0, x0" :: "i"(RISCV_CUSTOM3)); +} + // Remap logical row/col coordinate of a matrix element to a memory index that // follows the 2-level block-row-major layout that Gemmini DMA uses template @@ -779,15 +791,15 @@ template __attribute__((always_inline)) inline void thread_block_gemm_single_tile( const T *local_a, const T *local_b, const T *local_c, T *result_addr, @@ -804,44 +816,55 @@ __attribute__((always_inline)) inline void thread_block_gemm_single_tile( const uint32_t warps_per_threadblock_per_core = NUM_WARPS / threadblocks_per_cluster; - // TODO: it would be useful if this bit is split out into a function, so that - // preloading accumulation tile can be used for full GEMMs at the start of - // the K-loop. - if constexpr (load_accum) { -#pragma GCC unroll - for (int wm_iter = 0; wm_iter < WMITER; wm_iter++) { -#pragma GCC unroll - for (int wn_iter = 0; wn_iter < WNITER; wn_iter++) { - wmma_load_accum(tid_in_warp, warp_col, warp_row, wn_iter, wm_iter, - tile_dim_n, local_c); + if constexpr (TENSOR_HOPPER) { +#pragma GCC unroll 1 + for (int i = 0; i < BK_LOOP; i++) { +#pragma GCC unroll 4 + for (uint32_t local_k = 0; local_k < tile_dim_k; local_k += TCK) { + // FIXME: use local_a and local_b here + vx_wgmma(); + } + } + } else { + // TODO: it would be useful if this bit is split out into a function, so + // that preloading accumulation tile can be used for full GEMMs at the start + // of the K-loop. + if constexpr (load_accum) { +#pragma GCC unroll + for (int wm_iter = 0; wm_iter < WMITER; wm_iter++) { +#pragma GCC unroll + for (int wn_iter = 0; wn_iter < WNITER; wn_iter++) { + wmma_load_accum(tid_in_warp, warp_col, warp_row, wn_iter, wm_iter, + tile_dim_n, local_c); + } } } - } #pragma GCC unroll 1 - for (int i = 0; i < BK_LOOP; i++) { + for (int i = 0; i < BK_LOOP; i++) { #pragma GCC unroll 4 - for (uint32_t local_k = 0; local_k < tile_dim_k; local_k += TCK) { + for (uint32_t local_k = 0; local_k < tile_dim_k; local_k += TCK) { #pragma GCC unroll 2 - for (int wn_iter = 0; wn_iter < WNITER; wn_iter++) { - // SMEM -> RF - static_assert(leading_dim_b == 0, - "leading_dim for wmma_load_b is not implemented yet"); - wmma_load_b( - local_b, local_k, warp_col, wn_iter, tid_in_warp); -#pragma GCC unroll 2 - for (int wm_iter = 0; wm_iter < WMITER; wm_iter++) { + for (int wn_iter = 0; wn_iter < WNITER; wn_iter++) { // SMEM -> RF - if constexpr (leading_dim_a == 0) { - wmma_load_a( - local_a, local_k, warp_row, wm_iter, tid_in_warp); - } else { - wmma_load_a(local_a, local_k, warp_row, - wm_iter, tid_in_warp); + static_assert(leading_dim_b == 0, + "leading_dim for wmma_load_b is not implemented yet"); + wmma_load_b( + local_b, local_k, warp_col, wn_iter, tid_in_warp); +#pragma GCC unroll 2 + for (int wm_iter = 0; wm_iter < WMITER; wm_iter++) { + // SMEM -> RF + if constexpr (leading_dim_a == 0) { + wmma_load_a( + local_a, local_k, warp_row, wm_iter, tid_in_warp); + } else { + wmma_load_a( + local_a, local_k, warp_row, wm_iter, tid_in_warp); + } + // perform mma + vx_wmma(wm_iter); } - // perform mma - vx_wmma(wm_iter); } } }