From 6990fcc1e66820935f4d1244191c9cf505bee4d7 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Sat, 9 Nov 2024 16:43:45 -0800 Subject: [PATCH] Add compute-and-mvout-to-spad API --- kernel/include/gemmini_mmio.h | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/kernel/include/gemmini_mmio.h b/kernel/include/gemmini_mmio.h index 6043e0be..f1ca0e77 100644 --- a/kernel/include/gemmini_mmio.h +++ b/kernel/include/gemmini_mmio.h @@ -87,6 +87,7 @@ static size_t gemmini_tile_idx[NUM_THREADS * NUM_WARPS * NUM_CORES * NUM_CLUSTER #define GEMMINI_CISC_COMPUTE_HEXADECILES 0 +#define GEMMINI_CISC_COMPUTE_AND_STORE_TO_SPAD 1 #define GEMMINI_CISC_SET_AB_STRIDE 8 #define GEMMINI_CISC_STORE_TO_SPAD 9 #define GEMMINI_CISC_LOAD_TO_HEXADECILES 10 @@ -107,8 +108,18 @@ inline void gemmini_tile_load_ab(const elem_t * const a_addr, const elem_t * con GEMMINI_CISC_CMD_R((b_hexadecile << 16) | (a_hexadecile << 8) | GEMMINI_CISC_LOAD_TO_HEXADECILES); } -inline void gemmini_tile_compute(const uint32_t a_hexadecile, const uint32_t b_hexadecile, const bool accumulate) { - GEMMINI_CISC_CMD_R((accumulate << 24) | (b_hexadecile << 16) | (a_hexadecile << 8) | GEMMINI_CISC_COMPUTE_HEXADECILES); +template +inline void gemmini_tile_compute(const uint32_t a_hexadecile, + const uint32_t b_hexadecile, + const uint32_t d_hexadecile, + const bool accumulate) { + if constexpr (!store_to_spad) { + GEMMINI_CISC_CMD_R((accumulate << 24) | (b_hexadecile << 16) | + (a_hexadecile << 8) | GEMMINI_CISC_COMPUTE_HEXADECILES); + } else { + GEMMINI_CISC_CMD_R((d_hexadecile << 24) | (b_hexadecile << 16) | + (a_hexadecile << 8) | GEMMINI_CISC_COMPUTE_AND_STORE_TO_SPAD); + } } inline void gemmini_tile_store_c_gmem(elem_t * const c_addr,