From 53dfc690b9b05da4c27dc2699c5d28ffed9cc4c8 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Wed, 14 Aug 2024 21:50:20 -0700 Subject: [PATCH] flash: Allocate smem properly for rowsum and scratch --- tests/regression/flash_attention/kernel.cpp | 48 +++++++++++++++------ 1 file changed, 36 insertions(+), 12 deletions(-) diff --git a/tests/regression/flash_attention/kernel.cpp b/tests/regression/flash_attention/kernel.cpp index 5a6e71e7..64d0c302 100644 --- a/tests/regression/flash_attention/kernel.cpp +++ b/tests/regression/flash_attention/kernel.cpp @@ -10,15 +10,17 @@ // using float_type = float; using float_type = float16_t; +#define B_ROW BM +#define B_COL BN + inline void thread_block_flashattn(float *S, float *gmem, const uint32_t tid_in_threadblock, const uint32_t threads_per_threadblock, const uint32_t threadblock_id_in_cluster, - uint8_t *sharedmem_per_threadblock) { + float *sharedmem_scratchpad, + float *sharedmem_row_max_sum) { asm volatile("thread_block_flashattn_start_%=:" ::); - constexpr uint32_t Brow = BM; // FIXME - constexpr uint32_t Bcol = BN; // FIXME const uint32_t tid_in_warp = tid_in_threadblock % NUM_THREADS; const uint32_t warp_id = tid_in_threadblock / NUM_THREADS; const uint32_t warps_in_threadblock = threads_per_threadblock / NUM_THREADS; @@ -34,17 +36,19 @@ inline void thread_block_flashattn(float *S, float *gmem, // asm volatile("fmv.s %0, f21" : "=f"(ft[5])); // asm volatile("fmv.s %0, f22" : "=f"(ft[6])); // asm volatile("fmv.s %0, f23" : "=f"(ft[7])); + + // row-max // // one warp handles one row in tile; iterate enough times to cover all the // rows - for (int warp_offset = 0; warp_offset < Brow; + for (int warp_offset = 0; warp_offset < B_ROW; warp_offset += warps_in_threadblock) { const uint32_t row = warp_offset + warp_id; - const uint32_t first_thread_offset = Bcol * row; + const uint32_t first_thread_offset = B_COL * row; uint32_t thread_offset = first_thread_offset + tid_in_warp; float curr_max = S[first_thread_offset]; - constexpr uint32_t load_iter = Bcol / NUM_THREADS; + constexpr uint32_t load_iter = B_COL / NUM_THREADS; #pragma GCC unroll for (int iter = 0; iter < load_iter; iter++) { asm volatile("fmax.s %0, %1, %2" @@ -53,7 +57,8 @@ inline void thread_block_flashattn(float *S, float *gmem, thread_offset += NUM_THREADS; } // get max value across the same-warp threads using smem - float *warp_smem = S + (2 * Brow * Bcol) + (row * NUM_THREADS); + // NOTE: be careful with out-of-bounds + float *warp_smem = sharedmem_scratchpad + (row * NUM_THREADS); warp_smem[tid_in_warp] = curr_max; // sync writes to warp_smem @@ -68,10 +73,18 @@ inline void thread_block_flashattn(float *S, float *gmem, : "=f"(curr_max) : "f"(curr_max), "f"(other)); } - gmem[row] = curr_max; + sharedmem_row_max_sum[row] = curr_max; } } + // exponential + // + // FIXME: placeholder for proper exp + constexpr uint32_t exp_elem_per_thread = 1; + // B_ROW / (B_ROW * B_COL / (exp_elem * threads_per_threadblock)) + const uint32_t row_stride = + (exp_elem_per_thread * threads_per_threadblock) / B_COL; + asm volatile("thread_block_flashattn_finish_%=:" ::); } @@ -116,6 +129,16 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { (2 * BM * BK) * threadblock_id_in_cluster); uint8_t *smem_S = sharedmem_per_threadblock; + constexpr uint32_t sharedmem_row_max_sum_size = 2 * sizeof(float) * B_ROW; + // sharedmem area to store rowmax/rowsum values in softmax + uint8_t *sharedmem_row_max_sum = + reinterpret_cast(SMEM_ADDR_END) - sharedmem_row_max_sum_size; + // sharedmem "scratchpad" area to put temporary data, e.g. for tree reduction + // in rowsum + // FIXME: size is arbitrary, and out-of bounds is not checked + constexpr uint32_t sharedmem_scratchpad_size = 0x1000; + uint8_t *sharedmem_scratchpad = + sharedmem_row_max_sum - sharedmem_scratchpad_size; thread_block_gemm( (const float_type *)arg->addr_a, (const float_type *)arg->addr_b, @@ -124,15 +147,16 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { threadblocks_per_cluster, threadblock_id_in_cluster, sharedmem_per_threadblock); - // sync writes of GEMM results before softmax + // protect writes of GEMM results before softmax const uint32_t warps_per_threadblock_per_core = NUM_WARPS / threads_per_threadblock; threadblock_barrier(threadblock_id_in_cluster, warps_per_threadblock_per_core); - thread_block_flashattn((float *)smem_S, (float *)arg->addr_c, - tid_in_threadblock, threads_per_threadblock, - threadblock_id_in_cluster, sharedmem_per_threadblock); + thread_block_flashattn( + (float *)smem_S, (float *)arg->addr_c, tid_in_threadblock, + threads_per_threadblock, threadblock_id_in_cluster, + (float *)sharedmem_scratchpad_size, (float *)sharedmem_row_max_sum); } int main() {