flash: Write fast config for DMA
MAC utilization is 20-25% for the loop.
This commit is contained in:
@@ -13,11 +13,12 @@
|
|||||||
#define HEADDIM 64
|
#define HEADDIM 64
|
||||||
|
|
||||||
constexpr uint32_t ROWMAX_SETS = 3;
|
constexpr uint32_t ROWMAX_SETS = 3;
|
||||||
constexpr bool DEBUG = true;
|
constexpr bool DEBUG = false;
|
||||||
constexpr bool WARP_SPECIALIZED = true;
|
constexpr bool WARP_SPECIALIZED = true;
|
||||||
|
|
||||||
constexpr uint32_t DEV_FAKE_SMEM_START_ADDR = 0xf0000000;
|
constexpr uint32_t DEV_FAKE_SMEM_START_ADDR = 0xf0000000;
|
||||||
|
|
||||||
|
constexpr bool GEMMINI_DMA_FAST = true;
|
||||||
constexpr bool Q_IS_K_MAJOR = true;
|
constexpr bool Q_IS_K_MAJOR = true;
|
||||||
|
|
||||||
// temporary safety stop for wrong configs
|
// temporary safety stop for wrong configs
|
||||||
@@ -763,6 +764,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
// GEMM I: S = Q*K
|
// GEMM I: S = Q*K
|
||||||
//
|
//
|
||||||
// FIXME: deduplicate this between GEMM II
|
// FIXME: deduplicate this between GEMM II
|
||||||
|
asm volatile("gemm_qk_start_%=:" ::);
|
||||||
if constexpr (!WARP_SPECIALIZED) {
|
if constexpr (!WARP_SPECIALIZED) {
|
||||||
// clear out accumulators before GEMM
|
// clear out accumulators before GEMM
|
||||||
initialize_accum_regs<0>();
|
initialize_accum_regs<0>();
|
||||||
@@ -815,15 +817,27 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
|
|
||||||
// split by rows into 2 chunks
|
// split by rows into 2 chunks
|
||||||
if constexpr (GEMMINI_DMA) {
|
if constexpr (GEMMINI_DMA) {
|
||||||
thread_block_gemm_single_tile<float, MemLayout::block_row_major,
|
if constexpr (GEMMINI_DMA_FAST) {
|
||||||
MemLayout::block_row_major, B_ROW / 2,
|
thread_block_gemm_single_tile<float, MemLayout::MN_major,
|
||||||
B_COL, HEADDIM, /*leading_dim_a=*/0,
|
MemLayout::MN_major, B_ROW / 2,
|
||||||
/*leading_dim_b=*/0,
|
B_COL, HEADDIM, /*leading_dim_a=*/0,
|
||||||
/*load_accum=*/false,
|
/*leading_dim_b=*/0,
|
||||||
/*write_to_smem=*/true>(
|
/*load_accum=*/false,
|
||||||
smem_Q_half0, smem_K, nullptr /*ignore accum*/, smem_S_half0,
|
/*write_to_smem=*/true>(
|
||||||
tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster,
|
smem_Q_half0, smem_K, nullptr /*ignore accum*/, smem_S_half0,
|
||||||
warpgroup_id_in_cluster);
|
tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster,
|
||||||
|
warpgroup_id_in_cluster);
|
||||||
|
} else {
|
||||||
|
thread_block_gemm_single_tile<float, MemLayout::block_row_major,
|
||||||
|
MemLayout::block_row_major, B_ROW / 2,
|
||||||
|
B_COL, HEADDIM, /*leading_dim_a=*/0,
|
||||||
|
/*leading_dim_b=*/0,
|
||||||
|
/*load_accum=*/false,
|
||||||
|
/*write_to_smem=*/true>(
|
||||||
|
smem_Q_half0, smem_K, nullptr /*ignore accum*/, smem_S_half0,
|
||||||
|
tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster,
|
||||||
|
warpgroup_id_in_cluster);
|
||||||
|
}
|
||||||
} else if constexpr (Q_IS_K_MAJOR) {
|
} else if constexpr (Q_IS_K_MAJOR) {
|
||||||
thread_block_gemm_single_tile<
|
thread_block_gemm_single_tile<
|
||||||
float, MemLayout::K_major, MemLayout::MN_major, B_ROW / 2, B_COL,
|
float, MemLayout::K_major, MemLayout::MN_major, B_ROW / 2, B_COL,
|
||||||
@@ -848,15 +862,27 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
initialize_accum_regs<1>();
|
initialize_accum_regs<1>();
|
||||||
|
|
||||||
if constexpr (GEMMINI_DMA) {
|
if constexpr (GEMMINI_DMA) {
|
||||||
thread_block_gemm_single_tile<float, MemLayout::block_row_major,
|
if constexpr (GEMMINI_DMA_FAST) {
|
||||||
MemLayout::block_row_major, B_ROW / 2,
|
thread_block_gemm_single_tile<float, MemLayout::MN_major,
|
||||||
B_COL, HEADDIM, /*leading_dim_a=*/0,
|
MemLayout::MN_major, B_ROW / 2,
|
||||||
/*leading_dim_b=*/0,
|
B_COL, HEADDIM, /*leading_dim_a=*/0,
|
||||||
/*load_accum=*/false,
|
/*leading_dim_b=*/0,
|
||||||
/*write_to_smem=*/true>(
|
/*load_accum=*/false,
|
||||||
smem_Q_half1, smem_K, nullptr /*ignore accum*/, smem_S_half1,
|
/*write_to_smem=*/true>(
|
||||||
tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster,
|
smem_Q_half1, smem_K, nullptr /*ignore accum*/, smem_S_half1,
|
||||||
warpgroup_id_in_cluster);
|
tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster,
|
||||||
|
warpgroup_id_in_cluster);
|
||||||
|
} else {
|
||||||
|
thread_block_gemm_single_tile<float, MemLayout::block_row_major,
|
||||||
|
MemLayout::block_row_major, B_ROW / 2,
|
||||||
|
B_COL, HEADDIM, /*leading_dim_a=*/0,
|
||||||
|
/*leading_dim_b=*/0,
|
||||||
|
/*load_accum=*/false,
|
||||||
|
/*write_to_smem=*/true>(
|
||||||
|
smem_Q_half1, smem_K, nullptr /*ignore accum*/, smem_S_half1,
|
||||||
|
tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster,
|
||||||
|
warpgroup_id_in_cluster);
|
||||||
|
}
|
||||||
} else if constexpr (Q_IS_K_MAJOR) {
|
} else if constexpr (Q_IS_K_MAJOR) {
|
||||||
thread_block_gemm_single_tile<
|
thread_block_gemm_single_tile<
|
||||||
float, MemLayout::K_major, MemLayout::MN_major, B_ROW / 2, B_COL,
|
float, MemLayout::K_major, MemLayout::MN_major, B_ROW / 2, B_COL,
|
||||||
@@ -888,6 +914,8 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
// protect write to SMEM (smem_S) before softmax
|
// protect write to SMEM (smem_S) before softmax
|
||||||
threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core);
|
threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core);
|
||||||
|
|
||||||
|
asm volatile("gemm_qk_finish_%=:" ::);
|
||||||
|
|
||||||
if constexpr (DEBUG) {
|
if constexpr (DEBUG) {
|
||||||
if (warpgroup_id == 0) {
|
if (warpgroup_id == 0) {
|
||||||
if (tile_k == 0) {
|
if (tile_k == 0) {
|
||||||
@@ -921,6 +949,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
// data movement for K and V
|
// data movement for K and V
|
||||||
//
|
//
|
||||||
// Q stays in SMEM for the entire loop
|
// Q stays in SMEM for the entire loop
|
||||||
|
asm volatile("move_k_v_start_%=:" ::);
|
||||||
if constexpr (GEMMINI_DMA) {
|
if constexpr (GEMMINI_DMA) {
|
||||||
// NOTE: Beware of race conditions; with warp specialization, we need to
|
// NOTE: Beware of race conditions; with warp specialization, we need to
|
||||||
// make sure below command code to DMA is not executed simultaneously
|
// make sure below command code to DMA is not executed simultaneously
|
||||||
@@ -965,6 +994,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
HEADDIM, 0 /* full N-dimension */, tile_k, gmem_V, smem_V,
|
HEADDIM, 0 /* full N-dimension */, tile_k, gmem_V, smem_V,
|
||||||
tid_in_warpgroup);
|
tid_in_warpgroup);
|
||||||
}
|
}
|
||||||
|
asm volatile("move_k_v_finish_%=:" ::);
|
||||||
|
|
||||||
// protect write to SMEM
|
// protect write to SMEM
|
||||||
threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core);
|
threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core);
|
||||||
@@ -995,8 +1025,6 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
// inter-warpgroup barrier before GEMM II
|
// inter-warpgroup barrier before GEMM II
|
||||||
threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core);
|
threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core);
|
||||||
|
|
||||||
// GEMM II: O = O + P*V
|
|
||||||
|
|
||||||
// Oi rescale
|
// Oi rescale
|
||||||
thread_block_O_rescale(smem_O, smem_O /*in-place*/,
|
thread_block_O_rescale(smem_O, smem_O /*in-place*/,
|
||||||
smem_O_row_scale, tid_in_warpgroup,
|
smem_O_row_scale, tid_in_warpgroup,
|
||||||
@@ -1029,6 +1057,10 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GEMM II: O = O + P*V
|
||||||
|
|
||||||
|
asm volatile("gemm_pv_start_%=:" ::);
|
||||||
|
|
||||||
if constexpr (!WARP_SPECIALIZED) {
|
if constexpr (!WARP_SPECIALIZED) {
|
||||||
// clear out accumulators before GEMM
|
// clear out accumulators before GEMM
|
||||||
initialize_accum_regs<0>();
|
initialize_accum_regs<0>();
|
||||||
@@ -1084,16 +1116,29 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
|
|
||||||
// split by rows into 2 chunks
|
// split by rows into 2 chunks
|
||||||
if constexpr (GEMMINI_DMA) {
|
if constexpr (GEMMINI_DMA) {
|
||||||
thread_block_gemm_single_tile<
|
if constexpr (GEMMINI_DMA_FAST) {
|
||||||
float, MemLayout::K_major /* P matrix is row-major */,
|
thread_block_gemm_single_tile<float, MemLayout::MN_major,
|
||||||
MemLayout::block_row_major, B_ROW / 2, HEADDIM, B_COL,
|
MemLayout::MN_major, B_ROW / 2, HEADDIM,
|
||||||
/*leading_dim_a=*/0,
|
B_COL,
|
||||||
/*leading_dim_b=*/0,
|
/*leading_dim_a=*/0,
|
||||||
/*load_accum=*/true,
|
/*leading_dim_b=*/0,
|
||||||
/*write_to_smem=*/true>(
|
/*load_accum=*/true,
|
||||||
smem_P_half0, smem_V, smem_O_half0 /*load accum*/, smem_O_half0,
|
/*write_to_smem=*/true>(
|
||||||
tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster,
|
smem_P_half0, smem_V, smem_O_half0 /*load accum*/, smem_O_half0,
|
||||||
warpgroup_id_in_cluster);
|
tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster,
|
||||||
|
warpgroup_id_in_cluster);
|
||||||
|
} else {
|
||||||
|
thread_block_gemm_single_tile<
|
||||||
|
float, MemLayout::K_major /* P matrix is row-major */,
|
||||||
|
MemLayout::block_row_major, B_ROW / 2, HEADDIM, B_COL,
|
||||||
|
/*leading_dim_a=*/0,
|
||||||
|
/*leading_dim_b=*/0,
|
||||||
|
/*load_accum=*/true,
|
||||||
|
/*write_to_smem=*/true>(
|
||||||
|
smem_P_half0, smem_V, smem_O_half0 /*load accum*/, smem_O_half0,
|
||||||
|
tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster,
|
||||||
|
warpgroup_id_in_cluster);
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
thread_block_gemm_single_tile<
|
thread_block_gemm_single_tile<
|
||||||
float, MemLayout::K_major, MemLayout::MN_major, B_ROW / 2, HEADDIM,
|
float, MemLayout::K_major, MemLayout::MN_major, B_ROW / 2, HEADDIM,
|
||||||
@@ -1109,16 +1154,29 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
initialize_accum_regs<1>();
|
initialize_accum_regs<1>();
|
||||||
|
|
||||||
if constexpr (GEMMINI_DMA) {
|
if constexpr (GEMMINI_DMA) {
|
||||||
thread_block_gemm_single_tile<
|
if constexpr (GEMMINI_DMA_FAST) {
|
||||||
float, MemLayout::K_major /* P matrix is row-major */,
|
thread_block_gemm_single_tile<float, MemLayout::MN_major,
|
||||||
MemLayout::block_row_major, B_ROW / 2, HEADDIM, B_COL,
|
MemLayout::MN_major, B_ROW / 2, HEADDIM,
|
||||||
/*leading_dim_a=*/0,
|
B_COL,
|
||||||
/*leading_dim_b=*/0,
|
/*leading_dim_a=*/0,
|
||||||
/*load_accum=*/true,
|
/*leading_dim_b=*/0,
|
||||||
/*write_to_smem=*/true>(
|
/*load_accum=*/true,
|
||||||
smem_P_half1, smem_V, smem_O_half1 /*load accum*/, smem_O_half1,
|
/*write_to_smem=*/true>(
|
||||||
tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster,
|
smem_P_half1, smem_V, smem_O_half1 /*load accum*/, smem_O_half1,
|
||||||
warpgroup_id_in_cluster);
|
tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster,
|
||||||
|
warpgroup_id_in_cluster);
|
||||||
|
} else {
|
||||||
|
thread_block_gemm_single_tile<
|
||||||
|
float, MemLayout::K_major /* P matrix is row-major */,
|
||||||
|
MemLayout::block_row_major, B_ROW / 2, HEADDIM, B_COL,
|
||||||
|
/*leading_dim_a=*/0,
|
||||||
|
/*leading_dim_b=*/0,
|
||||||
|
/*load_accum=*/true,
|
||||||
|
/*write_to_smem=*/true>(
|
||||||
|
smem_P_half1, smem_V, smem_O_half1 /*load accum*/, smem_O_half1,
|
||||||
|
tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster,
|
||||||
|
warpgroup_id_in_cluster);
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
thread_block_gemm_single_tile<
|
thread_block_gemm_single_tile<
|
||||||
float, MemLayout::K_major, MemLayout::MN_major, B_ROW / 2, HEADDIM,
|
float, MemLayout::K_major, MemLayout::MN_major, B_ROW / 2, HEADDIM,
|
||||||
@@ -1133,6 +1191,8 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
|
|
||||||
threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core);
|
threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core);
|
||||||
|
|
||||||
|
asm volatile("gemm_pv_finish_%=:" ::);
|
||||||
|
|
||||||
if constexpr (DEBUG) {
|
if constexpr (DEBUG) {
|
||||||
if (warpgroup_id == 0) {
|
if (warpgroup_id == 0) {
|
||||||
// O after PV
|
// O after PV
|
||||||
|
|||||||
Reference in New Issue
Block a user