flash: Swap S1/S0 to avoid GEMM II - softmax bank conflict
+ remove spurrious fences to better overlap GEMM I and DMA
This commit is contained in:
@@ -108,8 +108,10 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
constexpr uint32_t smem_K1_offset = smem_quart3;
|
constexpr uint32_t smem_K1_offset = smem_quart3;
|
||||||
constexpr uint32_t smem_V0_offset = smem_Q0_offset + smem_Q_size * sizeof(float);
|
constexpr uint32_t smem_V0_offset = smem_Q0_offset + smem_Q_size * sizeof(float);
|
||||||
constexpr uint32_t smem_V1_offset = smem_Q1_offset + smem_Q_size * sizeof(float);
|
constexpr uint32_t smem_V1_offset = smem_Q1_offset + smem_Q_size * sizeof(float);
|
||||||
constexpr uint32_t smem_S0_offset = smem_V0_offset + smem_V_size * sizeof(float);
|
// put S1/S0 with V0/V1 so that softmax and GEMM-II doesn't cause bank
|
||||||
constexpr uint32_t smem_S1_offset = smem_V1_offset + smem_V_size * sizeof(float);
|
// conflicts
|
||||||
|
constexpr uint32_t smem_S0_offset = smem_V1_offset + smem_V_size * sizeof(float);
|
||||||
|
constexpr uint32_t smem_S1_offset = smem_V0_offset + smem_V_size * sizeof(float);
|
||||||
constexpr uint32_t smem_P0_offset = smem_K0_offset + smem_K_size * sizeof(float);
|
constexpr uint32_t smem_P0_offset = smem_K0_offset + smem_K_size * sizeof(float);
|
||||||
constexpr uint32_t smem_P1_offset = smem_K1_offset + smem_K_size * sizeof(float);
|
constexpr uint32_t smem_P1_offset = smem_K1_offset + smem_K_size * sizeof(float);
|
||||||
// reversed!
|
// reversed!
|
||||||
@@ -177,14 +179,16 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
constexpr uint32_t spad_addr_O0 = smem_O0_offset / spad_addr_factor;
|
constexpr uint32_t spad_addr_O0 = smem_O0_offset / spad_addr_factor;
|
||||||
constexpr uint32_t spad_addr_O1 = smem_O1_offset / spad_addr_factor;
|
constexpr uint32_t spad_addr_O1 = smem_O1_offset / spad_addr_factor;
|
||||||
|
|
||||||
|
constexpr uint32_t global_barrier_id = NUM_WARPS - 1; // arbitrary
|
||||||
|
static_assert(warps_per_threadblock_per_core == NUM_WARPS);
|
||||||
|
|
||||||
// initialize rowmax/rowsum values in sharedmem
|
// initialize rowmax/rowsum values in sharedmem
|
||||||
thread_block_init_sharedmem(tid_in_warpgroup, threads_per_warpgroup, smem_O0,
|
thread_block_init_sharedmem(tid_in_warpgroup, threads_per_warpgroup, smem_O0,
|
||||||
smem_rowmax_0, smem_rowsum_0, smem_O_row_scale_0);
|
smem_rowmax_0, smem_rowsum_0, smem_O_row_scale_0);
|
||||||
thread_block_init_sharedmem(tid_in_warpgroup, threads_per_warpgroup, smem_O1,
|
thread_block_init_sharedmem(tid_in_warpgroup, threads_per_warpgroup, smem_O1,
|
||||||
smem_rowmax_1, smem_rowsum_1, smem_O_row_scale_1);
|
smem_rowmax_1, smem_rowsum_1, smem_O_row_scale_1);
|
||||||
|
|
||||||
constexpr uint32_t global_barrier_id = NUM_WARPS - 1; // arbitrary
|
threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core);
|
||||||
static_assert(warps_per_threadblock_per_core == NUM_WARPS);
|
|
||||||
|
|
||||||
static_assert(!GEMMINI_DMA || Q_IS_K_MAJOR,
|
static_assert(!GEMMINI_DMA || Q_IS_K_MAJOR,
|
||||||
"DMA code assumes Q matrix is stored K-major");
|
"DMA code assumes Q matrix is stored K-major");
|
||||||
@@ -209,22 +213,19 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
loop_matmul_skips(/*skip_lda=*/1, /*skip_ldb=*/1, /*skip_ldd=*/0,
|
loop_matmul_skips(/*skip_lda=*/1, /*skip_ldb=*/1, /*skip_ldd=*/0,
|
||||||
/*skip_ex=*/0, /*skip_stc=*/1);
|
/*skip_ex=*/0, /*skip_stc=*/1);
|
||||||
|
|
||||||
if constexpr (GEMMINI_DMA) {
|
if (tid_in_warpgroup == 0) {
|
||||||
if (tid_in_warpgroup == 0) {
|
gemmini_extended_config_ex(WEIGHT_STATIONARY, 0, 0, 1, 0, 0);
|
||||||
gemmini_extended_config_ex(WEIGHT_STATIONARY, 0, 0, 1, 0, 0);
|
|
||||||
|
|
||||||
// configure DMA with GMEM address strides
|
// configure DMA with GMEM address strides
|
||||||
// Q matrix
|
// Q matrix
|
||||||
gemmini_extended3_config_ld(HEADDIM * sizeof(elem_t), MVIN_SCALE_IDENTITY,
|
gemmini_extended3_config_ld(HEADDIM * sizeof(elem_t), MVIN_SCALE_IDENTITY,
|
||||||
false, 0);
|
false, 0);
|
||||||
// K matrix
|
// K matrix
|
||||||
gemmini_extended3_config_ld(dim_seqlen * sizeof(elem_t), MVIN_SCALE_IDENTITY,
|
gemmini_extended3_config_ld(dim_seqlen * sizeof(elem_t),
|
||||||
false, 1);
|
MVIN_SCALE_IDENTITY, false, 1);
|
||||||
// configure DMA for Q*K store
|
// configure DMA for Q*K store
|
||||||
gemmini_extended_config_st(B_COL * sizeof(elem_t), 0,
|
gemmini_extended_config_st(B_COL * sizeof(elem_t), 0, MVIN_SCALE_IDENTITY);
|
||||||
MVIN_SCALE_IDENTITY);
|
gemmini_fence();
|
||||||
gemmini_fence();
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NOTE about barriers: Placing barriers around thread-divergent branches may
|
// NOTE about barriers: Placing barriers around thread-divergent branches may
|
||||||
@@ -319,8 +320,6 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
// threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core);
|
// threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core);
|
||||||
// }
|
// }
|
||||||
|
|
||||||
threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core);
|
|
||||||
|
|
||||||
constexpr uint32_t threads_per_warpgroup_simt =
|
constexpr uint32_t threads_per_warpgroup_simt =
|
||||||
threads_per_warpgroup -
|
threads_per_warpgroup -
|
||||||
CORES_PER_CLUSTER * NUM_THREADS /*warp 0, 4, 8, 12*/;
|
CORES_PER_CLUSTER * NUM_THREADS /*warp 0, 4, 8, 12*/;
|
||||||
@@ -337,7 +336,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
const uint32_t k_tiles = (dim_seqlen / B_COL);
|
const uint32_t k_tiles = (dim_seqlen / B_COL);
|
||||||
for (uint32_t tile_k = 0;
|
for (uint32_t tile_k = 0;
|
||||||
tile_k <
|
tile_k <
|
||||||
(1 /*FIXME: for perf measurement*/ * k_tiles) + 2 /*pipeline latency*/;
|
(4 /*FIXME: for perf measurement*/ * k_tiles) + 2 /*pipeline latency*/;
|
||||||
tile_k++) {
|
tile_k++) {
|
||||||
if constexpr (DEBUG || true) {
|
if constexpr (DEBUG || true) {
|
||||||
threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core);
|
threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core);
|
||||||
@@ -456,28 +455,17 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
/*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0,
|
/*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0,
|
||||||
/*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_matmul);
|
/*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips_matmul);
|
||||||
|
|
||||||
gemmini_fence();
|
// gemmini_fence();
|
||||||
gemmini_fence();
|
// gemmini_fence();
|
||||||
gemmini_fence();
|
// gemmini_fence();
|
||||||
gemmini_fence();
|
// gemmini_fence();
|
||||||
}
|
asm volatile("gemm_qk_finish_%=:" ::);
|
||||||
// // reconverge after mmio
|
|
||||||
// threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core);
|
|
||||||
|
|
||||||
asm volatile("gemm_qk_finish_%=:" ::);
|
// data move for K and V
|
||||||
|
//
|
||||||
|
// Q stays in SMEM for the entire loop
|
||||||
|
asm volatile("move_k_v_start_%=:" ::);
|
||||||
|
|
||||||
// TODO: put synchronization here with online softmax
|
|
||||||
|
|
||||||
// data move for K and V
|
|
||||||
//
|
|
||||||
// Q stays in SMEM for the entire loop
|
|
||||||
asm volatile("move_k_v_start_%=:" ::);
|
|
||||||
|
|
||||||
// NOTE: Beware of race conditions; with warp specialization, we need to
|
|
||||||
// make sure below command code to DMA is not executed simultaneously
|
|
||||||
// from the two warpgroups (which will result in hardware fault).
|
|
||||||
// Currently the ping-pong scheduling scheme prevents that.
|
|
||||||
if (tid_in_warpgroup == 0) {
|
|
||||||
// configure GMEM addresses for K and V tiles
|
// configure GMEM addresses for K and V tiles
|
||||||
// load K for the next iteration
|
// load K for the next iteration
|
||||||
const float *gmem_K_tile = gmem_K + (B_COL * (tile_k + 1 /*runahead*/));
|
const float *gmem_K_tile = gmem_K + (B_COL * (tile_k + 1 /*runahead*/));
|
||||||
@@ -497,7 +485,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
GEMMINI_CISC_CMD_R((HEADDIM /*V*/ << 20) | (dim_seqlen /*KT*/ << 8) |
|
GEMMINI_CISC_CMD_R((HEADDIM /*V*/ << 20) | (dim_seqlen /*KT*/ << 8) |
|
||||||
8 /*k_LOOP_WS_CONFIG_STRIDES_AB*/);
|
8 /*k_LOOP_WS_CONFIG_STRIDES_AB*/);
|
||||||
#endif
|
#endif
|
||||||
gemmini_fence();
|
// gemmini_fence();
|
||||||
|
|
||||||
// do DMA
|
// do DMA
|
||||||
if (tile_k == 0) {
|
if (tile_k == 0) {
|
||||||
@@ -518,6 +506,8 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
/*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0,
|
/*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0,
|
||||||
/*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips);
|
/*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// fence everything before going to the next tile
|
||||||
gemmini_fence();
|
gemmini_fence();
|
||||||
gemmini_fence();
|
gemmini_fence();
|
||||||
gemmini_fence();
|
gemmini_fence();
|
||||||
|
|||||||
Reference in New Issue
Block a user