Merge branch 'tensor_core' into rtl
This commit is contained in:
@@ -23,6 +23,9 @@
|
||||
// #include "verilated_vpi.h"
|
||||
#include "VX_config.h"
|
||||
|
||||
#include <bit>
|
||||
#include "half.hpp"
|
||||
|
||||
extern "C" {
|
||||
void dpi_fadd(bool enable, int dst_fmt, int64_t a, int64_t b, const svBitVecVal* frm, int64_t* result, svBitVecVal* fflags);
|
||||
void dpi_fsub(bool enable, int dst_fmt, int64_t a, int64_t b, const svBitVecVal* frm, int64_t* result, svBitVecVal* fflags);
|
||||
@@ -51,6 +54,8 @@ extern "C" {
|
||||
void dpi_feq(bool enable, int dst_fmt, int64_t a, int64_t b, int64_t* result, svBitVecVal* fflags);
|
||||
void dpi_fmin(bool enable, int dst_fmt, int64_t a, int64_t b, int64_t* result, svBitVecVal* fflags);
|
||||
void dpi_fmax(bool enable, int dst_fmt, int64_t a, int64_t b, int64_t* result, svBitVecVal* fflags);
|
||||
|
||||
void dpi_hmma(bool enable, const svBitVecVal* A_tile, const svBitVecVal* B_tile, const svBitVecVal* C_tile, svBitVecVal* D_tile);
|
||||
}
|
||||
|
||||
inline uint64_t nan_box(uint32_t value) {
|
||||
@@ -338,3 +343,74 @@ void dpi_fmax(bool enable, int dst_fmt, int64_t a, int64_t b, int64_t* result, s
|
||||
*result = nan_box(rv_fmax_s(check_boxing(a), check_boxing(b), fflags));
|
||||
}
|
||||
}
|
||||
|
||||
// A is M * K, B is K * K * M, C is M * M, D is M * M
|
||||
#define M 4
|
||||
#define K 2
|
||||
|
||||
// all row major
|
||||
float c_A_tile[M][K];
|
||||
float c_B_tile[K][M];
|
||||
float c_C_tile[M][M];
|
||||
float c_D_tile[M][M];
|
||||
|
||||
// code assumes that svBitVecVal is basically a uint32_t
|
||||
static_assert(sizeof(svBitVecVal) == 4);
|
||||
|
||||
void fill_float_array(const svBitVecVal* sv_tile, float* c_tile, int rows, int cols) {
|
||||
|
||||
for (int i = 0; i < rows; i += 1) {
|
||||
for (int j = 0; j < cols; j += 1) {
|
||||
int index = i * cols + j;
|
||||
svBitVecVal sv_val = sv_tile[index];
|
||||
|
||||
uint32_t c_val = sv_val;
|
||||
float c_float;
|
||||
|
||||
memcpy(&c_float, &c_val, sizeof(c_float));
|
||||
c_tile[index] = c_float;
|
||||
|
||||
// std::cout << c_float << " ";
|
||||
}
|
||||
// std::cout << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
void write_float_array(svBitVecVal* sv_tile, float* c_tile, int rows, int cols) {
|
||||
for (int i = 0; i < rows; i += 1) {
|
||||
for (int j = 0; j < cols; j += 1) {
|
||||
int index = i * cols + j;
|
||||
svBitVecVal* sv_val = &sv_tile[index];
|
||||
|
||||
float c_float = c_tile[index];
|
||||
memcpy(sv_val, &c_float, sizeof(c_float));
|
||||
|
||||
// std::cout << c_float << " ";
|
||||
}
|
||||
// std::cout << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
void dpi_hmma(bool enable, const svBitVecVal* A_tile, const svBitVecVal* B_tile, const svBitVecVal* C_tile, svBitVecVal* D_tile) {
|
||||
if (!enable) {
|
||||
return;
|
||||
}
|
||||
// std::cout << "A: " << std::endl;
|
||||
fill_float_array(A_tile, &c_A_tile[0][0], M, K);
|
||||
// std::cout << "B: " << std::endl;
|
||||
fill_float_array(B_tile, &c_B_tile[0][0], K, M);
|
||||
// std::cout << "C: " << std::endl;
|
||||
fill_float_array(C_tile, &c_C_tile[0][0], M, M);
|
||||
|
||||
for (int i = 0; i < M; i += 1) {
|
||||
for (int j = 0; j < M; j += 1) {
|
||||
float accum = c_C_tile[i][j];
|
||||
for (int k = 0; k < K; k += 1) {
|
||||
accum += c_A_tile[i][k] * c_B_tile[k][j];
|
||||
}
|
||||
c_D_tile[i][j] = accum;
|
||||
}
|
||||
}
|
||||
|
||||
write_float_array(D_tile, &c_D_tile[0][0], M, M);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user