From 90e03894fc69cbf6559f00597a424e84a3a3501d Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Tue, 10 Sep 2024 13:37:32 -0700 Subject: [PATCH] flash: Add flag in SMEM for dependency check on O TODO: results unverified. Stalls O rescale until GEMM II finishes. --- .../regression/flash_attention/flash_impl.hpp | 14 ++++++- .../flash_attention/kernel.gemmini.cpp | 39 +++++++++++++------ 2 files changed, 41 insertions(+), 12 deletions(-) diff --git a/tests/regression/flash_attention/flash_impl.hpp b/tests/regression/flash_attention/flash_impl.hpp index bd4aee9d..eb1a43bb 100644 --- a/tests/regression/flash_attention/flash_impl.hpp +++ b/tests/regression/flash_attention/flash_impl.hpp @@ -176,6 +176,19 @@ __attribute__((always_inline)) inline void thread_block_online_softmax( for (int row_offset = 0; row_offset < B_ROW; row_offset += warps_in_threadblock) { const uint32_t row = row_offset + warp_id; + // if the number of warps doesn't exactly divide the number of rows, + // early-exit to prevent out-of-bounds access + // if (row >= B_ROW) { + // // WARNING: the number of barrier calls have to exactly match that in the + // // outside of the branch to prevent stalls!! FIXME better proof this. + // threadblock_barrier(1, 7); + // threadblock_barrier(1, 7); + // threadblock_barrier(1, 7); + // threadblock_barrier(1, 7); + // threadblock_barrier(1, 7); + // threadblock_barrier(1, 7); + // continue; + // } const uint32_t first_thread_offset = B_COL * row; // rowmax @@ -334,7 +347,6 @@ __attribute__((always_inline)) inline void thread_block_online_softmax( asm volatile("flashattn_exp_p_end_%=:" ::); - // threadblock_barrier(threadblock_id_in_cluster, // warps_per_threadblock_per_core); threadblock_barrier(1, 7); diff --git a/tests/regression/flash_attention/kernel.gemmini.cpp b/tests/regression/flash_attention/kernel.gemmini.cpp index 884762d7..35a8cdf6 100644 --- a/tests/regression/flash_attention/kernel.gemmini.cpp +++ b/tests/regression/flash_attention/kernel.gemmini.cpp @@ -112,6 +112,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { constexpr uint32_t smem_S1_offset = smem_V1_offset + smem_V_size * sizeof(float); constexpr uint32_t smem_P0_offset = smem_K0_offset + smem_K_size * sizeof(float); constexpr uint32_t smem_P1_offset = smem_K1_offset + smem_K_size * sizeof(float); + // reversed! constexpr uint32_t smem_O0_offset = smem_P1_offset + smem_QK_size * sizeof(float); constexpr uint32_t smem_O1_offset = smem_P0_offset + smem_QK_size * sizeof(float); // unused @@ -158,6 +159,8 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { smem_cursor += smem_scratchpad_size; float *smem_scratchpad_1 = smem_cursor; smem_cursor += smem_scratchpad_size; + uint32_t *smem_O_flag = reinterpret_cast(smem_cursor); + smem_cursor += 1 /* 4Byte */; static_assert(sizeof(elem_t) == sizeof(float)); constexpr uint32_t spad_addr_factor = DIM * sizeof(elem_t); @@ -332,7 +335,9 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { // "inner loop" along the columns of K^T const uint32_t k_tiles = (dim_seqlen / B_COL); - for (uint32_t tile_k = 0; tile_k < k_tiles + 2 /*pipeline latency*/; + for (uint32_t tile_k = 0; + tile_k < + (1 /*FIXME: for perf measurement*/ * k_tiles) + 2 /*pipeline latency*/; tile_k++) { if constexpr (DEBUG || true) { threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core); @@ -371,16 +376,16 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { const auto spad_addr_O = spad_addr_O0; // NOTE: there's only single O tile asm volatile ("dbuf_sel_end_%=:" :: ); - // GEMM II: O = O + P*V - // -------------------- - // This is done *before* GEMM I in the software pipeline, working on the - // online softmax result tile from the previous iteration - if (vx_warp_id() == 0 /* warp 0 in every core */) { if (tile_k >= 2) // delay by 2 iters for pipelining { const uint32_t tile_k_ = tile_k - 2; + // GEMM II: O = O + P*V + // -------------------- + // This is done *before* GEMM I in the software pipeline, working on the + // online softmax result tile from the previous iteration + asm volatile("gemm_pv_start_%=:" ::); if (tid_in_warpgroup == 0) { @@ -427,11 +432,16 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { asm volatile("gemm_qk_start_%=:" ::); if (tid_in_warpgroup == 0) { + // fence to GEMM II completion gemmini_fence(); gemmini_fence(); gemmini_fence(); gemmini_fence(); + // signal that GEMM II is finished to O rescale step + *smem_O_flag = 1; + vx_fence(); + // 0,2,.: opcode 0 (quartile 0/2, no accum) // 1,3,.: opcode 3 (quartile 1/3, no accum) // const uint32_t opcode = 3 * (tile_k & 1); @@ -448,7 +458,6 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { gemmini_fence(); gemmini_fence(); gemmini_fence(); - } // // reconverge after mmio // threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core); @@ -534,11 +543,11 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { // verify S = Q*K if (warpgroup_id == 0) { - if (tile_k == 0) { + if (tile_k_ == 0) { thread_block_copy_tile( smem_S_produce, gmem_tmp_d0, tid_in_warpgroup_simt, threads_per_warpgroup_simt, warpgroup_id_simt); - } else if (tile_k == 1) { + } else if (tile_k_ == 1) { thread_block_copy_tile( smem_S_produce, gmem_tmp_d1, tid_in_warpgroup_simt, threads_per_warpgroup_simt, warpgroup_id_simt); @@ -579,9 +588,17 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { } } - // FIXME: put synchronization with GEMM II here + // check flag to make sure GEMM II finished and read-after-write + // dependency on O tile is settled for rescale + if (tid_in_warpgroup_simt == 0) { + while ((*smem_O_flag) != 1) + ; + // set it back to 0 for the next tile iteration + *smem_O_flag = 0; + vx_fence(); + } + #if 0 - // fence GEMM II to make sure dependency on O tile is settled if (tid_in_warpgroup == 0) { gemmini_fence(); gemmini_fence();