flash: Fix single-tile GEMM for warp-specialized
With 4 warps, we can only do 32x64 GEMM; serialize 64x64 into 2 32x64 GEMM calls by split by the row.
This commit is contained in:
@@ -635,22 +635,55 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
threadblock_barrier(threadblock_id_in_cluster,
|
||||
warps_per_threadblock_per_core);
|
||||
|
||||
thread_block_gemm_single_tile<float, MemLayout::K_major,
|
||||
MemLayout::MN_major, B_ROW, HEADDIM, B_COL,
|
||||
/*load_accum=*/true,
|
||||
/*write_to_smem=*/true>(
|
||||
smem_P, smem_V, smem_O /*load accum*/, smem_O, tid_in_threadblock,
|
||||
threads_per_threadblock, threadblocks_per_cluster,
|
||||
threadblock_id_in_cluster);
|
||||
// FIXME: wrong but fast
|
||||
// thread_block_gemm_single_tile<float, MemLayout::MN_major,
|
||||
// MemLayout::MN_major,
|
||||
// B_ROW, HEADDIM, B_COL,
|
||||
// /*load_accum=*/true,
|
||||
// /*write_to_smem=*/true>(
|
||||
// smem_P, smem_V, smem_O /*load accum*/, smem_O, tid_in_threadblock,
|
||||
// threads_per_threadblock, threadblocks_per_cluster,
|
||||
// threadblock_id_in_cluster);
|
||||
if constexpr (!DOUBLE_BUF) {
|
||||
thread_block_gemm_single_tile<float, MemLayout::K_major,
|
||||
MemLayout::MN_major, B_ROW, HEADDIM, B_COL,
|
||||
/*load_accum=*/true,
|
||||
/*write_to_smem=*/true>(
|
||||
smem_P, smem_V, smem_O /*load accum*/, smem_O, tid_in_threadblock,
|
||||
threads_per_threadblock, threadblocks_per_cluster,
|
||||
threadblock_id_in_cluster);
|
||||
|
||||
// FIXME: wrong but fast
|
||||
// thread_block_gemm_single_tile<float, MemLayout::MN_major,
|
||||
// MemLayout::MN_major,
|
||||
// B_ROW, HEADDIM, B_COL,
|
||||
// /*load_accum=*/true,
|
||||
// /*write_to_smem=*/true>(
|
||||
// smem_P, smem_V, smem_O /*load accum*/, smem_O, tid_in_threadblock,
|
||||
// threads_per_threadblock, threadblocks_per_cluster,
|
||||
// threadblock_id_in_cluster);
|
||||
} else {
|
||||
// when warp-specialized, there's only enough warps to do 64x32 tile size
|
||||
// so we need to do 2 GEMM calls
|
||||
static_assert(B_ROW / 2 == 32,
|
||||
"tile size assumption for warp-specialization not met");
|
||||
|
||||
// assumes smem_P is K-major
|
||||
float *smem_P0 = smem_P;
|
||||
float *smem_P1 = smem_P + (B_ROW / 2) * B_COL;
|
||||
float *smem_O0 = smem_O;
|
||||
float *smem_O1 = smem_O + (B_ROW / 2) * HEADDIM;
|
||||
|
||||
// split by rows into 2 chunks
|
||||
thread_block_gemm_single_tile<float, MemLayout::K_major,
|
||||
MemLayout::MN_major, B_ROW / 2, HEADDIM,
|
||||
B_COL,
|
||||
/*load_accum=*/true,
|
||||
/*write_to_smem=*/true>(
|
||||
smem_P0, smem_V, smem_O0 /*load accum*/, smem_O0, tid_in_threadblock,
|
||||
threads_per_threadblock, threadblocks_per_cluster,
|
||||
threadblock_id_in_cluster);
|
||||
|
||||
thread_block_gemm_single_tile<float, MemLayout::K_major,
|
||||
MemLayout::MN_major, B_ROW / 2, HEADDIM,
|
||||
B_COL,
|
||||
/*load_accum=*/true,
|
||||
/*write_to_smem=*/true>(
|
||||
smem_P1, smem_V, smem_O1 /*load accum*/, smem_O1, tid_in_threadblock,
|
||||
threads_per_threadblock, threadblocks_per_cluster,
|
||||
threadblock_id_in_cluster);
|
||||
}
|
||||
|
||||
threadblock_barrier(threadblock_id_in_cluster,
|
||||
warps_per_threadblock_per_core);
|
||||
|
||||
Reference in New Issue
Block a user