flash: Compile time flag for skipping GEMM
This commit is contained in:
@@ -260,26 +260,34 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
uint8_t *sharedmem_scratchpad =
|
uint8_t *sharedmem_scratchpad =
|
||||||
sharedmem_rowmax - sharedmem_scratchpad_size;
|
sharedmem_rowmax - sharedmem_scratchpad_size;
|
||||||
|
|
||||||
|
const uint32_t warps_per_threadblock_per_core =
|
||||||
|
NUM_WARPS / threads_per_threadblock;
|
||||||
|
|
||||||
// initialize rowmax/rowsum values in sharedmem
|
// initialize rowmax/rowsum values in sharedmem
|
||||||
thread_block_init_sharedmem(tid_in_threadblock, threads_per_threadblock,
|
thread_block_init_sharedmem(tid_in_threadblock, threads_per_threadblock,
|
||||||
(float *)sharedmem_scratchpad,
|
(float *)sharedmem_scratchpad,
|
||||||
(float *)sharedmem_rowmax,
|
(float *)sharedmem_rowmax,
|
||||||
(float *)sharedmem_rowsum);
|
(float *)sharedmem_rowsum);
|
||||||
|
|
||||||
// thread_block_gemm<float_type, /*write_to_gmem=*/true>(
|
#define SKIP_GEMM
|
||||||
// (const float_type *)arg->addr_a, (const float_type *)arg->addr_b,
|
#ifndef SKIP_GEMM
|
||||||
// (float *)smem_S /*write result to SMEM */, arg->dim_m, arg->dim_n,
|
thread_block_gemm<float_type, /*write_to_gmem=*/true>(
|
||||||
// arg->dim_k, tid_in_threadblock, threads_per_threadblock,
|
(const float_type *)arg->addr_a, (const float_type *)arg->addr_b,
|
||||||
// threadblocks_per_cluster, threadblock_id_in_cluster,
|
(float *)smem_S /*write result to SMEM */, arg->dim_m, arg->dim_n,
|
||||||
// sharedmem_per_threadblock);
|
arg->dim_k, tid_in_threadblock, threads_per_threadblock,
|
||||||
|
threadblocks_per_cluster, threadblock_id_in_cluster,
|
||||||
|
sharedmem_per_threadblock);
|
||||||
|
|
||||||
// protect writes of GEMM results before softmax
|
// protect writes of GEMM results before softmax
|
||||||
const uint32_t warps_per_threadblock_per_core =
|
|
||||||
NUM_WARPS / threads_per_threadblock;
|
|
||||||
threadblock_barrier(threadblock_id_in_cluster,
|
threadblock_barrier(threadblock_id_in_cluster,
|
||||||
warps_per_threadblock_per_core);
|
warps_per_threadblock_per_core);
|
||||||
|
|
||||||
thread_block_flashattn((float *)arg->addr_a /* smem_S, */, tid_in_threadblock,
|
float *tile_S = (float *)smem_S;
|
||||||
|
#else
|
||||||
|
float *tile_S = (float *)arg->addr_a;
|
||||||
|
#endif
|
||||||
|
|
||||||
|
thread_block_flashattn(tile_S, tid_in_threadblock,
|
||||||
threads_per_threadblock, threadblock_id_in_cluster,
|
threads_per_threadblock, threadblock_id_in_cluster,
|
||||||
(float *)sharedmem_scratchpad,
|
(float *)sharedmem_scratchpad,
|
||||||
(float *)sharedmem_rowmax, (float *)sharedmem_rowsum);
|
(float *)sharedmem_rowmax, (float *)sharedmem_rowsum);
|
||||||
|
|||||||
Reference in New Issue
Block a user