From adcd0a9d497488e2a5ad2645c96991bebffb5a8b Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Sun, 8 Sep 2024 02:23:51 -0700 Subject: [PATCH] sgemm_impl: Fix wrong smem address for fp16 Verified results for fp16 256x256. --- tests/regression/sgemm_tcore/kernel.cpp | 9 +++++---- tests/regression/sgemm_tcore/sgemm_impl.hpp | 10 +++++----- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/tests/regression/sgemm_tcore/kernel.cpp b/tests/regression/sgemm_tcore/kernel.cpp index 59fd7194..bc77ac2a 100644 --- a/tests/regression/sgemm_tcore/kernel.cpp +++ b/tests/regression/sgemm_tcore/kernel.cpp @@ -7,7 +7,7 @@ #include "include/gemmini.h" #include "gemmini_mmio.h" -constexpr bool DEBUG = true; +constexpr bool DEBUG = false; template inline void thread_block_copy_tile(const float *src, float *dest, @@ -91,8 +91,8 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { /*write_to_gmem=*/true, /*smem_a_offset=*/0, /*smem_a_dbuf_offset=*/0, - /*smem_b_offset=*/2 * BM * BK * sizeof(float), - /*smem_b_dbuf_offset=*/2 * BM * BK * sizeof(float)>( + /*smem_b_offset=*/2 * BM * BK * sizeof(float_type), + /*smem_b_dbuf_offset=*/2 * BM * BK * sizeof(float_type)>( (const float_type *)arg->addr_a, (const float_type *)arg->addr_b, (float *)arg->addr_c, arg->dim_m, arg->dim_n, arg->dim_k, tid_in_threadblock, threadblocks_per_cluster, threadblock_id_in_cluster, @@ -102,7 +102,8 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) { float *gmem_tmp_d1 = reinterpret_cast(0xd1000000UL); const float *smem_A = reinterpret_cast(sharedmem_per_threadblock); - const float *smem_B = smem_A + 2 * BM * BK; + const float *smem_B = reinterpret_cast( + sharedmem_per_threadblock + 2 * BM * BK * sizeof(float_type)); if constexpr (DEBUG) { threadblock_barrier(threadblock_id_in_cluster, diff --git a/tests/regression/sgemm_tcore/sgemm_impl.hpp b/tests/regression/sgemm_tcore/sgemm_impl.hpp index 1bb7b893..0c6274a2 100644 --- a/tests/regression/sgemm_tcore/sgemm_impl.hpp +++ b/tests/regression/sgemm_tcore/sgemm_impl.hpp @@ -6,7 +6,7 @@ #include "include/gemmini.h" #include "gemmini_mmio.h" -#define FP_SIZE 32 +#define FP_SIZE 16 // "fake" fp16 type that only has the correct data width. using float16_t = uint16_t; @@ -29,7 +29,7 @@ using float_type = float16_t; // (BM*BN) / (TM*TN) == threadblock size >= NT * CORES_PER_CLUSTER // * Combining BM * BK >= (BM*BN) / (TM*TN) == threadblock yields // BM <= BK*TM*TN -#define BM 64 +#define BM 128 #define BN 64 #if (FP_SIZE == 32) #define BK 64 @@ -72,7 +72,7 @@ static_assert(WMITER * WNITER * TCM * TCN * NUM_WARPS * CORES_PER_CLUSTER == #define TRANSPOSE_AT_PRODUCE 0 #define TRANSPOSE_AT_CONSUME 0 -#define GEMMINI_DMA 1 +#define GEMMINI_DMA 0 #define GEMMINI_DMA_FLEXIBLE_LAYOUT 0 #if SMEM_SIZE == 0x4000 #define SMEM_ADDR_Q0 ((float * const) 0xff000000) @@ -847,9 +847,9 @@ template < uint32_t smem_a_dbuf_offset = 0, // byte offset of A // double-buffer tile in shared // memory - uint32_t smem_b_offset = sizeof(float) * BM * BK, // byte offset of B tile + uint32_t smem_b_offset = sizeof(T) * BM * BK, // byte offset of B tile // in shared memory - uint32_t smem_b_dbuf_offset = sizeof(float) * BM * + uint32_t smem_b_dbuf_offset = sizeof(T) * BM * BK // byte offset of B double-buffer // tile in shared memory >