diff --git a/tests/regression/sgemm_tcore/sgemm_impl.hpp b/tests/regression/sgemm_tcore/sgemm_impl.hpp index afc37542..13226a76 100644 --- a/tests/regression/sgemm_tcore/sgemm_impl.hpp +++ b/tests/regression/sgemm_tcore/sgemm_impl.hpp @@ -1150,19 +1150,20 @@ inline void thread_block_gemm(const T *A, const T *B, float *C, if constexpr (GEMMINI_DMA) { // pipeline initiation - if (tid_in_threadblock == 0) { - // configure dma gmem address to load from - ROCC_INSTRUCTION_RS1_RS2( - XCUSTOM_ACC, - (uint64_t)(A + block_m * BM * dim_k + /*block_k:*/0 * BK), - (uint64_t)(B + /*block_k:*/0 * BK * dim_n + block_n * BN), - k_LOOP_WS_CONFIG_ADDRS_AB) - // GEMMINI_CISC(8) does k_LOOP_WS_CONFIG_STRIDES_AB - GEMMINI_CISC_CMD_R((dim_n << 20) | (dim_k << 8) | 8); - gemmini_fence(); + if (block_m == 0 && block_n == 0) { + if (tid_in_threadblock == 0) { + // configure dma gmem address to load from + ROCC_INSTRUCTION_RS1_RS2( + XCUSTOM_ACC, + (uint64_t)(A + block_m * BM * dim_k + /*block_k:*/ 0 * BK), + (uint64_t)(B + /*block_k:*/ 0 * BK * dim_n + block_n * BN), + k_LOOP_WS_CONFIG_ADDRS_AB) + // GEMMINI_CISC(8) does k_LOOP_WS_CONFIG_STRIDES_AB + GEMMINI_CISC_CMD_R((dim_n << 20) | (dim_k << 8) | 8); + gemmini_fence(); - GEMMINI_CISC_CMD_I(10); - gemmini_fence(); + GEMMINI_CISC_CMD_I(10); + gemmini_fence(); #if 0 // sp_tiled_matmul_full_spad_ws includes CONFIG_BOUNDS @@ -1181,10 +1182,11 @@ inline void thread_block_gemm(const T *A, const T *B, float *C, /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips) gemmini_fence(); #endif - } + } - threadblock_barrier(threadblock_id_in_cluster, - warps_per_threadblock_per_core); + threadblock_barrier(threadblock_id_in_cluster, + warps_per_threadblock_per_core); + } } #pragma GCC unroll 1 @@ -1197,12 +1199,27 @@ inline void thread_block_gemm(const T *A, const T *B, float *C, // this is either done using DMA or SIMT cores depending on GEMMINI_DMA #if (GEMMINI_DMA == 1) - if ((tid_in_threadblock == 0) && ((block_k * BK) != (dim_k - BK))) { + if (tid_in_threadblock == 0) { + asm volatile("next_index_start_%=:" ::); + + const uint32_t next_block_k = + ((block_k + 1) * BK == dim_k) ? 0 : block_k + 1; + const uint32_t next_block_n = + (next_block_k == 0) + ? (((block_n + 1) * BN == dim_n) ? 0 : block_n + 1) + : block_n; + const uint32_t next_block_m = + (next_block_n == 0) + ? ((block_m == block_m_end) ? 0 : block_n + 1) + : block_m; + + asm volatile("next_index_end_%=:" ::); + // configure dma gmem address to load from ROCC_INSTRUCTION_RS1_RS2( XCUSTOM_ACC, - (uint64_t)(A + block_m * BM * dim_k + (block_k + 1/*runahead*/) * BK), - (uint64_t)(B + (block_k + 1/*runahead*/) * BK * dim_n + block_n * BN), + (uint64_t)(A + next_block_m * BM * dim_k + next_block_k * BK), + (uint64_t)(B + next_block_k * BK * dim_n + next_block_n * BN), k_LOOP_WS_CONFIG_ADDRS_AB) // GEMMINI_CISC(8) does k_LOOP_WS_CONFIG_STRIDES_AB GEMMINI_CISC_CMD_R((dim_n << 20) | (dim_k << 8) | 8); @@ -1210,6 +1227,11 @@ inline void thread_block_gemm(const T *A, const T *B, float *C, // block_k is even: opcode 11 (write to local_a_buf) // block_k is odd: opcode 10 (write to local_a) + // + // FIXME: This depends on (dim_k / BK) being an even number, since + // the last iteration of the k-loop is prefetching for the first + // iteration of the n-loop. The ping-poing indexing has to match for + // the two loop end to connect. const uint32_t opcode = 11 - (block_k & 1); GEMMINI_CISC_CMD_I(opcode); // // TODO: branch is probably slow @@ -1349,6 +1371,8 @@ inline void thread_block_gemm(const T *A, const T *B, float *C, } if constexpr (write_to_gmem) { + asm volatile("move_out_start_%=:" ::); + if constexpr (TENSOR_HOPPER) { // wait until all results are accumulated into the RF vx_wgmma_wait(); @@ -1367,6 +1391,8 @@ inline void thread_block_gemm(const T *A, const T *B, float *C, } } } + + asm volatile("move_out_end_%=:" ::); } } asm volatile("loop_mn_end_%=:" ::);