diff --git a/tests/regression/flash_attention/kernel.gemmini.cpp b/tests/regression/flash_attention/kernel.gemmini.cpp index d5c553d8..9e36bf83 100644 --- a/tests/regression/flash_attention/kernel.gemmini.cpp +++ b/tests/regression/flash_attention/kernel.gemmini.cpp @@ -184,10 +184,10 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { constexpr uint32_t global_barrier_id = NUM_WARPS - 1; // arbitrary - // delay warpgroup 0 by 1 iteration to do ping-pong scheduling - if (WARP_SPECIALIZED && warpgroup_id == 1) { - threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core); - } + // // delay warpgroup 0 by 1 iteration to do ping-pong scheduling + // if (WARP_SPECIALIZED && warpgroup_id == 1) { + // threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core); + // } static_assert(!GEMMINI_DMA || Q_IS_K_MAJOR, "DMA code assumes Q matrix is stored K-major"); @@ -196,6 +196,9 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { constexpr uint32_t skips = loop_matmul_skips(/*skip_lda=*/0, /*skip_ldb=*/0, /*skip_ldd=*/1, /*skip_ex=*/1, /*skip_stc=*/1); + constexpr uint32_t skips_only_a = + loop_matmul_skips(/*skip_lda=*/0, /*skip_ldb=*/1, /*skip_ldd=*/1, + /*skip_ex=*/1, /*skip_stc=*/1); constexpr uint32_t skips_mvout_spad = loop_matmul_skips(/*skip_lda=*/1, /*skip_ldb=*/1, /*skip_ldd=*/1, /*skip_ex=*/1, /*skip_stc=*/0); @@ -248,6 +251,8 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { #define GEMMINI_DMA_CISC #ifdef GEMMINI_DMA_CISC + // the target addresses of this should match with spad_addr_Q0 and + // spad_addr_K0 set in this kernel GEMMINI_CISC_CMD_I(10); gemmini_fence(); #else @@ -292,15 +297,30 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { // "inner loop" along the columns of K^T const uint32_t k_tiles = (dim_seqlen / B_COL); - for (uint32_t tile_k = 0; tile_k < k_tiles; tile_k++) { + for (uint32_t tile_k = 0; tile_k < k_tiles + 2 /*pipeline latency*/; + tile_k++) { + if constexpr (DEBUG) { + // barrier for debugging + threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core); + } + // select the correct double buffer by tile iteration - // FIXME do correct double buffering - float *smem_Q = (tile_k & 1) ? smem_Q1 : smem_Q0; - float *smem_K = (tile_k & 1) ? smem_K1 : smem_K0; - float *smem_V = (tile_k & 1) ? smem_V1 : smem_V0; - float *smem_S = (tile_k & 1) ? smem_S1 : smem_S0; - float *smem_P = (tile_k & 1) ? smem_P1 : smem_P0; - float *smem_O = (tile_k & 1) ? smem_O1 : smem_O0; + // all iterations work on the same Q row tile; no ping-pong necessary + asm volatile ("dbuf_sel_start_%=:" :: ); + // FIXME speedup by doing arithmetic + float *smem_Q = smem_Q0; + float *smem_K_consume = (tile_k & 1) ? smem_K1 : smem_K0; + float *smem_K_produce = (tile_k & 1) ? smem_K0 : smem_K1; + float *smem_V_consume = (tile_k & 1) ? smem_V1 : smem_V0; + float *smem_V_produce = (tile_k & 1) ? smem_V0 : smem_V1; + float *smem_S_consume = (tile_k & 1) ? smem_S1 : smem_S0; + float *smem_S_produce = (tile_k & 1) ? smem_S0 : smem_S1; + float *smem_P_consume = (tile_k & 1) ? smem_P1 : smem_P0; + float *smem_P_produce = (tile_k & 1) ? smem_P0 : smem_P1; + // O tile is sequentially updated at every iteration; no ping-pong + // necessary + float *smem_O = smem_O0; + // FIXME: O_row_scale/rowmax/rowsum/spad shouldn't really need ping-pong float *smem_O_row_scale = (tile_k & 1) ? smem_O_row_scale_1 : smem_O_row_scale_0; float *smem_rowmax = (tile_k & 1) ? smem_rowmax_1 : smem_rowmax_0; @@ -308,28 +328,111 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { float *smem_scratchpad = (tile_k & 1) ? smem_scratchpad_1 : smem_scratchpad_0; - const auto spad_addr_Q = (tile_k & 1) ? spad_addr_Q1 : spad_addr_Q0; - const auto spad_addr_K = (tile_k & 1) ? spad_addr_K1 : spad_addr_K0; - const auto spad_addr_V = (tile_k & 1) ? spad_addr_V1 : spad_addr_V0; - const auto spad_addr_S = (tile_k & 1) ? spad_addr_S1 : spad_addr_S0; - const auto spad_addr_P = (tile_k & 1) ? spad_addr_P1 : spad_addr_P0; + const auto spad_addr_Q = spad_addr_Q0; + const auto spad_addr_K_consume = (tile_k & 1) ? spad_addr_K1 : spad_addr_K0; + const auto spad_addr_K_produce = (tile_k & 1) ? spad_addr_K0 : spad_addr_K1; + const auto spad_addr_V_consume = (tile_k & 1) ? spad_addr_V1 : spad_addr_V0; + const auto spad_addr_V_produce = (tile_k & 1) ? spad_addr_V0 : spad_addr_V1; + const auto spad_addr_S_consume = (tile_k & 1) ? spad_addr_S1 : spad_addr_S0; + const auto spad_addr_S_produce = (tile_k & 1) ? spad_addr_S0 : spad_addr_S1; + const auto spad_addr_P_consume = (tile_k & 1) ? spad_addr_P1 : spad_addr_P0; + const auto spad_addr_P_produce = (tile_k & 1) ? spad_addr_P0 : spad_addr_P1; const auto spad_addr_O = spad_addr_O0; // NOTE: there's only single O tile + asm volatile ("dbuf_sel_end_%=:" :: ); - // GEMM I: S = Q*K - // - asm volatile("gemm_qk_start_%=:" ::); + // GEMM II: O = O + P*V + // -------------------- + // This is done *before* GEMM I in the software pipeline, working on the + // online softmax result tile from the previous iteration - if (tid_in_warpgroup == 0) { - if (tile_k == 0) { + if (tile_k >= 2) // delay by 2 iters for pipelining + { + const uint32_t tile_k_ = tile_k - 2; + + asm volatile("gemm_pv_start_%=:" ::); + + if (tid_in_warpgroup == 0) { +#if 0 + if (tile_k_ == 0) { gemmini_fence(); GEMMINI_CISC_CMD_I(0); - } else if (tile_k & 1) { + } else if (tile_k_ & 1) { gemmini_fence(); GEMMINI_CISC_CMD_I(2); } else { gemmini_fence(); GEMMINI_CISC_CMD_I(1); } +#else + // do matmul + // among other things, this also configures CONFIG_BOUNDS so that the + // DMA knows the full matrix dimensions + sp_tiled_matmul_full_spad_ws( + spad_addr_P_consume, spad_addr_V_consume, + /*spad_D=*/0, /*spad_C=*/spad_addr_O, + /*I=*/(B_ROW / DIM), /*J=*/(HEADDIM / DIM), /*K=*/(B_COL / DIM), + /*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0, + /*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0, + /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_matmul); +#endif + + gemmini_fence(); + gemmini_fence(); + gemmini_fence(); + gemmini_fence(); + + // mvout to SMEM + // GEMMINI_CISC_CMD_I(9); + sp_tiled_matmul_full_spad_ws( + /*spad_A=*/spad_addr_P_consume, /*spad_B=*/spad_addr_V_consume, + /*spad_D=*/0, /*spad_C=*/spad_addr_O, + /*I=*/(B_ROW / DIM), /*J=*/(HEADDIM / DIM), /*K=*/(B_COL / DIM), + /*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0, + /*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0, + /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_mvout_spad); + gemmini_fence(); + + if constexpr (DEBUG) { + // for copy-out to GMEM + gemmini_fence(); + } + } + + // reconverge from mmio divergence + threadblock_barrier(warpgroup_id_in_cluster, + warps_per_warpgroup_per_core); + + asm volatile("gemm_pv_finish_%=:" ::); + + if constexpr (DEBUG) { + if (warpgroup_id == 0) { + // O after PV + if (tile_k_ == 0) { + thread_block_copy_tile( + smem_O, gmem_tmp_d6, tid_in_warpgroup, threads_per_warpgroup, + warpgroup_id_in_cluster); + } else if (tile_k_ == 1) { + thread_block_copy_tile( + smem_O, gmem_tmp_d7, tid_in_warpgroup, threads_per_warpgroup, + warpgroup_id_in_cluster); + } + + threadblock_barrier(warpgroup_id_in_cluster, + warps_per_warpgroup_per_core); + } + } + } + + // GEMM I: S = Q*K + // + asm volatile("gemm_qk_start_%=:" ::); + + if (tid_in_warpgroup == 0) { + // 0,2,.: opcode 0 (quartile 0/2, no accum) + // 1,3,.: opcode 3 (quartile 1/3, no accum) + const uint32_t opcode = 3 * (tile_k & 1); + gemmini_fence(); + GEMMINI_CISC_CMD_I(opcode); gemmini_fence(); gemmini_fence(); @@ -352,8 +455,8 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { // mvout to SMEM // GEMMINI_CISC_CMD_I(9); sp_tiled_matmul_full_spad_ws( - /*spad_A=*/spad_addr_Q, /*spad_B=*/spad_addr_K, - /*spad_D=*/0, /*spad_C=*/spad_addr_S, + /*spad_A=*/spad_addr_Q, /*spad_B=*/spad_addr_K_consume, + /*spad_D=*/0, /*spad_C=*/spad_addr_S_produce, /*I=*/(B_ROW / DIM), /*J=*/(B_COL / DIM), /*K=*/(HEADDIM / DIM), /*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0, /*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0, @@ -375,11 +478,11 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { if (warpgroup_id == 0) { if (tile_k == 0) { thread_block_copy_tile( - smem_S, gmem_tmp_d0, tid_in_warpgroup, threads_per_warpgroup, + smem_S_produce, gmem_tmp_d0, tid_in_warpgroup, threads_per_warpgroup, warpgroup_id_in_cluster); } else if (tile_k == 1) { thread_block_copy_tile( - smem_S, gmem_tmp_d1, tid_in_warpgroup, threads_per_warpgroup, + smem_S_produce, gmem_tmp_d1, tid_in_warpgroup, threads_per_warpgroup, warpgroup_id_in_cluster); } @@ -388,39 +491,76 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { } } - // inter-warpgroup barrier before online softmax - threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core); + if (tile_k >= 1) // delay by 1 iters for pipelining + { + const uint32_t tile_k_ = tile_k - 1; - // Online softmax - // - thread_block_online_softmax( - smem_S, smem_P, tid_in_warpgroup, threads_per_warpgroup, - warpgroup_id_in_cluster, smem_scratchpad, smem_rowmax, smem_rowsum, - smem_O_row_scale); + // Online softmax + // + thread_block_online_softmax( + smem_S_consume, smem_P_produce, tid_in_warpgroup, + threads_per_warpgroup, warpgroup_id_in_cluster, smem_scratchpad, + smem_rowmax, smem_rowsum, smem_O_row_scale); - // FIXME: unnecessary? - threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core); + threadblock_barrier(warpgroup_id_in_cluster, + warps_per_warpgroup_per_core); - if constexpr (DEBUG) { - if (warpgroup_id == 0) { - if (tile_k == 0) { - thread_block_copy_rowmax(smem_rowmax, gmem_tmp_e0, tid_in_warpgroup, - threads_per_warpgroup, - warpgroup_id_in_cluster); - thread_block_copy_rowmax(smem_rowsum, gmem_tmp_e2, tid_in_warpgroup, - threads_per_warpgroup, - warpgroup_id_in_cluster); - } else if (tile_k == 1) { - thread_block_copy_rowmax(smem_rowmax, gmem_tmp_e1, tid_in_warpgroup, - threads_per_warpgroup, - warpgroup_id_in_cluster); - thread_block_copy_rowmax(smem_rowsum, gmem_tmp_e3, tid_in_warpgroup, - threads_per_warpgroup, - warpgroup_id_in_cluster); + if constexpr (DEBUG) { + if (warpgroup_id == 0) { + if (tile_k_ == 0) { + thread_block_copy_rowmax(smem_rowmax, gmem_tmp_e0, tid_in_warpgroup, + threads_per_warpgroup, + warpgroup_id_in_cluster); + thread_block_copy_rowmax(smem_rowsum, gmem_tmp_e2, tid_in_warpgroup, + threads_per_warpgroup, + warpgroup_id_in_cluster); + } else if (tile_k_ == 1) { + thread_block_copy_rowmax(smem_rowmax, gmem_tmp_e1, tid_in_warpgroup, + threads_per_warpgroup, + warpgroup_id_in_cluster); + thread_block_copy_rowmax(smem_rowsum, gmem_tmp_e3, tid_in_warpgroup, + threads_per_warpgroup, + warpgroup_id_in_cluster); + } + + threadblock_barrier(warpgroup_id_in_cluster, + warps_per_warpgroup_per_core); } + } - threadblock_barrier(warpgroup_id_in_cluster, - warps_per_warpgroup_per_core); + // TODO: put a synchronization here with GEMM-II + + // Oi rescale + thread_block_O_rescale( + smem_O, smem_O /*in-place*/, smem_O_row_scale, tid_in_warpgroup, + threads_per_warpgroup, warpgroup_id_in_cluster); + + // rescale-to-PV-GEMM barrier + threadblock_barrier(warpgroup_id_in_cluster, + warps_per_warpgroup_per_core); + + if constexpr (DEBUG) { + if (warpgroup_id == 0) { + // O before PV + if (tile_k_ == 0) { + thread_block_copy_tile( + smem_P_produce, gmem_tmp_d2, tid_in_warpgroup, + threads_per_warpgroup, warpgroup_id_in_cluster); + thread_block_copy_tile( + smem_O, gmem_tmp_d4, tid_in_warpgroup, threads_per_warpgroup, + warpgroup_id_in_cluster); + } else if (tile_k_ == 1) { + thread_block_copy_tile( + smem_P_produce, gmem_tmp_d3, tid_in_warpgroup, + threads_per_warpgroup, warpgroup_id_in_cluster); + thread_block_copy_tile( + smem_O, gmem_tmp_d5, tid_in_warpgroup, threads_per_warpgroup, + warpgroup_id_in_cluster); + } + + threadblock_barrier(warpgroup_id_in_cluster, + warps_per_warpgroup_per_core); + } } } @@ -428,171 +568,64 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { // // Q stays in SMEM for the entire loop asm volatile("move_k_v_start_%=:" ::); - if constexpr (GEMMINI_DMA) { - // NOTE: Beware of race conditions; with warp specialization, we need to - // make sure below command code to DMA is not executed simultaneously - // from the two warpgroups (which will result in hardware fault). - // Currently the ping-pong scheduling scheme prevents that. - if (tid_in_warpgroup == 0) { - // configure GMEM addresses for K and V tiles - // load K for the next iteration - const float *gmem_K_tile = gmem_K + (B_COL * (tile_k + 1)); - // load V for the current iteration - const float *gmem_V_tile = gmem_V + (HEADDIM * B_COL * tile_k); - ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, (uint64_t)(gmem_K_tile), - (uint64_t)(gmem_V_tile), - k_LOOP_WS_CONFIG_ADDRS_AB) - // configure address strides for the DMA - // FIXME: unnecessary? - GEMMINI_CISC_CMD_R((HEADDIM /*V*/ << 20) | (dim_seqlen /*KT*/ << 8) | - 8 /*k_LOOP_WS_CONFIG_STRIDES_AB*/); - gemmini_fence(); - // do DMA + // NOTE: Beware of race conditions; with warp specialization, we need to + // make sure below command code to DMA is not executed simultaneously + // from the two warpgroups (which will result in hardware fault). + // Currently the ping-pong scheduling scheme prevents that. + if (tid_in_warpgroup == 0) { + // configure GMEM addresses for K and V tiles + // load K for the next iteration + const float *gmem_K_tile = gmem_K + (B_COL * (tile_k + 1 /*runahead*/)); + // load V for the *previous* iteration; this will be consumed 2 + // iterations later + const float *gmem_V_tile = + gmem_V + (HEADDIM * B_COL * (tile_k - 1 /*dragbehind*/)); + ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, (uint64_t)(gmem_K_tile), + (uint64_t)(gmem_V_tile), + k_LOOP_WS_CONFIG_ADDRS_AB) + // configure address strides for the DMA + // FIXME: unnecessary? + GEMMINI_CISC_CMD_R((HEADDIM /*V*/ << 20) | (dim_seqlen /*KT*/ << 8) | + 8 /*k_LOOP_WS_CONFIG_STRIDES_AB*/); + gemmini_fence(); + + // do DMA + if (tile_k == 0) { + // we load (k-1)th tile for V; skip V for the 1st iteration, sp_tiled_matmul_full_spad_ws( - spad_addr_K, spad_addr_V, - /*spad_D=*/0, /*spad_C=*/spad_addr_S, + spad_addr_K_produce, spad_addr_V_produce, + /*spad_D=*/0, /*spad_C=*/spad_addr_S_produce /*FIXME:bogus*/, + /*I=*/(HEADDIM / DIM), /*J=*/(HEADDIM / DIM), /*K=*/(B_COL / DIM), + /*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0, + /*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0, + /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_only_a); + } else { + sp_tiled_matmul_full_spad_ws( + spad_addr_K_produce, spad_addr_V_produce, + /*spad_D=*/0, /*spad_C=*/spad_addr_S_produce /*FIXME:bogus*/, /*I=*/(HEADDIM / DIM), /*J=*/(HEADDIM / DIM), /*K=*/(B_COL / DIM), /*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0, /*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0, /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips); - gemmini_fence(); } - } else { - // load K for the next iteration - load_tile_to_smem( - dim_seqlen, tile_k + 1, 0 /* dim_k == headdim */, gmem_K, smem_K, - tid_in_warpgroup); - - // load V for the current iteration - // V dimension is [seqlen, headdim], stored N(headdim)-major - load_tile_to_smem( - HEADDIM, 0 /* full N-dimension */, tile_k, gmem_V, smem_V, - tid_in_warpgroup); + gemmini_fence(); } + + threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core); + asm volatile("move_k_v_finish_%=:" ::); - - // inter-warpgroup barrier before GEMM II - threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core); - - // Oi rescale - thread_block_O_rescale( - smem_O, smem_O /*in-place*/, smem_O_row_scale, tid_in_warpgroup, - threads_per_warpgroup, warpgroup_id_in_cluster); - - // rescale-to-PV-GEMM barrier - threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core); - - if constexpr (DEBUG) { - if (warpgroup_id == 0) { - // O before PV - if (tile_k == 0) { - thread_block_copy_tile( - smem_P, gmem_tmp_d2, tid_in_warpgroup, threads_per_warpgroup, - warpgroup_id_in_cluster); - thread_block_copy_tile( - smem_O, gmem_tmp_d4, tid_in_warpgroup, threads_per_warpgroup, - warpgroup_id_in_cluster); - } else if (tile_k == 1) { - thread_block_copy_tile( - smem_P, gmem_tmp_d3, tid_in_warpgroup, threads_per_warpgroup, - warpgroup_id_in_cluster); - thread_block_copy_tile( - smem_O, gmem_tmp_d5, tid_in_warpgroup, threads_per_warpgroup, - warpgroup_id_in_cluster); - } - - threadblock_barrier(warpgroup_id_in_cluster, - warps_per_warpgroup_per_core); - } - } - - // GEMM II: O = O + P*V - - asm volatile("gemm_pv_start_%=:" ::); - - if (tid_in_warpgroup == 0) { -#if 0 - if (tile_k == 0) { - gemmini_fence(); - GEMMINI_CISC_CMD_I(0); - } else if (tile_k & 1) { - gemmini_fence(); - GEMMINI_CISC_CMD_I(2); - } else { - gemmini_fence(); - GEMMINI_CISC_CMD_I(1); - } -#else - // do matmul - // among other things, this also configures CONFIG_BOUNDS so that the - // DMA knows the full matrix dimensions - sp_tiled_matmul_full_spad_ws( - spad_addr_P, spad_addr_V, - /*spad_D=*/0, /*spad_C=*/spad_addr_O, - /*I=*/(B_ROW / DIM), /*J=*/(HEADDIM / DIM), /*K=*/(B_COL / DIM), - /*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0, - /*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0, - /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_matmul); -#endif - - gemmini_fence(); - gemmini_fence(); - gemmini_fence(); - gemmini_fence(); - - // mvout to SMEM - // GEMMINI_CISC_CMD_I(9); - sp_tiled_matmul_full_spad_ws( - /*spad_A=*/spad_addr_P, /*spad_B=*/spad_addr_V, - /*spad_D=*/0, /*spad_C=*/spad_addr_O, - /*I=*/(B_ROW / DIM), /*J=*/(HEADDIM / DIM), /*K=*/(B_COL/ DIM), - /*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0, - /*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0, - /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_mvout_spad); - gemmini_fence(); - - if constexpr (DEBUG) { - // for copy-out to GMEM - gemmini_fence(); - } - } - - // reconverge from mmio divergence - threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core); - - asm volatile("gemm_pv_finish_%=:" ::); - - if constexpr (DEBUG) { - if (warpgroup_id == 0) { - // O after PV - if (tile_k == 0) { - thread_block_copy_tile( - smem_O, gmem_tmp_d6, tid_in_warpgroup, threads_per_warpgroup, - warpgroup_id_in_cluster); - } else if (tile_k == 1) { - thread_block_copy_tile( - smem_O, gmem_tmp_d7, tid_in_warpgroup, threads_per_warpgroup, - warpgroup_id_in_cluster); - } - - threadblock_barrier(warpgroup_id_in_cluster, - warps_per_warpgroup_per_core); - } - } #if 0 #endif } asm volatile ("tile_loop_finish_%=:" :: ); - // wait for warpgroup 1 to finish, which called the global barrier before - // entering the loop - if (warpgroup_id == 0) { - threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core); - } + // // wait for warpgroup 1 to finish, which called the global barrier before + // // entering the loop + // if (warpgroup_id == 0) { + // threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core); + // } } int main() { diff --git a/tests/regression/sgemm_tcore/sgemm_impl.hpp b/tests/regression/sgemm_tcore/sgemm_impl.hpp index d24c61d6..05692308 100644 --- a/tests/regression/sgemm_tcore/sgemm_impl.hpp +++ b/tests/regression/sgemm_tcore/sgemm_impl.hpp @@ -73,7 +73,7 @@ static_assert(WMITER * WNITER * TCM * TCN * NUM_WARPS * CORES_PER_CLUSTER == #define TRANSPOSE_AT_CONSUME 0 #define GEMMINI_DMA 1 -#define GEMMINI_DMA_FAST 1 +#define GEMMINI_DMA_FAST 0 #define GEMMINI_DMA_FLEXIBLE_LAYOUT 1 #if SMEM_SIZE == 0x4000 #define SMEM_ADDR_Q0 ((float * const) 0xff000000)