From c51dc4902d5060202a3bbbdfe3ba25de2572f138 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Sat, 7 Sep 2024 23:21:28 -0700 Subject: [PATCH] flash: Fix online softmax for DMA layout --- tests/regression/flash_attention/Makefile | 2 +- .../regression/flash_attention/flash_impl.hpp | 37 ++++++----- .../flash_attention/kernel.gemmini.cpp | 63 +++++++++---------- 3 files changed, 54 insertions(+), 48 deletions(-) diff --git a/tests/regression/flash_attention/Makefile b/tests/regression/flash_attention/Makefile index 4d4fcad1..0456e983 100644 --- a/tests/regression/flash_attention/Makefile +++ b/tests/regression/flash_attention/Makefile @@ -3,7 +3,7 @@ PROJECT = flash_attention SRCS = main.cpp common.h VX_SRCS = kernel.gemmini.cpp -VX_INCLUDES = ../sgemm_tcore/sgemm_impl.hpp +VX_INCLUDES = flash_impl.hpp ../sgemm_tcore/sgemm_impl.hpp OPTS ?= -n16 diff --git a/tests/regression/flash_attention/flash_impl.hpp b/tests/regression/flash_attention/flash_impl.hpp index 423ebd69..48a0068f 100644 --- a/tests/regression/flash_attention/flash_impl.hpp +++ b/tests/regression/flash_attention/flash_impl.hpp @@ -152,6 +152,7 @@ inline float exponential_taylor_term(const float x) { return res; } +template __attribute__((always_inline)) inline void thread_block_online_softmax( const float *smem_S, float *smem_P, const uint32_t tid_in_threadblock, const uint32_t threads_per_threadblock, @@ -180,7 +181,6 @@ __attribute__((always_inline)) inline void thread_block_online_softmax( // one warp handles one row in tile constexpr uint32_t per_row_iter = B_COL / NUM_THREADS; - uint32_t thread_offset = first_thread_offset + tid_in_warp; // FIXME: threadblock_id needs to be in here too float *warp_smem = smem_scratchpad + (warp_id * NUM_THREADS); @@ -219,11 +219,16 @@ __attribute__((always_inline)) inline void thread_block_online_softmax( float per_thread_max = FLT_MIN; #pragma GCC unroll for (int i = 0; i < per_row_iter; i++) { - const float next = smem_S[thread_offset]; + const uint32_t col_offset = NUM_THREADS * i; + const uint32_t col = col_offset + tid_in_warp; + const auto [smem_row, smem_col] = + remap_to_gemmini_dma_layout(row, col); + const uint32_t offset = B_COL * smem_row + smem_col; + + const float next = smem_S[offset]; asm volatile("fmax.s %0, %1, %2" : "=f"(per_thread_max) : "f"(per_thread_max), "f"(next)); - thread_offset += NUM_THREADS; } // stage per-thread max value in smem warp_smem[tid_in_warp] = per_thread_max; @@ -299,10 +304,15 @@ __attribute__((always_inline)) inline void thread_block_online_softmax( asm volatile("flashattn_exp_p_start_%=:" ::); - thread_offset = first_thread_offset + tid_in_warp; #pragma GCC unroll for (int i = 0; i < per_row_iter; i++) { - float f0 = smem_S[thread_offset]; + const uint32_t col_offset = NUM_THREADS * i; + const uint32_t col = col_offset + tid_in_warp; + const auto [smem_row, smem_col] = + remap_to_gemmini_dma_layout(row, col); + const uint32_t offset = B_COL * smem_row + smem_col; + + float f0 = smem_S[offset]; f0 -= rowmax_new; @@ -313,9 +323,7 @@ __attribute__((always_inline)) inline void thread_block_online_softmax( // Store S transposed to the shared memory - smem_P[thread_offset] = exp; - - thread_offset += NUM_THREADS; + smem_P[offset] = exp; } asm volatile("flashattn_exp_p_end_%=:" ::); @@ -332,11 +340,15 @@ __attribute__((always_inline)) inline void thread_block_online_softmax( float per_thread_sum = 0.0f; - thread_offset = first_thread_offset + tid_in_warp; #pragma GCC unroll for (int i = 0; i < per_row_iter; i++) { - per_thread_sum += smem_P[thread_offset]; - thread_offset += NUM_THREADS; + const uint32_t col_offset = NUM_THREADS * i; + const uint32_t col = col_offset + tid_in_warp; + const auto [smem_row, smem_col] = + remap_to_gemmini_dma_layout(row, col); + const uint32_t offset = B_COL * smem_row + smem_col; + + per_thread_sum += smem_P[offset]; } // stage per-thread sum value in smem // FIXME: threadblock_id needs to be in here too @@ -381,7 +393,6 @@ __attribute__((always_inline)) inline void thread_block_online_softmax( // asm volatile("flashattn_rescale_factor_start_%=:" ::); - thread_offset = first_thread_offset + tid_in_warp; #pragma GCC unroll for (int i = 0; i < per_row_iter; i++) { const float mi_prev = rowmax_prev; @@ -395,8 +406,6 @@ __attribute__((always_inline)) inline void thread_block_online_softmax( // @perf: div vs. expansion on e(-x)? smem_O_row_scale[row] = 1.0f / exp; - - thread_offset += NUM_THREADS; } asm volatile("flashattn_rescale_factor_end_%=:" ::); diff --git a/tests/regression/flash_attention/kernel.gemmini.cpp b/tests/regression/flash_attention/kernel.gemmini.cpp index 0df0cf87..4572c921 100644 --- a/tests/regression/flash_attention/kernel.gemmini.cpp +++ b/tests/regression/flash_attention/kernel.gemmini.cpp @@ -319,8 +319,8 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { gemmini_fence(); gemmini_fence(); -#if 0 -// weight-stationary matmul loop +#if 0 // TODO +// loop_ws variant that skips configuring strides #define gemmini_loop_ws(I, J, K, pad_I, pad_J, pad_K, A, B, D, C, A_stride, B_stride, D_stride, C_stride, A_transpose, B_transpose, full_C, low_D, ex_accumulate, act, a_spad_id, b_spad_id, is_resadd) \ { \ ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, ((uint64_t)(pad_K) << 32) | ((uint64_t)(pad_J) << 16) | (uint64_t)(pad_I), ((uint64_t)(K) << 32) | ((uint64_t)(J) << 16) | (uint64_t)(I), k_LOOP_WS_CONFIG_BOUNDS) \ @@ -373,17 +373,40 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { // inter-warpgroup barrier before online softmax threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core); -#if 0 // Online softmax // - thread_block_online_softmax(smem_S, smem_P, tid_in_warpgroup, - threads_per_warpgroup, warpgroup_id_in_cluster, - smem_scratchpad, smem_rowmax, smem_rowsum, - smem_O_row_scale); + thread_block_online_softmax( + smem_S, smem_P, tid_in_warpgroup, threads_per_warpgroup, + warpgroup_id_in_cluster, smem_scratchpad, smem_rowmax, smem_rowsum, + smem_O_row_scale); // FIXME: unnecessary? threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core); + if constexpr (DEBUG) { + if (warpgroup_id == 0) { + if (tile_k == 0) { + thread_block_copy_rowmax(smem_rowmax, gmem_tmp_e0, tid_in_warpgroup, + threads_per_warpgroup, + warpgroup_id_in_cluster); + thread_block_copy_rowmax(smem_rowsum, gmem_tmp_e2, tid_in_warpgroup, + threads_per_warpgroup, + warpgroup_id_in_cluster); + } else if (tile_k == 1) { + thread_block_copy_rowmax(smem_rowmax, gmem_tmp_e1, tid_in_warpgroup, + threads_per_warpgroup, + warpgroup_id_in_cluster); + thread_block_copy_rowmax(smem_rowsum, gmem_tmp_e3, tid_in_warpgroup, + threads_per_warpgroup, + warpgroup_id_in_cluster); + } + + threadblock_barrier(warpgroup_id_in_cluster, + warps_per_warpgroup_per_core); + } + } + +#if 0 // data movement for K and V // // Q stays in SMEM for the entire loop @@ -434,32 +457,6 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { } asm volatile("move_k_v_finish_%=:" ::); - // protect write to SMEM - threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core); - - if constexpr (DEBUG) { - if (warpgroup_id == 0) { - if (tile_k == 0) { - thread_block_copy_rowmax(smem_rowmax, gmem_tmp_e0, tid_in_warpgroup, - threads_per_warpgroup, - warpgroup_id_in_cluster); - thread_block_copy_rowmax(smem_rowsum, gmem_tmp_e2, tid_in_warpgroup, - threads_per_warpgroup, - warpgroup_id_in_cluster); - } else if (tile_k == 1) { - thread_block_copy_rowmax(smem_rowmax, gmem_tmp_e1, tid_in_warpgroup, - threads_per_warpgroup, - warpgroup_id_in_cluster); - thread_block_copy_rowmax(smem_rowsum, gmem_tmp_e3, tid_in_warpgroup, - threads_per_warpgroup, - warpgroup_id_in_cluster); - } - - threadblock_barrier(warpgroup_id_in_cluster, - warps_per_warpgroup_per_core); - } - } - // inter-warpgroup barrier before GEMM II threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core);