flash: Complete Q_IS_K_MAJOR code for GEMM II
This commit is contained in:
@@ -672,9 +672,10 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
static_assert(B_ROW / 2 == 32,
|
||||
"tile size assumption for warp-specialization not met");
|
||||
|
||||
// assumes smem_P is K-major
|
||||
float *smem_P_half0 = smem_P;
|
||||
float *smem_P_half1 = smem_P + (B_ROW / 2) * B_COL;
|
||||
float *smem_P_half1 = (Q_IS_K_MAJOR || GEMMINI_DMA)
|
||||
? smem_P + (B_ROW / 2) * B_COL
|
||||
: smem_P + (B_ROW / 2);
|
||||
float *smem_O_half0 = smem_O;
|
||||
float *smem_O_half1 = smem_O + (B_ROW / 2) * HEADDIM;
|
||||
|
||||
@@ -707,7 +708,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster,
|
||||
warpgroup_id_in_cluster);
|
||||
}
|
||||
} else {
|
||||
} else if constexpr (Q_IS_K_MAJOR) {
|
||||
thread_block_gemm_single_tile<
|
||||
float, MemLayout::K_major, MemLayout::MN_major, B_ROW / 2, HEADDIM,
|
||||
B_COL, /*leading_dim_a=*/0, /*leading_dim_b=*/0,
|
||||
@@ -716,6 +717,15 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
smem_P_half0, smem_V, smem_O_half0 /*load accum*/, smem_O_half0,
|
||||
tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster,
|
||||
warpgroup_id_in_cluster);
|
||||
} else {
|
||||
thread_block_gemm_single_tile<
|
||||
float, MemLayout::MN_major, MemLayout::MN_major, B_ROW / 2, HEADDIM,
|
||||
B_COL, /*leading_dim_a=*/B_ROW, /*leading_dim_b=*/0,
|
||||
/*load_accum=*/true,
|
||||
/*write_to_smem=*/true>(
|
||||
smem_P_half0, smem_V, smem_O_half0 /*load accum*/, smem_O_half0,
|
||||
tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster,
|
||||
warpgroup_id_in_cluster);
|
||||
}
|
||||
|
||||
initialize_accum_regs<0>();
|
||||
@@ -745,7 +755,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster,
|
||||
warpgroup_id_in_cluster);
|
||||
}
|
||||
} else {
|
||||
} else if constexpr (Q_IS_K_MAJOR) {
|
||||
thread_block_gemm_single_tile<
|
||||
float, MemLayout::K_major, MemLayout::MN_major, B_ROW / 2, HEADDIM,
|
||||
B_COL, /*leading_dim_a=*/0, /*leading_dim_b=*/0,
|
||||
@@ -754,6 +764,15 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
smem_P_half1, smem_V, smem_O_half1 /*load accum*/, smem_O_half1,
|
||||
tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster,
|
||||
warpgroup_id_in_cluster);
|
||||
} else {
|
||||
thread_block_gemm_single_tile<
|
||||
float, MemLayout::MN_major, MemLayout::MN_major, B_ROW / 2, HEADDIM,
|
||||
B_COL, /*leading_dim_a=*/B_ROW, /*leading_dim_b=*/0,
|
||||
/*load_accum=*/true,
|
||||
/*write_to_smem=*/true>(
|
||||
smem_P_half1, smem_V, smem_O_half1 /*load accum*/, smem_O_half1,
|
||||
tid_in_warpgroup, threads_per_warpgroup, warpgroups_per_cluster,
|
||||
warpgroup_id_in_cluster);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user