flash: Write fast config for DMA

MAC utilization is 20-25% for the loop.
This commit is contained in:
Hansung Kim
2024-09-07 20:46:58 -07:00
parent 8d32a03d09
commit 03308f8033

View File

@@ -13,11 +13,12 @@
#define HEADDIM 64 #define HEADDIM 64
constexpr uint32_t ROWMAX_SETS = 3; constexpr uint32_t ROWMAX_SETS = 3;
constexpr bool DEBUG = true; constexpr bool DEBUG = false;
constexpr bool WARP_SPECIALIZED = true; constexpr bool WARP_SPECIALIZED = true;
constexpr uint32_t DEV_FAKE_SMEM_START_ADDR = 0xf0000000; constexpr uint32_t DEV_FAKE_SMEM_START_ADDR = 0xf0000000;
constexpr bool GEMMINI_DMA_FAST = true;
constexpr bool Q_IS_K_MAJOR = true; constexpr bool Q_IS_K_MAJOR = true;
// temporary safety stop for wrong configs // temporary safety stop for wrong configs
@@ -763,6 +764,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
// GEMM I: S = Q*K // GEMM I: S = Q*K
// //
// FIXME: deduplicate this between GEMM II // FIXME: deduplicate this between GEMM II
asm volatile("gemm_qk_start_%=:" ::);
if constexpr (!WARP_SPECIALIZED) { if constexpr (!WARP_SPECIALIZED) {
// clear out accumulators before GEMM // clear out accumulators before GEMM
initialize_accum_regs<0>(); initialize_accum_regs<0>();
@@ -815,15 +817,27 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
// split by rows into 2 chunks // split by rows into 2 chunks
if constexpr (GEMMINI_DMA) { if constexpr (GEMMINI_DMA) {
thread_block_gemm_single_tile<float, MemLayout::block_row_major, if constexpr (GEMMINI_DMA_FAST) {
MemLayout::block_row_major, B_ROW / 2, thread_block_gemm_single_tile<float, MemLayout::MN_major,
B_COL, HEADDIM, /*leading_dim_a=*/0, MemLayout::MN_major, B_ROW / 2,
/*leading_dim_b=*/0, B_COL, HEADDIM, /*leading_dim_a=*/0,
/*load_accum=*/false, /*leading_dim_b=*/0,
/*write_to_smem=*/true>( /*load_accum=*/false,
smem_Q_half0, smem_K, nullptr /*ignore accum*/, smem_S_half0, /*write_to_smem=*/true>(
tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, smem_Q_half0, smem_K, nullptr /*ignore accum*/, smem_S_half0,
warpgroup_id_in_cluster); tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster,
warpgroup_id_in_cluster);
} else {
thread_block_gemm_single_tile<float, MemLayout::block_row_major,
MemLayout::block_row_major, B_ROW / 2,
B_COL, HEADDIM, /*leading_dim_a=*/0,
/*leading_dim_b=*/0,
/*load_accum=*/false,
/*write_to_smem=*/true>(
smem_Q_half0, smem_K, nullptr /*ignore accum*/, smem_S_half0,
tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster,
warpgroup_id_in_cluster);
}
} else if constexpr (Q_IS_K_MAJOR) { } else if constexpr (Q_IS_K_MAJOR) {
thread_block_gemm_single_tile< thread_block_gemm_single_tile<
float, MemLayout::K_major, MemLayout::MN_major, B_ROW / 2, B_COL, float, MemLayout::K_major, MemLayout::MN_major, B_ROW / 2, B_COL,
@@ -848,15 +862,27 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
initialize_accum_regs<1>(); initialize_accum_regs<1>();
if constexpr (GEMMINI_DMA) { if constexpr (GEMMINI_DMA) {
thread_block_gemm_single_tile<float, MemLayout::block_row_major, if constexpr (GEMMINI_DMA_FAST) {
MemLayout::block_row_major, B_ROW / 2, thread_block_gemm_single_tile<float, MemLayout::MN_major,
B_COL, HEADDIM, /*leading_dim_a=*/0, MemLayout::MN_major, B_ROW / 2,
/*leading_dim_b=*/0, B_COL, HEADDIM, /*leading_dim_a=*/0,
/*load_accum=*/false, /*leading_dim_b=*/0,
/*write_to_smem=*/true>( /*load_accum=*/false,
smem_Q_half1, smem_K, nullptr /*ignore accum*/, smem_S_half1, /*write_to_smem=*/true>(
tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, smem_Q_half1, smem_K, nullptr /*ignore accum*/, smem_S_half1,
warpgroup_id_in_cluster); tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster,
warpgroup_id_in_cluster);
} else {
thread_block_gemm_single_tile<float, MemLayout::block_row_major,
MemLayout::block_row_major, B_ROW / 2,
B_COL, HEADDIM, /*leading_dim_a=*/0,
/*leading_dim_b=*/0,
/*load_accum=*/false,
/*write_to_smem=*/true>(
smem_Q_half1, smem_K, nullptr /*ignore accum*/, smem_S_half1,
tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster,
warpgroup_id_in_cluster);
}
} else if constexpr (Q_IS_K_MAJOR) { } else if constexpr (Q_IS_K_MAJOR) {
thread_block_gemm_single_tile< thread_block_gemm_single_tile<
float, MemLayout::K_major, MemLayout::MN_major, B_ROW / 2, B_COL, float, MemLayout::K_major, MemLayout::MN_major, B_ROW / 2, B_COL,
@@ -888,6 +914,8 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
// protect write to SMEM (smem_S) before softmax // protect write to SMEM (smem_S) before softmax
threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core); threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core);
asm volatile("gemm_qk_finish_%=:" ::);
if constexpr (DEBUG) { if constexpr (DEBUG) {
if (warpgroup_id == 0) { if (warpgroup_id == 0) {
if (tile_k == 0) { if (tile_k == 0) {
@@ -921,6 +949,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
// data movement for K and V // data movement for K and V
// //
// Q stays in SMEM for the entire loop // Q stays in SMEM for the entire loop
asm volatile("move_k_v_start_%=:" ::);
if constexpr (GEMMINI_DMA) { if constexpr (GEMMINI_DMA) {
// NOTE: Beware of race conditions; with warp specialization, we need to // NOTE: Beware of race conditions; with warp specialization, we need to
// make sure below command code to DMA is not executed simultaneously // make sure below command code to DMA is not executed simultaneously
@@ -965,6 +994,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
HEADDIM, 0 /* full N-dimension */, tile_k, gmem_V, smem_V, HEADDIM, 0 /* full N-dimension */, tile_k, gmem_V, smem_V,
tid_in_warpgroup); tid_in_warpgroup);
} }
asm volatile("move_k_v_finish_%=:" ::);
// protect write to SMEM // protect write to SMEM
threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core); threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core);
@@ -995,8 +1025,6 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
// inter-warpgroup barrier before GEMM II // inter-warpgroup barrier before GEMM II
threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core); threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core);
// GEMM II: O = O + P*V
// Oi rescale // Oi rescale
thread_block_O_rescale(smem_O, smem_O /*in-place*/, thread_block_O_rescale(smem_O, smem_O /*in-place*/,
smem_O_row_scale, tid_in_warpgroup, smem_O_row_scale, tid_in_warpgroup,
@@ -1029,6 +1057,10 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
} }
} }
// GEMM II: O = O + P*V
asm volatile("gemm_pv_start_%=:" ::);
if constexpr (!WARP_SPECIALIZED) { if constexpr (!WARP_SPECIALIZED) {
// clear out accumulators before GEMM // clear out accumulators before GEMM
initialize_accum_regs<0>(); initialize_accum_regs<0>();
@@ -1084,16 +1116,29 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
// split by rows into 2 chunks // split by rows into 2 chunks
if constexpr (GEMMINI_DMA) { if constexpr (GEMMINI_DMA) {
thread_block_gemm_single_tile< if constexpr (GEMMINI_DMA_FAST) {
float, MemLayout::K_major /* P matrix is row-major */, thread_block_gemm_single_tile<float, MemLayout::MN_major,
MemLayout::block_row_major, B_ROW / 2, HEADDIM, B_COL, MemLayout::MN_major, B_ROW / 2, HEADDIM,
/*leading_dim_a=*/0, B_COL,
/*leading_dim_b=*/0, /*leading_dim_a=*/0,
/*load_accum=*/true, /*leading_dim_b=*/0,
/*write_to_smem=*/true>( /*load_accum=*/true,
smem_P_half0, smem_V, smem_O_half0 /*load accum*/, smem_O_half0, /*write_to_smem=*/true>(
tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, smem_P_half0, smem_V, smem_O_half0 /*load accum*/, smem_O_half0,
warpgroup_id_in_cluster); tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster,
warpgroup_id_in_cluster);
} else {
thread_block_gemm_single_tile<
float, MemLayout::K_major /* P matrix is row-major */,
MemLayout::block_row_major, B_ROW / 2, HEADDIM, B_COL,
/*leading_dim_a=*/0,
/*leading_dim_b=*/0,
/*load_accum=*/true,
/*write_to_smem=*/true>(
smem_P_half0, smem_V, smem_O_half0 /*load accum*/, smem_O_half0,
tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster,
warpgroup_id_in_cluster);
}
} else { } else {
thread_block_gemm_single_tile< thread_block_gemm_single_tile<
float, MemLayout::K_major, MemLayout::MN_major, B_ROW / 2, HEADDIM, float, MemLayout::K_major, MemLayout::MN_major, B_ROW / 2, HEADDIM,
@@ -1109,16 +1154,29 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
initialize_accum_regs<1>(); initialize_accum_regs<1>();
if constexpr (GEMMINI_DMA) { if constexpr (GEMMINI_DMA) {
thread_block_gemm_single_tile< if constexpr (GEMMINI_DMA_FAST) {
float, MemLayout::K_major /* P matrix is row-major */, thread_block_gemm_single_tile<float, MemLayout::MN_major,
MemLayout::block_row_major, B_ROW / 2, HEADDIM, B_COL, MemLayout::MN_major, B_ROW / 2, HEADDIM,
/*leading_dim_a=*/0, B_COL,
/*leading_dim_b=*/0, /*leading_dim_a=*/0,
/*load_accum=*/true, /*leading_dim_b=*/0,
/*write_to_smem=*/true>( /*load_accum=*/true,
smem_P_half1, smem_V, smem_O_half1 /*load accum*/, smem_O_half1, /*write_to_smem=*/true>(
tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, smem_P_half1, smem_V, smem_O_half1 /*load accum*/, smem_O_half1,
warpgroup_id_in_cluster); tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster,
warpgroup_id_in_cluster);
} else {
thread_block_gemm_single_tile<
float, MemLayout::K_major /* P matrix is row-major */,
MemLayout::block_row_major, B_ROW / 2, HEADDIM, B_COL,
/*leading_dim_a=*/0,
/*leading_dim_b=*/0,
/*load_accum=*/true,
/*write_to_smem=*/true>(
smem_P_half1, smem_V, smem_O_half1 /*load accum*/, smem_O_half1,
tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster,
warpgroup_id_in_cluster);
}
} else { } else {
thread_block_gemm_single_tile< thread_block_gemm_single_tile<
float, MemLayout::K_major, MemLayout::MN_major, B_ROW / 2, HEADDIM, float, MemLayout::K_major, MemLayout::MN_major, B_ROW / 2, HEADDIM,
@@ -1133,6 +1191,8 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core); threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core);
asm volatile("gemm_pv_finish_%=:" ::);
if constexpr (DEBUG) { if constexpr (DEBUG) {
if (warpgroup_id == 0) { if (warpgroup_id == 0) {
// O after PV // O after PV