sgemm_impl: Add skeleton wgmma routine for single_tile
This commit is contained in:
@@ -233,6 +233,18 @@ inline void vx_wmma(const int dest_reg) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline void vx_wgmma() {
|
||||||
|
// .insn r opcode6, func3, func7, rd, rs1, rs2
|
||||||
|
// https://www.rowleydownload.co.uk/arm/documentation/gnu/as/RISC_002dV_002dFormats.html#RISC_002dV_002dFormats
|
||||||
|
asm volatile (".insn r %0, 0, 0, x0, x0, x0" :: "i"(RISCV_CUSTOM3));
|
||||||
|
}
|
||||||
|
|
||||||
|
inline void vx_wgmma_wait() {
|
||||||
|
// .insn r opcode6, func3, func7, rd, rs1, rs2
|
||||||
|
// func3 == 1 encodes wait
|
||||||
|
asm volatile (".insn r %0, 1, 0, x0, x0, x0" :: "i"(RISCV_CUSTOM3));
|
||||||
|
}
|
||||||
|
|
||||||
// Remap logical row/col coordinate of a matrix element to a memory index that
|
// Remap logical row/col coordinate of a matrix element to a memory index that
|
||||||
// follows the 2-level block-row-major layout that Gemmini DMA uses
|
// follows the 2-level block-row-major layout that Gemmini DMA uses
|
||||||
template <bool use_dma, uint32_t dim_col>
|
template <bool use_dma, uint32_t dim_col>
|
||||||
@@ -779,15 +791,15 @@ template <typename T,
|
|||||||
MemLayout layout_a, // memory layout of `local_a`
|
MemLayout layout_a, // memory layout of `local_a`
|
||||||
MemLayout layout_b, // memory layout of `local_b`
|
MemLayout layout_b, // memory layout of `local_b`
|
||||||
uint32_t tile_dim_m, uint32_t tile_dim_n, uint32_t tile_dim_k,
|
uint32_t tile_dim_m, uint32_t tile_dim_n, uint32_t tile_dim_k,
|
||||||
uint32_t leading_dim_a, // if zero, assumes packed layout, i.e. row
|
uint32_t leading_dim_a, // if zero, assumes packed layout, i.e. row
|
||||||
// stride == col.
|
// stride == col.
|
||||||
uint32_t leading_dim_b, // if zero, assumes packed layout, i.e. row
|
uint32_t leading_dim_b, // if zero, assumes packed layout, i.e. row
|
||||||
// stride == col.
|
// stride == col.
|
||||||
bool load_accum = false, // if true, load the accumulation registers
|
bool load_accum = false, // if true, load the accumulation registers
|
||||||
// with `local_c`. used for the (C + A*B)
|
// with `local_c`. used for the (C + A*B)
|
||||||
// operation
|
// operation
|
||||||
bool write_to_mem = false // if true, write the single result tile to
|
bool write_to_mem = false // if true, write the single result tile to
|
||||||
// the memory at a given address
|
// the memory at a given address
|
||||||
>
|
>
|
||||||
__attribute__((always_inline)) inline void thread_block_gemm_single_tile(
|
__attribute__((always_inline)) inline void thread_block_gemm_single_tile(
|
||||||
const T *local_a, const T *local_b, const T *local_c, T *result_addr,
|
const T *local_a, const T *local_b, const T *local_c, T *result_addr,
|
||||||
@@ -804,44 +816,55 @@ __attribute__((always_inline)) inline void thread_block_gemm_single_tile(
|
|||||||
const uint32_t warps_per_threadblock_per_core =
|
const uint32_t warps_per_threadblock_per_core =
|
||||||
NUM_WARPS / threadblocks_per_cluster;
|
NUM_WARPS / threadblocks_per_cluster;
|
||||||
|
|
||||||
// TODO: it would be useful if this bit is split out into a function, so that
|
if constexpr (TENSOR_HOPPER) {
|
||||||
// preloading accumulation tile can be used for full GEMMs at the start of
|
#pragma GCC unroll 1
|
||||||
// the K-loop.
|
for (int i = 0; i < BK_LOOP; i++) {
|
||||||
if constexpr (load_accum) {
|
#pragma GCC unroll 4
|
||||||
#pragma GCC unroll
|
for (uint32_t local_k = 0; local_k < tile_dim_k; local_k += TCK) {
|
||||||
for (int wm_iter = 0; wm_iter < WMITER; wm_iter++) {
|
// FIXME: use local_a and local_b here
|
||||||
#pragma GCC unroll
|
vx_wgmma();
|
||||||
for (int wn_iter = 0; wn_iter < WNITER; wn_iter++) {
|
}
|
||||||
wmma_load_accum(tid_in_warp, warp_col, warp_row, wn_iter, wm_iter,
|
}
|
||||||
tile_dim_n, local_c);
|
} else {
|
||||||
|
// TODO: it would be useful if this bit is split out into a function, so
|
||||||
|
// that preloading accumulation tile can be used for full GEMMs at the start
|
||||||
|
// of the K-loop.
|
||||||
|
if constexpr (load_accum) {
|
||||||
|
#pragma GCC unroll
|
||||||
|
for (int wm_iter = 0; wm_iter < WMITER; wm_iter++) {
|
||||||
|
#pragma GCC unroll
|
||||||
|
for (int wn_iter = 0; wn_iter < WNITER; wn_iter++) {
|
||||||
|
wmma_load_accum(tid_in_warp, warp_col, warp_row, wn_iter, wm_iter,
|
||||||
|
tile_dim_n, local_c);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
#pragma GCC unroll 1
|
#pragma GCC unroll 1
|
||||||
for (int i = 0; i < BK_LOOP; i++) {
|
for (int i = 0; i < BK_LOOP; i++) {
|
||||||
#pragma GCC unroll 4
|
#pragma GCC unroll 4
|
||||||
for (uint32_t local_k = 0; local_k < tile_dim_k; local_k += TCK) {
|
for (uint32_t local_k = 0; local_k < tile_dim_k; local_k += TCK) {
|
||||||
#pragma GCC unroll 2
|
#pragma GCC unroll 2
|
||||||
for (int wn_iter = 0; wn_iter < WNITER; wn_iter++) {
|
for (int wn_iter = 0; wn_iter < WNITER; wn_iter++) {
|
||||||
// SMEM -> RF
|
|
||||||
static_assert(leading_dim_b == 0,
|
|
||||||
"leading_dim for wmma_load_b is not implemented yet");
|
|
||||||
wmma_load_b<T, layout_b, tile_dim_m, tile_dim_n,
|
|
||||||
tile_dim_k /*leading_dim_b is TODO */>(
|
|
||||||
local_b, local_k, warp_col, wn_iter, tid_in_warp);
|
|
||||||
#pragma GCC unroll 2
|
|
||||||
for (int wm_iter = 0; wm_iter < WMITER; wm_iter++) {
|
|
||||||
// SMEM -> RF
|
// SMEM -> RF
|
||||||
if constexpr (leading_dim_a == 0) {
|
static_assert(leading_dim_b == 0,
|
||||||
wmma_load_a<T, layout_a, tile_dim_m, tile_dim_n, tile_dim_k>(
|
"leading_dim for wmma_load_b is not implemented yet");
|
||||||
local_a, local_k, warp_row, wm_iter, tid_in_warp);
|
wmma_load_b<T, layout_b, tile_dim_m, tile_dim_n,
|
||||||
} else {
|
tile_dim_k /*leading_dim_b is TODO */>(
|
||||||
wmma_load_a<T, layout_a, leading_dim_a>(local_a, local_k, warp_row,
|
local_b, local_k, warp_col, wn_iter, tid_in_warp);
|
||||||
wm_iter, tid_in_warp);
|
#pragma GCC unroll 2
|
||||||
|
for (int wm_iter = 0; wm_iter < WMITER; wm_iter++) {
|
||||||
|
// SMEM -> RF
|
||||||
|
if constexpr (leading_dim_a == 0) {
|
||||||
|
wmma_load_a<T, layout_a, tile_dim_m, tile_dim_n, tile_dim_k>(
|
||||||
|
local_a, local_k, warp_row, wm_iter, tid_in_warp);
|
||||||
|
} else {
|
||||||
|
wmma_load_a<T, layout_a, leading_dim_a>(
|
||||||
|
local_a, local_k, warp_row, wm_iter, tid_in_warp);
|
||||||
|
}
|
||||||
|
// perform mma
|
||||||
|
vx_wmma(wm_iter);
|
||||||
}
|
}
|
||||||
// perform mma
|
|
||||||
vx_wmma(wm_iter);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user