From d8d5df64e6512034bef2f5a188a3b9f945951bd2 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Tue, 20 Aug 2024 14:34:09 -0700 Subject: [PATCH] flash: Fix load addr for V tile; test with seqlen=128 --- tests/regression/flash_attention/kernel.cpp | 7 +++++-- tests/regression/flash_attention/main.cpp | 2 +- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/tests/regression/flash_attention/kernel.cpp b/tests/regression/flash_attention/kernel.cpp index dd6cb4f3..b1086e18 100644 --- a/tests/regression/flash_attention/kernel.cpp +++ b/tests/regression/flash_attention/kernel.cpp @@ -475,8 +475,11 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { initialize_accum_regs<0>(); initialize_accum_regs<1>(); - load_tile_to_smem( - B_COL, 0, 0, gmem_V, smem_V, tid_in_threadblock); + // V dimension is [seqlen, headdim], stored N(headdim)-major + load_tile_to_smem( + HEADDIM, 0 /* 0 because always reads the full N-dimension */, + tile_k * B_COL, gmem_V, smem_V, tid_in_threadblock); threadblock_barrier(threadblock_id_in_cluster, warps_per_threadblock_per_core); diff --git a/tests/regression/flash_attention/main.cpp b/tests/regression/flash_attention/main.cpp index 3747bbb6..b1b8d522 100644 --- a/tests/regression/flash_attention/main.cpp +++ b/tests/regression/flash_attention/main.cpp @@ -99,7 +99,7 @@ int main(int argc, char *argv[]) { std::cout << "open device connection" << std::endl; RT_CHECK(vx_dev_open(&device)); - uint32_t dim_seqlen = 64; + uint32_t dim_seqlen = 128; uint32_t dim_headdim = 64; using float_type = half;