Merge branch 'tensor_core' into rtl

This commit is contained in:
Hansung Kim
2024-05-01 16:18:14 -07:00
32 changed files with 6097 additions and 20 deletions

View File

@@ -23,6 +23,9 @@
// #include "verilated_vpi.h"
#include "VX_config.h"
#include <bit>
#include "half.hpp"
extern "C" {
void dpi_fadd(bool enable, int dst_fmt, int64_t a, int64_t b, const svBitVecVal* frm, int64_t* result, svBitVecVal* fflags);
void dpi_fsub(bool enable, int dst_fmt, int64_t a, int64_t b, const svBitVecVal* frm, int64_t* result, svBitVecVal* fflags);
@@ -51,6 +54,8 @@ extern "C" {
void dpi_feq(bool enable, int dst_fmt, int64_t a, int64_t b, int64_t* result, svBitVecVal* fflags);
void dpi_fmin(bool enable, int dst_fmt, int64_t a, int64_t b, int64_t* result, svBitVecVal* fflags);
void dpi_fmax(bool enable, int dst_fmt, int64_t a, int64_t b, int64_t* result, svBitVecVal* fflags);
void dpi_hmma(bool enable, const svBitVecVal* A_tile, const svBitVecVal* B_tile, const svBitVecVal* C_tile, svBitVecVal* D_tile);
}
inline uint64_t nan_box(uint32_t value) {
@@ -338,3 +343,74 @@ void dpi_fmax(bool enable, int dst_fmt, int64_t a, int64_t b, int64_t* result, s
*result = nan_box(rv_fmax_s(check_boxing(a), check_boxing(b), fflags));
}
}
// A is M * K, B is K * K * M, C is M * M, D is M * M
#define M 4
#define K 2
// all row major
float c_A_tile[M][K];
float c_B_tile[K][M];
float c_C_tile[M][M];
float c_D_tile[M][M];
// code assumes that svBitVecVal is basically a uint32_t
static_assert(sizeof(svBitVecVal) == 4);
void fill_float_array(const svBitVecVal* sv_tile, float* c_tile, int rows, int cols) {
for (int i = 0; i < rows; i += 1) {
for (int j = 0; j < cols; j += 1) {
int index = i * cols + j;
svBitVecVal sv_val = sv_tile[index];
uint32_t c_val = sv_val;
float c_float;
memcpy(&c_float, &c_val, sizeof(c_float));
c_tile[index] = c_float;
// std::cout << c_float << " ";
}
// std::cout << std::endl;
}
}
void write_float_array(svBitVecVal* sv_tile, float* c_tile, int rows, int cols) {
for (int i = 0; i < rows; i += 1) {
for (int j = 0; j < cols; j += 1) {
int index = i * cols + j;
svBitVecVal* sv_val = &sv_tile[index];
float c_float = c_tile[index];
memcpy(sv_val, &c_float, sizeof(c_float));
// std::cout << c_float << " ";
}
// std::cout << std::endl;
}
}
void dpi_hmma(bool enable, const svBitVecVal* A_tile, const svBitVecVal* B_tile, const svBitVecVal* C_tile, svBitVecVal* D_tile) {
if (!enable) {
return;
}
// std::cout << "A: " << std::endl;
fill_float_array(A_tile, &c_A_tile[0][0], M, K);
// std::cout << "B: " << std::endl;
fill_float_array(B_tile, &c_B_tile[0][0], K, M);
// std::cout << "C: " << std::endl;
fill_float_array(C_tile, &c_C_tile[0][0], M, M);
for (int i = 0; i < M; i += 1) {
for (int j = 0; j < M; j += 1) {
float accum = c_C_tile[i][j];
for (int k = 0; k < K; k += 1) {
accum += c_A_tile[i][k] * c_B_tile[k][j];
}
c_D_tile[i][j] = accum;
}
}
write_float_array(D_tile, &c_D_tile[0][0], M, M);
}