flash: Specify leading_dim for split QK GEMM; fix uninit'd RF before GEMM

This commit is contained in:
Hansung Kim
2024-09-02 00:15:57 -07:00
parent bdd955836d
commit 8125192846

View File

@@ -15,7 +15,7 @@
constexpr uint32_t ROWMAX_SETS = 3; constexpr uint32_t ROWMAX_SETS = 3;
constexpr bool DEBUG = true; constexpr bool DEBUG = true;
constexpr bool WARP_SPECIALIZED = false; constexpr bool WARP_SPECIALIZED = true;
constexpr uint32_t DEV_FAKE_SMEM_START_ADDR = 0xf0000000; constexpr uint32_t DEV_FAKE_SMEM_START_ADDR = 0xf0000000;
@@ -630,7 +630,11 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
tid_in_warpgroup); tid_in_warpgroup);
} else { } else {
// FIXME: transpose to K-major in SMEM for correctness // FIXME: transpose to K-major in SMEM for correctness
load_tile_to_smem<float, MemLayout::K_major, MemLayout::K_major, B_ROW, // load_tile_to_smem<float, MemLayout::K_major, MemLayout::K_major, B_ROW,
// HEADDIM, threads_per_warpgroup>(
// HEADDIM, warpgroup_id, 0 /* dim_k == headdim */, gmem_Q, smem_Q,
// tid_in_warpgroup);
load_tile_to_smem<float, MemLayout::MN_major, MemLayout::MN_major, B_ROW,
HEADDIM, threads_per_warpgroup>( HEADDIM, threads_per_warpgroup>(
dim_seqlen, warpgroup_id, 0 /* dim_k == headdim */, gmem_Q, smem_Q, dim_seqlen, warpgroup_id, 0 /* dim_k == headdim */, gmem_Q, smem_Q,
tid_in_warpgroup); tid_in_warpgroup);
@@ -669,11 +673,11 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
initialize_accum_regs<0>(); initialize_accum_regs<0>();
initialize_accum_regs<1>(); initialize_accum_regs<1>();
thread_block_gemm_single_tile<float, MemLayout::MN_major, thread_block_gemm_single_tile<
MemLayout::MN_major, B_ROW, B_COL, float, MemLayout::MN_major, MemLayout::MN_major, B_ROW, B_COL,
HEADDIM, HEADDIM, /*leading_dim_a=*/0, /*leading_dim_b=*/0,
/*load_accum=*/false, /*load_accum=*/false,
/*write_to_smem=*/true>( /*write_to_smem=*/true>(
smem_Q, smem_K, nullptr /*ignore accum*/, smem_S, tid_in_warpgroup, smem_Q, smem_K, nullptr /*ignore accum*/, smem_S, tid_in_warpgroup,
threads_per_warpgroup, warpgroups_per_cluster, threads_per_warpgroup, warpgroups_per_cluster,
warpgroup_id_in_cluster); warpgroup_id_in_cluster);
@@ -686,7 +690,8 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
// assumes smem_Q is K-major // assumes smem_Q is K-major
// FIXME: fix this to MN-major // FIXME: fix this to MN-major
float *smem_Q_half0 = smem_Q; float *smem_Q_half0 = smem_Q;
float *smem_Q_half1 = smem_Q + (B_ROW / 2) * HEADDIM; float *smem_Q_half1 = smem_Q + (B_ROW / 2); // MN-major
// float *smem_Q_half1 = smem_Q + (B_ROW / 2) * HEADDIM; // K-major
float *smem_S_half0 = smem_S; float *smem_S_half0 = smem_S;
float *smem_S_half1 = smem_S + (B_ROW / 2) * B_COL; float *smem_S_half1 = smem_S + (B_ROW / 2) * B_COL;
@@ -695,19 +700,23 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
initialize_accum_regs<1>(); initialize_accum_regs<1>();
// split by rows into 2 chunks // split by rows into 2 chunks
thread_block_gemm_single_tile<float, MemLayout::K_major /*FIXME*/, thread_block_gemm_single_tile<
MemLayout::MN_major, B_ROW / 2, B_COL, float, MemLayout::MN_major /*FIXME*/, MemLayout::MN_major, B_ROW / 2,
HEADDIM, B_COL, HEADDIM, /*leading_dim_a=*/B_ROW, /*leading_dim_b=*/0,
/*load_accum=*/false, /*load_accum=*/false,
/*write_to_smem=*/true>( /*write_to_smem=*/true>(
smem_Q_half0, smem_K, nullptr /*ignore accum*/, smem_S_half0, smem_Q_half0, smem_K, nullptr /*ignore accum*/, smem_S_half0,
tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster,
warpgroup_id_in_cluster); warpgroup_id_in_cluster);
thread_block_gemm_single_tile<float, MemLayout::K_major /*FIXME*/,
MemLayout::MN_major, B_ROW / 2, B_COL, initialize_accum_regs<0>();
HEADDIM, initialize_accum_regs<1>();
/*load_accum=*/false,
/*write_to_smem=*/true>( thread_block_gemm_single_tile<
float, MemLayout::MN_major /*FIXME*/, MemLayout::MN_major, B_ROW / 2,
B_COL, HEADDIM, /*leading_dim_a=*/B_ROW, /*leading_dim_b=*/0,
/*load_accum=*/false,
/*write_to_smem=*/true>(
smem_Q_half1, smem_K, nullptr /*ignore accum*/, smem_S_half1, smem_Q_half1, smem_K, nullptr /*ignore accum*/, smem_S_half1,
tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster,
warpgroup_id_in_cluster); warpgroup_id_in_cluster);
@@ -837,16 +846,18 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
thread_block_gemm_single_tile<float, MemLayout::K_major, thread_block_gemm_single_tile<float, MemLayout::K_major,
MemLayout::MN_major, B_ROW, HEADDIM, B_COL, MemLayout::MN_major, B_ROW, HEADDIM, B_COL,
/*leading_dim_a=*/0, /*leading_dim_b=*/0,
/*load_accum=*/true, /*load_accum=*/true,
/*write_to_smem=*/true>( /*write_to_smem=*/true>(
smem_P, smem_V, smem_O /*load accum*/, smem_O, smem_P, smem_V, smem_O /*load accum*/, smem_O, tid_in_warpgroup,
tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster, threads_per_warpgroup, warpgroups_per_cluster,
warpgroup_id_in_cluster); warpgroup_id_in_cluster);
// FIXME: wrong but fast // FIXME: wrong but fast
// thread_block_gemm_single_tile<float, MemLayout::MN_major, // thread_block_gemm_single_tile<float, MemLayout::MN_major,
// MemLayout::MN_major, // MemLayout::MN_major,
// B_ROW, HEADDIM, B_COL, // B_ROW, HEADDIM, B_COL,
// /*leading_dim_a=*/0, /*leading_dim_b=*/0,
// /*load_accum=*/true, // /*load_accum=*/true,
// /*write_to_smem=*/true>( // /*write_to_smem=*/true>(
// smem_P, smem_V, smem_O /*load accum*/, smem_O, // smem_P, smem_V, smem_O /*load accum*/, smem_O,
@@ -869,23 +880,26 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
initialize_accum_regs<1>(); initialize_accum_regs<1>();
// split by rows into 2 chunks // split by rows into 2 chunks
thread_block_gemm_single_tile<float, MemLayout::K_major, thread_block_gemm_single_tile<
MemLayout::MN_major, B_ROW / 2, HEADDIM, float, MemLayout::K_major, MemLayout::MN_major, B_ROW / 2, HEADDIM,
B_COL, B_COL, /*leading_dim_a=*/0, /*leading_dim_b=*/0,
/*load_accum=*/true, /*load_accum=*/true,
/*write_to_smem=*/true>( /*write_to_smem=*/true>(
smem_P_half0, smem_V, smem_O_half0 /*load accum*/, smem_P_half0, smem_V, smem_O_half0 /*load accum*/, smem_O_half0,
smem_O_half0, tid_in_warpgroup, threads_per_warpgroup, tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster,
warpgroups_per_cluster, warpgroup_id_in_cluster); warpgroup_id_in_cluster);
thread_block_gemm_single_tile<float, MemLayout::K_major, initialize_accum_regs<0>();
MemLayout::MN_major, B_ROW / 2, HEADDIM, initialize_accum_regs<1>();
B_COL,
/*load_accum=*/true, thread_block_gemm_single_tile<
/*write_to_smem=*/true>( float, MemLayout::K_major, MemLayout::MN_major, B_ROW / 2, HEADDIM,
smem_P_half1, smem_V, smem_O_half1 /*load accum*/, B_COL, /*leading_dim_a=*/0, /*leading_dim_b=*/0,
smem_O_half1, tid_in_warpgroup, threads_per_warpgroup, /*load_accum=*/true,
warpgroups_per_cluster, warpgroup_id_in_cluster); /*write_to_smem=*/true>(
smem_P_half1, smem_V, smem_O_half1 /*load accum*/, smem_O_half1,
tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster,
warpgroup_id_in_cluster);
} }
threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core); threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core);