new gemm kernel

This commit is contained in:
Richard Yan
2024-11-08 20:55:27 -08:00
parent 367fa927f8
commit c114a7a4ab
3 changed files with 110 additions and 77 deletions

View File

@@ -12,7 +12,7 @@
// 128KB
// #define SMEM_SIZE 0x20000
// 256KB
#define SMEM_SIZE 0x40000
#define SMEM_SIZE 0x20000
#define SMEM_MASK (SMEM_SIZE - 1)
#define SMEM_ADDR_END (SMEM_BASE + SMEM_SIZE)
@@ -85,6 +85,51 @@ static size_t gemmini_tile_idx[NUM_THREADS * NUM_WARPS * NUM_CORES * NUM_CLUSTER
gemmini_loop_ws_spad(I, J, K, pad_I, pad_J, pad_K, A_sp_addr_start, (B_sp_addr_start) + (K) * (J) * DIM, NULL, \
C_dst_sp_addr_start, a_transpose, b_transpose, full_C, low_D, acc, act, 0, 0, false, skips)
#define GEMMINI_CISC_COMPUTE_HEXADECILES 0
#define GEMMINI_CISC_SET_AB_STRIDE 8
#define GEMMINI_CISC_STORE_TO_SPAD 9
#define GEMMINI_CISC_LOAD_TO_HEXADECILES 10
#define GEMMINI_CISC_SET_DC_STRIDE 11
#define GEMMINI_CISC_STORE_TO_GMEM 12
// cisc instruction wrappers
inline void gemmini_tile_load_ab(const elem_t * const a_addr, const elem_t * const b_addr,
const uint32_t a_hexadecile, const uint32_t b_hexadecile,
const uint32_t tile_idx_i, const uint32_t tile_idx_j, const uint32_t tile_idx_k,
const uint32_t mat_size_m, const uint32_t mat_size_n, const uint32_t mat_size_k,
const uint32_t tile_size_m, const uint32_t tile_size_n, const uint32_t tile_size_k) {
ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC,
(uint64_t) (a_addr + tile_idx_i * tile_size_m * mat_size_k + tile_idx_k * tile_size_k),
(uint64_t) (b_addr + tile_idx_k * tile_size_k * mat_size_n + tile_idx_j * tile_size_n), k_LOOP_WS_CONFIG_ADDRS_AB)
GEMMINI_CISC_CMD_R((mat_size_n << 20) | (mat_size_k << 8) | GEMMINI_CISC_SET_AB_STRIDE);
GEMMINI_CISC_CMD_R((b_hexadecile << 16) | (a_hexadecile << 8) | GEMMINI_CISC_LOAD_TO_HEXADECILES);
}
inline void gemmini_tile_compute(const uint32_t a_hexadecile, const uint32_t b_hexadecile, const bool accumulate) {
GEMMINI_CISC_CMD_R((accumulate << 24) | (b_hexadecile << 16) | (a_hexadecile << 8) | GEMMINI_CISC_COMPUTE_HEXADECILES);
}
inline void gemmini_tile_store_c_gmem(elem_t * const c_addr,
const uint32_t tile_idx_i, const uint32_t tile_idx_j,
const uint32_t mat_size_m, const uint32_t mat_size_n,
const uint32_t tile_size_m, const uint32_t tile_size_n) {
elem_t * const dram_c_tile_start = c_addr + tile_idx_i * tile_size_m * mat_size_n + tile_idx_j * tile_size_n;
ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, 0, (uint64_t) dram_c_tile_start, k_LOOP_WS_CONFIG_ADDRS_DC)
GEMMINI_CISC_CMD_R((mat_size_n << 20) | GEMMINI_CISC_SET_DC_STRIDE);
GEMMINI_CISC_CMD_I(GEMMINI_CISC_STORE_TO_GMEM);
// ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, 0, BOUND_INST, k_LOOP_WS_CONFIG_BOUNDS)
// ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, 0, mat_size_n, k_LOOP_WS_CONFIG_STRIDES_DC)
// ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, 0, loop_matmul_skips(1, 1, 1, 1, 0), k_LOOP_WS)
}
inline void gemmini_tile_store_c_spad(const uint32_t c_hexadecile) {
GEMMINI_CISC_CMD_R(((uint32_t) (c_hexadecile << 8)) | GEMMINI_CISC_STORE_TO_SPAD);
}
/* inline static void sp_tiled_matmul_full_spad_ws(const uint32_t A_sp_addr_start, const uint32_t B_sp_addr_start,
const uint32_t D_sp_addr_start, const uint32_t C_dst_sp_addr_start,
size_t I, size_t J, size_t K, size_t pad_I, size_t pad_J, size_t pad_K,