From 1c9b02215652a9d4c9c59556a9b28c822867afdf Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Sat, 9 Nov 2024 19:58:45 -0800 Subject: [PATCH] flash: Rename nowarpspec to default --- tests/regression/flash_attention/Makefile | 4 +- .../regression/flash_attention/flash_impl.hpp | 1 + .../flash_attention/kernel.gemmini.cpp | 152 ++++++++---------- ...rpspec.cpp => kernel.gemmini.warpspec.cpp} | 152 ++++++++++-------- 4 files changed, 155 insertions(+), 154 deletions(-) rename tests/regression/flash_attention/{kernel.gemmini.nowarpspec.cpp => kernel.gemmini.warpspec.cpp} (87%) diff --git a/tests/regression/flash_attention/Makefile b/tests/regression/flash_attention/Makefile index 2f2c1c6a..3a25e4f3 100644 --- a/tests/regression/flash_attention/Makefile +++ b/tests/regression/flash_attention/Makefile @@ -3,8 +3,8 @@ PROJECT = flash_attention SRCS = main.cpp common.h # VX_SRCS = kernel.cpp -# VX_SRCS = kernel.gemmini.cpp -VX_SRCS = kernel.gemmini.nowarpspec.cpp +# VX_SRCS = kernel.gemmini.warpspec.cpp +VX_SRCS = kernel.gemmini.cpp VX_INCLUDES = flash_impl.hpp ../sgemm_tcore/sgemm_impl.hpp OPTS ?= -n16 diff --git a/tests/regression/flash_attention/flash_impl.hpp b/tests/regression/flash_attention/flash_impl.hpp index 096333e5..6f210b75 100644 --- a/tests/regression/flash_attention/flash_impl.hpp +++ b/tests/regression/flash_attention/flash_impl.hpp @@ -15,6 +15,7 @@ constexpr uint32_t ROWMAX_SETS = 3; // constexpr bool WARP_SPECIALIZED = true; +// constexpr bool GEMMINI_WARP_SPECIALIZED = false; // constexpr bool TENSOR_CORE = true; constexpr bool WARP_SPECIALIZED = false; constexpr bool GEMMINI_WARP_SPECIALIZED = false; diff --git a/tests/regression/flash_attention/kernel.gemmini.cpp b/tests/regression/flash_attention/kernel.gemmini.cpp index f2a2e471..79079811 100644 --- a/tests/regression/flash_attention/kernel.gemmini.cpp +++ b/tests/regression/flash_attention/kernel.gemmini.cpp @@ -342,16 +342,6 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { // threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core); // } - constexpr uint32_t threads_per_warpgroup_simt = - threads_per_warpgroup - - CORES_PER_CLUSTER * NUM_THREADS /*warp 0, 4, 8, 12*/; - constexpr uint32_t warpgroup_id_simt = 1; - constexpr uint32_t barrier_id_simt = 1; - constexpr uint32_t barrier_count_simt = NUM_WARPS - 1; - const uint32_t tid_in_warpgroup_simt = - tid_in_warpgroup - (CORES_PER_CLUSTER * NUM_THREADS); - static_assert(barrier_id_simt == 1 && barrier_count_simt == 7); - asm volatile ("tile_loop_start_%=:" :: ); // "inner loop" along the columns of K^T @@ -411,8 +401,8 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { const auto spad_hex_O = smem_O0_hexadecile; // NOTE: there's only single O tile asm volatile ("dbuf_sel_end_%=:" :: ); - if (vx_warp_id() == 0 /* warp 0 in every core */) { - if (tile_k >= 2) // delay by 2 iters for pipelining + { + if (tile_k >= 2) // delay GEMM II by 2 iters for pipelining { const uint32_t tile_k_ = tile_k - 2; @@ -457,16 +447,17 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { asm volatile("gemm_qk_start_%=:" ::); if (tid_in_warpgroup == 0) { - // fence to GEMM II completion - gemmini_fence(); + // FIXME: remove + // // fence to GEMM II completion + // gemmini_fence(); -#ifdef FENCE_GEMM_II - asm volatile("rescale_fence_write_start_%=:" ::); - // signal that GEMM II is finished to O rescale step - *smem_O_flag = 1; - vx_fence(); - asm volatile("rescale_fence_write_end_%=:" ::); -#endif +// #ifdef FENCE_GEMM_II +// asm volatile("rescale_fence_write_start_%=:" ::); +// // signal that GEMM II is finished to O rescale step +// *smem_O_flag = 1; +// vx_fence(); +// asm volatile("rescale_fence_write_end_%=:" ::); +// #endif // Kick off GEMM I // @@ -499,14 +490,6 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { const float *gmem_V_tile = gmem_V + (HEADDIM * B_COL * (tile_k - 1 /*dragbehind*/)); -#if 0 - // fence mvout S to SMEM - gemmini_fence(); - ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, (uint64_t)(gmem_K_tile), - (uint64_t)(gmem_V_tile), - k_LOOP_WS_CONFIG_ADDRS_AB) -#endif - // do DMA if (tile_k == 0) { // // configure address strides for the DMA @@ -545,24 +528,24 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips); #endif } - - // fence everything before going to the next tile - gemmini_fence(); } - // threadblock_barrier(warpgroup_id_in_cluster, - // warps_per_warpgroup_per_core); + // reconverge from mmio divergence + threadblock_barrier(warpgroup_id_in_cluster, + warps_per_warpgroup_per_core); asm volatile("move_k_v_finish_%=:" ::); + // FIXME: remove for nowarpspec + // // NOTE: cannot put barrier here; thread 1-7 in warp 0 will skip the // branch and call this barrier earlier than when thread 0 finishes. // Since tmask is not considered, that will be a barrier resolve done too // early // threadblock_barrier(0, 1); + } - } else /* warp_id != 0 */ { - + { if (tile_k >= 1) // delay online softmax by 1 iters { const uint32_t tile_k_ = tile_k - 1; @@ -572,46 +555,49 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { if (warpgroup_id == 0) { if (tile_k_ == 0) { thread_block_copy_tile( - smem_S_consume, gmem_tmp_d0, tid_in_warpgroup_simt, - threads_per_warpgroup_simt, warpgroup_id_simt); + smem_S_consume, gmem_tmp_d0, tid_in_warpgroup, + threads_per_warpgroup, warpgroup_id); } else if (tile_k_ == 1) { thread_block_copy_tile( - smem_S_consume, gmem_tmp_d1, tid_in_warpgroup_simt, - threads_per_warpgroup_simt, warpgroup_id_simt); + smem_S_consume, gmem_tmp_d1, tid_in_warpgroup, + threads_per_warpgroup, warpgroup_id); } - threadblock_barrier(barrier_id_simt, barrier_count_simt); + threadblock_barrier(warpgroup_id_in_cluster, + warps_per_warpgroup_per_core); } } // Online softmax // thread_block_online_softmax( - smem_S_consume, smem_P_produce, tid_in_warpgroup_simt, - threads_per_warpgroup_simt, warpgroup_id_simt, smem_scratchpad, + smem_S_consume, smem_P_produce, tid_in_warpgroup, + threads_per_warpgroup, warpgroup_id, smem_scratchpad, smem_rowmax, smem_rowsum, smem_O_row_scale); - threadblock_barrier(barrier_id_simt, barrier_count_simt); + threadblock_barrier(warpgroup_id_in_cluster, + warps_per_warpgroup_per_core); if constexpr (DEBUG) { if (warpgroup_id == 0) { if (tile_k_ == 0) { thread_block_copy_rowmax( - smem_rowmax, gmem_tmp_e0, tid_in_warpgroup_simt, - threads_per_warpgroup_simt, warpgroup_id_simt); + smem_rowmax, gmem_tmp_e0, tid_in_warpgroup, + threads_per_warpgroup, warpgroup_id_in_cluster); thread_block_copy_rowmax( - smem_rowsum, gmem_tmp_e2, tid_in_warpgroup_simt, - threads_per_warpgroup_simt, warpgroup_id_simt); + smem_rowsum, gmem_tmp_e2, tid_in_warpgroup, + threads_per_warpgroup, warpgroup_id_in_cluster); } else if (tile_k_ == 1) { thread_block_copy_rowmax(smem_rowmax, gmem_tmp_e1, - tid_in_warpgroup_simt, threads_per_warpgroup_simt, - warpgroup_id_simt); + tid_in_warpgroup, threads_per_warpgroup, + warpgroup_id_in_cluster); thread_block_copy_rowmax(smem_rowsum, gmem_tmp_e3, - tid_in_warpgroup_simt, threads_per_warpgroup_simt, - warpgroup_id_simt); + tid_in_warpgroup, threads_per_warpgroup, + warpgroup_id_in_cluster); } - threadblock_barrier(barrier_id_simt, barrier_count_simt); + threadblock_barrier(warpgroup_id_in_cluster, + warps_per_warpgroup_per_core); } } @@ -619,7 +605,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { asm volatile("rescale_fence_read_start_%=:" ::); // check flag to make sure GEMM II finished and read-after-write // dependency on O tile is settled for rescale - if (tid_in_warpgroup_simt == 0) { + if (tid_in_warpgroup == 0) { while ((*smem_O_flag) != 1) ; // set it back to 0 for the next tile iteration @@ -643,74 +629,66 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { #endif if constexpr (DEBUG) { - if (warpgroup_id == 0) { + if (warpgroup_id_in_cluster == 0) { gemmini_fence(); gemmini_fence(); // O after PV if (tile_k_ == 1 /*wait until GEMM II finshes */) { thread_block_copy_tile( - smem_O, gmem_tmp_d6, tid_in_warpgroup_simt, threads_per_warpgroup_simt, - warpgroup_id_simt); + smem_O, gmem_tmp_d6, tid_in_warpgroup, threads_per_warpgroup, + warpgroup_id_in_cluster); } else if (tile_k_ == 2) { thread_block_copy_tile( - smem_O, gmem_tmp_d7, tid_in_warpgroup_simt, threads_per_warpgroup_simt, - warpgroup_id_simt); + smem_O, gmem_tmp_d7, tid_in_warpgroup, threads_per_warpgroup, + warpgroup_id_in_cluster); } - threadblock_barrier(barrier_id_simt, barrier_count_simt); + threadblock_barrier(warpgroup_id_in_cluster, + warps_per_warpgroup_per_core); } } // Oi rescale thread_block_O_rescale( - smem_O, smem_O /*in-place*/, smem_O_row_scale, - tid_in_warpgroup_simt, threads_per_warpgroup_simt, - warpgroup_id_simt); + smem_O, smem_O /*in-place*/, smem_O_row_scale, tid_in_warpgroup, + threads_per_warpgroup, warpgroup_id_in_cluster); // rescale-to-PV-GEMM barrier - threadblock_barrier(barrier_id_simt, barrier_count_simt); + threadblock_barrier(warpgroup_id_in_cluster, + warps_per_warpgroup_per_core); if constexpr (DEBUG) { - if (warpgroup_id == 0) { + if (warpgroup_id_in_cluster == 0) { // O before PV if (tile_k_ == 0) { thread_block_copy_tile( - smem_P_produce, gmem_tmp_d2, tid_in_warpgroup_simt, - threads_per_warpgroup_simt, warpgroup_id_simt); + smem_P_produce, gmem_tmp_d2, tid_in_warpgroup, + threads_per_warpgroup, warpgroup_id_in_cluster); thread_block_copy_tile( - smem_O, gmem_tmp_d4, tid_in_warpgroup_simt, - threads_per_warpgroup_simt, warpgroup_id_simt); + smem_O, gmem_tmp_d4, tid_in_warpgroup, threads_per_warpgroup, + warpgroup_id_in_cluster); } else if (tile_k_ == 1) { thread_block_copy_tile( - smem_P_produce, gmem_tmp_d3, tid_in_warpgroup_simt, - threads_per_warpgroup_simt, warpgroup_id_simt); + smem_P_produce, gmem_tmp_d3, tid_in_warpgroup, + threads_per_warpgroup, warpgroup_id_in_cluster); thread_block_copy_tile( - smem_O, gmem_tmp_d5, tid_in_warpgroup_simt, - threads_per_warpgroup_simt, warpgroup_id_simt); + smem_O, gmem_tmp_d5, tid_in_warpgroup, threads_per_warpgroup, + warpgroup_id_in_cluster); } - threadblock_barrier(barrier_id_simt, barrier_count_simt); + threadblock_barrier(warpgroup_id_in_cluster, + warps_per_warpgroup_per_core); } } } -#if 0 - // fence GEMM I after Oi rescale - if (tid_in_warpgroup == 0) { - gemmini_fence(); - gemmini_fence(); - gemmini_fence(); - gemmini_fence(); - } - - // reconverge from mmio divergence + // intra-warpgroup barrier threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core); -#endif - // intra-warpgroup barrier - threadblock_barrier(barrier_id_simt, barrier_count_simt); + // fence everything before going to the next tile + gemmini_fence(); } } diff --git a/tests/regression/flash_attention/kernel.gemmini.nowarpspec.cpp b/tests/regression/flash_attention/kernel.gemmini.warpspec.cpp similarity index 87% rename from tests/regression/flash_attention/kernel.gemmini.nowarpspec.cpp rename to tests/regression/flash_attention/kernel.gemmini.warpspec.cpp index 79079811..f2a2e471 100644 --- a/tests/regression/flash_attention/kernel.gemmini.nowarpspec.cpp +++ b/tests/regression/flash_attention/kernel.gemmini.warpspec.cpp @@ -342,6 +342,16 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { // threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core); // } + constexpr uint32_t threads_per_warpgroup_simt = + threads_per_warpgroup - + CORES_PER_CLUSTER * NUM_THREADS /*warp 0, 4, 8, 12*/; + constexpr uint32_t warpgroup_id_simt = 1; + constexpr uint32_t barrier_id_simt = 1; + constexpr uint32_t barrier_count_simt = NUM_WARPS - 1; + const uint32_t tid_in_warpgroup_simt = + tid_in_warpgroup - (CORES_PER_CLUSTER * NUM_THREADS); + static_assert(barrier_id_simt == 1 && barrier_count_simt == 7); + asm volatile ("tile_loop_start_%=:" :: ); // "inner loop" along the columns of K^T @@ -401,8 +411,8 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { const auto spad_hex_O = smem_O0_hexadecile; // NOTE: there's only single O tile asm volatile ("dbuf_sel_end_%=:" :: ); - { - if (tile_k >= 2) // delay GEMM II by 2 iters for pipelining + if (vx_warp_id() == 0 /* warp 0 in every core */) { + if (tile_k >= 2) // delay by 2 iters for pipelining { const uint32_t tile_k_ = tile_k - 2; @@ -447,17 +457,16 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { asm volatile("gemm_qk_start_%=:" ::); if (tid_in_warpgroup == 0) { - // FIXME: remove - // // fence to GEMM II completion - // gemmini_fence(); + // fence to GEMM II completion + gemmini_fence(); -// #ifdef FENCE_GEMM_II -// asm volatile("rescale_fence_write_start_%=:" ::); -// // signal that GEMM II is finished to O rescale step -// *smem_O_flag = 1; -// vx_fence(); -// asm volatile("rescale_fence_write_end_%=:" ::); -// #endif +#ifdef FENCE_GEMM_II + asm volatile("rescale_fence_write_start_%=:" ::); + // signal that GEMM II is finished to O rescale step + *smem_O_flag = 1; + vx_fence(); + asm volatile("rescale_fence_write_end_%=:" ::); +#endif // Kick off GEMM I // @@ -490,6 +499,14 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { const float *gmem_V_tile = gmem_V + (HEADDIM * B_COL * (tile_k - 1 /*dragbehind*/)); +#if 0 + // fence mvout S to SMEM + gemmini_fence(); + ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, (uint64_t)(gmem_K_tile), + (uint64_t)(gmem_V_tile), + k_LOOP_WS_CONFIG_ADDRS_AB) +#endif + // do DMA if (tile_k == 0) { // // configure address strides for the DMA @@ -528,24 +545,24 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips); #endif } + + // fence everything before going to the next tile + gemmini_fence(); } - // reconverge from mmio divergence - threadblock_barrier(warpgroup_id_in_cluster, - warps_per_warpgroup_per_core); + // threadblock_barrier(warpgroup_id_in_cluster, + // warps_per_warpgroup_per_core); asm volatile("move_k_v_finish_%=:" ::); - // FIXME: remove for nowarpspec - // // NOTE: cannot put barrier here; thread 1-7 in warp 0 will skip the // branch and call this barrier earlier than when thread 0 finishes. // Since tmask is not considered, that will be a barrier resolve done too // early // threadblock_barrier(0, 1); - } - { + } else /* warp_id != 0 */ { + if (tile_k >= 1) // delay online softmax by 1 iters { const uint32_t tile_k_ = tile_k - 1; @@ -555,49 +572,46 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { if (warpgroup_id == 0) { if (tile_k_ == 0) { thread_block_copy_tile( - smem_S_consume, gmem_tmp_d0, tid_in_warpgroup, - threads_per_warpgroup, warpgroup_id); + smem_S_consume, gmem_tmp_d0, tid_in_warpgroup_simt, + threads_per_warpgroup_simt, warpgroup_id_simt); } else if (tile_k_ == 1) { thread_block_copy_tile( - smem_S_consume, gmem_tmp_d1, tid_in_warpgroup, - threads_per_warpgroup, warpgroup_id); + smem_S_consume, gmem_tmp_d1, tid_in_warpgroup_simt, + threads_per_warpgroup_simt, warpgroup_id_simt); } - threadblock_barrier(warpgroup_id_in_cluster, - warps_per_warpgroup_per_core); + threadblock_barrier(barrier_id_simt, barrier_count_simt); } } // Online softmax // thread_block_online_softmax( - smem_S_consume, smem_P_produce, tid_in_warpgroup, - threads_per_warpgroup, warpgroup_id, smem_scratchpad, + smem_S_consume, smem_P_produce, tid_in_warpgroup_simt, + threads_per_warpgroup_simt, warpgroup_id_simt, smem_scratchpad, smem_rowmax, smem_rowsum, smem_O_row_scale); - threadblock_barrier(warpgroup_id_in_cluster, - warps_per_warpgroup_per_core); + threadblock_barrier(barrier_id_simt, barrier_count_simt); if constexpr (DEBUG) { if (warpgroup_id == 0) { if (tile_k_ == 0) { thread_block_copy_rowmax( - smem_rowmax, gmem_tmp_e0, tid_in_warpgroup, - threads_per_warpgroup, warpgroup_id_in_cluster); + smem_rowmax, gmem_tmp_e0, tid_in_warpgroup_simt, + threads_per_warpgroup_simt, warpgroup_id_simt); thread_block_copy_rowmax( - smem_rowsum, gmem_tmp_e2, tid_in_warpgroup, - threads_per_warpgroup, warpgroup_id_in_cluster); + smem_rowsum, gmem_tmp_e2, tid_in_warpgroup_simt, + threads_per_warpgroup_simt, warpgroup_id_simt); } else if (tile_k_ == 1) { thread_block_copy_rowmax(smem_rowmax, gmem_tmp_e1, - tid_in_warpgroup, threads_per_warpgroup, - warpgroup_id_in_cluster); + tid_in_warpgroup_simt, threads_per_warpgroup_simt, + warpgroup_id_simt); thread_block_copy_rowmax(smem_rowsum, gmem_tmp_e3, - tid_in_warpgroup, threads_per_warpgroup, - warpgroup_id_in_cluster); + tid_in_warpgroup_simt, threads_per_warpgroup_simt, + warpgroup_id_simt); } - threadblock_barrier(warpgroup_id_in_cluster, - warps_per_warpgroup_per_core); + threadblock_barrier(barrier_id_simt, barrier_count_simt); } } @@ -605,7 +619,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { asm volatile("rescale_fence_read_start_%=:" ::); // check flag to make sure GEMM II finished and read-after-write // dependency on O tile is settled for rescale - if (tid_in_warpgroup == 0) { + if (tid_in_warpgroup_simt == 0) { while ((*smem_O_flag) != 1) ; // set it back to 0 for the next tile iteration @@ -629,66 +643,74 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { #endif if constexpr (DEBUG) { - if (warpgroup_id_in_cluster == 0) { + if (warpgroup_id == 0) { gemmini_fence(); gemmini_fence(); // O after PV if (tile_k_ == 1 /*wait until GEMM II finshes */) { thread_block_copy_tile( - smem_O, gmem_tmp_d6, tid_in_warpgroup, threads_per_warpgroup, - warpgroup_id_in_cluster); + smem_O, gmem_tmp_d6, tid_in_warpgroup_simt, threads_per_warpgroup_simt, + warpgroup_id_simt); } else if (tile_k_ == 2) { thread_block_copy_tile( - smem_O, gmem_tmp_d7, tid_in_warpgroup, threads_per_warpgroup, - warpgroup_id_in_cluster); + smem_O, gmem_tmp_d7, tid_in_warpgroup_simt, threads_per_warpgroup_simt, + warpgroup_id_simt); } - threadblock_barrier(warpgroup_id_in_cluster, - warps_per_warpgroup_per_core); + threadblock_barrier(barrier_id_simt, barrier_count_simt); } } // Oi rescale thread_block_O_rescale( - smem_O, smem_O /*in-place*/, smem_O_row_scale, tid_in_warpgroup, - threads_per_warpgroup, warpgroup_id_in_cluster); + smem_O, smem_O /*in-place*/, smem_O_row_scale, + tid_in_warpgroup_simt, threads_per_warpgroup_simt, + warpgroup_id_simt); // rescale-to-PV-GEMM barrier - threadblock_barrier(warpgroup_id_in_cluster, - warps_per_warpgroup_per_core); + threadblock_barrier(barrier_id_simt, barrier_count_simt); if constexpr (DEBUG) { - if (warpgroup_id_in_cluster == 0) { + if (warpgroup_id == 0) { // O before PV if (tile_k_ == 0) { thread_block_copy_tile( - smem_P_produce, gmem_tmp_d2, tid_in_warpgroup, - threads_per_warpgroup, warpgroup_id_in_cluster); + smem_P_produce, gmem_tmp_d2, tid_in_warpgroup_simt, + threads_per_warpgroup_simt, warpgroup_id_simt); thread_block_copy_tile( - smem_O, gmem_tmp_d4, tid_in_warpgroup, threads_per_warpgroup, - warpgroup_id_in_cluster); + smem_O, gmem_tmp_d4, tid_in_warpgroup_simt, + threads_per_warpgroup_simt, warpgroup_id_simt); } else if (tile_k_ == 1) { thread_block_copy_tile( - smem_P_produce, gmem_tmp_d3, tid_in_warpgroup, - threads_per_warpgroup, warpgroup_id_in_cluster); + smem_P_produce, gmem_tmp_d3, tid_in_warpgroup_simt, + threads_per_warpgroup_simt, warpgroup_id_simt); thread_block_copy_tile( - smem_O, gmem_tmp_d5, tid_in_warpgroup, threads_per_warpgroup, - warpgroup_id_in_cluster); + smem_O, gmem_tmp_d5, tid_in_warpgroup_simt, + threads_per_warpgroup_simt, warpgroup_id_simt); } - threadblock_barrier(warpgroup_id_in_cluster, - warps_per_warpgroup_per_core); + threadblock_barrier(barrier_id_simt, barrier_count_simt); } } } - // intra-warpgroup barrier +#if 0 + // fence GEMM I after Oi rescale + if (tid_in_warpgroup == 0) { + gemmini_fence(); + gemmini_fence(); + gemmini_fence(); + gemmini_fence(); + } + + // reconverge from mmio divergence threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core); +#endif - // fence everything before going to the next tile - gemmini_fence(); + // intra-warpgroup barrier + threadblock_barrier(barrier_id_simt, barrier_count_simt); } }