From aea257349ad2b448368027c219647457fecafde1 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Sun, 1 Sep 2024 20:40:26 -0700 Subject: [PATCH] flash: Correct schedule with inter-warpgroup barriers --- tests/regression/flash_attention/kernel.cpp | 45 +++++++++++++-------- 1 file changed, 29 insertions(+), 16 deletions(-) diff --git a/tests/regression/flash_attention/kernel.cpp b/tests/regression/flash_attention/kernel.cpp index 79d314a3..f58f3000 100644 --- a/tests/regression/flash_attention/kernel.cpp +++ b/tests/regression/flash_attention/kernel.cpp @@ -601,6 +601,11 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { constexpr uint32_t global_barrier_id = NUM_WARPS - 1; // arbitrary + // delay warpgroup 0 by 1 iteration to do ping-pong scheduling + if (warpgroup_id == 1) { + threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core); + } + asm volatile ("tile_loop_start_%=:" :: ); // "inner loop" along the columns of K^T @@ -636,10 +641,6 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { constexpr bool skip_gemm_qk = true; if constexpr (!skip_gemm_qk) { - // clear out accumulators - initialize_accum_regs<0>(); - initialize_accum_regs<1>(); - static_assert(B_ROW == B_COL, "currently only supports square tiles"); // load Q @@ -659,6 +660,10 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core); + // clear out accumulators before GEMM + initialize_accum_regs<0>(); + initialize_accum_regs<1>(); + // GEMM I: S = Q*K thread_block_gemm_single_tile( + HEADDIM, 0 /* full N-dimension */, tile_k_, gmem_V, smem_V, + tid_in_warpgroup); + threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core); if constexpr (DEBUG) { @@ -719,17 +732,11 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { } } + // inter-warpgroup barrier before GEMM II + threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core); + // GEMM II: O = O + P*V - // V dimension is [seqlen, headdim], stored N(headdim)-major - load_tile_to_smem( - HEADDIM, 0 /* 0 because always reads the full N-dimension */, tile_k_, - gmem_V, smem_V, tid_in_warpgroup); - - // FIXME: should be removable - threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core); - // Oi rescale thread_block_O_rescale(smem_O, smem_O /*in-place*/, smem_O_row_scale, tid_in_warpgroup, @@ -769,7 +776,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { } if constexpr (!WARP_SPECIALIZED) { - // clear out accumulators + // clear out accumulators before GEMM initialize_accum_regs<0>(); initialize_accum_regs<1>(); @@ -802,7 +809,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { float *smem_O_half0 = smem_O; float *smem_O_half1 = smem_O + (B_ROW / 2) * HEADDIM; - // clear out accumulators + // clear out accumulators before GEMM initialize_accum_regs<0>(); initialize_accum_regs<1>(); @@ -855,6 +862,12 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { } asm volatile ("tile_loop_finish_%=:" :: ); + + // wait for warpgroup 1 to finish, which called the global barrier before + // entering the loop + if (warpgroup_id == 0) { + threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core); + } } int main() {