flash: Add flag in SMEM for dependency check on O

TODO: results unverified.
Stalls O rescale until GEMM II finishes.
This commit is contained in:
Hansung Kim
2024-09-10 13:37:32 -07:00
parent 88760596cb
commit 90e03894fc
2 changed files with 41 additions and 12 deletions

View File

@@ -176,6 +176,19 @@ __attribute__((always_inline)) inline void thread_block_online_softmax(
for (int row_offset = 0; row_offset < B_ROW;
row_offset += warps_in_threadblock) {
const uint32_t row = row_offset + warp_id;
// if the number of warps doesn't exactly divide the number of rows,
// early-exit to prevent out-of-bounds access
// if (row >= B_ROW) {
// // WARNING: the number of barrier calls have to exactly match that in the
// // outside of the branch to prevent stalls!! FIXME better proof this.
// threadblock_barrier(1, 7);
// threadblock_barrier(1, 7);
// threadblock_barrier(1, 7);
// threadblock_barrier(1, 7);
// threadblock_barrier(1, 7);
// threadblock_barrier(1, 7);
// continue;
// }
const uint32_t first_thread_offset = B_COL * row;
// rowmax
@@ -334,7 +347,6 @@ __attribute__((always_inline)) inline void thread_block_online_softmax(
asm volatile("flashattn_exp_p_end_%=:" ::);
// threadblock_barrier(threadblock_id_in_cluster,
// warps_per_threadblock_per_core);
threadblock_barrier(1, 7);