tensor: Do DMA mvin for next m/n loop at the last k iter
This increases util by pulling the DMA wait time out of the K-loop wraparound (next N) and overlapping it with the last K iter.
This commit is contained in:
@@ -1150,19 +1150,20 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
|
||||
|
||||
if constexpr (GEMMINI_DMA) {
|
||||
// pipeline initiation
|
||||
if (tid_in_threadblock == 0) {
|
||||
// configure dma gmem address to load from
|
||||
ROCC_INSTRUCTION_RS1_RS2(
|
||||
XCUSTOM_ACC,
|
||||
(uint64_t)(A + block_m * BM * dim_k + /*block_k:*/0 * BK),
|
||||
(uint64_t)(B + /*block_k:*/0 * BK * dim_n + block_n * BN),
|
||||
k_LOOP_WS_CONFIG_ADDRS_AB)
|
||||
// GEMMINI_CISC(8) does k_LOOP_WS_CONFIG_STRIDES_AB
|
||||
GEMMINI_CISC_CMD_R((dim_n << 20) | (dim_k << 8) | 8);
|
||||
gemmini_fence();
|
||||
if (block_m == 0 && block_n == 0) {
|
||||
if (tid_in_threadblock == 0) {
|
||||
// configure dma gmem address to load from
|
||||
ROCC_INSTRUCTION_RS1_RS2(
|
||||
XCUSTOM_ACC,
|
||||
(uint64_t)(A + block_m * BM * dim_k + /*block_k:*/ 0 * BK),
|
||||
(uint64_t)(B + /*block_k:*/ 0 * BK * dim_n + block_n * BN),
|
||||
k_LOOP_WS_CONFIG_ADDRS_AB)
|
||||
// GEMMINI_CISC(8) does k_LOOP_WS_CONFIG_STRIDES_AB
|
||||
GEMMINI_CISC_CMD_R((dim_n << 20) | (dim_k << 8) | 8);
|
||||
gemmini_fence();
|
||||
|
||||
GEMMINI_CISC_CMD_I(10);
|
||||
gemmini_fence();
|
||||
GEMMINI_CISC_CMD_I(10);
|
||||
gemmini_fence();
|
||||
|
||||
#if 0
|
||||
// sp_tiled_matmul_full_spad_ws includes CONFIG_BOUNDS
|
||||
@@ -1181,10 +1182,11 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
|
||||
/*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips)
|
||||
gemmini_fence();
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
threadblock_barrier(threadblock_id_in_cluster,
|
||||
warps_per_threadblock_per_core);
|
||||
threadblock_barrier(threadblock_id_in_cluster,
|
||||
warps_per_threadblock_per_core);
|
||||
}
|
||||
}
|
||||
|
||||
#pragma GCC unroll 1
|
||||
@@ -1197,12 +1199,27 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
|
||||
// this is either done using DMA or SIMT cores depending on GEMMINI_DMA
|
||||
|
||||
#if (GEMMINI_DMA == 1)
|
||||
if ((tid_in_threadblock == 0) && ((block_k * BK) != (dim_k - BK))) {
|
||||
if (tid_in_threadblock == 0) {
|
||||
asm volatile("next_index_start_%=:" ::);
|
||||
|
||||
const uint32_t next_block_k =
|
||||
((block_k + 1) * BK == dim_k) ? 0 : block_k + 1;
|
||||
const uint32_t next_block_n =
|
||||
(next_block_k == 0)
|
||||
? (((block_n + 1) * BN == dim_n) ? 0 : block_n + 1)
|
||||
: block_n;
|
||||
const uint32_t next_block_m =
|
||||
(next_block_n == 0)
|
||||
? ((block_m == block_m_end) ? 0 : block_n + 1)
|
||||
: block_m;
|
||||
|
||||
asm volatile("next_index_end_%=:" ::);
|
||||
|
||||
// configure dma gmem address to load from
|
||||
ROCC_INSTRUCTION_RS1_RS2(
|
||||
XCUSTOM_ACC,
|
||||
(uint64_t)(A + block_m * BM * dim_k + (block_k + 1/*runahead*/) * BK),
|
||||
(uint64_t)(B + (block_k + 1/*runahead*/) * BK * dim_n + block_n * BN),
|
||||
(uint64_t)(A + next_block_m * BM * dim_k + next_block_k * BK),
|
||||
(uint64_t)(B + next_block_k * BK * dim_n + next_block_n * BN),
|
||||
k_LOOP_WS_CONFIG_ADDRS_AB)
|
||||
// GEMMINI_CISC(8) does k_LOOP_WS_CONFIG_STRIDES_AB
|
||||
GEMMINI_CISC_CMD_R((dim_n << 20) | (dim_k << 8) | 8);
|
||||
@@ -1210,6 +1227,11 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
|
||||
|
||||
// block_k is even: opcode 11 (write to local_a_buf)
|
||||
// block_k is odd: opcode 10 (write to local_a)
|
||||
//
|
||||
// FIXME: This depends on (dim_k / BK) being an even number, since
|
||||
// the last iteration of the k-loop is prefetching for the first
|
||||
// iteration of the n-loop. The ping-poing indexing has to match for
|
||||
// the two loop end to connect.
|
||||
const uint32_t opcode = 11 - (block_k & 1);
|
||||
GEMMINI_CISC_CMD_I(opcode);
|
||||
// // TODO: branch is probably slow
|
||||
@@ -1349,6 +1371,8 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
|
||||
}
|
||||
|
||||
if constexpr (write_to_gmem) {
|
||||
asm volatile("move_out_start_%=:" ::);
|
||||
|
||||
if constexpr (TENSOR_HOPPER) {
|
||||
// wait until all results are accumulated into the RF
|
||||
vx_wgmma_wait();
|
||||
@@ -1367,6 +1391,8 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
asm volatile("move_out_end_%=:" ::);
|
||||
}
|
||||
}
|
||||
asm volatile("loop_mn_end_%=:" ::);
|
||||
|
||||
Reference in New Issue
Block a user