diff --git a/tests/regression/flash_attention/kernel.cpp b/tests/regression/flash_attention/kernel.cpp index 10d8f555..13d743ea 100644 --- a/tests/regression/flash_attention/kernel.cpp +++ b/tests/regression/flash_attention/kernel.cpp @@ -14,7 +14,7 @@ constexpr uint32_t ROWMAX_SETS = 3; constexpr bool DEBUG = true; -constexpr bool WARP_SPECIALIZED = false; +constexpr bool WARP_SPECIALIZED = true; constexpr uint32_t DEV_FAKE_SMEM_START_ADDR = 0xf0000000; @@ -492,11 +492,6 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { warpgroup_id % warpgroups_per_cluster; const uint32_t tid_in_warpgroup = tid_in_threadblock % threads_per_warpgroup; - // FIXME do proper software pipelining - // if (WARP_SPECIALIZED && warpgroup_id_in_cluster != 1) { - // return; - // } - const uint32_t dim_seqlen = arg->dim_seqlen; const uint32_t dim_headdim = arg->dim_headdim; @@ -597,7 +592,6 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { // sharedmem "scratchpad" area to put temporary data, e.g. for tree reduction // in rowsum // NOTE: out-of bounds is not checked - // TODO: reduce this from B_ROW to NUM_WARPS constexpr uint32_t smem_scratchpad_size = threads_per_warpgroup * 2 /*arbitrary slack*/; float *smem_scratchpad_0 = smem_cursor; @@ -619,6 +613,11 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { float *smem_scratchpad = (warpgroup_id % 2) ? smem_scratchpad_1 : smem_scratchpad_0; + const auto spad_addr_Q = (warpgroup_id % 2) ? spad_addr_Q1 : spad_addr_Q0; + const auto spad_addr_K = (warpgroup_id % 2) ? spad_addr_K1 : spad_addr_K0; + const auto spad_addr_V = (warpgroup_id % 2) ? spad_addr_V1 : spad_addr_V0; + const auto spad_addr_S = (warpgroup_id % 2) ? spad_addr_S1 : spad_addr_S0; + // initialize rowmax/rowsum values in sharedmem thread_block_init_sharedmem(tid_in_warpgroup, threads_per_warpgroup, smem_O, smem_rowmax, smem_rowsum, smem_O_row_scale); @@ -626,7 +625,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { constexpr uint32_t global_barrier_id = NUM_WARPS - 1; // arbitrary // delay warpgroup 0 by 1 iteration to do ping-pong scheduling - if (warpgroup_id == 1) { + if (WARP_SPECIALIZED && warpgroup_id == 1) { threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core); } @@ -667,15 +666,16 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { // static_assert(B_ROW == B_COL, "currently only supports square tiles"); - static_assert(warps_per_warpgroup_per_core == 8); // FIXME nocheckin - if constexpr (GEMMINI_DMA) { asm volatile("dma_move_start_%=:" ::); - if (tid_in_threadblock == 0) { + if (tid_in_warpgroup == 0) { + const float *gmem_Q_tile = gmem_Q + HEADDIM * B_ROW * warpgroup_id; + const float *gmem_K_tile = gmem_K; // configure the GMEM addresses for the DMA to read from - ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, (uint64_t)(gmem_Q), - (uint64_t)(gmem_K), k_LOOP_WS_CONFIG_ADDRS_AB) + ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, (uint64_t)(gmem_Q_tile), + (uint64_t)(gmem_K_tile), + k_LOOP_WS_CONFIG_ADDRS_AB) // configure address strides for the DMA GEMMINI_CISC_CMD_R((dim_seqlen << 16) | (HEADDIM << 8) | 8 /*k_LOOP_WS_CONFIG_STRIDES_AB*/); @@ -691,8 +691,8 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { // among other things, this also configures CONFIG_BOUNDS so that the // DMA knows the full matrix dimensions sp_tiled_matmul_full_spad_ws( - spad_addr_Q0, spad_addr_K0, - /*spad_D=*/0, /*spad_C=*/spad_addr_S0, + spad_addr_Q, spad_addr_K, + /*spad_D=*/0, /*spad_C=*/spad_addr_S, /*I=*/(B_ROW / DIM), /*J=*/(B_COL / DIM), /*K=*/(HEADDIM / DIM), /*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0, /*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0, @@ -803,8 +803,9 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { "tile size assumption for warp-specialization not met"); float *smem_Q_half0 = smem_Q; - float *smem_Q_half1 = Q_IS_K_MAJOR ? smem_Q + (B_ROW / 2) * HEADDIM - : smem_Q + (B_ROW / 2); + float *smem_Q_half1 = (Q_IS_K_MAJOR || GEMMINI_DMA) + ? smem_Q + (B_ROW / 2) * HEADDIM + : smem_Q + (B_ROW / 2); float *smem_S_half0 = smem_S; float *smem_S_half1 = smem_S + (B_ROW / 2) * B_COL; @@ -813,8 +814,17 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { initialize_accum_regs<1>(); // split by rows into 2 chunks - // TODO: GEMMINI_DMA - if constexpr (Q_IS_K_MAJOR) { + if constexpr (GEMMINI_DMA) { + thread_block_gemm_single_tile( + smem_Q_half0, smem_K, nullptr /*ignore accum*/, smem_S_half0, + tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, + warpgroup_id_in_cluster); + } else if constexpr (Q_IS_K_MAJOR) { thread_block_gemm_single_tile< float, MemLayout::K_major, MemLayout::MN_major, B_ROW / 2, B_COL, HEADDIM, /*leading_dim_a=*/0, /*leading_dim_b=*/0, @@ -837,8 +847,17 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { initialize_accum_regs<0>(); initialize_accum_regs<1>(); - // TODO: GEMMINI_DMA - if constexpr (Q_IS_K_MAJOR) { + if constexpr (GEMMINI_DMA) { + thread_block_gemm_single_tile( + smem_Q_half1, smem_K, nullptr /*ignore accum*/, smem_S_half1, + tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, + warpgroup_id_in_cluster); + } else if constexpr (Q_IS_K_MAJOR) { thread_block_gemm_single_tile< float, MemLayout::K_major, MemLayout::MN_major, B_ROW / 2, B_COL, HEADDIM, /*leading_dim_a=*/0, /*leading_dim_b=*/0, @@ -903,7 +922,11 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { // // Q stays in SMEM for the entire loop if constexpr (GEMMINI_DMA) { - if (tid_in_threadblock == 0) { + // NOTE: Beware of race conditions; with warp specialization, we need to + // make sure below command code to DMA is not executed simultaneously + // from the two warpgroups (which will result in hardware fault). + // Currently the ping-pong scheduling scheme prevents that. + if (tid_in_warpgroup == 0) { // configure GMEM addresses for K and V tiles // load K for the next iteration const float *gmem_K_tile = gmem_K + (B_COL * (tile_k + 1)); @@ -920,8 +943,8 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { // do DMA sp_tiled_matmul_full_spad_ws( - spad_addr_K0, spad_addr_V0, - /*spad_D=*/0, /*spad_C=*/spad_addr_S0, + spad_addr_K, spad_addr_V, + /*spad_D=*/0, /*spad_C=*/spad_addr_S, /*I=*/(HEADDIM / DIM), /*J=*/(HEADDIM / DIM), /*K=*/(B_COL / DIM), /*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0, /*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0, @@ -1044,9 +1067,6 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { // warpgroups_per_cluster, warpgroup_id_in_cluster); } } else { - static_assert(!WARP_SPECIALIZED || !GEMMINI_DMA, - "warp specialization unimplemented for dma"); - // when warp-specialized, there's only enough warps to do 64x32 tile // size so we need to do 2 GEMM calls static_assert(B_ROW / 2 == 32, @@ -1063,27 +1083,52 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { initialize_accum_regs<1>(); // split by rows into 2 chunks - // TODO: GEMMINI_DMA - thread_block_gemm_single_tile< - float, MemLayout::K_major, MemLayout::MN_major, B_ROW / 2, HEADDIM, - B_COL, /*leading_dim_a=*/0, /*leading_dim_b=*/0, - /*load_accum=*/true, - /*write_to_smem=*/true>( - smem_P_half0, smem_V, smem_O_half0 /*load accum*/, smem_O_half0, - tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, - warpgroup_id_in_cluster); + if constexpr (GEMMINI_DMA) { + thread_block_gemm_single_tile< + float, MemLayout::K_major /* P matrix is row-major */, + MemLayout::block_row_major, B_ROW / 2, HEADDIM, B_COL, + /*leading_dim_a=*/0, + /*leading_dim_b=*/0, + /*load_accum=*/true, + /*write_to_smem=*/true>( + smem_P_half0, smem_V, smem_O_half0 /*load accum*/, smem_O_half0, + tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, + warpgroup_id_in_cluster); + } else { + thread_block_gemm_single_tile< + float, MemLayout::K_major, MemLayout::MN_major, B_ROW / 2, HEADDIM, + B_COL, /*leading_dim_a=*/0, /*leading_dim_b=*/0, + /*load_accum=*/true, + /*write_to_smem=*/true>( + smem_P_half0, smem_V, smem_O_half0 /*load accum*/, smem_O_half0, + tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, + warpgroup_id_in_cluster); + } initialize_accum_regs<0>(); initialize_accum_regs<1>(); - thread_block_gemm_single_tile< - float, MemLayout::K_major, MemLayout::MN_major, B_ROW / 2, HEADDIM, - B_COL, /*leading_dim_a=*/0, /*leading_dim_b=*/0, - /*load_accum=*/true, - /*write_to_smem=*/true>( - smem_P_half1, smem_V, smem_O_half1 /*load accum*/, smem_O_half1, - tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, - warpgroup_id_in_cluster); + if constexpr (GEMMINI_DMA) { + thread_block_gemm_single_tile< + float, MemLayout::K_major /* P matrix is row-major */, + MemLayout::block_row_major, B_ROW / 2, HEADDIM, B_COL, + /*leading_dim_a=*/0, + /*leading_dim_b=*/0, + /*load_accum=*/true, + /*write_to_smem=*/true>( + smem_P_half1, smem_V, smem_O_half1 /*load accum*/, smem_O_half1, + tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, + warpgroup_id_in_cluster); + } else { + thread_block_gemm_single_tile< + float, MemLayout::K_major, MemLayout::MN_major, B_ROW / 2, HEADDIM, + B_COL, /*leading_dim_a=*/0, /*leading_dim_b=*/0, + /*load_accum=*/true, + /*write_to_smem=*/true>( + smem_P_half1, smem_V, smem_O_half1 /*load accum*/, smem_O_half1, + tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, + warpgroup_id_in_cluster); + } } threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core);