fp16 kernel

This commit is contained in:
Richard Yan
2024-08-06 02:43:44 -07:00
parent ea4819702e
commit 4fddca3d1a
2 changed files with 51 additions and 10 deletions

View File

@@ -42,7 +42,10 @@
#define PRINTF(...) sprintf(PRINT_BUF, __VA_ARGS__)
// #define PRINTF(...) vx_printf(__VA_ARGS__)
#define SWISH(beta, x) ((x) / (1 + exp(-(beta) * (x))))
#define POWER
// #define POWER
typedef uint16_t smem_elem_t;
// typedef float smem_elem_t;
inline void threadblock_barrier(unsigned int barrier_id, unsigned int count) {
vx_fence();
@@ -53,9 +56,9 @@ void thread_block_matmul_gemmini(kernel_arg_t *__UNIFORM__ arg,
const uint32_t threadblock_id,
const uint32_t tid_in_threadblock) {
asm volatile ("matmul_start_%=:" :: );
const float * const A = (const float * const) arg->addr_a;
const float * const B = (const float * const) arg->addr_b;
float * const C = (float * const) arg->addr_c;
const smem_elem_t * const A = (const smem_elem_t * const) arg->addr_a;
const smem_elem_t * const B = (const smem_elem_t * const) arg->addr_b;
smem_elem_t * const C = (smem_elem_t * const) arg->addr_c;
if (HW_TID() == 0) {
gemmini_extended_config_ex(WEIGHT_STATIONARY, 0, 0, 1, 0, 0);
@@ -80,11 +83,13 @@ void thread_block_matmul_gemmini(kernel_arg_t *__UNIFORM__ arg,
const uint32_t num_tile_rows_per_tb = num_tiles_m / NUM_CLUSTERS;
constexpr scale_t MVIN_SCALE_IDENTITY_HEX = 0x3c00;
if (HW_TID() == 0) {
gemmini_extended3_config_ld(dim_k * sizeof(elem_t), MVIN_SCALE_IDENTITY, false, 0);
gemmini_extended3_config_ld(dim_n * sizeof(elem_t), MVIN_SCALE_IDENTITY, false, 1);
gemmini_extended3_config_ld(dim_k * sizeof(elem_t), MVIN_SCALE_IDENTITY_HEX, false, 0);
gemmini_extended3_config_ld(dim_n * sizeof(elem_t), MVIN_SCALE_IDENTITY_HEX, false, 1);
// gemmini_extended3_config_ld(repeating_bias ? 0 : (stride_D * sizeof_D), D_scale_factor, low_D, 2);
gemmini_extended_config_st(dim_n * sizeof(elem_t), 0, MVIN_SCALE_IDENTITY);
gemmini_extended_config_st(dim_n * sizeof(elem_t), 0, MVIN_SCALE_IDENTITY_HEX);
// gemmini_extended_config_st(stride_C * sizeof_C, act & 3, scale);
}
@@ -130,7 +135,7 @@ void thread_block_matmul_gemmini(kernel_arg_t *__UNIFORM__ arg,
// // move out to dram
// if (HW_TID() == 0) {
float * const dram_c_tile_start = C + tile_i * TILE_M * dim_n + tile_j * TILE_N;
smem_elem_t * const dram_c_tile_start = C + tile_i * TILE_M * dim_n + tile_j * TILE_N;
ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, 0, BOUND_INST, k_LOOP_WS_CONFIG_BOUNDS)
ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, 0, (uint64_t) dram_c_tile_start, k_LOOP_WS_CONFIG_ADDRS_DC)
ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, 0, dim_n, k_LOOP_WS_CONFIG_STRIDES_DC)
@@ -150,7 +155,8 @@ void thread_block_matmul_gemmini(kernel_arg_t *__UNIFORM__ arg,
PRINTF("total cycles: %d\n", marker1 - marker0);
for (int i = 0; i < dim_m; i += 8) {
for (int j = 0; j < dim_n; j += 8) {
PRINTF("%d %d ", (int) (C[i * dim_n + j]), (int) (C[i * dim_n + j + 4]));
// PRINTF("%d %d ", (int) (C[i * dim_n + j]), (int) (C[i * dim_n + j + 4]));
PRINTF("%04x %04x ", (int) (C[i * dim_n + j]), (int) (C[i * dim_n + j + 4]));
}
PRINTF("\n");
}
@@ -181,4 +187,4 @@ int main() {
vx_spawn_tasks_contiguous(grid_size, (vx_spawn_tasks_cb)kernel_body, arg);
#endif
return 0;
}
}