From 18cf0e73cd6ed9a8b14031e4afe9faf0f48cd651 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Wed, 11 Sep 2024 00:56:09 -0700 Subject: [PATCH] flash: Add early return for warp-indivisible row iter --- .../regression/flash_attention/flash_impl.hpp | 53 +++++++++++++++---- 1 file changed, 42 insertions(+), 11 deletions(-) diff --git a/tests/regression/flash_attention/flash_impl.hpp b/tests/regression/flash_attention/flash_impl.hpp index eb1a43bb..410c5f4f 100644 --- a/tests/regression/flash_attention/flash_impl.hpp +++ b/tests/regression/flash_attention/flash_impl.hpp @@ -8,6 +8,8 @@ #define B_COL 64 #define HEADDIM 64 +#define ROW_REMAINDER_LOGIC + constexpr uint32_t ROWMAX_SETS = 3; constexpr bool WARP_SPECIALIZED = false; @@ -56,6 +58,14 @@ inline void thread_block_init_sharedmem(const uint32_t tid_in_threadblock, for (int row_offset = 0; row_offset < B_COL; row_offset += warps_in_threadblock) { const uint32_t row = row_offset + warp_id; +#ifdef ROW_REMAINDER_LOGIC + if (row >= B_ROW) { + // WARNING: the number of barrier calls have to exactly match that in the + // outside of the branch to prevent stalls!! FIXME better proof this. + continue; + } +#endif + uint32_t thread_offset = HEADDIM * row + tid_in_warp; constexpr uint32_t per_row_iter = HEADDIM / NUM_THREADS; const float one = 0.0f; @@ -114,6 +124,14 @@ inline void thread_block_copy_tile(const float *src, float *dest, for (int row_offset = 0; row_offset < dim_row; row_offset += warps_in_threadblock) { const uint32_t row = row_offset + warp_id; +#ifdef ROW_REMAINDER_LOGIC + if (row >= B_ROW) { + // WARNING: the number of barrier calls have to exactly match that in the + // outside of the branch to prevent stalls!! FIXME better proof this. + threadblock_barrier(1, 7); + continue; + } +#endif constexpr uint32_t per_row_iter = dim_col / NUM_THREADS; #pragma GCC unroll @@ -176,19 +194,21 @@ __attribute__((always_inline)) inline void thread_block_online_softmax( for (int row_offset = 0; row_offset < B_ROW; row_offset += warps_in_threadblock) { const uint32_t row = row_offset + warp_id; +#ifdef ROW_REMAINDER_LOGIC // if the number of warps doesn't exactly divide the number of rows, // early-exit to prevent out-of-bounds access - // if (row >= B_ROW) { - // // WARNING: the number of barrier calls have to exactly match that in the - // // outside of the branch to prevent stalls!! FIXME better proof this. - // threadblock_barrier(1, 7); - // threadblock_barrier(1, 7); - // threadblock_barrier(1, 7); - // threadblock_barrier(1, 7); - // threadblock_barrier(1, 7); - // threadblock_barrier(1, 7); - // continue; - // } + if (row >= B_ROW) { + // WARNING: the number of barrier calls have to exactly match that in the + // outside of the branch to prevent stalls!! FIXME better proof this. + threadblock_barrier(1, 7); + threadblock_barrier(1, 7); + threadblock_barrier(1, 7); + threadblock_barrier(1, 7); + threadblock_barrier(1, 7); + threadblock_barrier(1, 7); + continue; + } +#endif const uint32_t first_thread_offset = B_COL * row; // rowmax @@ -456,6 +476,14 @@ __attribute__((always_inline)) inline void thread_block_O_rescale( for (int row_offset = 0; row_offset < B_ROW; row_offset += warps_in_threadblock) { const uint32_t row = row_offset + warp_id; +#ifdef ROW_REMAINDER_LOGIC + if (row >= B_ROW) { + // WARNING: the number of barrier calls have to exactly match that in the + // outside of the branch to prevent stalls!! FIXME better proof this. + continue; + } +#endif + constexpr uint32_t per_row_iter = HEADDIM / NUM_THREADS; // Oi rescale @@ -474,6 +502,9 @@ __attribute__((always_inline)) inline void thread_block_O_rescale( } } + // reconverge after warp divergence + threadblock_barrier(1, 7); + asm volatile("thread_block_O_rescale_finish_%=:" ::); }