From ac44633b39d65934378d83a65f059292bd73cb66 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Thu, 15 Aug 2024 17:40:32 -0700 Subject: [PATCH] flash: Compile time flag for skipping GEMM --- tests/regression/flash_attention/kernel.cpp | 26 ++++++++++++++------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/tests/regression/flash_attention/kernel.cpp b/tests/regression/flash_attention/kernel.cpp index 918f3607..888db94a 100644 --- a/tests/regression/flash_attention/kernel.cpp +++ b/tests/regression/flash_attention/kernel.cpp @@ -260,26 +260,34 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { uint8_t *sharedmem_scratchpad = sharedmem_rowmax - sharedmem_scratchpad_size; + const uint32_t warps_per_threadblock_per_core = + NUM_WARPS / threads_per_threadblock; + // initialize rowmax/rowsum values in sharedmem thread_block_init_sharedmem(tid_in_threadblock, threads_per_threadblock, (float *)sharedmem_scratchpad, (float *)sharedmem_rowmax, (float *)sharedmem_rowsum); - // thread_block_gemm( - // (const float_type *)arg->addr_a, (const float_type *)arg->addr_b, - // (float *)smem_S /*write result to SMEM */, arg->dim_m, arg->dim_n, - // arg->dim_k, tid_in_threadblock, threads_per_threadblock, - // threadblocks_per_cluster, threadblock_id_in_cluster, - // sharedmem_per_threadblock); +#define SKIP_GEMM +#ifndef SKIP_GEMM + thread_block_gemm( + (const float_type *)arg->addr_a, (const float_type *)arg->addr_b, + (float *)smem_S /*write result to SMEM */, arg->dim_m, arg->dim_n, + arg->dim_k, tid_in_threadblock, threads_per_threadblock, + threadblocks_per_cluster, threadblock_id_in_cluster, + sharedmem_per_threadblock); // protect writes of GEMM results before softmax - const uint32_t warps_per_threadblock_per_core = - NUM_WARPS / threads_per_threadblock; threadblock_barrier(threadblock_id_in_cluster, warps_per_threadblock_per_core); - thread_block_flashattn((float *)arg->addr_a /* smem_S, */, tid_in_threadblock, + float *tile_S = (float *)smem_S; +#else + float *tile_S = (float *)arg->addr_a; +#endif + + thread_block_flashattn(tile_S, tid_in_threadblock, threads_per_threadblock, threadblock_id_in_cluster, (float *)sharedmem_scratchpad, (float *)sharedmem_rowmax, (float *)sharedmem_rowsum);