sgemm_tcore: Fix correctness for GEMMINI_DMA
Remap the logical SMEM row/col coordinates to the DMA's two-level block-row-major layout.
This commit is contained in:
@@ -84,11 +84,15 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
// "static" shared memory allocation. This would determine threadblock
|
// "static" shared memory allocation. This would determine threadblock
|
||||||
// occupancy of a single cluster
|
// occupancy of a single cluster
|
||||||
uint8_t *sharedmem_per_threadblock = reinterpret_cast<uint8_t *>(
|
uint8_t *sharedmem_per_threadblock = reinterpret_cast<uint8_t *>(
|
||||||
DEV_SMEM_START_ADDR + sizeof(float_type) * 2 /*overkill for non-dma*/ *
|
DEV_SMEM_START_ADDR +
|
||||||
(2 * BM * BK) * threadblock_id_in_cluster);
|
sizeof(float_type) * 2 * (2 * BM * BK) * threadblock_id_in_cluster);
|
||||||
|
|
||||||
thread_block_gemm<float_type, threads_per_threadblock,
|
thread_block_gemm<float_type, threads_per_threadblock,
|
||||||
/*write_to_gmem=*/true>(
|
/*write_to_gmem=*/true,
|
||||||
|
/*smem_a_offset=*/0,
|
||||||
|
/*smem_a_dbuf_offset=*/0,
|
||||||
|
/*smem_b_offset=*/2 * BM * BK * sizeof(float),
|
||||||
|
/*smem_b_dbuf_offset=*/2 * BM * BK * sizeof(float)>(
|
||||||
(const float_type *)arg->addr_a, (const float_type *)arg->addr_b,
|
(const float_type *)arg->addr_a, (const float_type *)arg->addr_b,
|
||||||
(float *)arg->addr_c, arg->dim_m, arg->dim_n, arg->dim_k,
|
(float *)arg->addr_c, arg->dim_m, arg->dim_n, arg->dim_k,
|
||||||
tid_in_threadblock, threadblocks_per_cluster, threadblock_id_in_cluster,
|
tid_in_threadblock, threadblocks_per_cluster, threadblock_id_in_cluster,
|
||||||
|
|||||||
@@ -70,7 +70,7 @@ using float_type = float16_t;
|
|||||||
// To model the case where the A matrix is already stored column-major in GMEM,
|
// To model the case where the A matrix is already stored column-major in GMEM,
|
||||||
// set both to 0.
|
// set both to 0.
|
||||||
#define TRANSPOSE_AT_PRODUCE 0
|
#define TRANSPOSE_AT_PRODUCE 0
|
||||||
#define TRANSPOSE_AT_CONSUME 0
|
#define TRANSPOSE_AT_CONSUME 1
|
||||||
|
|
||||||
#define GEMMINI_DMA 1
|
#define GEMMINI_DMA 1
|
||||||
#if SMEM_SIZE == 0x4000
|
#if SMEM_SIZE == 0x4000
|
||||||
@@ -230,19 +230,42 @@ inline void wmma_load_a(volatile const T *smem_A, const int local_k,
|
|||||||
constexpr int packed_factor = (std::is_same_v<T, float16_t> ? 2 : 1);
|
constexpr int packed_factor = (std::is_same_v<T, float16_t> ? 2 : 1);
|
||||||
const int local_k_adjusted = local_k / packed_factor;
|
const int local_k_adjusted = local_k / packed_factor;
|
||||||
|
|
||||||
|
static_assert(!GEMMINI_DMA || (layout == MemLayout::K_major),
|
||||||
|
"GEMMINI_DMA only supported for K-major A tile");
|
||||||
|
|
||||||
if constexpr (layout == MemLayout::K_major) {
|
if constexpr (layout == MemLayout::K_major) {
|
||||||
constexpr int smem_A_cols = leading_dim;
|
constexpr int smem_A_cols = leading_dim;
|
||||||
|
|
||||||
// int A_offset = (WM * warp_row + TCM * wm_iter + row) * smem_A_cols;
|
|
||||||
|
|
||||||
// f8-f15 stores a single row of A
|
// f8-f15 stores a single row of A
|
||||||
|
const uint32_t smem_logical_row = WM * warp_row + TCM * wm_iter + row;
|
||||||
|
const uint32_t smem_logical_col = local_k + 0; /* FIXME: adjust for fp16? */
|
||||||
|
uint32_t smem_row;
|
||||||
|
uint32_t smem_col;
|
||||||
|
if constexpr (GEMMINI_DMA) {
|
||||||
|
// if using Gemmini DMA, remap logical row/col to Gemmini's 2-level
|
||||||
|
// block-row-major layout
|
||||||
|
static_assert(
|
||||||
|
DIM == 8,
|
||||||
|
"GEMMINI_DMA layout remapping code only written for DIM == 8");
|
||||||
|
constexpr int dim_blocks_in_row = (smem_A_cols / DIM);
|
||||||
|
smem_row = (smem_logical_row / dim_blocks_in_row) * DIM +
|
||||||
|
(smem_logical_col / DIM);
|
||||||
|
smem_col = (smem_logical_row % dim_blocks_in_row) * DIM +
|
||||||
|
(smem_logical_col % DIM);
|
||||||
|
} else {
|
||||||
|
smem_row = smem_logical_row;
|
||||||
|
smem_col = smem_logical_col;
|
||||||
|
}
|
||||||
|
|
||||||
const volatile uint8_t *smem_addr;
|
const volatile uint8_t *smem_addr;
|
||||||
smem_addr = reinterpret_cast<const volatile uint8_t *>(
|
smem_addr = reinterpret_cast<const volatile uint8_t *>(
|
||||||
&reinterpret_cast<const volatile float *>(
|
&reinterpret_cast<const volatile float *>(
|
||||||
smem_A)[(WM * warp_row + TCM * wm_iter + row) * smem_A_cols +
|
smem_A)[smem_A_cols * smem_row + smem_col]);
|
||||||
local_k /* FIXME: adjust for fp16? */]);
|
|
||||||
// step to the next column
|
// step to the next column
|
||||||
// @perf: bank conflicts; threads read from different rows
|
// @perf: bank conflicts; threads read from different rows
|
||||||
|
// below is correct for GEMMINI_DMA; smem_col is always a multiple of 8,
|
||||||
|
// and the next 7 elements in the row are guaranteed to be consecutive in
|
||||||
|
// the memory
|
||||||
asm volatile("flw f0, %0(%1)" ::"i"(0 * sizeof(float)), "r"(smem_addr));
|
asm volatile("flw f0, %0(%1)" ::"i"(0 * sizeof(float)), "r"(smem_addr));
|
||||||
asm volatile("flw f1, %0(%1)" ::"i"(1 * sizeof(float)), "r"(smem_addr));
|
asm volatile("flw f1, %0(%1)" ::"i"(1 * sizeof(float)), "r"(smem_addr));
|
||||||
asm volatile("flw f2, %0(%1)" ::"i"(2 * sizeof(float)), "r"(smem_addr));
|
asm volatile("flw f2, %0(%1)" ::"i"(2 * sizeof(float)), "r"(smem_addr));
|
||||||
@@ -325,24 +348,53 @@ inline void wmma_load_b(const volatile T *smem_B, const int local_k,
|
|||||||
const int local_k_adjusted = local_k / packed_factor;
|
const int local_k_adjusted = local_k / packed_factor;
|
||||||
|
|
||||||
// B is stored N-major in smem
|
// B is stored N-major in smem
|
||||||
constexpr int smem_B_rows = tile_dim_k_adjusted;
|
|
||||||
constexpr int smem_B_cols = tile_dim_n;
|
constexpr int smem_B_cols = tile_dim_n;
|
||||||
|
|
||||||
|
const uint32_t smem_logical_row = local_k_adjusted + 0;
|
||||||
|
const uint32_t smem_logical_col = (WN * warp_col + TCN * wn_iter) + col;
|
||||||
|
uint32_t smem_row;
|
||||||
|
uint32_t smem_col;
|
||||||
|
if constexpr (GEMMINI_DMA) {
|
||||||
|
// if using Gemmini DMA, remap logical row/col to Gemmini's 2-level
|
||||||
|
// block-row-major layout
|
||||||
|
constexpr int dim_blocks_in_row = (smem_B_cols / DIM);
|
||||||
|
smem_row =
|
||||||
|
(smem_logical_row / dim_blocks_in_row) * DIM + (smem_logical_col / DIM);
|
||||||
|
smem_col =
|
||||||
|
(smem_logical_row % dim_blocks_in_row) * DIM + (smem_logical_col % DIM);
|
||||||
|
} else {
|
||||||
|
smem_row = smem_logical_row;
|
||||||
|
smem_col = smem_logical_col;
|
||||||
|
}
|
||||||
|
|
||||||
const volatile uint8_t *smem_addr;
|
const volatile uint8_t *smem_addr;
|
||||||
smem_addr = reinterpret_cast<const volatile uint8_t *>(
|
smem_addr = reinterpret_cast<const volatile uint8_t *>(
|
||||||
&reinterpret_cast<const volatile float *>(
|
&reinterpret_cast<const volatile float *>(
|
||||||
smem_B)[((local_k_adjusted + 0) * smem_B_cols) +
|
smem_B)[smem_B_cols * smem_row + smem_col]);
|
||||||
(WN * warp_col + TCN * wn_iter) + col]);
|
|
||||||
// f8-f15 stores a single column of B
|
// f8-f15 stores a single column of B
|
||||||
// threads read from different columns; no bank conflicts
|
// threads read from different columns; no bank conflicts
|
||||||
asm volatile("flw f8, %0(%1)" :: "i"(smem_B_cols * 0 * sizeof(float)), "r"(smem_addr));
|
if constexpr (GEMMINI_DMA) {
|
||||||
asm volatile("flw f9, %0(%1)" :: "i"(smem_B_cols * 1 * sizeof(float)), "r"(smem_addr));
|
// for GEMMINI_DMA, moving rows for the next 7 elements in the same column
|
||||||
asm volatile("flw f10, %0(%1)" :: "i"(smem_B_cols * 2 * sizeof(float)), "r"(smem_addr));
|
// is the same as moving DIM elements forward in the memory because of the
|
||||||
asm volatile("flw f11, %0(%1)" :: "i"(smem_B_cols * 3 * sizeof(float)), "r"(smem_addr));
|
// block-row-major layout
|
||||||
asm volatile("flw f12, %0(%1)" :: "i"(smem_B_cols * 4 * sizeof(float)), "r"(smem_addr));
|
asm volatile("flw f8, %0(%1)" :: "i"(DIM * 0 * sizeof(float)), "r"(smem_addr));
|
||||||
asm volatile("flw f13, %0(%1)" :: "i"(smem_B_cols * 5 * sizeof(float)), "r"(smem_addr));
|
asm volatile("flw f9, %0(%1)" :: "i"(DIM * 1 * sizeof(float)), "r"(smem_addr));
|
||||||
asm volatile("flw f14, %0(%1)" :: "i"(smem_B_cols * 6 * sizeof(float)), "r"(smem_addr));
|
asm volatile("flw f10, %0(%1)" :: "i"(DIM * 2 * sizeof(float)), "r"(smem_addr));
|
||||||
asm volatile("flw f15, %0(%1)" :: "i"(smem_B_cols * 7 * sizeof(float)), "r"(smem_addr));
|
asm volatile("flw f11, %0(%1)" :: "i"(DIM * 3 * sizeof(float)), "r"(smem_addr));
|
||||||
|
asm volatile("flw f12, %0(%1)" :: "i"(DIM * 4 * sizeof(float)), "r"(smem_addr));
|
||||||
|
asm volatile("flw f13, %0(%1)" :: "i"(DIM * 5 * sizeof(float)), "r"(smem_addr));
|
||||||
|
asm volatile("flw f14, %0(%1)" :: "i"(DIM * 6 * sizeof(float)), "r"(smem_addr));
|
||||||
|
asm volatile("flw f15, %0(%1)" :: "i"(DIM * 7 * sizeof(float)), "r"(smem_addr));
|
||||||
|
} else {
|
||||||
|
asm volatile("flw f8, %0(%1)" :: "i"(smem_B_cols * 0 * sizeof(float)), "r"(smem_addr));
|
||||||
|
asm volatile("flw f9, %0(%1)" :: "i"(smem_B_cols * 1 * sizeof(float)), "r"(smem_addr));
|
||||||
|
asm volatile("flw f10, %0(%1)" :: "i"(smem_B_cols * 2 * sizeof(float)), "r"(smem_addr));
|
||||||
|
asm volatile("flw f11, %0(%1)" :: "i"(smem_B_cols * 3 * sizeof(float)), "r"(smem_addr));
|
||||||
|
asm volatile("flw f12, %0(%1)" :: "i"(smem_B_cols * 4 * sizeof(float)), "r"(smem_addr));
|
||||||
|
asm volatile("flw f13, %0(%1)" :: "i"(smem_B_cols * 5 * sizeof(float)), "r"(smem_addr));
|
||||||
|
asm volatile("flw f14, %0(%1)" :: "i"(smem_B_cols * 6 * sizeof(float)), "r"(smem_addr));
|
||||||
|
asm volatile("flw f15, %0(%1)" :: "i"(smem_B_cols * 7 * sizeof(float)), "r"(smem_addr));
|
||||||
|
}
|
||||||
|
|
||||||
asm volatile ("wmma_load_b_finish_%=:" :: );
|
asm volatile ("wmma_load_b_finish_%=:" :: );
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user