From be15cffbf39451099937519a8078ab7b7db5f233 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Thu, 12 Sep 2024 14:25:33 -0700 Subject: [PATCH] flash: Revert to gemmini config, remove DEBUG and unnecessary checks --- tests/regression/flash_attention/flash_impl.hpp | 4 ++-- tests/regression/flash_attention/kernel.gemmini.cpp | 8 ++------ 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/tests/regression/flash_attention/flash_impl.hpp b/tests/regression/flash_attention/flash_impl.hpp index 47e21c70..93dc3cc9 100644 --- a/tests/regression/flash_attention/flash_impl.hpp +++ b/tests/regression/flash_attention/flash_impl.hpp @@ -11,8 +11,8 @@ #define ROW_REMAINDER_LOGIC constexpr uint32_t ROWMAX_SETS = 3; -constexpr bool WARP_SPECIALIZED = true; -constexpr bool TENSOR_CORE = true; +constexpr bool WARP_SPECIALIZED = false; +constexpr bool TENSOR_CORE = false; // temporary safety stop for wrong configs static_assert(NUM_CORES == 4); diff --git a/tests/regression/flash_attention/kernel.gemmini.cpp b/tests/regression/flash_attention/kernel.gemmini.cpp index 63d3bd56..ac3788d4 100644 --- a/tests/regression/flash_attention/kernel.gemmini.cpp +++ b/tests/regression/flash_attention/kernel.gemmini.cpp @@ -10,7 +10,7 @@ #define FENCE_GEMM_II -constexpr bool DEBUG = true; +constexpr bool DEBUG = false; static_assert(GEMMINI_DMA && !WARP_SPECIALIZED, "GEMMINI_DMA should be set and WARP_SPECIALIZED unset"); @@ -192,9 +192,6 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core); - static_assert(!GEMMINI_DMA || Q_IS_K_MAJOR, - "DMA code assumes Q matrix is stored K-major"); - // skip everything except DMA in the loop FSM constexpr uint32_t skips = loop_matmul_skips(/*skip_lda=*/0, /*skip_ldb=*/0, /*skip_ldd=*/1, @@ -339,8 +336,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { // "inner loop" along the columns of K^T const uint32_t k_tiles = (dim_seqlen / B_COL); for (uint32_t tile_k = 0; - tile_k < - (4 /*FIXME: for perf measurement*/ * k_tiles) + 2 /*pipeline latency*/; + tile_k < (4 /*for perf measurement*/ * k_tiles) + 2 /*pipeline latency*/; tile_k++) { if constexpr (DEBUG || true) { threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core);