flash: Fix hardcoded barrier for tcore; move tcore-specific flags

This commit is contained in:
Hansung Kim
2024-09-11 22:08:06 -07:00
parent d69707f686
commit b5916f3f07
2 changed files with 91 additions and 38 deletions

View File

@@ -11,11 +11,8 @@
#define ROW_REMAINDER_LOGIC
constexpr uint32_t ROWMAX_SETS = 3;
constexpr bool WARP_SPECIALIZED = false;
constexpr uint32_t DEV_FAKE_SMEM_START_ADDR = 0xf0000000;
constexpr bool Q_IS_K_MAJOR = true;
constexpr bool WARP_SPECIALIZED = true;
constexpr bool TENSOR_CORE = true;
// temporary safety stop for wrong configs
static_assert(NUM_CORES == 4);
@@ -99,9 +96,12 @@ inline void thread_block_copy_rowmax(const float *src, float *dest,
dest[offset] = src[offset];
}
// threadblock_barrier(threadblock_id_in_cluster,
// warps_per_threadblock_per_core);
threadblock_barrier(1, 7);
if constexpr (!TENSOR_CORE) {
threadblock_barrier(1, 7);
} else {
threadblock_barrier(threadblock_id_in_cluster,
warps_per_threadblock_per_core);
}
asm volatile("threadblock_copy_rowmax_finish_%=:" ::);
}
@@ -128,7 +128,12 @@ inline void thread_block_copy_tile(const float *src, float *dest,
if (row >= B_ROW) {
// WARNING: the number of barrier calls have to exactly match that in the
// outside of the branch to prevent stalls!! FIXME better proof this.
threadblock_barrier(1, 7);
if constexpr (!TENSOR_CORE) {
threadblock_barrier(1, 7);
} else {
threadblock_barrier(threadblock_id_in_cluster,
warps_per_threadblock_per_core);
}
continue;
}
#endif
@@ -146,9 +151,12 @@ inline void thread_block_copy_tile(const float *src, float *dest,
dest[gmem_offset] = src[smem_offset];
}
// threadblock_barrier(threadblock_id_in_cluster,
// warps_per_threadblock_per_core);
threadblock_barrier(1, 7);
if constexpr (!TENSOR_CORE) {
threadblock_barrier(1, 7);
} else {
threadblock_barrier(threadblock_id_in_cluster,
warps_per_threadblock_per_core);
}
}
asm volatile("threadblock_copy_tile_finish_%=:" ::);
@@ -200,12 +208,28 @@ __attribute__((always_inline)) inline void thread_block_online_softmax(
if (row >= B_ROW) {
// WARNING: the number of barrier calls have to exactly match that in the
// outside of the branch to prevent stalls!! FIXME better proof this.
threadblock_barrier(1, 7);
threadblock_barrier(1, 7);
threadblock_barrier(1, 7);
threadblock_barrier(1, 7);
threadblock_barrier(1, 7);
threadblock_barrier(1, 7);
if constexpr (!TENSOR_CORE) {
threadblock_barrier(1, 7);
threadblock_barrier(1, 7);
threadblock_barrier(1, 7);
threadblock_barrier(1, 7);
threadblock_barrier(1, 7);
threadblock_barrier(1, 7);
} else {
threadblock_barrier(threadblock_id_in_cluster,
warps_per_threadblock_per_core);
threadblock_barrier(threadblock_id_in_cluster,
warps_per_threadblock_per_core);
threadblock_barrier(threadblock_id_in_cluster,
warps_per_threadblock_per_core);
threadblock_barrier(threadblock_id_in_cluster,
warps_per_threadblock_per_core);
threadblock_barrier(threadblock_id_in_cluster,
warps_per_threadblock_per_core);
threadblock_barrier(threadblock_id_in_cluster,
warps_per_threadblock_per_core);
}
continue;
}
#endif
@@ -271,9 +295,12 @@ __attribute__((always_inline)) inline void thread_block_online_softmax(
warp_smem[tid_in_warp] = per_thread_max;
// sync writes to warp_smem
// threadblock_barrier(threadblock_id_in_cluster,
// warps_per_threadblock_per_core);
threadblock_barrier(1, 7);
if constexpr (!TENSOR_CORE) {
threadblock_barrier(1, 7);
} else {
threadblock_barrier(threadblock_id_in_cluster,
warps_per_threadblock_per_core);
}
// #define PARALLEL_ROWMAX
#ifndef PARALLEL_ROWMAX
@@ -323,9 +350,13 @@ __attribute__((always_inline)) inline void thread_block_online_softmax(
#endif // PARALLEL_ROWMAX
#endif // DUMB_ROWMAX
// threadblock_barrier(threadblock_id_in_cluster,
// warps_per_threadblock_per_core);
threadblock_barrier(1, 7);
if constexpr (!TENSOR_CORE) {
threadblock_barrier(1, 7);
} else {
threadblock_barrier(threadblock_id_in_cluster,
warps_per_threadblock_per_core);
}
// broadcast prev rowmax to all threads in the warp
// NOTE: memory consistency is a little sketchy here
@@ -367,9 +398,12 @@ __attribute__((always_inline)) inline void thread_block_online_softmax(
asm volatile("flashattn_exp_p_end_%=:" ::);
// threadblock_barrier(threadblock_id_in_cluster,
// warps_per_threadblock_per_core);
threadblock_barrier(1, 7);
if constexpr (!TENSOR_CORE) {
threadblock_barrier(1, 7);
} else {
threadblock_barrier(threadblock_id_in_cluster,
warps_per_threadblock_per_core);
}
// rowsum
//
@@ -395,9 +429,12 @@ __attribute__((always_inline)) inline void thread_block_online_softmax(
warp_smem[tid_in_warp] = per_thread_sum;
// sync writes to warp_smem
// threadblock_barrier(threadblock_id_in_cluster,
// warps_per_threadblock_per_core);
threadblock_barrier(1, 7);
if constexpr (!TENSOR_CORE) {
threadblock_barrier(1, 7);
} else {
threadblock_barrier(threadblock_id_in_cluster,
warps_per_threadblock_per_core);
}
// 0-th thread collects all other thread's values in the warp
if (tid_in_warp == 0) {
@@ -425,9 +462,12 @@ __attribute__((always_inline)) inline void thread_block_online_softmax(
asm volatile("flashattn_rowsum_end_%=:" ::);
// threadblock_barrier(threadblock_id_in_cluster,
// warps_per_threadblock_per_core);
threadblock_barrier(1, 7);
if constexpr (!TENSOR_CORE) {
threadblock_barrier(1, 7);
} else {
threadblock_barrier(threadblock_id_in_cluster,
warps_per_threadblock_per_core);
}
// compute Oi rescale factor
// FIXME: parallelize this across threads
@@ -451,9 +491,12 @@ __attribute__((always_inline)) inline void thread_block_online_softmax(
asm volatile("flashattn_rescale_factor_end_%=:" ::);
// threadblock_barrier(threadblock_id_in_cluster,
// warps_per_threadblock_per_core);
threadblock_barrier(1, 7);
if constexpr (!TENSOR_CORE) {
threadblock_barrier(1, 7);
} else {
threadblock_barrier(threadblock_id_in_cluster,
warps_per_threadblock_per_core);
}
}
asm volatile("thread_block_online_softmax_finish_%=:" ::);
@@ -503,7 +546,12 @@ __attribute__((always_inline)) inline void thread_block_O_rescale(
}
// reconverge after warp divergence
threadblock_barrier(1, 7);
if constexpr (!TENSOR_CORE) {
threadblock_barrier(1, 7);
} else {
threadblock_barrier(threadblock_id_in_cluster,
warps_per_threadblock_per_core);
}
asm volatile("thread_block_O_rescale_finish_%=:" ::);
}