flash: Correct schedule with inter-warpgroup barriers
This commit is contained in:
@@ -601,6 +601,11 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
|
||||
constexpr uint32_t global_barrier_id = NUM_WARPS - 1; // arbitrary
|
||||
|
||||
// delay warpgroup 0 by 1 iteration to do ping-pong scheduling
|
||||
if (warpgroup_id == 1) {
|
||||
threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core);
|
||||
}
|
||||
|
||||
asm volatile ("tile_loop_start_%=:" :: );
|
||||
|
||||
// "inner loop" along the columns of K^T
|
||||
@@ -636,10 +641,6 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
|
||||
constexpr bool skip_gemm_qk = true;
|
||||
if constexpr (!skip_gemm_qk) {
|
||||
// clear out accumulators
|
||||
initialize_accum_regs<0>();
|
||||
initialize_accum_regs<1>();
|
||||
|
||||
static_assert(B_ROW == B_COL, "currently only supports square tiles");
|
||||
|
||||
// load Q
|
||||
@@ -659,6 +660,10 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
threadblock_barrier(warpgroup_id_in_cluster,
|
||||
warps_per_warpgroup_per_core);
|
||||
|
||||
// clear out accumulators before GEMM
|
||||
initialize_accum_regs<0>();
|
||||
initialize_accum_regs<1>();
|
||||
|
||||
// GEMM I: S = Q*K
|
||||
thread_block_gemm_single_tile<float, MemLayout::MN_major,
|
||||
MemLayout::MN_major, B_ROW, B_COL, HEADDIM,
|
||||
@@ -678,6 +683,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
// protect GEMM result writes (smem_S) before softmax
|
||||
threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core);
|
||||
|
||||
// inter-warpgroup barrier before online softmax
|
||||
threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core);
|
||||
|
||||
// Online softmax
|
||||
@@ -687,7 +693,14 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
smem_scratchpad, smem_rowmax, smem_rowsum,
|
||||
smem_O_row_scale);
|
||||
|
||||
// FIXME unnecessary?
|
||||
// TODO: put the data movement for QKV here for inter-warpgroup
|
||||
//
|
||||
// V dimension is [seqlen, headdim], stored N(headdim)-major
|
||||
load_tile_to_smem<float, MemLayout::MN_major, MemLayout::MN_major, B_COL,
|
||||
HEADDIM, threads_per_warpgroup>(
|
||||
HEADDIM, 0 /* full N-dimension */, tile_k_, gmem_V, smem_V,
|
||||
tid_in_warpgroup);
|
||||
|
||||
threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core);
|
||||
|
||||
if constexpr (DEBUG) {
|
||||
@@ -719,17 +732,11 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
}
|
||||
}
|
||||
|
||||
// inter-warpgroup barrier before GEMM II
|
||||
threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core);
|
||||
|
||||
// GEMM II: O = O + P*V
|
||||
|
||||
// V dimension is [seqlen, headdim], stored N(headdim)-major
|
||||
load_tile_to_smem<float, MemLayout::MN_major, MemLayout::MN_major, B_COL,
|
||||
HEADDIM, threads_per_warpgroup>(
|
||||
HEADDIM, 0 /* 0 because always reads the full N-dimension */, tile_k_,
|
||||
gmem_V, smem_V, tid_in_warpgroup);
|
||||
|
||||
// FIXME: should be removable
|
||||
threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core);
|
||||
|
||||
// Oi rescale
|
||||
thread_block_O_rescale(smem_O, smem_O /*in-place*/,
|
||||
smem_O_row_scale, tid_in_warpgroup,
|
||||
@@ -769,7 +776,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
}
|
||||
|
||||
if constexpr (!WARP_SPECIALIZED) {
|
||||
// clear out accumulators
|
||||
// clear out accumulators before GEMM
|
||||
initialize_accum_regs<0>();
|
||||
initialize_accum_regs<1>();
|
||||
|
||||
@@ -802,7 +809,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
float *smem_O_half0 = smem_O;
|
||||
float *smem_O_half1 = smem_O + (B_ROW / 2) * HEADDIM;
|
||||
|
||||
// clear out accumulators
|
||||
// clear out accumulators before GEMM
|
||||
initialize_accum_regs<0>();
|
||||
initialize_accum_regs<1>();
|
||||
|
||||
@@ -855,6 +862,12 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
}
|
||||
|
||||
asm volatile ("tile_loop_finish_%=:" :: );
|
||||
|
||||
// wait for warpgroup 1 to finish, which called the global barrier before
|
||||
// entering the loop
|
||||
if (warpgroup_id == 0) {
|
||||
threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core);
|
||||
}
|
||||
}
|
||||
|
||||
int main() {
|
||||
|
||||
Reference in New Issue
Block a user