flash: Fix DMA layout for GEMM II
This commit is contained in:
@@ -578,7 +578,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
constexpr uint32_t smem_rowmax_size = B_ROW * ROWMAX_SETS;
|
constexpr uint32_t smem_rowmax_size = B_ROW * ROWMAX_SETS;
|
||||||
constexpr uint32_t smem_rowsum_size = B_ROW;
|
constexpr uint32_t smem_rowsum_size = B_ROW;
|
||||||
constexpr uint32_t smem_O_row_scale_size = B_ROW;
|
constexpr uint32_t smem_O_row_scale_size = B_ROW;
|
||||||
// smem_cursor = reinterpret_cast<float *>(DEV_FAKE_SMEM_START_ADDR + SMEM_SIZE);
|
// FIXME: dangerous
|
||||||
smem_cursor = reinterpret_cast<float *>(0xff038000);
|
smem_cursor = reinterpret_cast<float *>(0xff038000);
|
||||||
|
|
||||||
float *smem_rowmax_0 = smem_cursor;
|
float *smem_rowmax_0 = smem_cursor;
|
||||||
@@ -599,8 +599,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
// NOTE: out-of bounds is not checked
|
// NOTE: out-of bounds is not checked
|
||||||
// TODO: reduce this from B_ROW to NUM_WARPS
|
// TODO: reduce this from B_ROW to NUM_WARPS
|
||||||
constexpr uint32_t smem_scratchpad_size =
|
constexpr uint32_t smem_scratchpad_size =
|
||||||
B_ROW * NUM_THREADS * 2 /*arbitrary slack*/;
|
threads_per_warpgroup * 2 /*arbitrary slack*/;
|
||||||
// threads_per_warpgroup * 2 /*arbitrary slack*/;
|
|
||||||
float *smem_scratchpad_0 = smem_cursor;
|
float *smem_scratchpad_0 = smem_cursor;
|
||||||
smem_cursor += smem_scratchpad_size;
|
smem_cursor += smem_scratchpad_size;
|
||||||
float *smem_scratchpad_1 = smem_cursor;
|
float *smem_scratchpad_1 = smem_cursor;
|
||||||
@@ -1013,12 +1012,12 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
initialize_accum_regs<1>();
|
initialize_accum_regs<1>();
|
||||||
|
|
||||||
if constexpr (GEMMINI_DMA) {
|
if constexpr (GEMMINI_DMA) {
|
||||||
thread_block_gemm_single_tile<float, MemLayout::block_row_major,
|
thread_block_gemm_single_tile<
|
||||||
MemLayout::block_row_major, B_ROW,
|
float, MemLayout::K_major /* P matrix is row-major */,
|
||||||
HEADDIM, B_COL,
|
MemLayout::block_row_major, B_ROW, HEADDIM, B_COL,
|
||||||
/*leading_dim_a=*/0, /*leading_dim_b=*/0,
|
/*leading_dim_a=*/0, /*leading_dim_b=*/0,
|
||||||
/*load_accum=*/true,
|
/*load_accum=*/true,
|
||||||
/*write_to_smem=*/true>(
|
/*write_to_smem=*/true>(
|
||||||
smem_P, smem_V, smem_O /*load accum*/, smem_O, tid_in_warpgroup,
|
smem_P, smem_V, smem_O /*load accum*/, smem_O, tid_in_warpgroup,
|
||||||
threads_per_warpgroup, warpgroups_per_cluster,
|
threads_per_warpgroup, warpgroups_per_cluster,
|
||||||
warpgroup_id_in_cluster);
|
warpgroup_id_in_cluster);
|
||||||
@@ -1045,6 +1044,9 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
// warpgroups_per_cluster, warpgroup_id_in_cluster);
|
// warpgroups_per_cluster, warpgroup_id_in_cluster);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
static_assert(!WARP_SPECIALIZED || !GEMMINI_DMA,
|
||||||
|
"warp specialization unimplemented for dma");
|
||||||
|
|
||||||
// when warp-specialized, there's only enough warps to do 64x32 tile
|
// when warp-specialized, there's only enough warps to do 64x32 tile
|
||||||
// size so we need to do 2 GEMM calls
|
// size so we need to do 2 GEMM calls
|
||||||
static_assert(B_ROW / 2 == 32,
|
static_assert(B_ROW / 2 == 32,
|
||||||
|
|||||||
@@ -254,9 +254,6 @@ inline void wmma_load_a(volatile const T *smem_A, const int local_k,
|
|||||||
constexpr int packed_factor = (std::is_same_v<T, float16_t> ? 2 : 1);
|
constexpr int packed_factor = (std::is_same_v<T, float16_t> ? 2 : 1);
|
||||||
const int local_k_adjusted = local_k / packed_factor;
|
const int local_k_adjusted = local_k / packed_factor;
|
||||||
|
|
||||||
static_assert(!GEMMINI_DMA || (layout == MemLayout::block_row_major) ||
|
|
||||||
GEMMINI_DMA_FLEXIBLE_LAYOUT,
|
|
||||||
"wrong memory layout selected for DMA");
|
|
||||||
static_assert((layout != MemLayout::K_major) || (FP_SIZE == 32),
|
static_assert((layout != MemLayout::K_major) || (FP_SIZE == 32),
|
||||||
"fp16 is not really tested for K-major A layout");
|
"fp16 is not really tested for K-major A layout");
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user