Fix overlapping smem in rowmax
This commit is contained in:
@@ -43,8 +43,8 @@ inline void thread_block_flashattn(float *S, float *gmem,
|
|||||||
const uint32_t first_thread_offset = Bcol * row;
|
const uint32_t first_thread_offset = Bcol * row;
|
||||||
uint32_t thread_offset = first_thread_offset + tid_in_warp;
|
uint32_t thread_offset = first_thread_offset + tid_in_warp;
|
||||||
|
|
||||||
constexpr uint32_t load_iter = Bcol / NUM_THREADS;
|
|
||||||
float curr_max = S[first_thread_offset];
|
float curr_max = S[first_thread_offset];
|
||||||
|
constexpr uint32_t load_iter = Bcol / NUM_THREADS;
|
||||||
#pragma GCC unroll
|
#pragma GCC unroll
|
||||||
for (int iter = 0; iter < load_iter; iter++) {
|
for (int iter = 0; iter < load_iter; iter++) {
|
||||||
asm volatile("fmax.s %0, %1, %2"
|
asm volatile("fmax.s %0, %1, %2"
|
||||||
@@ -53,7 +53,7 @@ inline void thread_block_flashattn(float *S, float *gmem,
|
|||||||
thread_offset += NUM_THREADS;
|
thread_offset += NUM_THREADS;
|
||||||
}
|
}
|
||||||
// get max value across the same-warp threads using smem
|
// get max value across the same-warp threads using smem
|
||||||
float *warp_smem = S + (row * NUM_THREADS);
|
float *warp_smem = S + (2 * Brow * Bcol) + (row * NUM_THREADS);
|
||||||
warp_smem[tid_in_warp] = curr_max;
|
warp_smem[tid_in_warp] = curr_max;
|
||||||
|
|
||||||
// sync writes to warp_smem
|
// sync writes to warp_smem
|
||||||
|
|||||||
Reference in New Issue
Block a user