flash: Load Q to both quartiles; preload O for acc

This commit is contained in:
Hansung Kim
2024-09-08 19:47:55 -07:00
parent 8efa6868ea
commit 6547e92757

View File

@@ -202,8 +202,8 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
constexpr uint32_t skips_mvout_spad =
loop_matmul_skips(/*skip_lda=*/1, /*skip_ldb=*/1, /*skip_ldd=*/1,
/*skip_ex=*/1, /*skip_stc=*/0);
constexpr uint32_t skips_matmul =
loop_matmul_skips(/*skip_lda=*/1, /*skip_ldb=*/1, /*skip_ldd=*/1,
constexpr uint32_t skips_matmul_preload =
loop_matmul_skips(/*skip_lda=*/1, /*skip_ldb=*/1, /*skip_ldd=*/0,
/*skip_ex=*/0, /*skip_stc=*/1);
if constexpr (GEMMINI_DMA) {
@@ -255,6 +255,15 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
// spad_addr_K0 set in this kernel
GEMMINI_CISC_CMD_I(10);
gemmini_fence();
// need to also move to Q1 for the next iteration
ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, (uint64_t)(gmem_Q_tile),
(uint64_t)(gmem_K_tile), k_LOOP_WS_CONFIG_ADDRS_AB)
GEMMINI_CISC_CMD_R((dim_seqlen << 20) | (HEADDIM << 8) |
8 /*k_LOOP_WS_CONFIG_STRIDES_AB*/);
gemmini_fence();
GEMMINI_CISC_CMD_I(11);
gemmini_fence();
#else
// do DMA
//
@@ -369,11 +378,11 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
// DMA knows the full matrix dimensions
sp_tiled_matmul_full_spad_ws(
spad_addr_P_consume, spad_addr_V_consume,
/*spad_D=*/0, /*spad_C=*/spad_addr_O,
/*spad_D=*/spad_addr_O, /*spad_C=*/spad_addr_O,
/*I=*/(B_ROW / DIM), /*J=*/(HEADDIM / DIM), /*K=*/(B_COL / DIM),
/*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0,
/*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0,
/*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_matmul);
/*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_matmul_preload);
#endif
gemmini_fence();
@@ -455,7 +464,8 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
// mvout to SMEM
// GEMMINI_CISC_CMD_I(9);
sp_tiled_matmul_full_spad_ws(
/*spad_A=*/spad_addr_Q, /*spad_B=*/spad_addr_K_consume,
/*spad_A=*/spad_addr_Q /*bogus*/,
/*spad_B=*/spad_addr_K_consume /*bogus*/,
/*spad_D=*/0, /*spad_C=*/spad_addr_S_produce,
/*I=*/(B_ROW / DIM), /*J=*/(B_COL / DIM), /*K=*/(HEADDIM / DIM),
/*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0,