flash: Specify leading_dim for split QK GEMM; fix uninit'd RF before GEMM
This commit is contained in:
@@ -15,7 +15,7 @@
|
|||||||
|
|
||||||
constexpr uint32_t ROWMAX_SETS = 3;
|
constexpr uint32_t ROWMAX_SETS = 3;
|
||||||
constexpr bool DEBUG = true;
|
constexpr bool DEBUG = true;
|
||||||
constexpr bool WARP_SPECIALIZED = false;
|
constexpr bool WARP_SPECIALIZED = true;
|
||||||
|
|
||||||
constexpr uint32_t DEV_FAKE_SMEM_START_ADDR = 0xf0000000;
|
constexpr uint32_t DEV_FAKE_SMEM_START_ADDR = 0xf0000000;
|
||||||
|
|
||||||
@@ -630,7 +630,11 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
tid_in_warpgroup);
|
tid_in_warpgroup);
|
||||||
} else {
|
} else {
|
||||||
// FIXME: transpose to K-major in SMEM for correctness
|
// FIXME: transpose to K-major in SMEM for correctness
|
||||||
load_tile_to_smem<float, MemLayout::K_major, MemLayout::K_major, B_ROW,
|
// load_tile_to_smem<float, MemLayout::K_major, MemLayout::K_major, B_ROW,
|
||||||
|
// HEADDIM, threads_per_warpgroup>(
|
||||||
|
// HEADDIM, warpgroup_id, 0 /* dim_k == headdim */, gmem_Q, smem_Q,
|
||||||
|
// tid_in_warpgroup);
|
||||||
|
load_tile_to_smem<float, MemLayout::MN_major, MemLayout::MN_major, B_ROW,
|
||||||
HEADDIM, threads_per_warpgroup>(
|
HEADDIM, threads_per_warpgroup>(
|
||||||
dim_seqlen, warpgroup_id, 0 /* dim_k == headdim */, gmem_Q, smem_Q,
|
dim_seqlen, warpgroup_id, 0 /* dim_k == headdim */, gmem_Q, smem_Q,
|
||||||
tid_in_warpgroup);
|
tid_in_warpgroup);
|
||||||
@@ -669,11 +673,11 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
initialize_accum_regs<0>();
|
initialize_accum_regs<0>();
|
||||||
initialize_accum_regs<1>();
|
initialize_accum_regs<1>();
|
||||||
|
|
||||||
thread_block_gemm_single_tile<float, MemLayout::MN_major,
|
thread_block_gemm_single_tile<
|
||||||
MemLayout::MN_major, B_ROW, B_COL,
|
float, MemLayout::MN_major, MemLayout::MN_major, B_ROW, B_COL,
|
||||||
HEADDIM,
|
HEADDIM, /*leading_dim_a=*/0, /*leading_dim_b=*/0,
|
||||||
/*load_accum=*/false,
|
/*load_accum=*/false,
|
||||||
/*write_to_smem=*/true>(
|
/*write_to_smem=*/true>(
|
||||||
smem_Q, smem_K, nullptr /*ignore accum*/, smem_S, tid_in_warpgroup,
|
smem_Q, smem_K, nullptr /*ignore accum*/, smem_S, tid_in_warpgroup,
|
||||||
threads_per_warpgroup, warpgroups_per_cluster,
|
threads_per_warpgroup, warpgroups_per_cluster,
|
||||||
warpgroup_id_in_cluster);
|
warpgroup_id_in_cluster);
|
||||||
@@ -686,7 +690,8 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
// assumes smem_Q is K-major
|
// assumes smem_Q is K-major
|
||||||
// FIXME: fix this to MN-major
|
// FIXME: fix this to MN-major
|
||||||
float *smem_Q_half0 = smem_Q;
|
float *smem_Q_half0 = smem_Q;
|
||||||
float *smem_Q_half1 = smem_Q + (B_ROW / 2) * HEADDIM;
|
float *smem_Q_half1 = smem_Q + (B_ROW / 2); // MN-major
|
||||||
|
// float *smem_Q_half1 = smem_Q + (B_ROW / 2) * HEADDIM; // K-major
|
||||||
float *smem_S_half0 = smem_S;
|
float *smem_S_half0 = smem_S;
|
||||||
float *smem_S_half1 = smem_S + (B_ROW / 2) * B_COL;
|
float *smem_S_half1 = smem_S + (B_ROW / 2) * B_COL;
|
||||||
|
|
||||||
@@ -695,19 +700,23 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
initialize_accum_regs<1>();
|
initialize_accum_regs<1>();
|
||||||
|
|
||||||
// split by rows into 2 chunks
|
// split by rows into 2 chunks
|
||||||
thread_block_gemm_single_tile<float, MemLayout::K_major /*FIXME*/,
|
thread_block_gemm_single_tile<
|
||||||
MemLayout::MN_major, B_ROW / 2, B_COL,
|
float, MemLayout::MN_major /*FIXME*/, MemLayout::MN_major, B_ROW / 2,
|
||||||
HEADDIM,
|
B_COL, HEADDIM, /*leading_dim_a=*/B_ROW, /*leading_dim_b=*/0,
|
||||||
/*load_accum=*/false,
|
/*load_accum=*/false,
|
||||||
/*write_to_smem=*/true>(
|
/*write_to_smem=*/true>(
|
||||||
smem_Q_half0, smem_K, nullptr /*ignore accum*/, smem_S_half0,
|
smem_Q_half0, smem_K, nullptr /*ignore accum*/, smem_S_half0,
|
||||||
tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster,
|
tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster,
|
||||||
warpgroup_id_in_cluster);
|
warpgroup_id_in_cluster);
|
||||||
thread_block_gemm_single_tile<float, MemLayout::K_major /*FIXME*/,
|
|
||||||
MemLayout::MN_major, B_ROW / 2, B_COL,
|
initialize_accum_regs<0>();
|
||||||
HEADDIM,
|
initialize_accum_regs<1>();
|
||||||
/*load_accum=*/false,
|
|
||||||
/*write_to_smem=*/true>(
|
thread_block_gemm_single_tile<
|
||||||
|
float, MemLayout::MN_major /*FIXME*/, MemLayout::MN_major, B_ROW / 2,
|
||||||
|
B_COL, HEADDIM, /*leading_dim_a=*/B_ROW, /*leading_dim_b=*/0,
|
||||||
|
/*load_accum=*/false,
|
||||||
|
/*write_to_smem=*/true>(
|
||||||
smem_Q_half1, smem_K, nullptr /*ignore accum*/, smem_S_half1,
|
smem_Q_half1, smem_K, nullptr /*ignore accum*/, smem_S_half1,
|
||||||
tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster,
|
tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster,
|
||||||
warpgroup_id_in_cluster);
|
warpgroup_id_in_cluster);
|
||||||
@@ -837,16 +846,18 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
|
|
||||||
thread_block_gemm_single_tile<float, MemLayout::K_major,
|
thread_block_gemm_single_tile<float, MemLayout::K_major,
|
||||||
MemLayout::MN_major, B_ROW, HEADDIM, B_COL,
|
MemLayout::MN_major, B_ROW, HEADDIM, B_COL,
|
||||||
|
/*leading_dim_a=*/0, /*leading_dim_b=*/0,
|
||||||
/*load_accum=*/true,
|
/*load_accum=*/true,
|
||||||
/*write_to_smem=*/true>(
|
/*write_to_smem=*/true>(
|
||||||
smem_P, smem_V, smem_O /*load accum*/, smem_O,
|
smem_P, smem_V, smem_O /*load accum*/, smem_O, tid_in_warpgroup,
|
||||||
tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster,
|
threads_per_warpgroup, warpgroups_per_cluster,
|
||||||
warpgroup_id_in_cluster);
|
warpgroup_id_in_cluster);
|
||||||
|
|
||||||
// FIXME: wrong but fast
|
// FIXME: wrong but fast
|
||||||
// thread_block_gemm_single_tile<float, MemLayout::MN_major,
|
// thread_block_gemm_single_tile<float, MemLayout::MN_major,
|
||||||
// MemLayout::MN_major,
|
// MemLayout::MN_major,
|
||||||
// B_ROW, HEADDIM, B_COL,
|
// B_ROW, HEADDIM, B_COL,
|
||||||
|
// /*leading_dim_a=*/0, /*leading_dim_b=*/0,
|
||||||
// /*load_accum=*/true,
|
// /*load_accum=*/true,
|
||||||
// /*write_to_smem=*/true>(
|
// /*write_to_smem=*/true>(
|
||||||
// smem_P, smem_V, smem_O /*load accum*/, smem_O,
|
// smem_P, smem_V, smem_O /*load accum*/, smem_O,
|
||||||
@@ -869,23 +880,26 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
initialize_accum_regs<1>();
|
initialize_accum_regs<1>();
|
||||||
|
|
||||||
// split by rows into 2 chunks
|
// split by rows into 2 chunks
|
||||||
thread_block_gemm_single_tile<float, MemLayout::K_major,
|
thread_block_gemm_single_tile<
|
||||||
MemLayout::MN_major, B_ROW / 2, HEADDIM,
|
float, MemLayout::K_major, MemLayout::MN_major, B_ROW / 2, HEADDIM,
|
||||||
B_COL,
|
B_COL, /*leading_dim_a=*/0, /*leading_dim_b=*/0,
|
||||||
/*load_accum=*/true,
|
/*load_accum=*/true,
|
||||||
/*write_to_smem=*/true>(
|
/*write_to_smem=*/true>(
|
||||||
smem_P_half0, smem_V, smem_O_half0 /*load accum*/,
|
smem_P_half0, smem_V, smem_O_half0 /*load accum*/, smem_O_half0,
|
||||||
smem_O_half0, tid_in_warpgroup, threads_per_warpgroup,
|
tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster,
|
||||||
warpgroups_per_cluster, warpgroup_id_in_cluster);
|
warpgroup_id_in_cluster);
|
||||||
|
|
||||||
thread_block_gemm_single_tile<float, MemLayout::K_major,
|
initialize_accum_regs<0>();
|
||||||
MemLayout::MN_major, B_ROW / 2, HEADDIM,
|
initialize_accum_regs<1>();
|
||||||
B_COL,
|
|
||||||
/*load_accum=*/true,
|
thread_block_gemm_single_tile<
|
||||||
/*write_to_smem=*/true>(
|
float, MemLayout::K_major, MemLayout::MN_major, B_ROW / 2, HEADDIM,
|
||||||
smem_P_half1, smem_V, smem_O_half1 /*load accum*/,
|
B_COL, /*leading_dim_a=*/0, /*leading_dim_b=*/0,
|
||||||
smem_O_half1, tid_in_warpgroup, threads_per_warpgroup,
|
/*load_accum=*/true,
|
||||||
warpgroups_per_cluster, warpgroup_id_in_cluster);
|
/*write_to_smem=*/true>(
|
||||||
|
smem_P_half1, smem_V, smem_O_half1 /*load accum*/, smem_O_half1,
|
||||||
|
tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster,
|
||||||
|
warpgroup_id_in_cluster);
|
||||||
}
|
}
|
||||||
|
|
||||||
threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core);
|
threadblock_barrier(warpgroup_id_in_cluster, warps_per_warpgroup_per_core);
|
||||||
|
|||||||
Reference in New Issue
Block a user