update gemmini dma kernel
This commit is contained in:
@@ -49,6 +49,8 @@
|
|||||||
#define rd_cycles_force(x) asm volatile ("csrr %0, mcycle" : "=r" (x))
|
#define rd_cycles_force(x) asm volatile ("csrr %0, mcycle" : "=r" (x))
|
||||||
#define rd_cycles(x) rd_cycles_force(x)
|
#define rd_cycles(x) rd_cycles_force(x)
|
||||||
#define HW_TID() ({uint32_t gtid; asm volatile ("csrr %0, mhartid" : "=r" (gtid)); gtid;})
|
#define HW_TID() ({uint32_t gtid; asm volatile ("csrr %0, mhartid" : "=r" (gtid)); gtid;})
|
||||||
|
#define MARK_BEG() asm volatile ("slti x0, x1, -1047")
|
||||||
|
#define MARK_END() asm volatile ("slti x0, x1, -499")
|
||||||
#define PRINTF(...) sprintf(PRINT_BUF, __VA_ARGS__)
|
#define PRINTF(...) sprintf(PRINT_BUF, __VA_ARGS__)
|
||||||
// #define PRINTF(...) vx_printf(__VA_ARGS__)
|
// #define PRINTF(...) vx_printf(__VA_ARGS__)
|
||||||
#define SWISH(beta, x) ((x) / (1 + exp(-(beta) * (x))))
|
#define SWISH(beta, x) ((x) / (1 + exp(-(beta) * (x))))
|
||||||
@@ -98,6 +100,7 @@ void thread_block_matmul_gemmini(kernel_arg_t *__UNIFORM__ arg,
|
|||||||
|
|
||||||
uint32_t marker0, marker1;
|
uint32_t marker0, marker1;
|
||||||
rd_cycles_force(marker0);
|
rd_cycles_force(marker0);
|
||||||
|
MARK_BEG();
|
||||||
|
|
||||||
const uint32_t dim_m = arg->dim_m;
|
const uint32_t dim_m = arg->dim_m;
|
||||||
const uint32_t dim_n = arg->dim_n;
|
const uint32_t dim_n = arg->dim_n;
|
||||||
@@ -109,15 +112,6 @@ void thread_block_matmul_gemmini(kernel_arg_t *__UNIFORM__ arg,
|
|||||||
|
|
||||||
const uint32_t num_tile_rows_per_tb = num_tiles_m / NUM_CLUSTERS;
|
const uint32_t num_tile_rows_per_tb = num_tiles_m / NUM_CLUSTERS;
|
||||||
|
|
||||||
if (HW_TID() == 0) gemmini_fence();
|
|
||||||
threadblock_barrier(3, NUM_WARPS);
|
|
||||||
if (HW_TID() == 0) gemmini_fence();
|
|
||||||
threadblock_barrier(3, NUM_WARPS);
|
|
||||||
if (HW_TID() == 0) gemmini_fence();
|
|
||||||
threadblock_barrier(3, NUM_WARPS);
|
|
||||||
if (HW_TID() == 0) gemmini_fence();
|
|
||||||
threadblock_barrier(3, NUM_WARPS);
|
|
||||||
|
|
||||||
if (HW_TID() == 0) {
|
if (HW_TID() == 0) {
|
||||||
gemmini_extended3_config_ld(dim_k * sizeof(elem_t), MVIN_SCALE_IDENTITY, false, 0);
|
gemmini_extended3_config_ld(dim_k * sizeof(elem_t), MVIN_SCALE_IDENTITY, false, 0);
|
||||||
gemmini_extended3_config_ld(dim_n * sizeof(elem_t), MVIN_SCALE_IDENTITY, false, 1);
|
gemmini_extended3_config_ld(dim_n * sizeof(elem_t), MVIN_SCALE_IDENTITY, false, 1);
|
||||||
@@ -179,6 +173,7 @@ void thread_block_matmul_gemmini(kernel_arg_t *__UNIFORM__ arg,
|
|||||||
// last thread block complete
|
// last thread block complete
|
||||||
if (threadblock_id == NUM_CLUSTERS - 1) {
|
if (threadblock_id == NUM_CLUSTERS - 1) {
|
||||||
threadblock_barrier(/*barrier_id=*/0, /*count=*/NUM_WARPS);
|
threadblock_barrier(/*barrier_id=*/0, /*count=*/NUM_WARPS);
|
||||||
|
MARK_END();
|
||||||
rd_cycles_force(marker1);
|
rd_cycles_force(marker1);
|
||||||
if (HW_TID() == 0) {
|
if (HW_TID() == 0) {
|
||||||
#ifdef POWER
|
#ifdef POWER
|
||||||
|
|||||||
Reference in New Issue
Block a user