From 602fe4a400401d71a7e295ecb04e287784424007 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Sun, 1 Sep 2024 22:06:46 -0700 Subject: [PATCH] flash: Change timing for QKV move Verified with warp_specialized false; true remains to be fixed. --- tests/regression/flash_attention/kernel.cpp | 269 ++++++++++++-------- 1 file changed, 162 insertions(+), 107 deletions(-) diff --git a/tests/regression/flash_attention/kernel.cpp b/tests/regression/flash_attention/kernel.cpp index f58f3000..12abcd8e 100644 --- a/tests/regression/flash_attention/kernel.cpp +++ b/tests/regression/flash_attention/kernel.cpp @@ -15,7 +15,7 @@ constexpr uint32_t ROWMAX_SETS = 3; constexpr bool DEBUG = true; -constexpr bool WARP_SPECIALIZED = true; +constexpr bool WARP_SPECIALIZED = false; constexpr uint32_t DEV_FAKE_SMEM_START_ADDR = 0xf0000000; @@ -490,8 +490,10 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { threads_per_threadblock / NUM_THREADS; // warpgroup context - constexpr uint32_t threads_per_warpgroup = threads_per_threadblock / 2; - constexpr uint32_t warpgroups_per_cluster = threadblocks_per_cluster * 2; + constexpr uint32_t threads_per_warpgroup = + threads_per_threadblock / (WARP_SPECIALIZED ? 2 : 1); + constexpr uint32_t warpgroups_per_cluster = + threadblocks_per_cluster * (WARP_SPECIALIZED ? 2 : 1); const uint32_t warps_per_warpgroup_per_core = NUM_WARPS / warpgroups_per_cluster; const uint32_t warpgroup_id = task_id / threads_per_warpgroup; @@ -507,6 +509,25 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { const uint32_t dim_seqlen = arg->dim_seqlen; const uint32_t dim_headdim = arg->dim_headdim; + // get global memory addresses from kernel arguments + const float *gmem_Q = reinterpret_cast(arg->addr_q); + const float *gmem_K = reinterpret_cast(arg->addr_k); + const float *gmem_V = reinterpret_cast(arg->addr_v); + float *gmem_O = reinterpret_cast(arg->addr_o); + + float *gmem_tmp_d0 = reinterpret_cast(0xd0000000UL); + float *gmem_tmp_d1 = reinterpret_cast(0xd1000000UL); + float *gmem_tmp_d2 = reinterpret_cast(0xd2000000UL); + float *gmem_tmp_d3 = reinterpret_cast(0xd3000000UL); + float *gmem_tmp_d4 = reinterpret_cast(0xd4000000UL); + float *gmem_tmp_d5 = reinterpret_cast(0xd5000000UL); + float *gmem_tmp_d6 = reinterpret_cast(0xd6000000UL); + float *gmem_tmp_d7 = reinterpret_cast(0xd7000000UL); + float *gmem_tmp_e0 = reinterpret_cast(0xe0000000UL); + float *gmem_tmp_e1 = reinterpret_cast(0xe1000000UL); + float *gmem_tmp_e2 = reinterpret_cast(0xe2000000UL); + float *gmem_tmp_e3 = reinterpret_cast(0xe3000000UL); + // static shared memory allocation constexpr uint32_t smem_Q_size = B_ROW * HEADDIM; constexpr uint32_t smem_K_size = B_COL * HEADDIM; @@ -572,32 +593,23 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { smem_cursor -= smem_scratchpad_size; float *smem_scratchpad_1 = smem_cursor; + // select the correct buffer by warpgroup + float *smem_Q = (warpgroup_id % 2) ? smem_Q1 : smem_Q0; + float *smem_K = (warpgroup_id % 2) ? smem_K1 : smem_K0; + float *smem_V = (warpgroup_id % 2) ? smem_V1 : smem_V0; + float *smem_S = (warpgroup_id % 2) ? smem_S1 : smem_S0; + float *smem_O = (warpgroup_id % 2) ? smem_O1 : smem_O0; + float *smem_P = smem_S; + float *smem_O_row_scale = + (warpgroup_id % 2) ? smem_O_row_scale_1 : smem_O_row_scale_0; + float *smem_rowmax = (warpgroup_id % 2) ? smem_rowmax_1 : smem_rowmax_0; + float *smem_rowsum = (warpgroup_id % 2) ? smem_rowsum_1 : smem_rowsum_0; + float *smem_scratchpad = + (warpgroup_id % 2) ? smem_scratchpad_1 : smem_scratchpad_0; + // initialize rowmax/rowsum values in sharedmem - if (warpgroup_id == 0) { - thread_block_init_sharedmem(tid_in_warpgroup, threads_per_warpgroup, smem_O0, - smem_rowmax_0, smem_rowsum_0, smem_O_row_scale_0); - } else { - thread_block_init_sharedmem(tid_in_warpgroup, threads_per_warpgroup, smem_O1, - smem_rowmax_1, smem_rowsum_1, smem_O_row_scale_1); - } - - const float *gmem_Q = reinterpret_cast(arg->addr_q); - const float *gmem_K = reinterpret_cast(arg->addr_k); - const float *gmem_V = reinterpret_cast(arg->addr_v); - float *gmem_O = reinterpret_cast(arg->addr_o); - - float *gmem_tmp_d0 = reinterpret_cast(0xd0000000UL); - float *gmem_tmp_d1 = reinterpret_cast(0xd1000000UL); - float *gmem_tmp_d2 = reinterpret_cast(0xd2000000UL); - float *gmem_tmp_d3 = reinterpret_cast(0xd3000000UL); - float *gmem_tmp_d4 = reinterpret_cast(0xd4000000UL); - float *gmem_tmp_d5 = reinterpret_cast(0xd5000000UL); - float *gmem_tmp_d6 = reinterpret_cast(0xd6000000UL); - float *gmem_tmp_d7 = reinterpret_cast(0xd7000000UL); - float *gmem_tmp_e0 = reinterpret_cast(0xe0000000UL); - float *gmem_tmp_e1 = reinterpret_cast(0xe1000000UL); - float *gmem_tmp_e2 = reinterpret_cast(0xe2000000UL); - float *gmem_tmp_e3 = reinterpret_cast(0xe3000000UL); + thread_block_init_sharedmem(tid_in_warpgroup, threads_per_warpgroup, smem_O, + smem_rowmax, smem_rowsum, smem_O_row_scale); constexpr uint32_t global_barrier_id = NUM_WARPS - 1; // arbitrary @@ -606,13 +618,38 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core); } + // read Q and K into SMEM before the loop starts + // + static_assert(B_ROW == B_COL, "currently only supports square tiles"); + + // load Q; this stays in SMEM for the entire loop + if constexpr (!WARP_SPECIALIZED) { + load_tile_to_smem( + dim_seqlen, warpgroup_id, 0 /* dim_k == headdim */, gmem_Q, smem_Q, + tid_in_warpgroup); + } else { + // FIXME: transpose to K-major in SMEM for correctness + load_tile_to_smem( + dim_seqlen, warpgroup_id, 0 /* dim_k == headdim */, gmem_Q, smem_Q, + tid_in_warpgroup); + } + + // load K + load_tile_to_smem( + dim_seqlen, /*tile_k=*/0, 0 /* dim_k == headdim */, gmem_K, smem_K, + tid_in_warpgroup); + + // protect write to SMEM + threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core); + asm volatile ("tile_loop_start_%=:" :: ); // "inner loop" along the columns of K^T const uint32_t k_tiles = (dim_seqlen / B_COL); for (uint32_t tile_k = 0; tile_k < k_tiles; tile_k++) { - asm volatile ("buf_select_start_%=:" :: ); - // float *smem_P_produce = (tile_k % 2) ? smem_P0 : smem_P1; // float *smem_P_consume = (tile_k % 2) ? smem_P1 : smem_P0; // float *smem_V_produce = (tile_k % 2) ? smem_V0 : smem_V1; @@ -622,67 +659,87 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { // float *smem_O_row_scale_consume = // (tile_k % 2) ? smem_O_row_scale_1 : smem_O_row_scale_0; - float *smem_Q = (warpgroup_id % 2) ? smem_Q1 : smem_Q0; - float *smem_K = (warpgroup_id % 2) ? smem_K1 : smem_K0; - float *smem_V = (warpgroup_id % 2) ? smem_V1 : smem_V0; - float *smem_S = (warpgroup_id % 2) ? smem_S1 : smem_S0; - float *smem_O = (warpgroup_id % 2) ? smem_O1 : smem_O0; - float *smem_P = smem_S; - float *smem_O_row_scale = - (warpgroup_id % 2) ? smem_O_row_scale_1 : smem_O_row_scale_0; - float *smem_rowmax = (warpgroup_id % 2) ? smem_rowmax_1 : smem_rowmax_0; - float *smem_rowsum = (warpgroup_id % 2) ? smem_rowsum_1 : smem_rowsum_0; - float *smem_scratchpad = - (warpgroup_id % 2) ? smem_scratchpad_1 : smem_scratchpad_0; - - asm volatile ("buf_select_finish_%=:" :: ); - - const uint32_t tile_k_ = tile_k; - - constexpr bool skip_gemm_qk = true; + constexpr bool skip_gemm_qk = false; if constexpr (!skip_gemm_qk) { - static_assert(B_ROW == B_COL, "currently only supports square tiles"); - - // load Q - load_tile_to_smem( - dim_seqlen, 0 /*FIXME: only work on first B_ROW rows of Q for now*/, - 0 /* always 0 because dim_k == headdim */, gmem_Q, smem_Q, - tid_in_warpgroup); - - // load K - load_tile_to_smem( - dim_seqlen, tile_k_, 0 /* always 0 because dim_k == headdim */, - gmem_K, smem_K, tid_in_warpgroup); - - // GMEM->SMEM and compute barrier - threadblock_barrier(warpgroup_id_in_cluster, - warps_per_warpgroup_per_core); - - // clear out accumulators before GEMM - initialize_accum_regs<0>(); - initialize_accum_regs<1>(); - // GEMM I: S = Q*K - thread_block_gemm_single_tile( - smem_Q, smem_K, nullptr /*ignore accum*/, smem_S, tid_in_warpgroup, - threads_per_warpgroup, warpgroups_per_cluster, - warpgroup_id_in_cluster); + // + // FIXME: deduplicate this between GEMM II + if constexpr (!WARP_SPECIALIZED) { + // clear out accumulators before GEMM + initialize_accum_regs<0>(); + initialize_accum_regs<1>(); + + thread_block_gemm_single_tile( + smem_Q, smem_K, nullptr /*ignore accum*/, smem_S, tid_in_warpgroup, + threads_per_warpgroup, warpgroups_per_cluster, + warpgroup_id_in_cluster); + } else { + // when warp-specialized, there's only enough warps to do 64x32 tile + // size so we need to do 2 GEMM calls + static_assert(B_ROW / 2 == 32, + "tile size assumption for warp-specialization not met"); + + // assumes smem_Q is K-major + // FIXME: fix this to MN-major + float *smem_Q_half0 = smem_Q; + float *smem_Q_half1 = smem_Q + (B_ROW / 2) * HEADDIM; + float *smem_S_half0 = smem_S; + float *smem_S_half1 = smem_S + (B_ROW / 2) * B_COL; + + // clear out accumulators before GEMM + initialize_accum_regs<0>(); + initialize_accum_regs<1>(); + + // split by rows into 2 chunks + thread_block_gemm_single_tile( + smem_Q_half0, smem_K, nullptr /*ignore accum*/, smem_S_half0, + tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, + warpgroup_id_in_cluster); + thread_block_gemm_single_tile( + smem_Q_half1, smem_K, nullptr /*ignore accum*/, smem_S_half1, + tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, + warpgroup_id_in_cluster); + } } else { // load Q*K load_tile_to_smem( - dim_seqlen, warpgroup_id /* parallelize across rows */, tile_k_, - gmem_Q /*=gmem_S*/, smem_S, tid_in_warpgroup); + dim_seqlen, warpgroup_id /* parallelize across rows */, tile_k, + gmem_Q /*contains S*/, smem_S, tid_in_warpgroup); } - // protect GEMM result writes (smem_S) before softmax + // protect write to SMEM (smem_S) before softmax threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core); + if constexpr (DEBUG) { + if (warpgroup_id == 0) { + if (tile_k == 0) { + thread_block_copy_tile(smem_S, gmem_tmp_d0, + tid_in_warpgroup, threads_per_warpgroup, + warpgroup_id_in_cluster); + } else if (tile_k == 1) { + thread_block_copy_tile(smem_S, gmem_tmp_d1, + tid_in_warpgroup, threads_per_warpgroup, + warpgroup_id_in_cluster); + } + + threadblock_barrier(warpgroup_id_in_cluster, + warps_per_warpgroup_per_core); + } + } + // inter-warpgroup barrier before online softmax threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core); @@ -693,32 +750,36 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { smem_scratchpad, smem_rowmax, smem_rowsum, smem_O_row_scale); - // TODO: put the data movement for QKV here for inter-warpgroup + // data movement for K and V // + // Q stays in SMEM for the entire loop + // + // load K for the next iteration + load_tile_to_smem( + dim_seqlen, tile_k + 1, 0 /* dim_k == headdim */, gmem_K, smem_K, + tid_in_warpgroup); + + // load V for the current iteration // V dimension is [seqlen, headdim], stored N(headdim)-major load_tile_to_smem( - HEADDIM, 0 /* full N-dimension */, tile_k_, gmem_V, smem_V, + HEADDIM, 0 /* full N-dimension */, tile_k, gmem_V, smem_V, tid_in_warpgroup); + // protect write to SMEM threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core); if constexpr (DEBUG) { if (warpgroup_id == 0) { - if (tile_k_ == 0) { - // thread_block_copy_tile(smem_P, gmem_tmp_d0, - // tid_in_warpgroup, threads_per_warpgroup, - // warpgroup_id_in_cluster); + if (tile_k == 0) { thread_block_copy_rowmax(smem_rowmax, gmem_tmp_e0, tid_in_warpgroup, threads_per_warpgroup, warpgroup_id_in_cluster); thread_block_copy_rowmax(smem_rowsum, gmem_tmp_e2, tid_in_warpgroup, threads_per_warpgroup, warpgroup_id_in_cluster); - } else if (tile_k_ == 1) { - // thread_block_copy_tile(smem_P, gmem_tmp_d1, - // tid_in_warpgroup, threads_per_warpgroup, - // warpgroup_id_in_cluster); + } else if (tile_k == 1) { thread_block_copy_rowmax(smem_rowmax, gmem_tmp_e1, tid_in_warpgroup, threads_per_warpgroup, warpgroup_id_in_cluster); @@ -748,24 +809,18 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { if constexpr (DEBUG) { if (warpgroup_id == 0) { // O before PV - if (tile_k_ == 0) { - thread_block_copy_tile(smem_P, gmem_tmp_d0, tid_in_warpgroup, + if (tile_k == 0) { + thread_block_copy_tile(smem_P, gmem_tmp_d2, tid_in_warpgroup, threads_per_warpgroup, warpgroup_id_in_cluster); - thread_block_copy_tile(smem_V, gmem_tmp_d6, tid_in_warpgroup, + thread_block_copy_tile(smem_O, gmem_tmp_d4, tid_in_warpgroup, threads_per_warpgroup, warpgroup_id_in_cluster); - thread_block_copy_tile(smem_O, gmem_tmp_d2, tid_in_warpgroup, + } else if (tile_k == 1) { + thread_block_copy_tile(smem_P, gmem_tmp_d3, tid_in_warpgroup, threads_per_warpgroup, warpgroup_id_in_cluster); - } else if (tile_k_ == 1) { - thread_block_copy_tile(smem_P, gmem_tmp_d1, tid_in_warpgroup, - threads_per_warpgroup, - warpgroup_id_in_cluster); - thread_block_copy_tile(smem_V, gmem_tmp_d7, tid_in_warpgroup, - threads_per_warpgroup, - warpgroup_id_in_cluster); - thread_block_copy_tile(smem_O, gmem_tmp_d3, tid_in_warpgroup, + thread_block_copy_tile(smem_O, gmem_tmp_d5, tid_in_warpgroup, threads_per_warpgroup, warpgroup_id_in_cluster); } @@ -838,12 +893,12 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { if constexpr (DEBUG) { if (warpgroup_id == 0) { // O after PV - if (tile_k_ == 0) { - thread_block_copy_tile(smem_O, gmem_tmp_d4, tid_in_warpgroup, + if (tile_k == 0) { + thread_block_copy_tile(smem_O, gmem_tmp_d6, tid_in_warpgroup, threads_per_warpgroup, warpgroup_id_in_cluster); - } else if (tile_k_ == 1) { - thread_block_copy_tile(smem_O, gmem_tmp_d5, tid_in_warpgroup, + } else if (tile_k == 1) { + thread_block_copy_tile(smem_O, gmem_tmp_d7, tid_in_warpgroup, threads_per_warpgroup, warpgroup_id_in_cluster); }