From 88760596cb5e8bfd73a8c37ef296408c45b21099 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Mon, 9 Sep 2024 17:18:59 -0700 Subject: [PATCH] flash: Remove bogus mvout to SMEM code --- .../flash_attention/kernel.gemmini.cpp | 27 ++----------------- 1 file changed, 2 insertions(+), 25 deletions(-) diff --git a/tests/regression/flash_attention/kernel.gemmini.cpp b/tests/regression/flash_attention/kernel.gemmini.cpp index f85755e1..884762d7 100644 --- a/tests/regression/flash_attention/kernel.gemmini.cpp +++ b/tests/regression/flash_attention/kernel.gemmini.cpp @@ -449,18 +449,6 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { gemmini_fence(); gemmini_fence(); -#if 0 // TODO: speed up mvout to SMEM - // loop_ws variant that skips configuring strides -#define gemmini_loop_ws(I, J, K, pad_I, pad_J, pad_K, A, B, D, C, A_stride, B_stride, D_stride, C_stride, A_transpose, B_transpose, full_C, low_D, ex_accumulate, act, a_spad_id, b_spad_id, is_resadd) \ - { \ - ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, ((uint64_t)(pad_K) << 32) | ((uint64_t)(pad_J) << 16) | (uint64_t)(pad_I), ((uint64_t)(K) << 32) | ((uint64_t)(J) << 16) | (uint64_t)(I), k_LOOP_WS_CONFIG_BOUNDS) \ - ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, A, B, k_LOOP_WS_CONFIG_ADDRS_AB) \ - ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, D, C, k_LOOP_WS_CONFIG_ADDRS_DC) \ - ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, A_stride, B_stride, k_LOOP_WS_CONFIG_STRIDES_AB) \ - ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, D_stride, C_stride, k_LOOP_WS_CONFIG_STRIDES_DC) \ - ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, ((uint64_t)(a_spad_id) << 18) | ((uint64_t)(b_spad_id) << 16) | ((uint64_t)(act) << 8) | ((low_D) << 2) | ((full_C) << 1) | (ex_accumulate), ((is_resadd) << 2) | ((B_transpose) << 1) | (A_transpose), k_LOOP_WS) \ - } -#endif } // // reconverge after mmio // threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core); @@ -487,6 +475,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { const float *gmem_V_tile = gmem_V + (HEADDIM * B_COL * (tile_k - 1 /*dragbehind*/)); +#if 0 // fence mvout S to SMEM gemmini_fence(); ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, (uint64_t)(gmem_K_tile), @@ -496,6 +485,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { // FIXME: unnecessary? GEMMINI_CISC_CMD_R((HEADDIM /*V*/ << 20) | (dim_seqlen /*KT*/ << 8) | 8 /*k_LOOP_WS_CONFIG_STRIDES_AB*/); +#endif gemmini_fence(); // do DMA @@ -663,19 +653,6 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { gemmini_fence(); gemmini_fence(); gemmini_fence(); - -#if 0 - // mvout to SMEM - // GEMMINI_CISC_CMD_I(9); - sp_tiled_matmul_full_spad_ws( - /*spad_A=*/spad_addr_Q /*bogus*/, - /*spad_B=*/spad_addr_K_consume /*bogus*/, - /*spad_D=*/0, /*spad_C=*/spad_addr_S_produce, - /*I=*/(B_ROW / DIM), /*J=*/(B_COL / DIM), /*K=*/(HEADDIM / DIM), - /*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0, - /*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0, - /*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_mvout_spad); -#endif } // reconverge from mmio divergence