flash: Do accumulation of PV into O using the single_tile API
This commit is contained in:
@@ -371,10 +371,13 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
threadblock_barrier(threadblock_id_in_cluster,
|
threadblock_barrier(threadblock_id_in_cluster,
|
||||||
warps_per_threadblock_per_core);
|
warps_per_threadblock_per_core);
|
||||||
|
|
||||||
|
// GEMM I: S = Q*K
|
||||||
thread_block_gemm_single_tile<float, MemLayout::MN_major, MemLayout::MN_major,
|
thread_block_gemm_single_tile<float, MemLayout::MN_major, MemLayout::MN_major,
|
||||||
|
/*load_accum=*/false,
|
||||||
/*write_to_smem=*/true>(
|
/*write_to_smem=*/true>(
|
||||||
smem_Q, smem_K, smem_S, tid_in_threadblock, threads_per_threadblock,
|
smem_Q, smem_K, nullptr /*ignore accum*/, smem_S, tid_in_threadblock,
|
||||||
threadblocks_per_cluster, threadblock_id_in_cluster);
|
threads_per_threadblock, threadblocks_per_cluster,
|
||||||
|
threadblock_id_in_cluster);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
// protect GEMM result writes (smem_S) before softmax
|
// protect GEMM result writes (smem_S) before softmax
|
||||||
@@ -395,21 +398,24 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
threadblock_barrier(threadblock_id_in_cluster,
|
threadblock_barrier(threadblock_id_in_cluster,
|
||||||
warps_per_threadblock_per_core);
|
warps_per_threadblock_per_core);
|
||||||
|
|
||||||
|
// GEMM II: O = O + P*V
|
||||||
|
|
||||||
// clear out accumulators
|
// clear out accumulators
|
||||||
initialize_accum_regs<0>();
|
initialize_accum_regs<0>();
|
||||||
initialize_accum_regs<1>();
|
initialize_accum_regs<1>();
|
||||||
|
|
||||||
load_tile_to_smem<float, MemLayout::MN_major, MemLayout::MN_major, BN, BK>(
|
load_tile_to_smem<float, MemLayout::MN_major, MemLayout::MN_major, BN, BK>(
|
||||||
B_COL, 0 /*FIXME*/, 0 /*FIXME*/, gmem_V, smem_V, tid_in_threadblock);
|
B_COL, 0, 0, gmem_V, smem_V, tid_in_threadblock);
|
||||||
|
|
||||||
threadblock_barrier(threadblock_id_in_cluster,
|
threadblock_barrier(threadblock_id_in_cluster,
|
||||||
warps_per_threadblock_per_core);
|
warps_per_threadblock_per_core);
|
||||||
|
|
||||||
// FIXME: support MN_major for A for ideal performance
|
// FIXME: support MN_major for A for ideal performance
|
||||||
thread_block_gemm_single_tile<float, MemLayout::K_major, MemLayout::MN_major,
|
thread_block_gemm_single_tile<float, MemLayout::K_major, MemLayout::MN_major,
|
||||||
|
/*load_accum=*/false,
|
||||||
/*write_to_smem=*/true>(
|
/*write_to_smem=*/true>(
|
||||||
smem_P, smem_V, gmem_O /*smem_O*/, tid_in_threadblock,
|
smem_P, smem_V, smem_O, gmem_O /*smem_O*/,
|
||||||
threads_per_threadblock, threadblocks_per_cluster,
|
tid_in_threadblock, threads_per_threadblock, threadblocks_per_cluster,
|
||||||
threadblock_id_in_cluster);
|
threadblock_id_in_cluster);
|
||||||
|
|
||||||
threadblock_barrier(threadblock_id_in_cluster,
|
threadblock_barrier(threadblock_id_in_cluster,
|
||||||
|
|||||||
Reference in New Issue
Block a user