flash: Correct schedule with inter-warpgroup barriers

This commit is contained in:
Hansung Kim
2024-09-01 20:40:26 -07:00
parent e5e65312d2
commit aea257349a

View File

@@ -601,6 +601,11 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
constexpr uint32_t global_barrier_id = NUM_WARPS - 1; // arbitrary
// delay warpgroup 0 by 1 iteration to do ping-pong scheduling
if (warpgroup_id == 1) {
threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core);
}
asm volatile ("tile_loop_start_%=:" :: );
// "inner loop" along the columns of K^T
@@ -636,10 +641,6 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
constexpr bool skip_gemm_qk = true;
if constexpr (!skip_gemm_qk) {
// clear out accumulators
initialize_accum_regs<0>();
initialize_accum_regs<1>();
static_assert(B_ROW == B_COL, "currently only supports square tiles");
// load Q
@@ -659,6 +660,10 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
threadblock_barrier(warpgroup_id_in_cluster,
warps_per_warpgroup_per_core);
// clear out accumulators before GEMM
initialize_accum_regs<0>();
initialize_accum_regs<1>();
// GEMM I: S = Q*K
thread_block_gemm_single_tile<float, MemLayout::MN_major,
MemLayout::MN_major, B_ROW, B_COL, HEADDIM,
@@ -678,6 +683,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
// protect GEMM result writes (smem_S) before softmax
threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core);
// inter-warpgroup barrier before online softmax
threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core);
// Online softmax
@@ -687,7 +693,14 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
smem_scratchpad, smem_rowmax, smem_rowsum,
smem_O_row_scale);
// FIXME unnecessary?
// TODO: put the data movement for QKV here for inter-warpgroup
//
// V dimension is [seqlen, headdim], stored N(headdim)-major
load_tile_to_smem<float, MemLayout::MN_major, MemLayout::MN_major, B_COL,
HEADDIM, threads_per_warpgroup>(
HEADDIM, 0 /* full N-dimension */, tile_k_, gmem_V, smem_V,
tid_in_warpgroup);
threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core);
if constexpr (DEBUG) {
@@ -719,17 +732,11 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
}
}
// inter-warpgroup barrier before GEMM II
threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core);
// GEMM II: O = O + P*V
// V dimension is [seqlen, headdim], stored N(headdim)-major
load_tile_to_smem<float, MemLayout::MN_major, MemLayout::MN_major, B_COL,
HEADDIM, threads_per_warpgroup>(
HEADDIM, 0 /* 0 because always reads the full N-dimension */, tile_k_,
gmem_V, smem_V, tid_in_warpgroup);
// FIXME: should be removable
threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core);
// Oi rescale
thread_block_O_rescale(smem_O, smem_O /*in-place*/,
smem_O_row_scale, tid_in_warpgroup,
@@ -769,7 +776,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
}
if constexpr (!WARP_SPECIALIZED) {
// clear out accumulators
// clear out accumulators before GEMM
initialize_accum_regs<0>();
initialize_accum_regs<1>();
@@ -802,7 +809,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
float *smem_O_half0 = smem_O;
float *smem_O_half1 = smem_O + (B_ROW / 2) * HEADDIM;
// clear out accumulators
// clear out accumulators before GEMM
initialize_accum_regs<0>();
initialize_accum_regs<1>();
@@ -855,6 +862,12 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
}
asm volatile ("tile_loop_finish_%=:" :: );
// wait for warpgroup 1 to finish, which called the global barrier before
// entering the loop
if (warpgroup_id == 0) {
threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core);
}
}
int main() {