From 986d5072239cd565123794976fb25b8c8aa6311b Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Fri, 30 Aug 2024 17:12:46 -0700 Subject: [PATCH] flash: Fix single-tile GEMM for warp-specialized With 4 warps, we can only do 32x64 GEMM; serialize 64x64 into 2 32x64 GEMM calls by split by the row. --- tests/regression/flash_attention/kernel.cpp | 65 ++++++++++++++++----- 1 file changed, 49 insertions(+), 16 deletions(-) diff --git a/tests/regression/flash_attention/kernel.cpp b/tests/regression/flash_attention/kernel.cpp index 75b20803..87f5749a 100644 --- a/tests/regression/flash_attention/kernel.cpp +++ b/tests/regression/flash_attention/kernel.cpp @@ -635,22 +635,55 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { threadblock_barrier(threadblock_id_in_cluster, warps_per_threadblock_per_core); - thread_block_gemm_single_tile( - smem_P, smem_V, smem_O /*load accum*/, smem_O, tid_in_threadblock, - threads_per_threadblock, threadblocks_per_cluster, - threadblock_id_in_cluster); - // FIXME: wrong but fast - // thread_block_gemm_single_tile( - // smem_P, smem_V, smem_O /*load accum*/, smem_O, tid_in_threadblock, - // threads_per_threadblock, threadblocks_per_cluster, - // threadblock_id_in_cluster); + if constexpr (!DOUBLE_BUF) { + thread_block_gemm_single_tile( + smem_P, smem_V, smem_O /*load accum*/, smem_O, tid_in_threadblock, + threads_per_threadblock, threadblocks_per_cluster, + threadblock_id_in_cluster); + + // FIXME: wrong but fast + // thread_block_gemm_single_tile( + // smem_P, smem_V, smem_O /*load accum*/, smem_O, tid_in_threadblock, + // threads_per_threadblock, threadblocks_per_cluster, + // threadblock_id_in_cluster); + } else { + // when warp-specialized, there's only enough warps to do 64x32 tile size + // so we need to do 2 GEMM calls + static_assert(B_ROW / 2 == 32, + "tile size assumption for warp-specialization not met"); + + // assumes smem_P is K-major + float *smem_P0 = smem_P; + float *smem_P1 = smem_P + (B_ROW / 2) * B_COL; + float *smem_O0 = smem_O; + float *smem_O1 = smem_O + (B_ROW / 2) * HEADDIM; + + // split by rows into 2 chunks + thread_block_gemm_single_tile( + smem_P0, smem_V, smem_O0 /*load accum*/, smem_O0, tid_in_threadblock, + threads_per_threadblock, threadblocks_per_cluster, + threadblock_id_in_cluster); + + thread_block_gemm_single_tile( + smem_P1, smem_V, smem_O1 /*load accum*/, smem_O1, tid_in_threadblock, + threads_per_threadblock, threadblocks_per_cluster, + threadblock_id_in_cluster); + } threadblock_barrier(threadblock_id_in_cluster, warps_per_threadblock_per_core);