From 9cabe3413b496f6c246ffee8cde271f05b968b6a Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Wed, 14 Aug 2024 21:09:47 -0700 Subject: [PATCH] Fix overlapping smem in rowmax --- tests/regression/flash_attention/kernel.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/regression/flash_attention/kernel.cpp b/tests/regression/flash_attention/kernel.cpp index bc212ade..5a6e71e7 100644 --- a/tests/regression/flash_attention/kernel.cpp +++ b/tests/regression/flash_attention/kernel.cpp @@ -43,8 +43,8 @@ inline void thread_block_flashattn(float *S, float *gmem, const uint32_t first_thread_offset = Bcol * row; uint32_t thread_offset = first_thread_offset + tid_in_warp; - constexpr uint32_t load_iter = Bcol / NUM_THREADS; float curr_max = S[first_thread_offset]; + constexpr uint32_t load_iter = Bcol / NUM_THREADS; #pragma GCC unroll for (int iter = 0; iter < load_iter; iter++) { asm volatile("fmax.s %0, %1, %2" @@ -53,7 +53,7 @@ inline void thread_block_flashattn(float *S, float *gmem, thread_offset += NUM_THREADS; } // get max value across the same-warp threads using smem - float *warp_smem = S + (row * NUM_THREADS); + float *warp_smem = S + (2 * Brow * Bcol) + (row * NUM_THREADS); warp_smem[tid_in_warp] = curr_max; // sync writes to warp_smem