flash: Fix single-tile GEMM for warp-specialized

With 4 warps, we can only do 32x64 GEMM; serialize 64x64 into 2 32x64
GEMM calls by split by the row.
This commit is contained in:
Hansung Kim
2024-08-30 17:12:46 -07:00
parent 72b6004e24
commit 986d507223

View File

@@ -635,22 +635,55 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
threadblock_barrier(threadblock_id_in_cluster,
warps_per_threadblock_per_core);
thread_block_gemm_single_tile<float, MemLayout::K_major,
MemLayout::MN_major, B_ROW, HEADDIM, B_COL,
/*load_accum=*/true,
/*write_to_smem=*/true>(
smem_P, smem_V, smem_O /*load accum*/, smem_O, tid_in_threadblock,
threads_per_threadblock, threadblocks_per_cluster,
threadblock_id_in_cluster);
// FIXME: wrong but fast
// thread_block_gemm_single_tile<float, MemLayout::MN_major,
// MemLayout::MN_major,
// B_ROW, HEADDIM, B_COL,
// /*load_accum=*/true,
// /*write_to_smem=*/true>(
// smem_P, smem_V, smem_O /*load accum*/, smem_O, tid_in_threadblock,
// threads_per_threadblock, threadblocks_per_cluster,
// threadblock_id_in_cluster);
if constexpr (!DOUBLE_BUF) {
thread_block_gemm_single_tile<float, MemLayout::K_major,
MemLayout::MN_major, B_ROW, HEADDIM, B_COL,
/*load_accum=*/true,
/*write_to_smem=*/true>(
smem_P, smem_V, smem_O /*load accum*/, smem_O, tid_in_threadblock,
threads_per_threadblock, threadblocks_per_cluster,
threadblock_id_in_cluster);
// FIXME: wrong but fast
// thread_block_gemm_single_tile<float, MemLayout::MN_major,
// MemLayout::MN_major,
// B_ROW, HEADDIM, B_COL,
// /*load_accum=*/true,
// /*write_to_smem=*/true>(
// smem_P, smem_V, smem_O /*load accum*/, smem_O, tid_in_threadblock,
// threads_per_threadblock, threadblocks_per_cluster,
// threadblock_id_in_cluster);
} else {
// when warp-specialized, there's only enough warps to do 64x32 tile size
// so we need to do 2 GEMM calls
static_assert(B_ROW / 2 == 32,
"tile size assumption for warp-specialization not met");
// assumes smem_P is K-major
float *smem_P0 = smem_P;
float *smem_P1 = smem_P + (B_ROW / 2) * B_COL;
float *smem_O0 = smem_O;
float *smem_O1 = smem_O + (B_ROW / 2) * HEADDIM;
// split by rows into 2 chunks
thread_block_gemm_single_tile<float, MemLayout::K_major,
MemLayout::MN_major, B_ROW / 2, HEADDIM,
B_COL,
/*load_accum=*/true,
/*write_to_smem=*/true>(
smem_P0, smem_V, smem_O0 /*load accum*/, smem_O0, tid_in_threadblock,
threads_per_threadblock, threadblocks_per_cluster,
threadblock_id_in_cluster);
thread_block_gemm_single_tile<float, MemLayout::K_major,
MemLayout::MN_major, B_ROW / 2, HEADDIM,
B_COL,
/*load_accum=*/true,
/*write_to_smem=*/true>(
smem_P1, smem_V, smem_O1 /*load accum*/, smem_O1, tid_in_threadblock,
threads_per_threadblock, threadblocks_per_cluster,
threadblock_id_in_cluster);
}
threadblock_barrier(threadblock_id_in_cluster,
warps_per_threadblock_per_core);