Merge branch 'new-cisc' of https://github.com/hansungk/vortex-private into new-cisc

This commit is contained in:
Richard Yan
2025-01-28 17:14:49 -08:00
9 changed files with 1269 additions and 331 deletions

View File

@@ -11,10 +11,11 @@
// #define SMEM_SIZE 0x4000
// 64KB
// #define SMEM_SIZE 0x10000
// 128KB
// #define SMEM_SIZE 0x20000
// 256KB
// 128KB (FP16 GEMM)
#define SMEM_SIZE 0x20000
// 256KB (FlashAttention)
// #define SMEM_SIZE 0x40000
#define SMEM_MASK (SMEM_SIZE - 1)
#define SMEM_ADDR_END (SMEM_BASE + SMEM_SIZE)
@@ -47,6 +48,7 @@ static size_t gemmini_tile_idx[NUM_THREADS * NUM_WARPS * NUM_CORES * NUM_CLUSTER
#define GEMMINI_RS2_ADDR (GEMMINI_CTRL + 0x18)
#define GEMMINI_INST_ADDR (GEMMINI_CTRL + 0x0)
#define GEMMINI_BUSY_ADDR (GEMMINI_CTRL + 0x20)
#define GEMMINI_OCCUPANCY_ADDR (GEMMINI_CTRL + 0x28)
#undef ROCC_INSTRUCTION_RS1_RS2
#define ROCC_INSTRUCTION_RS1_RS2(x, rs1, rs2, funct) { \
*((volatile uint64_t *) GEMMINI_RS1_ADDR) = (rs1); \
@@ -70,6 +72,8 @@ static size_t gemmini_tile_idx[NUM_THREADS * NUM_WARPS * NUM_CORES * NUM_CLUSTER
//#define gemmini_fence() { while (gemmini_status()); }
#define gemmini_fence() { while (*((volatile uint32_t *) GEMMINI_BUSY_ADDR)) asm volatile ("nop"); }
#define virgo_fence(n) { while (*((volatile uint32_t *) GEMMINI_OCCUPANCY_ADDR) > n) asm volatile ("nop"); }
/* cisc instructions */
/* ================= */
@@ -80,6 +84,8 @@ static size_t gemmini_tile_idx[NUM_THREADS * NUM_WARPS * NUM_CORES * NUM_CLUSTER
#define GEMMINI_CISC_CMD_R(x) asm("csrw 0xacc, %0" :: "r" (x))
#define GEMMINI_CISC_COMPUTE_HEXADECILES 0
#define GEMMINI_CISC_COMPUTE_AND_STORE_TO_SPAD 1
#define GEMMINI_CISC_MANUAL 2
#define GEMMINI_CISC_SET_AB_STRIDE 8
#define GEMMINI_CISC_STORE_TO_SPAD 9
#define GEMMINI_CISC_LOAD_TO_HEXADECILES 10
@@ -101,8 +107,19 @@ inline void gemmini_tile_load_ab(const elem_t * const a_addr, const elem_t * con
GEMMINI_CISC_CMD_R((b_hexadecile << 16) | (a_hexadecile << 8) | GEMMINI_CISC_LOAD_TO_HEXADECILES);
}
inline void gemmini_tile_compute(const uint32_t a_hexadecile, const uint32_t b_hexadecile, const bool accumulate) {
GEMMINI_CISC_CMD_R((accumulate << 24) | (b_hexadecile << 16) | (a_hexadecile << 8) | GEMMINI_CISC_COMPUTE_HEXADECILES);
template <bool store_to_spad = false>
inline void gemmini_tile_compute(const uint32_t a_hexadecile,
const uint32_t b_hexadecile,
const uint32_t d_hexadecile,
const bool accumulate) {
if constexpr (!store_to_spad) {
GEMMINI_CISC_CMD_R((static_cast<uint32_t>(accumulate) << 24) |
(b_hexadecile << 16) | (a_hexadecile << 8) |
GEMMINI_CISC_COMPUTE_HEXADECILES);
} else {
GEMMINI_CISC_CMD_R((d_hexadecile << 24) | (b_hexadecile << 16) |
(a_hexadecile << 8) | GEMMINI_CISC_COMPUTE_AND_STORE_TO_SPAD);
}
}
inline void gemmini_tile_store_c_gmem(elem_t * const c_addr,
@@ -125,6 +142,10 @@ inline void gemmini_tile_store_c_spad(const uint32_t c_hexadecile) {
GEMMINI_CISC_CMD_R(((uint32_t) (c_hexadecile << 8)) | GEMMINI_CISC_STORE_TO_SPAD);
}
inline void gemmini_manual_job() {
GEMMINI_CISC_CMD_I(GEMMINI_CISC_MANUAL);
}
/* inline static void sp_tiled_matmul_full_spad_ws(const uint32_t A_sp_addr_start, const uint32_t B_sp_addr_start,
const uint32_t D_sp_addr_start, const uint32_t C_dst_sp_addr_start,
size_t I, size_t J, size_t K, size_t pad_I, size_t pad_J, size_t pad_K,