flash: Warp-specialize between warp 0 and 1-7
Finishes without stalls; No dependency check between O rescale and GEMM-II.
This commit is contained in:
@@ -236,8 +236,9 @@ __attribute__((always_inline)) inline void thread_block_online_softmax(
|
||||
warp_smem[tid_in_warp] = per_thread_max;
|
||||
|
||||
// sync writes to warp_smem
|
||||
threadblock_barrier(threadblock_id_in_cluster,
|
||||
warps_per_threadblock_per_core);
|
||||
// threadblock_barrier(threadblock_id_in_cluster,
|
||||
// warps_per_threadblock_per_core);
|
||||
threadblock_barrier(1, 7);
|
||||
|
||||
// #define PARALLEL_ROWMAX
|
||||
#ifndef PARALLEL_ROWMAX
|
||||
@@ -287,8 +288,9 @@ __attribute__((always_inline)) inline void thread_block_online_softmax(
|
||||
#endif // PARALLEL_ROWMAX
|
||||
#endif // DUMB_ROWMAX
|
||||
|
||||
threadblock_barrier(threadblock_id_in_cluster,
|
||||
warps_per_threadblock_per_core);
|
||||
// threadblock_barrier(threadblock_id_in_cluster,
|
||||
// warps_per_threadblock_per_core);
|
||||
threadblock_barrier(1, 7);
|
||||
|
||||
// broadcast prev rowmax to all threads in the warp
|
||||
// NOTE: memory consistency is a little sketchy here
|
||||
@@ -331,8 +333,9 @@ __attribute__((always_inline)) inline void thread_block_online_softmax(
|
||||
asm volatile("flashattn_exp_p_end_%=:" ::);
|
||||
|
||||
|
||||
threadblock_barrier(threadblock_id_in_cluster,
|
||||
warps_per_threadblock_per_core);
|
||||
// threadblock_barrier(threadblock_id_in_cluster,
|
||||
// warps_per_threadblock_per_core);
|
||||
threadblock_barrier(1, 7);
|
||||
|
||||
// rowsum
|
||||
//
|
||||
@@ -358,8 +361,9 @@ __attribute__((always_inline)) inline void thread_block_online_softmax(
|
||||
warp_smem[tid_in_warp] = per_thread_sum;
|
||||
|
||||
// sync writes to warp_smem
|
||||
threadblock_barrier(threadblock_id_in_cluster,
|
||||
warps_per_threadblock_per_core);
|
||||
// threadblock_barrier(threadblock_id_in_cluster,
|
||||
// warps_per_threadblock_per_core);
|
||||
threadblock_barrier(1, 7);
|
||||
|
||||
// 0-th thread collects all other thread's values in the warp
|
||||
if (tid_in_warp == 0) {
|
||||
@@ -387,8 +391,9 @@ __attribute__((always_inline)) inline void thread_block_online_softmax(
|
||||
|
||||
asm volatile("flashattn_rowsum_end_%=:" ::);
|
||||
|
||||
threadblock_barrier(threadblock_id_in_cluster,
|
||||
warps_per_threadblock_per_core);
|
||||
// threadblock_barrier(threadblock_id_in_cluster,
|
||||
// warps_per_threadblock_per_core);
|
||||
threadblock_barrier(1, 7);
|
||||
|
||||
// compute Oi rescale factor
|
||||
// FIXME: parallelize this across threads
|
||||
@@ -412,8 +417,9 @@ __attribute__((always_inline)) inline void thread_block_online_softmax(
|
||||
|
||||
asm volatile("flashattn_rescale_factor_end_%=:" ::);
|
||||
|
||||
threadblock_barrier(threadblock_id_in_cluster,
|
||||
warps_per_threadblock_per_core);
|
||||
// threadblock_barrier(threadblock_id_in_cluster,
|
||||
// warps_per_threadblock_per_core);
|
||||
threadblock_barrier(1, 7);
|
||||
}
|
||||
|
||||
asm volatile("thread_block_online_softmax_finish_%=:" ::);
|
||||
|
||||
Reference in New Issue
Block a user