sgemm_tcore: Fix addr gen for GMEM->SMEM for M-major A
This fixes correctness for TRANSPOSE_AT_PRODUCE/COLUMN=0/0, provided the matrices are already stored in the correct layout in GMEM.
This commit is contained in:
@@ -158,6 +158,8 @@ template <typename T>
|
||||
inline void vx_wmma_load_a(volatile const T *smem_A, const int local_k,
|
||||
const int warp_row, const int wm_iter,
|
||||
const int thread_in_warp) {
|
||||
asm volatile ("vx_wmma_load_a_start_%=:" :: );
|
||||
|
||||
const int tid = thread_in_warp;
|
||||
const int tg = tid / 4;
|
||||
|
||||
@@ -174,17 +176,13 @@ inline void vx_wmma_load_a(volatile const T *smem_A, const int local_k,
|
||||
// by a factor of two.
|
||||
constexpr int packed_factor = (std::is_same_v<T, float16_t> ? 2 : 1);
|
||||
constexpr int BK_adjusted = BK / packed_factor;
|
||||
constexpr int BM_adjusted = BM / packed_factor;
|
||||
const int local_k_adjusted = local_k / packed_factor;
|
||||
|
||||
constexpr int smem_A_rows = BM;
|
||||
constexpr int smem_A_cols = BK_adjusted;
|
||||
constexpr int smem_AS_rows = BK_adjusted;
|
||||
constexpr int smem_AS_cols = BM;
|
||||
// constexpr int smem_AS_rows = BK;
|
||||
// constexpr int smem_AS_cols = BM_adjusted;
|
||||
|
||||
if constexpr (TRANSPOSE_AT_CONSUME) {
|
||||
// A is stored K-major in smem
|
||||
constexpr int smem_A_rows = BM;
|
||||
constexpr int smem_A_cols = BK_adjusted;
|
||||
|
||||
// int A_offset = (WM * warp_row + TCM * wm_iter + row) * smem_A_cols;
|
||||
|
||||
// @perf: bank conflicts
|
||||
@@ -205,15 +203,16 @@ inline void vx_wmma_load_a(volatile const T *smem_A, const int local_k,
|
||||
asm volatile("flw f6, %0(%1)" ::"i"(6 * sizeof(float)), "r"(smem_addr));
|
||||
asm volatile("flw f7, %0(%1)" ::"i"(7 * sizeof(float)), "r"(smem_addr));
|
||||
} else {
|
||||
// read smem A tile as-is; bank-conflict-free AS load
|
||||
// smem A tile is stored column-major
|
||||
// f8-f15 stores a single row of A
|
||||
// A is stored M-major in smem
|
||||
constexpr int smem_AS_rows = BK_adjusted;
|
||||
constexpr int smem_AS_cols = BM;
|
||||
|
||||
const volatile uint8_t *smem_addr;
|
||||
smem_addr = reinterpret_cast<const volatile uint8_t *>(
|
||||
&reinterpret_cast<const volatile float *>(
|
||||
smem_A)[((local_k_adjusted + 0) * smem_AS_cols) +
|
||||
(WM * warp_row + TCM * wm_iter) + row]);
|
||||
// step to the next row
|
||||
// f8-f15 stores a single row of A
|
||||
// threads read from different columns; no bank conflicts
|
||||
asm volatile("flw f0, %0(%1)" :: "i"(smem_AS_cols * 0 * sizeof(float)), "r"(smem_addr));
|
||||
asm volatile("flw f1, %0(%1)" :: "i"(smem_AS_cols * 1 * sizeof(float)), "r"(smem_addr));
|
||||
@@ -224,6 +223,8 @@ inline void vx_wmma_load_a(volatile const T *smem_A, const int local_k,
|
||||
asm volatile("flw f6, %0(%1)" :: "i"(smem_AS_cols * 6 * sizeof(float)), "r"(smem_addr));
|
||||
asm volatile("flw f7, %0(%1)" :: "i"(smem_AS_cols * 7 * sizeof(float)), "r"(smem_addr));
|
||||
}
|
||||
|
||||
asm volatile ("vx_wmma_load_a_finish_%=:" :: );
|
||||
}
|
||||
|
||||
// `local_k` is assumed to be multiple of TCK
|
||||
@@ -231,6 +232,8 @@ template <typename T>
|
||||
inline void vx_wmma_load_b(const volatile T *smem_B, const int local_k,
|
||||
const int warp_col, const int wn_iter,
|
||||
const int thread_in_warp) {
|
||||
asm volatile ("vx_wmma_load_b_start_%=:" :: );
|
||||
|
||||
const int tid = thread_in_warp;
|
||||
const int tg = tid / 4;
|
||||
|
||||
@@ -244,18 +247,16 @@ inline void vx_wmma_load_b(const volatile T *smem_B, const int local_k,
|
||||
constexpr int BN_adjusted = BN / packed_factor;
|
||||
const int local_k_adjusted = local_k / packed_factor;
|
||||
|
||||
// constexpr int smem_B_rows = BK;
|
||||
// constexpr int smem_B_cols = BN_adjusted;
|
||||
// B is stored N-major in smem
|
||||
constexpr int smem_B_rows = BK_adjusted;
|
||||
constexpr int smem_B_cols = BN;
|
||||
|
||||
// f8-f15 stores a single column of B
|
||||
const volatile uint8_t *smem_addr;
|
||||
smem_addr = reinterpret_cast<const volatile uint8_t *>(
|
||||
&reinterpret_cast<const volatile float *>(
|
||||
smem_B)[((local_k_adjusted + 0) * smem_B_cols) +
|
||||
(WN * warp_col + TCN * wn_iter) + col]);
|
||||
// step to the next row
|
||||
// f8-f15 stores a single column of B
|
||||
// threads read from different columns; no bank conflicts
|
||||
asm volatile("flw f8, %0(%1)" :: "i"(smem_B_cols * 0 * sizeof(float)), "r"(smem_addr));
|
||||
asm volatile("flw f9, %0(%1)" :: "i"(smem_B_cols * 1 * sizeof(float)), "r"(smem_addr));
|
||||
@@ -265,6 +266,8 @@ inline void vx_wmma_load_b(const volatile T *smem_B, const int local_k,
|
||||
asm volatile("flw f13, %0(%1)" :: "i"(smem_B_cols * 5 * sizeof(float)), "r"(smem_addr));
|
||||
asm volatile("flw f14, %0(%1)" :: "i"(smem_B_cols * 6 * sizeof(float)), "r"(smem_addr));
|
||||
asm volatile("flw f15, %0(%1)" :: "i"(smem_B_cols * 7 * sizeof(float)), "r"(smem_addr));
|
||||
|
||||
asm volatile ("vx_wmma_load_b_finish_%=:" :: );
|
||||
}
|
||||
|
||||
inline void initialize_C(const int dest_reg) {
|
||||
@@ -295,6 +298,8 @@ inline void write_results(const int thread_in_warp, const int warp_col,
|
||||
const int wm_iter, const int dim_n,
|
||||
float *C, const int threadblock_id_x,
|
||||
const int threadblock_id_y) {
|
||||
asm volatile ("write_results_start_%=:" :: );
|
||||
|
||||
int tid = thread_in_warp;
|
||||
|
||||
// these are [0, TCM/TCN)
|
||||
@@ -342,6 +347,8 @@ inline void write_results(const int thread_in_warp, const int warp_col,
|
||||
asm volatile ("fsw f30, %0(%1)" :: "i"(4 * sizeof(float)), "r"(gmem_addr_tmp));
|
||||
asm volatile ("fsw f31, %0(%1)" :: "i"(5 * sizeof(float)), "r"(gmem_addr_tmp));
|
||||
}
|
||||
|
||||
asm volatile ("write_results_finish_%=:" :: );
|
||||
}
|
||||
|
||||
inline void threadblock_barrier(const uint32_t barrier_id, const uint32_t count) {
|
||||
|
||||
Reference in New Issue
Block a user