sgemm_tcore: Support data move for fp16-packed elements

Since core does not support memory accesses to non-word-aligned
addresses, pack fp16 elements in pairs into fp32 values, and do regular
tile movement with conditionally compressed column dimensions.
Perf seems to stay the same for fp32 256x256.
This commit is contained in:
Hansung Kim
2024-07-30 18:07:34 -07:00
parent 7f26548724
commit 88cddc2b66
2 changed files with 155 additions and 132 deletions

View File

@@ -37,10 +37,6 @@
#error "threadblock size too big for cluster"
#endif
// "fake" fp16 type that only has the correct word size. Proper conversion to
// fp32 need to be done in a custom function.
using float16_t = uint16_t;
template <typename T>
inline void global_dmem_load(const uint32_t dim_n, const uint32_t dim_k,
const uint32_t k, const T *A, const T *B,
@@ -48,13 +44,27 @@ inline void global_dmem_load(const uint32_t dim_n, const uint32_t dim_k,
const uint32_t tid_in_threadblock,
const uint32_t threadblock_id_x,
const uint32_t threadblock_id_y) {
const uint32_t local_a_row = tid_in_threadblock / BK;
const uint32_t local_a_col = tid_in_threadblock % BK;
// In fp16 mode, bit-pack two fp16 elements into each fp32 element, and do
// data movement at the fp32 granularity. Assuming that the matrix is stored
// row-major in GMEM, the packed fp16 pairs belong to the same row,
// neighboring columns; therefore, it essentially becomes equivalent to
// moving a fp32 matrix whose column dimensions (dim_k/BK/k) are compressed
// by a factor of two.
constexpr uint32_t packed_factor = (std::is_same_v<T, float16_t> ? 2 : 1);
constexpr uint32_t BK_adjusted = BK / packed_factor;
const uint32_t dim_k_adjusted = dim_k / packed_factor;
constexpr uint32_t BN_adjusted = BN / packed_factor;
const uint32_t dim_n_adjusted = dim_n / packed_factor;
const uint32_t k_adjusted = k / packed_factor;
const uint32_t local_a_row = tid_in_threadblock / BK_adjusted;
const uint32_t local_a_col = tid_in_threadblock % BK_adjusted;
const uint32_t local_as_row = tid_in_threadblock / BM;
const uint32_t local_as_col = tid_in_threadblock % BM;
const uint32_t local_b_row = tid_in_threadblock / BN;
const uint32_t local_b_col = tid_in_threadblock % BN;
const uint32_t local_b_row = tid_in_threadblock / BN_adjusted;
const uint32_t local_b_col = tid_in_threadblock % BN_adjusted;
// FIXME: need fix for fp16?
constexpr uint32_t threads_in_threadblock = (BM * BN) / ELEM_PER_THREAD;
// Data move from GMEM to SMEM
@@ -63,53 +73,59 @@ inline void global_dmem_load(const uint32_t dim_n, const uint32_t dim_k,
// neighboring threads to ensure GMEM coalescing.
//
// TODO: Sharedmem swizzling is important here
// move A
if constexpr (!TRANSPOSE_AT_PRODUCE) {
// No transpose at GMEM->SMEM movement
// FIXME: !TRANSPOSE_AS code is old
const uint32_t global_a_row = BM * threadblock_id_y + local_a_row;
// number of rows a full TB can read at a time
constexpr uint32_t row_stride_a = threads_in_threadblock / BK;
const T *global_a = A + dim_k * global_a_row + (k + local_a_col);
volatile T *local_a_tmp = local_a + BK * local_a_row + local_a_col;
// this is equivalent to threadblock_dim_y (assuming threadblock_dim_x ==
// BK)
constexpr uint32_t row_stride_a = threads_in_threadblock / BK_adjusted;
const float *global_a = reinterpret_cast<float *>(A) +
dim_k_adjusted * global_a_row +
(k_adjusted + local_a_col);
volatile float *local_a_tmp = reinterpret_cast<float *>(local_a) +
BK_adjusted * local_a_row + local_a_col;
#pragma GCC unroll 1
for (uint32_t local_row_offset = 0; local_row_offset < BM;
local_row_offset += row_stride_a) {
// const uint32_t global_a_offset =
// dim_k * (global_a_row + local_row_offset) + (k + local_a_col);
// local_a[BK * (local_a_row + local_row_offset) + local_a_col] =
// A[global_a_offset];
*local_a_tmp = *global_a;
global_a += dim_k * row_stride_a;
local_a_tmp += BK * row_stride_a;
// move to the next "row-chunk", when threadblock is smaller than BM*BK
global_a += dim_k_adjusted * row_stride_a;
local_a_tmp += BK_adjusted * row_stride_a;
}
} else {
if constexpr (!GMEM_COALESCED_A) {
// !GMEM_COALESCED_A: threads do uncoalesced read from neighboring row in
// GMEM, writes to neighboring cols in SMEM
constexpr uint32_t row_stride_as = threads_in_threadblock / BM;
const uint32_t global_a_row = BM * threadblock_id_y + local_as_col;
const T *global_a = A + dim_k * global_a_row + (k + local_as_row);
// FIXME experimenting with global coalescing
// const uint32_t global_a_row = BM * threadblock_id_y + local_as_row;
// const T *global_a = A + dim_k * global_a_row + (k + local_as_col);
volatile T *local_a_tmp = local_a + BM * local_as_row + local_as_col;
const float *global_a =
reinterpret_cast<float *>(A) + dim_k_adjusted * global_a_row + (k_adjusted + local_as_row);
volatile float *local_a_tmp =
reinterpret_cast<float *>(local_a) + BM * local_as_row + local_as_col;
static_assert(
row_stride_as * 8 <= BK,
row_stride_as * 8 <= BK_adjusted,
"manual loop unrolling condition not met; consider increasing BK");
static_assert(
(BK % (row_stride_as * 8)) == 0,
(BK_adjusted % (row_stride_as * 8)) == 0,
"manual loop unrolling condition not met; BK should be power-of-two");
#pragma GCC unroll 1
for (uint32_t local_row_offset = 0; local_row_offset < BK;
for (uint32_t local_row_offset = 0; local_row_offset < BK_adjusted;
local_row_offset += row_stride_as * 8) {
// @perf: bank conflicts here
// const uint32_t global_a_offset =
// dim_k * (global_a_row) + (k + local_as_row + local_row_offset);
// dim_k_adjusted * (global_a_row) + (k + local_as_row + local_row_offset);
// FIXME experimenting with global coalescing
// const uint32_t global_a_offset =
// dim_k * (global_a_row + local_row_offset) + (k + local_as_col);
// dim_k_adjusted * (global_a_row + local_row_offset) + (k + local_as_col);
// local_a[BM * (local_as_row + local_row_offset) + local_as_col] =
// A[global_a_offset];
@@ -146,11 +162,15 @@ inline void global_dmem_load(const uint32_t dim_n, const uint32_t dim_k,
local_a_tmp += BM * row_stride_as * 8;
}
} else {
constexpr uint32_t row_stride_a = threads_in_threadblock / BK;
constexpr uint32_t row_stride_a = threads_in_threadblock / BK_adjusted;
const uint32_t global_a_row = BM * threadblock_id_y + local_a_row;
const T *global_a = A + dim_k * global_a_row + (k + local_a_col);
const float *global_a = reinterpret_cast<const float *>(A) +
dim_k_adjusted * global_a_row +
(k_adjusted + local_a_col);
// NOTE that SMEM writes are transposed
volatile T *local_a_tmp = local_a + BM * local_a_col + local_a_row;
volatile float *local_a_tmp =
reinterpret_cast<volatile float *>(local_a) + BM * local_a_col +
local_a_row;
static_assert(
row_stride_a * 8 <= BM,
@@ -163,27 +183,27 @@ inline void global_dmem_load(const uint32_t dim_n, const uint32_t dim_k,
for (uint32_t local_row_offset = 0; local_row_offset < BM;
local_row_offset += row_stride_a * 8) {
// const uint32_t global_a_offset =
// dim_k * (global_a_row + local_row_offset) + (k + local_a_col);
// dim_k_adjusted * (global_a_row + local_row_offset) + (k + local_a_col);
// NOTE that SMEM writes are transposed
// local_a[BM * (local_a_col) + local_a_row + local_row_offset] =
// A[global_a_offset];
asm volatile ("flw ft0, (%0)" :: "r"(global_a));
global_a += dim_k * row_stride_a;
global_a += dim_k_adjusted * row_stride_a;
asm volatile ("flw ft1, (%0)" :: "r"(global_a));
global_a += dim_k * row_stride_a;
global_a += dim_k_adjusted * row_stride_a;
asm volatile ("flw ft2, (%0)" :: "r"(global_a));
global_a += dim_k * row_stride_a;
global_a += dim_k_adjusted * row_stride_a;
asm volatile ("flw ft3, (%0)" :: "r"(global_a));
global_a += dim_k * row_stride_a;
global_a += dim_k_adjusted * row_stride_a;
asm volatile ("flw ft4, (%0)" :: "r"(global_a));
global_a += dim_k * row_stride_a;
global_a += dim_k_adjusted * row_stride_a;
asm volatile ("flw ft5, (%0)" :: "r"(global_a));
global_a += dim_k * row_stride_a;
global_a += dim_k_adjusted * row_stride_a;
asm volatile ("flw ft6, (%0)" :: "r"(global_a));
global_a += dim_k * row_stride_a;
global_a += dim_k_adjusted * row_stride_a;
asm volatile ("flw ft7, (%0)" :: "r"(global_a));
global_a += dim_k * row_stride_a;
global_a += dim_k_adjusted * row_stride_a;
// stride along columns
asm volatile ("fsw ft0, %0(%1)" :: "i"(row_stride_a * 0 * sizeof(float)), "r"(local_a_tmp));
@@ -197,62 +217,63 @@ inline void global_dmem_load(const uint32_t dim_n, const uint32_t dim_k,
local_a_tmp += row_stride_a * 8;
}
}
}
} // end move A
constexpr uint32_t row_stride_b = threads_in_threadblock / BN;
const uint32_t global_b_col = BN * threadblock_id_x + local_b_col;
const T *global_b = B + dim_n * (k + local_b_row) + global_b_col;
volatile T *local_b_tmp = local_b + BN * local_b_row + local_b_col;
// move B
constexpr uint32_t row_stride_b = threads_in_threadblock / BN_adjusted;
const uint32_t global_b_col = BN_adjusted * threadblock_id_x + local_b_col;
// NOTE: not k_adjusted here; k is along the row dimension which is not
// compressed for fp16
const float *global_b = reinterpret_cast<const float *>(B) +
dim_n_adjusted * (k + local_b_row) + global_b_col;
volatile float *local_b_tmp = reinterpret_cast<volatile float *>(local_b) +
BN_adjusted * local_b_row + local_b_col;
static_assert(
row_stride_b * 8 <= BK,
row_stride_b * 8 <= BK_adjusted,
"manual loop unrolling condition not met; consider increasing BK");
static_assert(
(BK % (row_stride_b * 8)) == 0,
(BK_adjusted % (row_stride_b * 8)) == 0,
"manual loop unrolling condition not met; BK should be power-of-two");
#pragma GCC unroll 1
for (uint32_t load_offset = 0; load_offset < BK;
load_offset += row_stride_b * 8) {
// const uint32_t global_b_offset =
// dim_n * (k + local_b_row + load_offset) + global_b_col;
// local_b[BN * (local_b_row + load_offset) + local_b_col] =
// B[global_b_offset];
// equivalent code:
//
// *local_b_tmp = *global_b;
// global_b += dim_n * row_stride_b;
// local_b_tmp += BN * row_stride_b;
asm volatile ("flw ft0, (%0)" :: "r"(global_b));
global_b += dim_n * row_stride_b;
global_b += dim_n_adjusted * row_stride_b;
asm volatile ("flw ft1, (%0)" :: "r"(global_b));
global_b += dim_n * row_stride_b;
global_b += dim_n_adjusted * row_stride_b;
asm volatile ("flw ft2, (%0)" :: "r"(global_b));
global_b += dim_n * row_stride_b;
global_b += dim_n_adjusted * row_stride_b;
asm volatile ("flw ft3, (%0)" :: "r"(global_b));
global_b += dim_n * row_stride_b;
global_b += dim_n_adjusted * row_stride_b;
asm volatile ("flw ft4, (%0)" :: "r"(global_b));
global_b += dim_n * row_stride_b;
global_b += dim_n_adjusted * row_stride_b;
asm volatile ("flw ft5, (%0)" :: "r"(global_b));
global_b += dim_n * row_stride_b;
global_b += dim_n_adjusted * row_stride_b;
asm volatile ("flw ft6, (%0)" :: "r"(global_b));
global_b += dim_n * row_stride_b;
global_b += dim_n_adjusted * row_stride_b;
asm volatile ("flw ft7, (%0)" :: "r"(global_b));
global_b += dim_n * row_stride_b;
global_b += dim_n_adjusted * row_stride_b;
asm volatile ("fsw ft0, %0(%1)" :: "i"(BN * row_stride_b * 0 * sizeof(float)), "r"(local_b_tmp));
asm volatile ("fsw ft1, %0(%1)" :: "i"(BN * row_stride_b * 1 * sizeof(float)), "r"(local_b_tmp));
local_b_tmp += BN * row_stride_b * 2;
asm volatile ("fsw ft2, %0(%1)" :: "i"(BN * row_stride_b * 0 * sizeof(float)), "r"(local_b_tmp));
asm volatile ("fsw ft3, %0(%1)" :: "i"(BN * row_stride_b * 1 * sizeof(float)), "r"(local_b_tmp));
local_b_tmp += BN * row_stride_b * 2;
asm volatile ("fsw ft4, %0(%1)" :: "i"(BN * row_stride_b * 0 * sizeof(float)), "r"(local_b_tmp));
asm volatile ("fsw ft5, %0(%1)" :: "i"(BN * row_stride_b * 1 * sizeof(float)), "r"(local_b_tmp));
local_b_tmp += BN * row_stride_b * 2;
asm volatile ("fsw ft6, %0(%1)" :: "i"(BN * row_stride_b * 0 * sizeof(float)), "r"(local_b_tmp));
asm volatile ("fsw ft7, %0(%1)" :: "i"(BN * row_stride_b * 1 * sizeof(float)), "r"(local_b_tmp));
local_b_tmp += BN * row_stride_b * 2;
asm volatile ("fsw ft0, %0(%1)" :: "i"(BN_adjusted * row_stride_b * 0 * sizeof(float)), "r"(local_b_tmp));
asm volatile ("fsw ft1, %0(%1)" :: "i"(BN_adjusted * row_stride_b * 1 * sizeof(float)), "r"(local_b_tmp));
local_b_tmp += BN_adjusted * row_stride_b * 2;
asm volatile ("fsw ft2, %0(%1)" :: "i"(BN_adjusted * row_stride_b * 0 * sizeof(float)), "r"(local_b_tmp));
asm volatile ("fsw ft3, %0(%1)" :: "i"(BN_adjusted * row_stride_b * 1 * sizeof(float)), "r"(local_b_tmp));
local_b_tmp += BN_adjusted * row_stride_b * 2;
asm volatile ("fsw ft4, %0(%1)" :: "i"(BN_adjusted * row_stride_b * 0 * sizeof(float)), "r"(local_b_tmp));
asm volatile ("fsw ft5, %0(%1)" :: "i"(BN_adjusted * row_stride_b * 1 * sizeof(float)), "r"(local_b_tmp));
local_b_tmp += BN_adjusted * row_stride_b * 2;
asm volatile ("fsw ft6, %0(%1)" :: "i"(BN_adjusted * row_stride_b * 0 * sizeof(float)), "r"(local_b_tmp));
asm volatile ("fsw ft7, %0(%1)" :: "i"(BN_adjusted * row_stride_b * 1 * sizeof(float)), "r"(local_b_tmp));
local_b_tmp += BN_adjusted * row_stride_b * 2;
}
}
@@ -440,8 +461,8 @@ inline void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
#endif
}
#else
global_dmem_load(dim_n, dim_k, block_k * BK, A, B, local_a, local_b,
tid_in_threadblock, block_n, block_m);
global_dmem_load<T>(dim_n, dim_k, block_k * BK, A, B, local_a, local_b,
tid_in_threadblock, block_n, block_m);
threadblock_barrier(threadblock_id_in_cluster, threadblock_dim_y);
#endif
@@ -466,6 +487,7 @@ inline void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
local_a_consume = local_a + (block_k & 1) * (local_a_elems);
local_b_consume = local_b + (block_k & 1) * (local_b_elems);
} else {
// no double-buffering without DMA
local_a_consume = local_a;
local_b_consume = local_b;
}
@@ -477,12 +499,13 @@ inline void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
#pragma GCC unroll 2
for (int wn_iter = 0; wn_iter < WNITER; wn_iter++) {
// SMEM -> RF
vx_wmma_load_b(local_b_consume, local_k, warp_col, wn_iter, tid_in_warp);
vx_wmma_load_b<T>(local_b_consume, local_k, warp_col, wn_iter,
tid_in_warp);
#pragma GCC unroll 2
for (int wm_iter = 0; wm_iter < WMITER; wm_iter++) {
// SMEM -> RF
vx_wmma_load_a(local_a_consume, local_k, warp_row, wm_iter,
tid_in_warp);
vx_wmma_load_a<T>(local_a_consume, local_k, warp_row, wm_iter,
tid_in_warp);
// perform mma
vx_wmma(wm_iter);
}
@@ -506,8 +529,8 @@ inline void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
for (int wm_iter = 0; wm_iter < WMITER; wm_iter++) {
#pragma GCC unroll 2
for (int wn_iter = 0; wn_iter < WNITER; wn_iter++) {
write_results<float>(tid_in_warp, warp_col, warp_row, wn_iter,
wm_iter, dim_n, C, block_n, block_m);
write_results(tid_in_warp, warp_col, warp_row, wn_iter, wm_iter,
dim_n, C, block_n, block_m);
}
}
}

View File

@@ -35,20 +35,27 @@
#define BK_LOOP 1
// Whether to transpose smem A tile at GMEM->SMEM (produce), or SMEM->RF
// (consume). This is because the tensor core expects the A tile to be stored
// in column-major order in SMEM, whereas it is stored row-major in GMEM.
// in column-major order in SMEM, whereas it will be ultimately stored in
// row-major in the RF.
//
// For correctness, only one of either should be 1. To model the case where
// the A matrix is already stored transposed in GMEM ("TN" kernel), set
// both to 0.
//
// For reference, PRODUCE 1 CONSUME 0 generates the performant NN kernel.
// For correctness, only one of either should be 1. E.g., PRODUCE 1 CONSUME 0
// generates the NN kernel where both A and B are stored row-major in GMEM.
// To model the case where the A matrix is already stored transposed in GMEM
// ("TN" kernel), set both to 0.
#define TRANSPOSE_AT_PRODUCE 1
#define TRANSPOSE_AT_CONSUME 0
// GMEM_COALESCED sets bank conflict-free accesses for
// 1: GMEM loads of A matrix
// 0: SMEM stores of A matrix
// GMEM_COALESCED: When TRANSPOSE_AT_PRODUCE == 1 (i.e. transpose at
// GMEM->SMEM), determines whether we do bank-conflict-free accesses for
// 1: GMEM loads of A matrix, or
// 0: SMEM stores of A matrix.
//
// Usually, GMEM_COALESCED==1 yields better performance since the memory
// behavior of GMEM is more sensitive to bank conflicts.
#define GMEM_COALESCED_A 1
// "fake" fp16 type that only has the correct data width.
using float16_t = uint16_t;
inline constexpr void map_operand_32lanes(const int tid, int &row, int &col) {
const int tg = tid / 4;
@@ -153,14 +160,23 @@ inline void vx_wmma_load_a(volatile const T *smem_A, const int local_k,
const int tid = thread_in_warp;
const int tg = tid / 4;
// TODO: this is duplicately computed between vx_wmma_load_a and vx_wmma_load_b
// @perf: this is duplicately computed in vx_wmma_load_a and vx_wmma_load_b
int row = 0;
int col = 0;
map_operand(tid, row, col);
// In fp16 mode, bit-pack two fp16 elements into each fp32 element, and do
// data movement at the fp32 granularity. Assuming that the matrix is stored
// row-major in GMEM, the packed fp16 pairs belong to the same row,
// neighboring columns; therefore, it essentially becomes equivalent to
// moving a fp32 matrix whose column dimensions (dim_k/BK/k) are compressed
// by a factor of two.
constexpr uint32_t packed_factor = (std::is_same_v<T, float16_t> ? 2 : 1);
constexpr uint32_t BK_adjusted = BK / packed_factor;
constexpr int smem_A_rows = BM;
constexpr int smem_A_cols = BK;
constexpr int smem_AS_rows = BK;
constexpr int smem_A_cols = BK_adjusted;
constexpr int smem_AS_rows = BK_adjusted;
constexpr int smem_AS_cols = BM;
if constexpr (TRANSPOSE_AT_CONSUME) {
@@ -170,11 +186,11 @@ inline void vx_wmma_load_a(volatile const T *smem_A, const int local_k,
// f8-f15 stores a single row of A
const volatile uint8_t *smem_addr;
smem_addr = reinterpret_cast<const volatile uint8_t *>(
&smem_A[(WM * warp_row + TCM * wm_iter + row) * smem_A_cols + local_k]);
// NOTE: stride is fixed to word size , i.e. sizeof(float) = 4,
// regardless of fp16 or fp32. Since Vortex core does not support fp16,
// load things at word granularity and reinterpret bits inside the tensor
// core.
&reinterpret_cast<const volatile float *>(
smem_A)[(WM * warp_row + TCM * wm_iter + row) * smem_A_cols +
local_k]);
// step to the next column
// threads read from different rows; bank conflicts
asm volatile("flw f0, %0(%1)" ::"i"(0 * sizeof(float)), "r"(smem_addr));
asm volatile("flw f1, %0(%1)" ::"i"(1 * sizeof(float)), "r"(smem_addr));
asm volatile("flw f2, %0(%1)" ::"i"(2 * sizeof(float)), "r"(smem_addr));
@@ -183,21 +199,17 @@ inline void vx_wmma_load_a(volatile const T *smem_A, const int local_k,
asm volatile("flw f5, %0(%1)" ::"i"(5 * sizeof(float)), "r"(smem_addr));
asm volatile("flw f6, %0(%1)" ::"i"(6 * sizeof(float)), "r"(smem_addr));
asm volatile("flw f7, %0(%1)" ::"i"(7 * sizeof(float)), "r"(smem_addr));
// asm volatile("flw f0, %0" ::"m"(smem_A[A_offset + (local_k + 0)]));
// asm volatile("flw f1, %0" ::"m"(smem_A[A_offset + (local_k + 1)]));
// asm volatile("flw f2, %0" ::"m"(smem_A[A_offset + (local_k + 2)]));
// asm volatile("flw f3, %0" ::"m"(smem_A[A_offset + (local_k + 3)]));
// asm volatile("flw f4, %0" ::"m"(smem_A[A_offset + (local_k + 4)]));
// asm volatile("flw f5, %0" ::"m"(smem_A[A_offset + (local_k + 5)]));
// asm volatile("flw f6, %0" ::"m"(smem_A[A_offset + (local_k + 6)]));
// asm volatile("flw f7, %0" ::"m"(smem_A[A_offset + (local_k + 7)]));
} else {
// read smem A tile as-is; bank-conflict-free AS load
// smem A tile is stored column-major
// f8-f15 stores a single row of A
const volatile uint8_t *smem_addr;
smem_addr = reinterpret_cast<const volatile uint8_t *>(
&smem_A[((local_k + 0) * smem_AS_cols) +
(WM * warp_row + TCM * wm_iter) + row]);
&reinterpret_cast<const volatile float *>(
smem_A)[((local_k + 0) * smem_AS_cols) +
(WM * warp_row + TCM * wm_iter) + row]);
// step to the next row
// threads read from different columns; no bank conflicts
asm volatile("flw f0, %0(%1)" :: "i"(smem_AS_cols * 0 * sizeof(float)), "r"(smem_addr));
asm volatile("flw f1, %0(%1)" :: "i"(smem_AS_cols * 1 * sizeof(float)), "r"(smem_addr));
asm volatile("flw f2, %0(%1)" :: "i"(smem_AS_cols * 2 * sizeof(float)), "r"(smem_addr));
@@ -206,15 +218,6 @@ inline void vx_wmma_load_a(volatile const T *smem_A, const int local_k,
asm volatile("flw f5, %0(%1)" :: "i"(smem_AS_cols * 5 * sizeof(float)), "r"(smem_addr));
asm volatile("flw f6, %0(%1)" :: "i"(smem_AS_cols * 6 * sizeof(float)), "r"(smem_addr));
asm volatile("flw f7, %0(%1)" :: "i"(smem_AS_cols * 7 * sizeof(float)), "r"(smem_addr));
// asm volatile("flw f0, %0" ::"m"(smem_A[((local_k + 0) * smem_AS_cols) + (WM * warp_row + TCM * wm_iter) + row]));
// asm volatile("flw f1, %0" ::"m"(smem_A[((local_k + 1) * smem_AS_cols) + (WM * warp_row + TCM * wm_iter) + row]));
// asm volatile("flw f2, %0" ::"m"(smem_A[((local_k + 2) * smem_AS_cols) + (WM * warp_row + TCM * wm_iter) + row]));
// asm volatile("flw f3, %0" ::"m"(smem_A[((local_k + 3) * smem_AS_cols) + (WM * warp_row + TCM * wm_iter) + row]));
// asm volatile("flw f4, %0" ::"m"(smem_A[((local_k + 4) * smem_AS_cols) + (WM * warp_row + TCM * wm_iter) + row]));
// asm volatile("flw f5, %0" ::"m"(smem_A[((local_k + 5) * smem_AS_cols) + (WM * warp_row + TCM * wm_iter) + row]));
// asm volatile("flw f6, %0" ::"m"(smem_A[((local_k + 6) * smem_AS_cols) + (WM * warp_row + TCM * wm_iter) + row]));
// asm volatile("flw f7, %0" ::"m"(smem_A[((local_k + 7) * smem_AS_cols) + (WM * warp_row + TCM * wm_iter) + row]));
}
}
@@ -230,14 +233,21 @@ inline void vx_wmma_load_b(const volatile T *smem_B, const int local_k,
int col = 0;
map_operand(tid, row, col);
// see comment in vx_wmma_load_a
constexpr uint32_t packed_factor = (std::is_same_v<T, float16_t> ? 2 : 1);
constexpr uint32_t BN_adjusted = BN / packed_factor;
constexpr int smem_B_rows = BK;
constexpr int smem_B_cols = BN;
constexpr int smem_B_cols = BN_adjusted;
// f8-f15 stores a single column of B
const volatile uint8_t *smem_addr;
smem_addr = reinterpret_cast<const volatile uint8_t *>(
&smem_B[((local_k + 0) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) +
col]);
&reinterpret_cast<const volatile float *>(
smem_B)[((local_k + 0) * smem_B_cols) +
(WN * warp_col + TCN * wn_iter) + col]);
// step to the next row
// threads read from different columns; no bank conflicts
asm volatile("flw f8, %0(%1)" :: "i"(smem_B_cols * 0 * sizeof(float)), "r"(smem_addr));
asm volatile("flw f9, %0(%1)" :: "i"(smem_B_cols * 1 * sizeof(float)), "r"(smem_addr));
asm volatile("flw f10, %0(%1)" :: "i"(smem_B_cols * 2 * sizeof(float)), "r"(smem_addr));
@@ -246,15 +256,6 @@ inline void vx_wmma_load_b(const volatile T *smem_B, const int local_k,
asm volatile("flw f13, %0(%1)" :: "i"(smem_B_cols * 5 * sizeof(float)), "r"(smem_addr));
asm volatile("flw f14, %0(%1)" :: "i"(smem_B_cols * 6 * sizeof(float)), "r"(smem_addr));
asm volatile("flw f15, %0(%1)" :: "i"(smem_B_cols * 7 * sizeof(float)), "r"(smem_addr));
// asm volatile("flw f8, %0" ::"m"(smem_B[((local_k + 0) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col]));
// asm volatile("flw f9, %0" ::"m"(smem_B[((local_k + 1) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col]));
// asm volatile("flw f10, %0" ::"m"(smem_B[((local_k + 2) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col]));
// asm volatile("flw f11, %0" ::"m"(smem_B[((local_k + 3) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col]));
// asm volatile("flw f12, %0" ::"m"(smem_B[((local_k + 4) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col]));
// asm volatile("flw f13, %0" ::"m"(smem_B[((local_k + 5) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col]));
// asm volatile("flw f14, %0" ::"m"(smem_B[((local_k + 6) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col]));
// asm volatile("flw f15, %0" ::"m"(smem_B[((local_k + 7) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col]));
}
inline void initialize_C(const int dest_reg) {
@@ -280,11 +281,10 @@ inline void initialize_C(const int dest_reg) {
}
}
template <typename T>
inline void write_results(const int thread_in_warp, const int warp_col,
const int warp_row, const int wn_iter,
const int wm_iter, const int dim_n,
T *C, const int threadblock_id_x,
float *C, const int threadblock_id_x,
const int threadblock_id_y) {
int tid = thread_in_warp;
@@ -296,14 +296,14 @@ inline void write_results(const int thread_in_warp, const int warp_col,
int local_row = (WM * warp_row + TCM * wm_iter) + tid_row;
int local_col = (WN * warp_col + TCN * wn_iter) + tid_col;
T *global_offset_C =
float *global_offset_C =
C + (BM * threadblock_id_y) * dim_n + BN * threadblock_id_x;
// @perf: this likely causes a lot of gmem bank conflicts
if (wm_iter == 0) {
volatile uint8_t *gmem_addr = reinterpret_cast<volatile uint8_t *>(
&global_offset_C[dim_n * (local_row + 0) + (local_col + 0)]);
volatile uint8_t *gmem_addr_tmp = gmem_addr + (2 * dim_n) * sizeof(T);
volatile uint8_t *gmem_addr_tmp = gmem_addr + (2 * dim_n) * sizeof(float);
asm volatile ("fsw f16, %0(%1)" :: "i"(0 * sizeof(float)), "r"(gmem_addr));
asm volatile ("fsw f17, %0(%1)" :: "i"(1 * sizeof(float)), "r"(gmem_addr));
asm volatile ("fsw f18, %0(%1)" :: "i"(0 * sizeof(float)), "r"(gmem_addr_tmp));
@@ -323,7 +323,7 @@ inline void write_results(const int thread_in_warp, const int warp_col,
} else {
volatile uint8_t *gmem_addr = reinterpret_cast<volatile uint8_t *>(
&global_offset_C[dim_n * (local_row + 0) + (local_col + 0)]);
volatile uint8_t *gmem_addr_tmp = gmem_addr + (2 * dim_n) * sizeof(T);
volatile uint8_t *gmem_addr_tmp = gmem_addr + (2 * dim_n) * sizeof(float);
asm volatile ("fsw f24, %0(%1)" :: "i"(0 * sizeof(float)), "r"(gmem_addr));
asm volatile ("fsw f25, %0(%1)" :: "i"(1 * sizeof(float)), "r"(gmem_addr));
asm volatile ("fsw f26, %0(%1)" :: "i"(0 * sizeof(float)), "r"(gmem_addr_tmp));