flash: Complete Q_IS_K_MAJOR code for GEMM II

This commit is contained in:
Hansung Kim
2024-09-19 20:36:03 -07:00
parent b9cafd6372
commit d0ef06cec1

View File

@@ -672,9 +672,10 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
static_assert(B_ROW / 2 == 32,
"tile size assumption for warp-specialization not met");
// assumes smem_P is K-major
float *smem_P_half0 = smem_P;
float *smem_P_half1 = smem_P + (B_ROW / 2) * B_COL;
float *smem_P_half1 = (Q_IS_K_MAJOR || GEMMINI_DMA)
? smem_P + (B_ROW / 2) * B_COL
: smem_P + (B_ROW / 2);
float *smem_O_half0 = smem_O;
float *smem_O_half1 = smem_O + (B_ROW / 2) * HEADDIM;
@@ -707,7 +708,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster,
warpgroup_id_in_cluster);
}
} else {
} else if constexpr (Q_IS_K_MAJOR) {
thread_block_gemm_single_tile<
float, MemLayout::K_major, MemLayout::MN_major, B_ROW / 2, HEADDIM,
B_COL, /*leading_dim_a=*/0, /*leading_dim_b=*/0,
@@ -716,6 +717,15 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
smem_P_half0, smem_V, smem_O_half0 /*load accum*/, smem_O_half0,
tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster,
warpgroup_id_in_cluster);
} else {
thread_block_gemm_single_tile<
float, MemLayout::MN_major, MemLayout::MN_major, B_ROW / 2, HEADDIM,
B_COL, /*leading_dim_a=*/B_ROW, /*leading_dim_b=*/0,
/*load_accum=*/true,
/*write_to_smem=*/true>(
smem_P_half0, smem_V, smem_O_half0 /*load accum*/, smem_O_half0,
tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster,
warpgroup_id_in_cluster);
}
initialize_accum_regs<0>();
@@ -745,7 +755,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster,
warpgroup_id_in_cluster);
}
} else {
} else if constexpr (Q_IS_K_MAJOR) {
thread_block_gemm_single_tile<
float, MemLayout::K_major, MemLayout::MN_major, B_ROW / 2, HEADDIM,
B_COL, /*leading_dim_a=*/0, /*leading_dim_b=*/0,
@@ -754,6 +764,15 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
smem_P_half1, smem_V, smem_O_half1 /*load accum*/, smem_O_half1,
tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster,
warpgroup_id_in_cluster);
} else {
thread_block_gemm_single_tile<
float, MemLayout::MN_major, MemLayout::MN_major, B_ROW / 2, HEADDIM,
B_COL, /*leading_dim_a=*/B_ROW, /*leading_dim_b=*/0,
/*load_accum=*/true,
/*write_to_smem=*/true>(
smem_P_half1, smem_V, smem_O_half1 /*load accum*/, smem_O_half1,
tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster,
warpgroup_id_in_cluster);
}
}