From 33bc084c37a2d2e8b03ff4a85844bdf1f8936c34 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Sat, 7 Sep 2024 19:50:04 -0700 Subject: [PATCH] flash: Fix DMA layout for GEMM II --- tests/regression/flash_attention/kernel.cpp | 20 +++++++++++--------- tests/regression/sgemm_tcore/sgemm_impl.hpp | 3 --- 2 files changed, 11 insertions(+), 12 deletions(-) diff --git a/tests/regression/flash_attention/kernel.cpp b/tests/regression/flash_attention/kernel.cpp index 8da963a6..10d8f555 100644 --- a/tests/regression/flash_attention/kernel.cpp +++ b/tests/regression/flash_attention/kernel.cpp @@ -578,7 +578,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { constexpr uint32_t smem_rowmax_size = B_ROW * ROWMAX_SETS; constexpr uint32_t smem_rowsum_size = B_ROW; constexpr uint32_t smem_O_row_scale_size = B_ROW; - // smem_cursor = reinterpret_cast(DEV_FAKE_SMEM_START_ADDR + SMEM_SIZE); + // FIXME: dangerous smem_cursor = reinterpret_cast(0xff038000); float *smem_rowmax_0 = smem_cursor; @@ -599,8 +599,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { // NOTE: out-of bounds is not checked // TODO: reduce this from B_ROW to NUM_WARPS constexpr uint32_t smem_scratchpad_size = - B_ROW * NUM_THREADS * 2 /*arbitrary slack*/; - // threads_per_warpgroup * 2 /*arbitrary slack*/; + threads_per_warpgroup * 2 /*arbitrary slack*/; float *smem_scratchpad_0 = smem_cursor; smem_cursor += smem_scratchpad_size; float *smem_scratchpad_1 = smem_cursor; @@ -1013,12 +1012,12 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { initialize_accum_regs<1>(); if constexpr (GEMMINI_DMA) { - thread_block_gemm_single_tile( + thread_block_gemm_single_tile< + float, MemLayout::K_major /* P matrix is row-major */, + MemLayout::block_row_major, B_ROW, HEADDIM, B_COL, + /*leading_dim_a=*/0, /*leading_dim_b=*/0, + /*load_accum=*/true, + /*write_to_smem=*/true>( smem_P, smem_V, smem_O /*load accum*/, smem_O, tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, warpgroup_id_in_cluster); @@ -1045,6 +1044,9 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { // warpgroups_per_cluster, warpgroup_id_in_cluster); } } else { + static_assert(!WARP_SPECIALIZED || !GEMMINI_DMA, + "warp specialization unimplemented for dma"); + // when warp-specialized, there's only enough warps to do 64x32 tile // size so we need to do 2 GEMM calls static_assert(B_ROW / 2 == 32, diff --git a/tests/regression/sgemm_tcore/sgemm_impl.hpp b/tests/regression/sgemm_tcore/sgemm_impl.hpp index e563f23c..1bb7b893 100644 --- a/tests/regression/sgemm_tcore/sgemm_impl.hpp +++ b/tests/regression/sgemm_tcore/sgemm_impl.hpp @@ -254,9 +254,6 @@ inline void wmma_load_a(volatile const T *smem_A, const int local_k, constexpr int packed_factor = (std::is_same_v ? 2 : 1); const int local_k_adjusted = local_k / packed_factor; - static_assert(!GEMMINI_DMA || (layout == MemLayout::block_row_major) || - GEMMINI_DMA_FLEXIBLE_LAYOUT, - "wrong memory layout selected for DMA"); static_assert((layout != MemLayout::K_major) || (FP_SIZE == 32), "fp16 is not really tested for K-major A layout");