sgemm_tcore: Blocksize 64; Fix kernel launch on larger dim
& fix addrgen assembly too large offset error
This commit is contained in:
@@ -7,6 +7,36 @@
|
||||
#include "include/gemmini.h"
|
||||
#include "gemmini_mmio.h"
|
||||
|
||||
#define GEMMINI_DMA 1
|
||||
#if SMEM_SIZE == 0x4000
|
||||
#define SMEM_ADDR_Q0 ((float * const) 0xff000000)
|
||||
#define SMEM_ADDR_Q1 ((float * const) 0xff001000)
|
||||
#define SMEM_ADDR_Q2 ((float * const) 0xff002000)
|
||||
#define SMEM_ADDR_Q3 ((float * const) 0xff003000)
|
||||
#define SPAD_ADDR_Q0 0x0
|
||||
#define SPAD_ADDR_Q1 0x80
|
||||
#define SPAD_ADDR_Q2 0x100
|
||||
#define SPAD_ADDR_Q3 0x180
|
||||
#define BOUND_INST 0x400040004ULL
|
||||
#elif SMEM_SIZE == 0x10000
|
||||
#define SMEM_ADDR_Q0 ((float * const) 0xff000000)
|
||||
#define SMEM_ADDR_Q1 ((float * const) 0xff004000)
|
||||
#define SMEM_ADDR_Q2 ((float * const) 0xff008000)
|
||||
#define SMEM_ADDR_Q3 ((float * const) 0xff00c000)
|
||||
#define SPAD_ADDR_Q0 0x0
|
||||
#define SPAD_ADDR_Q1 0x200
|
||||
#define SPAD_ADDR_Q2 0x400
|
||||
#define SPAD_ADDR_Q3 0x600
|
||||
#define BOUND_INST 0x800080008ULL
|
||||
#else
|
||||
#error Unsupported smem size
|
||||
#endif
|
||||
|
||||
// FIXME: NUM_THREADS and NUM_WARPS hardcoded
|
||||
#if ((BM * BN / ELEM_PER_THREAD) > (CORES_PER_CLUSTER * 8 * 8))
|
||||
#error "threadblock size too big for cluster"
|
||||
#endif
|
||||
|
||||
inline void global_dmem_load(const uint32_t dim_n, const uint32_t dim_k,
|
||||
const uint32_t k, const float *A, const float *B,
|
||||
volatile float *local_a, volatile float *local_b,
|
||||
@@ -204,14 +234,16 @@ inline void global_dmem_load(const uint32_t dim_n, const uint32_t dim_k,
|
||||
|
||||
asm volatile ("fsw ft0, %0(%1)" :: "i"(BN * row_stride_b * 0 * sizeof(float)), "r"(local_b_tmp));
|
||||
asm volatile ("fsw ft1, %0(%1)" :: "i"(BN * row_stride_b * 1 * sizeof(float)), "r"(local_b_tmp));
|
||||
asm volatile ("fsw ft2, %0(%1)" :: "i"(BN * row_stride_b * 2 * sizeof(float)), "r"(local_b_tmp));
|
||||
asm volatile ("fsw ft3, %0(%1)" :: "i"(BN * row_stride_b * 3 * sizeof(float)), "r"(local_b_tmp));
|
||||
local_b_tmp += BN * row_stride_b * 4;
|
||||
local_b_tmp += BN * row_stride_b * 2;
|
||||
asm volatile ("fsw ft2, %0(%1)" :: "i"(BN * row_stride_b * 0 * sizeof(float)), "r"(local_b_tmp));
|
||||
asm volatile ("fsw ft3, %0(%1)" :: "i"(BN * row_stride_b * 1 * sizeof(float)), "r"(local_b_tmp));
|
||||
local_b_tmp += BN * row_stride_b * 2;
|
||||
asm volatile ("fsw ft4, %0(%1)" :: "i"(BN * row_stride_b * 0 * sizeof(float)), "r"(local_b_tmp));
|
||||
asm volatile ("fsw ft5, %0(%1)" :: "i"(BN * row_stride_b * 1 * sizeof(float)), "r"(local_b_tmp));
|
||||
asm volatile ("fsw ft6, %0(%1)" :: "i"(BN * row_stride_b * 2 * sizeof(float)), "r"(local_b_tmp));
|
||||
asm volatile ("fsw ft7, %0(%1)" :: "i"(BN * row_stride_b * 3 * sizeof(float)), "r"(local_b_tmp));
|
||||
local_b_tmp += BN * row_stride_b * 4;
|
||||
local_b_tmp += BN * row_stride_b * 2;
|
||||
asm volatile ("fsw ft6, %0(%1)" :: "i"(BN * row_stride_b * 0 * sizeof(float)), "r"(local_b_tmp));
|
||||
asm volatile ("fsw ft7, %0(%1)" :: "i"(BN * row_stride_b * 1 * sizeof(float)), "r"(local_b_tmp));
|
||||
local_b_tmp += BN * row_stride_b * 2;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -221,8 +253,7 @@ inline void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
|
||||
const uint32_t threadblock_dim_y,
|
||||
/*const uint32_t threadblock_id_x,
|
||||
const uint32_t threadblock_id_y,*/
|
||||
const uint32_t num_threadblocks,
|
||||
const uint32_t threadblock_id,
|
||||
const uint32_t threadblocks_per_cluster,
|
||||
const uint32_t threadblock_id_in_cluster,
|
||||
float *sharedmem_per_threadblock) {
|
||||
const float *A = (const float *)arg->addr_a;
|
||||
@@ -276,8 +307,8 @@ inline void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
|
||||
#endif
|
||||
|
||||
// divide rows (M) by the number of threadblocks
|
||||
const uint32_t dim_m_range = (dim_m / num_threadblocks);
|
||||
const uint32_t dim_m_start = dim_m_range * threadblock_id;
|
||||
const uint32_t dim_m_range = (dim_m / threadblocks_per_cluster);
|
||||
const uint32_t dim_m_start = dim_m_range * threadblock_id_in_cluster;
|
||||
const uint32_t block_m_start = dim_m_start / BM;
|
||||
const uint32_t block_m_end = (dim_m_start + dim_m_range) / BM;
|
||||
|
||||
@@ -303,9 +334,10 @@ inline void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
|
||||
GEMMINI_CISC_CMD_R((dim_n << 16) | (dim_k << 8) | 8);
|
||||
gemmini_fence();
|
||||
|
||||
// GEMMINI_CISC_CMD_I(12);
|
||||
// gemmini_fence();
|
||||
GEMMINI_CISC_CMD_I(12);
|
||||
gemmini_fence();
|
||||
|
||||
#if 0
|
||||
// sp_tiled_matmul_full_spad_ws includes CONFIG_BOUNDS
|
||||
// FIXME: block_k is 0 for two times
|
||||
sp_tiled_matmul_full_spad_ws(
|
||||
@@ -321,6 +353,7 @@ inline void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
|
||||
/*a_transpose=*/0, /*b_transpose=*/0, /*full_C=*/0, /*low_D=*/0,
|
||||
/*acc=*/0, /*act=*/NO_ACTIVATION, /*skips=*/skips)
|
||||
gemmini_fence();
|
||||
#endif
|
||||
}
|
||||
|
||||
threadblock_barrier(threadblock_id_in_cluster, threadblock_dim_y);
|
||||
@@ -340,23 +373,22 @@ inline void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
|
||||
// FIXME: block_k is wrong
|
||||
ROCC_INSTRUCTION_RS1_RS2(
|
||||
XCUSTOM_ACC,
|
||||
(uint64_t)(A + block_m * BM * dim_k + block_k * BK),
|
||||
(uint64_t)(B + block_k * BK * dim_n + block_n * BN),
|
||||
(uint64_t)(A + block_m * BM * dim_k + (block_k + 1/*runahead*/) * BK),
|
||||
(uint64_t)(B + (block_k + 1/*runahead*/) * BK * dim_n + block_n * BN),
|
||||
k_LOOP_WS_CONFIG_ADDRS_AB)
|
||||
// GEMMINI_CISC(8) does k_LOOP_WS_CONFIG_STRIDES_AB
|
||||
GEMMINI_CISC_CMD_R((dim_n << 16) | (dim_k << 8) | 8);
|
||||
// gemmini_fence();
|
||||
|
||||
// TODO: this is probably slow
|
||||
// if (block_k & 1) {
|
||||
// GEMMINI_CISC_CMD_I(12);
|
||||
// } else { // block_k == 0 is here
|
||||
// GEMMINI_CISC_CMD_I(13);
|
||||
// }
|
||||
// TODO: branch is probably slow
|
||||
if (block_k & 1) {
|
||||
GEMMINI_CISC_CMD_I(12);
|
||||
} else { // block_k == 0 is here
|
||||
GEMMINI_CISC_CMD_I(13);
|
||||
}
|
||||
|
||||
// configure loop iteration bounds
|
||||
// FIXME: shouldn't be necessary
|
||||
// #define BOUND_INST 0x400040004ULL
|
||||
// ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, 0, BOUND_INST,
|
||||
// k_LOOP_WS_CONFIG_BOUNDS) ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC,
|
||||
// SPAD_ADDR_Q0, SPAD_ADDR_Q1, k_LOOP_WS_CONFIG_SPAD_AB)
|
||||
@@ -483,12 +515,11 @@ void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
float *sharedmem_per_threadblock =
|
||||
(float *)DEV_SMEM_START_ADDR + (2 * BM * BK) * threadblock_id_in_cluster;
|
||||
|
||||
const int warp_id = vx_warp_id();
|
||||
thread_block_gemm(arg, tid_in_threadblock, threads_per_threadblock,
|
||||
threadblock_dim_y,
|
||||
/*threadblock_id_x, threadblock_id_y,*/
|
||||
num_threadblocks,
|
||||
threadblock_id,
|
||||
threadblocks_per_cluster,
|
||||
// threadblock_id,
|
||||
threadblock_id_in_cluster,
|
||||
sharedmem_per_threadblock);
|
||||
}
|
||||
|
||||
@@ -11,6 +11,11 @@
|
||||
#undef ELEM_PER_THREAD
|
||||
#define ELEM_PER_THREAD (WMITER * WNITER * ((TCM * TCN) / NUM_LANES) / (DOUBLE_BUFFER ? 2 : 1))
|
||||
|
||||
// FIXME: NUM_THREADS and NUM_WARPS hardcoded
|
||||
#if ((BM * BN / ELEM_PER_THREAD) > (CORES_PER_CLUSTER * 8 * 8))
|
||||
#error "threadblock size too big for cluster"
|
||||
#endif
|
||||
|
||||
inline void global_dmem_load(const uint32_t dim_n, const uint32_t dim_k,
|
||||
const uint32_t k, const float *A, const float *B,
|
||||
volatile float *local_a, volatile float *local_b,
|
||||
@@ -85,11 +90,12 @@ inline void global_dmem_load(const uint32_t dim_n, const uint32_t dim_k,
|
||||
asm volatile ("fsw ft1, %0(%1)" :: "i"(BK * row_stride_a * 1 * sizeof(float)), "r"(local_a_tmp));
|
||||
asm volatile ("fsw ft2, %0(%1)" :: "i"(BK * row_stride_a * 2 * sizeof(float)), "r"(local_a_tmp));
|
||||
asm volatile ("fsw ft3, %0(%1)" :: "i"(BK * row_stride_a * 3 * sizeof(float)), "r"(local_a_tmp));
|
||||
asm volatile ("fsw ft4, %0(%1)" :: "i"(BK * row_stride_a * 4 * sizeof(float)), "r"(local_a_tmp));
|
||||
asm volatile ("fsw ft5, %0(%1)" :: "i"(BK * row_stride_a * 5 * sizeof(float)), "r"(local_a_tmp));
|
||||
asm volatile ("fsw ft6, %0(%1)" :: "i"(BK * row_stride_a * 6 * sizeof(float)), "r"(local_a_tmp));
|
||||
asm volatile ("fsw ft7, %0(%1)" :: "i"(BK * row_stride_a * 7 * sizeof(float)), "r"(local_a_tmp));
|
||||
local_a_tmp += BK * row_stride_a * 8;
|
||||
local_a_tmp += BK * row_stride_a * 4;
|
||||
asm volatile ("fsw ft4, %0(%1)" :: "i"(BK * row_stride_a * 0 * sizeof(float)), "r"(local_a_tmp));
|
||||
asm volatile ("fsw ft5, %0(%1)" :: "i"(BK * row_stride_a * 1 * sizeof(float)), "r"(local_a_tmp));
|
||||
asm volatile ("fsw ft6, %0(%1)" :: "i"(BK * row_stride_a * 2 * sizeof(float)), "r"(local_a_tmp));
|
||||
asm volatile ("fsw ft7, %0(%1)" :: "i"(BK * row_stride_a * 3 * sizeof(float)), "r"(local_a_tmp));
|
||||
local_a_tmp += BK * row_stride_a * 4;
|
||||
}
|
||||
} else {
|
||||
if constexpr (!GMEM_COALESCED_A) {
|
||||
@@ -245,13 +251,16 @@ inline void global_dmem_load(const uint32_t dim_n, const uint32_t dim_k,
|
||||
|
||||
asm volatile ("fsw ft0, %0(%1)" :: "i"(BN * row_stride_b * 0 * sizeof(float)), "r"(local_b_tmp));
|
||||
asm volatile ("fsw ft1, %0(%1)" :: "i"(BN * row_stride_b * 1 * sizeof(float)), "r"(local_b_tmp));
|
||||
asm volatile ("fsw ft2, %0(%1)" :: "i"(BN * row_stride_b * 2 * sizeof(float)), "r"(local_b_tmp));
|
||||
asm volatile ("fsw ft3, %0(%1)" :: "i"(BN * row_stride_b * 3 * sizeof(float)), "r"(local_b_tmp));
|
||||
asm volatile ("fsw ft4, %0(%1)" :: "i"(BN * row_stride_b * 4 * sizeof(float)), "r"(local_b_tmp));
|
||||
asm volatile ("fsw ft5, %0(%1)" :: "i"(BN * row_stride_b * 5 * sizeof(float)), "r"(local_b_tmp));
|
||||
asm volatile ("fsw ft6, %0(%1)" :: "i"(BN * row_stride_b * 6 * sizeof(float)), "r"(local_b_tmp));
|
||||
asm volatile ("fsw ft7, %0(%1)" :: "i"(BN * row_stride_b * 7 * sizeof(float)), "r"(local_b_tmp));
|
||||
local_b_tmp += BN * row_stride_b * 8;
|
||||
local_b_tmp += BN * row_stride_b * 2;
|
||||
asm volatile ("fsw ft2, %0(%1)" :: "i"(BN * row_stride_b * 0 * sizeof(float)), "r"(local_b_tmp));
|
||||
asm volatile ("fsw ft3, %0(%1)" :: "i"(BN * row_stride_b * 1 * sizeof(float)), "r"(local_b_tmp));
|
||||
local_b_tmp += BN * row_stride_b * 2;
|
||||
asm volatile ("fsw ft4, %0(%1)" :: "i"(BN * row_stride_b * 0 * sizeof(float)), "r"(local_b_tmp));
|
||||
asm volatile ("fsw ft5, %0(%1)" :: "i"(BN * row_stride_b * 1 * sizeof(float)), "r"(local_b_tmp));
|
||||
local_b_tmp += BN * row_stride_b * 2;
|
||||
asm volatile ("fsw ft6, %0(%1)" :: "i"(BN * row_stride_b * 0 * sizeof(float)), "r"(local_b_tmp));
|
||||
asm volatile ("fsw ft7, %0(%1)" :: "i"(BN * row_stride_b * 1 * sizeof(float)), "r"(local_b_tmp));
|
||||
local_b_tmp += BN * row_stride_b * 2;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -20,9 +20,9 @@
|
||||
// (BM*BN) / (TM*TN) == threadblock size >= NT * CORES_PER_CLUSTER
|
||||
// * Combining BM * BK >= (BM*BN) / (TM*TN) == threadblock yields
|
||||
// BM <= BK*TM*TN
|
||||
#define BM 32
|
||||
#define BN 32
|
||||
#define BK 32
|
||||
#define BM 64
|
||||
#define BN 64
|
||||
#define BK 64
|
||||
#define WM 16
|
||||
#define WN 8
|
||||
#define TCM 8
|
||||
@@ -42,29 +42,12 @@
|
||||
// For correctness, only one of either should be 1. To model the case where
|
||||
// the entire A matrix is already stored transposed in GMEM ("TN" kernel), set
|
||||
// both to 0.
|
||||
#define TRANSPOSE_AT_PRODUCE 0
|
||||
#define TRANSPOSE_AT_PRODUCE 1
|
||||
#define TRANSPOSE_AT_CONSUME 0
|
||||
// GMEM_COALESCED sets bank conflict-free accesses for
|
||||
// 1: GMEM loads of A matrix
|
||||
// 0: SMEM stores of A matrix
|
||||
#define GMEM_COALESCED_A 1
|
||||
#define GEMMINI_DMA 0
|
||||
#if SMEM_SIZE != 0x4000
|
||||
#error Currently only supports 16K spad
|
||||
#endif
|
||||
#define SMEM_ADDR_Q0 ((float * const) 0xff000000)
|
||||
#define SMEM_ADDR_Q1 ((float * const) 0xff001000)
|
||||
#define SMEM_ADDR_Q2 ((float * const) 0xff002000)
|
||||
#define SMEM_ADDR_Q3 ((float * const) 0xff003000)
|
||||
#define SPAD_ADDR_Q0 0x0
|
||||
#define SPAD_ADDR_Q1 0x80
|
||||
#define SPAD_ADDR_Q2 0x100
|
||||
#define SPAD_ADDR_Q3 0x180
|
||||
|
||||
// FIXME: NUM_THREADS and NUM_WARPS hardcoded
|
||||
#if ((BM * BN / ELEM_PER_THREAD) > (CORES_PER_CLUSTER * 8 * 8))
|
||||
#error "threadblock size too big for cluster"
|
||||
#endif
|
||||
|
||||
inline constexpr void map_operand_32lanes(const int tid, int &row, int &col) {
|
||||
const int tg = tid / 4;
|
||||
|
||||
Reference in New Issue
Block a user