flash: Add early return for warp-indivisible row iter
This commit is contained in:
@@ -8,6 +8,8 @@
|
|||||||
#define B_COL 64
|
#define B_COL 64
|
||||||
#define HEADDIM 64
|
#define HEADDIM 64
|
||||||
|
|
||||||
|
#define ROW_REMAINDER_LOGIC
|
||||||
|
|
||||||
constexpr uint32_t ROWMAX_SETS = 3;
|
constexpr uint32_t ROWMAX_SETS = 3;
|
||||||
constexpr bool WARP_SPECIALIZED = false;
|
constexpr bool WARP_SPECIALIZED = false;
|
||||||
|
|
||||||
@@ -56,6 +58,14 @@ inline void thread_block_init_sharedmem(const uint32_t tid_in_threadblock,
|
|||||||
for (int row_offset = 0; row_offset < B_COL;
|
for (int row_offset = 0; row_offset < B_COL;
|
||||||
row_offset += warps_in_threadblock) {
|
row_offset += warps_in_threadblock) {
|
||||||
const uint32_t row = row_offset + warp_id;
|
const uint32_t row = row_offset + warp_id;
|
||||||
|
#ifdef ROW_REMAINDER_LOGIC
|
||||||
|
if (row >= B_ROW) {
|
||||||
|
// WARNING: the number of barrier calls have to exactly match that in the
|
||||||
|
// outside of the branch to prevent stalls!! FIXME better proof this.
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
uint32_t thread_offset = HEADDIM * row + tid_in_warp;
|
uint32_t thread_offset = HEADDIM * row + tid_in_warp;
|
||||||
constexpr uint32_t per_row_iter = HEADDIM / NUM_THREADS;
|
constexpr uint32_t per_row_iter = HEADDIM / NUM_THREADS;
|
||||||
const float one = 0.0f;
|
const float one = 0.0f;
|
||||||
@@ -114,6 +124,14 @@ inline void thread_block_copy_tile(const float *src, float *dest,
|
|||||||
for (int row_offset = 0; row_offset < dim_row;
|
for (int row_offset = 0; row_offset < dim_row;
|
||||||
row_offset += warps_in_threadblock) {
|
row_offset += warps_in_threadblock) {
|
||||||
const uint32_t row = row_offset + warp_id;
|
const uint32_t row = row_offset + warp_id;
|
||||||
|
#ifdef ROW_REMAINDER_LOGIC
|
||||||
|
if (row >= B_ROW) {
|
||||||
|
// WARNING: the number of barrier calls have to exactly match that in the
|
||||||
|
// outside of the branch to prevent stalls!! FIXME better proof this.
|
||||||
|
threadblock_barrier(1, 7);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
constexpr uint32_t per_row_iter = dim_col / NUM_THREADS;
|
constexpr uint32_t per_row_iter = dim_col / NUM_THREADS;
|
||||||
#pragma GCC unroll
|
#pragma GCC unroll
|
||||||
@@ -176,19 +194,21 @@ __attribute__((always_inline)) inline void thread_block_online_softmax(
|
|||||||
for (int row_offset = 0; row_offset < B_ROW;
|
for (int row_offset = 0; row_offset < B_ROW;
|
||||||
row_offset += warps_in_threadblock) {
|
row_offset += warps_in_threadblock) {
|
||||||
const uint32_t row = row_offset + warp_id;
|
const uint32_t row = row_offset + warp_id;
|
||||||
|
#ifdef ROW_REMAINDER_LOGIC
|
||||||
// if the number of warps doesn't exactly divide the number of rows,
|
// if the number of warps doesn't exactly divide the number of rows,
|
||||||
// early-exit to prevent out-of-bounds access
|
// early-exit to prevent out-of-bounds access
|
||||||
// if (row >= B_ROW) {
|
if (row >= B_ROW) {
|
||||||
// // WARNING: the number of barrier calls have to exactly match that in the
|
// WARNING: the number of barrier calls have to exactly match that in the
|
||||||
// // outside of the branch to prevent stalls!! FIXME better proof this.
|
// outside of the branch to prevent stalls!! FIXME better proof this.
|
||||||
// threadblock_barrier(1, 7);
|
threadblock_barrier(1, 7);
|
||||||
// threadblock_barrier(1, 7);
|
threadblock_barrier(1, 7);
|
||||||
// threadblock_barrier(1, 7);
|
threadblock_barrier(1, 7);
|
||||||
// threadblock_barrier(1, 7);
|
threadblock_barrier(1, 7);
|
||||||
// threadblock_barrier(1, 7);
|
threadblock_barrier(1, 7);
|
||||||
// threadblock_barrier(1, 7);
|
threadblock_barrier(1, 7);
|
||||||
// continue;
|
continue;
|
||||||
// }
|
}
|
||||||
|
#endif
|
||||||
const uint32_t first_thread_offset = B_COL * row;
|
const uint32_t first_thread_offset = B_COL * row;
|
||||||
|
|
||||||
// rowmax
|
// rowmax
|
||||||
@@ -456,6 +476,14 @@ __attribute__((always_inline)) inline void thread_block_O_rescale(
|
|||||||
for (int row_offset = 0; row_offset < B_ROW;
|
for (int row_offset = 0; row_offset < B_ROW;
|
||||||
row_offset += warps_in_threadblock) {
|
row_offset += warps_in_threadblock) {
|
||||||
const uint32_t row = row_offset + warp_id;
|
const uint32_t row = row_offset + warp_id;
|
||||||
|
#ifdef ROW_REMAINDER_LOGIC
|
||||||
|
if (row >= B_ROW) {
|
||||||
|
// WARNING: the number of barrier calls have to exactly match that in the
|
||||||
|
// outside of the branch to prevent stalls!! FIXME better proof this.
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
constexpr uint32_t per_row_iter = HEADDIM / NUM_THREADS;
|
constexpr uint32_t per_row_iter = HEADDIM / NUM_THREADS;
|
||||||
|
|
||||||
// Oi rescale
|
// Oi rescale
|
||||||
@@ -474,6 +502,9 @@ __attribute__((always_inline)) inline void thread_block_O_rescale(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// reconverge after warp divergence
|
||||||
|
threadblock_barrier(1, 7);
|
||||||
|
|
||||||
asm volatile("thread_block_O_rescale_finish_%=:" ::);
|
asm volatile("thread_block_O_rescale_finish_%=:" ::);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user