From 6547e927577c4df7b0bfc44aac2643ed6c5b3f8a Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Sun, 8 Sep 2024 19:47:55 -0700 Subject: [PATCH] flash: Load Q to both quartiles; preload O for acc --- .../flash_attention/kernel.gemmini.cpp | 20 ++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/tests/regression/flash_attention/kernel.gemmini.cpp b/tests/regression/flash_attention/kernel.gemmini.cpp index 9e36bf83..e934a1e1 100644 --- a/tests/regression/flash_attention/kernel.gemmini.cpp +++ b/tests/regression/flash_attention/kernel.gemmini.cpp @@ -202,8 +202,8 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { constexpr uint32_t skips_mvout_spad = loop_matmul_skips(/*skip_lda=*/1, /*skip_ldb=*/1, /*skip_ldd=*/1, /*skip_ex=*/1, /*skip_stc=*/0); - constexpr uint32_t skips_matmul = - loop_matmul_skips(/*skip_lda=*/1, /*skip_ldb=*/1, /*skip_ldd=*/1, + constexpr uint32_t skips_matmul_preload = + loop_matmul_skips(/*skip_lda=*/1, /*skip_ldb=*/1, /*skip_ldd=*/0, /*skip_ex=*/0, /*skip_stc=*/1); if constexpr (GEMMINI_DMA) { @@ -255,6 +255,15 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { // spad_addr_K0 set in this kernel GEMMINI_CISC_CMD_I(10); gemmini_fence(); + + // need to also move to Q1 for the next iteration + ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, (uint64_t)(gmem_Q_tile), + (uint64_t)(gmem_K_tile), k_LOOP_WS_CONFIG_ADDRS_AB) + GEMMINI_CISC_CMD_R((dim_seqlen << 20) | (HEADDIM << 8) | + 8 /*k_LOOP_WS_CONFIG_STRIDES_AB*/); + gemmini_fence(); + GEMMINI_CISC_CMD_I(11); + gemmini_fence(); #else // do DMA // @@ -369,11 +378,11 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { // DMA knows the full matrix dimensions sp_tiled_matmul_full_spad_ws( spad_addr_P_consume, spad_addr_V_consume, - /*spad_D=*/0, /*spad_C=*/spad_addr_O, + /*spad_D=*/spad_addr_O, /*spad_C=*/spad_addr_O, /*I=*/(B_ROW / DIM), /*J=*/(HEADDIM / DIM), /*K=*/(B_COL / DIM), /*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0, /*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0, - /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_matmul); + /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_matmul_preload); #endif gemmini_fence(); @@ -455,7 +464,8 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { // mvout to SMEM // GEMMINI_CISC_CMD_I(9); sp_tiled_matmul_full_spad_ws( - /*spad_A=*/spad_addr_Q, /*spad_B=*/spad_addr_K_consume, + /*spad_A=*/spad_addr_Q /*bogus*/, + /*spad_B=*/spad_addr_K_consume /*bogus*/, /*spad_D=*/0, /*spad_C=*/spad_addr_S_produce, /*I=*/(B_ROW / DIM), /*J=*/(B_COL / DIM), /*K=*/(HEADDIM / DIM), /*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0,