diff --git a/tests/regression/sgemm_gemmini/kernel.cpp b/tests/regression/sgemm_gemmini/kernel.cpp index c609eec1..28f65d7c 100644 --- a/tests/regression/sgemm_gemmini/kernel.cpp +++ b/tests/regression/sgemm_gemmini/kernel.cpp @@ -6,25 +6,53 @@ #include "include/gemmini.h" #include "gemmini_mmio.h" +#define NUM_CLUSTERS 1 +// #define FP32 + +#ifdef FP32 +// fp32 #define TILE_M 64 #define TILE_N 64 #define TILE_K 64 #define TILE_MN 4096 #define TILE_MK 4096 #define TILE_NK 4096 - -#define NUM_CLUSTERS 1 #define NUM_THREADS_IN_CLUSTER 256 -#define SMEM_ADDR_Q0 ((float * const) 0xff000000) -#define SMEM_ADDR_Q1 ((float * const) 0xff004000) -#define SMEM_ADDR_Q2 ((float * const) 0xff008000) -#define SMEM_ADDR_Q3 ((float * const) 0xff00c000) +#define SMEM_ADDR_Q0 ((mem_elem_t * const) 0xff000000) +#define SMEM_ADDR_Q1 ((mem_elem_t * const) 0xff004000) +#define SMEM_ADDR_Q2 ((mem_elem_t * const) 0xff008000) +#define SMEM_ADDR_Q3 ((mem_elem_t * const) 0xff00c000) #define SPAD_ADDR_Q0 0x0 #define SPAD_ADDR_Q1 0x200 #define SPAD_ADDR_Q2 0x400 #define SPAD_ADDR_Q3 0x600 #define SPAD_ADDR_Q4 0x800 +typedef float smem_elem_t; +typedef float mem_elem_t; + +#else +// fp16 +#define TILE_M 128 +#define TILE_N 64 +#define TILE_K 128 +#define TILE_MN 8192 +#define TILE_MK 16384 +#define TILE_NK 8192 +#define NUM_THREADS_IN_CLUSTER 512 + +#define SMEM_ADDR_Q0 ((mem_elem_t * const) 0xff000000) +#define SMEM_ADDR_Q1 ((mem_elem_t * const) 0xff008000) +#define SMEM_ADDR_Q2 ((mem_elem_t * const) 0xff001000) +#define SMEM_ADDR_Q3 ((mem_elem_t * const) 0xff018000) +#define SPAD_ADDR_Q0 0x0 +#define SPAD_ADDR_Q1 0x400 +#define SPAD_ADDR_Q2 0x800 +#define SPAD_ADDR_Q3 0xc00 +#define SPAD_ADDR_Q4 0x1000 +typedef uint16_t smem_elem_t; +typedef uint32_t mem_elem_t; +#endif #define HARDCODE #define REGBLOCK @@ -61,9 +89,9 @@ inline void threadblock_barrier(unsigned int barrier_id, unsigned int count) { void thread_block_matmul_gemmini(kernel_arg_t *__UNIFORM__ arg, const uint32_t threadblock_id, const uint32_t tid_in_threadblock) { - const float * const A = (const float * const) arg->addr_a; - const float * const B = (const float * const) arg->addr_b; - float * const C = (float * const) arg->addr_c; + const smem_elem_t * const A = (const smem_elem_t * const) arg->addr_a; + const smem_elem_t * const B = (const smem_elem_t * const) arg->addr_b; + smem_elem_t * const C = (smem_elem_t * const) arg->addr_c; if (tid_in_threadblock % NUM_THREADS_IN_CLUSTER == 0) { gemmini_extended_config_ex(WEIGHT_STATIONARY, 0, 0, 1, 0, 0); @@ -123,11 +151,11 @@ void thread_block_matmul_gemmini(kernel_arg_t *__UNIFORM__ arg, tile_i < num_tile_rows_per_tb * (threadblock_id + 1); tile_i += 1) { for (int tile_j = 0; tile_j < num_tiles_n; tile_j += 1) { - float * const smem_c_tile_start = SMEM_ADDR_Q1; + mem_elem_t * const smem_c_tile_start = SMEM_ADDR_Q1; #ifdef OFFLOAD_ACCUMULATE - float * const smem_acc_tile_start = SMEM_ADDR_Q0 + HW_TID(); + mem_elem_t * const smem_acc_tile_start = SMEM_ADDR_Q0 + HW_TID(); #else - float * const smem_acc_tile_start = SMEM_ADDR_Q2 + hw_tid; + mem_elem_t * const smem_acc_tile_start = SMEM_ADDR_Q2 + hw_tid; #endif for (int tile_k = 0; tile_k < num_tiles_k; tile_k += 1) { @@ -140,19 +168,19 @@ void thread_block_matmul_gemmini(kernel_arg_t *__UNIFORM__ arg, // #endif constexpr uint32_t every_iter = j1_stride; - const uint32_t every_2iters_a = i1_stride * dim_k; - const uint32_t runtime_const_a = i0 * dim_k + j1_idx + j0; + const uint32_t every_2iters_a = i1_stride * (dim_k * sizeof(smem_elem_t) / 4); + const uint32_t runtime_const_a = i0 * (dim_k * sizeof(smem_elem_t) / 4) + j1_idx + j0; const uint32_t every_2iters_b = i1_stride * dim_n; const uint32_t runtime_const_b = i0 * dim_n + j1_idx + j0; - const float * const dram_a_tile_start = A + tile_i * TILE_M * dim_k + tile_k * TILE_K + runtime_const_a; - const float * const dram_b_tile_start = B + tile_k * TILE_K * dim_n + tile_j * TILE_N + runtime_const_b; + const mem_elem_t * const dram_a_tile_start = (const mem_elem_t * const) (A + tile_i * TILE_M * dim_k + tile_k * TILE_K + runtime_const_a); + const mem_elem_t * const dram_b_tile_start = (const mem_elem_t * const) (B + tile_k * TILE_K * dim_n + tile_j * TILE_N + runtime_const_b); #ifdef DBUF - float * const smem_a_tile_start = ((tile_k & 1) ? SMEM_ADDR_Q1 : SMEM_ADDR_Q0) + HW_TID(); - float * const smem_b_tile_start = ((tile_k & 1) ? SMEM_ADDR_Q3 : SMEM_ADDR_Q2) + HW_TID(); + mem_elem_t * const smem_a_tile_start = (mem_elem_t * const) (((tile_k & 1) ? SMEM_ADDR_Q1 : SMEM_ADDR_Q0) + HW_TID()); + mem_elem_t * const smem_b_tile_start = (mem_elem_t * const) (((tile_k & 1) ? SMEM_ADDR_Q3 : SMEM_ADDR_Q2) + HW_TID()); #else - float * const smem_a_tile_start = SMEM_ADDR_Q0 + HW_TID(); - float * const smem_b_tile_start = SMEM_ADDR_Q3 + HW_TID(); + mem_elem_t * const smem_a_tile_start = (mem_elem_t * const) (SMEM_ADDR_Q0 + HW_TID()); + mem_elem_t * const smem_b_tile_start = (mem_elem_t * const) (SMEM_ADDR_Q3 + HW_TID()); #endif { @@ -191,10 +219,10 @@ void thread_block_matmul_gemmini(kernel_arg_t *__UNIFORM__ arg, smem_b_tile_start[7 * num_threads_in_cluster + hw_tid] = \ dram_b_tile_start[every_iter * 1 + every_2iters_b * 3]; #else - float v0 = dram_a_tile_start[every_iter * 0 + every_2iters_a * 0]; - float v1 = dram_a_tile_start[every_iter * 1 + every_2iters_a * 0]; - float v2 = dram_a_tile_start[every_iter * 0 + every_2iters_a * 1]; - float v3 = dram_a_tile_start[every_iter * 1 + every_2iters_a * 1]; + mem_elem_t v0 = dram_a_tile_start[every_iter * 0 + every_2iters_a * 0]; + mem_elem_t v1 = dram_a_tile_start[every_iter * 1 + every_2iters_a * 0]; + mem_elem_t v2 = dram_a_tile_start[every_iter * 0 + every_2iters_a * 1]; + mem_elem_t v3 = dram_a_tile_start[every_iter * 1 + every_2iters_a * 1]; smem_a_tile_start[0 * num_threads_in_cluster] = v0; smem_a_tile_start[1 * num_threads_in_cluster] = v1; smem_a_tile_start[2 * num_threads_in_cluster] = v2; @@ -236,14 +264,14 @@ void thread_block_matmul_gemmini(kernel_arg_t *__UNIFORM__ arg, smem_a_tile_start[10 * num_threads_in_cluster] = v2; smem_a_tile_start[11 * num_threads_in_cluster] = v3; - v0 = dram_b_tile_start[every_iter * 0 + every_2iters_b * 4]; - v1 = dram_b_tile_start[every_iter * 1 + every_2iters_b * 4]; - v2 = dram_b_tile_start[every_iter * 0 + every_2iters_b * 5]; - v3 = dram_b_tile_start[every_iter * 1 + every_2iters_b * 5]; - smem_b_tile_start[8 * num_threads_in_cluster] = v0; - smem_b_tile_start[9 * num_threads_in_cluster] = v1; - smem_b_tile_start[10 * num_threads_in_cluster] = v2; - smem_b_tile_start[11 * num_threads_in_cluster] = v3; + // v0 = dram_b_tile_start[every_iter * 0 + every_2iters_b * 4]; + // v1 = dram_b_tile_start[every_iter * 1 + every_2iters_b * 4]; + // v2 = dram_b_tile_start[every_iter * 0 + every_2iters_b * 5]; + // v3 = dram_b_tile_start[every_iter * 1 + every_2iters_b * 5]; + // smem_b_tile_start[8 * num_threads_in_cluster] = v0; + // smem_b_tile_start[9 * num_threads_in_cluster] = v1; + // smem_b_tile_start[10 * num_threads_in_cluster] = v2; + // smem_b_tile_start[11 * num_threads_in_cluster] = v3; v0 = dram_a_tile_start[every_iter * 0 + every_2iters_a * 6]; v1 = dram_a_tile_start[every_iter * 1 + every_2iters_a * 6]; @@ -254,14 +282,14 @@ void thread_block_matmul_gemmini(kernel_arg_t *__UNIFORM__ arg, smem_a_tile_start[14 * num_threads_in_cluster] = v2; smem_a_tile_start[15 * num_threads_in_cluster] = v3; - v0 = dram_b_tile_start[every_iter * 0 + every_2iters_b * 6]; - v1 = dram_b_tile_start[every_iter * 1 + every_2iters_b * 6]; - v2 = dram_b_tile_start[every_iter * 0 + every_2iters_b * 7]; - v3 = dram_b_tile_start[every_iter * 1 + every_2iters_b * 7]; - smem_b_tile_start[12 * num_threads_in_cluster] = v0; - smem_b_tile_start[13 * num_threads_in_cluster] = v1; - smem_b_tile_start[14 * num_threads_in_cluster] = v2; - smem_b_tile_start[15 * num_threads_in_cluster] = v3; + // v0 = dram_b_tile_start[every_iter * 0 + every_2iters_b * 6]; + // v1 = dram_b_tile_start[every_iter * 1 + every_2iters_b * 6]; + // v2 = dram_b_tile_start[every_iter * 0 + every_2iters_b * 7]; + // v3 = dram_b_tile_start[every_iter * 1 + every_2iters_b * 7]; + // smem_b_tile_start[12 * num_threads_in_cluster] = v0; + // smem_b_tile_start[13 * num_threads_in_cluster] = v1; + // smem_b_tile_start[14 * num_threads_in_cluster] = v2; + // smem_b_tile_start[15 * num_threads_in_cluster] = v3; #endif } #else @@ -440,8 +468,8 @@ void thread_block_matmul_gemmini(kernel_arg_t *__UNIFORM__ arg, #ifdef CISC GEMMINI_CISC_CMD_I(9); #else - ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, 0, (((uint64_t) TILE_M / DIM) << 32) | - (((uint64_t) TILE_K / DIM) << 16) | ((uint64_t) TILE_N / DIM), k_LOOP_WS_CONFIG_BOUNDS) + ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, 0, (((uint64_t) TILE_K / DIM) << 32) | + (((uint64_t) TILE_N / DIM) << 16) | ((uint64_t) TILE_M / DIM), k_LOOP_WS_CONFIG_BOUNDS) ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, 0, 0x278U, k_LOOP_WS) #endif gemmini_fence(); @@ -458,13 +486,13 @@ void thread_block_matmul_gemmini(kernel_arg_t *__UNIFORM__ arg, constexpr uint32_t every_iter = j1_stride; const uint32_t every_2iters = i1_stride * dim_n; const uint32_t runtime_const = i0 * dim_n + j1_idx + j0; - float * const dram_c_tile_start = C + tile_i * TILE_M * dim_n + tile_j * TILE_N + runtime_const; + mem_elem_t * const dram_c_tile_start = (mem_elem_t * const) (C + tile_i * TILE_M * dim_n + tile_j * TILE_N + runtime_const); #ifdef REGBLOCK - float v0 = smem_acc_tile_start[0 * num_threads_in_cluster]; - float v1 = smem_acc_tile_start[1 * num_threads_in_cluster]; - float v2 = smem_acc_tile_start[2 * num_threads_in_cluster]; - float v3 = smem_acc_tile_start[3 * num_threads_in_cluster]; + mem_elem_t v0 = smem_acc_tile_start[0 * num_threads_in_cluster]; + mem_elem_t v1 = smem_acc_tile_start[1 * num_threads_in_cluster]; + mem_elem_t v2 = smem_acc_tile_start[2 * num_threads_in_cluster]; + mem_elem_t v3 = smem_acc_tile_start[3 * num_threads_in_cluster]; #ifdef ACTIVATE uint32_t swish_start, swish_end; rd_cycles_force(swish_start); @@ -498,23 +526,23 @@ void thread_block_matmul_gemmini(kernel_arg_t *__UNIFORM__ arg, dram_c_tile_start[every_iter * 0 + every_2iters * 3] = v2; dram_c_tile_start[every_iter * 1 + every_2iters * 3] = v3; - v0 = smem_acc_tile_start[8 * num_threads_in_cluster]; - v1 = smem_acc_tile_start[9 * num_threads_in_cluster]; - v2 = smem_acc_tile_start[10 * num_threads_in_cluster]; - v3 = smem_acc_tile_start[11 * num_threads_in_cluster]; - dram_c_tile_start[every_iter * 0 + every_2iters * 4] = v0; - dram_c_tile_start[every_iter * 1 + every_2iters * 4] = v1; - dram_c_tile_start[every_iter * 0 + every_2iters * 5] = v2; - dram_c_tile_start[every_iter * 1 + every_2iters * 5] = v3; + // v0 = smem_acc_tile_start[8 * num_threads_in_cluster]; + // v1 = smem_acc_tile_start[9 * num_threads_in_cluster]; + // v2 = smem_acc_tile_start[10 * num_threads_in_cluster]; + // v3 = smem_acc_tile_start[11 * num_threads_in_cluster]; + // dram_c_tile_start[every_iter * 0 + every_2iters * 4] = v0; + // dram_c_tile_start[every_iter * 1 + every_2iters * 4] = v1; + // dram_c_tile_start[every_iter * 0 + every_2iters * 5] = v2; + // dram_c_tile_start[every_iter * 1 + every_2iters * 5] = v3; - v0 = smem_acc_tile_start[12 * num_threads_in_cluster]; - v1 = smem_acc_tile_start[13 * num_threads_in_cluster]; - v2 = smem_acc_tile_start[14 * num_threads_in_cluster]; - v3 = smem_acc_tile_start[15 * num_threads_in_cluster]; - dram_c_tile_start[every_iter * 0 + every_2iters * 6] = v0; - dram_c_tile_start[every_iter * 1 + every_2iters * 6] = v1; - dram_c_tile_start[every_iter * 0 + every_2iters * 7] = v2; - dram_c_tile_start[every_iter * 1 + every_2iters * 7] = v3; + // v0 = smem_acc_tile_start[12 * num_threads_in_cluster]; + // v1 = smem_acc_tile_start[13 * num_threads_in_cluster]; + // v2 = smem_acc_tile_start[14 * num_threads_in_cluster]; + // v3 = smem_acc_tile_start[15 * num_threads_in_cluster]; + // dram_c_tile_start[every_iter * 0 + every_2iters * 6] = v0; + // dram_c_tile_start[every_iter * 1 + every_2iters * 6] = v1; + // dram_c_tile_start[every_iter * 0 + every_2iters * 7] = v2; + // dram_c_tile_start[every_iter * 1 + every_2iters * 7] = v3; #else dram_c_tile_start[every_iter * 0 + every_2iters * 0] = \ smem_acc_tile_start[0 * num_threads_in_cluster];