fp16 kernel
This commit is contained in:
@@ -42,7 +42,10 @@
|
||||
#define PRINTF(...) sprintf(PRINT_BUF, __VA_ARGS__)
|
||||
// #define PRINTF(...) vx_printf(__VA_ARGS__)
|
||||
#define SWISH(beta, x) ((x) / (1 + exp(-(beta) * (x))))
|
||||
#define POWER
|
||||
// #define POWER
|
||||
|
||||
typedef uint16_t smem_elem_t;
|
||||
// typedef float smem_elem_t;
|
||||
|
||||
inline void threadblock_barrier(unsigned int barrier_id, unsigned int count) {
|
||||
vx_fence();
|
||||
@@ -53,9 +56,9 @@ void thread_block_matmul_gemmini(kernel_arg_t *__UNIFORM__ arg,
|
||||
const uint32_t threadblock_id,
|
||||
const uint32_t tid_in_threadblock) {
|
||||
asm volatile ("matmul_start_%=:" :: );
|
||||
const float * const A = (const float * const) arg->addr_a;
|
||||
const float * const B = (const float * const) arg->addr_b;
|
||||
float * const C = (float * const) arg->addr_c;
|
||||
const smem_elem_t * const A = (const smem_elem_t * const) arg->addr_a;
|
||||
const smem_elem_t * const B = (const smem_elem_t * const) arg->addr_b;
|
||||
smem_elem_t * const C = (smem_elem_t * const) arg->addr_c;
|
||||
|
||||
if (HW_TID() == 0) {
|
||||
gemmini_extended_config_ex(WEIGHT_STATIONARY, 0, 0, 1, 0, 0);
|
||||
@@ -80,11 +83,13 @@ void thread_block_matmul_gemmini(kernel_arg_t *__UNIFORM__ arg,
|
||||
|
||||
const uint32_t num_tile_rows_per_tb = num_tiles_m / NUM_CLUSTERS;
|
||||
|
||||
constexpr scale_t MVIN_SCALE_IDENTITY_HEX = 0x3c00;
|
||||
|
||||
if (HW_TID() == 0) {
|
||||
gemmini_extended3_config_ld(dim_k * sizeof(elem_t), MVIN_SCALE_IDENTITY, false, 0);
|
||||
gemmini_extended3_config_ld(dim_n * sizeof(elem_t), MVIN_SCALE_IDENTITY, false, 1);
|
||||
gemmini_extended3_config_ld(dim_k * sizeof(elem_t), MVIN_SCALE_IDENTITY_HEX, false, 0);
|
||||
gemmini_extended3_config_ld(dim_n * sizeof(elem_t), MVIN_SCALE_IDENTITY_HEX, false, 1);
|
||||
// gemmini_extended3_config_ld(repeating_bias ? 0 : (stride_D * sizeof_D), D_scale_factor, low_D, 2);
|
||||
gemmini_extended_config_st(dim_n * sizeof(elem_t), 0, MVIN_SCALE_IDENTITY);
|
||||
gemmini_extended_config_st(dim_n * sizeof(elem_t), 0, MVIN_SCALE_IDENTITY_HEX);
|
||||
// gemmini_extended_config_st(stride_C * sizeof_C, act & 3, scale);
|
||||
}
|
||||
|
||||
@@ -130,7 +135,7 @@ void thread_block_matmul_gemmini(kernel_arg_t *__UNIFORM__ arg,
|
||||
|
||||
// // move out to dram
|
||||
// if (HW_TID() == 0) {
|
||||
float * const dram_c_tile_start = C + tile_i * TILE_M * dim_n + tile_j * TILE_N;
|
||||
smem_elem_t * const dram_c_tile_start = C + tile_i * TILE_M * dim_n + tile_j * TILE_N;
|
||||
ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, 0, BOUND_INST, k_LOOP_WS_CONFIG_BOUNDS)
|
||||
ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, 0, (uint64_t) dram_c_tile_start, k_LOOP_WS_CONFIG_ADDRS_DC)
|
||||
ROCC_INSTRUCTION_RS1_RS2(XCUSTOM_ACC, 0, dim_n, k_LOOP_WS_CONFIG_STRIDES_DC)
|
||||
@@ -150,7 +155,8 @@ void thread_block_matmul_gemmini(kernel_arg_t *__UNIFORM__ arg,
|
||||
PRINTF("total cycles: %d\n", marker1 - marker0);
|
||||
for (int i = 0; i < dim_m; i += 8) {
|
||||
for (int j = 0; j < dim_n; j += 8) {
|
||||
PRINTF("%d %d ", (int) (C[i * dim_n + j]), (int) (C[i * dim_n + j + 4]));
|
||||
// PRINTF("%d %d ", (int) (C[i * dim_n + j]), (int) (C[i * dim_n + j + 4]));
|
||||
PRINTF("%04x %04x ", (int) (C[i * dim_n + j]), (int) (C[i * dim_n + j + 4]));
|
||||
}
|
||||
PRINTF("\n");
|
||||
}
|
||||
@@ -181,4 +187,4 @@ int main() {
|
||||
vx_spawn_tasks_contiguous(grid_size, (vx_spawn_tasks_cb)kernel_body, arg);
|
||||
#endif
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user