flash: Compile time flag for skipping GEMM

This commit is contained in:
Hansung Kim
2024-08-15 17:40:32 -07:00
parent f844d96eea
commit ac44633b39

View File

@@ -260,26 +260,34 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
uint8_t *sharedmem_scratchpad = uint8_t *sharedmem_scratchpad =
sharedmem_rowmax - sharedmem_scratchpad_size; sharedmem_rowmax - sharedmem_scratchpad_size;
const uint32_t warps_per_threadblock_per_core =
NUM_WARPS / threads_per_threadblock;
// initialize rowmax/rowsum values in sharedmem // initialize rowmax/rowsum values in sharedmem
thread_block_init_sharedmem(tid_in_threadblock, threads_per_threadblock, thread_block_init_sharedmem(tid_in_threadblock, threads_per_threadblock,
(float *)sharedmem_scratchpad, (float *)sharedmem_scratchpad,
(float *)sharedmem_rowmax, (float *)sharedmem_rowmax,
(float *)sharedmem_rowsum); (float *)sharedmem_rowsum);
// thread_block_gemm<float_type, /*write_to_gmem=*/true>( #define SKIP_GEMM
// (const float_type *)arg->addr_a, (const float_type *)arg->addr_b, #ifndef SKIP_GEMM
// (float *)smem_S /*write result to SMEM */, arg->dim_m, arg->dim_n, thread_block_gemm<float_type, /*write_to_gmem=*/true>(
// arg->dim_k, tid_in_threadblock, threads_per_threadblock, (const float_type *)arg->addr_a, (const float_type *)arg->addr_b,
// threadblocks_per_cluster, threadblock_id_in_cluster, (float *)smem_S /*write result to SMEM */, arg->dim_m, arg->dim_n,
// sharedmem_per_threadblock); arg->dim_k, tid_in_threadblock, threads_per_threadblock,
threadblocks_per_cluster, threadblock_id_in_cluster,
sharedmem_per_threadblock);
// protect writes of GEMM results before softmax // protect writes of GEMM results before softmax
const uint32_t warps_per_threadblock_per_core =
NUM_WARPS / threads_per_threadblock;
threadblock_barrier(threadblock_id_in_cluster, threadblock_barrier(threadblock_id_in_cluster,
warps_per_threadblock_per_core); warps_per_threadblock_per_core);
thread_block_flashattn((float *)arg->addr_a /* smem_S, */, tid_in_threadblock, float *tile_S = (float *)smem_S;
#else
float *tile_S = (float *)arg->addr_a;
#endif
thread_block_flashattn(tile_S, tid_in_threadblock,
threads_per_threadblock, threadblock_id_in_cluster, threads_per_threadblock, threadblock_id_in_cluster,
(float *)sharedmem_scratchpad, (float *)sharedmem_scratchpad,
(float *)sharedmem_rowmax, (float *)sharedmem_rowsum); (float *)sharedmem_rowmax, (float *)sharedmem_rowsum);