diff --git a/tests/regression/flash_attention/kernel.cpp b/tests/regression/flash_attention/kernel.cpp index 1c9b015d..1d88b4de 100644 --- a/tests/regression/flash_attention/kernel.cpp +++ b/tests/regression/flash_attention/kernel.cpp @@ -672,9 +672,10 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { static_assert(B_ROW / 2 == 32, "tile size assumption for warp-specialization not met"); - // assumes smem_P is K-major float *smem_P_half0 = smem_P; - float *smem_P_half1 = smem_P + (B_ROW / 2) * B_COL; + float *smem_P_half1 = (Q_IS_K_MAJOR || GEMMINI_DMA) + ? smem_P + (B_ROW / 2) * B_COL + : smem_P + (B_ROW / 2); float *smem_O_half0 = smem_O; float *smem_O_half1 = smem_O + (B_ROW / 2) * HEADDIM; @@ -707,7 +708,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, warpgroup_id_in_cluster); } - } else { + } else if constexpr (Q_IS_K_MAJOR) { thread_block_gemm_single_tile< float, MemLayout::K_major, MemLayout::MN_major, B_ROW / 2, HEADDIM, B_COL, /*leading_dim_a=*/0, /*leading_dim_b=*/0, @@ -716,6 +717,15 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { smem_P_half0, smem_V, smem_O_half0 /*load accum*/, smem_O_half0, tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, warpgroup_id_in_cluster); + } else { + thread_block_gemm_single_tile< + float, MemLayout::MN_major, MemLayout::MN_major, B_ROW / 2, HEADDIM, + B_COL, /*leading_dim_a=*/B_ROW, /*leading_dim_b=*/0, + /*load_accum=*/true, + /*write_to_smem=*/true>( + smem_P_half0, smem_V, smem_O_half0 /*load accum*/, smem_O_half0, + tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, + warpgroup_id_in_cluster); } initialize_accum_regs<0>(); @@ -745,7 +755,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, warpgroup_id_in_cluster); } - } else { + } else if constexpr (Q_IS_K_MAJOR) { thread_block_gemm_single_tile< float, MemLayout::K_major, MemLayout::MN_major, B_ROW / 2, HEADDIM, B_COL, /*leading_dim_a=*/0, /*leading_dim_b=*/0, @@ -754,6 +764,15 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { smem_P_half1, smem_V, smem_O_half1 /*load accum*/, smem_O_half1, tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, warpgroup_id_in_cluster); + } else { + thread_block_gemm_single_tile< + float, MemLayout::MN_major, MemLayout::MN_major, B_ROW / 2, HEADDIM, + B_COL, /*leading_dim_a=*/B_ROW, /*leading_dim_b=*/0, + /*load_accum=*/true, + /*write_to_smem=*/true>( + smem_P_half1, smem_V, smem_O_half1 /*load accum*/, smem_O_half1, + tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, + warpgroup_id_in_cluster); } }