From b652e259451e7a51d73cc32a4a03278745ddaa27 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Mon, 9 Sep 2024 16:42:30 -0700 Subject: [PATCH] flash: Warp-specialize between warp 0 and 1-7 Finishes without stalls; No dependency check between O rescale and GEMM-II. --- .../regression/flash_attention/flash_impl.hpp | 30 +- .../flash_attention/kernel.gemmini.cpp | 473 +++++++++--------- 2 files changed, 261 insertions(+), 242 deletions(-) diff --git a/tests/regression/flash_attention/flash_impl.hpp b/tests/regression/flash_attention/flash_impl.hpp index 46a62546..8aac50ab 100644 --- a/tests/regression/flash_attention/flash_impl.hpp +++ b/tests/regression/flash_attention/flash_impl.hpp @@ -236,8 +236,9 @@ __attribute__((always_inline)) inline void thread_block_online_softmax( warp_smem[tid_in_warp] = per_thread_max; // sync writes to warp_smem - threadblock_barrier(threadblock_id_in_cluster, - warps_per_threadblock_per_core); + // threadblock_barrier(threadblock_id_in_cluster, + // warps_per_threadblock_per_core); + threadblock_barrier(1, 7); // #define PARALLEL_ROWMAX #ifndef PARALLEL_ROWMAX @@ -287,8 +288,9 @@ __attribute__((always_inline)) inline void thread_block_online_softmax( #endif // PARALLEL_ROWMAX #endif // DUMB_ROWMAX - threadblock_barrier(threadblock_id_in_cluster, - warps_per_threadblock_per_core); + // threadblock_barrier(threadblock_id_in_cluster, + // warps_per_threadblock_per_core); + threadblock_barrier(1, 7); // broadcast prev rowmax to all threads in the warp // NOTE: memory consistency is a little sketchy here @@ -331,8 +333,9 @@ __attribute__((always_inline)) inline void thread_block_online_softmax( asm volatile("flashattn_exp_p_end_%=:" ::); - threadblock_barrier(threadblock_id_in_cluster, - warps_per_threadblock_per_core); + // threadblock_barrier(threadblock_id_in_cluster, + // warps_per_threadblock_per_core); + threadblock_barrier(1, 7); // rowsum // @@ -358,8 +361,9 @@ __attribute__((always_inline)) inline void thread_block_online_softmax( warp_smem[tid_in_warp] = per_thread_sum; // sync writes to warp_smem - threadblock_barrier(threadblock_id_in_cluster, - warps_per_threadblock_per_core); + // threadblock_barrier(threadblock_id_in_cluster, + // warps_per_threadblock_per_core); + threadblock_barrier(1, 7); // 0-th thread collects all other thread's values in the warp if (tid_in_warp == 0) { @@ -387,8 +391,9 @@ __attribute__((always_inline)) inline void thread_block_online_softmax( asm volatile("flashattn_rowsum_end_%=:" ::); - threadblock_barrier(threadblock_id_in_cluster, - warps_per_threadblock_per_core); + // threadblock_barrier(threadblock_id_in_cluster, + // warps_per_threadblock_per_core); + threadblock_barrier(1, 7); // compute Oi rescale factor // FIXME: parallelize this across threads @@ -412,8 +417,9 @@ __attribute__((always_inline)) inline void thread_block_online_softmax( asm volatile("flashattn_rescale_factor_end_%=:" ::); - threadblock_barrier(threadblock_id_in_cluster, - warps_per_threadblock_per_core); + // threadblock_barrier(threadblock_id_in_cluster, + // warps_per_threadblock_per_core); + threadblock_barrier(1, 7); } asm volatile("thread_block_online_softmax_finish_%=:" ::); diff --git a/tests/regression/flash_attention/kernel.gemmini.cpp b/tests/regression/flash_attention/kernel.gemmini.cpp index b943ef0f..9d611ba2 100644 --- a/tests/regression/flash_attention/kernel.gemmini.cpp +++ b/tests/regression/flash_attention/kernel.gemmini.cpp @@ -58,6 +58,9 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { const uint32_t warpgroup_id_in_cluster = warpgroup_id % warpgroups_per_cluster; const uint32_t tid_in_warpgroup = tid_in_threadblock % threads_per_warpgroup; + // // warpgroup 0: warp 0 + // // warpgroup 1: warp 1~7 + // const uint32_t warpgroup_id = (warp_id != 0); const uint32_t dim_seqlen = arg->dim_seqlen; const uint32_t dim_headdim = arg->dim_headdim; @@ -178,6 +181,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { smem_rowmax_1, smem_rowsum_1, smem_O_row_scale_1); constexpr uint32_t global_barrier_id = NUM_WARPS - 1; // arbitrary + static_assert(warps_per_threadblock_per_core == NUM_WARPS); static_assert(!GEMMINI_DMA || Q_IS_K_MAJOR, "DMA code assumes Q matrix is stored K-major"); @@ -301,7 +305,8 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { asm volatile("dma_move_end_%=:" ::); // protect write to SMEM - threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core); + // threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core); + threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core); // if constexpr (DEBUG) { // thread_block_copy_tile(smem_Q0, gmem_tmp_d0, tid_in_warpgroup, @@ -311,6 +316,18 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { // threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core); // } + threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core); + + constexpr uint32_t threads_per_warpgroup_simt = + threads_per_warpgroup - + CORES_PER_CLUSTER * NUM_THREADS /*warp 0, 4, 8, 12*/; + constexpr uint32_t warpgroup_id_simt = 1; + constexpr uint32_t barrier_id_simt = 1; + constexpr uint32_t barrier_count_simt = NUM_WARPS - 1; + const uint32_t tid_in_warpgroup_simt = + tid_in_warpgroup - (CORES_PER_CLUSTER * NUM_THREADS); + static_assert(barrier_id_simt == 1 && barrier_count_simt == 7); + asm volatile ("tile_loop_start_%=:" :: ); // "inner loop" along the columns of K^T @@ -318,8 +335,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { for (uint32_t tile_k = 0; tile_k < k_tiles + 2 /*pipeline latency*/; tile_k++) { if constexpr (DEBUG || true) { - // barrier for debugging - // threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core); + threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core); } // select the correct double buffer by tile iteration @@ -360,13 +376,14 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { // This is done *before* GEMM I in the software pipeline, working on the // online softmax result tile from the previous iteration - if (tile_k >= 2) // delay by 2 iters for pipelining - { - const uint32_t tile_k_ = tile_k - 2; + if (vx_warp_id() == 0 /* warp 0 in every core */) { + if (tile_k >= 2) // delay by 2 iters for pipelining + { + const uint32_t tile_k_ = tile_k - 2; - asm volatile("gemm_pv_start_%=:" ::); + asm volatile("gemm_pv_start_%=:" ::); - if (tid_in_warpgroup == 0) { + if (tid_in_warpgroup == 0) { #if 0 if (tile_k_ == 0) { gemmini_fence(); @@ -379,114 +396,31 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { GEMMINI_CISC_CMD_I(1); } #else - // kickoff matmul - // among other things, this also configures CONFIG_BOUNDS so that the - // DMA knows the full matrix dimensions - // FIXME: perf: prevent GMEM->SMEM load for O tile - gemmini_fence(); - gemmini_fence(); - gemmini_fence(); - gemmini_fence(); - sp_tiled_matmul_full_spad_ws( - spad_addr_P_consume, spad_addr_V_consume, - /*spad_D=*/spad_addr_O, /*spad_C=*/spad_addr_O, - /*I=*/(B_ROW / DIM), /*J=*/(HEADDIM / DIM), /*K=*/(B_COL / DIM), - /*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0, - /*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0, - /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_matmul); + // kickoff matmul + // among other things, this also configures CONFIG_BOUNDS so that the + // DMA knows the full matrix dimensions + // FIXME: perf: prevent GMEM->SMEM load for O tile + gemmini_fence(); + gemmini_fence(); + gemmini_fence(); + gemmini_fence(); + sp_tiled_matmul_full_spad_ws( + spad_addr_P_consume, spad_addr_V_consume, + /*spad_D=*/spad_addr_O, /*spad_C=*/spad_addr_O, + /*I=*/(B_ROW / DIM), /*J=*/(HEADDIM / DIM), /*K=*/(B_COL / DIM), + /*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0, + /*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0, + /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_matmul); #endif - } - - // reconverge from mmio divergence - threadblock_barrier(warpgroup_id_in_cluster, - warps_per_warpgroup_per_core); - - asm volatile("gemm_pv_finish_%=:" ::); - - } - - if (tile_k >= 1) // delay by 1 iters for pipelining - { - const uint32_t tile_k_ = tile_k - 1; - - // Online softmax - // - thread_block_online_softmax( - smem_S_consume, smem_P_produce, tid_in_warpgroup, - threads_per_warpgroup, warpgroup_id_in_cluster, smem_scratchpad, - smem_rowmax, smem_rowsum, smem_O_row_scale); - - threadblock_barrier(warpgroup_id_in_cluster, - warps_per_warpgroup_per_core); - - if constexpr (DEBUG) { - if (warpgroup_id == 0) { - if (tile_k_ == 0) { - thread_block_copy_rowmax(smem_rowmax, gmem_tmp_e0, tid_in_warpgroup, - threads_per_warpgroup, - warpgroup_id_in_cluster); - thread_block_copy_rowmax(smem_rowsum, gmem_tmp_e2, tid_in_warpgroup, - threads_per_warpgroup, - warpgroup_id_in_cluster); - } else if (tile_k_ == 1) { - thread_block_copy_rowmax(smem_rowmax, gmem_tmp_e1, tid_in_warpgroup, - threads_per_warpgroup, - warpgroup_id_in_cluster); - thread_block_copy_rowmax(smem_rowsum, gmem_tmp_e3, tid_in_warpgroup, - threads_per_warpgroup, - warpgroup_id_in_cluster); - } - - threadblock_barrier(warpgroup_id_in_cluster, - warps_per_warpgroup_per_core); } + + // // reconverge from mmio divergence + // threadblock_barrier(warpgroup_id_in_cluster, + // warps_per_warpgroup_per_core); + + asm volatile("gemm_pv_finish_%=:" ::); } - // fence GEMM II to make sure dependency on O tile is settled - if (tid_in_warpgroup == 0) { - gemmini_fence(); - gemmini_fence(); - gemmini_fence(); - gemmini_fence(); - -#if 0 - // mvout to SMEM - // GEMMINI_CISC_CMD_I(9); - sp_tiled_matmul_full_spad_ws( - /*spad_A=*/spad_addr_P_consume, /*spad_B=*/spad_addr_V_consume, - /*spad_D=*/0, /*spad_C=*/spad_addr_O, - /*I=*/(B_ROW / DIM), /*J=*/(HEADDIM / DIM), /*K=*/(B_COL / DIM), - /*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0, - /*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0, - /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_mvout_spad); -#endif - } - - // reconverge from mmio divergence - threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core); - - if constexpr (DEBUG) { - gemmini_fence(); - - if (warpgroup_id == 0) { - // O after PV - if (tile_k_ == 0) { - thread_block_copy_tile( - smem_O, gmem_tmp_d6, tid_in_warpgroup, threads_per_warpgroup, - warpgroup_id_in_cluster); - } else if (tile_k_ == 1) { - thread_block_copy_tile( - smem_O, gmem_tmp_d7, tid_in_warpgroup, threads_per_warpgroup, - warpgroup_id_in_cluster); - } - - threadblock_barrier(warpgroup_id_in_cluster, - warps_per_warpgroup_per_core); - } - } - } - - { // GEMM I: S = Q*K // // kick off asynchronously; fence later @@ -510,6 +444,11 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { /*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0, /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_matmul); + gemmini_fence(); + gemmini_fence(); + gemmini_fence(); + gemmini_fence(); + #if 0 // TODO: speed up mvout to SMEM // loop_ws variant that skips configuring strides #define gemmini_loop_ws(I, J, K, pad_I, pad_J, pad_K, A, B, D, C, A_stride, B_stride, D_stride, C_stride, A_transpose, B_transpose, full_C, low_D, ex_accumulate, act, a_spad_id, b_spad_id, is_resadd) \ @@ -523,57 +462,186 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { } #endif } - - // reconverge from mmio divergence - threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core); + // // reconverge after mmio + // threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core); asm volatile("gemm_qk_finish_%=:" ::); - } - if (tile_k >= 1) // delay by 1 iters for pipelining - { - const uint32_t tile_k_ = tile_k - 1; + // TODO: put synchronization here with online softmax - // Oi rescale - thread_block_O_rescale( - smem_O, smem_O /*in-place*/, smem_O_row_scale, tid_in_warpgroup, - threads_per_warpgroup, warpgroup_id_in_cluster); + // data move for K and V + // + // Q stays in SMEM for the entire loop + asm volatile("move_k_v_start_%=:" ::); - // rescale-to-PV-GEMM barrier - threadblock_barrier(warpgroup_id_in_cluster, - warps_per_warpgroup_per_core); + // NOTE: Beware of race conditions; with warp specialization, we need to + // make sure below command code to DMA is not executed simultaneously + // from the two warpgroups (which will result in hardware fault). + // Currently the ping-pong scheduling scheme prevents that. + if (tid_in_warpgroup == 0) { + // configure GMEM addresses for K and V tiles + // load K for the next iteration + const float *gmem_K_tile = gmem_K + (B_COL * (tile_k + 1 /*runahead*/)); + // load V for the *previous* iteration; this will be consumed 2 + // iterations later + const float *gmem_V_tile = + gmem_V + (HEADDIM * B_COL * (tile_k - 1 /*dragbehind*/)); - if constexpr (DEBUG) { - if (warpgroup_id == 0) { - // O before PV - if (tile_k_ == 0) { - thread_block_copy_tile( - smem_P_produce, gmem_tmp_d2, tid_in_warpgroup, - threads_per_warpgroup, warpgroup_id_in_cluster); - thread_block_copy_tile( - smem_O, gmem_tmp_d4, tid_in_warpgroup, threads_per_warpgroup, - warpgroup_id_in_cluster); - } else if (tile_k_ == 1) { - thread_block_copy_tile( - smem_P_produce, gmem_tmp_d3, tid_in_warpgroup, - threads_per_warpgroup, warpgroup_id_in_cluster); - thread_block_copy_tile( - smem_O, gmem_tmp_d5, tid_in_warpgroup, threads_per_warpgroup, - warpgroup_id_in_cluster); + // fence mvout S to SMEM + gemmini_fence(); + ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, (uint64_t)(gmem_K_tile), + (uint64_t)(gmem_V_tile), + k_LOOP_WS_CONFIG_ADDRS_AB) + // configure address strides for the DMA + // FIXME: unnecessary? + GEMMINI_CISC_CMD_R((HEADDIM /*V*/ << 20) | (dim_seqlen /*KT*/ << 8) | + 8 /*k_LOOP_WS_CONFIG_STRIDES_AB*/); + gemmini_fence(); + + // do DMA + if (tile_k == 0) { + // we load (k-1)th tile for V; skip V for the 1st iteration, + sp_tiled_matmul_full_spad_ws( + spad_addr_K_produce, spad_addr_V_produce, + /*spad_D=*/0, /*spad_C=*/spad_addr_S_produce /*FIXME:bogus*/, + /*I=*/(B_ROW / DIM), /*J=*/(HEADDIM / DIM), /*K=*/(B_COL / DIM), + /*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0, + /*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0, + /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_only_a); + } else { + sp_tiled_matmul_full_spad_ws( + spad_addr_K_produce, spad_addr_V_produce, + /*spad_D=*/0, /*spad_C=*/spad_addr_S_produce /*FIXME:bogus*/, + /*I=*/(B_ROW / DIM), /*J=*/(HEADDIM / DIM), /*K=*/(B_COL / DIM), + /*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0, + /*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0, + /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips); + } + gemmini_fence(); + gemmini_fence(); + gemmini_fence(); + gemmini_fence(); + } + + // threadblock_barrier(warpgroup_id_in_cluster, + // warps_per_warpgroup_per_core); + + asm volatile("move_k_v_finish_%=:" ::); + + // // intra-warpgroup barrier + // // FIXME hardcoded + // threadblock_barrier(0, 1); + + } else /* warp_id != 0 */ { + + if (tile_k >= 1) // delay by 1 iters for pipelining + { + const uint32_t tile_k_ = tile_k - 1; + + // Online softmax + // + thread_block_online_softmax( + smem_S_consume, smem_P_produce, tid_in_warpgroup_simt, + threads_per_warpgroup_simt, warpgroup_id_simt, smem_scratchpad, + smem_rowmax, smem_rowsum, smem_O_row_scale); + + threadblock_barrier(barrier_id_simt, barrier_count_simt); + + if constexpr (DEBUG) { + if (warpgroup_id == 0) { + if (tile_k_ == 0) { + thread_block_copy_rowmax(smem_rowmax, gmem_tmp_e0, + tid_in_warpgroup_simt, threads_per_warpgroup, + warpgroup_id_in_cluster); + thread_block_copy_rowmax(smem_rowsum, gmem_tmp_e2, + tid_in_warpgroup_simt, threads_per_warpgroup, + warpgroup_id_in_cluster); + } else if (tile_k_ == 1) { + thread_block_copy_rowmax(smem_rowmax, gmem_tmp_e1, + tid_in_warpgroup_simt, threads_per_warpgroup, + warpgroup_id_in_cluster); + thread_block_copy_rowmax(smem_rowsum, gmem_tmp_e3, + tid_in_warpgroup_simt, threads_per_warpgroup, + warpgroup_id_in_cluster); + } + + threadblock_barrier(barrier_id_simt, barrier_count_simt); } + } - threadblock_barrier(warpgroup_id_in_cluster, - warps_per_warpgroup_per_core); +#if 0 + // fence GEMM II to make sure dependency on O tile is settled + if (tid_in_warpgroup == 0) { + gemmini_fence(); + gemmini_fence(); + gemmini_fence(); + gemmini_fence(); + } + + // reconverge from mmio divergence + threadblock_barrier(warpgroup_id_in_cluster, + warps_per_warpgroup_per_core); +#endif + + if constexpr (DEBUG) { + // gemmini_fence(); + + if (warpgroup_id == 0) { + // O after PV + if (tile_k_ == 0) { + thread_block_copy_tile( + smem_O, gmem_tmp_d6, tid_in_warpgroup_simt, threads_per_warpgroup_simt, + warpgroup_id_simt); + } else if (tile_k_ == 1) { + thread_block_copy_tile( + smem_O, gmem_tmp_d7, tid_in_warpgroup_simt, threads_per_warpgroup_simt, + warpgroup_id_simt); + } + + threadblock_barrier(barrier_id_simt, barrier_count_simt); + } + } + + // Oi rescale + thread_block_O_rescale( + smem_O, smem_O /*in-place*/, smem_O_row_scale, + tid_in_warpgroup_simt, threads_per_warpgroup_simt, + warpgroup_id_simt); + + // rescale-to-PV-GEMM barrier + threadblock_barrier(barrier_id_simt, barrier_count_simt); + + if constexpr (DEBUG) { + if (warpgroup_id == 0) { + // O before PV + if (tile_k_ == 0) { + thread_block_copy_tile( + smem_P_produce, gmem_tmp_d2, tid_in_warpgroup_simt, + threads_per_warpgroup_simt, warpgroup_id_simt); + thread_block_copy_tile( + smem_O, gmem_tmp_d4, tid_in_warpgroup_simt, + threads_per_warpgroup_simt, warpgroup_id_simt); + } else if (tile_k_ == 1) { + thread_block_copy_tile( + smem_P_produce, gmem_tmp_d3, tid_in_warpgroup_simt, + threads_per_warpgroup_simt, warpgroup_id_simt); + thread_block_copy_tile( + smem_O, gmem_tmp_d5, tid_in_warpgroup_simt, + threads_per_warpgroup_simt, warpgroup_id_simt); + } + + threadblock_barrier(barrier_id_simt, barrier_count_simt); + } } } - } - // fence GEMM I after Oi rescale - if (tid_in_warpgroup == 0) { - gemmini_fence(); - gemmini_fence(); - gemmini_fence(); - gemmini_fence(); +#if 0 + // fence GEMM I after Oi rescale + if (tid_in_warpgroup == 0) { + gemmini_fence(); + gemmini_fence(); + gemmini_fence(); + gemmini_fence(); #if 0 // mvout to SMEM @@ -587,87 +655,32 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { /*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0, /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_mvout_spad); #endif - } - - // reconverge from mmio divergence - threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core); - - if constexpr (DEBUG) { - if (warpgroup_id == 0) { - if (tile_k == 0) { - thread_block_copy_tile( - smem_S_produce, gmem_tmp_d0, tid_in_warpgroup, - threads_per_warpgroup, warpgroup_id_in_cluster); - } else if (tile_k == 1) { - thread_block_copy_tile( - smem_S_produce, gmem_tmp_d1, tid_in_warpgroup, - threads_per_warpgroup, warpgroup_id_in_cluster); - } - - threadblock_barrier(warpgroup_id_in_cluster, - warps_per_warpgroup_per_core); } - } - // data move for K and V - // - // Q stays in SMEM for the entire loop - asm volatile("move_k_v_start_%=:" ::); - - // NOTE: Beware of race conditions; with warp specialization, we need to - // make sure below command code to DMA is not executed simultaneously - // from the two warpgroups (which will result in hardware fault). - // Currently the ping-pong scheduling scheme prevents that. - if (tid_in_warpgroup == 0) { - // configure GMEM addresses for K and V tiles - // load K for the next iteration - const float *gmem_K_tile = gmem_K + (B_COL * (tile_k + 1 /*runahead*/)); - // load V for the *previous* iteration; this will be consumed 2 - // iterations later - const float *gmem_V_tile = - gmem_V + (HEADDIM * B_COL * (tile_k - 1 /*dragbehind*/)); - - // fence mvout S to SMEM - gemmini_fence(); - ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, (uint64_t)(gmem_K_tile), - (uint64_t)(gmem_V_tile), - k_LOOP_WS_CONFIG_ADDRS_AB) - // configure address strides for the DMA - // FIXME: unnecessary? - GEMMINI_CISC_CMD_R((HEADDIM /*V*/ << 20) | (dim_seqlen /*KT*/ << 8) | - 8 /*k_LOOP_WS_CONFIG_STRIDES_AB*/); - gemmini_fence(); - - // do DMA - if (tile_k == 0) { - // we load (k-1)th tile for V; skip V for the 1st iteration, - sp_tiled_matmul_full_spad_ws( - spad_addr_K_produce, spad_addr_V_produce, - /*spad_D=*/0, /*spad_C=*/spad_addr_S_produce /*FIXME:bogus*/, - /*I=*/(B_ROW / DIM), /*J=*/(HEADDIM / DIM), /*K=*/(B_COL / DIM), - /*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0, - /*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0, - /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_only_a); - } else { - sp_tiled_matmul_full_spad_ws( - spad_addr_K_produce, spad_addr_V_produce, - /*spad_D=*/0, /*spad_C=*/spad_addr_S_produce /*FIXME:bogus*/, - /*I=*/(B_ROW / DIM), /*J=*/(HEADDIM / DIM), /*K=*/(B_COL / DIM), - /*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0, - /*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0, - /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips); - } - gemmini_fence(); - gemmini_fence(); - gemmini_fence(); - gemmini_fence(); - } - - threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core); - - asm volatile("move_k_v_finish_%=:" ::); -#if 0 + // reconverge from mmio divergence + threadblock_barrier(warpgroup_id_in_cluster, + warps_per_warpgroup_per_core); #endif + + if constexpr (DEBUG) { + if (warpgroup_id == 0) { + if (tile_k == 0) { + thread_block_copy_tile( + smem_S_produce, gmem_tmp_d0, tid_in_warpgroup_simt, + threads_per_warpgroup_simt, warpgroup_id_simt); + } else if (tile_k == 1) { + thread_block_copy_tile( + smem_S_produce, gmem_tmp_d1, tid_in_warpgroup_simt, + threads_per_warpgroup_simt, warpgroup_id_simt); + } + + threadblock_barrier(barrier_id_simt, barrier_count_simt); + } + } + + // intra-warpgroup barrier + threadblock_barrier(barrier_id_simt, barrier_count_simt); + } } asm volatile ("tile_loop_finish_%=:" :: );