sgemm_tcore: Fix fp16 addr gen in vx_wmma_load
This commit is contained in:
@@ -37,6 +37,9 @@
|
|||||||
#error "threadblock size too big for cluster"
|
#error "threadblock size too big for cluster"
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
// using float_type = float;
|
||||||
|
using float_type = float16_t;
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
inline void global_dmem_load(const uint32_t dim_n, const uint32_t dim_k,
|
inline void global_dmem_load(const uint32_t dim_n, const uint32_t dim_k,
|
||||||
const uint32_t k, const T *A, const T *B,
|
const uint32_t k, const T *A, const T *B,
|
||||||
@@ -391,7 +394,7 @@ inline void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
|
|||||||
}
|
}
|
||||||
|
|
||||||
#pragma GCC unroll 1
|
#pragma GCC unroll 1
|
||||||
for (uint32_t block_k = 0; (block_k * BK) < (dim_k); block_k++) {
|
for (uint32_t block_k = 0; (block_k * BK) < dim_k; block_k++) {
|
||||||
|
|
||||||
// producer code: GMEM->SMEM memory movement
|
// producer code: GMEM->SMEM memory movement
|
||||||
// ---------------------------------------------------------------------
|
// ---------------------------------------------------------------------
|
||||||
@@ -572,8 +575,6 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
|||||||
const uint32_t problem_size = (dim_m * dim_n) / (ELEM_PER_THREAD);
|
const uint32_t problem_size = (dim_m * dim_n) / (ELEM_PER_THREAD);
|
||||||
const uint32_t num_threadblocks = problem_size / threads_per_threadblock;
|
const uint32_t num_threadblocks = problem_size / threads_per_threadblock;
|
||||||
|
|
||||||
using float_type = float16_t;
|
|
||||||
|
|
||||||
// "static" shared memory allocation. This would determine threadblock
|
// "static" shared memory allocation. This would determine threadblock
|
||||||
// occupancy of a single cluster
|
// occupancy of a single cluster
|
||||||
uint8_t *sharedmem_per_threadblock = reinterpret_cast<uint8_t *>(
|
uint8_t *sharedmem_per_threadblock = reinterpret_cast<uint8_t *>(
|
||||||
|
|||||||
@@ -25,7 +25,7 @@
|
|||||||
#define WN 8
|
#define WN 8
|
||||||
#define TCM 8
|
#define TCM 8
|
||||||
#define TCN 8
|
#define TCN 8
|
||||||
#define TCK 8
|
#define TCK 16
|
||||||
#define WMITER (WM / TCM)
|
#define WMITER (WM / TCM)
|
||||||
#define WNITER (WN / TCN)
|
#define WNITER (WN / TCN)
|
||||||
#define ELEM_PER_THREAD (WMITER * WNITER * (TCM * TCN) / NUM_THREADS)
|
#define ELEM_PER_THREAD (WMITER * WNITER * (TCM * TCN) / NUM_THREADS)
|
||||||
@@ -40,9 +40,9 @@
|
|||||||
//
|
//
|
||||||
// For correctness, only one of either should be 1. E.g., PRODUCE 1 CONSUME 0
|
// For correctness, only one of either should be 1. E.g., PRODUCE 1 CONSUME 0
|
||||||
// generates the NN kernel where both A and B are stored row-major in GMEM.
|
// generates the NN kernel where both A and B are stored row-major in GMEM.
|
||||||
// To model the case where the A matrix is already stored transposed in GMEM
|
// To model the case where the A matrix is already stored column-major in GMEM,
|
||||||
// ("TN" kernel), set both to 0.
|
// set both to 0.
|
||||||
#define TRANSPOSE_AT_PRODUCE 1
|
#define TRANSPOSE_AT_PRODUCE 0
|
||||||
#define TRANSPOSE_AT_CONSUME 0
|
#define TRANSPOSE_AT_CONSUME 0
|
||||||
// GMEM_COALESCED: When TRANSPOSE_AT_PRODUCE == 1 (i.e. transpose at
|
// GMEM_COALESCED: When TRANSPOSE_AT_PRODUCE == 1 (i.e. transpose at
|
||||||
// GMEM->SMEM), determines whether we do bank-conflict-free accesses for
|
// GMEM->SMEM), determines whether we do bank-conflict-free accesses for
|
||||||
@@ -156,7 +156,8 @@ inline void vx_wmma(const int dest_reg) {
|
|||||||
// `local_k` is assumed to be multiple of TCK
|
// `local_k` is assumed to be multiple of TCK
|
||||||
template <typename T>
|
template <typename T>
|
||||||
inline void vx_wmma_load_a(volatile const T *smem_A, const int local_k,
|
inline void vx_wmma_load_a(volatile const T *smem_A, const int local_k,
|
||||||
const int warp_row, const int wm_iter, const int thread_in_warp) {
|
const int warp_row, const int wm_iter,
|
||||||
|
const int thread_in_warp) {
|
||||||
const int tid = thread_in_warp;
|
const int tid = thread_in_warp;
|
||||||
const int tg = tid / 4;
|
const int tg = tid / 4;
|
||||||
|
|
||||||
@@ -171,13 +172,17 @@ inline void vx_wmma_load_a(volatile const T *smem_A, const int local_k,
|
|||||||
// neighboring columns; therefore, it essentially becomes equivalent to
|
// neighboring columns; therefore, it essentially becomes equivalent to
|
||||||
// moving a fp32 matrix whose column dimensions (dim_k/BK/k) are compressed
|
// moving a fp32 matrix whose column dimensions (dim_k/BK/k) are compressed
|
||||||
// by a factor of two.
|
// by a factor of two.
|
||||||
constexpr uint32_t packed_factor = (std::is_same_v<T, float16_t> ? 2 : 1);
|
constexpr int packed_factor = (std::is_same_v<T, float16_t> ? 2 : 1);
|
||||||
constexpr uint32_t BK_adjusted = BK / packed_factor;
|
constexpr int BK_adjusted = BK / packed_factor;
|
||||||
|
constexpr int BM_adjusted = BM / packed_factor;
|
||||||
|
const int local_k_adjusted = local_k / packed_factor;
|
||||||
|
|
||||||
constexpr int smem_A_rows = BM;
|
constexpr int smem_A_rows = BM;
|
||||||
constexpr int smem_A_cols = BK_adjusted;
|
constexpr int smem_A_cols = BK_adjusted;
|
||||||
constexpr int smem_AS_rows = BK_adjusted;
|
constexpr int smem_AS_rows = BK_adjusted;
|
||||||
constexpr int smem_AS_cols = BM;
|
constexpr int smem_AS_cols = BM;
|
||||||
|
// constexpr int smem_AS_rows = BK;
|
||||||
|
// constexpr int smem_AS_cols = BM_adjusted;
|
||||||
|
|
||||||
if constexpr (TRANSPOSE_AT_CONSUME) {
|
if constexpr (TRANSPOSE_AT_CONSUME) {
|
||||||
// int A_offset = (WM * warp_row + TCM * wm_iter + row) * smem_A_cols;
|
// int A_offset = (WM * warp_row + TCM * wm_iter + row) * smem_A_cols;
|
||||||
@@ -188,7 +193,7 @@ inline void vx_wmma_load_a(volatile const T *smem_A, const int local_k,
|
|||||||
smem_addr = reinterpret_cast<const volatile uint8_t *>(
|
smem_addr = reinterpret_cast<const volatile uint8_t *>(
|
||||||
&reinterpret_cast<const volatile float *>(
|
&reinterpret_cast<const volatile float *>(
|
||||||
smem_A)[(WM * warp_row + TCM * wm_iter + row) * smem_A_cols +
|
smem_A)[(WM * warp_row + TCM * wm_iter + row) * smem_A_cols +
|
||||||
local_k]);
|
local_k /* FIXME: adjust for fp16? */]);
|
||||||
// step to the next column
|
// step to the next column
|
||||||
// threads read from different rows; bank conflicts
|
// threads read from different rows; bank conflicts
|
||||||
asm volatile("flw f0, %0(%1)" ::"i"(0 * sizeof(float)), "r"(smem_addr));
|
asm volatile("flw f0, %0(%1)" ::"i"(0 * sizeof(float)), "r"(smem_addr));
|
||||||
@@ -206,7 +211,7 @@ inline void vx_wmma_load_a(volatile const T *smem_A, const int local_k,
|
|||||||
const volatile uint8_t *smem_addr;
|
const volatile uint8_t *smem_addr;
|
||||||
smem_addr = reinterpret_cast<const volatile uint8_t *>(
|
smem_addr = reinterpret_cast<const volatile uint8_t *>(
|
||||||
&reinterpret_cast<const volatile float *>(
|
&reinterpret_cast<const volatile float *>(
|
||||||
smem_A)[((local_k + 0) * smem_AS_cols) +
|
smem_A)[((local_k_adjusted + 0) * smem_AS_cols) +
|
||||||
(WM * warp_row + TCM * wm_iter) + row]);
|
(WM * warp_row + TCM * wm_iter) + row]);
|
||||||
// step to the next row
|
// step to the next row
|
||||||
// threads read from different columns; no bank conflicts
|
// threads read from different columns; no bank conflicts
|
||||||
@@ -234,17 +239,21 @@ inline void vx_wmma_load_b(const volatile T *smem_B, const int local_k,
|
|||||||
map_operand(tid, row, col);
|
map_operand(tid, row, col);
|
||||||
|
|
||||||
// see comment in vx_wmma_load_a
|
// see comment in vx_wmma_load_a
|
||||||
constexpr uint32_t packed_factor = (std::is_same_v<T, float16_t> ? 2 : 1);
|
constexpr int packed_factor = (std::is_same_v<T, float16_t> ? 2 : 1);
|
||||||
constexpr uint32_t BN_adjusted = BN / packed_factor;
|
constexpr int BK_adjusted = BN / packed_factor;
|
||||||
|
constexpr int BN_adjusted = BN / packed_factor;
|
||||||
|
const int local_k_adjusted = local_k / packed_factor;
|
||||||
|
|
||||||
constexpr int smem_B_rows = BK;
|
// constexpr int smem_B_rows = BK;
|
||||||
constexpr int smem_B_cols = BN_adjusted;
|
// constexpr int smem_B_cols = BN_adjusted;
|
||||||
|
constexpr int smem_B_rows = BK_adjusted;
|
||||||
|
constexpr int smem_B_cols = BN;
|
||||||
|
|
||||||
// f8-f15 stores a single column of B
|
// f8-f15 stores a single column of B
|
||||||
const volatile uint8_t *smem_addr;
|
const volatile uint8_t *smem_addr;
|
||||||
smem_addr = reinterpret_cast<const volatile uint8_t *>(
|
smem_addr = reinterpret_cast<const volatile uint8_t *>(
|
||||||
&reinterpret_cast<const volatile float *>(
|
&reinterpret_cast<const volatile float *>(
|
||||||
smem_B)[((local_k + 0) * smem_B_cols) +
|
smem_B)[((local_k_adjusted + 0) * smem_B_cols) +
|
||||||
(WN * warp_col + TCN * wn_iter) + col]);
|
(WN * warp_col + TCN * wn_iter) + col]);
|
||||||
// step to the next row
|
// step to the next row
|
||||||
// threads read from different columns; no bank conflicts
|
// threads read from different columns; no bank conflicts
|
||||||
|
|||||||
Reference in New Issue
Block a user