diff --git a/tests/regression/flash_attention/kernel.cpp b/tests/regression/flash_attention/kernel.cpp index ccd20fbd..64fcf201 100644 --- a/tests/regression/flash_attention/kernel.cpp +++ b/tests/regression/flash_attention/kernel.cpp @@ -465,53 +465,65 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { float *gmem_tmp_e3 = reinterpret_cast(0xe3000000UL); // "inner loop" along the columns of K^T - for (uint32_t tile_k = 0; tile_k < (dim_seqlen / B_COL); tile_k++) { + const uint32_t k_tiles = (dim_seqlen / B_COL); + for (uint32_t tile_k = 0; tile_k < k_tiles; tile_k++) { -// #define SKIP_GEMM -#ifndef SKIP_GEMM - // clear out accumulators - initialize_accum_regs<0>(); - initialize_accum_regs<1>(); + const float *tile_S = nullptr; - static_assert(B_ROW == B_COL, "currently only supports square tiles"); + constexpr bool skip_gemm_qk = true; + if constexpr (!skip_gemm_qk) { + // clear out accumulators + initialize_accum_regs<0>(); + initialize_accum_regs<1>(); - // load Q - load_tile_to_smem( - dim_seqlen, 0 /*FIXME: only work on first B_ROW rows of Q for now*/, - 0 /* always 0 because dim_k == headdim */, gmem_Q, smem_Q, - tid_in_threadblock); + static_assert(B_ROW == B_COL, "currently only supports square tiles"); - // load K - load_tile_to_smem(dim_seqlen, tile_k, - 0 /* always 0 because dim_k == headdim */, - gmem_K, smem_K, tid_in_threadblock); + // load Q + load_tile_to_smem( + dim_seqlen, 0 /*FIXME: only work on first B_ROW rows of Q for now*/, + 0 /* always 0 because dim_k == headdim */, gmem_Q, smem_Q, + tid_in_threadblock); - // GMEM->SMEM and compute barrier - threadblock_barrier(threadblock_id_in_cluster, - warps_per_threadblock_per_core); + // load K + load_tile_to_smem(dim_seqlen, tile_k, + 0 /* always 0 because dim_k == headdim */, + gmem_K, smem_K, tid_in_threadblock); - // GEMM I: S = Q*K - thread_block_gemm_single_tile( - smem_Q, smem_K, nullptr /*ignore accum*/, smem_S, tid_in_threadblock, - threads_per_threadblock, threadblocks_per_cluster, - threadblock_id_in_cluster); + // GMEM->SMEM and compute barrier + threadblock_barrier(threadblock_id_in_cluster, + warps_per_threadblock_per_core); + + // GEMM I: S = Q*K + thread_block_gemm_single_tile( + smem_Q, smem_K, nullptr /*ignore accum*/, smem_S, tid_in_threadblock, + threads_per_threadblock, threadblocks_per_cluster, + threadblock_id_in_cluster); + + // tile_S = smem_S; + } else { + // load Q*K + load_tile_to_smem(dim_seqlen, 0, tile_k, gmem_Q /*=gmem_S*/, + smem_S, tid_in_threadblock); + // the above should be equivalent to: + // load_tile_to_smem(dim_seqlen, tile_k, 0, gmem_Q /*=gmem_S*/, + // smem_S, tid_in_threadblock); + + // tile_S = reinterpret_cast(arg->addr_q); + } // protect GEMM result writes (smem_S) before softmax threadblock_barrier(threadblock_id_in_cluster, warps_per_threadblock_per_core); - const float *tile_S = (float *)smem_S; -#else - float *tile_S = (float *)arg->addr_q; -#endif - thread_block_online_softmax( - tile_S, smem_O, smem_P, tid_in_threadblock, threads_per_threadblock, + smem_S, smem_O, smem_P, tid_in_threadblock, threads_per_threadblock, threadblocks_per_cluster, threadblock_id_in_cluster, smem_scratchpad, smem_rowmax, smem_rowsum); @@ -535,7 +547,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { threads_per_threadblock, threadblocks_per_cluster, threadblock_id_in_cluster); - } else if (tile_k == 1) { + } else if (tile_k == k_tiles - 1) { thread_block_copy_tile( smem_P, gmem_tmp_d1, tid_in_threadblock, threads_per_threadblock, threadblocks_per_cluster, threadblock_id_in_cluster); @@ -565,8 +577,8 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { // V dimension is [seqlen, headdim], stored N(headdim)-major load_tile_to_smem( - HEADDIM, 0 /* 0 because always reads the full N-dimension */, - tile_k * B_COL, gmem_V, smem_V, tid_in_threadblock); + HEADDIM, 0 /* 0 because always reads the full N-dimension */, tile_k, + gmem_V, smem_V, tid_in_threadblock); threadblock_barrier(threadblock_id_in_cluster, warps_per_threadblock_per_core); @@ -588,7 +600,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { thread_block_copy_tile( smem_O, gmem_tmp_d4, tid_in_threadblock, threads_per_threadblock, threadblocks_per_cluster, threadblock_id_in_cluster); - } else if (tile_k == 1) { + } else if (tile_k == k_tiles - 1) { thread_block_copy_tile( smem_O, gmem_tmp_d5, tid_in_threadblock, threads_per_threadblock, threadblocks_per_cluster, threadblock_id_in_cluster);