diff --git a/tests/regression/flash_attention/kernel.cpp b/tests/regression/flash_attention/kernel.cpp index 00a5323a..e57eb16b 100644 --- a/tests/regression/flash_attention/kernel.cpp +++ b/tests/regression/flash_attention/kernel.cpp @@ -25,11 +25,11 @@ static_assert(NUM_THREADS == 8); static_assert(NUM_WARPS == 8); inline void thread_block_init_sharedmem(const uint32_t tid_in_threadblock, - const uint32_t threads_per_threadblock, - float *smem_O, - float *smem_rowmax, - float *smem_rowsum, - float *smem_O_row_scale) { + const uint32_t threads_per_threadblock, + float *smem_O, float *smem_rowmax, + float *smem_rowsum, + float *smem_O_row_scale_0, + float *smem_O_row_scale_1) { asm volatile("threadblock_init_sharedmem_start_%=:" ::); const uint32_t tid_in_warp = tid_in_threadblock % NUM_THREADS; @@ -52,7 +52,8 @@ inline void thread_block_init_sharedmem(const uint32_t tid_in_threadblock, smem_rowmax[offset + i * ROWMAX_SETS] = FLT_MIN; } smem_rowsum[offset] = 0.0f; - smem_O_row_scale[offset] = 0.0f; + smem_O_row_scale_0[offset] = 0.0f; + smem_O_row_scale_1[offset] = 0.0f; } // each warp clears out a row of smem_O @@ -501,9 +502,9 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { const uint32_t tid_in_warpgroup = tid_in_threadblock % threads_per_warpgroup; // FIXME do proper software pipelining - if (DOUBLE_BUF && warpgroup_id_in_cluster != 1) { - return; - } + // if (DOUBLE_BUF && warpgroup_id_in_cluster != 1) { + // return; + // } const uint32_t dim_seqlen = arg->dim_seqlen; const uint32_t dim_headdim = arg->dim_headdim; @@ -538,7 +539,8 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { float *smem_rowmax = reinterpret_cast(SMEM_ADDR_END) - smem_rowmax_size; float *smem_rowsum = smem_rowmax - smem_rowsum_size; - float *smem_O_row_scale = smem_rowsum - smem_O_row_scale_size; + float *smem_O_row_scale_0 = smem_rowsum - smem_O_row_scale_size; + float *smem_O_row_scale_1 = smem_O_row_scale_0 - smem_O_row_scale_size; // sharedmem "scratchpad" area to put temporary data, e.g. for tree reduction // in rowsum @@ -546,12 +548,12 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { // TODO: reduce this from B_ROW to NUM_WARPS constexpr uint32_t smem_scratchpad_size = B_ROW * NUM_THREADS * 2 /*arbitrary slack*/; - float *smem_scratchpad = smem_O_row_scale - smem_scratchpad_size; + float *smem_scratchpad = smem_O_row_scale_1 - smem_scratchpad_size; // initialize rowmax/rowsum values in sharedmem - thread_block_init_sharedmem(tid_in_warpgroup, threads_per_warpgroup, - smem_O, smem_rowmax, smem_rowsum, - smem_O_row_scale); + thread_block_init_sharedmem(tid_in_warpgroup, threads_per_warpgroup, smem_O, + smem_rowmax, smem_rowsum, smem_O_row_scale_0, + smem_O_row_scale_1); const float *gmem_Q = reinterpret_cast(arg->addr_q); const float *gmem_K = reinterpret_cast(arg->addr_k); @@ -573,24 +575,28 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { // "inner loop" along the columns of K^T const uint32_t k_tiles = (dim_seqlen / B_COL); - for (uint32_t tile_k = 0; tile_k < k_tiles; tile_k++) { - // float *smem_P_produce = (tile_k % 2) ? smem_P0 : smem_P1; - // float *smem_P_consume = (tile_k % 2) ? smem_P1 : smem_P0; - // float *smem_V_produce = (tile_k % 2) ? smem_V0 : smem_V1; - // float *smem_V_consume = (tile_k % 2) ? smem_V1 : smem_V0; - float *smem_P_produce = smem_P0; - float *smem_P_consume = smem_P0; - float *smem_V_produce = smem_V0; - float *smem_V_consume = smem_V0; + for (uint32_t tile_k = 0; tile_k < k_tiles + 1 /*pipeline latency*/; + tile_k++) { + float *smem_P_produce = (tile_k % 2) ? smem_P0 : smem_P1; + float *smem_P_consume = (tile_k % 2) ? smem_P1 : smem_P0; + float *smem_V_produce = (tile_k % 2) ? smem_V0 : smem_V1; + float *smem_V_consume = (tile_k % 2) ? smem_V1 : smem_V0; + float *smem_O_row_scale_produce = + (tile_k % 2) ? smem_O_row_scale_0 : smem_O_row_scale_1; + float *smem_O_row_scale_consume = + (tile_k % 2) ? smem_O_row_scale_1 : smem_O_row_scale_0; + // float *smem_P_produce = smem_P0; + // float *smem_P_consume = smem_P0; + // float *smem_V_produce = smem_V0; + // float *smem_V_consume = smem_V0; - // if (warpgroup_id == 0) { - { + if (warpgroup_id == 0) { // Pipeline stage 1 // // skip pipeline drain - // if (tile_k == k_tiles) { - // continue; - // } + if (tile_k == k_tiles) { + goto tile_iter_end; + } const uint32_t tile_k_ = tile_k; constexpr bool skip_gemm_qk = true; @@ -645,10 +651,10 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core); - thread_block_online_softmax(smem_S, smem_P_produce, tid_in_warpgroup, - threads_per_warpgroup, - warpgroup_id_in_cluster, smem_scratchpad, - smem_rowmax, smem_rowsum, smem_O_row_scale); + thread_block_online_softmax( + smem_S, smem_P_produce, tid_in_warpgroup, threads_per_warpgroup, + warpgroup_id_in_cluster, smem_scratchpad, smem_rowmax, smem_rowsum, + smem_O_row_scale_produce); // FIXME unnecessary? threadblock_barrier(warpgroup_id_in_cluster, @@ -680,17 +686,15 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core); } - // } else if (warpgroup_id == 1) { - } - { + } else if (warpgroup_id == 1) { // Pipeline stage 2 // // skip pipeline start - // if (tile_k == 0) { - // continue; - // } - // const uint32_t tile_k_ = tile_k - 1; - const uint32_t tile_k_ = tile_k; + if (tile_k == 0) { + goto tile_iter_end; + } + const uint32_t tile_k_ = tile_k - 1; + // const uint32_t tile_k_ = tile_k; // GEMM II: O = O + P*V @@ -709,9 +713,9 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { warps_per_warpgroup_per_core); // Oi rescale - thread_block_O_rescale(smem_O, smem_O /*in-place*/, smem_O_row_scale, - tid_in_warpgroup, threads_per_warpgroup, - warpgroup_id_in_cluster); + thread_block_O_rescale(smem_O, smem_O /*in-place*/, + smem_O_row_scale_consume, tid_in_warpgroup, + threads_per_warpgroup, warpgroup_id_in_cluster); // rescale-to-PV-GEMM barrier threadblock_barrier(warpgroup_id_in_cluster, @@ -803,6 +807,13 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { warps_per_warpgroup_per_core); } } + + tile_iter_end: + // synchronize progress of two warpgroups + // threadblock_barrier(threadblock_id_in_cluster, + // warps_per_threadblock_per_core); + threadblock_barrier(3, // FIXME + 8); } asm volatile ("tile_loop_finish_%=:" :: );