tensor: Do DMA mvin for next m/n loop at the last k iter

This increases util by pulling the DMA wait time out of the K-loop
wraparound (next N) and overlapping it with the last K iter.
This commit is contained in:
Hansung Kim
2024-10-29 19:43:22 -07:00
parent 367fa927f8
commit 8dadbdd42d

View File

@@ -1150,19 +1150,20 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
if constexpr (GEMMINI_DMA) {
// pipeline initiation
if (tid_in_threadblock == 0) {
// configure dma gmem address to load from
ROCC_INSTRUCTION_RS1_RS2(
XCUSTOM_ACC,
(uint64_t)(A + block_m * BM * dim_k + /*block_k:*/0 * BK),
(uint64_t)(B + /*block_k:*/0 * BK * dim_n + block_n * BN),
k_LOOP_WS_CONFIG_ADDRS_AB)
// GEMMINI_CISC(8) does k_LOOP_WS_CONFIG_STRIDES_AB
GEMMINI_CISC_CMD_R((dim_n << 20) | (dim_k << 8) | 8);
gemmini_fence();
if (block_m == 0 && block_n == 0) {
if (tid_in_threadblock == 0) {
// configure dma gmem address to load from
ROCC_INSTRUCTION_RS1_RS2(
XCUSTOM_ACC,
(uint64_t)(A + block_m * BM * dim_k + /*block_k:*/ 0 * BK),
(uint64_t)(B + /*block_k:*/ 0 * BK * dim_n + block_n * BN),
k_LOOP_WS_CONFIG_ADDRS_AB)
// GEMMINI_CISC(8) does k_LOOP_WS_CONFIG_STRIDES_AB
GEMMINI_CISC_CMD_R((dim_n << 20) | (dim_k << 8) | 8);
gemmini_fence();
GEMMINI_CISC_CMD_I(10);
gemmini_fence();
GEMMINI_CISC_CMD_I(10);
gemmini_fence();
#if 0
// sp_tiled_matmul_full_spad_ws includes CONFIG_BOUNDS
@@ -1181,10 +1182,11 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
/*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips)
gemmini_fence();
#endif
}
}
threadblock_barrier(threadblock_id_in_cluster,
warps_per_threadblock_per_core);
threadblock_barrier(threadblock_id_in_cluster,
warps_per_threadblock_per_core);
}
}
#pragma GCC unroll 1
@@ -1197,12 +1199,27 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
// this is either done using DMA or SIMT cores depending on GEMMINI_DMA
#if (GEMMINI_DMA == 1)
if ((tid_in_threadblock == 0) && ((block_k * BK) != (dim_k - BK))) {
if (tid_in_threadblock == 0) {
asm volatile("next_index_start_%=:" ::);
const uint32_t next_block_k =
((block_k + 1) * BK == dim_k) ? 0 : block_k + 1;
const uint32_t next_block_n =
(next_block_k == 0)
? (((block_n + 1) * BN == dim_n) ? 0 : block_n + 1)
: block_n;
const uint32_t next_block_m =
(next_block_n == 0)
? ((block_m == block_m_end) ? 0 : block_n + 1)
: block_m;
asm volatile("next_index_end_%=:" ::);
// configure dma gmem address to load from
ROCC_INSTRUCTION_RS1_RS2(
XCUSTOM_ACC,
(uint64_t)(A + block_m * BM * dim_k + (block_k + 1/*runahead*/) * BK),
(uint64_t)(B + (block_k + 1/*runahead*/) * BK * dim_n + block_n * BN),
(uint64_t)(A + next_block_m * BM * dim_k + next_block_k * BK),
(uint64_t)(B + next_block_k * BK * dim_n + next_block_n * BN),
k_LOOP_WS_CONFIG_ADDRS_AB)
// GEMMINI_CISC(8) does k_LOOP_WS_CONFIG_STRIDES_AB
GEMMINI_CISC_CMD_R((dim_n << 20) | (dim_k << 8) | 8);
@@ -1210,6 +1227,11 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
// block_k is even: opcode 11 (write to local_a_buf)
// block_k is odd: opcode 10 (write to local_a)
//
// FIXME: This depends on (dim_k / BK) being an even number, since
// the last iteration of the k-loop is prefetching for the first
// iteration of the n-loop. The ping-poing indexing has to match for
// the two loop end to connect.
const uint32_t opcode = 11 - (block_k & 1);
GEMMINI_CISC_CMD_I(opcode);
// // TODO: branch is probably slow
@@ -1349,6 +1371,8 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
}
if constexpr (write_to_gmem) {
asm volatile("move_out_start_%=:" ::);
if constexpr (TENSOR_HOPPER) {
// wait until all results are accumulated into the RF
vx_wgmma_wait();
@@ -1367,6 +1391,8 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
}
}
}
asm volatile("move_out_end_%=:" ::);
}
}
asm volatile("loop_mn_end_%=:" ::);