From dc746272fb1a36e7508c47c5be572ce78e739f83 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Tue, 10 Sep 2024 22:53:35 -0700 Subject: [PATCH] flash: Conditionally enable GEMM II fence code, fix tile_k for DEBUG --- .../flash_attention/kernel.gemmini.cpp | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/tests/regression/flash_attention/kernel.gemmini.cpp b/tests/regression/flash_attention/kernel.gemmini.cpp index 35a8cdf6..51993b21 100644 --- a/tests/regression/flash_attention/kernel.gemmini.cpp +++ b/tests/regression/flash_attention/kernel.gemmini.cpp @@ -8,7 +8,7 @@ #include "gemmini_mmio.h" #include "flash_impl.hpp" -constexpr bool DEBUG = true; +constexpr bool DEBUG = false; static_assert(GEMMINI_DMA && !WARP_SPECIALIZED, "GEMMINI_DMA should be set and WARP_SPECIALIZED unset"); @@ -438,9 +438,11 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { gemmini_fence(); gemmini_fence(); +#ifdef FENCE_GEMM_II // signal that GEMM II is finished to O rescale step *smem_O_flag = 1; vx_fence(); +#endif // 0,2,.: opcode 0 (quartile 0/2, no accum) // 1,3,.: opcode 3 (quartile 1/3, no accum) @@ -540,8 +542,10 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { const uint32_t tile_k_ = tile_k - 1; if constexpr (DEBUG) { - // verify S = Q*K + gemmini_fence(); + gemmini_fence(); + // verify S = Q*K if (warpgroup_id == 0) { if (tile_k_ == 0) { thread_block_copy_tile( @@ -588,6 +592,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { } } +#ifdef FENCE_GEMM_II // check flag to make sure GEMM II finished and read-after-write // dependency on O tile is settled for rescale if (tid_in_warpgroup_simt == 0) { @@ -597,6 +602,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { *smem_O_flag = 0; vx_fence(); } +#endif #if 0 if (tid_in_warpgroup == 0) { @@ -612,15 +618,16 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { #endif if constexpr (DEBUG) { - // gemmini_fence(); - if (warpgroup_id == 0) { + gemmini_fence(); + gemmini_fence(); + // O after PV - if (tile_k_ == 0) { + if (tile_k_ == 1 /*wait until GEMM II finshes */) { thread_block_copy_tile( smem_O, gmem_tmp_d6, tid_in_warpgroup_simt, threads_per_warpgroup_simt, warpgroup_id_simt); - } else if (tile_k_ == 1) { + } else if (tile_k_ == 2) { thread_block_copy_tile( smem_O, gmem_tmp_d7, tid_in_warpgroup_simt, threads_per_warpgroup_simt, warpgroup_id_simt);