diff --git a/tests/regression/flash_attention/kernel.gemmini.cpp b/tests/regression/flash_attention/kernel.gemmini.cpp index ef16a6ee..063bc468 100644 --- a/tests/regression/flash_attention/kernel.gemmini.cpp +++ b/tests/regression/flash_attention/kernel.gemmini.cpp @@ -216,10 +216,11 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { if (tid_in_warpgroup == 0) { gemmini_extended_config_ex(WEIGHT_STATIONARY, 0, 0, 1, 0, 0); - // configure DMA for the full Q matrix + // configure DMA with GMEM address strides + // Q matrix gemmini_extended3_config_ld(HEADDIM * sizeof(elem_t), MVIN_SCALE_IDENTITY, false, 0); - // configure DMA for the full K matrix + // K matrix gemmini_extended3_config_ld(dim_seqlen * sizeof(elem_t), MVIN_SCALE_IDENTITY, false, 1); // configure DMA for Q*K store @@ -344,16 +345,13 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { float *smem_S_produce = (tile_k & 1) ? smem_S0 : smem_S1; float *smem_P_consume = (tile_k & 1) ? smem_P1 : smem_P0; float *smem_P_produce = (tile_k & 1) ? smem_P0 : smem_P1; - // O tile is sequentially updated at every iteration; no ping-pong - // necessary + // O, rowmax/rowsum etc. is sequentially updated at every iteration; no + // ping-pong necessary float *smem_O = smem_O0; - // FIXME: O_row_scale/rowmax/rowsum/spad shouldn't really need ping-pong - float *smem_O_row_scale = - (tile_k & 1) ? smem_O_row_scale_1 : smem_O_row_scale_0; - float *smem_rowmax = (tile_k & 1) ? smem_rowmax_1 : smem_rowmax_0; - float *smem_rowsum = (tile_k & 1) ? smem_rowsum_1 : smem_rowsum_0; - float *smem_scratchpad = - (tile_k & 1) ? smem_scratchpad_1 : smem_scratchpad_0; + float *smem_O_row_scale = smem_O_row_scale_0; + float *smem_rowmax = smem_rowmax_0; + float *smem_rowsum = smem_rowsum_0; + float *smem_scratchpad = smem_scratchpad_0; const auto spad_addr_Q = spad_addr_Q0; const auto spad_addr_K_consume = (tile_k & 1) ? spad_addr_K1 : spad_addr_K0; @@ -394,6 +392,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { // do matmul // among other things, this also configures CONFIG_BOUNDS so that the // DMA knows the full matrix dimensions + // FIXME: perf: prevent GMEM->SMEM load for O tile gemmini_fence(); sp_tiled_matmul_full_spad_ws( spad_addr_P_consume, spad_addr_V_consume, @@ -401,7 +400,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { /*I=*/(B_ROW / DIM), /*J=*/(HEADDIM / DIM), /*K=*/(B_COL / DIM), /*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0, /*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0, - /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_matmul_preload); + /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_matmul); #endif gemmini_fence();