sgemm_impl: Remove GMEM_COALESCED_A option
Uncoalesced GMEM accesses is verified to yield slow performance and the relevant code is not used anymore; remove the cruft
This commit is contained in:
@@ -71,14 +71,6 @@ using float_type = float16_t;
|
||||
// set both to 0.
|
||||
#define TRANSPOSE_AT_PRODUCE 1
|
||||
#define TRANSPOSE_AT_CONSUME 0
|
||||
// GMEM_COALESCED: When TRANSPOSE_AT_PRODUCE == 1 (i.e. transpose at
|
||||
// GMEM->SMEM), determines whether we do bank-conflict-free accesses for
|
||||
// 1: GMEM loads of A matrix, or
|
||||
// 0: SMEM stores of A matrix.
|
||||
//
|
||||
// Usually, GMEM_COALESCED==1 yields better performance since the memory
|
||||
// behavior of GMEM is more sensitive to bank conflicts.
|
||||
#define GMEM_COALESCED_A 1
|
||||
|
||||
#define GEMMINI_DMA 0
|
||||
#if SMEM_SIZE == 0x4000
|
||||
@@ -403,8 +395,7 @@ template <typename T,
|
||||
MemLayout gmem_layout, // memory layout of the GMEM tile
|
||||
MemLayout smem_layout, // memory layout of the GMEM tile
|
||||
uint32_t tile_dim_mn, // row dimension of the SMEM tile
|
||||
uint32_t tile_dim_k, // column dimension of the SMEM tile
|
||||
bool gmem_contiguous = true
|
||||
uint32_t tile_dim_k // column dimension of the SMEM tile
|
||||
>
|
||||
__attribute__((always_inline)) inline void
|
||||
global_dmem_load_new(const uint32_t dim_col, const uint32_t mn_index,
|
||||
@@ -450,9 +441,6 @@ global_dmem_load_new(const uint32_t dim_col, const uint32_t mn_index,
|
||||
// FIXME: don't hardcode this here
|
||||
constexpr uint32_t threads_per_threadblock = (BM * BN) / ELEM_PER_THREAD;
|
||||
|
||||
static_assert(gmem_contiguous == true,
|
||||
"currently only supports contiguous accesses in GMEM");
|
||||
|
||||
const uint32_t global_row_mn_major = k_ + local_row_gmem;
|
||||
const uint32_t global_col_mn_major = smem_dim_col * mn_index + local_col_gmem;
|
||||
const uint32_t global_row_k_major = gmem_dim_row * mn_index + local_row_gmem;
|
||||
@@ -505,12 +493,9 @@ global_dmem_load_new(const uint32_t dim_col, const uint32_t mn_index,
|
||||
asm volatile("flw ft7, (%0)" ::"r"(global));
|
||||
global += dim_col_ * row_stride;
|
||||
|
||||
// do we need to do transposed write?
|
||||
// need to branch because address offset constant in the inline assembly
|
||||
// cannot be larger than a certain limit
|
||||
if constexpr (!transposed_write) {
|
||||
static_assert(gmem_layout == MemLayout::MN_major);
|
||||
|
||||
// if not, do the same along-the-column accesses for registers as we did
|
||||
// for gmem
|
||||
asm volatile("fsw ft0, %0(%1)" ::"i"(smem_dim_col * row_stride * 0 *
|
||||
sizeof(float)),
|
||||
"r"(local));
|
||||
@@ -540,11 +525,11 @@ global_dmem_load_new(const uint32_t dim_col, const uint32_t mn_index,
|
||||
"r"(local));
|
||||
local += smem_dim_col * row_stride * 2;
|
||||
} else {
|
||||
// currently, tensor core hardware only supports MN-major SMEM tile
|
||||
// layout for correct results
|
||||
static_assert(gmem_layout == MemLayout::K_major);
|
||||
static_assert(smem_layout == MemLayout::MN_major);
|
||||
|
||||
// if yes, write the registers along the row, doing a transpose
|
||||
// @perf: this will incur bank conflicts in smem
|
||||
asm volatile("fsw ft0, %0(%1)" ::"i"(row_stride * 0 * sizeof(float)),
|
||||
"r"(local));
|
||||
asm volatile("fsw ft1, %0(%1)" ::"i"(row_stride * 1 * sizeof(float)),
|
||||
@@ -568,121 +553,6 @@ global_dmem_load_new(const uint32_t dim_col, const uint32_t mn_index,
|
||||
asm volatile("global_dmem_load_finish_new_%=:" ::);
|
||||
}
|
||||
|
||||
// TODO: reduce args by passing leading A/B dimensions
|
||||
template <typename T>
|
||||
__attribute__((always_inline))
|
||||
inline void global_dmem_load(const uint32_t dim_m, const uint32_t dim_n, const uint32_t dim_k,
|
||||
const uint32_t k, const T *A, const T *B,
|
||||
volatile T *local_a, volatile T *local_b,
|
||||
const uint32_t tid_in_threadblock,
|
||||
const uint32_t threadblock_id_x,
|
||||
const uint32_t threadblock_id_y) {
|
||||
asm volatile ("global_dmem_load_start_%=:" :: );
|
||||
|
||||
// In fp16 mode, bit-pack two fp16 elements into each fp32 element, and do
|
||||
// data movement at the fp32 granularity. Assuming that the matrix is stored
|
||||
// row-major in GMEM, the packed fp16 pairs belong to the same row,
|
||||
// neighboring columns; therefore, it essentially becomes equivalent to
|
||||
// moving a fp32 matrix whose column dimensions (dim_k/BK/k) are compressed
|
||||
// by a factor of two.
|
||||
constexpr uint32_t packed_factor = (std::is_same_v<T, float16_t> ? 2 : 1);
|
||||
constexpr uint32_t BK_adjusted = BK / packed_factor;
|
||||
const uint32_t dim_k_adjusted = dim_k / packed_factor;
|
||||
const uint32_t k_adjusted = k / packed_factor;
|
||||
|
||||
const uint32_t local_a_row = tid_in_threadblock / BK_adjusted;
|
||||
const uint32_t local_a_col = tid_in_threadblock % BK_adjusted;
|
||||
const uint32_t local_as_row = tid_in_threadblock / BM;
|
||||
const uint32_t local_as_col = tid_in_threadblock % BM;
|
||||
const uint32_t local_b_row = tid_in_threadblock / BN;
|
||||
const uint32_t local_b_col = tid_in_threadblock % BN;
|
||||
|
||||
// FIXME: need fix for fp16?
|
||||
constexpr uint32_t threads_per_threadblock = (BM * BN) / ELEM_PER_THREAD;
|
||||
|
||||
// Data move from GMEM to SMEM
|
||||
//
|
||||
// Make sure global offset values for A and B are contiguous between
|
||||
// neighboring threads to ensure GMEM coalescing.
|
||||
//
|
||||
// TODO: Sharedmem swizzling is important here
|
||||
|
||||
// move A
|
||||
if constexpr (!TRANSPOSE_AT_PRODUCE) {
|
||||
global_dmem_load_new<T, MemLayout::MN_major, MemLayout::MN_major, BM, BK>(
|
||||
dim_m, threadblock_id_y, k, A, local_a, tid_in_threadblock);
|
||||
} else {
|
||||
if constexpr (!GMEM_COALESCED_A) {
|
||||
// !GMEM_COALESCED_A: threads do uncoalesced read from neighboring row in
|
||||
// GMEM, writes to neighboring cols in SMEM
|
||||
constexpr uint32_t row_stride_as = threads_per_threadblock / BM;
|
||||
const uint32_t global_a_row = BM * threadblock_id_y + local_as_col;
|
||||
const float *global_a =
|
||||
reinterpret_cast<float *>(A) + dim_k_adjusted * global_a_row + (k_adjusted + local_as_row);
|
||||
volatile float *local_a_tmp =
|
||||
reinterpret_cast<float *>(local_a) + BM * local_as_row + local_as_col;
|
||||
|
||||
static_assert(
|
||||
row_stride_as * 8 <= BK_adjusted,
|
||||
"manual loop unrolling condition not met; consider increasing BK");
|
||||
static_assert(
|
||||
(BK_adjusted % (row_stride_as * 8)) == 0,
|
||||
"manual loop unrolling condition not met; BK should be power-of-two");
|
||||
|
||||
#pragma GCC unroll 1
|
||||
for (uint32_t local_row_offset = 0; local_row_offset < BK_adjusted;
|
||||
local_row_offset += row_stride_as * 8) {
|
||||
// const uint32_t global_a_offset =
|
||||
// dim_k_adjusted * (global_a_row) + (k + local_as_row + local_row_offset);
|
||||
// local_a[BM * (local_as_row + local_row_offset) + local_as_col] =
|
||||
// A[global_a_offset];
|
||||
|
||||
// @perf: bank conflicts
|
||||
asm volatile ("flw ft0, (%0)" :: "r"(global_a));
|
||||
global_a += row_stride_as;
|
||||
asm volatile ("flw ft1, (%0)" :: "r"(global_a));
|
||||
global_a += row_stride_as;
|
||||
asm volatile ("flw ft2, (%0)" :: "r"(global_a));
|
||||
global_a += row_stride_as;
|
||||
asm volatile ("flw ft3, (%0)" :: "r"(global_a));
|
||||
global_a += row_stride_as;
|
||||
asm volatile ("flw ft4, (%0)" :: "r"(global_a));
|
||||
global_a += row_stride_as;
|
||||
asm volatile ("flw ft5, (%0)" :: "r"(global_a));
|
||||
global_a += row_stride_as;
|
||||
asm volatile ("flw ft6, (%0)" :: "r"(global_a));
|
||||
global_a += row_stride_as;
|
||||
asm volatile ("flw ft7, (%0)" :: "r"(global_a));
|
||||
global_a += row_stride_as;
|
||||
|
||||
// NOTE: stride is fixed to word size , i.e. sizeof(float) = 4,
|
||||
// regardless of fp16 or fp32. Since Vortex core does not support fp16,
|
||||
// load things at word granularity and reinterpret bits inside the
|
||||
// tensor core.
|
||||
asm volatile ("fsw ft0, %0(%1)" :: "i"(BM * row_stride_as * 0 * sizeof(float)), "r"(local_a_tmp));
|
||||
asm volatile ("fsw ft1, %0(%1)" :: "i"(BM * row_stride_as * 1 * sizeof(float)), "r"(local_a_tmp));
|
||||
asm volatile ("fsw ft2, %0(%1)" :: "i"(BM * row_stride_as * 2 * sizeof(float)), "r"(local_a_tmp));
|
||||
asm volatile ("fsw ft3, %0(%1)" :: "i"(BM * row_stride_as * 3 * sizeof(float)), "r"(local_a_tmp));
|
||||
asm volatile ("fsw ft4, %0(%1)" :: "i"(BM * row_stride_as * 4 * sizeof(float)), "r"(local_a_tmp));
|
||||
asm volatile ("fsw ft5, %0(%1)" :: "i"(BM * row_stride_as * 5 * sizeof(float)), "r"(local_a_tmp));
|
||||
asm volatile ("fsw ft6, %0(%1)" :: "i"(BM * row_stride_as * 6 * sizeof(float)), "r"(local_a_tmp));
|
||||
asm volatile ("fsw ft7, %0(%1)" :: "i"(BM * row_stride_as * 7 * sizeof(float)), "r"(local_a_tmp));
|
||||
local_a_tmp += BM * row_stride_as * 8;
|
||||
}
|
||||
} else {
|
||||
global_dmem_load_new<T, MemLayout::K_major, MemLayout::MN_major, BM,
|
||||
BK>(dim_k, threadblock_id_y, k, A, local_a,
|
||||
tid_in_threadblock);
|
||||
}
|
||||
} // end move A
|
||||
|
||||
// move B
|
||||
global_dmem_load_new<T, MemLayout::MN_major, MemLayout::MN_major, BN, BK>(
|
||||
dim_n, threadblock_id_x, k, B, local_b, tid_in_threadblock);
|
||||
|
||||
asm volatile ("global_dmem_load_finish_%=:" :: );
|
||||
}
|
||||
|
||||
// Do a single tile*tile matrix multiplication using the matrix data stored in
|
||||
// SMEM. Useful in fused kernels where GEMMs are done at a per-tile scope.
|
||||
template <typename T,
|
||||
@@ -933,8 +803,21 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
|
||||
#endif
|
||||
}
|
||||
#else
|
||||
global_dmem_load<T>(dim_m, dim_n, dim_k, block_k * BK, A, B, local_a,
|
||||
local_b, tid_in_threadblock, block_n, block_m);
|
||||
// move A
|
||||
if constexpr (!TRANSPOSE_AT_PRODUCE) {
|
||||
global_dmem_load_new<T, MemLayout::MN_major, MemLayout::MN_major, BM,
|
||||
BK>(dim_m, block_m, block_k * BK, A, local_a,
|
||||
tid_in_threadblock);
|
||||
} else {
|
||||
global_dmem_load_new<T, MemLayout::K_major, MemLayout::MN_major, BM,
|
||||
BK>(dim_k, block_m, block_k * BK, A, local_a,
|
||||
tid_in_threadblock);
|
||||
}
|
||||
|
||||
// move B
|
||||
global_dmem_load_new<T, MemLayout::MN_major, MemLayout::MN_major, BN,
|
||||
BK>(dim_n, block_n, block_k * BK, B, local_b,
|
||||
tid_in_threadblock);
|
||||
|
||||
threadblock_barrier(threadblock_id_in_cluster,
|
||||
warps_per_threadblock_per_core);
|
||||
|
||||
Reference in New Issue
Block a user