sgemm_tcore: Separate transpose control on AS read/write

Make separate control flags on transposed AS read/write to make it easy
to model bank-conflict-free GMEM _and_ SMEM access.
This commit is contained in:
Hansung Kim
2024-06-11 21:16:23 -07:00
parent 34eaab4c87
commit 03d1df8f53
3 changed files with 56 additions and 12 deletions

View File

@@ -33,8 +33,9 @@ inline void global_dmem_load(const uint32_t dim_n, const uint32_t dim_k,
// neighboring threads to ensure GMEM coalescing.
//
// TODO: Sharedmem swizzling is important here
if constexpr (!TRANSPOSE_AS) {
// FIXME: !TRANSPOSE_AS code is old
if constexpr (!TRANSPOSE_AT_PRODUCE) {
// if !TRANSPOSE_AT_PRODUCE, we only support coalesced GMEM loads
static_assert(TRANSPOSE_AT_PRODUCE || GMEM_COALESCED_A);
const uint32_t global_a_row = BM * threadblock_id_y + local_a_row;
// number of rows a full TB can read at a time
@@ -42,26 +43,60 @@ inline void global_dmem_load(const uint32_t dim_n, const uint32_t dim_k,
const float *global_a = A + dim_k * global_a_row + (k + local_a_col);
volatile float *local_a_tmp = local_a + BK * local_a_row + local_a_col;
static_assert(
row_stride_a * 8 <= BM,
"manual loop unrolling condition not met; consider increasing BM");
static_assert(
(BM % (row_stride_a * 8)) == 0,
"manual loop unrolling condition not met; BM should be power-of-two");
#pragma GCC unroll 1
for (uint32_t local_row_offset = 0; local_row_offset < BM;
local_row_offset += row_stride_a) {
local_row_offset += row_stride_a * 8) {
// const uint32_t global_a_offset =
// dim_k * (global_a_row + local_row_offset) + (k + local_a_col);
// local_a[BK * (local_a_row + local_row_offset) + local_a_col] =
// A[global_a_offset];
*local_a_tmp = *global_a;
//
// *local_a_tmp = *global_a;
// global_a += dim_k * row_stride_a;
// local_a_tmp += BK * row_stride_a;
asm volatile ("flw ft0, (%0)" :: "r"(global_a));
global_a += dim_k * row_stride_a;
local_a_tmp += BK * row_stride_a;
asm volatile ("flw ft1, (%0)" :: "r"(global_a));
global_a += dim_k * row_stride_a;
asm volatile ("flw ft2, (%0)" :: "r"(global_a));
global_a += dim_k * row_stride_a;
asm volatile ("flw ft3, (%0)" :: "r"(global_a));
global_a += dim_k * row_stride_a;
asm volatile ("flw ft4, (%0)" :: "r"(global_a));
global_a += dim_k * row_stride_a;
asm volatile ("flw ft5, (%0)" :: "r"(global_a));
global_a += dim_k * row_stride_a;
asm volatile ("flw ft6, (%0)" :: "r"(global_a));
global_a += dim_k * row_stride_a;
asm volatile ("flw ft7, (%0)" :: "r"(global_a));
global_a += dim_k * row_stride_a;
// stride along columns
// bank conflicts
asm volatile ("fsw ft0, %0(%1)" :: "i"(BK * row_stride_a * 0 * sizeof(float)), "r"(local_a_tmp));
asm volatile ("fsw ft1, %0(%1)" :: "i"(BK * row_stride_a * 1 * sizeof(float)), "r"(local_a_tmp));
asm volatile ("fsw ft2, %0(%1)" :: "i"(BK * row_stride_a * 2 * sizeof(float)), "r"(local_a_tmp));
asm volatile ("fsw ft3, %0(%1)" :: "i"(BK * row_stride_a * 3 * sizeof(float)), "r"(local_a_tmp));
asm volatile ("fsw ft4, %0(%1)" :: "i"(BK * row_stride_a * 4 * sizeof(float)), "r"(local_a_tmp));
asm volatile ("fsw ft5, %0(%1)" :: "i"(BK * row_stride_a * 5 * sizeof(float)), "r"(local_a_tmp));
asm volatile ("fsw ft6, %0(%1)" :: "i"(BK * row_stride_a * 6 * sizeof(float)), "r"(local_a_tmp));
asm volatile ("fsw ft7, %0(%1)" :: "i"(BK * row_stride_a * 7 * sizeof(float)), "r"(local_a_tmp));
local_a_tmp += BK * row_stride_a * 8;
}
} else {
if constexpr (!GMEM_COALESCED_A) {
constexpr uint32_t row_stride_as = threads_in_warpgroup / BM;
const uint32_t global_a_row = BM * threadblock_id_y + local_as_col;
// NOTE that GMEM reads are transposed
const float *global_a = A + dim_k * global_a_row + (k + local_as_row);
// FIXME experimenting with global coalescing
// const uint32_t global_a_row = BM * threadblock_id_y + local_as_row;
// const float *global_a = A + dim_k * global_a_row + (k + local_as_col);
volatile float *local_a_tmp = local_a + BM * local_as_row + local_as_col;
static_assert(
@@ -152,6 +187,7 @@ inline void global_dmem_load(const uint32_t dim_n, const uint32_t dim_k,
global_a += dim_k * row_stride_a;
// stride along columns
// bank conflicts
asm volatile ("fsw ft0, %0(%1)" :: "i"(row_stride_a * 0 * sizeof(float)), "r"(local_a_tmp));
asm volatile ("fsw ft1, %0(%1)" :: "i"(row_stride_a * 1 * sizeof(float)), "r"(local_a_tmp));
asm volatile ("fsw ft2, %0(%1)" :: "i"(row_stride_a * 2 * sizeof(float)), "r"(local_a_tmp));