diff --git a/tests/regression/sgemm_tcore/sgemm_impl.hpp b/tests/regression/sgemm_tcore/sgemm_impl.hpp index e018fb25..7e6c7ee8 100644 --- a/tests/regression/sgemm_tcore/sgemm_impl.hpp +++ b/tests/regression/sgemm_tcore/sgemm_impl.hpp @@ -213,11 +213,23 @@ inline constexpr void map_c_8lanes(const int tid, int &row, int &col) { col += ((tid % 4) / 2) * 2; } +inline constexpr void map_c_8lanes_hopper(const int tid, int &row, int &col) { + const int tg = tid / 2; + + // FIXME wrong!!! + row = 0; + col = tid; +} + inline constexpr void map_c(const int tid, int &row, int &col) { if constexpr (NUM_THREADS == 32) { map_c_32lanes(tid, row, col); } else if constexpr (NUM_THREADS == 8) { - map_c_8lanes(tid, row, col); + if constexpr (TENSOR_HOPPER) { + map_c_8lanes_hopper(tid, row, col); + } else { + map_c_8lanes(tid, row, col); + } } else { // FIXME: not allowed } @@ -233,10 +245,11 @@ inline void vx_wmma(const int dest_reg) { } } -inline void vx_wgmma() { +inline void vx_wgmma(const uint32_t addr_a, const uint32_t addr_b) { // .insn r opcode6, func3, func7, rd, rs1, rs2 // https://www.rowleydownload.co.uk/arm/documentation/gnu/as/RISC_002dV_002dFormats.html#RISC_002dV_002dFormats - asm volatile (".insn r %0, 0, 0, x0, x0, x0" :: "i"(RISCV_CUSTOM3)); + asm volatile(".insn r %0, 0, 0, x0, %1, %2" ::"i"(RISCV_CUSTOM3), "r"(addr_a), + "r"(addr_b)); } inline void vx_wgmma_wait() { @@ -503,6 +516,25 @@ inline void wmma_load_b(const volatile T *smem_B, const int local_k, asm volatile ("wmma_load_b_finish_%=:" :: ); } +template +inline void wgmma(volatile const T *smem_A, volatile const T *smem_B, + const int local_k, const int warp_row, const int warp_col) { + asm volatile("wgmma_start_%=:" ::); + + const volatile uint8_t *addr_a = + generate_smem_addr_a(smem_A, local_k, + warp_row, 0, 0); + const volatile uint8_t *addr_b = + generate_smem_addr_b( + smem_B, local_k, warp_col, 0, 0); + + vx_wgmma(reinterpret_cast(addr_a), + reinterpret_cast(addr_b)); + + asm volatile("wgmma_finish_%=:" ::); +} + // Initialize the accumulator registers to zero before starting FMA operations // with the tensor cores. template inline void initialize_accum_regs() { @@ -527,6 +559,41 @@ template inline void initialize_accum_regs() { } } +inline void initialize_all_regs() { + asm volatile("fmv.w.x f0, x0"); + asm volatile("fmv.w.x f1, x0"); + asm volatile("fmv.w.x f2, x0"); + asm volatile("fmv.w.x f3, x0"); + asm volatile("fmv.w.x f4, x0"); + asm volatile("fmv.w.x f5, x0"); + asm volatile("fmv.w.x f6, x0"); + asm volatile("fmv.w.x f7, x0"); + asm volatile("fmv.w.x f8, x0"); + asm volatile("fmv.w.x f9, x0"); + asm volatile("fmv.w.x f10, x0"); + asm volatile("fmv.w.x f11, x0"); + asm volatile("fmv.w.x f12, x0"); + asm volatile("fmv.w.x f13, x0"); + asm volatile("fmv.w.x f14, x0"); + asm volatile("fmv.w.x f15, x0"); + asm volatile("fmv.w.x f16, x0"); + asm volatile("fmv.w.x f17, x0"); + asm volatile("fmv.w.x f18, x0"); + asm volatile("fmv.w.x f19, x0"); + asm volatile("fmv.w.x f20, x0"); + asm volatile("fmv.w.x f21, x0"); + asm volatile("fmv.w.x f22, x0"); + asm volatile("fmv.w.x f23, x0"); + asm volatile("fmv.w.x f24, x0"); + asm volatile("fmv.w.x f25, x0"); + asm volatile("fmv.w.x f26, x0"); + asm volatile("fmv.w.x f27, x0"); + asm volatile("fmv.w.x f28, x0"); + asm volatile("fmv.w.x f29, x0"); + asm volatile("fmv.w.x f30, x0"); + asm volatile("fmv.w.x f31, x0"); +} + // `C` is expected to be in N-major layout. __attribute__((always_inline)) inline void wmma_load_accum(const int thread_in_warp, const int warp_col, @@ -622,6 +689,66 @@ wmma_store(const int thread_in_warp, const int warp_col, const int warp_row, asm volatile ("wmma_store_finish_%=:" :: ); } +// Write out the matrix data stored in RF to memory +__attribute__((always_inline)) inline void +wgmma_store(const int thread_in_warp, const int warp_col, const int warp_row, + const int dim_n, float *write_addr) { + asm volatile ("wgmma_store_start_%=:" :: ); + + const int tid = thread_in_warp; + + // these are [0, TCM/TCN) + int tid_row = 0; + int tid_col = 0; + map_c(tid, tid_row, tid_col); + + // FIXME: WM and WN might be swapped here + int local_row = WM * warp_row + tid_row; + int local_col = WN * warp_col + tid_col; + + // FIXME: this is storing in M-major format + volatile uint8_t *addr = reinterpret_cast( + &write_addr[dim_n * (local_row + 0) + (local_col + 0)]); + volatile uint8_t *addr_tworow = addr + (2 * dim_n) * sizeof(float); + asm volatile("fsw f0, %0(%1)" ::"i"(0 * WM * sizeof(float)), "r"(addr)); + asm volatile("fsw f1, %0(%1)" ::"i"(1 * WM * sizeof(float)), "r"(addr)); + asm volatile("fsw f2, %0(%1)" ::"i"(2 * WM * sizeof(float)), "r"(addr)); + asm volatile("fsw f3, %0(%1)" ::"i"(3 * WM * sizeof(float)), "r"(addr)); + asm volatile("fsw f4, %0(%1)" ::"i"(4 * WM * sizeof(float)), "r"(addr)); + asm volatile("fsw f5, %0(%1)" ::"i"(5 * WM * sizeof(float)), "r"(addr)); + asm volatile("fsw f6, %0(%1)" ::"i"(6 * WM * sizeof(float)), "r"(addr)); + asm volatile("fsw f7, %0(%1)" ::"i"(7 * WM * sizeof(float)), "r"(addr)); + + asm volatile("fsw f8, %0(%1)" ::"i"( 8 * WM * sizeof(float)), "r"(addr)); + asm volatile("fsw f9, %0(%1)" ::"i"( 9 * WM * sizeof(float)), "r"(addr)); + asm volatile("fsw f10, %0(%1)" ::"i"(10 * WM * sizeof(float)), "r"(addr)); + asm volatile("fsw f11, %0(%1)" ::"i"(11 * WM * sizeof(float)), "r"(addr)); + asm volatile("fsw f12, %0(%1)" ::"i"(12 * WM * sizeof(float)), "r"(addr)); + asm volatile("fsw f13, %0(%1)" ::"i"(13 * WM * sizeof(float)), "r"(addr)); + asm volatile("fsw f14, %0(%1)" ::"i"(14 * WM * sizeof(float)), "r"(addr)); + asm volatile("fsw f15, %0(%1)" ::"i"(15 * WM * sizeof(float)), "r"(addr)); + + asm volatile("fsw f16, %0(%1)" ::"i"((0 * WM + 8) * sizeof(float)), "r"(addr)); + asm volatile("fsw f17, %0(%1)" ::"i"((1 * WM + 8) * sizeof(float)), "r"(addr)); + asm volatile("fsw f18, %0(%1)" ::"i"((2 * WM + 8) * sizeof(float)), "r"(addr)); + asm volatile("fsw f19, %0(%1)" ::"i"((3 * WM + 8) * sizeof(float)), "r"(addr)); + asm volatile("fsw f20, %0(%1)" ::"i"((4 * WM + 8) * sizeof(float)), "r"(addr)); + asm volatile("fsw f21, %0(%1)" ::"i"((5 * WM + 8) * sizeof(float)), "r"(addr)); + asm volatile("fsw f22, %0(%1)" ::"i"((6 * WM + 8) * sizeof(float)), "r"(addr)); + asm volatile("fsw f23, %0(%1)" ::"i"((7 * WM + 8) * sizeof(float)), "r"(addr)); + + asm volatile("fsw f24, %0(%1)" ::"i"(( 8 * WM + 8) * sizeof(float)), "r"(addr)); + asm volatile("fsw f25, %0(%1)" ::"i"(( 9 * WM + 8) * sizeof(float)), "r"(addr)); + asm volatile("fsw f26, %0(%1)" ::"i"((10 * WM + 8) * sizeof(float)), "r"(addr)); + asm volatile("fsw f27, %0(%1)" ::"i"((11 * WM + 8) * sizeof(float)), "r"(addr)); + asm volatile("fsw f28, %0(%1)" ::"i"((12 * WM + 8) * sizeof(float)), "r"(addr)); + asm volatile("fsw f29, %0(%1)" ::"i"((13 * WM + 8) * sizeof(float)), "r"(addr)); + asm volatile("fsw f30, %0(%1)" ::"i"((14 * WM + 8) * sizeof(float)), "r"(addr)); + asm volatile("fsw f31, %0(%1)" ::"i"((15 * WM + 8) * sizeof(float)), "r"(addr)); + + asm volatile ("wgmma_store_finish_%=:" :: ); +} + __attribute__((convergent)) inline void threadblock_barrier(const uint32_t barrier_id, const uint32_t count) { asm volatile("" ::: "memory"); @@ -859,13 +986,14 @@ __attribute__((always_inline)) inline void thread_block_gemm_single_tile( const uint32_t warps_per_threadblock_per_core = NUM_WARPS / threadblocks_per_cluster; +#if 1 if constexpr (TENSOR_HOPPER) { #pragma GCC unroll 1 for (int i = 0; i < BK_LOOP; i++) { #pragma GCC unroll 4 - for (uint32_t local_k = 0; local_k < tile_dim_k; local_k += TCK) { - // FIXME: use local_a and local_b here - vx_wgmma(); + for (uint32_t local_k = 0; local_k < tile_dim_k; local_k += TCK) { + wgmma( + local_a, local_b, local_k, warp_row, warp_col); } } } else { @@ -912,6 +1040,7 @@ __attribute__((always_inline)) inline void thread_block_gemm_single_tile( } } } +#endif // if constexpr (GEMMINI_DMA) { // // Call gemmini fence at the end of the loop to overlap dma & wmma. @@ -1011,9 +1140,13 @@ inline void thread_block_gemm(const T *A, const T *B, float *C, for (uint32_t block_n = 0; (block_n * BN) < dim_n; block_n++) { asm volatile ("loop_mn_start_%=:" :: ); - // clear out accumulators - initialize_accum_regs<0>(); - initialize_accum_regs<1>(); + if constexpr (TENSOR_HOPPER) { + initialize_all_regs(); + } else { + // clear out accumulators + initialize_accum_regs<0>(); + initialize_accum_regs<1>(); + } if constexpr (GEMMINI_DMA) { // pipeline initiation @@ -1216,13 +1349,22 @@ inline void thread_block_gemm(const T *A, const T *B, float *C, } if constexpr (write_to_gmem) { + if constexpr (TENSOR_HOPPER) { + // wait until all results are accumulated into the RF + vx_wgmma_wait(); + + float *global_offset_C = C + (BM * block_m) * dim_n + BN * block_n; + wgmma_store(tid_in_warp, warp_col, warp_row, dim_n, global_offset_C); + } else { #pragma GCC unroll - for (int wm_iter = 0; wm_iter < WMITER; wm_iter++) { + for (int wm_iter = 0; wm_iter < WMITER; wm_iter++) { #pragma GCC unroll - for (int wn_iter = 0; wn_iter < WNITER; wn_iter++) { - float *global_offset_C = C + (BM * block_m) * dim_n + BN * block_n; - wmma_store(tid_in_warp, warp_col, warp_row, wn_iter, wm_iter, dim_n, - global_offset_C); + for (int wn_iter = 0; wn_iter < WNITER; wn_iter++) { + float *global_offset_C = + C + (BM * block_m) * dim_n + BN * block_n; + wmma_store(tid_in_warp, warp_col, warp_row, wn_iter, wm_iter, + dim_n, global_offset_C); + } } } }