sgemm_impl: Use 12-bit cmd interface, allow DIM=16
This commit is contained in:
@@ -207,7 +207,7 @@ template <bool use_dma, uint32_t dim_col>
|
|||||||
inline constexpr std::pair<uint32_t, uint32_t>
|
inline constexpr std::pair<uint32_t, uint32_t>
|
||||||
remap_to_gemmini_dma_layout(const uint32_t logical_row,
|
remap_to_gemmini_dma_layout(const uint32_t logical_row,
|
||||||
const uint32_t logical_col) {
|
const uint32_t logical_col) {
|
||||||
static_assert(DIM == 8,
|
static_assert(GEMMINI_DMA_FLEXIBLE_LAYOUT || DIM == 8,
|
||||||
"GEMMINI_DMA layout remapping code only written for DIM == 8");
|
"GEMMINI_DMA layout remapping code only written for DIM == 8");
|
||||||
|
|
||||||
if constexpr (use_dma) {
|
if constexpr (use_dma) {
|
||||||
@@ -905,6 +905,8 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
|
|||||||
for (uint32_t block_m = block_m_start; block_m < block_m_end; block_m++) {
|
for (uint32_t block_m = block_m_start; block_m < block_m_end; block_m++) {
|
||||||
#pragma GCC unroll 1
|
#pragma GCC unroll 1
|
||||||
for (uint32_t block_n = 0; (block_n * BN) < dim_n; block_n++) {
|
for (uint32_t block_n = 0; (block_n * BN) < dim_n; block_n++) {
|
||||||
|
asm volatile ("loop_mn_start_%=:" :: );
|
||||||
|
|
||||||
// clear out accumulators
|
// clear out accumulators
|
||||||
initialize_accum_regs<0>();
|
initialize_accum_regs<0>();
|
||||||
initialize_accum_regs<1>();
|
initialize_accum_regs<1>();
|
||||||
@@ -920,7 +922,7 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
|
|||||||
(uint64_t)(B + /*block_k:*/0 * BK * dim_n + block_n * BN),
|
(uint64_t)(B + /*block_k:*/0 * BK * dim_n + block_n * BN),
|
||||||
k_LOOP_WS_CONFIG_ADDRS_AB)
|
k_LOOP_WS_CONFIG_ADDRS_AB)
|
||||||
// GEMMINI_CISC(8) does k_LOOP_WS_CONFIG_STRIDES_AB
|
// GEMMINI_CISC(8) does k_LOOP_WS_CONFIG_STRIDES_AB
|
||||||
GEMMINI_CISC_CMD_R((dim_n << 16) | (dim_k << 8) | 8);
|
GEMMINI_CISC_CMD_R((dim_n << 20) | (dim_k << 8) | 8);
|
||||||
gemmini_fence();
|
gemmini_fence();
|
||||||
|
|
||||||
GEMMINI_CISC_CMD_I(10);
|
GEMMINI_CISC_CMD_I(10);
|
||||||
@@ -951,6 +953,7 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
|
|||||||
|
|
||||||
#pragma GCC unroll 1
|
#pragma GCC unroll 1
|
||||||
for (uint32_t block_k = 0; (block_k * BK) < dim_k; block_k++) {
|
for (uint32_t block_k = 0; (block_k * BK) < dim_k; block_k++) {
|
||||||
|
asm volatile("loop_k_start_%=:" ::);
|
||||||
|
|
||||||
// producer code: GMEM->SMEM memory movement
|
// producer code: GMEM->SMEM memory movement
|
||||||
// ---------------------------------------------------------------------
|
// ---------------------------------------------------------------------
|
||||||
@@ -967,8 +970,8 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
|
|||||||
(uint64_t)(B + (block_k + 1/*runahead*/) * BK * dim_n + block_n * BN),
|
(uint64_t)(B + (block_k + 1/*runahead*/) * BK * dim_n + block_n * BN),
|
||||||
k_LOOP_WS_CONFIG_ADDRS_AB)
|
k_LOOP_WS_CONFIG_ADDRS_AB)
|
||||||
// GEMMINI_CISC(8) does k_LOOP_WS_CONFIG_STRIDES_AB
|
// GEMMINI_CISC(8) does k_LOOP_WS_CONFIG_STRIDES_AB
|
||||||
GEMMINI_CISC_CMD_R((dim_n << 16) | (dim_k << 8) | 8);
|
GEMMINI_CISC_CMD_R((dim_n << 20) | (dim_k << 8) | 8);
|
||||||
// gemmini_fence();
|
gemmini_fence();
|
||||||
|
|
||||||
// block_k is even: opcode 11 (write to local_a_buf)
|
// block_k is even: opcode 11 (write to local_a_buf)
|
||||||
// block_k is odd: opcode 10 (write to local_a)
|
// block_k is odd: opcode 10 (write to local_a)
|
||||||
@@ -1043,6 +1046,7 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
|
|||||||
// consumer code: SMEM->RF and compute
|
// consumer code: SMEM->RF and compute
|
||||||
// ----------------------------------------------------------------------
|
// ----------------------------------------------------------------------
|
||||||
// @perf: this loop spills to stack a lot because of all the flws in
|
// @perf: this loop spills to stack a lot because of all the flws in
|
||||||
|
asm volatile("dbuf_sel_start_%=:" ::);
|
||||||
const T *local_a_consume;
|
const T *local_a_consume;
|
||||||
const T *local_b_consume;
|
const T *local_b_consume;
|
||||||
if constexpr (GEMMINI_DMA) {
|
if constexpr (GEMMINI_DMA) {
|
||||||
@@ -1064,6 +1068,7 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
|
|||||||
local_a_consume = local_a;
|
local_a_consume = local_a;
|
||||||
local_b_consume = local_b;
|
local_b_consume = local_b;
|
||||||
}
|
}
|
||||||
|
asm volatile("dbuf_sel_end_%=:" ::);
|
||||||
|
|
||||||
constexpr MemLayout layout_a =
|
constexpr MemLayout layout_a =
|
||||||
GEMMINI_DMA ? MemLayout::block_row_major
|
GEMMINI_DMA ? MemLayout::block_row_major
|
||||||
@@ -1092,6 +1097,8 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
|
|||||||
|
|
||||||
threadblock_barrier(threadblock_id_in_cluster,
|
threadblock_barrier(threadblock_id_in_cluster,
|
||||||
warps_per_threadblock_per_core);
|
warps_per_threadblock_per_core);
|
||||||
|
|
||||||
|
asm volatile("loop_k_end_%=:" ::);
|
||||||
}
|
}
|
||||||
|
|
||||||
if constexpr (write_to_gmem) {
|
if constexpr (write_to_gmem) {
|
||||||
@@ -1106,6 +1113,7 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
asm volatile("loop_mn_end_%=:" ::);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user