sgemm_impl: Do proper addr gen and store for wgmma
This commit is contained in:
@@ -213,11 +213,23 @@ inline constexpr void map_c_8lanes(const int tid, int &row, int &col) {
|
||||
col += ((tid % 4) / 2) * 2;
|
||||
}
|
||||
|
||||
inline constexpr void map_c_8lanes_hopper(const int tid, int &row, int &col) {
|
||||
const int tg = tid / 2;
|
||||
|
||||
// FIXME wrong!!!
|
||||
row = 0;
|
||||
col = tid;
|
||||
}
|
||||
|
||||
inline constexpr void map_c(const int tid, int &row, int &col) {
|
||||
if constexpr (NUM_THREADS == 32) {
|
||||
map_c_32lanes(tid, row, col);
|
||||
} else if constexpr (NUM_THREADS == 8) {
|
||||
map_c_8lanes(tid, row, col);
|
||||
if constexpr (TENSOR_HOPPER) {
|
||||
map_c_8lanes_hopper(tid, row, col);
|
||||
} else {
|
||||
map_c_8lanes(tid, row, col);
|
||||
}
|
||||
} else {
|
||||
// FIXME: not allowed
|
||||
}
|
||||
@@ -233,10 +245,11 @@ inline void vx_wmma(const int dest_reg) {
|
||||
}
|
||||
}
|
||||
|
||||
inline void vx_wgmma() {
|
||||
inline void vx_wgmma(const uint32_t addr_a, const uint32_t addr_b) {
|
||||
// .insn r opcode6, func3, func7, rd, rs1, rs2
|
||||
// https://www.rowleydownload.co.uk/arm/documentation/gnu/as/RISC_002dV_002dFormats.html#RISC_002dV_002dFormats
|
||||
asm volatile (".insn r %0, 0, 0, x0, x0, x0" :: "i"(RISCV_CUSTOM3));
|
||||
asm volatile(".insn r %0, 0, 0, x0, %1, %2" ::"i"(RISCV_CUSTOM3), "r"(addr_a),
|
||||
"r"(addr_b));
|
||||
}
|
||||
|
||||
inline void vx_wgmma_wait() {
|
||||
@@ -503,6 +516,25 @@ inline void wmma_load_b(const volatile T *smem_B, const int local_k,
|
||||
asm volatile ("wmma_load_b_finish_%=:" :: );
|
||||
}
|
||||
|
||||
template <typename T, MemLayout layout_a, MemLayout layout_b,
|
||||
uint32_t leading_dim_a, uint32_t leading_dim_b, uint32_t tile_dim_k>
|
||||
inline void wgmma(volatile const T *smem_A, volatile const T *smem_B,
|
||||
const int local_k, const int warp_row, const int warp_col) {
|
||||
asm volatile("wgmma_start_%=:" ::);
|
||||
|
||||
const volatile uint8_t *addr_a =
|
||||
generate_smem_addr_a<T, layout_a, leading_dim_a>(smem_A, local_k,
|
||||
warp_row, 0, 0);
|
||||
const volatile uint8_t *addr_b =
|
||||
generate_smem_addr_b<T, layout_b, leading_dim_b, tile_dim_k>(
|
||||
smem_B, local_k, warp_col, 0, 0);
|
||||
|
||||
vx_wgmma(reinterpret_cast<uint32_t>(addr_a),
|
||||
reinterpret_cast<uint32_t>(addr_b));
|
||||
|
||||
asm volatile("wgmma_finish_%=:" ::);
|
||||
}
|
||||
|
||||
// Initialize the accumulator registers to zero before starting FMA operations
|
||||
// with the tensor cores.
|
||||
template <int accum_reg_set> inline void initialize_accum_regs() {
|
||||
@@ -527,6 +559,41 @@ template <int accum_reg_set> inline void initialize_accum_regs() {
|
||||
}
|
||||
}
|
||||
|
||||
inline void initialize_all_regs() {
|
||||
asm volatile("fmv.w.x f0, x0");
|
||||
asm volatile("fmv.w.x f1, x0");
|
||||
asm volatile("fmv.w.x f2, x0");
|
||||
asm volatile("fmv.w.x f3, x0");
|
||||
asm volatile("fmv.w.x f4, x0");
|
||||
asm volatile("fmv.w.x f5, x0");
|
||||
asm volatile("fmv.w.x f6, x0");
|
||||
asm volatile("fmv.w.x f7, x0");
|
||||
asm volatile("fmv.w.x f8, x0");
|
||||
asm volatile("fmv.w.x f9, x0");
|
||||
asm volatile("fmv.w.x f10, x0");
|
||||
asm volatile("fmv.w.x f11, x0");
|
||||
asm volatile("fmv.w.x f12, x0");
|
||||
asm volatile("fmv.w.x f13, x0");
|
||||
asm volatile("fmv.w.x f14, x0");
|
||||
asm volatile("fmv.w.x f15, x0");
|
||||
asm volatile("fmv.w.x f16, x0");
|
||||
asm volatile("fmv.w.x f17, x0");
|
||||
asm volatile("fmv.w.x f18, x0");
|
||||
asm volatile("fmv.w.x f19, x0");
|
||||
asm volatile("fmv.w.x f20, x0");
|
||||
asm volatile("fmv.w.x f21, x0");
|
||||
asm volatile("fmv.w.x f22, x0");
|
||||
asm volatile("fmv.w.x f23, x0");
|
||||
asm volatile("fmv.w.x f24, x0");
|
||||
asm volatile("fmv.w.x f25, x0");
|
||||
asm volatile("fmv.w.x f26, x0");
|
||||
asm volatile("fmv.w.x f27, x0");
|
||||
asm volatile("fmv.w.x f28, x0");
|
||||
asm volatile("fmv.w.x f29, x0");
|
||||
asm volatile("fmv.w.x f30, x0");
|
||||
asm volatile("fmv.w.x f31, x0");
|
||||
}
|
||||
|
||||
// `C` is expected to be in N-major layout.
|
||||
__attribute__((always_inline)) inline void
|
||||
wmma_load_accum(const int thread_in_warp, const int warp_col,
|
||||
@@ -622,6 +689,66 @@ wmma_store(const int thread_in_warp, const int warp_col, const int warp_row,
|
||||
asm volatile ("wmma_store_finish_%=:" :: );
|
||||
}
|
||||
|
||||
// Write out the matrix data stored in RF to memory
|
||||
__attribute__((always_inline)) inline void
|
||||
wgmma_store(const int thread_in_warp, const int warp_col, const int warp_row,
|
||||
const int dim_n, float *write_addr) {
|
||||
asm volatile ("wgmma_store_start_%=:" :: );
|
||||
|
||||
const int tid = thread_in_warp;
|
||||
|
||||
// these are [0, TCM/TCN)
|
||||
int tid_row = 0;
|
||||
int tid_col = 0;
|
||||
map_c(tid, tid_row, tid_col);
|
||||
|
||||
// FIXME: WM and WN might be swapped here
|
||||
int local_row = WM * warp_row + tid_row;
|
||||
int local_col = WN * warp_col + tid_col;
|
||||
|
||||
// FIXME: this is storing in M-major format
|
||||
volatile uint8_t *addr = reinterpret_cast<volatile uint8_t *>(
|
||||
&write_addr[dim_n * (local_row + 0) + (local_col + 0)]);
|
||||
volatile uint8_t *addr_tworow = addr + (2 * dim_n) * sizeof(float);
|
||||
asm volatile("fsw f0, %0(%1)" ::"i"(0 * WM * sizeof(float)), "r"(addr));
|
||||
asm volatile("fsw f1, %0(%1)" ::"i"(1 * WM * sizeof(float)), "r"(addr));
|
||||
asm volatile("fsw f2, %0(%1)" ::"i"(2 * WM * sizeof(float)), "r"(addr));
|
||||
asm volatile("fsw f3, %0(%1)" ::"i"(3 * WM * sizeof(float)), "r"(addr));
|
||||
asm volatile("fsw f4, %0(%1)" ::"i"(4 * WM * sizeof(float)), "r"(addr));
|
||||
asm volatile("fsw f5, %0(%1)" ::"i"(5 * WM * sizeof(float)), "r"(addr));
|
||||
asm volatile("fsw f6, %0(%1)" ::"i"(6 * WM * sizeof(float)), "r"(addr));
|
||||
asm volatile("fsw f7, %0(%1)" ::"i"(7 * WM * sizeof(float)), "r"(addr));
|
||||
|
||||
asm volatile("fsw f8, %0(%1)" ::"i"( 8 * WM * sizeof(float)), "r"(addr));
|
||||
asm volatile("fsw f9, %0(%1)" ::"i"( 9 * WM * sizeof(float)), "r"(addr));
|
||||
asm volatile("fsw f10, %0(%1)" ::"i"(10 * WM * sizeof(float)), "r"(addr));
|
||||
asm volatile("fsw f11, %0(%1)" ::"i"(11 * WM * sizeof(float)), "r"(addr));
|
||||
asm volatile("fsw f12, %0(%1)" ::"i"(12 * WM * sizeof(float)), "r"(addr));
|
||||
asm volatile("fsw f13, %0(%1)" ::"i"(13 * WM * sizeof(float)), "r"(addr));
|
||||
asm volatile("fsw f14, %0(%1)" ::"i"(14 * WM * sizeof(float)), "r"(addr));
|
||||
asm volatile("fsw f15, %0(%1)" ::"i"(15 * WM * sizeof(float)), "r"(addr));
|
||||
|
||||
asm volatile("fsw f16, %0(%1)" ::"i"((0 * WM + 8) * sizeof(float)), "r"(addr));
|
||||
asm volatile("fsw f17, %0(%1)" ::"i"((1 * WM + 8) * sizeof(float)), "r"(addr));
|
||||
asm volatile("fsw f18, %0(%1)" ::"i"((2 * WM + 8) * sizeof(float)), "r"(addr));
|
||||
asm volatile("fsw f19, %0(%1)" ::"i"((3 * WM + 8) * sizeof(float)), "r"(addr));
|
||||
asm volatile("fsw f20, %0(%1)" ::"i"((4 * WM + 8) * sizeof(float)), "r"(addr));
|
||||
asm volatile("fsw f21, %0(%1)" ::"i"((5 * WM + 8) * sizeof(float)), "r"(addr));
|
||||
asm volatile("fsw f22, %0(%1)" ::"i"((6 * WM + 8) * sizeof(float)), "r"(addr));
|
||||
asm volatile("fsw f23, %0(%1)" ::"i"((7 * WM + 8) * sizeof(float)), "r"(addr));
|
||||
|
||||
asm volatile("fsw f24, %0(%1)" ::"i"(( 8 * WM + 8) * sizeof(float)), "r"(addr));
|
||||
asm volatile("fsw f25, %0(%1)" ::"i"(( 9 * WM + 8) * sizeof(float)), "r"(addr));
|
||||
asm volatile("fsw f26, %0(%1)" ::"i"((10 * WM + 8) * sizeof(float)), "r"(addr));
|
||||
asm volatile("fsw f27, %0(%1)" ::"i"((11 * WM + 8) * sizeof(float)), "r"(addr));
|
||||
asm volatile("fsw f28, %0(%1)" ::"i"((12 * WM + 8) * sizeof(float)), "r"(addr));
|
||||
asm volatile("fsw f29, %0(%1)" ::"i"((13 * WM + 8) * sizeof(float)), "r"(addr));
|
||||
asm volatile("fsw f30, %0(%1)" ::"i"((14 * WM + 8) * sizeof(float)), "r"(addr));
|
||||
asm volatile("fsw f31, %0(%1)" ::"i"((15 * WM + 8) * sizeof(float)), "r"(addr));
|
||||
|
||||
asm volatile ("wgmma_store_finish_%=:" :: );
|
||||
}
|
||||
|
||||
__attribute__((convergent)) inline void
|
||||
threadblock_barrier(const uint32_t barrier_id, const uint32_t count) {
|
||||
asm volatile("" ::: "memory");
|
||||
@@ -859,13 +986,14 @@ __attribute__((always_inline)) inline void thread_block_gemm_single_tile(
|
||||
const uint32_t warps_per_threadblock_per_core =
|
||||
NUM_WARPS / threadblocks_per_cluster;
|
||||
|
||||
#if 1
|
||||
if constexpr (TENSOR_HOPPER) {
|
||||
#pragma GCC unroll 1
|
||||
for (int i = 0; i < BK_LOOP; i++) {
|
||||
#pragma GCC unroll 4
|
||||
for (uint32_t local_k = 0; local_k < tile_dim_k; local_k += TCK) {
|
||||
// FIXME: use local_a and local_b here
|
||||
vx_wgmma();
|
||||
for (uint32_t local_k = 0; local_k < tile_dim_k; local_k += TCK) {
|
||||
wgmma<T, layout_a, layout_b, leading_dim_a, tile_dim_n, tile_dim_k>(
|
||||
local_a, local_b, local_k, warp_row, warp_col);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
@@ -912,6 +1040,7 @@ __attribute__((always_inline)) inline void thread_block_gemm_single_tile(
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
// if constexpr (GEMMINI_DMA) {
|
||||
// // Call gemmini fence at the end of the loop to overlap dma & wmma.
|
||||
@@ -1011,9 +1140,13 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
|
||||
for (uint32_t block_n = 0; (block_n * BN) < dim_n; block_n++) {
|
||||
asm volatile ("loop_mn_start_%=:" :: );
|
||||
|
||||
// clear out accumulators
|
||||
initialize_accum_regs<0>();
|
||||
initialize_accum_regs<1>();
|
||||
if constexpr (TENSOR_HOPPER) {
|
||||
initialize_all_regs();
|
||||
} else {
|
||||
// clear out accumulators
|
||||
initialize_accum_regs<0>();
|
||||
initialize_accum_regs<1>();
|
||||
}
|
||||
|
||||
if constexpr (GEMMINI_DMA) {
|
||||
// pipeline initiation
|
||||
@@ -1216,13 +1349,22 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
|
||||
}
|
||||
|
||||
if constexpr (write_to_gmem) {
|
||||
if constexpr (TENSOR_HOPPER) {
|
||||
// wait until all results are accumulated into the RF
|
||||
vx_wgmma_wait();
|
||||
|
||||
float *global_offset_C = C + (BM * block_m) * dim_n + BN * block_n;
|
||||
wgmma_store(tid_in_warp, warp_col, warp_row, dim_n, global_offset_C);
|
||||
} else {
|
||||
#pragma GCC unroll
|
||||
for (int wm_iter = 0; wm_iter < WMITER; wm_iter++) {
|
||||
for (int wm_iter = 0; wm_iter < WMITER; wm_iter++) {
|
||||
#pragma GCC unroll
|
||||
for (int wn_iter = 0; wn_iter < WNITER; wn_iter++) {
|
||||
float *global_offset_C = C + (BM * block_m) * dim_n + BN * block_n;
|
||||
wmma_store(tid_in_warp, warp_col, warp_row, wn_iter, wm_iter, dim_n,
|
||||
global_offset_C);
|
||||
for (int wn_iter = 0; wn_iter < WNITER; wn_iter++) {
|
||||
float *global_offset_C =
|
||||
C + (BM * block_m) * dim_n + BN * block_n;
|
||||
wmma_store(tid_in_warp, warp_col, warp_row, wn_iter, wm_iter,
|
||||
dim_n, global_offset_C);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user