From b5916f3f0718eb0232c5c104bd9985ffea1ad6d5 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Wed, 11 Sep 2024 22:08:06 -0700 Subject: [PATCH] flash: Fix hardcoded barrier for tcore; move tcore-specific flags --- .../regression/flash_attention/flash_impl.hpp | 122 ++++++++++++------ tests/regression/flash_attention/kernel.cpp | 7 +- 2 files changed, 91 insertions(+), 38 deletions(-) diff --git a/tests/regression/flash_attention/flash_impl.hpp b/tests/regression/flash_attention/flash_impl.hpp index 410c5f4f..47e21c70 100644 --- a/tests/regression/flash_attention/flash_impl.hpp +++ b/tests/regression/flash_attention/flash_impl.hpp @@ -11,11 +11,8 @@ #define ROW_REMAINDER_LOGIC constexpr uint32_t ROWMAX_SETS = 3; -constexpr bool WARP_SPECIALIZED = false; - -constexpr uint32_t DEV_FAKE_SMEM_START_ADDR = 0xf0000000; - -constexpr bool Q_IS_K_MAJOR = true; +constexpr bool WARP_SPECIALIZED = true; +constexpr bool TENSOR_CORE = true; // temporary safety stop for wrong configs static_assert(NUM_CORES == 4); @@ -99,9 +96,12 @@ inline void thread_block_copy_rowmax(const float *src, float *dest, dest[offset] = src[offset]; } - // threadblock_barrier(threadblock_id_in_cluster, - // warps_per_threadblock_per_core); - threadblock_barrier(1, 7); + if constexpr (!TENSOR_CORE) { + threadblock_barrier(1, 7); + } else { + threadblock_barrier(threadblock_id_in_cluster, + warps_per_threadblock_per_core); + } asm volatile("threadblock_copy_rowmax_finish_%=:" ::); } @@ -128,7 +128,12 @@ inline void thread_block_copy_tile(const float *src, float *dest, if (row >= B_ROW) { // WARNING: the number of barrier calls have to exactly match that in the // outside of the branch to prevent stalls!! FIXME better proof this. - threadblock_barrier(1, 7); + if constexpr (!TENSOR_CORE) { + threadblock_barrier(1, 7); + } else { + threadblock_barrier(threadblock_id_in_cluster, + warps_per_threadblock_per_core); + } continue; } #endif @@ -146,9 +151,12 @@ inline void thread_block_copy_tile(const float *src, float *dest, dest[gmem_offset] = src[smem_offset]; } - // threadblock_barrier(threadblock_id_in_cluster, - // warps_per_threadblock_per_core); - threadblock_barrier(1, 7); + if constexpr (!TENSOR_CORE) { + threadblock_barrier(1, 7); + } else { + threadblock_barrier(threadblock_id_in_cluster, + warps_per_threadblock_per_core); + } } asm volatile("threadblock_copy_tile_finish_%=:" ::); @@ -200,12 +208,28 @@ __attribute__((always_inline)) inline void thread_block_online_softmax( if (row >= B_ROW) { // WARNING: the number of barrier calls have to exactly match that in the // outside of the branch to prevent stalls!! FIXME better proof this. - threadblock_barrier(1, 7); - threadblock_barrier(1, 7); - threadblock_barrier(1, 7); - threadblock_barrier(1, 7); - threadblock_barrier(1, 7); - threadblock_barrier(1, 7); + if constexpr (!TENSOR_CORE) { + threadblock_barrier(1, 7); + threadblock_barrier(1, 7); + threadblock_barrier(1, 7); + threadblock_barrier(1, 7); + threadblock_barrier(1, 7); + threadblock_barrier(1, 7); + } else { + threadblock_barrier(threadblock_id_in_cluster, + warps_per_threadblock_per_core); + threadblock_barrier(threadblock_id_in_cluster, + warps_per_threadblock_per_core); + threadblock_barrier(threadblock_id_in_cluster, + warps_per_threadblock_per_core); + threadblock_barrier(threadblock_id_in_cluster, + warps_per_threadblock_per_core); + threadblock_barrier(threadblock_id_in_cluster, + warps_per_threadblock_per_core); + threadblock_barrier(threadblock_id_in_cluster, + warps_per_threadblock_per_core); + } + continue; } #endif @@ -271,9 +295,12 @@ __attribute__((always_inline)) inline void thread_block_online_softmax( warp_smem[tid_in_warp] = per_thread_max; // sync writes to warp_smem - // threadblock_barrier(threadblock_id_in_cluster, - // warps_per_threadblock_per_core); - threadblock_barrier(1, 7); + if constexpr (!TENSOR_CORE) { + threadblock_barrier(1, 7); + } else { + threadblock_barrier(threadblock_id_in_cluster, + warps_per_threadblock_per_core); + } // #define PARALLEL_ROWMAX #ifndef PARALLEL_ROWMAX @@ -323,9 +350,13 @@ __attribute__((always_inline)) inline void thread_block_online_softmax( #endif // PARALLEL_ROWMAX #endif // DUMB_ROWMAX - // threadblock_barrier(threadblock_id_in_cluster, - // warps_per_threadblock_per_core); - threadblock_barrier(1, 7); + if constexpr (!TENSOR_CORE) { + threadblock_barrier(1, 7); + } else { + threadblock_barrier(threadblock_id_in_cluster, + warps_per_threadblock_per_core); + } + // broadcast prev rowmax to all threads in the warp // NOTE: memory consistency is a little sketchy here @@ -367,9 +398,12 @@ __attribute__((always_inline)) inline void thread_block_online_softmax( asm volatile("flashattn_exp_p_end_%=:" ::); - // threadblock_barrier(threadblock_id_in_cluster, - // warps_per_threadblock_per_core); - threadblock_barrier(1, 7); + if constexpr (!TENSOR_CORE) { + threadblock_barrier(1, 7); + } else { + threadblock_barrier(threadblock_id_in_cluster, + warps_per_threadblock_per_core); + } // rowsum // @@ -395,9 +429,12 @@ __attribute__((always_inline)) inline void thread_block_online_softmax( warp_smem[tid_in_warp] = per_thread_sum; // sync writes to warp_smem - // threadblock_barrier(threadblock_id_in_cluster, - // warps_per_threadblock_per_core); - threadblock_barrier(1, 7); + if constexpr (!TENSOR_CORE) { + threadblock_barrier(1, 7); + } else { + threadblock_barrier(threadblock_id_in_cluster, + warps_per_threadblock_per_core); + } // 0-th thread collects all other thread's values in the warp if (tid_in_warp == 0) { @@ -425,9 +462,12 @@ __attribute__((always_inline)) inline void thread_block_online_softmax( asm volatile("flashattn_rowsum_end_%=:" ::); - // threadblock_barrier(threadblock_id_in_cluster, - // warps_per_threadblock_per_core); - threadblock_barrier(1, 7); + if constexpr (!TENSOR_CORE) { + threadblock_barrier(1, 7); + } else { + threadblock_barrier(threadblock_id_in_cluster, + warps_per_threadblock_per_core); + } // compute Oi rescale factor // FIXME: parallelize this across threads @@ -451,9 +491,12 @@ __attribute__((always_inline)) inline void thread_block_online_softmax( asm volatile("flashattn_rescale_factor_end_%=:" ::); - // threadblock_barrier(threadblock_id_in_cluster, - // warps_per_threadblock_per_core); - threadblock_barrier(1, 7); + if constexpr (!TENSOR_CORE) { + threadblock_barrier(1, 7); + } else { + threadblock_barrier(threadblock_id_in_cluster, + warps_per_threadblock_per_core); + } } asm volatile("thread_block_online_softmax_finish_%=:" ::); @@ -503,7 +546,12 @@ __attribute__((always_inline)) inline void thread_block_O_rescale( } // reconverge after warp divergence - threadblock_barrier(1, 7); + if constexpr (!TENSOR_CORE) { + threadblock_barrier(1, 7); + } else { + threadblock_barrier(threadblock_id_in_cluster, + warps_per_threadblock_per_core); + } asm volatile("thread_block_O_rescale_finish_%=:" ::); } diff --git a/tests/regression/flash_attention/kernel.cpp b/tests/regression/flash_attention/kernel.cpp index 9eee2b60..1c9b015d 100644 --- a/tests/regression/flash_attention/kernel.cpp +++ b/tests/regression/flash_attention/kernel.cpp @@ -8,6 +8,9 @@ #include "gemmini_mmio.h" #include "flash_impl.hpp" +constexpr bool DEBUG = false; +constexpr bool Q_IS_K_MAJOR = true; + void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { // @perf: All threads are running these compute whose result is mostly same // across the threadblock @@ -88,6 +91,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { uint8_t *smem_per_threadblock = reinterpret_cast( DEV_SMEM_START_ADDR); float *smem_cursor = reinterpret_cast(smem_per_threadblock); + // constexpr uint32_t DEV_FAKE_SMEM_START_ADDR = 0xf0000000; // float *smem_cursor = reinterpret_cast(DEV_FAKE_SMEM_START_ADDR); float *smem_Q0 = smem_cursor; smem_cursor += smem_Q_size; @@ -310,7 +314,8 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { // "inner loop" along the columns of K^T const uint32_t k_tiles = (dim_seqlen / B_COL); - for (uint32_t tile_k = 0; tile_k < k_tiles; tile_k++) { + for (uint32_t tile_k = 0; tile_k < (4 /* for perf measurement */ * k_tiles); + tile_k++) { // float *smem_P_produce = (tile_k % 2) ? smem_P0 : smem_P1; // float *smem_P_consume = (tile_k % 2) ? smem_P1 : smem_P0; // float *smem_V_produce = (tile_k % 2) ? smem_V0 : smem_V1;