sgemm_tcore: Support data move for fp16-packed elements

Since core does not support memory accesses to non-word-aligned
addresses, pack fp16 elements in pairs into fp32 values, and do regular
tile movement with conditionally compressed column dimensions.
Perf seems to stay the same for fp32 256x256.
This commit is contained in:
Hansung Kim
2024-07-30 18:07:34 -07:00
parent 7f26548724
commit 88cddc2b66
2 changed files with 155 additions and 132 deletions

View File

@@ -35,20 +35,27 @@
#define BK_LOOP 1
// Whether to transpose smem A tile at GMEM->SMEM (produce), or SMEM->RF
// (consume). This is because the tensor core expects the A tile to be stored
// in column-major order in SMEM, whereas it is stored row-major in GMEM.
// in column-major order in SMEM, whereas it will be ultimately stored in
// row-major in the RF.
//
// For correctness, only one of either should be 1. To model the case where
// the A matrix is already stored transposed in GMEM ("TN" kernel), set
// both to 0.
//
// For reference, PRODUCE 1 CONSUME 0 generates the performant NN kernel.
// For correctness, only one of either should be 1. E.g., PRODUCE 1 CONSUME 0
// generates the NN kernel where both A and B are stored row-major in GMEM.
// To model the case where the A matrix is already stored transposed in GMEM
// ("TN" kernel), set both to 0.
#define TRANSPOSE_AT_PRODUCE 1
#define TRANSPOSE_AT_CONSUME 0
// GMEM_COALESCED sets bank conflict-free accesses for
// 1: GMEM loads of A matrix
// 0: SMEM stores of A matrix
// GMEM_COALESCED: When TRANSPOSE_AT_PRODUCE == 1 (i.e. transpose at
// GMEM->SMEM), determines whether we do bank-conflict-free accesses for
// 1: GMEM loads of A matrix, or
// 0: SMEM stores of A matrix.
//
// Usually, GMEM_COALESCED==1 yields better performance since the memory
// behavior of GMEM is more sensitive to bank conflicts.
#define GMEM_COALESCED_A 1
// "fake" fp16 type that only has the correct data width.
using float16_t = uint16_t;
inline constexpr void map_operand_32lanes(const int tid, int &row, int &col) {
const int tg = tid / 4;
@@ -153,14 +160,23 @@ inline void vx_wmma_load_a(volatile const T *smem_A, const int local_k,
const int tid = thread_in_warp;
const int tg = tid / 4;
// TODO: this is duplicately computed between vx_wmma_load_a and vx_wmma_load_b
// @perf: this is duplicately computed in vx_wmma_load_a and vx_wmma_load_b
int row = 0;
int col = 0;
map_operand(tid, row, col);
// In fp16 mode, bit-pack two fp16 elements into each fp32 element, and do
// data movement at the fp32 granularity. Assuming that the matrix is stored
// row-major in GMEM, the packed fp16 pairs belong to the same row,
// neighboring columns; therefore, it essentially becomes equivalent to
// moving a fp32 matrix whose column dimensions (dim_k/BK/k) are compressed
// by a factor of two.
constexpr uint32_t packed_factor = (std::is_same_v<T, float16_t> ? 2 : 1);
constexpr uint32_t BK_adjusted = BK / packed_factor;
constexpr int smem_A_rows = BM;
constexpr int smem_A_cols = BK;
constexpr int smem_AS_rows = BK;
constexpr int smem_A_cols = BK_adjusted;
constexpr int smem_AS_rows = BK_adjusted;
constexpr int smem_AS_cols = BM;
if constexpr (TRANSPOSE_AT_CONSUME) {
@@ -170,11 +186,11 @@ inline void vx_wmma_load_a(volatile const T *smem_A, const int local_k,
// f8-f15 stores a single row of A
const volatile uint8_t *smem_addr;
smem_addr = reinterpret_cast<const volatile uint8_t *>(
&smem_A[(WM * warp_row + TCM * wm_iter + row) * smem_A_cols + local_k]);
// NOTE: stride is fixed to word size , i.e. sizeof(float) = 4,
// regardless of fp16 or fp32. Since Vortex core does not support fp16,
// load things at word granularity and reinterpret bits inside the tensor
// core.
&reinterpret_cast<const volatile float *>(
smem_A)[(WM * warp_row + TCM * wm_iter + row) * smem_A_cols +
local_k]);
// step to the next column
// threads read from different rows; bank conflicts
asm volatile("flw f0, %0(%1)" ::"i"(0 * sizeof(float)), "r"(smem_addr));
asm volatile("flw f1, %0(%1)" ::"i"(1 * sizeof(float)), "r"(smem_addr));
asm volatile("flw f2, %0(%1)" ::"i"(2 * sizeof(float)), "r"(smem_addr));
@@ -183,21 +199,17 @@ inline void vx_wmma_load_a(volatile const T *smem_A, const int local_k,
asm volatile("flw f5, %0(%1)" ::"i"(5 * sizeof(float)), "r"(smem_addr));
asm volatile("flw f6, %0(%1)" ::"i"(6 * sizeof(float)), "r"(smem_addr));
asm volatile("flw f7, %0(%1)" ::"i"(7 * sizeof(float)), "r"(smem_addr));
// asm volatile("flw f0, %0" ::"m"(smem_A[A_offset + (local_k + 0)]));
// asm volatile("flw f1, %0" ::"m"(smem_A[A_offset + (local_k + 1)]));
// asm volatile("flw f2, %0" ::"m"(smem_A[A_offset + (local_k + 2)]));
// asm volatile("flw f3, %0" ::"m"(smem_A[A_offset + (local_k + 3)]));
// asm volatile("flw f4, %0" ::"m"(smem_A[A_offset + (local_k + 4)]));
// asm volatile("flw f5, %0" ::"m"(smem_A[A_offset + (local_k + 5)]));
// asm volatile("flw f6, %0" ::"m"(smem_A[A_offset + (local_k + 6)]));
// asm volatile("flw f7, %0" ::"m"(smem_A[A_offset + (local_k + 7)]));
} else {
// read smem A tile as-is; bank-conflict-free AS load
// smem A tile is stored column-major
// f8-f15 stores a single row of A
const volatile uint8_t *smem_addr;
smem_addr = reinterpret_cast<const volatile uint8_t *>(
&smem_A[((local_k + 0) * smem_AS_cols) +
(WM * warp_row + TCM * wm_iter) + row]);
&reinterpret_cast<const volatile float *>(
smem_A)[((local_k + 0) * smem_AS_cols) +
(WM * warp_row + TCM * wm_iter) + row]);
// step to the next row
// threads read from different columns; no bank conflicts
asm volatile("flw f0, %0(%1)" :: "i"(smem_AS_cols * 0 * sizeof(float)), "r"(smem_addr));
asm volatile("flw f1, %0(%1)" :: "i"(smem_AS_cols * 1 * sizeof(float)), "r"(smem_addr));
asm volatile("flw f2, %0(%1)" :: "i"(smem_AS_cols * 2 * sizeof(float)), "r"(smem_addr));
@@ -206,15 +218,6 @@ inline void vx_wmma_load_a(volatile const T *smem_A, const int local_k,
asm volatile("flw f5, %0(%1)" :: "i"(smem_AS_cols * 5 * sizeof(float)), "r"(smem_addr));
asm volatile("flw f6, %0(%1)" :: "i"(smem_AS_cols * 6 * sizeof(float)), "r"(smem_addr));
asm volatile("flw f7, %0(%1)" :: "i"(smem_AS_cols * 7 * sizeof(float)), "r"(smem_addr));
// asm volatile("flw f0, %0" ::"m"(smem_A[((local_k + 0) * smem_AS_cols) + (WM * warp_row + TCM * wm_iter) + row]));
// asm volatile("flw f1, %0" ::"m"(smem_A[((local_k + 1) * smem_AS_cols) + (WM * warp_row + TCM * wm_iter) + row]));
// asm volatile("flw f2, %0" ::"m"(smem_A[((local_k + 2) * smem_AS_cols) + (WM * warp_row + TCM * wm_iter) + row]));
// asm volatile("flw f3, %0" ::"m"(smem_A[((local_k + 3) * smem_AS_cols) + (WM * warp_row + TCM * wm_iter) + row]));
// asm volatile("flw f4, %0" ::"m"(smem_A[((local_k + 4) * smem_AS_cols) + (WM * warp_row + TCM * wm_iter) + row]));
// asm volatile("flw f5, %0" ::"m"(smem_A[((local_k + 5) * smem_AS_cols) + (WM * warp_row + TCM * wm_iter) + row]));
// asm volatile("flw f6, %0" ::"m"(smem_A[((local_k + 6) * smem_AS_cols) + (WM * warp_row + TCM * wm_iter) + row]));
// asm volatile("flw f7, %0" ::"m"(smem_A[((local_k + 7) * smem_AS_cols) + (WM * warp_row + TCM * wm_iter) + row]));
}
}
@@ -230,14 +233,21 @@ inline void vx_wmma_load_b(const volatile T *smem_B, const int local_k,
int col = 0;
map_operand(tid, row, col);
// see comment in vx_wmma_load_a
constexpr uint32_t packed_factor = (std::is_same_v<T, float16_t> ? 2 : 1);
constexpr uint32_t BN_adjusted = BN / packed_factor;
constexpr int smem_B_rows = BK;
constexpr int smem_B_cols = BN;
constexpr int smem_B_cols = BN_adjusted;
// f8-f15 stores a single column of B
const volatile uint8_t *smem_addr;
smem_addr = reinterpret_cast<const volatile uint8_t *>(
&smem_B[((local_k + 0) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) +
col]);
&reinterpret_cast<const volatile float *>(
smem_B)[((local_k + 0) * smem_B_cols) +
(WN * warp_col + TCN * wn_iter) + col]);
// step to the next row
// threads read from different columns; no bank conflicts
asm volatile("flw f8, %0(%1)" :: "i"(smem_B_cols * 0 * sizeof(float)), "r"(smem_addr));
asm volatile("flw f9, %0(%1)" :: "i"(smem_B_cols * 1 * sizeof(float)), "r"(smem_addr));
asm volatile("flw f10, %0(%1)" :: "i"(smem_B_cols * 2 * sizeof(float)), "r"(smem_addr));
@@ -246,15 +256,6 @@ inline void vx_wmma_load_b(const volatile T *smem_B, const int local_k,
asm volatile("flw f13, %0(%1)" :: "i"(smem_B_cols * 5 * sizeof(float)), "r"(smem_addr));
asm volatile("flw f14, %0(%1)" :: "i"(smem_B_cols * 6 * sizeof(float)), "r"(smem_addr));
asm volatile("flw f15, %0(%1)" :: "i"(smem_B_cols * 7 * sizeof(float)), "r"(smem_addr));
// asm volatile("flw f8, %0" ::"m"(smem_B[((local_k + 0) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col]));
// asm volatile("flw f9, %0" ::"m"(smem_B[((local_k + 1) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col]));
// asm volatile("flw f10, %0" ::"m"(smem_B[((local_k + 2) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col]));
// asm volatile("flw f11, %0" ::"m"(smem_B[((local_k + 3) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col]));
// asm volatile("flw f12, %0" ::"m"(smem_B[((local_k + 4) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col]));
// asm volatile("flw f13, %0" ::"m"(smem_B[((local_k + 5) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col]));
// asm volatile("flw f14, %0" ::"m"(smem_B[((local_k + 6) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col]));
// asm volatile("flw f15, %0" ::"m"(smem_B[((local_k + 7) * smem_B_cols) + (WN * warp_col + TCN * wn_iter) + col]));
}
inline void initialize_C(const int dest_reg) {
@@ -280,11 +281,10 @@ inline void initialize_C(const int dest_reg) {
}
}
template <typename T>
inline void write_results(const int thread_in_warp, const int warp_col,
const int warp_row, const int wn_iter,
const int wm_iter, const int dim_n,
T *C, const int threadblock_id_x,
float *C, const int threadblock_id_x,
const int threadblock_id_y) {
int tid = thread_in_warp;
@@ -296,14 +296,14 @@ inline void write_results(const int thread_in_warp, const int warp_col,
int local_row = (WM * warp_row + TCM * wm_iter) + tid_row;
int local_col = (WN * warp_col + TCN * wn_iter) + tid_col;
T *global_offset_C =
float *global_offset_C =
C + (BM * threadblock_id_y) * dim_n + BN * threadblock_id_x;
// @perf: this likely causes a lot of gmem bank conflicts
if (wm_iter == 0) {
volatile uint8_t *gmem_addr = reinterpret_cast<volatile uint8_t *>(
&global_offset_C[dim_n * (local_row + 0) + (local_col + 0)]);
volatile uint8_t *gmem_addr_tmp = gmem_addr + (2 * dim_n) * sizeof(T);
volatile uint8_t *gmem_addr_tmp = gmem_addr + (2 * dim_n) * sizeof(float);
asm volatile ("fsw f16, %0(%1)" :: "i"(0 * sizeof(float)), "r"(gmem_addr));
asm volatile ("fsw f17, %0(%1)" :: "i"(1 * sizeof(float)), "r"(gmem_addr));
asm volatile ("fsw f18, %0(%1)" :: "i"(0 * sizeof(float)), "r"(gmem_addr_tmp));
@@ -323,7 +323,7 @@ inline void write_results(const int thread_in_warp, const int warp_col,
} else {
volatile uint8_t *gmem_addr = reinterpret_cast<volatile uint8_t *>(
&global_offset_C[dim_n * (local_row + 0) + (local_col + 0)]);
volatile uint8_t *gmem_addr_tmp = gmem_addr + (2 * dim_n) * sizeof(T);
volatile uint8_t *gmem_addr_tmp = gmem_addr + (2 * dim_n) * sizeof(float);
asm volatile ("fsw f24, %0(%1)" :: "i"(0 * sizeof(float)), "r"(gmem_addr));
asm volatile ("fsw f25, %0(%1)" :: "i"(1 * sizeof(float)), "r"(gmem_addr));
asm volatile ("fsw f26, %0(%1)" :: "i"(0 * sizeof(float)), "r"(gmem_addr_tmp));