sgemm_tcore: Separate transpose control on AS read/write
Make separate control flags on transposed AS read/write to make it easy to model bank-conflict-free GMEM _and_ SMEM access.
This commit is contained in:
@@ -28,7 +28,7 @@ inline void global_dmem_load(const uint32_t dim_n, const uint32_t dim_k,
|
|||||||
// neighboring threads to ensure GMEM coalescing.
|
// neighboring threads to ensure GMEM coalescing.
|
||||||
//
|
//
|
||||||
// TODO: Sharedmem swizzling is important here
|
// TODO: Sharedmem swizzling is important here
|
||||||
if constexpr (!TRANSPOSE_AS) {
|
if constexpr (!TRANSPOSE_AT_PRODUCE) {
|
||||||
// FIXME: !TRANSPOSE_AS code is old
|
// FIXME: !TRANSPOSE_AS code is old
|
||||||
|
|
||||||
const uint32_t global_a_row = BM * threadblock_id_y + local_a_row;
|
const uint32_t global_a_row = BM * threadblock_id_y + local_a_row;
|
||||||
|
|||||||
@@ -33,8 +33,9 @@ inline void global_dmem_load(const uint32_t dim_n, const uint32_t dim_k,
|
|||||||
// neighboring threads to ensure GMEM coalescing.
|
// neighboring threads to ensure GMEM coalescing.
|
||||||
//
|
//
|
||||||
// TODO: Sharedmem swizzling is important here
|
// TODO: Sharedmem swizzling is important here
|
||||||
if constexpr (!TRANSPOSE_AS) {
|
if constexpr (!TRANSPOSE_AT_PRODUCE) {
|
||||||
// FIXME: !TRANSPOSE_AS code is old
|
// if !TRANSPOSE_AT_PRODUCE, we only support coalesced GMEM loads
|
||||||
|
static_assert(TRANSPOSE_AT_PRODUCE || GMEM_COALESCED_A);
|
||||||
|
|
||||||
const uint32_t global_a_row = BM * threadblock_id_y + local_a_row;
|
const uint32_t global_a_row = BM * threadblock_id_y + local_a_row;
|
||||||
// number of rows a full TB can read at a time
|
// number of rows a full TB can read at a time
|
||||||
@@ -42,26 +43,60 @@ inline void global_dmem_load(const uint32_t dim_n, const uint32_t dim_k,
|
|||||||
const float *global_a = A + dim_k * global_a_row + (k + local_a_col);
|
const float *global_a = A + dim_k * global_a_row + (k + local_a_col);
|
||||||
volatile float *local_a_tmp = local_a + BK * local_a_row + local_a_col;
|
volatile float *local_a_tmp = local_a + BK * local_a_row + local_a_col;
|
||||||
|
|
||||||
|
static_assert(
|
||||||
|
row_stride_a * 8 <= BM,
|
||||||
|
"manual loop unrolling condition not met; consider increasing BM");
|
||||||
|
static_assert(
|
||||||
|
(BM % (row_stride_a * 8)) == 0,
|
||||||
|
"manual loop unrolling condition not met; BM should be power-of-two");
|
||||||
|
|
||||||
#pragma GCC unroll 1
|
#pragma GCC unroll 1
|
||||||
for (uint32_t local_row_offset = 0; local_row_offset < BM;
|
for (uint32_t local_row_offset = 0; local_row_offset < BM;
|
||||||
local_row_offset += row_stride_a) {
|
local_row_offset += row_stride_a * 8) {
|
||||||
// const uint32_t global_a_offset =
|
// const uint32_t global_a_offset =
|
||||||
// dim_k * (global_a_row + local_row_offset) + (k + local_a_col);
|
// dim_k * (global_a_row + local_row_offset) + (k + local_a_col);
|
||||||
// local_a[BK * (local_a_row + local_row_offset) + local_a_col] =
|
// local_a[BK * (local_a_row + local_row_offset) + local_a_col] =
|
||||||
// A[global_a_offset];
|
// A[global_a_offset];
|
||||||
*local_a_tmp = *global_a;
|
//
|
||||||
|
// *local_a_tmp = *global_a;
|
||||||
|
// global_a += dim_k * row_stride_a;
|
||||||
|
// local_a_tmp += BK * row_stride_a;
|
||||||
|
|
||||||
|
asm volatile ("flw ft0, (%0)" :: "r"(global_a));
|
||||||
global_a += dim_k * row_stride_a;
|
global_a += dim_k * row_stride_a;
|
||||||
local_a_tmp += BK * row_stride_a;
|
asm volatile ("flw ft1, (%0)" :: "r"(global_a));
|
||||||
|
global_a += dim_k * row_stride_a;
|
||||||
|
asm volatile ("flw ft2, (%0)" :: "r"(global_a));
|
||||||
|
global_a += dim_k * row_stride_a;
|
||||||
|
asm volatile ("flw ft3, (%0)" :: "r"(global_a));
|
||||||
|
global_a += dim_k * row_stride_a;
|
||||||
|
asm volatile ("flw ft4, (%0)" :: "r"(global_a));
|
||||||
|
global_a += dim_k * row_stride_a;
|
||||||
|
asm volatile ("flw ft5, (%0)" :: "r"(global_a));
|
||||||
|
global_a += dim_k * row_stride_a;
|
||||||
|
asm volatile ("flw ft6, (%0)" :: "r"(global_a));
|
||||||
|
global_a += dim_k * row_stride_a;
|
||||||
|
asm volatile ("flw ft7, (%0)" :: "r"(global_a));
|
||||||
|
global_a += dim_k * row_stride_a;
|
||||||
|
|
||||||
|
// stride along columns
|
||||||
|
// bank conflicts
|
||||||
|
asm volatile ("fsw ft0, %0(%1)" :: "i"(BK * row_stride_a * 0 * sizeof(float)), "r"(local_a_tmp));
|
||||||
|
asm volatile ("fsw ft1, %0(%1)" :: "i"(BK * row_stride_a * 1 * sizeof(float)), "r"(local_a_tmp));
|
||||||
|
asm volatile ("fsw ft2, %0(%1)" :: "i"(BK * row_stride_a * 2 * sizeof(float)), "r"(local_a_tmp));
|
||||||
|
asm volatile ("fsw ft3, %0(%1)" :: "i"(BK * row_stride_a * 3 * sizeof(float)), "r"(local_a_tmp));
|
||||||
|
asm volatile ("fsw ft4, %0(%1)" :: "i"(BK * row_stride_a * 4 * sizeof(float)), "r"(local_a_tmp));
|
||||||
|
asm volatile ("fsw ft5, %0(%1)" :: "i"(BK * row_stride_a * 5 * sizeof(float)), "r"(local_a_tmp));
|
||||||
|
asm volatile ("fsw ft6, %0(%1)" :: "i"(BK * row_stride_a * 6 * sizeof(float)), "r"(local_a_tmp));
|
||||||
|
asm volatile ("fsw ft7, %0(%1)" :: "i"(BK * row_stride_a * 7 * sizeof(float)), "r"(local_a_tmp));
|
||||||
|
local_a_tmp += BK * row_stride_a * 8;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if constexpr (!GMEM_COALESCED_A) {
|
if constexpr (!GMEM_COALESCED_A) {
|
||||||
constexpr uint32_t row_stride_as = threads_in_warpgroup / BM;
|
constexpr uint32_t row_stride_as = threads_in_warpgroup / BM;
|
||||||
const uint32_t global_a_row = BM * threadblock_id_y + local_as_col;
|
const uint32_t global_a_row = BM * threadblock_id_y + local_as_col;
|
||||||
|
// NOTE that GMEM reads are transposed
|
||||||
const float *global_a = A + dim_k * global_a_row + (k + local_as_row);
|
const float *global_a = A + dim_k * global_a_row + (k + local_as_row);
|
||||||
// FIXME experimenting with global coalescing
|
|
||||||
// const uint32_t global_a_row = BM * threadblock_id_y + local_as_row;
|
|
||||||
// const float *global_a = A + dim_k * global_a_row + (k + local_as_col);
|
|
||||||
volatile float *local_a_tmp = local_a + BM * local_as_row + local_as_col;
|
volatile float *local_a_tmp = local_a + BM * local_as_row + local_as_col;
|
||||||
|
|
||||||
static_assert(
|
static_assert(
|
||||||
@@ -152,6 +187,7 @@ inline void global_dmem_load(const uint32_t dim_n, const uint32_t dim_k,
|
|||||||
global_a += dim_k * row_stride_a;
|
global_a += dim_k * row_stride_a;
|
||||||
|
|
||||||
// stride along columns
|
// stride along columns
|
||||||
|
// bank conflicts
|
||||||
asm volatile ("fsw ft0, %0(%1)" :: "i"(row_stride_a * 0 * sizeof(float)), "r"(local_a_tmp));
|
asm volatile ("fsw ft0, %0(%1)" :: "i"(row_stride_a * 0 * sizeof(float)), "r"(local_a_tmp));
|
||||||
asm volatile ("fsw ft1, %0(%1)" :: "i"(row_stride_a * 1 * sizeof(float)), "r"(local_a_tmp));
|
asm volatile ("fsw ft1, %0(%1)" :: "i"(row_stride_a * 1 * sizeof(float)), "r"(local_a_tmp));
|
||||||
asm volatile ("fsw ft2, %0(%1)" :: "i"(row_stride_a * 2 * sizeof(float)), "r"(local_a_tmp));
|
asm volatile ("fsw ft2, %0(%1)" :: "i"(row_stride_a * 2 * sizeof(float)), "r"(local_a_tmp));
|
||||||
|
|||||||
@@ -35,7 +35,15 @@
|
|||||||
// number of loop around the inner 0..TCK..BK loop to simulate perfect-DRAM
|
// number of loop around the inner 0..TCK..BK loop to simulate perfect-DRAM
|
||||||
// scenario
|
// scenario
|
||||||
#define BK_LOOP 1
|
#define BK_LOOP 1
|
||||||
#define TRANSPOSE_AS 1
|
// whether to transpose smem A tile at GMEM->SMEM (produce), or SMEM->RF
|
||||||
|
// (consume). This is because the tensor core expects the A tile to be stored
|
||||||
|
// in column-major order in SMEM.
|
||||||
|
//
|
||||||
|
// For correctness, only one of either should be 1. To model the case where
|
||||||
|
// the entire A matrix is already stored transposed in GMEM ("TN" kernel), set
|
||||||
|
// both to 0.
|
||||||
|
#define TRANSPOSE_AT_PRODUCE 0
|
||||||
|
#define TRANSPOSE_AT_CONSUME 0
|
||||||
// GMEM_COALESCED sets bank conflict-free accesses for
|
// GMEM_COALESCED sets bank conflict-free accesses for
|
||||||
// 1: GMEM loads of A matrix
|
// 1: GMEM loads of A matrix
|
||||||
// 0: SMEM stores of A matrix
|
// 0: SMEM stores of A matrix
|
||||||
@@ -171,7 +179,7 @@ inline void vx_wmma_load_a(volatile float *smem_A, const int local_k,
|
|||||||
constexpr int smem_AS_rows = BK;
|
constexpr int smem_AS_rows = BK;
|
||||||
constexpr int smem_AS_cols = BM;
|
constexpr int smem_AS_cols = BM;
|
||||||
|
|
||||||
if constexpr (!TRANSPOSE_AS) {
|
if constexpr (TRANSPOSE_AT_CONSUME) {
|
||||||
// int A_offset = (WM * warp_row + TCM * wm_iter + row) * smem_A_cols;
|
// int A_offset = (WM * warp_row + TCM * wm_iter + row) * smem_A_cols;
|
||||||
|
|
||||||
// @perf: bank conflicts
|
// @perf: bank conflicts
|
||||||
@@ -195,7 +203,7 @@ inline void vx_wmma_load_a(volatile float *smem_A, const int local_k,
|
|||||||
// asm volatile("flw f6, %0" ::"m"(smem_A[A_offset + (local_k + 6)]));
|
// asm volatile("flw f6, %0" ::"m"(smem_A[A_offset + (local_k + 6)]));
|
||||||
// asm volatile("flw f7, %0" ::"m"(smem_A[A_offset + (local_k + 7)]));
|
// asm volatile("flw f7, %0" ::"m"(smem_A[A_offset + (local_k + 7)]));
|
||||||
} else {
|
} else {
|
||||||
// transposed A
|
// read smem A tile as-is; bank-conflict-free AS load
|
||||||
// f8-f15 stores a single row of A
|
// f8-f15 stores a single row of A
|
||||||
volatile float *smem_addr;
|
volatile float *smem_addr;
|
||||||
smem_addr = &smem_A[((local_k + 0) * smem_AS_cols) + (WM * warp_row + TCM * wm_iter) + row];
|
smem_addr = &smem_A[((local_k + 0) * smem_AS_cols) + (WM * warp_row + TCM * wm_iter) + row];
|
||||||
|
|||||||
Reference in New Issue
Block a user