flash: Warp-specialize between warp 0 and 1-7

Finishes without stalls; No dependency check between O rescale and
GEMM-II.
This commit is contained in:
Hansung Kim
2024-09-09 16:42:30 -07:00
parent d31c8ffd7d
commit b652e25945
2 changed files with 261 additions and 242 deletions

View File

@@ -236,8 +236,9 @@ __attribute__((always_inline)) inline void thread_block_online_softmax(
warp_smem[tid_in_warp] = per_thread_max;
// sync writes to warp_smem
threadblock_barrier(threadblock_id_in_cluster,
warps_per_threadblock_per_core);
// threadblock_barrier(threadblock_id_in_cluster,
// warps_per_threadblock_per_core);
threadblock_barrier(1, 7);
// #define PARALLEL_ROWMAX
#ifndef PARALLEL_ROWMAX
@@ -287,8 +288,9 @@ __attribute__((always_inline)) inline void thread_block_online_softmax(
#endif // PARALLEL_ROWMAX
#endif // DUMB_ROWMAX
threadblock_barrier(threadblock_id_in_cluster,
warps_per_threadblock_per_core);
// threadblock_barrier(threadblock_id_in_cluster,
// warps_per_threadblock_per_core);
threadblock_barrier(1, 7);
// broadcast prev rowmax to all threads in the warp
// NOTE: memory consistency is a little sketchy here
@@ -331,8 +333,9 @@ __attribute__((always_inline)) inline void thread_block_online_softmax(
asm volatile("flashattn_exp_p_end_%=:" ::);
threadblock_barrier(threadblock_id_in_cluster,
warps_per_threadblock_per_core);
// threadblock_barrier(threadblock_id_in_cluster,
// warps_per_threadblock_per_core);
threadblock_barrier(1, 7);
// rowsum
//
@@ -358,8 +361,9 @@ __attribute__((always_inline)) inline void thread_block_online_softmax(
warp_smem[tid_in_warp] = per_thread_sum;
// sync writes to warp_smem
threadblock_barrier(threadblock_id_in_cluster,
warps_per_threadblock_per_core);
// threadblock_barrier(threadblock_id_in_cluster,
// warps_per_threadblock_per_core);
threadblock_barrier(1, 7);
// 0-th thread collects all other thread's values in the warp
if (tid_in_warp == 0) {
@@ -387,8 +391,9 @@ __attribute__((always_inline)) inline void thread_block_online_softmax(
asm volatile("flashattn_rowsum_end_%=:" ::);
threadblock_barrier(threadblock_id_in_cluster,
warps_per_threadblock_per_core);
// threadblock_barrier(threadblock_id_in_cluster,
// warps_per_threadblock_per_core);
threadblock_barrier(1, 7);
// compute Oi rescale factor
// FIXME: parallelize this across threads
@@ -412,8 +417,9 @@ __attribute__((always_inline)) inline void thread_block_online_softmax(
asm volatile("flashattn_rescale_factor_end_%=:" ::);
threadblock_barrier(threadblock_id_in_cluster,
warps_per_threadblock_per_core);
// threadblock_barrier(threadblock_id_in_cluster,
// warps_per_threadblock_per_core);
threadblock_barrier(1, 7);
}
asm volatile("thread_block_online_softmax_finish_%=:" ::);