flash: Add flag in SMEM for dependency check on O
TODO: results unverified. Stalls O rescale until GEMM II finishes.
This commit is contained in:
@@ -176,6 +176,19 @@ __attribute__((always_inline)) inline void thread_block_online_softmax(
|
||||
for (int row_offset = 0; row_offset < B_ROW;
|
||||
row_offset += warps_in_threadblock) {
|
||||
const uint32_t row = row_offset + warp_id;
|
||||
// if the number of warps doesn't exactly divide the number of rows,
|
||||
// early-exit to prevent out-of-bounds access
|
||||
// if (row >= B_ROW) {
|
||||
// // WARNING: the number of barrier calls have to exactly match that in the
|
||||
// // outside of the branch to prevent stalls!! FIXME better proof this.
|
||||
// threadblock_barrier(1, 7);
|
||||
// threadblock_barrier(1, 7);
|
||||
// threadblock_barrier(1, 7);
|
||||
// threadblock_barrier(1, 7);
|
||||
// threadblock_barrier(1, 7);
|
||||
// threadblock_barrier(1, 7);
|
||||
// continue;
|
||||
// }
|
||||
const uint32_t first_thread_offset = B_COL * row;
|
||||
|
||||
// rowmax
|
||||
@@ -334,7 +347,6 @@ __attribute__((always_inline)) inline void thread_block_online_softmax(
|
||||
|
||||
asm volatile("flashattn_exp_p_end_%=:" ::);
|
||||
|
||||
|
||||
// threadblock_barrier(threadblock_id_in_cluster,
|
||||
// warps_per_threadblock_per_core);
|
||||
threadblock_barrier(1, 7);
|
||||
|
||||
Reference in New Issue
Block a user