fp16 no dma kernel
This commit is contained in:
@@ -6,25 +6,53 @@
|
|||||||
#include "include/gemmini.h"
|
#include "include/gemmini.h"
|
||||||
#include "gemmini_mmio.h"
|
#include "gemmini_mmio.h"
|
||||||
|
|
||||||
|
#define NUM_CLUSTERS 1
|
||||||
|
// #define FP32
|
||||||
|
|
||||||
|
#ifdef FP32
|
||||||
|
// fp32
|
||||||
#define TILE_M 64
|
#define TILE_M 64
|
||||||
#define TILE_N 64
|
#define TILE_N 64
|
||||||
#define TILE_K 64
|
#define TILE_K 64
|
||||||
#define TILE_MN 4096
|
#define TILE_MN 4096
|
||||||
#define TILE_MK 4096
|
#define TILE_MK 4096
|
||||||
#define TILE_NK 4096
|
#define TILE_NK 4096
|
||||||
|
|
||||||
#define NUM_CLUSTERS 1
|
|
||||||
#define NUM_THREADS_IN_CLUSTER 256
|
#define NUM_THREADS_IN_CLUSTER 256
|
||||||
|
|
||||||
#define SMEM_ADDR_Q0 ((float * const) 0xff000000)
|
#define SMEM_ADDR_Q0 ((mem_elem_t * const) 0xff000000)
|
||||||
#define SMEM_ADDR_Q1 ((float * const) 0xff004000)
|
#define SMEM_ADDR_Q1 ((mem_elem_t * const) 0xff004000)
|
||||||
#define SMEM_ADDR_Q2 ((float * const) 0xff008000)
|
#define SMEM_ADDR_Q2 ((mem_elem_t * const) 0xff008000)
|
||||||
#define SMEM_ADDR_Q3 ((float * const) 0xff00c000)
|
#define SMEM_ADDR_Q3 ((mem_elem_t * const) 0xff00c000)
|
||||||
#define SPAD_ADDR_Q0 0x0
|
#define SPAD_ADDR_Q0 0x0
|
||||||
#define SPAD_ADDR_Q1 0x200
|
#define SPAD_ADDR_Q1 0x200
|
||||||
#define SPAD_ADDR_Q2 0x400
|
#define SPAD_ADDR_Q2 0x400
|
||||||
#define SPAD_ADDR_Q3 0x600
|
#define SPAD_ADDR_Q3 0x600
|
||||||
#define SPAD_ADDR_Q4 0x800
|
#define SPAD_ADDR_Q4 0x800
|
||||||
|
typedef float smem_elem_t;
|
||||||
|
typedef float mem_elem_t;
|
||||||
|
|
||||||
|
#else
|
||||||
|
// fp16
|
||||||
|
#define TILE_M 128
|
||||||
|
#define TILE_N 64
|
||||||
|
#define TILE_K 128
|
||||||
|
#define TILE_MN 8192
|
||||||
|
#define TILE_MK 16384
|
||||||
|
#define TILE_NK 8192
|
||||||
|
#define NUM_THREADS_IN_CLUSTER 512
|
||||||
|
|
||||||
|
#define SMEM_ADDR_Q0 ((mem_elem_t * const) 0xff000000)
|
||||||
|
#define SMEM_ADDR_Q1 ((mem_elem_t * const) 0xff008000)
|
||||||
|
#define SMEM_ADDR_Q2 ((mem_elem_t * const) 0xff001000)
|
||||||
|
#define SMEM_ADDR_Q3 ((mem_elem_t * const) 0xff018000)
|
||||||
|
#define SPAD_ADDR_Q0 0x0
|
||||||
|
#define SPAD_ADDR_Q1 0x400
|
||||||
|
#define SPAD_ADDR_Q2 0x800
|
||||||
|
#define SPAD_ADDR_Q3 0xc00
|
||||||
|
#define SPAD_ADDR_Q4 0x1000
|
||||||
|
typedef uint16_t smem_elem_t;
|
||||||
|
typedef uint32_t mem_elem_t;
|
||||||
|
#endif
|
||||||
|
|
||||||
#define HARDCODE
|
#define HARDCODE
|
||||||
#define REGBLOCK
|
#define REGBLOCK
|
||||||
@@ -61,9 +89,9 @@ inline void threadblock_barrier(unsigned int barrier_id, unsigned int count) {
|
|||||||
void thread_block_matmul_gemmini(kernel_arg_t *__UNIFORM__ arg,
|
void thread_block_matmul_gemmini(kernel_arg_t *__UNIFORM__ arg,
|
||||||
const uint32_t threadblock_id,
|
const uint32_t threadblock_id,
|
||||||
const uint32_t tid_in_threadblock) {
|
const uint32_t tid_in_threadblock) {
|
||||||
const float * const A = (const float * const) arg->addr_a;
|
const smem_elem_t * const A = (const smem_elem_t * const) arg->addr_a;
|
||||||
const float * const B = (const float * const) arg->addr_b;
|
const smem_elem_t * const B = (const smem_elem_t * const) arg->addr_b;
|
||||||
float * const C = (float * const) arg->addr_c;
|
smem_elem_t * const C = (smem_elem_t * const) arg->addr_c;
|
||||||
|
|
||||||
if (tid_in_threadblock % NUM_THREADS_IN_CLUSTER == 0) {
|
if (tid_in_threadblock % NUM_THREADS_IN_CLUSTER == 0) {
|
||||||
gemmini_extended_config_ex(WEIGHT_STATIONARY, 0, 0, 1, 0, 0);
|
gemmini_extended_config_ex(WEIGHT_STATIONARY, 0, 0, 1, 0, 0);
|
||||||
@@ -123,11 +151,11 @@ void thread_block_matmul_gemmini(kernel_arg_t *__UNIFORM__ arg,
|
|||||||
tile_i < num_tile_rows_per_tb * (threadblock_id + 1);
|
tile_i < num_tile_rows_per_tb * (threadblock_id + 1);
|
||||||
tile_i += 1) {
|
tile_i += 1) {
|
||||||
for (int tile_j = 0; tile_j < num_tiles_n; tile_j += 1) {
|
for (int tile_j = 0; tile_j < num_tiles_n; tile_j += 1) {
|
||||||
float * const smem_c_tile_start = SMEM_ADDR_Q1;
|
mem_elem_t * const smem_c_tile_start = SMEM_ADDR_Q1;
|
||||||
#ifdef OFFLOAD_ACCUMULATE
|
#ifdef OFFLOAD_ACCUMULATE
|
||||||
float * const smem_acc_tile_start = SMEM_ADDR_Q0 + HW_TID();
|
mem_elem_t * const smem_acc_tile_start = SMEM_ADDR_Q0 + HW_TID();
|
||||||
#else
|
#else
|
||||||
float * const smem_acc_tile_start = SMEM_ADDR_Q2 + hw_tid;
|
mem_elem_t * const smem_acc_tile_start = SMEM_ADDR_Q2 + hw_tid;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
for (int tile_k = 0; tile_k < num_tiles_k; tile_k += 1) {
|
for (int tile_k = 0; tile_k < num_tiles_k; tile_k += 1) {
|
||||||
@@ -140,19 +168,19 @@ void thread_block_matmul_gemmini(kernel_arg_t *__UNIFORM__ arg,
|
|||||||
// #endif
|
// #endif
|
||||||
|
|
||||||
constexpr uint32_t every_iter = j1_stride;
|
constexpr uint32_t every_iter = j1_stride;
|
||||||
const uint32_t every_2iters_a = i1_stride * dim_k;
|
const uint32_t every_2iters_a = i1_stride * (dim_k * sizeof(smem_elem_t) / 4);
|
||||||
const uint32_t runtime_const_a = i0 * dim_k + j1_idx + j0;
|
const uint32_t runtime_const_a = i0 * (dim_k * sizeof(smem_elem_t) / 4) + j1_idx + j0;
|
||||||
const uint32_t every_2iters_b = i1_stride * dim_n;
|
const uint32_t every_2iters_b = i1_stride * dim_n;
|
||||||
const uint32_t runtime_const_b = i0 * dim_n + j1_idx + j0;
|
const uint32_t runtime_const_b = i0 * dim_n + j1_idx + j0;
|
||||||
|
|
||||||
const float * const dram_a_tile_start = A + tile_i * TILE_M * dim_k + tile_k * TILE_K + runtime_const_a;
|
const mem_elem_t * const dram_a_tile_start = (const mem_elem_t * const) (A + tile_i * TILE_M * dim_k + tile_k * TILE_K + runtime_const_a);
|
||||||
const float * const dram_b_tile_start = B + tile_k * TILE_K * dim_n + tile_j * TILE_N + runtime_const_b;
|
const mem_elem_t * const dram_b_tile_start = (const mem_elem_t * const) (B + tile_k * TILE_K * dim_n + tile_j * TILE_N + runtime_const_b);
|
||||||
#ifdef DBUF
|
#ifdef DBUF
|
||||||
float * const smem_a_tile_start = ((tile_k & 1) ? SMEM_ADDR_Q1 : SMEM_ADDR_Q0) + HW_TID();
|
mem_elem_t * const smem_a_tile_start = (mem_elem_t * const) (((tile_k & 1) ? SMEM_ADDR_Q1 : SMEM_ADDR_Q0) + HW_TID());
|
||||||
float * const smem_b_tile_start = ((tile_k & 1) ? SMEM_ADDR_Q3 : SMEM_ADDR_Q2) + HW_TID();
|
mem_elem_t * const smem_b_tile_start = (mem_elem_t * const) (((tile_k & 1) ? SMEM_ADDR_Q3 : SMEM_ADDR_Q2) + HW_TID());
|
||||||
#else
|
#else
|
||||||
float * const smem_a_tile_start = SMEM_ADDR_Q0 + HW_TID();
|
mem_elem_t * const smem_a_tile_start = (mem_elem_t * const) (SMEM_ADDR_Q0 + HW_TID());
|
||||||
float * const smem_b_tile_start = SMEM_ADDR_Q3 + HW_TID();
|
mem_elem_t * const smem_b_tile_start = (mem_elem_t * const) (SMEM_ADDR_Q3 + HW_TID());
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
{
|
{
|
||||||
@@ -191,10 +219,10 @@ void thread_block_matmul_gemmini(kernel_arg_t *__UNIFORM__ arg,
|
|||||||
smem_b_tile_start[7 * num_threads_in_cluster + hw_tid] = \
|
smem_b_tile_start[7 * num_threads_in_cluster + hw_tid] = \
|
||||||
dram_b_tile_start[every_iter * 1 + every_2iters_b * 3];
|
dram_b_tile_start[every_iter * 1 + every_2iters_b * 3];
|
||||||
#else
|
#else
|
||||||
float v0 = dram_a_tile_start[every_iter * 0 + every_2iters_a * 0];
|
mem_elem_t v0 = dram_a_tile_start[every_iter * 0 + every_2iters_a * 0];
|
||||||
float v1 = dram_a_tile_start[every_iter * 1 + every_2iters_a * 0];
|
mem_elem_t v1 = dram_a_tile_start[every_iter * 1 + every_2iters_a * 0];
|
||||||
float v2 = dram_a_tile_start[every_iter * 0 + every_2iters_a * 1];
|
mem_elem_t v2 = dram_a_tile_start[every_iter * 0 + every_2iters_a * 1];
|
||||||
float v3 = dram_a_tile_start[every_iter * 1 + every_2iters_a * 1];
|
mem_elem_t v3 = dram_a_tile_start[every_iter * 1 + every_2iters_a * 1];
|
||||||
smem_a_tile_start[0 * num_threads_in_cluster] = v0;
|
smem_a_tile_start[0 * num_threads_in_cluster] = v0;
|
||||||
smem_a_tile_start[1 * num_threads_in_cluster] = v1;
|
smem_a_tile_start[1 * num_threads_in_cluster] = v1;
|
||||||
smem_a_tile_start[2 * num_threads_in_cluster] = v2;
|
smem_a_tile_start[2 * num_threads_in_cluster] = v2;
|
||||||
@@ -236,14 +264,14 @@ void thread_block_matmul_gemmini(kernel_arg_t *__UNIFORM__ arg,
|
|||||||
smem_a_tile_start[10 * num_threads_in_cluster] = v2;
|
smem_a_tile_start[10 * num_threads_in_cluster] = v2;
|
||||||
smem_a_tile_start[11 * num_threads_in_cluster] = v3;
|
smem_a_tile_start[11 * num_threads_in_cluster] = v3;
|
||||||
|
|
||||||
v0 = dram_b_tile_start[every_iter * 0 + every_2iters_b * 4];
|
// v0 = dram_b_tile_start[every_iter * 0 + every_2iters_b * 4];
|
||||||
v1 = dram_b_tile_start[every_iter * 1 + every_2iters_b * 4];
|
// v1 = dram_b_tile_start[every_iter * 1 + every_2iters_b * 4];
|
||||||
v2 = dram_b_tile_start[every_iter * 0 + every_2iters_b * 5];
|
// v2 = dram_b_tile_start[every_iter * 0 + every_2iters_b * 5];
|
||||||
v3 = dram_b_tile_start[every_iter * 1 + every_2iters_b * 5];
|
// v3 = dram_b_tile_start[every_iter * 1 + every_2iters_b * 5];
|
||||||
smem_b_tile_start[8 * num_threads_in_cluster] = v0;
|
// smem_b_tile_start[8 * num_threads_in_cluster] = v0;
|
||||||
smem_b_tile_start[9 * num_threads_in_cluster] = v1;
|
// smem_b_tile_start[9 * num_threads_in_cluster] = v1;
|
||||||
smem_b_tile_start[10 * num_threads_in_cluster] = v2;
|
// smem_b_tile_start[10 * num_threads_in_cluster] = v2;
|
||||||
smem_b_tile_start[11 * num_threads_in_cluster] = v3;
|
// smem_b_tile_start[11 * num_threads_in_cluster] = v3;
|
||||||
|
|
||||||
v0 = dram_a_tile_start[every_iter * 0 + every_2iters_a * 6];
|
v0 = dram_a_tile_start[every_iter * 0 + every_2iters_a * 6];
|
||||||
v1 = dram_a_tile_start[every_iter * 1 + every_2iters_a * 6];
|
v1 = dram_a_tile_start[every_iter * 1 + every_2iters_a * 6];
|
||||||
@@ -254,14 +282,14 @@ void thread_block_matmul_gemmini(kernel_arg_t *__UNIFORM__ arg,
|
|||||||
smem_a_tile_start[14 * num_threads_in_cluster] = v2;
|
smem_a_tile_start[14 * num_threads_in_cluster] = v2;
|
||||||
smem_a_tile_start[15 * num_threads_in_cluster] = v3;
|
smem_a_tile_start[15 * num_threads_in_cluster] = v3;
|
||||||
|
|
||||||
v0 = dram_b_tile_start[every_iter * 0 + every_2iters_b * 6];
|
// v0 = dram_b_tile_start[every_iter * 0 + every_2iters_b * 6];
|
||||||
v1 = dram_b_tile_start[every_iter * 1 + every_2iters_b * 6];
|
// v1 = dram_b_tile_start[every_iter * 1 + every_2iters_b * 6];
|
||||||
v2 = dram_b_tile_start[every_iter * 0 + every_2iters_b * 7];
|
// v2 = dram_b_tile_start[every_iter * 0 + every_2iters_b * 7];
|
||||||
v3 = dram_b_tile_start[every_iter * 1 + every_2iters_b * 7];
|
// v3 = dram_b_tile_start[every_iter * 1 + every_2iters_b * 7];
|
||||||
smem_b_tile_start[12 * num_threads_in_cluster] = v0;
|
// smem_b_tile_start[12 * num_threads_in_cluster] = v0;
|
||||||
smem_b_tile_start[13 * num_threads_in_cluster] = v1;
|
// smem_b_tile_start[13 * num_threads_in_cluster] = v1;
|
||||||
smem_b_tile_start[14 * num_threads_in_cluster] = v2;
|
// smem_b_tile_start[14 * num_threads_in_cluster] = v2;
|
||||||
smem_b_tile_start[15 * num_threads_in_cluster] = v3;
|
// smem_b_tile_start[15 * num_threads_in_cluster] = v3;
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
@@ -440,8 +468,8 @@ void thread_block_matmul_gemmini(kernel_arg_t *__UNIFORM__ arg,
|
|||||||
#ifdef CISC
|
#ifdef CISC
|
||||||
GEMMINI_CISC_CMD_I(9);
|
GEMMINI_CISC_CMD_I(9);
|
||||||
#else
|
#else
|
||||||
ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, 0, (((uint64_t) TILE_M / DIM) << 32) |
|
ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, 0, (((uint64_t) TILE_K / DIM) << 32) |
|
||||||
(((uint64_t) TILE_K / DIM) << 16) | ((uint64_t) TILE_N / DIM), k_LOOP_WS_CONFIG_BOUNDS)
|
(((uint64_t) TILE_N / DIM) << 16) | ((uint64_t) TILE_M / DIM), k_LOOP_WS_CONFIG_BOUNDS)
|
||||||
ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, 0, 0x278U, k_LOOP_WS)
|
ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, 0, 0x278U, k_LOOP_WS)
|
||||||
#endif
|
#endif
|
||||||
gemmini_fence();
|
gemmini_fence();
|
||||||
@@ -458,13 +486,13 @@ void thread_block_matmul_gemmini(kernel_arg_t *__UNIFORM__ arg,
|
|||||||
constexpr uint32_t every_iter = j1_stride;
|
constexpr uint32_t every_iter = j1_stride;
|
||||||
const uint32_t every_2iters = i1_stride * dim_n;
|
const uint32_t every_2iters = i1_stride * dim_n;
|
||||||
const uint32_t runtime_const = i0 * dim_n + j1_idx + j0;
|
const uint32_t runtime_const = i0 * dim_n + j1_idx + j0;
|
||||||
float * const dram_c_tile_start = C + tile_i * TILE_M * dim_n + tile_j * TILE_N + runtime_const;
|
mem_elem_t * const dram_c_tile_start = (mem_elem_t * const) (C + tile_i * TILE_M * dim_n + tile_j * TILE_N + runtime_const);
|
||||||
|
|
||||||
#ifdef REGBLOCK
|
#ifdef REGBLOCK
|
||||||
float v0 = smem_acc_tile_start[0 * num_threads_in_cluster];
|
mem_elem_t v0 = smem_acc_tile_start[0 * num_threads_in_cluster];
|
||||||
float v1 = smem_acc_tile_start[1 * num_threads_in_cluster];
|
mem_elem_t v1 = smem_acc_tile_start[1 * num_threads_in_cluster];
|
||||||
float v2 = smem_acc_tile_start[2 * num_threads_in_cluster];
|
mem_elem_t v2 = smem_acc_tile_start[2 * num_threads_in_cluster];
|
||||||
float v3 = smem_acc_tile_start[3 * num_threads_in_cluster];
|
mem_elem_t v3 = smem_acc_tile_start[3 * num_threads_in_cluster];
|
||||||
#ifdef ACTIVATE
|
#ifdef ACTIVATE
|
||||||
uint32_t swish_start, swish_end;
|
uint32_t swish_start, swish_end;
|
||||||
rd_cycles_force(swish_start);
|
rd_cycles_force(swish_start);
|
||||||
@@ -498,23 +526,23 @@ void thread_block_matmul_gemmini(kernel_arg_t *__UNIFORM__ arg,
|
|||||||
dram_c_tile_start[every_iter * 0 + every_2iters * 3] = v2;
|
dram_c_tile_start[every_iter * 0 + every_2iters * 3] = v2;
|
||||||
dram_c_tile_start[every_iter * 1 + every_2iters * 3] = v3;
|
dram_c_tile_start[every_iter * 1 + every_2iters * 3] = v3;
|
||||||
|
|
||||||
v0 = smem_acc_tile_start[8 * num_threads_in_cluster];
|
// v0 = smem_acc_tile_start[8 * num_threads_in_cluster];
|
||||||
v1 = smem_acc_tile_start[9 * num_threads_in_cluster];
|
// v1 = smem_acc_tile_start[9 * num_threads_in_cluster];
|
||||||
v2 = smem_acc_tile_start[10 * num_threads_in_cluster];
|
// v2 = smem_acc_tile_start[10 * num_threads_in_cluster];
|
||||||
v3 = smem_acc_tile_start[11 * num_threads_in_cluster];
|
// v3 = smem_acc_tile_start[11 * num_threads_in_cluster];
|
||||||
dram_c_tile_start[every_iter * 0 + every_2iters * 4] = v0;
|
// dram_c_tile_start[every_iter * 0 + every_2iters * 4] = v0;
|
||||||
dram_c_tile_start[every_iter * 1 + every_2iters * 4] = v1;
|
// dram_c_tile_start[every_iter * 1 + every_2iters * 4] = v1;
|
||||||
dram_c_tile_start[every_iter * 0 + every_2iters * 5] = v2;
|
// dram_c_tile_start[every_iter * 0 + every_2iters * 5] = v2;
|
||||||
dram_c_tile_start[every_iter * 1 + every_2iters * 5] = v3;
|
// dram_c_tile_start[every_iter * 1 + every_2iters * 5] = v3;
|
||||||
|
|
||||||
v0 = smem_acc_tile_start[12 * num_threads_in_cluster];
|
// v0 = smem_acc_tile_start[12 * num_threads_in_cluster];
|
||||||
v1 = smem_acc_tile_start[13 * num_threads_in_cluster];
|
// v1 = smem_acc_tile_start[13 * num_threads_in_cluster];
|
||||||
v2 = smem_acc_tile_start[14 * num_threads_in_cluster];
|
// v2 = smem_acc_tile_start[14 * num_threads_in_cluster];
|
||||||
v3 = smem_acc_tile_start[15 * num_threads_in_cluster];
|
// v3 = smem_acc_tile_start[15 * num_threads_in_cluster];
|
||||||
dram_c_tile_start[every_iter * 0 + every_2iters * 6] = v0;
|
// dram_c_tile_start[every_iter * 0 + every_2iters * 6] = v0;
|
||||||
dram_c_tile_start[every_iter * 1 + every_2iters * 6] = v1;
|
// dram_c_tile_start[every_iter * 1 + every_2iters * 6] = v1;
|
||||||
dram_c_tile_start[every_iter * 0 + every_2iters * 7] = v2;
|
// dram_c_tile_start[every_iter * 0 + every_2iters * 7] = v2;
|
||||||
dram_c_tile_start[every_iter * 1 + every_2iters * 7] = v3;
|
// dram_c_tile_start[every_iter * 1 + every_2iters * 7] = v3;
|
||||||
#else
|
#else
|
||||||
dram_c_tile_start[every_iter * 0 + every_2iters * 0] = \
|
dram_c_tile_start[every_iter * 0 + every_2iters * 0] = \
|
||||||
smem_acc_tile_start[0 * num_threads_in_cluster];
|
smem_acc_tile_start[0 * num_threads_in_cluster];
|
||||||
|
|||||||
Reference in New Issue
Block a user