new gemm kernel
This commit is contained in:
@@ -12,7 +12,7 @@
|
||||
// 128KB
|
||||
// #define SMEM_SIZE 0x20000
|
||||
// 256KB
|
||||
#define SMEM_SIZE 0x40000
|
||||
#define SMEM_SIZE 0x20000
|
||||
|
||||
#define SMEM_MASK (SMEM_SIZE - 1)
|
||||
#define SMEM_ADDR_END (SMEM_BASE + SMEM_SIZE)
|
||||
@@ -85,6 +85,51 @@ static size_t gemmini_tile_idx[NUM_THREADS * NUM_WARPS * NUM_CORES * NUM_CLUSTER
|
||||
gemmini_loop_ws_spad(I, J, K, pad_I, pad_J, pad_K, A_sp_addr_start, (B_sp_addr_start) + (K) * (J) * DIM, NULL, \
|
||||
C_dst_sp_addr_start, a_transpose, b_transpose, full_C, low_D, acc, act, 0, 0, false, skips)
|
||||
|
||||
|
||||
#define GEMMINI_CISC_COMPUTE_HEXADECILES 0
|
||||
#define GEMMINI_CISC_SET_AB_STRIDE 8
|
||||
#define GEMMINI_CISC_STORE_TO_SPAD 9
|
||||
#define GEMMINI_CISC_LOAD_TO_HEXADECILES 10
|
||||
#define GEMMINI_CISC_SET_DC_STRIDE 11
|
||||
#define GEMMINI_CISC_STORE_TO_GMEM 12
|
||||
|
||||
// cisc instruction wrappers
|
||||
inline void gemmini_tile_load_ab(const elem_t * const a_addr, const elem_t * const b_addr,
|
||||
const uint32_t a_hexadecile, const uint32_t b_hexadecile,
|
||||
const uint32_t tile_idx_i, const uint32_t tile_idx_j, const uint32_t tile_idx_k,
|
||||
const uint32_t mat_size_m, const uint32_t mat_size_n, const uint32_t mat_size_k,
|
||||
const uint32_t tile_size_m, const uint32_t tile_size_n, const uint32_t tile_size_k) {
|
||||
|
||||
ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC,
|
||||
(uint64_t) (a_addr + tile_idx_i * tile_size_m * mat_size_k + tile_idx_k * tile_size_k),
|
||||
(uint64_t) (b_addr + tile_idx_k * tile_size_k * mat_size_n + tile_idx_j * tile_size_n), k_LOOP_WS_CONFIG_ADDRS_AB)
|
||||
GEMMINI_CISC_CMD_R((mat_size_n << 20) | (mat_size_k << 8) | GEMMINI_CISC_SET_AB_STRIDE);
|
||||
GEMMINI_CISC_CMD_R((b_hexadecile << 16) | (a_hexadecile << 8) | GEMMINI_CISC_LOAD_TO_HEXADECILES);
|
||||
}
|
||||
|
||||
inline void gemmini_tile_compute(const uint32_t a_hexadecile, const uint32_t b_hexadecile, const bool accumulate) {
|
||||
GEMMINI_CISC_CMD_R((accumulate << 24) | (b_hexadecile << 16) | (a_hexadecile << 8) | GEMMINI_CISC_COMPUTE_HEXADECILES);
|
||||
}
|
||||
|
||||
inline void gemmini_tile_store_c_gmem(elem_t * const c_addr,
|
||||
const uint32_t tile_idx_i, const uint32_t tile_idx_j,
|
||||
const uint32_t mat_size_m, const uint32_t mat_size_n,
|
||||
const uint32_t tile_size_m, const uint32_t tile_size_n) {
|
||||
|
||||
elem_t * const dram_c_tile_start = c_addr + tile_idx_i * tile_size_m * mat_size_n + tile_idx_j * tile_size_n;
|
||||
ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, 0, (uint64_t) dram_c_tile_start, k_LOOP_WS_CONFIG_ADDRS_DC)
|
||||
|
||||
GEMMINI_CISC_CMD_R((mat_size_n << 20) | GEMMINI_CISC_SET_DC_STRIDE);
|
||||
GEMMINI_CISC_CMD_I(GEMMINI_CISC_STORE_TO_GMEM);
|
||||
|
||||
// ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, 0, BOUND_INST, k_LOOP_WS_CONFIG_BOUNDS)
|
||||
// ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, 0, mat_size_n, k_LOOP_WS_CONFIG_STRIDES_DC)
|
||||
// ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, 0, loop_matmul_skips(1, 1, 1, 1, 0), k_LOOP_WS)
|
||||
}
|
||||
|
||||
inline void gemmini_tile_store_c_spad(const uint32_t c_hexadecile) {
|
||||
GEMMINI_CISC_CMD_R(((uint32_t) (c_hexadecile << 8)) | GEMMINI_CISC_STORE_TO_SPAD);
|
||||
}
|
||||
/* inline static void sp_tiled_matmul_full_spad_ws(const uint32_t A_sp_addr_start, const uint32_t B_sp_addr_start,
|
||||
const uint32_t D_sp_addr_start, const uint32_t C_dst_sp_addr_start,
|
||||
size_t I, size_t J, size_t K, size_t pad_I, size_t pad_J, size_t pad_K,
|
||||
|
||||
Reference in New Issue
Block a user