flash: Reduce fence calls to improve util

This commit is contained in:
Hansung Kim
2024-11-09 16:44:17 -08:00
parent 6990fcc1e6
commit ad75561efe

View File

@@ -10,7 +10,9 @@
#define FENCE_GEMM_II #define FENCE_GEMM_II
#define GEMMINI_NEW_CISC #define GEMMINI_NEW_CISC 1
static_assert(GEMMINI_NEW_CISC, "NOTE: old non-CISC code is untested; look for "
"any misalignment of fields in ciscArgs.");
constexpr bool DEBUG = false; constexpr bool DEBUG = false;
@@ -282,6 +284,8 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
/*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0, /*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0,
/*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips); /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips);
#endif #endif
// block until DMA complete
gemmini_fence(); gemmini_fence();
// also move Q to spad_addr_Q1 for the second iteration // also move Q to spad_addr_Q1 for the second iteration
@@ -309,12 +313,12 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
/*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips); /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips);
#endif #endif
gemmini_fence(); // block until DMA complete
gemmini_fence();
gemmini_fence();
gemmini_fence(); gemmini_fence();
// re-configure DMA for K and V load that will later happen in the loop // re-configure DMA for K and V load that will later happen in the loop
// FIXME: not sure necessary with new CISC
//
// GMEM addr stride for K // GMEM addr stride for K
gemmini_extended3_config_ld(dim_seqlen * sizeof(elem_t), gemmini_extended3_config_ld(dim_seqlen * sizeof(elem_t),
MVIN_SCALE_IDENTITY, false, 0); MVIN_SCALE_IDENTITY, false, 0);
@@ -424,9 +428,6 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
// FIXME: perf: prevent GMEM->SMEM load for O tile // FIXME: perf: prevent GMEM->SMEM load for O tile
gemmini_fence(); gemmini_fence();
gemmini_fence();
gemmini_fence();
gemmini_fence();
#ifdef GEMMINI_NEW_CISC #ifdef GEMMINI_NEW_CISC
gemmini_tile_compute</*store_to_spad=*/true>( gemmini_tile_compute</*store_to_spad=*/true>(
spad_hex_P_consume, spad_hex_V_consume, spad_hex_O, spad_hex_P_consume, spad_hex_V_consume, spad_hex_O,
@@ -458,16 +459,17 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
if (tid_in_warpgroup == 0) { if (tid_in_warpgroup == 0) {
// fence to GEMM II completion // fence to GEMM II completion
gemmini_fence(); gemmini_fence();
gemmini_fence();
gemmini_fence();
gemmini_fence();
#ifdef FENCE_GEMM_II #ifdef FENCE_GEMM_II
asm volatile("rescale_fence_write_start_%=:" ::);
// signal that GEMM II is finished to O rescale step // signal that GEMM II is finished to O rescale step
*smem_O_flag = 1; *smem_O_flag = 1;
vx_fence(); vx_fence();
asm volatile("rescale_fence_write_end_%=:" ::);
#endif #endif
// Kick off GEMM I
//
// 0,2,.: opcode 0 (quartile 0/2, no accum) // 0,2,.: opcode 0 (quartile 0/2, no accum)
// 1,3,.: opcode 3 (quartile 1/3, no accum) // 1,3,.: opcode 3 (quartile 1/3, no accum)
// const uint32_t opcode = 3 * (tile_k & 1); // const uint32_t opcode = 3 * (tile_k & 1);
@@ -485,10 +487,6 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
/*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_matmul); /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_matmul);
#endif #endif
// gemmini_fence();
// gemmini_fence();
// gemmini_fence();
// gemmini_fence();
asm volatile("gemm_qk_finish_%=:" ::); asm volatile("gemm_qk_finish_%=:" ::);
// data move for K and V // data move for K and V
@@ -511,7 +509,6 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
(uint64_t)(gmem_V_tile), (uint64_t)(gmem_V_tile),
k_LOOP_WS_CONFIG_ADDRS_AB) k_LOOP_WS_CONFIG_ADDRS_AB)
#endif #endif
// gemmini_fence();
// do DMA // do DMA
if (tile_k == 0) { if (tile_k == 0) {
@@ -554,9 +551,6 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
// fence everything before going to the next tile // fence everything before going to the next tile
gemmini_fence(); gemmini_fence();
gemmini_fence();
gemmini_fence();
gemmini_fence();
} }
// threadblock_barrier(warpgroup_id_in_cluster, // threadblock_barrier(warpgroup_id_in_cluster,
@@ -625,6 +619,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
} }
#ifdef FENCE_GEMM_II #ifdef FENCE_GEMM_II
asm volatile("rescale_fence_read_start_%=:" ::);
// check flag to make sure GEMM II finished and read-after-write // check flag to make sure GEMM II finished and read-after-write
// dependency on O tile is settled for rescale // dependency on O tile is settled for rescale
if (tid_in_warpgroup_simt == 0) { if (tid_in_warpgroup_simt == 0) {
@@ -634,6 +629,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
*smem_O_flag = 0; *smem_O_flag = 0;
vx_fence(); vx_fence();
} }
asm volatile("rescale_fence_read_end_%=:" ::);
#endif #endif
#if 0 #if 0