flash: Load Q to both quartiles; preload O for acc
This commit is contained in:
@@ -202,8 +202,8 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
constexpr uint32_t skips_mvout_spad =
|
||||
loop_matmul_skips(/*skip_lda=*/1, /*skip_ldb=*/1, /*skip_ldd=*/1,
|
||||
/*skip_ex=*/1, /*skip_stc=*/0);
|
||||
constexpr uint32_t skips_matmul =
|
||||
loop_matmul_skips(/*skip_lda=*/1, /*skip_ldb=*/1, /*skip_ldd=*/1,
|
||||
constexpr uint32_t skips_matmul_preload =
|
||||
loop_matmul_skips(/*skip_lda=*/1, /*skip_ldb=*/1, /*skip_ldd=*/0,
|
||||
/*skip_ex=*/0, /*skip_stc=*/1);
|
||||
|
||||
if constexpr (GEMMINI_DMA) {
|
||||
@@ -255,6 +255,15 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
// spad_addr_K0 set in this kernel
|
||||
GEMMINI_CISC_CMD_I(10);
|
||||
gemmini_fence();
|
||||
|
||||
// need to also move to Q1 for the next iteration
|
||||
ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, (uint64_t)(gmem_Q_tile),
|
||||
(uint64_t)(gmem_K_tile), k_LOOP_WS_CONFIG_ADDRS_AB)
|
||||
GEMMINI_CISC_CMD_R((dim_seqlen << 20) | (HEADDIM << 8) |
|
||||
8 /*k_LOOP_WS_CONFIG_STRIDES_AB*/);
|
||||
gemmini_fence();
|
||||
GEMMINI_CISC_CMD_I(11);
|
||||
gemmini_fence();
|
||||
#else
|
||||
// do DMA
|
||||
//
|
||||
@@ -369,11 +378,11 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
// DMA knows the full matrix dimensions
|
||||
sp_tiled_matmul_full_spad_ws(
|
||||
spad_addr_P_consume, spad_addr_V_consume,
|
||||
/*spad_D=*/0, /*spad_C=*/spad_addr_O,
|
||||
/*spad_D=*/spad_addr_O, /*spad_C=*/spad_addr_O,
|
||||
/*I=*/(B_ROW / DIM), /*J=*/(HEADDIM / DIM), /*K=*/(B_COL / DIM),
|
||||
/*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0,
|
||||
/*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0,
|
||||
/*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_matmul);
|
||||
/*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_matmul_preload);
|
||||
#endif
|
||||
|
||||
gemmini_fence();
|
||||
@@ -455,7 +464,8 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
// mvout to SMEM
|
||||
// GEMMINI_CISC_CMD_I(9);
|
||||
sp_tiled_matmul_full_spad_ws(
|
||||
/*spad_A=*/spad_addr_Q, /*spad_B=*/spad_addr_K_consume,
|
||||
/*spad_A=*/spad_addr_Q /*bogus*/,
|
||||
/*spad_B=*/spad_addr_K_consume /*bogus*/,
|
||||
/*spad_D=*/0, /*spad_C=*/spad_addr_S_produce,
|
||||
/*I=*/(B_ROW / DIM), /*J=*/(B_COL / DIM), /*K=*/(HEADDIM / DIM),
|
||||
/*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0,
|
||||
|
||||
Reference in New Issue
Block a user