diff --git a/tests/regression/flash_attention/kernel.cpp b/tests/regression/flash_attention/kernel.cpp index e57eb16b..2b1fea33 100644 --- a/tests/regression/flash_attention/kernel.cpp +++ b/tests/regression/flash_attention/kernel.cpp @@ -523,13 +523,11 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { float *smem_Q = reinterpret_cast(smem_per_threadblock); float *smem_K = smem_Q + smem_Q_size; - // in-place multiplication of QK into Q float *smem_S = reinterpret_cast(smem_per_threadblock); - float *smem_P0 = smem_S; // in-place update from S to P + float *smem_O = smem_S + smem_QK_size; + float *smem_P0 = reinterpret_cast(DEV_FAKE_SMEM_START_ADDR); float *smem_P1 = smem_P0 + smem_QK_size; - float *smem_O = smem_P1 + smem_QK_size; - float *smem_V0 = - reinterpret_cast(DEV_FAKE_SMEM_START_ADDR) + smem_QK_size; + float *smem_V0 = smem_P1 + smem_QK_size; float *smem_V1 = smem_V0 + smem_QK_size; // allocate rowmax/rowsum storage at the end of the sharedmem address space @@ -566,6 +564,8 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { float *gmem_tmp_d3 = reinterpret_cast(0xd3000000UL); float *gmem_tmp_d4 = reinterpret_cast(0xd4000000UL); float *gmem_tmp_d5 = reinterpret_cast(0xd5000000UL); + float *gmem_tmp_d6 = reinterpret_cast(0xd6000000UL); + float *gmem_tmp_d7 = reinterpret_cast(0xd7000000UL); float *gmem_tmp_e0 = reinterpret_cast(0xe0000000UL); float *gmem_tmp_e1 = reinterpret_cast(0xe1000000UL); float *gmem_tmp_e2 = reinterpret_cast(0xe2000000UL); @@ -662,19 +662,19 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { if constexpr (DEBUG) { if (tile_k_ == 0) { - thread_block_copy_tile(smem_P_produce, gmem_tmp_d0, - tid_in_warpgroup, threads_per_warpgroup, - warpgroup_id_in_cluster); + // thread_block_copy_tile(smem_P_produce, gmem_tmp_d0, + // tid_in_warpgroup, threads_per_warpgroup, + // warpgroup_id_in_cluster); thread_block_copy_rowmax(smem_rowmax, gmem_tmp_e0, tid_in_warpgroup, threads_per_warpgroup, warpgroup_id_in_cluster); thread_block_copy_rowmax(smem_rowsum, gmem_tmp_e2, tid_in_warpgroup, threads_per_warpgroup, warpgroup_id_in_cluster); - } else if (tile_k_ == k_tiles - 1) { - thread_block_copy_tile(smem_P_produce, gmem_tmp_d1, - tid_in_warpgroup, threads_per_warpgroup, - warpgroup_id_in_cluster); + } else if (tile_k_ == 1) { + // thread_block_copy_tile(smem_P_produce, gmem_tmp_d1, + // tid_in_warpgroup, threads_per_warpgroup, + // warpgroup_id_in_cluster); thread_block_copy_rowmax(smem_rowmax, gmem_tmp_e1, tid_in_warpgroup, threads_per_warpgroup, warpgroup_id_in_cluster); @@ -698,10 +698,6 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { // GEMM II: O = O + P*V - // clear out accumulators - initialize_accum_regs<0>(); - initialize_accum_regs<1>(); - // V dimension is [seqlen, headdim], stored N(headdim)-major load_tile_to_smem( @@ -724,10 +720,22 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { if constexpr (DEBUG) { // O before PV if (tile_k_ == 0) { + thread_block_copy_tile(smem_P_consume, gmem_tmp_d0, + tid_in_warpgroup, threads_per_warpgroup, + warpgroup_id_in_cluster); + thread_block_copy_tile(smem_V_consume, gmem_tmp_d6, + tid_in_warpgroup, threads_per_warpgroup, + warpgroup_id_in_cluster); thread_block_copy_tile(smem_O, gmem_tmp_d2, tid_in_warpgroup, threads_per_warpgroup, warpgroup_id_in_cluster); - } else if (tile_k_ == k_tiles - 1) { + } else if (tile_k_ == 1) { + thread_block_copy_tile(smem_P_consume, gmem_tmp_d1, + tid_in_warpgroup, threads_per_warpgroup, + warpgroup_id_in_cluster); + thread_block_copy_tile(smem_V_consume, gmem_tmp_d7, + tid_in_warpgroup, threads_per_warpgroup, + warpgroup_id_in_cluster); thread_block_copy_tile(smem_O, gmem_tmp_d3, tid_in_warpgroup, threads_per_warpgroup, warpgroup_id_in_cluster); @@ -738,6 +746,10 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { } if constexpr (!DOUBLE_BUF) { + // clear out accumulators + initialize_accum_regs<0>(); + initialize_accum_regs<1>(); + thread_block_gemm_single_tile(); + initialize_accum_regs<1>(); + // split by rows into 2 chunks thread_block_gemm_single_tile