flash: Compile time flag for skipping GEMM
This commit is contained in:
@@ -260,26 +260,34 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
uint8_t *sharedmem_scratchpad =
|
||||
sharedmem_rowmax - sharedmem_scratchpad_size;
|
||||
|
||||
const uint32_t warps_per_threadblock_per_core =
|
||||
NUM_WARPS / threads_per_threadblock;
|
||||
|
||||
// initialize rowmax/rowsum values in sharedmem
|
||||
thread_block_init_sharedmem(tid_in_threadblock, threads_per_threadblock,
|
||||
(float *)sharedmem_scratchpad,
|
||||
(float *)sharedmem_rowmax,
|
||||
(float *)sharedmem_rowsum);
|
||||
|
||||
// thread_block_gemm<float_type, /*write_to_gmem=*/true>(
|
||||
// (const float_type *)arg->addr_a, (const float_type *)arg->addr_b,
|
||||
// (float *)smem_S /*write result to SMEM */, arg->dim_m, arg->dim_n,
|
||||
// arg->dim_k, tid_in_threadblock, threads_per_threadblock,
|
||||
// threadblocks_per_cluster, threadblock_id_in_cluster,
|
||||
// sharedmem_per_threadblock);
|
||||
#define SKIP_GEMM
|
||||
#ifndef SKIP_GEMM
|
||||
thread_block_gemm<float_type, /*write_to_gmem=*/true>(
|
||||
(const float_type *)arg->addr_a, (const float_type *)arg->addr_b,
|
||||
(float *)smem_S /*write result to SMEM */, arg->dim_m, arg->dim_n,
|
||||
arg->dim_k, tid_in_threadblock, threads_per_threadblock,
|
||||
threadblocks_per_cluster, threadblock_id_in_cluster,
|
||||
sharedmem_per_threadblock);
|
||||
|
||||
// protect writes of GEMM results before softmax
|
||||
const uint32_t warps_per_threadblock_per_core =
|
||||
NUM_WARPS / threads_per_threadblock;
|
||||
threadblock_barrier(threadblock_id_in_cluster,
|
||||
warps_per_threadblock_per_core);
|
||||
|
||||
thread_block_flashattn((float *)arg->addr_a /* smem_S, */, tid_in_threadblock,
|
||||
float *tile_S = (float *)smem_S;
|
||||
#else
|
||||
float *tile_S = (float *)arg->addr_a;
|
||||
#endif
|
||||
|
||||
thread_block_flashattn(tile_S, tid_in_threadblock,
|
||||
threads_per_threadblock, threadblock_id_in_cluster,
|
||||
(float *)sharedmem_scratchpad,
|
||||
(float *)sharedmem_rowmax, (float *)sharedmem_rowsum);
|
||||
|
||||
Reference in New Issue
Block a user