From 6cc1b5ca37a91c9284667843381d32710361db98 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Sun, 1 Sep 2024 16:02:06 -0700 Subject: [PATCH] flash: Reduce smem_scratchpad alloc size --- tests/regression/flash_attention/kernel.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/regression/flash_attention/kernel.cpp b/tests/regression/flash_attention/kernel.cpp index 2b1fea33..5e6f5b9b 100644 --- a/tests/regression/flash_attention/kernel.cpp +++ b/tests/regression/flash_attention/kernel.cpp @@ -545,7 +545,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { // NOTE: out-of bounds is not checked // TODO: reduce this from B_ROW to NUM_WARPS constexpr uint32_t smem_scratchpad_size = - B_ROW * NUM_THREADS * 2 /*arbitrary slack*/; + threads_per_warpgroup * 2 /*arbitrary slack*/; float *smem_scratchpad = smem_O_row_scale_1 - smem_scratchpad_size; // initialize rowmax/rowsum values in sharedmem