flash: Rename nowarpspec to default

This commit is contained in:
Hansung Kim
2024-11-09 19:58:45 -08:00
parent 8fe6d918f2
commit 1c9b022156
4 changed files with 155 additions and 154 deletions

View File

@@ -3,8 +3,8 @@ PROJECT = flash_attention
SRCS = main.cpp common.h
# VX_SRCS = kernel.cpp
# VX_SRCS = kernel.gemmini.cpp
VX_SRCS = kernel.gemmini.nowarpspec.cpp
# VX_SRCS = kernel.gemmini.warpspec.cpp
VX_SRCS = kernel.gemmini.cpp
VX_INCLUDES = flash_impl.hpp ../sgemm_tcore/sgemm_impl.hpp
OPTS ?= -n16

View File

@@ -15,6 +15,7 @@
constexpr uint32_t ROWMAX_SETS = 3;
// constexpr bool WARP_SPECIALIZED = true;
// constexpr bool GEMMINI_WARP_SPECIALIZED = false;
// constexpr bool TENSOR_CORE = true;
constexpr bool WARP_SPECIALIZED = false;
constexpr bool GEMMINI_WARP_SPECIALIZED = false;

View File

@@ -342,16 +342,6 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
// threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core);
// }
constexpr uint32_t threads_per_warpgroup_simt =
threads_per_warpgroup -
CORES_PER_CLUSTER * NUM_THREADS /*warp 0, 4, 8, 12*/;
constexpr uint32_t warpgroup_id_simt = 1;
constexpr uint32_t barrier_id_simt = 1;
constexpr uint32_t barrier_count_simt = NUM_WARPS - 1;
const uint32_t tid_in_warpgroup_simt =
tid_in_warpgroup - (CORES_PER_CLUSTER * NUM_THREADS);
static_assert(barrier_id_simt == 1 && barrier_count_simt == 7);
asm volatile ("tile_loop_start_%=:" :: );
// "inner loop" along the columns of K^T
@@ -411,8 +401,8 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
const auto spad_hex_O = smem_O0_hexadecile; // NOTE: there's only single O tile
asm volatile ("dbuf_sel_end_%=:" :: );
if (vx_warp_id() == 0 /* warp 0 in every core */) {
if (tile_k >= 2) // delay by 2 iters for pipelining
{
if (tile_k >= 2) // delay GEMM II by 2 iters for pipelining
{
const uint32_t tile_k_ = tile_k - 2;
@@ -457,16 +447,17 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
asm volatile("gemm_qk_start_%=:" ::);
if (tid_in_warpgroup == 0) {
// fence to GEMM II completion
gemmini_fence();
// FIXME: remove
// // fence to GEMM II completion
// gemmini_fence();
#ifdef FENCE_GEMM_II
asm volatile("rescale_fence_write_start_%=:" ::);
// signal that GEMM II is finished to O rescale step
*smem_O_flag = 1;
vx_fence();
asm volatile("rescale_fence_write_end_%=:" ::);
#endif
// #ifdef FENCE_GEMM_II
// asm volatile("rescale_fence_write_start_%=:" ::);
// // signal that GEMM II is finished to O rescale step
// *smem_O_flag = 1;
// vx_fence();
// asm volatile("rescale_fence_write_end_%=:" ::);
// #endif
// Kick off GEMM I
//
@@ -499,14 +490,6 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
const float *gmem_V_tile =
gmem_V + (HEADDIM * B_COL * (tile_k - 1 /*dragbehind*/));
#if 0
// fence mvout S to SMEM
gemmini_fence();
ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, (uint64_t)(gmem_K_tile),
(uint64_t)(gmem_V_tile),
k_LOOP_WS_CONFIG_ADDRS_AB)
#endif
// do DMA
if (tile_k == 0) {
// // configure address strides for the DMA
@@ -545,24 +528,24 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
/*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips);
#endif
}
// fence everything before going to the next tile
gemmini_fence();
}
// threadblock_barrier(warpgroup_id_in_cluster,
// warps_per_warpgroup_per_core);
// reconverge from mmio divergence
threadblock_barrier(warpgroup_id_in_cluster,
warps_per_warpgroup_per_core);
asm volatile("move_k_v_finish_%=:" ::);
// FIXME: remove for nowarpspec
//
// NOTE: cannot put barrier here; thread 1-7 in warp 0 will skip the
// branch and call this barrier earlier than when thread 0 finishes.
// Since tmask is not considered, that will be a barrier resolve done too
// early
// threadblock_barrier(0, 1);
}
} else /* warp_id != 0 */ {
{
if (tile_k >= 1) // delay online softmax by 1 iters
{
const uint32_t tile_k_ = tile_k - 1;
@@ -572,46 +555,49 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
if (warpgroup_id == 0) {
if (tile_k_ == 0) {
thread_block_copy_tile<B_ROW, B_COL, GEMMINI_DMA>(
smem_S_consume, gmem_tmp_d0, tid_in_warpgroup_simt,
threads_per_warpgroup_simt, warpgroup_id_simt);
smem_S_consume, gmem_tmp_d0, tid_in_warpgroup,
threads_per_warpgroup, warpgroup_id);
} else if (tile_k_ == 1) {
thread_block_copy_tile<B_ROW, B_COL, GEMMINI_DMA>(
smem_S_consume, gmem_tmp_d1, tid_in_warpgroup_simt,
threads_per_warpgroup_simt, warpgroup_id_simt);
smem_S_consume, gmem_tmp_d1, tid_in_warpgroup,
threads_per_warpgroup, warpgroup_id);
}
threadblock_barrier(barrier_id_simt, barrier_count_simt);
threadblock_barrier(warpgroup_id_in_cluster,
warps_per_warpgroup_per_core);
}
}
// Online softmax
//
thread_block_online_softmax</*block_row_major=*/GEMMINI_DMA>(
smem_S_consume, smem_P_produce, tid_in_warpgroup_simt,
threads_per_warpgroup_simt, warpgroup_id_simt, smem_scratchpad,
smem_S_consume, smem_P_produce, tid_in_warpgroup,
threads_per_warpgroup, warpgroup_id, smem_scratchpad,
smem_rowmax, smem_rowsum, smem_O_row_scale);
threadblock_barrier(barrier_id_simt, barrier_count_simt);
threadblock_barrier(warpgroup_id_in_cluster,
warps_per_warpgroup_per_core);
if constexpr (DEBUG) {
if (warpgroup_id == 0) {
if (tile_k_ == 0) {
thread_block_copy_rowmax(
smem_rowmax, gmem_tmp_e0, tid_in_warpgroup_simt,
threads_per_warpgroup_simt, warpgroup_id_simt);
smem_rowmax, gmem_tmp_e0, tid_in_warpgroup,
threads_per_warpgroup, warpgroup_id_in_cluster);
thread_block_copy_rowmax(
smem_rowsum, gmem_tmp_e2, tid_in_warpgroup_simt,
threads_per_warpgroup_simt, warpgroup_id_simt);
smem_rowsum, gmem_tmp_e2, tid_in_warpgroup,
threads_per_warpgroup, warpgroup_id_in_cluster);
} else if (tile_k_ == 1) {
thread_block_copy_rowmax(smem_rowmax, gmem_tmp_e1,
tid_in_warpgroup_simt, threads_per_warpgroup_simt,
warpgroup_id_simt);
tid_in_warpgroup, threads_per_warpgroup,
warpgroup_id_in_cluster);
thread_block_copy_rowmax(smem_rowsum, gmem_tmp_e3,
tid_in_warpgroup_simt, threads_per_warpgroup_simt,
warpgroup_id_simt);
tid_in_warpgroup, threads_per_warpgroup,
warpgroup_id_in_cluster);
}
threadblock_barrier(barrier_id_simt, barrier_count_simt);
threadblock_barrier(warpgroup_id_in_cluster,
warps_per_warpgroup_per_core);
}
}
@@ -619,7 +605,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
asm volatile("rescale_fence_read_start_%=:" ::);
// check flag to make sure GEMM II finished and read-after-write
// dependency on O tile is settled for rescale
if (tid_in_warpgroup_simt == 0) {
if (tid_in_warpgroup == 0) {
while ((*smem_O_flag) != 1)
;
// set it back to 0 for the next tile iteration
@@ -643,74 +629,66 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
#endif
if constexpr (DEBUG) {
if (warpgroup_id == 0) {
if (warpgroup_id_in_cluster == 0) {
gemmini_fence();
gemmini_fence();
// O after PV
if (tile_k_ == 1 /*wait until GEMM II finshes */) {
thread_block_copy_tile<B_ROW, HEADDIM, GEMMINI_DMA>(
smem_O, gmem_tmp_d6, tid_in_warpgroup_simt, threads_per_warpgroup_simt,
warpgroup_id_simt);
smem_O, gmem_tmp_d6, tid_in_warpgroup, threads_per_warpgroup,
warpgroup_id_in_cluster);
} else if (tile_k_ == 2) {
thread_block_copy_tile<B_ROW, HEADDIM, GEMMINI_DMA>(
smem_O, gmem_tmp_d7, tid_in_warpgroup_simt, threads_per_warpgroup_simt,
warpgroup_id_simt);
smem_O, gmem_tmp_d7, tid_in_warpgroup, threads_per_warpgroup,
warpgroup_id_in_cluster);
}
threadblock_barrier(barrier_id_simt, barrier_count_simt);
threadblock_barrier(warpgroup_id_in_cluster,
warps_per_warpgroup_per_core);
}
}
// Oi rescale
thread_block_O_rescale</*block_row_major=*/GEMMINI_DMA>(
smem_O, smem_O /*in-place*/, smem_O_row_scale,
tid_in_warpgroup_simt, threads_per_warpgroup_simt,
warpgroup_id_simt);
smem_O, smem_O /*in-place*/, smem_O_row_scale, tid_in_warpgroup,
threads_per_warpgroup, warpgroup_id_in_cluster);
// rescale-to-PV-GEMM barrier
threadblock_barrier(barrier_id_simt, barrier_count_simt);
threadblock_barrier(warpgroup_id_in_cluster,
warps_per_warpgroup_per_core);
if constexpr (DEBUG) {
if (warpgroup_id == 0) {
if (warpgroup_id_in_cluster == 0) {
// O before PV
if (tile_k_ == 0) {
thread_block_copy_tile<B_ROW, B_COL, GEMMINI_DMA>(
smem_P_produce, gmem_tmp_d2, tid_in_warpgroup_simt,
threads_per_warpgroup_simt, warpgroup_id_simt);
smem_P_produce, gmem_tmp_d2, tid_in_warpgroup,
threads_per_warpgroup, warpgroup_id_in_cluster);
thread_block_copy_tile<B_ROW, HEADDIM, GEMMINI_DMA>(
smem_O, gmem_tmp_d4, tid_in_warpgroup_simt,
threads_per_warpgroup_simt, warpgroup_id_simt);
smem_O, gmem_tmp_d4, tid_in_warpgroup, threads_per_warpgroup,
warpgroup_id_in_cluster);
} else if (tile_k_ == 1) {
thread_block_copy_tile<B_ROW, B_COL, GEMMINI_DMA>(
smem_P_produce, gmem_tmp_d3, tid_in_warpgroup_simt,
threads_per_warpgroup_simt, warpgroup_id_simt);
smem_P_produce, gmem_tmp_d3, tid_in_warpgroup,
threads_per_warpgroup, warpgroup_id_in_cluster);
thread_block_copy_tile<B_ROW, HEADDIM, GEMMINI_DMA>(
smem_O, gmem_tmp_d5, tid_in_warpgroup_simt,
threads_per_warpgroup_simt, warpgroup_id_simt);
smem_O, gmem_tmp_d5, tid_in_warpgroup, threads_per_warpgroup,
warpgroup_id_in_cluster);
}
threadblock_barrier(barrier_id_simt, barrier_count_simt);
threadblock_barrier(warpgroup_id_in_cluster,
warps_per_warpgroup_per_core);
}
}
}
#if 0
// fence GEMM I after Oi rescale
if (tid_in_warpgroup == 0) {
gemmini_fence();
gemmini_fence();
gemmini_fence();
gemmini_fence();
}
// reconverge from mmio divergence
// intra-warpgroup barrier
threadblock_barrier(warpgroup_id_in_cluster,
warps_per_warpgroup_per_core);
#endif
// intra-warpgroup barrier
threadblock_barrier(barrier_id_simt, barrier_count_simt);
// fence everything before going to the next tile
gemmini_fence();
}
}

View File

@@ -342,6 +342,16 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
// threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core);
// }
constexpr uint32_t threads_per_warpgroup_simt =
threads_per_warpgroup -
CORES_PER_CLUSTER * NUM_THREADS /*warp 0, 4, 8, 12*/;
constexpr uint32_t warpgroup_id_simt = 1;
constexpr uint32_t barrier_id_simt = 1;
constexpr uint32_t barrier_count_simt = NUM_WARPS - 1;
const uint32_t tid_in_warpgroup_simt =
tid_in_warpgroup - (CORES_PER_CLUSTER * NUM_THREADS);
static_assert(barrier_id_simt == 1 && barrier_count_simt == 7);
asm volatile ("tile_loop_start_%=:" :: );
// "inner loop" along the columns of K^T
@@ -401,8 +411,8 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
const auto spad_hex_O = smem_O0_hexadecile; // NOTE: there's only single O tile
asm volatile ("dbuf_sel_end_%=:" :: );
{
if (tile_k >= 2) // delay GEMM II by 2 iters for pipelining
if (vx_warp_id() == 0 /* warp 0 in every core */) {
if (tile_k >= 2) // delay by 2 iters for pipelining
{
const uint32_t tile_k_ = tile_k - 2;
@@ -447,17 +457,16 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
asm volatile("gemm_qk_start_%=:" ::);
if (tid_in_warpgroup == 0) {
// FIXME: remove
// // fence to GEMM II completion
// gemmini_fence();
// fence to GEMM II completion
gemmini_fence();
// #ifdef FENCE_GEMM_II
// asm volatile("rescale_fence_write_start_%=:" ::);
// // signal that GEMM II is finished to O rescale step
// *smem_O_flag = 1;
// vx_fence();
// asm volatile("rescale_fence_write_end_%=:" ::);
// #endif
#ifdef FENCE_GEMM_II
asm volatile("rescale_fence_write_start_%=:" ::);
// signal that GEMM II is finished to O rescale step
*smem_O_flag = 1;
vx_fence();
asm volatile("rescale_fence_write_end_%=:" ::);
#endif
// Kick off GEMM I
//
@@ -490,6 +499,14 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
const float *gmem_V_tile =
gmem_V + (HEADDIM * B_COL * (tile_k - 1 /*dragbehind*/));
#if 0
// fence mvout S to SMEM
gemmini_fence();
ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, (uint64_t)(gmem_K_tile),
(uint64_t)(gmem_V_tile),
k_LOOP_WS_CONFIG_ADDRS_AB)
#endif
// do DMA
if (tile_k == 0) {
// // configure address strides for the DMA
@@ -528,24 +545,24 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
/*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips);
#endif
}
// fence everything before going to the next tile
gemmini_fence();
}
// reconverge from mmio divergence
threadblock_barrier(warpgroup_id_in_cluster,
warps_per_warpgroup_per_core);
// threadblock_barrier(warpgroup_id_in_cluster,
// warps_per_warpgroup_per_core);
asm volatile("move_k_v_finish_%=:" ::);
// FIXME: remove for nowarpspec
//
// NOTE: cannot put barrier here; thread 1-7 in warp 0 will skip the
// branch and call this barrier earlier than when thread 0 finishes.
// Since tmask is not considered, that will be a barrier resolve done too
// early
// threadblock_barrier(0, 1);
}
{
} else /* warp_id != 0 */ {
if (tile_k >= 1) // delay online softmax by 1 iters
{
const uint32_t tile_k_ = tile_k - 1;
@@ -555,49 +572,46 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
if (warpgroup_id == 0) {
if (tile_k_ == 0) {
thread_block_copy_tile<B_ROW, B_COL, GEMMINI_DMA>(
smem_S_consume, gmem_tmp_d0, tid_in_warpgroup,
threads_per_warpgroup, warpgroup_id);
smem_S_consume, gmem_tmp_d0, tid_in_warpgroup_simt,
threads_per_warpgroup_simt, warpgroup_id_simt);
} else if (tile_k_ == 1) {
thread_block_copy_tile<B_ROW, B_COL, GEMMINI_DMA>(
smem_S_consume, gmem_tmp_d1, tid_in_warpgroup,
threads_per_warpgroup, warpgroup_id);
smem_S_consume, gmem_tmp_d1, tid_in_warpgroup_simt,
threads_per_warpgroup_simt, warpgroup_id_simt);
}
threadblock_barrier(warpgroup_id_in_cluster,
warps_per_warpgroup_per_core);
threadblock_barrier(barrier_id_simt, barrier_count_simt);
}
}
// Online softmax
//
thread_block_online_softmax</*block_row_major=*/GEMMINI_DMA>(
smem_S_consume, smem_P_produce, tid_in_warpgroup,
threads_per_warpgroup, warpgroup_id, smem_scratchpad,
smem_S_consume, smem_P_produce, tid_in_warpgroup_simt,
threads_per_warpgroup_simt, warpgroup_id_simt, smem_scratchpad,
smem_rowmax, smem_rowsum, smem_O_row_scale);
threadblock_barrier(warpgroup_id_in_cluster,
warps_per_warpgroup_per_core);
threadblock_barrier(barrier_id_simt, barrier_count_simt);
if constexpr (DEBUG) {
if (warpgroup_id == 0) {
if (tile_k_ == 0) {
thread_block_copy_rowmax(
smem_rowmax, gmem_tmp_e0, tid_in_warpgroup,
threads_per_warpgroup, warpgroup_id_in_cluster);
smem_rowmax, gmem_tmp_e0, tid_in_warpgroup_simt,
threads_per_warpgroup_simt, warpgroup_id_simt);
thread_block_copy_rowmax(
smem_rowsum, gmem_tmp_e2, tid_in_warpgroup,
threads_per_warpgroup, warpgroup_id_in_cluster);
smem_rowsum, gmem_tmp_e2, tid_in_warpgroup_simt,
threads_per_warpgroup_simt, warpgroup_id_simt);
} else if (tile_k_ == 1) {
thread_block_copy_rowmax(smem_rowmax, gmem_tmp_e1,
tid_in_warpgroup, threads_per_warpgroup,
warpgroup_id_in_cluster);
tid_in_warpgroup_simt, threads_per_warpgroup_simt,
warpgroup_id_simt);
thread_block_copy_rowmax(smem_rowsum, gmem_tmp_e3,
tid_in_warpgroup, threads_per_warpgroup,
warpgroup_id_in_cluster);
tid_in_warpgroup_simt, threads_per_warpgroup_simt,
warpgroup_id_simt);
}
threadblock_barrier(warpgroup_id_in_cluster,
warps_per_warpgroup_per_core);
threadblock_barrier(barrier_id_simt, barrier_count_simt);
}
}
@@ -605,7 +619,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
asm volatile("rescale_fence_read_start_%=:" ::);
// check flag to make sure GEMM II finished and read-after-write
// dependency on O tile is settled for rescale
if (tid_in_warpgroup == 0) {
if (tid_in_warpgroup_simt == 0) {
while ((*smem_O_flag) != 1)
;
// set it back to 0 for the next tile iteration
@@ -629,66 +643,74 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
#endif
if constexpr (DEBUG) {
if (warpgroup_id_in_cluster == 0) {
if (warpgroup_id == 0) {
gemmini_fence();
gemmini_fence();
// O after PV
if (tile_k_ == 1 /*wait until GEMM II finshes */) {
thread_block_copy_tile<B_ROW, HEADDIM, GEMMINI_DMA>(
smem_O, gmem_tmp_d6, tid_in_warpgroup, threads_per_warpgroup,
warpgroup_id_in_cluster);
smem_O, gmem_tmp_d6, tid_in_warpgroup_simt, threads_per_warpgroup_simt,
warpgroup_id_simt);
} else if (tile_k_ == 2) {
thread_block_copy_tile<B_ROW, HEADDIM, GEMMINI_DMA>(
smem_O, gmem_tmp_d7, tid_in_warpgroup, threads_per_warpgroup,
warpgroup_id_in_cluster);
smem_O, gmem_tmp_d7, tid_in_warpgroup_simt, threads_per_warpgroup_simt,
warpgroup_id_simt);
}
threadblock_barrier(warpgroup_id_in_cluster,
warps_per_warpgroup_per_core);
threadblock_barrier(barrier_id_simt, barrier_count_simt);
}
}
// Oi rescale
thread_block_O_rescale</*block_row_major=*/GEMMINI_DMA>(
smem_O, smem_O /*in-place*/, smem_O_row_scale, tid_in_warpgroup,
threads_per_warpgroup, warpgroup_id_in_cluster);
smem_O, smem_O /*in-place*/, smem_O_row_scale,
tid_in_warpgroup_simt, threads_per_warpgroup_simt,
warpgroup_id_simt);
// rescale-to-PV-GEMM barrier
threadblock_barrier(warpgroup_id_in_cluster,
warps_per_warpgroup_per_core);
threadblock_barrier(barrier_id_simt, barrier_count_simt);
if constexpr (DEBUG) {
if (warpgroup_id_in_cluster == 0) {
if (warpgroup_id == 0) {
// O before PV
if (tile_k_ == 0) {
thread_block_copy_tile<B_ROW, B_COL, GEMMINI_DMA>(
smem_P_produce, gmem_tmp_d2, tid_in_warpgroup,
threads_per_warpgroup, warpgroup_id_in_cluster);
smem_P_produce, gmem_tmp_d2, tid_in_warpgroup_simt,
threads_per_warpgroup_simt, warpgroup_id_simt);
thread_block_copy_tile<B_ROW, HEADDIM, GEMMINI_DMA>(
smem_O, gmem_tmp_d4, tid_in_warpgroup, threads_per_warpgroup,
warpgroup_id_in_cluster);
smem_O, gmem_tmp_d4, tid_in_warpgroup_simt,
threads_per_warpgroup_simt, warpgroup_id_simt);
} else if (tile_k_ == 1) {
thread_block_copy_tile<B_ROW, B_COL, GEMMINI_DMA>(
smem_P_produce, gmem_tmp_d3, tid_in_warpgroup,
threads_per_warpgroup, warpgroup_id_in_cluster);
smem_P_produce, gmem_tmp_d3, tid_in_warpgroup_simt,
threads_per_warpgroup_simt, warpgroup_id_simt);
thread_block_copy_tile<B_ROW, HEADDIM, GEMMINI_DMA>(
smem_O, gmem_tmp_d5, tid_in_warpgroup, threads_per_warpgroup,
warpgroup_id_in_cluster);
smem_O, gmem_tmp_d5, tid_in_warpgroup_simt,
threads_per_warpgroup_simt, warpgroup_id_simt);
}
threadblock_barrier(warpgroup_id_in_cluster,
warps_per_warpgroup_per_core);
threadblock_barrier(barrier_id_simt, barrier_count_simt);
}
}
}
// intra-warpgroup barrier
#if 0
// fence GEMM I after Oi rescale
if (tid_in_warpgroup == 0) {
gemmini_fence();
gemmini_fence();
gemmini_fence();
gemmini_fence();
}
// reconverge from mmio divergence
threadblock_barrier(warpgroup_id_in_cluster,
warps_per_warpgroup_per_core);
#endif
// fence everything before going to the next tile
gemmini_fence();
// intra-warpgroup barrier
threadblock_barrier(barrier_id_simt, barrier_count_simt);
}
}