Merge branch 'kernels-flash' into new-cisc

This commit is contained in:
Hansung Kim
2024-11-09 14:42:46 -08:00
6 changed files with 170 additions and 50 deletions

View File

@@ -9,10 +9,10 @@
// #define SMEM_SIZE 0x4000
// 64KB
// #define SMEM_SIZE 0x10000
// 128KB
// 128KB (FP16 GEMM)
// #define SMEM_SIZE 0x20000
// 256KB
#define SMEM_SIZE 0x20000
// 256KB (FlashAttention)
#define SMEM_SIZE 0x40000
#define SMEM_MASK (SMEM_SIZE - 1)
#define SMEM_ADDR_END (SMEM_BASE + SMEM_SIZE)

View File

@@ -4,6 +4,9 @@
#include <vx_spawn.h>
#include <float.h>
#define MARK_BEG() asm volatile ("slti x0, x1, -1047")
#define MARK_END() asm volatile ("slti x0, x1, -499")
#define B_ROW 64
#define B_COL 64
#define HEADDIM 64
@@ -11,8 +14,10 @@
#define ROW_REMAINDER_LOGIC
constexpr uint32_t ROWMAX_SETS = 3;
constexpr bool WARP_SPECIALIZED = true;
constexpr bool TENSOR_CORE = true;
// constexpr bool WARP_SPECIALIZED = true;
// constexpr bool TENSOR_CORE = true;
constexpr bool WARP_SPECIALIZED = false;
constexpr bool TENSOR_CORE = false;
// temporary safety stop for wrong configs
static_assert(NUM_CORES == 4);

View File

@@ -219,6 +219,8 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
loop_matmul_skips(/*skip_lda=*/0, /*skip_ldb=*/0, /*skip_ldd=*/1,
/*skip_ex=*/1, /*skip_stc=*/1);
MARK_BEG();
if constexpr (GEMMINI_DMA) {
if (tid_in_warpgroup == 0) {
gemmini_extended_config_ex(WEIGHT_STATIONARY, 0, 0, 1, 0, 0);
@@ -259,7 +261,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
(uint64_t)(gmem_K_tile),
k_LOOP_WS_CONFIG_ADDRS_AB)
// configure address strides for the DMA
GEMMINI_CISC_CMD_R((dim_seqlen << 16) | (HEADDIM << 8) |
GEMMINI_CISC_CMD_R((dim_seqlen << 20) | (HEADDIM << 8) |
8 /*k_LOOP_WS_CONFIG_STRIDES_AB*/);
gemmini_fence();
@@ -549,7 +551,7 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
k_LOOP_WS_CONFIG_ADDRS_AB)
// configure address strides for the DMA
// FIXME: unnecessary?
GEMMINI_CISC_CMD_R((HEADDIM /*V*/ << 16) | (dim_seqlen /*KT*/ << 8) |
GEMMINI_CISC_CMD_R((HEADDIM /*V*/ << 20) | (dim_seqlen /*KT*/ << 8) |
8 /*k_LOOP_WS_CONFIG_STRIDES_AB*/);
gemmini_fence();
@@ -813,8 +815,6 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
warps_per_warpgroup_per_core);
}
}
#if 0
#endif
}
asm volatile ("tile_loop_finish_%=:" :: );
@@ -824,6 +824,8 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
if (warpgroup_id == 0) {
threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core);
}
MARK_END();
}
int main() {

View File

@@ -212,6 +212,8 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
loop_matmul_skips(/*skip_lda=*/1, /*skip_ldb=*/1, /*skip_ldd=*/0,
/*skip_ex=*/0, /*skip_stc=*/1);
MARK_BEG();
if (tid_in_warpgroup == 0) {
gemmini_extended_config_ex(WEIGHT_STATIONARY, 0, 0, 1, 0, 0);
@@ -336,7 +338,11 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
// "inner loop" along the columns of K^T
const uint32_t k_tiles = (dim_seqlen / B_COL);
for (uint32_t tile_k = 0;
tile_k < (4 /*for perf measurement*/ * k_tiles) + 2 /*pipeline latency*/;
tile_k < (4 /*for perf measurement*/ *
// virgo kernel is fully pipelined around (2 GEMMs | softmax);
// requires two loop iterations to finish one tile compute
(2 * k_tiles)) +
2 /*pipeline latency*/;
tile_k++) {
if constexpr (DEBUG || true) {
threadblock_barrier(global_barrier_id, warps_per_threadblock_per_core);
@@ -677,6 +683,8 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
}
asm volatile ("tile_loop_finish_%=:" :: );
MARK_END();
}
int main() {

View File

@@ -6,7 +6,7 @@
#include "include/gemmini.h"
#include "gemmini_mmio.h"
#define FP_SIZE 16
#define FP_SIZE 32
// "fake" fp16 type that only has the correct data width.
using float16_t = uint16_t;
@@ -19,7 +19,7 @@ using float_type = float16_t;
// Generate kernel for the Hopper-style SMEM-decoupled tensor core. This uses
// asynchronous HGMMA and HGMMA_WAIT instructions.
#define TENSOR_HOPPER 1
#define TENSOR_HOPPER 0
// Constraints on parameters:
// * Memory:
@@ -104,6 +104,12 @@ static_assert(WMITER * WNITER * TCM * TCN * NUM_WARPS * CORES_PER_CLUSTER ==
#define TRANSPOSE_AT_PRODUCE 0
#define TRANSPOSE_AT_CONSUME 0
// if 1, wmma_store() will not respect the register <-> matrix fragment mapping
// scheme and instead do a fast coalesced GMEM writes for move out. This
// doesn't necessarily mean breaking correctness; it means that the final
// result matrix will be stored in a swizzled form in the global memory.
#define WMMA_STORE_FAST 1
#define GEMMINI_DMA 1
#define GEMMINI_DMA_FAST 1
#if SMEM_SIZE == 0x4000
@@ -213,10 +219,9 @@ inline constexpr void map_c_8lanes(const int tid, int &row, int &col) {
col += ((tid % 4) / 2) * 2;
}
inline constexpr void map_c_8lanes_hopper(const int tid, int &row, int &col) {
inline constexpr void map_c_8lanes_coalesced(const int tid, int &row, int &col) {
const int tg = tid / 2;
// FIXME wrong!!!
row = 0;
col = tid;
}
@@ -225,8 +230,8 @@ inline constexpr void map_c(const int tid, int &row, int &col) {
if constexpr (NUM_THREADS == 32) {
map_c_32lanes(tid, row, col);
} else if constexpr (NUM_THREADS == 8) {
if constexpr (TENSOR_HOPPER) {
map_c_8lanes_hopper(tid, row, col);
if constexpr (TENSOR_HOPPER || WMMA_STORE_FAST) {
map_c_8lanes_coalesced(tid, row, col);
} else {
map_c_8lanes(tid, row, col);
}
@@ -664,26 +669,48 @@ wmma_store(const int thread_in_warp, const int warp_col, const int warp_row,
volatile uint8_t *addr = reinterpret_cast<volatile uint8_t *>(
&write_addr[dim_n * (local_row + 0) + (local_col + 0)]);
volatile uint8_t *addr_tworow = addr + (2 * dim_n) * sizeof(float);
asm volatile("fsw f16, %0(%1)" ::"i"(0 * sizeof(float)), "r"(addr));
asm volatile("fsw f17, %0(%1)" ::"i"(1 * sizeof(float)), "r"(addr));
asm volatile("fsw f18, %0(%1)" ::"i"(0 * sizeof(float)), "r"(addr_tworow));
asm volatile("fsw f19, %0(%1)" ::"i"(1 * sizeof(float)), "r"(addr_tworow));
asm volatile("fsw f20, %0(%1)" ::"i"(4 * sizeof(float)), "r"(addr));
asm volatile("fsw f21, %0(%1)" ::"i"(5 * sizeof(float)), "r"(addr));
asm volatile("fsw f22, %0(%1)" ::"i"(4 * sizeof(float)), "r"(addr_tworow));
asm volatile("fsw f23, %0(%1)" ::"i"(5 * sizeof(float)), "r"(addr_tworow));
if constexpr (!WMMA_STORE_FAST) {
asm volatile("fsw f16, %0(%1)" ::"i"(0 * sizeof(float)), "r"(addr));
asm volatile("fsw f17, %0(%1)" ::"i"(1 * sizeof(float)), "r"(addr));
asm volatile("fsw f18, %0(%1)" ::"i"(0 * sizeof(float)), "r"(addr_tworow));
asm volatile("fsw f19, %0(%1)" ::"i"(1 * sizeof(float)), "r"(addr_tworow));
asm volatile("fsw f20, %0(%1)" ::"i"(4 * sizeof(float)), "r"(addr));
asm volatile("fsw f21, %0(%1)" ::"i"(5 * sizeof(float)), "r"(addr));
asm volatile("fsw f22, %0(%1)" ::"i"(4 * sizeof(float)), "r"(addr_tworow));
asm volatile("fsw f23, %0(%1)" ::"i"(5 * sizeof(float)), "r"(addr_tworow));
} else {
asm volatile("fsw f16, %0(%1)" ::"i"(0 * WN * sizeof(float)), "r"(addr));
asm volatile("fsw f17, %0(%1)" ::"i"(1 * WN * sizeof(float)), "r"(addr));
asm volatile("fsw f18, %0(%1)" ::"i"(2 * WN * sizeof(float)), "r"(addr));
asm volatile("fsw f19, %0(%1)" ::"i"(3 * WN * sizeof(float)), "r"(addr));
asm volatile("fsw f20, %0(%1)" ::"i"(4 * WN * sizeof(float)), "r"(addr));
asm volatile("fsw f21, %0(%1)" ::"i"(5 * WN * sizeof(float)), "r"(addr));
asm volatile("fsw f22, %0(%1)" ::"i"(6 * WN * sizeof(float)), "r"(addr));
asm volatile("fsw f23, %0(%1)" ::"i"(7 * WN * sizeof(float)), "r"(addr));
}
} else {
volatile uint8_t *addr = reinterpret_cast<volatile uint8_t *>(
&write_addr[dim_n * (local_row + 0) + (local_col + 0)]);
volatile uint8_t *addr_tworow = addr + (2 * dim_n) * sizeof(float);
asm volatile("fsw f24, %0(%1)" ::"i"(0 * sizeof(float)), "r"(addr));
asm volatile("fsw f25, %0(%1)" ::"i"(1 * sizeof(float)), "r"(addr));
asm volatile("fsw f26, %0(%1)" ::"i"(0 * sizeof(float)), "r"(addr_tworow));
asm volatile("fsw f27, %0(%1)" ::"i"(1 * sizeof(float)), "r"(addr_tworow));
asm volatile("fsw f28, %0(%1)" ::"i"(4 * sizeof(float)), "r"(addr));
asm volatile("fsw f29, %0(%1)" ::"i"(5 * sizeof(float)), "r"(addr));
asm volatile("fsw f30, %0(%1)" ::"i"(4 * sizeof(float)), "r"(addr_tworow));
asm volatile("fsw f31, %0(%1)" ::"i"(5 * sizeof(float)), "r"(addr_tworow));
if constexpr (!WMMA_STORE_FAST) {
asm volatile("fsw f24, %0(%1)" ::"i"(0 * sizeof(float)), "r"(addr));
asm volatile("fsw f25, %0(%1)" ::"i"(1 * sizeof(float)), "r"(addr));
asm volatile("fsw f26, %0(%1)" ::"i"(2 * sizeof(float)), "r"(addr_tworow));
asm volatile("fsw f27, %0(%1)" ::"i"(3 * sizeof(float)), "r"(addr_tworow));
asm volatile("fsw f28, %0(%1)" ::"i"(4 * sizeof(float)), "r"(addr));
asm volatile("fsw f29, %0(%1)" ::"i"(5 * sizeof(float)), "r"(addr));
asm volatile("fsw f30, %0(%1)" ::"i"(6 * sizeof(float)), "r"(addr_tworow));
asm volatile("fsw f31, %0(%1)" ::"i"(7 * sizeof(float)), "r"(addr_tworow));
} else {
asm volatile("fsw f24, %0(%1)" ::"i"(0 * WN * sizeof(float)), "r"(addr));
asm volatile("fsw f25, %0(%1)" ::"i"(1 * WN * sizeof(float)), "r"(addr));
asm volatile("fsw f26, %0(%1)" ::"i"(2 * WN * sizeof(float)), "r"(addr));
asm volatile("fsw f27, %0(%1)" ::"i"(3 * WN * sizeof(float)), "r"(addr));
asm volatile("fsw f28, %0(%1)" ::"i"(4 * WN * sizeof(float)), "r"(addr));
asm volatile("fsw f29, %0(%1)" ::"i"(5 * WN * sizeof(float)), "r"(addr));
asm volatile("fsw f30, %0(%1)" ::"i"(6 * WN * sizeof(float)), "r"(addr));
asm volatile("fsw f31, %0(%1)" ::"i"(7 * WN * sizeof(float)), "r"(addr));
}
}
asm volatile ("wmma_store_finish_%=:" :: );
@@ -1150,19 +1177,20 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
if constexpr (GEMMINI_DMA) {
// pipeline initiation
if (tid_in_threadblock == 0) {
// configure dma gmem address to load from
ROCC_INSTRUCTION_RS1_RS2(
XCUSTOM_ACC,
(uint64_t)(A + block_m * BM * dim_k + /*block_k:*/0 * BK),
(uint64_t)(B + /*block_k:*/0 * BK * dim_n + block_n * BN),
k_LOOP_WS_CONFIG_ADDRS_AB)
// GEMMINI_CISC(8) does k_LOOP_WS_CONFIG_STRIDES_AB
GEMMINI_CISC_CMD_R((dim_n << 20) | (dim_k << 8) | 8);
gemmini_fence();
if (block_m == 0 && block_n == 0) {
if (tid_in_threadblock == 0) {
// configure dma gmem address to load from
ROCC_INSTRUCTION_RS1_RS2(
XCUSTOM_ACC,
(uint64_t)(A + block_m * BM * dim_k + /*block_k:*/ 0 * BK),
(uint64_t)(B + /*block_k:*/ 0 * BK * dim_n + block_n * BN),
k_LOOP_WS_CONFIG_ADDRS_AB)
// GEMMINI_CISC(8) does k_LOOP_WS_CONFIG_STRIDES_AB
GEMMINI_CISC_CMD_R((dim_n << 20) | (dim_k << 8) | 8);
gemmini_fence();
GEMMINI_CISC_CMD_I(10);
gemmini_fence();
GEMMINI_CISC_CMD_I(10);
gemmini_fence();
#if 0
// sp_tiled_matmul_full_spad_ws includes CONFIG_BOUNDS
@@ -1181,10 +1209,11 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
/*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips)
gemmini_fence();
#endif
}
}
threadblock_barrier(threadblock_id_in_cluster,
warps_per_threadblock_per_core);
threadblock_barrier(threadblock_id_in_cluster,
warps_per_threadblock_per_core);
}
}
#pragma GCC unroll 1
@@ -1197,12 +1226,28 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
// this is either done using DMA or SIMT cores depending on GEMMINI_DMA
#if (GEMMINI_DMA == 1)
if ((tid_in_threadblock == 0) && ((block_k * BK) != (dim_k - BK))) {
if (tid_in_threadblock == 0) {
asm volatile("next_index_start_%=:" ::);
const uint32_t next_block_k =
((block_k + 1) * BK == dim_k) ? 0 : block_k + 1;
const uint32_t next_block_n =
(next_block_k == 0)
? (((block_n + 1) * BN == dim_n) ? 0 : block_n + 1)
: block_n;
const uint32_t next_block_m =
(next_block_n == 0)
? (((block_m + 1) == block_m_end) ? block_m_start /*unused*/
: block_m + 1)
: block_m;
asm volatile("next_index_end_%=:" ::);
// configure dma gmem address to load from
ROCC_INSTRUCTION_RS1_RS2(
XCUSTOM_ACC,
(uint64_t)(A + block_m * BM * dim_k + (block_k + 1/*runahead*/) * BK),
(uint64_t)(B + (block_k + 1/*runahead*/) * BK * dim_n + block_n * BN),
(uint64_t)(A + next_block_m * BM * dim_k + next_block_k * BK),
(uint64_t)(B + next_block_k * BK * dim_n + next_block_n * BN),
k_LOOP_WS_CONFIG_ADDRS_AB)
// GEMMINI_CISC(8) does k_LOOP_WS_CONFIG_STRIDES_AB
GEMMINI_CISC_CMD_R((dim_n << 20) | (dim_k << 8) | 8);
@@ -1210,6 +1255,11 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
// block_k is even: opcode 11 (write to local_a_buf)
// block_k is odd: opcode 10 (write to local_a)
//
// FIXME: This depends on (dim_k / BK) being an even number, since
// the last iteration of the k-loop is prefetching for the first
// iteration of the n-loop. The ping-poing indexing has to match for
// the two loop end to connect.
const uint32_t opcode = 11 - (block_k & 1);
GEMMINI_CISC_CMD_I(opcode);
// // TODO: branch is probably slow
@@ -1349,6 +1399,8 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
}
if constexpr (write_to_gmem) {
asm volatile("move_out_start_%=:" ::);
if constexpr (TENSOR_HOPPER) {
// wait until all results are accumulated into the RF
vx_wgmma_wait();
@@ -1367,6 +1419,8 @@ inline void thread_block_gemm(const T *A, const T *B, float *C,
}
}
}
asm volatile("move_out_end_%=:" ::);
}
}
asm volatile("loop_mn_end_%=:" ::);

View File

@@ -0,0 +1,51 @@
#!/bin/sh
#
# Updates symlink to args.bin, input.a.bin, input.b.bin to point to the right
# binary according to the dimension size given as the argument.
if [ "$#" != "2" ]; then
echo "usage: $0 DIMENSION 1|0"
echo "second argument indicates using DMA or not."
exit 1
fi
dim="$1"
dma="$2"
if [ "$2" == "1" ]; then
layout_a="row.swizzle_fp16"
layout_b="row"
else
layout_a="col.swizzle_fp16"
layout_b="row.swizzle_fp16"
fi
check_exists() {
if ! [ -f "$1" ]; then
echo "error: looked for file $1 that does not exist."
exit 1
fi
}
args="args.m$1n$1k$1.bin"
input_a="input.a.rand01.fp16.m$1n$1k$1.$layout_a.bin"
input_b="input.b.rand01.fp16.m$1n$1k$1.$layout_b.bin"
check_exists "$args"
check_exists "$input_a"
check_exists "$input_b"
echo "will symlink:"
echo "args.bin -> $args"
echo "input.a.bin -> $input_a"
echo "input.b.bin -> $input_b"
echo "continue? (Y/N)"
read -r -s -n 1 answer
if [ "$answer" != "Y" ] && [ "$answer" != "y" ]; then
echo "exiting..."
exit 1
fi
ln -sf -v "$args" "args.bin"
ln -sf -v "$input_a" "input.a.bin"
ln -sf -v "$input_b" "input.b.bin"
echo "done."