flash: Add early return for warp-indivisible row iter
This commit is contained in:
@@ -8,6 +8,8 @@
|
||||
#define B_COL 64
|
||||
#define HEADDIM 64
|
||||
|
||||
#define ROW_REMAINDER_LOGIC
|
||||
|
||||
constexpr uint32_t ROWMAX_SETS = 3;
|
||||
constexpr bool WARP_SPECIALIZED = false;
|
||||
|
||||
@@ -56,6 +58,14 @@ inline void thread_block_init_sharedmem(const uint32_t tid_in_threadblock,
|
||||
for (int row_offset = 0; row_offset < B_COL;
|
||||
row_offset += warps_in_threadblock) {
|
||||
const uint32_t row = row_offset + warp_id;
|
||||
#ifdef ROW_REMAINDER_LOGIC
|
||||
if (row >= B_ROW) {
|
||||
// WARNING: the number of barrier calls have to exactly match that in the
|
||||
// outside of the branch to prevent stalls!! FIXME better proof this.
|
||||
continue;
|
||||
}
|
||||
#endif
|
||||
|
||||
uint32_t thread_offset = HEADDIM * row + tid_in_warp;
|
||||
constexpr uint32_t per_row_iter = HEADDIM / NUM_THREADS;
|
||||
const float one = 0.0f;
|
||||
@@ -114,6 +124,14 @@ inline void thread_block_copy_tile(const float *src, float *dest,
|
||||
for (int row_offset = 0; row_offset < dim_row;
|
||||
row_offset += warps_in_threadblock) {
|
||||
const uint32_t row = row_offset + warp_id;
|
||||
#ifdef ROW_REMAINDER_LOGIC
|
||||
if (row >= B_ROW) {
|
||||
// WARNING: the number of barrier calls have to exactly match that in the
|
||||
// outside of the branch to prevent stalls!! FIXME better proof this.
|
||||
threadblock_barrier(1, 7);
|
||||
continue;
|
||||
}
|
||||
#endif
|
||||
|
||||
constexpr uint32_t per_row_iter = dim_col / NUM_THREADS;
|
||||
#pragma GCC unroll
|
||||
@@ -176,19 +194,21 @@ __attribute__((always_inline)) inline void thread_block_online_softmax(
|
||||
for (int row_offset = 0; row_offset < B_ROW;
|
||||
row_offset += warps_in_threadblock) {
|
||||
const uint32_t row = row_offset + warp_id;
|
||||
#ifdef ROW_REMAINDER_LOGIC
|
||||
// if the number of warps doesn't exactly divide the number of rows,
|
||||
// early-exit to prevent out-of-bounds access
|
||||
// if (row >= B_ROW) {
|
||||
// // WARNING: the number of barrier calls have to exactly match that in the
|
||||
// // outside of the branch to prevent stalls!! FIXME better proof this.
|
||||
// threadblock_barrier(1, 7);
|
||||
// threadblock_barrier(1, 7);
|
||||
// threadblock_barrier(1, 7);
|
||||
// threadblock_barrier(1, 7);
|
||||
// threadblock_barrier(1, 7);
|
||||
// threadblock_barrier(1, 7);
|
||||
// continue;
|
||||
// }
|
||||
if (row >= B_ROW) {
|
||||
// WARNING: the number of barrier calls have to exactly match that in the
|
||||
// outside of the branch to prevent stalls!! FIXME better proof this.
|
||||
threadblock_barrier(1, 7);
|
||||
threadblock_barrier(1, 7);
|
||||
threadblock_barrier(1, 7);
|
||||
threadblock_barrier(1, 7);
|
||||
threadblock_barrier(1, 7);
|
||||
threadblock_barrier(1, 7);
|
||||
continue;
|
||||
}
|
||||
#endif
|
||||
const uint32_t first_thread_offset = B_COL * row;
|
||||
|
||||
// rowmax
|
||||
@@ -456,6 +476,14 @@ __attribute__((always_inline)) inline void thread_block_O_rescale(
|
||||
for (int row_offset = 0; row_offset < B_ROW;
|
||||
row_offset += warps_in_threadblock) {
|
||||
const uint32_t row = row_offset + warp_id;
|
||||
#ifdef ROW_REMAINDER_LOGIC
|
||||
if (row >= B_ROW) {
|
||||
// WARNING: the number of barrier calls have to exactly match that in the
|
||||
// outside of the branch to prevent stalls!! FIXME better proof this.
|
||||
continue;
|
||||
}
|
||||
#endif
|
||||
|
||||
constexpr uint32_t per_row_iter = HEADDIM / NUM_THREADS;
|
||||
|
||||
// Oi rescale
|
||||
@@ -474,6 +502,9 @@ __attribute__((always_inline)) inline void thread_block_O_rescale(
|
||||
}
|
||||
}
|
||||
|
||||
// reconverge after warp divergence
|
||||
threadblock_barrier(1, 7);
|
||||
|
||||
asm volatile("thread_block_O_rescale_finish_%=:" ::);
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user