flash: Restructure to do delayed fence for better concurrency
Verified up to O_before_PV of 2nd iteration; O_after_PV needs preload fix. FIXME: Stalls at barrier without DEBUG set.
This commit is contained in:
@@ -389,7 +389,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
GEMMINI_CISC_CMD_I(1);
|
GEMMINI_CISC_CMD_I(1);
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
// do matmul
|
// kickoff matmul
|
||||||
// among other things, this also configures CONFIG_BOUNDS so that the
|
// among other things, this also configures CONFIG_BOUNDS so that the
|
||||||
// DMA knows the full matrix dimensions
|
// DMA knows the full matrix dimensions
|
||||||
// FIXME: perf: prevent GMEM->SMEM load for O tile
|
// FIXME: perf: prevent GMEM->SMEM load for O tile
|
||||||
@@ -402,27 +402,6 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
/*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0,
|
/*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0,
|
||||||
/*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_matmul);
|
/*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_matmul);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
gemmini_fence();
|
|
||||||
gemmini_fence();
|
|
||||||
gemmini_fence();
|
|
||||||
gemmini_fence();
|
|
||||||
|
|
||||||
// mvout to SMEM
|
|
||||||
// GEMMINI_CISC_CMD_I(9);
|
|
||||||
sp_tiled_matmul_full_spad_ws(
|
|
||||||
/*spad_A=*/spad_addr_P_consume, /*spad_B=*/spad_addr_V_consume,
|
|
||||||
/*spad_D=*/0, /*spad_C=*/spad_addr_O,
|
|
||||||
/*I=*/(B_ROW / DIM), /*J=*/(HEADDIM / DIM), /*K=*/(B_COL / DIM),
|
|
||||||
/*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0,
|
|
||||||
/*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0,
|
|
||||||
/*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_mvout_spad);
|
|
||||||
gemmini_fence();
|
|
||||||
|
|
||||||
if constexpr (DEBUG) {
|
|
||||||
// for copy-out to GMEM
|
|
||||||
gemmini_fence();
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// reconverge from mmio divergence
|
// reconverge from mmio divergence
|
||||||
@@ -431,99 +410,6 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
|
|
||||||
asm volatile("gemm_pv_finish_%=:" ::);
|
asm volatile("gemm_pv_finish_%=:" ::);
|
||||||
|
|
||||||
if constexpr (DEBUG) {
|
|
||||||
if (warpgroup_id == 0) {
|
|
||||||
// O after PV
|
|
||||||
if (tile_k_ == 0) {
|
|
||||||
thread_block_copy_tile<B_ROW, HEADDIM, GEMMINI_DMA>(
|
|
||||||
smem_O, gmem_tmp_d6, tid_in_warpgroup, threads_per_warpgroup,
|
|
||||||
warpgroup_id_in_cluster);
|
|
||||||
} else if (tile_k_ == 1) {
|
|
||||||
thread_block_copy_tile<B_ROW, HEADDIM, GEMMINI_DMA>(
|
|
||||||
smem_O, gmem_tmp_d7, tid_in_warpgroup, threads_per_warpgroup,
|
|
||||||
warpgroup_id_in_cluster);
|
|
||||||
}
|
|
||||||
|
|
||||||
threadblock_barrier(warpgroup_id_in_cluster,
|
|
||||||
warps_per_warpgroup_per_core);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// GEMM I: S = Q*K
|
|
||||||
//
|
|
||||||
asm volatile("gemm_qk_start_%=:" ::);
|
|
||||||
|
|
||||||
if (tid_in_warpgroup == 0) {
|
|
||||||
gemmini_fence();
|
|
||||||
// 0,2,.: opcode 0 (quartile 0/2, no accum)
|
|
||||||
// 1,3,.: opcode 3 (quartile 1/3, no accum)
|
|
||||||
const uint32_t opcode = 3 * (tile_k & 1);
|
|
||||||
//GEMMINI_CISC_CMD_I(opcode);
|
|
||||||
sp_tiled_matmul_full_spad_ws(
|
|
||||||
spad_addr_Q, spad_addr_K_consume,
|
|
||||||
/*spad_D=*/0, /*spad_C=*/spad_addr_S_produce,
|
|
||||||
/*I=*/(B_ROW / DIM), /*J=*/(B_COL / DIM), /*K=*/(HEADDIM / DIM),
|
|
||||||
/*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0,
|
|
||||||
/*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0,
|
|
||||||
/*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_matmul);
|
|
||||||
|
|
||||||
gemmini_fence();
|
|
||||||
gemmini_fence();
|
|
||||||
gemmini_fence();
|
|
||||||
gemmini_fence();
|
|
||||||
|
|
||||||
#if 0 // TODO: speed up mvout to SMEM
|
|
||||||
// loop_ws variant that skips configuring strides
|
|
||||||
#define gemmini_loop_ws(I, J, K, pad_I, pad_J, pad_K, A, B, D, C, A_stride, B_stride, D_stride, C_stride, A_transpose, B_transpose, full_C, low_D, ex_accumulate, act, a_spad_id, b_spad_id, is_resadd) \
|
|
||||||
{ \
|
|
||||||
ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, ((uint64_t)(pad_K) << 32) | ((uint64_t)(pad_J) << 16) | (uint64_t)(pad_I), ((uint64_t)(K) << 32) | ((uint64_t)(J) << 16) | (uint64_t)(I), k_LOOP_WS_CONFIG_BOUNDS) \
|
|
||||||
ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, A, B, k_LOOP_WS_CONFIG_ADDRS_AB) \
|
|
||||||
ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, D, C, k_LOOP_WS_CONFIG_ADDRS_DC) \
|
|
||||||
ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, A_stride, B_stride, k_LOOP_WS_CONFIG_STRIDES_AB) \
|
|
||||||
ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, D_stride, C_stride, k_LOOP_WS_CONFIG_STRIDES_DC) \
|
|
||||||
ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, ((uint64_t)(a_spad_id) << 18) | ((uint64_t)(b_spad_id) << 16) | ((uint64_t)(act) << 8) | ((low_D) << 2) | ((full_C) << 1) | (ex_accumulate), ((is_resadd) << 2) | ((B_transpose) << 1) | (A_transpose), k_LOOP_WS) \
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
|
|
||||||
// mvout to SMEM
|
|
||||||
// GEMMINI_CISC_CMD_I(9);
|
|
||||||
sp_tiled_matmul_full_spad_ws(
|
|
||||||
/*spad_A=*/spad_addr_Q /*bogus*/,
|
|
||||||
/*spad_B=*/spad_addr_K_consume /*bogus*/,
|
|
||||||
/*spad_D=*/0, /*spad_C=*/spad_addr_S_produce,
|
|
||||||
/*I=*/(B_ROW / DIM), /*J=*/(B_COL / DIM), /*K=*/(HEADDIM / DIM),
|
|
||||||
/*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0,
|
|
||||||
/*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0,
|
|
||||||
/*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_mvout_spad);
|
|
||||||
gemmini_fence();
|
|
||||||
|
|
||||||
if constexpr (DEBUG) {
|
|
||||||
// for copy-out to GMEM
|
|
||||||
gemmini_fence();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// reconverge from mmio divergence
|
|
||||||
threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core);
|
|
||||||
|
|
||||||
asm volatile("gemm_qk_finish_%=:" ::);
|
|
||||||
|
|
||||||
if constexpr (DEBUG) {
|
|
||||||
if (warpgroup_id == 0) {
|
|
||||||
if (tile_k == 0) {
|
|
||||||
thread_block_copy_tile<B_ROW, B_COL, GEMMINI_DMA>(
|
|
||||||
smem_S_produce, gmem_tmp_d0, tid_in_warpgroup, threads_per_warpgroup,
|
|
||||||
warpgroup_id_in_cluster);
|
|
||||||
} else if (tile_k == 1) {
|
|
||||||
thread_block_copy_tile<B_ROW, B_COL, GEMMINI_DMA>(
|
|
||||||
smem_S_produce, gmem_tmp_d1, tid_in_warpgroup, threads_per_warpgroup,
|
|
||||||
warpgroup_id_in_cluster);
|
|
||||||
}
|
|
||||||
|
|
||||||
threadblock_barrier(warpgroup_id_in_cluster,
|
|
||||||
warps_per_warpgroup_per_core);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (tile_k >= 1) // delay by 1 iters for pipelining
|
if (tile_k >= 1) // delay by 1 iters for pipelining
|
||||||
@@ -563,7 +449,89 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: put a synchronization here with GEMM-II
|
if (tid_in_warpgroup == 0) {
|
||||||
|
// fence GEMM-II to make sure dependency on O tile is settled
|
||||||
|
gemmini_fence();
|
||||||
|
gemmini_fence();
|
||||||
|
|
||||||
|
// mvout to SMEM
|
||||||
|
// GEMMINI_CISC_CMD_I(9);
|
||||||
|
sp_tiled_matmul_full_spad_ws(
|
||||||
|
/*spad_A=*/spad_addr_P_consume, /*spad_B=*/spad_addr_V_consume,
|
||||||
|
/*spad_D=*/0, /*spad_C=*/spad_addr_O,
|
||||||
|
/*I=*/(B_ROW / DIM), /*J=*/(HEADDIM / DIM), /*K=*/(B_COL / DIM),
|
||||||
|
/*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0,
|
||||||
|
/*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0,
|
||||||
|
/*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_mvout_spad);
|
||||||
|
}
|
||||||
|
|
||||||
|
// reconverge from mmio divergence
|
||||||
|
threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core);
|
||||||
|
|
||||||
|
if constexpr (DEBUG) {
|
||||||
|
gemmini_fence();
|
||||||
|
|
||||||
|
if (warpgroup_id == 0) {
|
||||||
|
// O after PV
|
||||||
|
if (tile_k_ == 0) {
|
||||||
|
thread_block_copy_tile<B_ROW, HEADDIM, GEMMINI_DMA>(
|
||||||
|
smem_O, gmem_tmp_d6, tid_in_warpgroup, threads_per_warpgroup,
|
||||||
|
warpgroup_id_in_cluster);
|
||||||
|
} else if (tile_k_ == 1) {
|
||||||
|
thread_block_copy_tile<B_ROW, HEADDIM, GEMMINI_DMA>(
|
||||||
|
smem_O, gmem_tmp_d7, tid_in_warpgroup, threads_per_warpgroup,
|
||||||
|
warpgroup_id_in_cluster);
|
||||||
|
}
|
||||||
|
|
||||||
|
threadblock_barrier(warpgroup_id_in_cluster,
|
||||||
|
warps_per_warpgroup_per_core);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
// GEMM I: S = Q*K
|
||||||
|
//
|
||||||
|
// kick off asynchronously; fence later
|
||||||
|
asm volatile("gemm_qk_start_%=:" ::);
|
||||||
|
|
||||||
|
if (tid_in_warpgroup == 0) {
|
||||||
|
gemmini_fence();
|
||||||
|
// 0,2,.: opcode 0 (quartile 0/2, no accum)
|
||||||
|
// 1,3,.: opcode 3 (quartile 1/3, no accum)
|
||||||
|
const uint32_t opcode = 3 * (tile_k & 1);
|
||||||
|
//GEMMINI_CISC_CMD_I(opcode);
|
||||||
|
sp_tiled_matmul_full_spad_ws(
|
||||||
|
spad_addr_Q, spad_addr_K_consume,
|
||||||
|
/*spad_D=*/0, /*spad_C=*/spad_addr_S_produce,
|
||||||
|
/*I=*/(B_ROW / DIM), /*J=*/(B_COL / DIM), /*K=*/(HEADDIM / DIM),
|
||||||
|
/*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0,
|
||||||
|
/*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0,
|
||||||
|
/*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_matmul);
|
||||||
|
|
||||||
|
#if 0 // TODO: speed up mvout to SMEM
|
||||||
|
// loop_ws variant that skips configuring strides
|
||||||
|
#define gemmini_loop_ws(I, J, K, pad_I, pad_J, pad_K, A, B, D, C, A_stride, B_stride, D_stride, C_stride, A_transpose, B_transpose, full_C, low_D, ex_accumulate, act, a_spad_id, b_spad_id, is_resadd) \
|
||||||
|
{ \
|
||||||
|
ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, ((uint64_t)(pad_K) << 32) | ((uint64_t)(pad_J) << 16) | (uint64_t)(pad_I), ((uint64_t)(K) << 32) | ((uint64_t)(J) << 16) | (uint64_t)(I), k_LOOP_WS_CONFIG_BOUNDS) \
|
||||||
|
ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, A, B, k_LOOP_WS_CONFIG_ADDRS_AB) \
|
||||||
|
ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, D, C, k_LOOP_WS_CONFIG_ADDRS_DC) \
|
||||||
|
ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, A_stride, B_stride, k_LOOP_WS_CONFIG_STRIDES_AB) \
|
||||||
|
ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, D_stride, C_stride, k_LOOP_WS_CONFIG_STRIDES_DC) \
|
||||||
|
ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, ((uint64_t)(a_spad_id) << 18) | ((uint64_t)(b_spad_id) << 16) | ((uint64_t)(act) << 8) | ((low_D) << 2) | ((full_C) << 1) | (ex_accumulate), ((is_resadd) << 2) | ((B_transpose) << 1) | (A_transpose), k_LOOP_WS) \
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
// reconverge from mmio divergence
|
||||||
|
threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core);
|
||||||
|
|
||||||
|
asm volatile("gemm_qk_finish_%=:" ::);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (tile_k >= 1) // delay by 1 iters for pipelining
|
||||||
|
{
|
||||||
|
const uint32_t tile_k_ = tile_k - 1;
|
||||||
|
|
||||||
// Oi rescale
|
// Oi rescale
|
||||||
thread_block_O_rescale</*block_row_major=*/GEMMINI_DMA>(
|
thread_block_O_rescale</*block_row_major=*/GEMMINI_DMA>(
|
||||||
@@ -599,6 +567,46 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// fence GEMM I after Oi rescale
|
||||||
|
if (tid_in_warpgroup == 0) {
|
||||||
|
gemmini_fence();
|
||||||
|
gemmini_fence();
|
||||||
|
gemmini_fence();
|
||||||
|
gemmini_fence();
|
||||||
|
|
||||||
|
// mvout to SMEM
|
||||||
|
// GEMMINI_CISC_CMD_I(9);
|
||||||
|
sp_tiled_matmul_full_spad_ws(
|
||||||
|
/*spad_A=*/spad_addr_Q /*bogus*/,
|
||||||
|
/*spad_B=*/spad_addr_K_consume /*bogus*/,
|
||||||
|
/*spad_D=*/0, /*spad_C=*/spad_addr_S_produce,
|
||||||
|
/*I=*/(B_ROW / DIM), /*J=*/(B_COL / DIM), /*K=*/(HEADDIM / DIM),
|
||||||
|
/*pad_I=*/0, /*pad_J=*/0, /*pad_K=*/0,
|
||||||
|
/*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0,
|
||||||
|
/*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_mvout_spad);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// reconverge from mmio divergence
|
||||||
|
threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core);
|
||||||
|
|
||||||
|
if constexpr (DEBUG) {
|
||||||
|
if (warpgroup_id == 0) {
|
||||||
|
if (tile_k == 0) {
|
||||||
|
thread_block_copy_tile<B_ROW, B_COL, GEMMINI_DMA>(
|
||||||
|
smem_S_produce, gmem_tmp_d0, tid_in_warpgroup,
|
||||||
|
threads_per_warpgroup, warpgroup_id_in_cluster);
|
||||||
|
} else if (tile_k == 1) {
|
||||||
|
thread_block_copy_tile<B_ROW, B_COL, GEMMINI_DMA>(
|
||||||
|
smem_S_produce, gmem_tmp_d1, tid_in_warpgroup,
|
||||||
|
threads_per_warpgroup, warpgroup_id_in_cluster);
|
||||||
|
}
|
||||||
|
|
||||||
|
threadblock_barrier(warpgroup_id_in_cluster,
|
||||||
|
warps_per_warpgroup_per_core);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// data move for K and V
|
// data move for K and V
|
||||||
//
|
//
|
||||||
// Q stays in SMEM for the entire loop
|
// Q stays in SMEM for the entire loop
|
||||||
@@ -616,6 +624,9 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
// iterations later
|
// iterations later
|
||||||
const float *gmem_V_tile =
|
const float *gmem_V_tile =
|
||||||
gmem_V + (HEADDIM * B_COL * (tile_k - 1 /*dragbehind*/));
|
gmem_V + (HEADDIM * B_COL * (tile_k - 1 /*dragbehind*/));
|
||||||
|
|
||||||
|
// fence mvout S to SMEM
|
||||||
|
gemmini_fence();
|
||||||
ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, (uint64_t)(gmem_K_tile),
|
ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, (uint64_t)(gmem_K_tile),
|
||||||
(uint64_t)(gmem_V_tile),
|
(uint64_t)(gmem_V_tile),
|
||||||
k_LOOP_WS_CONFIG_ADDRS_AB)
|
k_LOOP_WS_CONFIG_ADDRS_AB)
|
||||||
|
|||||||
Reference in New Issue
Block a user