Merge branch 'rtl' of https://github.com/hansungk/vortex-private into rtl
This commit is contained in:
@@ -23,6 +23,9 @@
|
||||
// #include "verilated_vpi.h"
|
||||
#include "VX_config.h"
|
||||
|
||||
#include <bit>
|
||||
#include "half.h"
|
||||
|
||||
extern "C" {
|
||||
void dpi_fadd(bool enable, int dst_fmt, int64_t a, int64_t b, const svBitVecVal* frm, int64_t* result, svBitVecVal* fflags);
|
||||
void dpi_fsub(bool enable, int dst_fmt, int64_t a, int64_t b, const svBitVecVal* frm, int64_t* result, svBitVecVal* fflags);
|
||||
@@ -51,6 +54,9 @@ extern "C" {
|
||||
void dpi_feq(bool enable, int dst_fmt, int64_t a, int64_t b, int64_t* result, svBitVecVal* fflags);
|
||||
void dpi_fmin(bool enable, int dst_fmt, int64_t a, int64_t b, int64_t* result, svBitVecVal* fflags);
|
||||
void dpi_fmax(bool enable, int dst_fmt, int64_t a, int64_t b, int64_t* result, svBitVecVal* fflags);
|
||||
|
||||
void dpi_hmma(bool enable, const svBitVecVal* A_tile, const svBitVecVal* B_tile, const svBitVecVal* C_tile, svBitVecVal* D_tile);
|
||||
void dpi_print_results(int wid, int octet, const svBitVecVal* A_tile, const svBitVecVal* B_tile, const svBitVecVal* C_tile, const svBitVecVal* D_tile);
|
||||
}
|
||||
|
||||
inline uint64_t nan_box(uint32_t value) {
|
||||
@@ -338,3 +344,223 @@ void dpi_fmax(bool enable, int dst_fmt, int64_t a, int64_t b, int64_t* result, s
|
||||
*result = nan_box(rv_fmax_s(check_boxing(a), check_boxing(b), fflags));
|
||||
}
|
||||
}
|
||||
|
||||
// A is M * K, B is K * M, C is M * M, D is M * M
|
||||
#define M 4
|
||||
#define K 2
|
||||
|
||||
// all row major
|
||||
float c_A_tile[M][K];
|
||||
float c_B_tile[K][M];
|
||||
float c_C_tile[M][M];
|
||||
float c_D_tile[M][M];
|
||||
|
||||
// code assumes that svBitVecVal is basically a uint32_t
|
||||
static_assert(sizeof(svBitVecVal) == 4);
|
||||
|
||||
void fill_float_array(const svBitVecVal* sv_tile, float* c_tile, int rows, int cols) {
|
||||
|
||||
for (int i = 0; i < rows; i += 1) {
|
||||
for (int j = 0; j < cols; j += 1) {
|
||||
int index = i * cols + j;
|
||||
svBitVecVal sv_val = sv_tile[index];
|
||||
|
||||
uint32_t c_val = sv_val;
|
||||
float c_float;
|
||||
|
||||
memcpy(&c_float, &c_val, sizeof(c_float));
|
||||
c_tile[index] = c_float;
|
||||
|
||||
// std::cout << c_float << " ";
|
||||
}
|
||||
// std::cout << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
void write_float_array(svBitVecVal* sv_tile, float* c_tile, int rows, int cols) {
|
||||
for (int i = 0; i < rows; i += 1) {
|
||||
for (int j = 0; j < cols; j += 1) {
|
||||
int index = i * cols + j;
|
||||
svBitVecVal* sv_val = &sv_tile[index];
|
||||
|
||||
float c_float = c_tile[index];
|
||||
memcpy(sv_val, &c_float, sizeof(c_float));
|
||||
|
||||
// std::cout << c_float << " ";
|
||||
}
|
||||
// std::cout << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
void dpi_hmma(bool enable, const svBitVecVal* A_tile, const svBitVecVal* B_tile, const svBitVecVal* C_tile, svBitVecVal* D_tile) {
|
||||
if (!enable) {
|
||||
return;
|
||||
}
|
||||
// std::cout << "A: " << std::endl;
|
||||
fill_float_array(A_tile, &c_A_tile[0][0], M, K);
|
||||
// std::cout << "B: " << std::endl;
|
||||
fill_float_array(B_tile, &c_B_tile[0][0], K, M);
|
||||
// std::cout << "C: " << std::endl;
|
||||
fill_float_array(C_tile, &c_C_tile[0][0], M, M);
|
||||
|
||||
for (int i = 0; i < M; i += 1) {
|
||||
for (int j = 0; j < M; j += 1) {
|
||||
float accum = c_C_tile[i][j];
|
||||
for (int k = 0; k < K; k += 1) {
|
||||
accum += c_A_tile[i][k] * c_B_tile[k][j];
|
||||
}
|
||||
c_D_tile[i][j] = accum;
|
||||
}
|
||||
}
|
||||
|
||||
write_float_array(D_tile, &c_D_tile[0][0], M, M);
|
||||
}
|
||||
|
||||
// 1 copy per warp
|
||||
float A_tile_full[4][16][8];
|
||||
float B_tile_full[4][8][16];
|
||||
float C_tile_full[4][16][16];
|
||||
float D_tile_full[4][16][16];
|
||||
int steps[4];
|
||||
|
||||
void print_array(float* array, int rows, int cols) {
|
||||
for (int i = 0; i < rows; i += 1) {
|
||||
for (int j = 0; j < cols; j += 1) {
|
||||
std::cout << array[i*cols+j] << " ";
|
||||
}
|
||||
std::cout << "\n";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
}
|
||||
|
||||
void dpi_print_results(int wid, int octet, const svBitVecVal* A_tile, const svBitVecVal* B_tile, const svBitVecVal* C_tile, const svBitVecVal* D_tile) {
|
||||
// std::cout << "A: " << std::endl;
|
||||
fill_float_array(A_tile, &c_A_tile[0][0], M, K);
|
||||
// std::cout << "B: " << std::endl;
|
||||
fill_float_array(B_tile, &c_B_tile[0][0], K, M);
|
||||
// std::cout << "C: " << std::endl;
|
||||
fill_float_array(C_tile, &c_C_tile[0][0], M, M);
|
||||
// for some reason this still holds onto old value? very strange
|
||||
// std::cout << "D: " << std::endl;
|
||||
fill_float_array(D_tile, &c_D_tile[0][0], M, M);
|
||||
|
||||
int octet_row_offset;
|
||||
int octet_col_offset;
|
||||
switch(octet) {
|
||||
case 0:
|
||||
octet_row_offset = 0;
|
||||
octet_col_offset = 0;
|
||||
break;
|
||||
case 1:
|
||||
octet_row_offset = 8;
|
||||
octet_col_offset = 0;
|
||||
break;
|
||||
case 2:
|
||||
octet_row_offset = 0;
|
||||
octet_col_offset = 8;
|
||||
break;
|
||||
case 3:
|
||||
octet_row_offset = 8;
|
||||
octet_col_offset = 8;
|
||||
break;
|
||||
}
|
||||
|
||||
int step_row_offset;
|
||||
int step_col_offset;
|
||||
int step = (steps[wid] % 16) / 4;
|
||||
int set = (steps[wid] / 16);
|
||||
switch(step) {
|
||||
case 0:
|
||||
step_row_offset = 0;
|
||||
step_col_offset = 0;
|
||||
break;
|
||||
case 1:
|
||||
step_row_offset = 2;
|
||||
step_col_offset = 0;
|
||||
break;
|
||||
case 2:
|
||||
step_row_offset = 0;
|
||||
step_col_offset = 4;
|
||||
break;
|
||||
case 3:
|
||||
step_row_offset = 2;
|
||||
step_col_offset = 4;
|
||||
break;
|
||||
}
|
||||
|
||||
if (steps[0] >= 48) {
|
||||
// std::cout << "octet " << octet << " step " << steps[0] << "\n";
|
||||
// print_array(&c_D_tile[0][0], 4, 4);
|
||||
}
|
||||
|
||||
D_tile_full[wid][octet_row_offset+step_row_offset+0][octet_col_offset+step_col_offset+0] = c_D_tile[0][0];
|
||||
D_tile_full[wid][octet_row_offset+step_row_offset+0][octet_col_offset+step_col_offset+1] = c_D_tile[0][1];
|
||||
D_tile_full[wid][octet_row_offset+step_row_offset+0][octet_col_offset+step_col_offset+2] = c_D_tile[0][2];
|
||||
D_tile_full[wid][octet_row_offset+step_row_offset+0][octet_col_offset+step_col_offset+3] = c_D_tile[0][3];
|
||||
D_tile_full[wid][octet_row_offset+step_row_offset+1][octet_col_offset+step_col_offset+0] = c_D_tile[1][0];
|
||||
D_tile_full[wid][octet_row_offset+step_row_offset+1][octet_col_offset+step_col_offset+1] = c_D_tile[1][1];
|
||||
D_tile_full[wid][octet_row_offset+step_row_offset+1][octet_col_offset+step_col_offset+2] = c_D_tile[1][2];
|
||||
D_tile_full[wid][octet_row_offset+step_row_offset+1][octet_col_offset+step_col_offset+3] = c_D_tile[1][3];
|
||||
D_tile_full[wid][octet_row_offset+step_row_offset+4][octet_col_offset+step_col_offset+0] = c_D_tile[2][0];
|
||||
D_tile_full[wid][octet_row_offset+step_row_offset+4][octet_col_offset+step_col_offset+1] = c_D_tile[2][1];
|
||||
D_tile_full[wid][octet_row_offset+step_row_offset+4][octet_col_offset+step_col_offset+2] = c_D_tile[2][2];
|
||||
D_tile_full[wid][octet_row_offset+step_row_offset+4][octet_col_offset+step_col_offset+3] = c_D_tile[2][3];
|
||||
D_tile_full[wid][octet_row_offset+step_row_offset+5][octet_col_offset+step_col_offset+0] = c_D_tile[3][0];
|
||||
D_tile_full[wid][octet_row_offset+step_row_offset+5][octet_col_offset+step_col_offset+1] = c_D_tile[3][1];
|
||||
D_tile_full[wid][octet_row_offset+step_row_offset+5][octet_col_offset+step_col_offset+2] = c_D_tile[3][2];
|
||||
D_tile_full[wid][octet_row_offset+step_row_offset+5][octet_col_offset+step_col_offset+3] = c_D_tile[3][3];
|
||||
|
||||
if (octet == 0 || octet == 1) {
|
||||
octet_row_offset = octet * 8;
|
||||
if (step == 0) {
|
||||
step_row_offset = 0;
|
||||
}
|
||||
if (step == 1) {
|
||||
step_row_offset = 2;
|
||||
}
|
||||
if (step == 0 || step == 1) {
|
||||
A_tile_full[wid][octet_row_offset+step_row_offset+0][set*2+0] = c_A_tile[0][0];
|
||||
A_tile_full[wid][octet_row_offset+step_row_offset+0][set*2+1] = c_A_tile[0][1];
|
||||
A_tile_full[wid][octet_row_offset+step_row_offset+1][set*2+0] = c_A_tile[1][0];
|
||||
A_tile_full[wid][octet_row_offset+step_row_offset+1][set*2+1] = c_A_tile[1][1];
|
||||
A_tile_full[wid][octet_row_offset+step_row_offset+4][set*2+0] = c_A_tile[2][0];
|
||||
A_tile_full[wid][octet_row_offset+step_row_offset+4][set*2+1] = c_A_tile[2][1];
|
||||
A_tile_full[wid][octet_row_offset+step_row_offset+5][set*2+0] = c_A_tile[3][0];
|
||||
A_tile_full[wid][octet_row_offset+step_row_offset+5][set*2+1] = c_A_tile[3][1];
|
||||
}
|
||||
}
|
||||
|
||||
if (octet == 0 || octet == 2) {
|
||||
octet_col_offset = octet * 4;
|
||||
if (step == 0) {
|
||||
step_col_offset = 0;
|
||||
}
|
||||
else if (step == 2) {
|
||||
step_col_offset = 4;
|
||||
}
|
||||
if (step == 0 || step == 2) {
|
||||
B_tile_full[wid][set*2+0][octet_col_offset+step_col_offset+0] = c_B_tile[0][0];
|
||||
B_tile_full[wid][set*2+0][octet_col_offset+step_col_offset+1] = c_B_tile[0][1];
|
||||
B_tile_full[wid][set*2+0][octet_col_offset+step_col_offset+2] = c_B_tile[0][2];
|
||||
B_tile_full[wid][set*2+0][octet_col_offset+step_col_offset+3] = c_B_tile[0][3];
|
||||
B_tile_full[wid][set*2+1][octet_col_offset+step_col_offset+0] = c_B_tile[1][0];
|
||||
B_tile_full[wid][set*2+1][octet_col_offset+step_col_offset+1] = c_B_tile[1][1];
|
||||
B_tile_full[wid][set*2+1][octet_col_offset+step_col_offset+2] = c_B_tile[1][2];
|
||||
B_tile_full[wid][set*2+1][octet_col_offset+step_col_offset+3] = c_B_tile[1][3];
|
||||
}
|
||||
}
|
||||
|
||||
steps[wid] += 1;
|
||||
if (steps[wid] % 64 == 0) {
|
||||
steps[wid] = 0;
|
||||
std::cout << "warp " << wid << " finished wmma\n";
|
||||
std::cout << "A tile" << "\n";
|
||||
print_array(&A_tile_full[wid][0][0], 16, 8);
|
||||
std::cout << "B tile" << "\n";
|
||||
print_array(&B_tile_full[wid][0][0], 8, 16);
|
||||
// std::cout << "C tile" << "\n";
|
||||
// print_array(&C_tile_full[wid][0][0], 16, 16);
|
||||
std::cout << "D tile" << "\n";
|
||||
print_array(&D_tile_full[wid][0][0], 16, 16);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -44,4 +44,7 @@ import "DPI-C" function void dpi_feq(input logic enable, input int dst_fmt, inpu
|
||||
import "DPI-C" function void dpi_fmin(input logic enable, input int dst_fmt, input longint a, input longint b, output longint result, output bit[4:0] fflags);
|
||||
import "DPI-C" function void dpi_fmax(input logic enable, input int dst_fmt, input longint a, input longint b, output longint result, output bit[4:0] fflags);
|
||||
|
||||
import "DPI-C" function void dpi_hmma(input logic enable, input bit[3:0][1:0][31:0] A_tile, input bit[1:0][3:0][31:0] B_tile, input bit[3:0][3:0][31:0] C_tile, output bit[3:0][3:0][31:0] D_tile);
|
||||
import "DPI-C" function void dpi_print_results(input int wid, input int octet, input bit[3:0][1:0][31:0] A_tile, input bit[1:0][3:0][31:0] B_tile, input bit[3:0][3:0][31:0] C_tile, input bit[3:0][3:0][31:0] D_tile);
|
||||
|
||||
`endif
|
||||
|
||||
4018
hw/dpi/half.h
Normal file
4018
hw/dpi/half.h
Normal file
File diff suppressed because it is too large
Load Diff
@@ -40,6 +40,10 @@
|
||||
`define EXT_F_ENABLE
|
||||
`endif
|
||||
|
||||
`ifndef EXT_T_DISABLE
|
||||
`define EXT_T_ENABLE
|
||||
`endif
|
||||
|
||||
`ifndef XLEN_32
|
||||
`ifndef XLEN_64
|
||||
`define XLEN_32
|
||||
@@ -309,7 +313,7 @@
|
||||
|
||||
// Size of FPU Request Queue
|
||||
`ifndef FPUQ_SIZE
|
||||
`define FPUQ_SIZE (2 * (`NUM_THREADS / `NUM_FPU_LANES))
|
||||
`define FPUQ_SIZE (8 * (`NUM_THREADS / `NUM_FPU_LANES))
|
||||
`endif
|
||||
|
||||
// FNCP Latency
|
||||
@@ -385,6 +389,11 @@
|
||||
`define LATENCY_FCVT 5
|
||||
`endif
|
||||
|
||||
// Tensor Core Latency
|
||||
`ifndef LATENCY_HMMA
|
||||
`define LATENCY_HMMA 8
|
||||
`endif
|
||||
|
||||
// Icache Configurable Knobs //////////////////////////////////////////////////
|
||||
|
||||
// Cache Enable
|
||||
@@ -613,6 +622,12 @@
|
||||
`define EXT_F_ENABLED 0
|
||||
`endif
|
||||
|
||||
`ifdef EXT_T_ENABLE
|
||||
`define EXT_T_ENABLED 1
|
||||
`else
|
||||
`define EXT_T_ENABLED 0
|
||||
`endif
|
||||
|
||||
`ifdef EXT_M_ENABLE
|
||||
`define EXT_M_ENABLED 1
|
||||
`else
|
||||
|
||||
@@ -58,8 +58,9 @@
|
||||
`define EX_LSU 1
|
||||
`define EX_SFU 2
|
||||
`define EX_FPU (`EX_SFU + `EXT_F_ENABLED)
|
||||
`define EX_TENSOR (`EX_FPU + `EXT_T_ENABLED)
|
||||
|
||||
`define NUM_EX_UNITS (3 + `EXT_F_ENABLED)
|
||||
`define NUM_EX_UNITS (3 + `EXT_F_ENABLED + `EXT_T_ENABLED)
|
||||
`define EX_BITS `CLOG2(`NUM_EX_UNITS)
|
||||
`define EX_WIDTH `UP(`EX_BITS)
|
||||
|
||||
@@ -115,7 +116,7 @@
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
`define INST_OP_BITS 4
|
||||
`define INST_MOD_BITS 3
|
||||
`define INST_MOD_BITS 4
|
||||
`define INST_FMT_BITS 2
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
@@ -140,6 +141,7 @@
|
||||
`define INST_ALU_IS_BR(mod) mod[0]
|
||||
`define INST_ALU_IS_M(mod) mod[1]
|
||||
`define INST_ALU_IS_W(mod) mod[2]
|
||||
`define INST_ALU_IS_RED(mod) mod[3]
|
||||
|
||||
`define INST_BR_EQ 4'b0000
|
||||
`define INST_BR_NE 4'b0010
|
||||
@@ -176,6 +178,17 @@
|
||||
`define INST_M_SIGNED_A(op) (op[1:0] != 1)
|
||||
`define INST_M_IS_REM(op) op[1]
|
||||
|
||||
`define INST_RED_ADD 4'b0000
|
||||
`define INST_RED_ADDU 4'b1000
|
||||
`define INST_RED_MIN 4'b0001
|
||||
`define INST_RED_MINU 4'b1001
|
||||
`define INST_RED_MAX 4'b0010
|
||||
`define INST_RED_MAXU 4'b1010
|
||||
`define INST_RED_AND 4'b0011
|
||||
`define INST_RED_OR 4'b0100
|
||||
`define INST_RED_XOR 4'b0101
|
||||
`define INST_RED_BITS 4
|
||||
|
||||
`define INST_FMT_B 3'b000
|
||||
`define INST_FMT_H 3'b001
|
||||
`define INST_FMT_W 3'b010
|
||||
@@ -241,6 +254,8 @@
|
||||
`define INST_SFU_IS_WCTL(op) (op <= 5)
|
||||
`define INST_SFU_IS_CSR(op) (op >= 6 && op <= 8)
|
||||
|
||||
`define INST_TENSOR_HMMA 4'b0000
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// non-cacheable tag bits
|
||||
|
||||
@@ -14,6 +14,16 @@
|
||||
`ifndef VX_PLATFORM_VH
|
||||
`define VX_PLATFORM_VH
|
||||
|
||||
// synthesis only
|
||||
`ifndef SIMULATION
|
||||
`define SYNTHESIS
|
||||
`define NDEBUG
|
||||
`define DPI_DISABLE
|
||||
`else
|
||||
`define SV_DPI
|
||||
`endif
|
||||
|
||||
`ifdef SYNTHESIS
|
||||
`define GPR_RESET
|
||||
`define LSU_DUP_DISABLE
|
||||
`define ICACHE_DISABLE
|
||||
|
||||
@@ -33,7 +33,7 @@ module VX_alu_unit #(
|
||||
localparam PID_BITS = `CLOG2(`NUM_THREADS / NUM_LANES);
|
||||
localparam PID_WIDTH = `UP(PID_BITS);
|
||||
localparam RSP_ARB_DATAW= `UUID_WIDTH + `NW_WIDTH + NUM_LANES + `XLEN + `NR_BITS + 1 + NUM_LANES * `XLEN + PID_WIDTH + 1 + 1;
|
||||
localparam RSP_ARB_SIZE = 1 + `EXT_M_ENABLED;
|
||||
localparam RSP_ARB_SIZE = 2 + `EXT_M_ENABLED;
|
||||
localparam PARTIAL_BW = (BLOCK_SIZE != `ISSUE_WIDTH) || (NUM_LANES != `NUM_THREADS);
|
||||
|
||||
VX_execute_if #(
|
||||
@@ -60,12 +60,13 @@ module VX_alu_unit #(
|
||||
for (genvar block_idx = 0; block_idx < BLOCK_SIZE; ++block_idx) begin
|
||||
|
||||
wire is_muldiv_op;
|
||||
wire is_reduce_op;
|
||||
|
||||
VX_execute_if #(
|
||||
.NUM_LANES (NUM_LANES)
|
||||
) int_execute_if();
|
||||
|
||||
assign int_execute_if.valid = execute_if[block_idx].valid && ~is_muldiv_op;
|
||||
assign int_execute_if.valid = execute_if[block_idx].valid && ~is_muldiv_op && ~is_reduce_op;
|
||||
assign int_execute_if.data = execute_if[block_idx].data;
|
||||
|
||||
VX_commit_if #(
|
||||
@@ -86,6 +87,31 @@ module VX_alu_unit #(
|
||||
.commit_if (int_commit_if)
|
||||
);
|
||||
|
||||
assign is_reduce_op = `INST_ALU_IS_RED(execute_if[block_idx].data.op_mod);
|
||||
|
||||
VX_execute_if #(
|
||||
.NUM_LANES (NUM_LANES)
|
||||
) red_execute_if();
|
||||
|
||||
assign red_execute_if.valid = execute_if[block_idx].valid && is_reduce_op;
|
||||
assign red_execute_if.data = execute_if[block_idx].data;
|
||||
|
||||
VX_commit_if #(
|
||||
.NUM_LANES (NUM_LANES)
|
||||
) red_commit_if();
|
||||
|
||||
`RESET_RELAY(red_reset, reset);
|
||||
|
||||
VX_reduce_unit #(
|
||||
.CORE_ID(CORE_ID),
|
||||
.NUM_LANES(NUM_LANES)
|
||||
) reduce_unit (
|
||||
.clk(clk),
|
||||
.reset(red_reset),
|
||||
.execute_if(red_execute_if),
|
||||
.commit_if(red_commit_if)
|
||||
);
|
||||
|
||||
`ifdef EXT_M_ENABLE
|
||||
|
||||
assign is_muldiv_op = `INST_ALU_IS_M(execute_if[block_idx].data.op_mod);
|
||||
@@ -96,7 +122,7 @@ module VX_alu_unit #(
|
||||
.NUM_LANES (NUM_LANES)
|
||||
) mdv_execute_if();
|
||||
|
||||
assign mdv_execute_if.valid = execute_if[block_idx].valid && is_muldiv_op;
|
||||
assign mdv_execute_if.valid = execute_if[block_idx].valid && is_muldiv_op && ~is_reduce_op;
|
||||
assign mdv_execute_if.data = execute_if[block_idx].data;
|
||||
|
||||
VX_commit_if #(
|
||||
@@ -113,12 +139,12 @@ module VX_alu_unit #(
|
||||
.commit_if (mdv_commit_if)
|
||||
);
|
||||
|
||||
assign execute_if[block_idx].ready = is_muldiv_op ? mdv_execute_if.ready : int_execute_if.ready;
|
||||
assign execute_if[block_idx].ready = is_reduce_op ? red_execute_if.ready : (is_muldiv_op ? mdv_execute_if.ready : int_execute_if.ready);
|
||||
|
||||
`else
|
||||
|
||||
assign is_muldiv_op = 0;
|
||||
assign execute_if[block_idx].ready = int_execute_if.ready;
|
||||
assign execute_if[block_idx].ready = is_reduce_op ? red_execute_if.ready : int_execute_if.ready;
|
||||
|
||||
`endif
|
||||
|
||||
@@ -135,19 +161,22 @@ module VX_alu_unit #(
|
||||
`ifdef EXT_M_ENABLE
|
||||
mdv_commit_if.valid,
|
||||
`endif
|
||||
int_commit_if.valid
|
||||
int_commit_if.valid,
|
||||
red_commit_if.valid
|
||||
}),
|
||||
.ready_in ({
|
||||
`ifdef EXT_M_ENABLE
|
||||
mdv_commit_if.ready,
|
||||
`endif
|
||||
int_commit_if.ready
|
||||
int_commit_if.ready,
|
||||
red_commit_if.ready
|
||||
}),
|
||||
.data_in ({
|
||||
`ifdef EXT_M_ENABLE
|
||||
mdv_commit_if.data,
|
||||
`endif
|
||||
int_commit_if.data
|
||||
int_commit_if.data,
|
||||
red_commit_if.data
|
||||
}),
|
||||
.data_out (commit_block_if[block_idx].data),
|
||||
.valid_out (commit_block_if[block_idx].valid),
|
||||
|
||||
@@ -28,6 +28,10 @@ module VX_commit import VX_gpu_pkg::*; #(
|
||||
`endif
|
||||
VX_commit_if.slave sfu_commit_if [`ISSUE_WIDTH],
|
||||
|
||||
`ifdef EXT_T_ENABLE
|
||||
VX_commit_if.slave tensor_commit_if [`ISSUE_WIDTH],
|
||||
`endif
|
||||
|
||||
// outputs
|
||||
VX_writeback_if.master writeback_if [`ISSUE_WIDTH],
|
||||
VX_commit_csr_if.master commit_csr_if,
|
||||
@@ -49,6 +53,8 @@ module VX_commit import VX_gpu_pkg::*; #(
|
||||
wire [`ISSUE_WIDTH-1:0][`NW_WIDTH-1:0] commit_wid;
|
||||
wire [`ISSUE_WIDTH-1:0][`NUM_THREADS-1:0] commit_tmask;
|
||||
wire [`ISSUE_WIDTH-1:0] commit_eop;
|
||||
wire [`ISSUE_WIDTH-1:0][`EX_BITS-1:0] commit_sel;
|
||||
`UNUSED_VAR (commit_sel)
|
||||
|
||||
for (genvar i = 0; i < `ISSUE_WIDTH; ++i) begin
|
||||
|
||||
@@ -66,6 +72,9 @@ module VX_commit import VX_gpu_pkg::*; #(
|
||||
sfu_commit_if[i].valid,
|
||||
`ifdef EXT_F_ENABLE
|
||||
fpu_commit_if[i].valid,
|
||||
`endif
|
||||
`ifdef EXT_T_ENABLE
|
||||
tensor_commit_if[i].valid,
|
||||
`endif
|
||||
alu_commit_if[i].valid,
|
||||
lsu_commit_if[i].valid
|
||||
@@ -74,6 +83,9 @@ module VX_commit import VX_gpu_pkg::*; #(
|
||||
sfu_commit_if[i].ready,
|
||||
`ifdef EXT_F_ENABLE
|
||||
fpu_commit_if[i].ready,
|
||||
`endif
|
||||
`ifdef EXT_T_ENABLE
|
||||
tensor_commit_if[i].ready,
|
||||
`endif
|
||||
alu_commit_if[i].ready,
|
||||
lsu_commit_if[i].ready
|
||||
@@ -82,6 +94,9 @@ module VX_commit import VX_gpu_pkg::*; #(
|
||||
sfu_commit_if[i].data,
|
||||
`ifdef EXT_F_ENABLE
|
||||
fpu_commit_if[i].data,
|
||||
`endif
|
||||
`ifdef EXT_T_ENABLE
|
||||
tensor_commit_if[i].data,
|
||||
`endif
|
||||
alu_commit_if[i].data,
|
||||
lsu_commit_if[i].data
|
||||
@@ -89,7 +104,7 @@ module VX_commit import VX_gpu_pkg::*; #(
|
||||
.data_out (commit_if[i].data),
|
||||
.valid_out (commit_if[i].valid),
|
||||
.ready_out (commit_if[i].ready),
|
||||
`UNUSED_PIN (sel_out)
|
||||
.sel_out (commit_sel[i])
|
||||
);
|
||||
|
||||
assign commit_fire[i] = commit_if[i].valid && commit_if[i].ready;
|
||||
@@ -158,7 +173,36 @@ module VX_commit import VX_gpu_pkg::*; #(
|
||||
|
||||
// Committed instructions
|
||||
|
||||
wire [`ISSUE_WIDTH-1:0] committed = commit_fire & commit_eop;
|
||||
// temporary hack to not underflow the pending instructions buffer
|
||||
// relies on 1 cycle delay of arbiter and continuous issuing of tensor instructions,
|
||||
// so probably want to change this at some point
|
||||
// (i.e. pass a "don't count this towards pending instructions" signal down the pipeline)
|
||||
// logic [`ISSUE_WIDTH-1:0][4:0] hmma_ctr, hmma_ctr_n;
|
||||
wire [`ISSUE_WIDTH-1:0] final_hmma;
|
||||
`ifdef EXT_T_ENABLE
|
||||
for (genvar i = 0; i < `ISSUE_WIDTH; ++i) begin
|
||||
// assign hmma_ctr_n[i] = (tensor_commit_if[i].valid && tensor_commit_if[i].ready) ? hmma_ctr[i] + 5'b1 : hmma_ctr[i];
|
||||
// assign final_hmma[i] = (commit_sel[i] != `EX_BITS'(2) || hmma_ctr == '0);
|
||||
// i suppose this is now a feature and not a bug
|
||||
// if PC is 0, this means it is not final step of a wmma, shouldn't be committed
|
||||
assign final_hmma[i] = (commit_if[i].data.PC != 32'b0);
|
||||
end
|
||||
/*
|
||||
always @(posedge clk) begin
|
||||
if (reset) begin
|
||||
hmma_ctr <= '0;
|
||||
end
|
||||
else begin
|
||||
hmma_ctr <= hmma_ctr_n;
|
||||
end
|
||||
end
|
||||
*/
|
||||
`else
|
||||
assign final_hmma = '1;
|
||||
`endif
|
||||
|
||||
|
||||
wire [`ISSUE_WIDTH-1:0] committed = (commit_fire & commit_eop) & final_hmma;
|
||||
|
||||
VX_pipe_register #(
|
||||
.DATAW (`ISSUE_WIDTH * (1 + `NW_WIDTH)),
|
||||
|
||||
@@ -71,6 +71,10 @@ module VX_core import VX_gpu_pkg::*; #(
|
||||
`ifdef EXT_F_ENABLE
|
||||
VX_dispatch_if fpu_dispatch_if[`ISSUE_WIDTH]();
|
||||
VX_commit_if fpu_commit_if[`ISSUE_WIDTH]();
|
||||
`endif
|
||||
`ifdef EXT_T_ENABLE
|
||||
VX_dispatch_if tensor_dispatch_if[`ISSUE_WIDTH]();
|
||||
VX_commit_if tensor_commit_if[`ISSUE_WIDTH]();
|
||||
`endif
|
||||
VX_dispatch_if sfu_dispatch_if[`ISSUE_WIDTH]();
|
||||
VX_commit_if sfu_commit_if[`ISSUE_WIDTH]();
|
||||
@@ -178,6 +182,9 @@ module VX_core import VX_gpu_pkg::*; #(
|
||||
.lsu_dispatch_if(lsu_dispatch_if),
|
||||
`ifdef EXT_F_ENABLE
|
||||
.fpu_dispatch_if(fpu_dispatch_if),
|
||||
`endif
|
||||
`ifdef EXT_T_ENABLE
|
||||
.tensor_dispatch_if(tensor_dispatch_if),
|
||||
`endif
|
||||
.sfu_dispatch_if(sfu_dispatch_if)
|
||||
);
|
||||
@@ -203,6 +210,10 @@ module VX_core import VX_gpu_pkg::*; #(
|
||||
.fpu_dispatch_if(fpu_dispatch_if),
|
||||
.fpu_commit_if (fpu_commit_if),
|
||||
`endif
|
||||
`ifdef EXT_T_ENABLE
|
||||
.tensor_dispatch_if (tensor_dispatch_if),
|
||||
.tensor_commit_if (tensor_commit_if),
|
||||
`endif
|
||||
|
||||
.commit_csr_if (commit_csr_if),
|
||||
.sched_csr_if (sched_csr_if),
|
||||
@@ -237,6 +248,9 @@ module VX_core import VX_gpu_pkg::*; #(
|
||||
.fpu_commit_if (fpu_commit_if),
|
||||
`endif
|
||||
.sfu_commit_if (sfu_commit_if),
|
||||
`ifdef EXT_T_ENABLE
|
||||
.tensor_commit_if (tensor_commit_if),
|
||||
`endif
|
||||
|
||||
.writeback_if (writeback_if),
|
||||
|
||||
|
||||
@@ -513,6 +513,40 @@ module VX_decode #(
|
||||
default:;
|
||||
endcase
|
||||
end
|
||||
`INST_EXT3: begin
|
||||
ex_type = `EX_ALU;
|
||||
op_mod[3] = 1;
|
||||
`USED_IREG(rs1);
|
||||
`USED_IREG(rd);
|
||||
|
||||
case (func7[5:0])
|
||||
6'h0: begin
|
||||
op_type = func7[6] ? `INST_RED_ADDU : `INST_RED_ADD;
|
||||
end
|
||||
6'h1: begin
|
||||
op_type = func7[6] ? `INST_RED_MINU : `INST_RED_MIN;
|
||||
end
|
||||
6'h2: begin
|
||||
op_type = func7[6] ? `INST_RED_MAXU : `INST_RED_MAX;
|
||||
end
|
||||
6'h3: begin
|
||||
op_type = `INST_RED_AND;
|
||||
end
|
||||
6'h4: begin
|
||||
op_type = `INST_RED_OR;
|
||||
end
|
||||
6'h5: begin
|
||||
op_type = `INST_RED_XOR;
|
||||
end
|
||||
default:;
|
||||
endcase
|
||||
end
|
||||
`ifdef EXT_T_ENABLE
|
||||
`INST_EXT4: begin
|
||||
ex_type = `EX_TENSOR;
|
||||
op_type = `INST_TENSOR_HMMA;
|
||||
end
|
||||
`endif
|
||||
default:;
|
||||
endcase
|
||||
end
|
||||
|
||||
@@ -34,6 +34,9 @@ module VX_dispatch import VX_gpu_pkg::*; #(
|
||||
VX_dispatch_if.master lsu_dispatch_if [`ISSUE_WIDTH],
|
||||
`ifdef EXT_F_ENABLE
|
||||
VX_dispatch_if.master fpu_dispatch_if [`ISSUE_WIDTH],
|
||||
`endif
|
||||
`ifdef EXT_T_ENABLE
|
||||
VX_dispatch_if.master tensor_dispatch_if [`ISSUE_WIDTH],
|
||||
`endif
|
||||
VX_dispatch_if.master sfu_dispatch_if [`ISSUE_WIDTH]
|
||||
);
|
||||
@@ -142,6 +145,35 @@ module VX_dispatch import VX_gpu_pkg::*; #(
|
||||
end
|
||||
`endif
|
||||
|
||||
// Tensor Core dispatch
|
||||
|
||||
`ifdef EXT_T_ENABLE
|
||||
|
||||
VX_operands_if tensor_operands_if[`ISSUE_WIDTH]();
|
||||
|
||||
for (genvar i = 0; i < `ISSUE_WIDTH; ++i) begin
|
||||
assign tensor_operands_if[i].valid = operands_if[i].valid && (operands_if[i].data.ex_type == `EX_TENSOR);
|
||||
assign tensor_operands_if[i].data = operands_if[i].data;
|
||||
|
||||
`RESET_RELAY (tensor_reset, reset);
|
||||
|
||||
VX_elastic_buffer #(
|
||||
.DATAW (DATAW),
|
||||
.SIZE (2),
|
||||
.OUT_REG (2)
|
||||
) tensor_buffer (
|
||||
.clk (clk),
|
||||
.reset (tensor_reset),
|
||||
.valid_in (tensor_operands_if[i].valid),
|
||||
.ready_in (tensor_operands_if[i].ready),
|
||||
.data_in (`TO_DISPATCH_DATA(tensor_operands_if[i].data, last_active_tid[i])),
|
||||
.data_out (tensor_dispatch_if[i].data),
|
||||
.valid_out (tensor_dispatch_if[i].valid),
|
||||
.ready_out (tensor_dispatch_if[i].ready)
|
||||
);
|
||||
end
|
||||
`endif
|
||||
|
||||
// SFU dispatch
|
||||
|
||||
VX_operands_if sfu_operands_if[`ISSUE_WIDTH]();
|
||||
@@ -174,6 +206,9 @@ module VX_dispatch import VX_gpu_pkg::*; #(
|
||||
|| (lsu_operands_if[i].ready && (operands_if[i].data.ex_type == `EX_LSU))
|
||||
`ifdef EXT_F_ENABLE
|
||||
|| (fpu_operands_if[i].ready && (operands_if[i].data.ex_type == `EX_FPU))
|
||||
`endif
|
||||
`ifdef EXT_T_ENABLE
|
||||
|| (tensor_operands_if[i].ready && (operands_if[i].data.ex_type == `EX_TENSOR))
|
||||
`endif
|
||||
|| (sfu_operands_if[i].ready && (operands_if[i].data.ex_type == `EX_SFU));
|
||||
end
|
||||
|
||||
@@ -41,7 +41,7 @@ module VX_execute import VX_gpu_pkg::*; #(
|
||||
VX_dispatch_if.slave fpu_dispatch_if [`ISSUE_WIDTH],
|
||||
VX_commit_if.master fpu_commit_if [`ISSUE_WIDTH],
|
||||
`endif
|
||||
|
||||
|
||||
VX_dispatch_if.slave alu_dispatch_if [`ISSUE_WIDTH],
|
||||
VX_commit_if.master alu_commit_if [`ISSUE_WIDTH],
|
||||
VX_branch_ctl_if.master branch_ctl_if [`NUM_ALU_BLOCKS],
|
||||
@@ -53,6 +53,11 @@ module VX_execute import VX_gpu_pkg::*; #(
|
||||
VX_commit_if.master sfu_commit_if [`ISSUE_WIDTH],
|
||||
VX_warp_ctl_if.master warp_ctl_if,
|
||||
|
||||
`ifdef EXT_T_ENABLE
|
||||
VX_dispatch_if.slave tensor_dispatch_if [`ISSUE_WIDTH],
|
||||
VX_commit_if.master tensor_commit_if [`ISSUE_WIDTH],
|
||||
`endif
|
||||
|
||||
// simulation helper signals
|
||||
output wire sim_ebreak,
|
||||
|
||||
@@ -135,6 +140,18 @@ module VX_execute import VX_gpu_pkg::*; #(
|
||||
.acc_write_en (acc_write_en)
|
||||
);
|
||||
|
||||
`ifdef EXT_T_ENABLE
|
||||
VX_tensor_core #(
|
||||
|
||||
) tensor_core (
|
||||
.clk(clk),
|
||||
.reset(reset),
|
||||
|
||||
.dispatch_if(tensor_dispatch_if),
|
||||
.commit_if(tensor_commit_if)
|
||||
);
|
||||
`endif
|
||||
|
||||
// simulation helper signal to get RISC-V tests Pass/Fail status
|
||||
assign sim_ebreak = alu_dispatch_if[0].valid && alu_dispatch_if[0].ready
|
||||
&& alu_dispatch_if[0].data.wis == 0
|
||||
|
||||
@@ -85,6 +85,25 @@ module VX_fpu_unit import VX_fpu_pkg::*; #(
|
||||
wire execute_fire = execute_if[block_idx].valid && execute_if[block_idx].ready;
|
||||
wire fpu_rsp_fire = fpu_rsp_valid && fpu_rsp_ready;
|
||||
|
||||
reg [63:0] perf_execute_fires;
|
||||
reg [63:0] perf_execute_valids;
|
||||
reg [63:0] perf_fpu_req_valids;
|
||||
reg [63:0] perf_fpu_req_readys;
|
||||
|
||||
always @(posedge clk) begin
|
||||
if (reset) begin
|
||||
perf_execute_fires <= '0;
|
||||
perf_execute_valids <= '0;
|
||||
perf_fpu_req_valids <= '0;
|
||||
perf_fpu_req_readys <= '0;
|
||||
end else begin
|
||||
perf_execute_fires <= perf_execute_fires + 64'(execute_fire);
|
||||
perf_execute_valids <= perf_execute_valids + 64'(execute_if[block_idx].valid);
|
||||
perf_fpu_req_valids <= perf_fpu_req_valids + 64'(fpu_req_valid);
|
||||
perf_fpu_req_readys <= perf_fpu_req_readys + 64'(fpu_req_ready);
|
||||
end
|
||||
end
|
||||
|
||||
VX_index_buffer #(
|
||||
.DATAW (`UUID_WIDTH + `NW_WIDTH + NUM_LANES + `XLEN + `NR_BITS + PID_WIDTH + 1 + 1),
|
||||
.SIZE (`FPUQ_SIZE)
|
||||
|
||||
@@ -36,6 +36,8 @@ module VX_ibuffer import VX_gpu_pkg::*; #(
|
||||
|
||||
assign decode_if.ready = ibuf_ready_in[decode_isw];
|
||||
|
||||
VX_ibuffer_if uop_sequencer_if [`ISSUE_WIDTH]();
|
||||
|
||||
for (genvar i = 0; i < `ISSUE_WIDTH; ++i) begin
|
||||
VX_elastic_buffer #(
|
||||
.DATAW (DATAW),
|
||||
@@ -62,13 +64,29 @@ module VX_ibuffer import VX_gpu_pkg::*; #(
|
||||
decode_if.data.rs1,
|
||||
decode_if.data.rs2,
|
||||
decode_if.data.rs3}),
|
||||
.data_out(ibuffer_if[i].data),
|
||||
.valid_out (ibuffer_if[i].valid),
|
||||
.ready_out(ibuffer_if[i].ready)
|
||||
);
|
||||
|
||||
.data_out (uop_sequencer_if[i].data),
|
||||
.valid_out (uop_sequencer_if[i].valid),
|
||||
.ready_out (uop_sequencer_if[i].ready)
|
||||
);
|
||||
|
||||
`ifndef L1_ENABLE
|
||||
assign decode_if.ibuf_pop[i] = ibuffer_if[i].valid && ibuffer_if[i].ready;
|
||||
assign decode_if.ibuf_pop[i] = uop_sequencer_if[i].valid && uop_sequencer_if[i].ready;
|
||||
`endif
|
||||
|
||||
// tensor-core operation is controlled by a single macro-instruction at
|
||||
// the ISA; internally, the uop_sequencer blitzs micro-ops (counterpart
|
||||
// to Volta SASS set/step instructions) into the ibuffer upon encountering
|
||||
// this macro-instruction. this becomes a pass-through for non-tensorcore
|
||||
// instructions.
|
||||
VX_uop_sequencer uop_sequencer (
|
||||
.clk(clk),
|
||||
.reset(reset),
|
||||
|
||||
.uop_sequencer_if(uop_sequencer_if[i]),
|
||||
.ibuffer_if(ibuffer_if[i])
|
||||
);
|
||||
|
||||
end
|
||||
|
||||
endmodule
|
||||
|
||||
@@ -33,6 +33,9 @@ module VX_issue #(
|
||||
VX_dispatch_if.master lsu_dispatch_if [`ISSUE_WIDTH],
|
||||
`ifdef EXT_F_ENABLE
|
||||
VX_dispatch_if.master fpu_dispatch_if [`ISSUE_WIDTH],
|
||||
`endif
|
||||
`ifdef EXT_T_ENABLE
|
||||
VX_dispatch_if.master tensor_dispatch_if [`ISSUE_WIDTH],
|
||||
`endif
|
||||
VX_dispatch_if.master sfu_dispatch_if [`ISSUE_WIDTH]
|
||||
);
|
||||
@@ -93,7 +96,6 @@ module VX_issue #(
|
||||
.clk (clk),
|
||||
.reset (dispatch_reset),
|
||||
`ifdef PERF_ENABLE
|
||||
`UNUSED_PIN (perf_stalls),
|
||||
.perf_stalls (perf_issue_if.dispatch_stalls),
|
||||
.perf_valids (perf_issue_if.dispatch_valids),
|
||||
.perf_fires (perf_issue_if.dispatch_fires),
|
||||
@@ -104,6 +106,9 @@ module VX_issue #(
|
||||
.lsu_dispatch_if(lsu_dispatch_if),
|
||||
`ifdef EXT_F_ENABLE
|
||||
.fpu_dispatch_if(fpu_dispatch_if),
|
||||
`endif
|
||||
`ifdef EXT_T_ENABLE
|
||||
.tensor_dispatch_if(tensor_dispatch_if),
|
||||
`endif
|
||||
.sfu_dispatch_if(sfu_dispatch_if)
|
||||
);
|
||||
|
||||
@@ -294,6 +294,27 @@ module VX_operands import VX_gpu_pkg::*; #(
|
||||
.raddr (gpr_rd_addr),
|
||||
.rdata (gpr_rd_data[j])
|
||||
);
|
||||
|
||||
// blast read register file because printf is slowge
|
||||
logic [31:0] cycle, cycle_n;
|
||||
assign cycle_n = cycle + 32'd1;
|
||||
always @(posedge clk) begin
|
||||
if (reset) begin
|
||||
cycle <= '0;
|
||||
end
|
||||
else begin
|
||||
cycle <= cycle_n;
|
||||
end
|
||||
|
||||
// if (cycle == 32'd25000) begin
|
||||
// for (integer k = 0; k < `NUM_REGS * ISSUE_RATIO; ++k) begin
|
||||
// integer warp = i * ISSUE_RATIO + (k / `NUM_REGS);
|
||||
// integer thread = j;
|
||||
// integer register = k % `NUM_REGS;
|
||||
// $display("warp %0d, thread %0d, register %0d: %0x", warp, thread, register, gpr_ram.ram[k]);
|
||||
// end
|
||||
// end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
|
||||
@@ -31,9 +31,6 @@ module VX_operands_dup import VX_gpu_pkg::*; #(
|
||||
localparam RAM_ADDRW = `LOG2UP(`NUM_REGS * ISSUE_RATIO);
|
||||
|
||||
for (genvar i = 0; i < `ISSUE_WIDTH; ++i) begin
|
||||
// NOTE(hansung): toggle_buffer is 1-reg pipe without flow, halving
|
||||
// throughput. Wouldn't this cap overall IPC? Or OK as long as
|
||||
// ISSUE_WIDTH > 1?
|
||||
VX_stream_buffer #(
|
||||
.DATAW (DATAW)
|
||||
) staging_buffer (
|
||||
|
||||
285
hw/rtl/core/VX_reduce_unit.sv
Normal file
285
hw/rtl/core/VX_reduce_unit.sv
Normal file
@@ -0,0 +1,285 @@
|
||||
`include "VX_define.vh"
|
||||
`include "VX_platform.vh"
|
||||
|
||||
|
||||
// Copyright © 2019-2023
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
`include "VX_platform.vh"
|
||||
|
||||
module VX_reduce_ext #(
|
||||
parameter DATAW_IN = 1,
|
||||
parameter DATAW_OUT = DATAW_IN,
|
||||
parameter N = 1
|
||||
) (
|
||||
input wire [N-1:0][DATAW_IN-1:0] data_in,
|
||||
input wire [N-1:0] mask,
|
||||
input wire [`INST_RED_BITS-1:0] op_type,
|
||||
output wire [DATAW_OUT-1:0] data_out
|
||||
);
|
||||
// recursive binary reduction
|
||||
if (N == 1) begin
|
||||
`UNUSED_VAR(op_type)
|
||||
`UNUSED_VAR(mask)
|
||||
assign data_out = DATAW_OUT'(data_in[0]);
|
||||
end else begin
|
||||
localparam int N_A = N / 2;
|
||||
localparam int N_B = N - N_A;
|
||||
|
||||
wire [N_A-1:0][DATAW_IN-1:0] in_A;
|
||||
wire [N_B-1:0][DATAW_IN-1:0] in_B;
|
||||
wire [DATAW_OUT-1:0] out_A, out_B;
|
||||
|
||||
wire [N_A-1:0] mask_A;
|
||||
wire [N_B-1:0] mask_B;
|
||||
wire any_A, any_B;
|
||||
|
||||
for (genvar i = 0; i < N_A; i++) begin
|
||||
assign in_A[i] = data_in[i];
|
||||
end
|
||||
|
||||
for (genvar i = 0; i < N_B; i++) begin
|
||||
assign in_B[i] = data_in[N_A + i];
|
||||
end
|
||||
|
||||
assign mask_A = mask[N_A-1:0];
|
||||
assign mask_B = mask[N-1:N_A];
|
||||
assign any_A = |mask_A;
|
||||
assign any_B = |mask_B;
|
||||
|
||||
VX_reduce_ext #(
|
||||
.DATAW_IN (DATAW_IN),
|
||||
.DATAW_OUT (DATAW_OUT),
|
||||
.N (N_A)
|
||||
) reduce_A (
|
||||
.data_in (in_A),
|
||||
.mask(mask_A),
|
||||
.op_type(op_type),
|
||||
.data_out (out_A)
|
||||
);
|
||||
|
||||
VX_reduce_ext #(
|
||||
.DATAW_IN (DATAW_IN),
|
||||
.DATAW_OUT (DATAW_OUT),
|
||||
.N (N_B)
|
||||
) reduce_B (
|
||||
.data_in (in_B),
|
||||
.mask(mask_B),
|
||||
.op_type(op_type),
|
||||
.data_out (out_B)
|
||||
);
|
||||
|
||||
logic [DATAW_OUT-1:0] _data_out;
|
||||
|
||||
always @(*) begin
|
||||
case (op_type)
|
||||
`INST_RED_ADD: _data_out = out_A + out_B;
|
||||
`INST_RED_ADDU: _data_out = out_A + out_B;
|
||||
`INST_RED_MIN: _data_out = ($signed(out_A) < $signed(out_B)) ? out_A : out_B;
|
||||
`INST_RED_MINU: _data_out = (out_A < out_B) ? out_A : out_B;
|
||||
`INST_RED_MAX: _data_out = ($signed(out_A) < $signed(out_B)) ? out_B : out_A;
|
||||
`INST_RED_MAXU: _data_out = (out_A < out_B) ? out_B : out_A;
|
||||
`INST_RED_AND: _data_out = out_A & out_B;
|
||||
`INST_RED_OR: _data_out = out_A | out_B;
|
||||
`INST_RED_XOR: _data_out = out_A ^ out_B;
|
||||
default: _data_out = out_A;
|
||||
endcase
|
||||
end
|
||||
|
||||
// if both sides are masked out, then it doesn't matter what we output
|
||||
assign data_out = (any_A && any_B) ? _data_out : (any_A ? out_A : out_B);
|
||||
|
||||
end
|
||||
|
||||
endmodule
|
||||
|
||||
module VX_reduce_unit #(
|
||||
parameter CORE_ID = 0,
|
||||
parameter NUM_LANES = 1
|
||||
) (
|
||||
input wire clk,
|
||||
input wire reset,
|
||||
|
||||
VX_execute_if.slave execute_if,
|
||||
VX_commit_if.master commit_if
|
||||
);
|
||||
`UNUSED_PARAM(CORE_ID)
|
||||
|
||||
localparam NUM_PACKETS = `NUM_THREADS / NUM_LANES;
|
||||
localparam PID_BITS = `CLOG2(`NUM_THREADS / NUM_LANES);
|
||||
localparam PID_WIDTH = `UP(PID_BITS);
|
||||
|
||||
logic [`XLEN-1:0] accumulator, accumulator_n, reduced_accumulator;
|
||||
wire [(NUM_LANES * `XLEN)-1:0] broadcasted_accumulator;
|
||||
|
||||
assign broadcasted_accumulator = {NUM_LANES{accumulator}};
|
||||
|
||||
wire eop;
|
||||
wire [NUM_LANES-1:0][`XLEN-1:0] data_in;
|
||||
wire [`XLEN-1:0] data_out;
|
||||
|
||||
assign eop = execute_if.data.eop;
|
||||
assign data_in = execute_if.data.rs1_data;
|
||||
|
||||
logic execute_if_valid;
|
||||
logic execute_if_ready;
|
||||
logic commit_if_valid;
|
||||
logic commit_if_ready;
|
||||
|
||||
wire execute_if_fire;
|
||||
wire commit_if_fire;
|
||||
|
||||
assign execute_if_valid = execute_if.valid;
|
||||
assign execute_if.ready = execute_if_ready;
|
||||
|
||||
assign execute_if_fire = execute_if.ready && execute_if.valid;
|
||||
assign commit_if_fire = commit_if_ready && commit_if_valid;
|
||||
|
||||
logic store_tmask_pid;
|
||||
logic read_tmask_pid;
|
||||
wire [PID_WIDTH-1:0] stored_pid;
|
||||
wire [NUM_LANES-1:0] stored_tmask;
|
||||
wire stored_sop;
|
||||
wire stored_eop;
|
||||
|
||||
logic [PID_BITS:0] size;
|
||||
logic [PID_BITS:0] size_n;
|
||||
|
||||
// 1. idle state - wait for execute_if to be valid
|
||||
// 2. accumulate - continue accumulating until eop, store packet id + thread mask for broadcast phase
|
||||
// 3. broadcast - broadcast to rds
|
||||
localparam IDLE = 2'b00;
|
||||
localparam ACCUMULATE = 2'b01;
|
||||
localparam BROADCAST = 2'b10;
|
||||
localparam FINISH = 2'b11;
|
||||
|
||||
logic [1:0] state, state_n;
|
||||
|
||||
always @(*) begin
|
||||
state_n = state;
|
||||
accumulator_n = accumulator;
|
||||
execute_if_ready = '0;
|
||||
commit_if_valid = '0;
|
||||
store_tmask_pid = '0;
|
||||
read_tmask_pid = '0;
|
||||
size_n = store_tmask_pid ? size + 1 : (read_tmask_pid ? size - 1 : size);
|
||||
|
||||
case (state)
|
||||
IDLE: begin
|
||||
if (execute_if_valid) begin
|
||||
accumulator_n = data_out;
|
||||
store_tmask_pid = '1;
|
||||
if (eop) begin
|
||||
state_n = BROADCAST;
|
||||
end
|
||||
else begin
|
||||
execute_if_ready = '1;
|
||||
state_n = ACCUMULATE;
|
||||
end
|
||||
end
|
||||
end
|
||||
ACCUMULATE: begin
|
||||
execute_if_ready = '1;
|
||||
if (eop) begin
|
||||
execute_if_ready = '0;
|
||||
state_n = BROADCAST;
|
||||
end
|
||||
if (eop || execute_if_fire) begin
|
||||
accumulator_n = reduced_accumulator;
|
||||
store_tmask_pid = '1;
|
||||
end
|
||||
end
|
||||
BROADCAST: begin
|
||||
execute_if_ready = '0;
|
||||
commit_if_valid = '1;
|
||||
|
||||
if (commit_if_fire) begin
|
||||
read_tmask_pid = '1;
|
||||
end
|
||||
if (size_n == '0) begin
|
||||
state_n = FINISH;
|
||||
end
|
||||
end
|
||||
FINISH: begin
|
||||
execute_if_ready = '1;
|
||||
if (execute_if_fire) begin
|
||||
state_n = IDLE;
|
||||
end
|
||||
end
|
||||
endcase
|
||||
end
|
||||
|
||||
always @(posedge clk) begin
|
||||
if (reset) begin
|
||||
accumulator <= '0;
|
||||
state <= IDLE;
|
||||
size <= '0;
|
||||
end
|
||||
else begin
|
||||
accumulator <= accumulator_n;
|
||||
state <= state_n;
|
||||
size <= size_n;
|
||||
end
|
||||
end
|
||||
|
||||
VX_reduce_ext #(
|
||||
.DATAW_IN(`XLEN),
|
||||
.N(NUM_LANES)
|
||||
) reducer (
|
||||
.data_in(data_in),
|
||||
.mask(execute_if.data.tmask),
|
||||
.op_type(execute_if.data.op_type),
|
||||
.data_out(data_out)
|
||||
);
|
||||
|
||||
VX_reduce_ext #(
|
||||
.DATAW_IN(`XLEN),
|
||||
.N(2)
|
||||
) accumulator_reducer (
|
||||
.data_in({accumulator, data_out}),
|
||||
.mask(2'b11),
|
||||
.op_type(execute_if.data.op_type),
|
||||
.data_out(reduced_accumulator)
|
||||
);
|
||||
|
||||
VX_elastic_buffer #(
|
||||
.DATAW(NUM_LANES + PID_WIDTH + 1 + 1),
|
||||
.SIZE(NUM_PACKETS)
|
||||
) tmask_pid_store (
|
||||
.clk(clk),
|
||||
.reset(reset),
|
||||
|
||||
.valid_in(store_tmask_pid),
|
||||
`UNUSED_PIN(ready_in),
|
||||
.data_in({execute_if.data.tmask, execute_if.data.pid, execute_if.data.sop, execute_if.data.eop}),
|
||||
|
||||
.data_out({stored_tmask, stored_pid, stored_sop, stored_eop}),
|
||||
.ready_out(read_tmask_pid),
|
||||
`UNUSED_PIN(valid_out)
|
||||
);
|
||||
|
||||
VX_elastic_buffer #(
|
||||
.DATAW(`UUID_WIDTH + `NW_WIDTH + NUM_LANES + `XLEN + 1 + `NR_BITS + (`XLEN * NUM_LANES) + PID_WIDTH + 1 + 1)
|
||||
) output_buffer (
|
||||
.clk(clk),
|
||||
.reset(reset),
|
||||
.valid_in(commit_if_valid),
|
||||
.ready_in(commit_if_ready),
|
||||
.data_in({execute_if.data.uuid, execute_if.data.wid, stored_tmask, execute_if.data.PC, execute_if.data.wb, execute_if.data.rd, broadcasted_accumulator, stored_pid, stored_sop, stored_eop}),
|
||||
|
||||
.data_out({commit_if.data.uuid, commit_if.data.wid, commit_if.data.tmask, commit_if.data.PC, commit_if.data.wb, commit_if.data.rd, commit_if.data.data, commit_if.data.pid, commit_if.data.sop, commit_if.data.eop}),
|
||||
.ready_out(commit_if.ready),
|
||||
.valid_out(commit_if.valid)
|
||||
);
|
||||
|
||||
endmodule
|
||||
316
hw/rtl/core/VX_tensor_core.sv
Normal file
316
hw/rtl/core/VX_tensor_core.sv
Normal file
@@ -0,0 +1,316 @@
|
||||
`include "VX_fpu_define.vh"
|
||||
|
||||
module VX_tensor_core #(
|
||||
|
||||
) (
|
||||
input clk,
|
||||
input reset,
|
||||
|
||||
VX_dispatch_if.slave dispatch_if [`ISSUE_WIDTH],
|
||||
VX_commit_if.master commit_if [`ISSUE_WIDTH]
|
||||
);
|
||||
`STATIC_ASSERT(`NUM_THREADS == 32, ("tensor core requires # of threads in a warp to be 32 (try running w/ CONFIGS=\"-DNUM_THREADS=32\")"));
|
||||
|
||||
for (genvar i = 0; i < `ISSUE_WIDTH; ++i) begin
|
||||
VX_tensor_core_warp #(
|
||||
.ISW(i)
|
||||
) tensor_core (
|
||||
.clk(clk),
|
||||
.reset(reset),
|
||||
|
||||
.dispatch_if(dispatch_if[i]),
|
||||
.commit_if(commit_if[i])
|
||||
);
|
||||
end
|
||||
|
||||
endmodule
|
||||
|
||||
module VX_tensor_core_warp import VX_gpu_pkg::*; #(
|
||||
parameter ISW
|
||||
) (
|
||||
input clk,
|
||||
input reset,
|
||||
|
||||
VX_dispatch_if.slave dispatch_if,
|
||||
VX_commit_if.master commit_if
|
||||
);
|
||||
wire [1:0] step = 2'(dispatch_if.data.op_type);
|
||||
logic [3:0] octet_results_valid;
|
||||
logic [3:0] octet_results_ready;
|
||||
logic [3:0] octet_operands_ready;
|
||||
logic [`NUM_THREADS-1:0][`XLEN-1:0] wb_data_0;
|
||||
logic [`NUM_THREADS-1:0][`XLEN-1:0] wb_data_1;
|
||||
|
||||
assign dispatch_if.ready = &octet_operands_ready;
|
||||
|
||||
for (genvar i = 0; i < 4/*octets*/; ++i) begin
|
||||
// lane-to-octet mapping; see figure 13 of the paper
|
||||
wire [7:0][31:0] octet_A = {
|
||||
dispatch_if.data.rs1_data[16+4*i +: 4], dispatch_if.data.rs1_data[4*i +: 4]
|
||||
};
|
||||
wire [7:0][31:0] octet_B = {
|
||||
dispatch_if.data.rs2_data[16+4*i +: 4], dispatch_if.data.rs2_data[4*i +: 4]
|
||||
};
|
||||
wire [7:0][31:0] octet_C = {
|
||||
dispatch_if.data.rs3_data[16+4*i +: 4], dispatch_if.data.rs3_data[4*i +: 4]
|
||||
};
|
||||
|
||||
logic [3:0][3:0][31:0] octet_D;
|
||||
logic result_valid;
|
||||
logic result_ready;
|
||||
VX_tensor_octet #(
|
||||
.ISW(ISW),
|
||||
.OCTET(i)
|
||||
) octet (
|
||||
.clk(clk),
|
||||
.reset(reset),
|
||||
|
||||
.A_in(octet_A),
|
||||
.B_in(octet_B),
|
||||
.C_in(octet_C),
|
||||
.operands_valid(dispatch_if.valid),
|
||||
.operands_ready(octet_operands_ready[i]),
|
||||
|
||||
.step(step),
|
||||
|
||||
.D_out(octet_D),
|
||||
.result_valid(result_valid),
|
||||
.result_ready(result_ready)
|
||||
);
|
||||
|
||||
// these should always be in lockstep
|
||||
assign octet_results_valid[i] = result_valid;
|
||||
assign result_ready = octet_results_ready[i];
|
||||
|
||||
// each octet produces 4x4 output partial sum, but the 8 lanes mapped
|
||||
// to the octet can only do 8 fp32 writeback at a time; so we need to
|
||||
// split writeback over two cycles
|
||||
assign wb_data_0[4*i+0] = octet_D[0][0];
|
||||
assign wb_data_0[4*i+1] = octet_D[1][0];
|
||||
assign wb_data_0[4*i+2] = octet_D[0][2];
|
||||
assign wb_data_0[4*i+3] = octet_D[1][2];
|
||||
|
||||
assign wb_data_1[4*i+0] = octet_D[0][1];
|
||||
assign wb_data_1[4*i+1] = octet_D[1][1];
|
||||
assign wb_data_1[4*i+2] = octet_D[0][3];
|
||||
assign wb_data_1[4*i+3] = octet_D[1][3];
|
||||
|
||||
assign wb_data_0[4*i+16+0] = octet_D[2][0];
|
||||
assign wb_data_0[4*i+16+1] = octet_D[3][0];
|
||||
assign wb_data_0[4*i+16+2] = octet_D[2][2];
|
||||
assign wb_data_0[4*i+16+3] = octet_D[3][2];
|
||||
|
||||
assign wb_data_1[4*i+16+0] = octet_D[2][1];
|
||||
assign wb_data_1[4*i+16+1] = octet_D[3][1];
|
||||
assign wb_data_1[4*i+16+2] = octet_D[2][3];
|
||||
assign wb_data_1[4*i+16+3] = octet_D[3][3];
|
||||
end
|
||||
|
||||
/* commit_if.data_t parts that we need to keep around:
|
||||
- uuid
|
||||
- wid
|
||||
- tmask
|
||||
- PC
|
||||
- wb
|
||||
- rd
|
||||
*/
|
||||
|
||||
localparam DATAW = `UUID_WIDTH + `NW_WIDTH + `NUM_THREADS + `XLEN + 1 + `NR_BITS;
|
||||
|
||||
wire dispatch_if_fire = dispatch_if.valid && dispatch_if.ready;
|
||||
wire commit_if_fire = commit_if.valid && commit_if.ready;
|
||||
wire [DATAW-1:0] dispatch_if_data_enq = {
|
||||
dispatch_if.data.uuid,
|
||||
wis_to_wid(dispatch_if.data.wis, ISW),
|
||||
dispatch_if.data.tmask,
|
||||
dispatch_if.data.PC,
|
||||
dispatch_if.data.wb,
|
||||
dispatch_if.data.rd
|
||||
};
|
||||
|
||||
wire [DATAW-1:0] dispatch_if_data_deq;
|
||||
|
||||
// this is probably a little oversized
|
||||
VX_fifo_queue #(
|
||||
.DATAW(DATAW),
|
||||
.DEPTH(16)
|
||||
) pending_uops (
|
||||
.clk(clk),
|
||||
.reset(reset),
|
||||
.push(dispatch_if_fire),
|
||||
.pop(commit_if_fire),
|
||||
.data_in(dispatch_if_data_enq),
|
||||
.data_out(dispatch_if_data_deq),
|
||||
`UNUSED_PIN(empty),
|
||||
`UNUSED_PIN(alm_empty),
|
||||
`UNUSED_PIN(full), // should be impossible to overflow
|
||||
`UNUSED_PIN(alm_full),
|
||||
`UNUSED_PIN(size)
|
||||
);
|
||||
|
||||
logic subcommit, subcommit_n;
|
||||
wire all_valid = (& octet_results_valid);
|
||||
assign commit_if.valid = all_valid;
|
||||
|
||||
localparam COMMIT_DATAW = `UUID_WIDTH + `NW_WIDTH + `NUM_THREADS + `XLEN + 1 + `NR_BITS + (`NUM_THREADS * `XLEN) + 1 + 1 + 1;
|
||||
wire [COMMIT_DATAW-1:0] commit_if_data = {
|
||||
dispatch_if_data_deq, /* uuid ~ rd */
|
||||
subcommit == 1'b0 ? wb_data_0 : wb_data_1, /* data */
|
||||
1'b0, /* pid */
|
||||
1'b1, /* sop */
|
||||
1'b1 /* eop */
|
||||
};
|
||||
|
||||
assign commit_if.data = commit_if_data;
|
||||
|
||||
always @(*) begin
|
||||
subcommit_n = commit_if_fire ? ~subcommit : subcommit;
|
||||
if (commit_if_fire && subcommit == 1'b1) begin
|
||||
octet_results_ready = '1;
|
||||
end
|
||||
else begin
|
||||
octet_results_ready = '0;
|
||||
end
|
||||
end
|
||||
|
||||
always @(posedge clk) begin
|
||||
if (reset) begin
|
||||
subcommit <= '0;
|
||||
end
|
||||
else begin
|
||||
subcommit <= subcommit_n;
|
||||
end
|
||||
end
|
||||
|
||||
endmodule
|
||||
|
||||
module VX_tensor_octet #(
|
||||
parameter ISW,
|
||||
parameter OCTET
|
||||
) (
|
||||
input clk,
|
||||
input reset,
|
||||
|
||||
input [7:0][31:0] A_in,
|
||||
input [7:0][31:0] B_in,
|
||||
input [7:0][31:0] C_in,
|
||||
input operands_valid, // we have to backpressure due to there potentially being contention over commit
|
||||
output operands_ready,
|
||||
|
||||
input [1:0] step,
|
||||
|
||||
output [3:0][3:0][31:0] D_out,
|
||||
output result_valid,
|
||||
input result_ready
|
||||
);
|
||||
// 512 bits/octet * 4 octets per warp
|
||||
logic [3:0][31:0] A_buffer, A_buffer_n;
|
||||
logic [3:0][31:0] B_buffer, B_buffer_n;
|
||||
logic [7:0][31:0] C_buffer, C_buffer_n;
|
||||
|
||||
// half the inputs are buffered, half are not (instead coming straight
|
||||
// from operand bus) unlike the real tensor core.
|
||||
// the banks are only 32 bit rather than 64 bit (a pair of fp32 regs).
|
||||
logic [3:0][31:0] A_half;
|
||||
logic [3:0][31:0] B_half;
|
||||
logic [7:0][31:0] C_half;
|
||||
always @(*) begin
|
||||
// note that not all lanes participate at every step
|
||||
case (step)
|
||||
2'b00: begin
|
||||
A_half = { A_in[5:4], A_in[1:0] };
|
||||
B_half = B_in[3:0];
|
||||
end
|
||||
2'b01: begin
|
||||
A_half = { A_in[7:6], A_in[3:2] };
|
||||
B_half = B_in[3:0];
|
||||
end
|
||||
2'b10: begin
|
||||
A_half = { A_in[5:4], A_in[1:0] };
|
||||
B_half = B_in[7:4];
|
||||
end
|
||||
2'b11: begin
|
||||
A_half = { A_in[7:6], A_in[3:2] };
|
||||
B_half = B_in[7:4];
|
||||
end
|
||||
endcase
|
||||
C_half = C_in;
|
||||
end
|
||||
|
||||
logic substep;
|
||||
wire substep_n = (operands_ready && operands_valid) ? ~substep : substep;
|
||||
|
||||
always @(*) begin
|
||||
A_buffer_n = A_buffer;
|
||||
B_buffer_n = B_buffer;
|
||||
C_buffer_n = C_buffer;
|
||||
|
||||
if (substep == 1'b0) begin
|
||||
A_buffer_n = A_half;
|
||||
B_buffer_n = B_half;
|
||||
C_buffer_n = C_half;
|
||||
end
|
||||
end
|
||||
|
||||
always @(posedge clk) begin
|
||||
if (reset) begin
|
||||
A_buffer <= '0;
|
||||
B_buffer <= '0;
|
||||
C_buffer <= '0;
|
||||
substep <= '0;
|
||||
end
|
||||
else begin
|
||||
A_buffer <= A_buffer_n;
|
||||
B_buffer <= B_buffer_n;
|
||||
C_buffer <= C_buffer_n;
|
||||
substep <= substep_n;
|
||||
end
|
||||
end
|
||||
|
||||
wire stall = result_valid && ~result_ready;
|
||||
assign operands_ready = ~stall;
|
||||
|
||||
// A is 4x2 fp32 matrix
|
||||
wire [3:0][1:0][31:0] A_tile = {
|
||||
{ A_half[3], A_buffer[3] },
|
||||
{ A_half[2], A_buffer[2] },
|
||||
{ A_half[1], A_buffer[1] },
|
||||
{ A_half[0], A_buffer[0] }
|
||||
};
|
||||
// B is 2x4 fp32 matrix
|
||||
wire [1:0][3:0][31:0] B_tile = {
|
||||
B_half, B_buffer
|
||||
};
|
||||
// C is 4x4 fp32 matrix
|
||||
logic [3:0][3:0][31:0] C_tile;
|
||||
|
||||
always @(*) begin
|
||||
C_tile = {
|
||||
C_half[7], C_buffer[7], C_half[5], C_buffer[5],
|
||||
C_half[6], C_buffer[6], C_half[4], C_buffer[4],
|
||||
C_half[3], C_buffer[3], C_half[1], C_buffer[1],
|
||||
C_half[2], C_buffer[2], C_half[0], C_buffer[0]
|
||||
};
|
||||
end
|
||||
|
||||
wire do_hmma = (substep == 1'b1 && operands_valid && operands_ready);
|
||||
|
||||
// this does (m,n,k)=(4,4,2) matmul, modeling compute of a single octet
|
||||
VX_tensor_dpu #(
|
||||
.ISW(ISW),
|
||||
.OCTET(OCTET)
|
||||
) dpu (
|
||||
.clk(clk),
|
||||
.reset(reset),
|
||||
|
||||
.stall(stall),
|
||||
|
||||
.valid_in(do_hmma),
|
||||
.A_tile(A_tile),
|
||||
.B_tile(B_tile),
|
||||
.C_tile(C_tile),
|
||||
|
||||
.valid_out(result_valid),
|
||||
.D_tile(D_out)
|
||||
);
|
||||
endmodule
|
||||
97
hw/rtl/core/VX_tensor_ucode.vh
Normal file
97
hw/rtl/core/VX_tensor_ucode.vh
Normal file
@@ -0,0 +1,97 @@
|
||||
// uop metadata (sequencing, next state), execution metadata (EX_TYPE, OP_TYPE, OP_MOD), wb, use pc, use imm, pc, imm, rd, rs1, rs2, rs3
|
||||
HMMA_SET0_STEP0_0: begin
|
||||
uop = {NEXT, HMMA_SET0_STEP0_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(0), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(16), `FREG(0), `FREG(8), `FREG(16)};
|
||||
end
|
||||
HMMA_SET0_STEP0_1: begin
|
||||
uop = {NEXT, HMMA_SET0_STEP1_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(0), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(17), `FREG(1), `FREG(9), `FREG(17)};
|
||||
end
|
||||
HMMA_SET0_STEP1_0: begin
|
||||
uop = {NEXT, HMMA_SET0_STEP1_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(1), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(18), `FREG(0), `FREG(8), `FREG(18)};
|
||||
end
|
||||
HMMA_SET0_STEP1_1: begin
|
||||
uop = {NEXT, HMMA_SET0_STEP2_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(1), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(19), `FREG(1), `FREG(9), `FREG(19)};
|
||||
end
|
||||
HMMA_SET0_STEP2_0: begin
|
||||
uop = {NEXT, HMMA_SET0_STEP2_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(2), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(20), `FREG(0), `FREG(8), `FREG(20)};
|
||||
end
|
||||
HMMA_SET0_STEP2_1: begin
|
||||
uop = {NEXT, HMMA_SET0_STEP3_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(2), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(21), `FREG(1), `FREG(9), `FREG(21)};
|
||||
end
|
||||
HMMA_SET0_STEP3_0: begin
|
||||
uop = {NEXT, HMMA_SET0_STEP3_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(3), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(22), `FREG(0), `FREG(8), `FREG(22)};
|
||||
end
|
||||
HMMA_SET0_STEP3_1: begin
|
||||
uop = {NEXT, HMMA_SET1_STEP0_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(3), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(23), `FREG(1), `FREG(9), `FREG(23)};
|
||||
end
|
||||
HMMA_SET1_STEP0_0: begin
|
||||
uop = {NEXT, HMMA_SET1_STEP0_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(0), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(16), `FREG(2), `FREG(10), `FREG(16)};
|
||||
end
|
||||
HMMA_SET1_STEP0_1: begin
|
||||
uop = {NEXT, HMMA_SET1_STEP1_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(0), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(17), `FREG(3), `FREG(11), `FREG(17)};
|
||||
end
|
||||
HMMA_SET1_STEP1_0: begin
|
||||
uop = {NEXT, HMMA_SET1_STEP1_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(1), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(18), `FREG(2), `FREG(10), `FREG(18)};
|
||||
end
|
||||
HMMA_SET1_STEP1_1: begin
|
||||
uop = {NEXT, HMMA_SET1_STEP2_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(1), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(19), `FREG(3), `FREG(11), `FREG(19)};
|
||||
end
|
||||
HMMA_SET1_STEP2_0: begin
|
||||
uop = {NEXT, HMMA_SET1_STEP2_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(2), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(20), `FREG(2), `FREG(10), `FREG(20)};
|
||||
end
|
||||
HMMA_SET1_STEP2_1: begin
|
||||
uop = {NEXT, HMMA_SET1_STEP3_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(2), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(21), `FREG(3), `FREG(11), `FREG(21)};
|
||||
end
|
||||
HMMA_SET1_STEP3_0: begin
|
||||
uop = {NEXT, HMMA_SET1_STEP3_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(3), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(22), `FREG(2), `FREG(10), `FREG(22)};
|
||||
end
|
||||
HMMA_SET1_STEP3_1: begin
|
||||
uop = {NEXT, HMMA_SET2_STEP0_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(3), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(23), `FREG(3), `FREG(11), `FREG(23)};
|
||||
end
|
||||
HMMA_SET2_STEP0_0: begin
|
||||
uop = {NEXT, HMMA_SET2_STEP0_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(0), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(16), `FREG(4), `FREG(12), `FREG(16)};
|
||||
end
|
||||
HMMA_SET2_STEP0_1: begin
|
||||
uop = {NEXT, HMMA_SET2_STEP1_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(0), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(17), `FREG(5), `FREG(13), `FREG(17)};
|
||||
end
|
||||
HMMA_SET2_STEP1_0: begin
|
||||
uop = {NEXT, HMMA_SET2_STEP1_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(1), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(18), `FREG(4), `FREG(12), `FREG(18)};
|
||||
end
|
||||
HMMA_SET2_STEP1_1: begin
|
||||
uop = {NEXT, HMMA_SET2_STEP2_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(1), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(19), `FREG(5), `FREG(13), `FREG(19)};
|
||||
end
|
||||
HMMA_SET2_STEP2_0: begin
|
||||
uop = {NEXT, HMMA_SET2_STEP2_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(2), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(20), `FREG(4), `FREG(12), `FREG(20)};
|
||||
end
|
||||
HMMA_SET2_STEP2_1: begin
|
||||
uop = {NEXT, HMMA_SET2_STEP3_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(2), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(21), `FREG(5), `FREG(13), `FREG(21)};
|
||||
end
|
||||
HMMA_SET2_STEP3_0: begin
|
||||
uop = {NEXT, HMMA_SET2_STEP3_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(3), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(22), `FREG(4), `FREG(12), `FREG(22)};
|
||||
end
|
||||
HMMA_SET2_STEP3_1: begin
|
||||
uop = {NEXT, HMMA_SET3_STEP0_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(3), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(23), `FREG(5), `FREG(13), `FREG(23)};
|
||||
end
|
||||
HMMA_SET3_STEP0_0: begin
|
||||
uop = {NEXT, HMMA_SET3_STEP0_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(0), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(16), `FREG(6), `FREG(14), `FREG(16)};
|
||||
end
|
||||
HMMA_SET3_STEP0_1: begin
|
||||
uop = {NEXT, HMMA_SET3_STEP1_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(0), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(17), `FREG(7), `FREG(15), `FREG(17)};
|
||||
end
|
||||
HMMA_SET3_STEP1_0: begin
|
||||
uop = {NEXT, HMMA_SET3_STEP1_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(1), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(18), `FREG(6), `FREG(14), `FREG(18)};
|
||||
end
|
||||
HMMA_SET3_STEP1_1: begin
|
||||
uop = {NEXT, HMMA_SET3_STEP2_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(1), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(19), `FREG(7), `FREG(15), `FREG(19)};
|
||||
end
|
||||
HMMA_SET3_STEP2_0: begin
|
||||
uop = {NEXT, HMMA_SET3_STEP2_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(2), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(20), `FREG(6), `FREG(14), `FREG(20)};
|
||||
end
|
||||
HMMA_SET3_STEP2_1: begin
|
||||
uop = {NEXT, HMMA_SET3_STEP3_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(2), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(21), `FREG(7), `FREG(15), `FREG(21)};
|
||||
end
|
||||
HMMA_SET3_STEP3_0: begin
|
||||
uop = {NEXT, HMMA_SET3_STEP3_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(3), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(22), `FREG(6), `FREG(14), `FREG(22)};
|
||||
end
|
||||
HMMA_SET3_STEP3_1: begin
|
||||
uop = {FINISH, HMMA_SET0_STEP0_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(3), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b1, 32'b1, `FREG(23), `FREG(7), `FREG(15), `FREG(23)};
|
||||
end
|
||||
164
hw/rtl/core/VX_uop_sequencer.sv
Normal file
164
hw/rtl/core/VX_uop_sequencer.sv
Normal file
@@ -0,0 +1,164 @@
|
||||
`include "VX_define.vh"
|
||||
|
||||
`define FREG(x) {1'b1, `NRI_BITS'(x)}
|
||||
|
||||
module VX_uop_sequencer import VX_gpu_pkg::*; (
|
||||
input clk,
|
||||
input reset,
|
||||
|
||||
VX_ibuffer_if.slave uop_sequencer_if,
|
||||
VX_ibuffer_if.master ibuffer_if
|
||||
);
|
||||
|
||||
`ifdef EXT_T_ENABLE
|
||||
localparam UOP_TABLE_SIZE = 64;
|
||||
localparam UPC_BITS = `CLOG2(UOP_TABLE_SIZE);
|
||||
|
||||
localparam NEXT = 2'b00;
|
||||
localparam FINISH = 2'b01;
|
||||
|
||||
localparam UBR_BITS = 2;
|
||||
|
||||
// uop metadata (sequencing, next state), execution metadata (EX_TYPE, OP_TYPE, OP_MOD), wb, use pc, use imm, pc, imm, rd, rs1, rs2, rs3
|
||||
localparam UOP_TABLE_WIDTH = UBR_BITS + UPC_BITS + `EX_BITS + `INST_OP_BITS + `INST_MOD_BITS + 1 + 1 + 1 + `XLEN + `XLEN + (`NR_BITS * 4);
|
||||
localparam IBUFFER_IF_DATAW = `UUID_WIDTH + ISSUE_WIS_W + `NUM_THREADS + `XLEN + 1 + `EX_BITS + `INST_OP_BITS + `INST_MOD_BITS + 1 + 1 + `XLEN + (`NR_BITS * 4);
|
||||
|
||||
logic [UOP_TABLE_WIDTH-1:0] uop;
|
||||
|
||||
// reserve space at start of table for more uop sequences
|
||||
localparam HMMA_SET0_STEP0_0 = UPC_BITS'(0);
|
||||
localparam HMMA_SET0_STEP0_1 = UPC_BITS'(8);
|
||||
localparam HMMA_SET0_STEP1_0 = UPC_BITS'(9);
|
||||
localparam HMMA_SET0_STEP1_1 = UPC_BITS'(10);
|
||||
localparam HMMA_SET0_STEP2_0 = UPC_BITS'(11);
|
||||
localparam HMMA_SET0_STEP2_1 = UPC_BITS'(12);
|
||||
localparam HMMA_SET0_STEP3_0 = UPC_BITS'(13);
|
||||
localparam HMMA_SET0_STEP3_1 = UPC_BITS'(14);
|
||||
|
||||
localparam HMMA_SET1_STEP0_0 = UPC_BITS'(15);
|
||||
localparam HMMA_SET1_STEP0_1 = UPC_BITS'(16);
|
||||
localparam HMMA_SET1_STEP1_0 = UPC_BITS'(17);
|
||||
localparam HMMA_SET1_STEP1_1 = UPC_BITS'(18);
|
||||
localparam HMMA_SET1_STEP2_0 = UPC_BITS'(19);
|
||||
localparam HMMA_SET1_STEP2_1 = UPC_BITS'(20);
|
||||
localparam HMMA_SET1_STEP3_0 = UPC_BITS'(21);
|
||||
localparam HMMA_SET1_STEP3_1 = UPC_BITS'(22);
|
||||
|
||||
localparam HMMA_SET2_STEP0_0 = UPC_BITS'(23);
|
||||
localparam HMMA_SET2_STEP0_1 = UPC_BITS'(24);
|
||||
localparam HMMA_SET2_STEP1_0 = UPC_BITS'(25);
|
||||
localparam HMMA_SET2_STEP1_1 = UPC_BITS'(26);
|
||||
localparam HMMA_SET2_STEP2_0 = UPC_BITS'(27);
|
||||
localparam HMMA_SET2_STEP2_1 = UPC_BITS'(28);
|
||||
localparam HMMA_SET2_STEP3_0 = UPC_BITS'(29);
|
||||
localparam HMMA_SET2_STEP3_1 = UPC_BITS'(30);
|
||||
|
||||
localparam HMMA_SET3_STEP0_0 = UPC_BITS'(31);
|
||||
localparam HMMA_SET3_STEP0_1 = UPC_BITS'(32);
|
||||
localparam HMMA_SET3_STEP1_0 = UPC_BITS'(33);
|
||||
localparam HMMA_SET3_STEP1_1 = UPC_BITS'(34);
|
||||
localparam HMMA_SET3_STEP2_0 = UPC_BITS'(35);
|
||||
localparam HMMA_SET3_STEP2_1 = UPC_BITS'(36);
|
||||
localparam HMMA_SET3_STEP3_0 = UPC_BITS'(37);
|
||||
localparam HMMA_SET3_STEP3_1 = UPC_BITS'(38);
|
||||
// register layout: f0-f7 used for A, f8-f15 used for B, f16-f23 used for C
|
||||
|
||||
logic [UPC_BITS-1:0] upc, upc_r, upc_n;
|
||||
|
||||
always @(*) begin
|
||||
case (upc)
|
||||
`include "VX_tensor_ucode.vh"
|
||||
default: begin
|
||||
uop = '0;
|
||||
end
|
||||
endcase
|
||||
end
|
||||
|
||||
wire [UBR_BITS-1:0] ubr = uop[UOP_TABLE_WIDTH-1:UOP_TABLE_WIDTH-UBR_BITS];
|
||||
wire [UPC_BITS-1:0] next_upc = uop[UOP_TABLE_WIDTH-UBR_BITS-1:UOP_TABLE_WIDTH-UBR_BITS-UPC_BITS];
|
||||
|
||||
logic use_uop, use_uop_1d;
|
||||
wire uop_fire = use_uop && ibuffer_if.valid && ibuffer_if.ready;
|
||||
|
||||
wire uop_start = ~use_uop_1d && use_uop;
|
||||
wire uop_finish = use_uop && uop_sequencer_if.valid && uop_sequencer_if.ready;
|
||||
|
||||
// merging the 2 always blocks leads to spurious UNOPTFLAT verilator lint,
|
||||
// but conceptually they should be linked
|
||||
always @(*) begin
|
||||
use_uop = uop_sequencer_if.valid && uop_sequencer_if.data.ex_type == `EX_BITS'(`EX_TENSOR);
|
||||
|
||||
if (uop_start) begin
|
||||
// 1st cycle of microcoded operation, use op_type to determine entry point into microcode table
|
||||
upc_n = UPC_BITS'(uop_sequencer_if.data.op_type);
|
||||
end
|
||||
else begin
|
||||
upc_n = upc;
|
||||
end
|
||||
|
||||
if (uop_fire) begin
|
||||
upc_n = next_upc;
|
||||
end
|
||||
end
|
||||
|
||||
always @(*) begin
|
||||
if (uop_start) begin
|
||||
// 1st cycle of microcoded operation, use op_type to determine entry point into microcode table
|
||||
upc = UPC_BITS'(uop_sequencer_if.data.op_type);
|
||||
end
|
||||
else begin
|
||||
upc = upc_r;
|
||||
end
|
||||
end
|
||||
|
||||
// copy UUID, wis, tmask from microcoded instruction
|
||||
wire [IBUFFER_IF_DATAW-1:0] ibuffer_output = {
|
||||
uop_sequencer_if.data.uuid,
|
||||
uop_sequencer_if.data.wis,
|
||||
uop_sequencer_if.data.tmask,
|
||||
uop[UOP_TABLE_WIDTH-UBR_BITS-UPC_BITS-1:0]
|
||||
};
|
||||
|
||||
// passthrough when !use_uop
|
||||
assign ibuffer_if.valid = use_uop ? 1'b1 : uop_sequencer_if.valid;
|
||||
assign uop_sequencer_if.ready = use_uop ? (uop_fire && ubr == FINISH) : ibuffer_if.ready;
|
||||
assign ibuffer_if.data = use_uop ? ibuffer_output : uop_sequencer_if.data;
|
||||
|
||||
always @(posedge clk) begin
|
||||
if (uop_start) begin
|
||||
// $display("UOP start @ %t", $time);
|
||||
// $display("use_uop=%0d, use_uop_1d=%0d, uop_start=%0d, ibuffer_if.valid=%0d, ibuffer_if.ready=%0d", use_uop, use_uop_1d, uop_start, ibuffer_if.valid, ibuffer_if.ready);
|
||||
end
|
||||
|
||||
if (uop_fire) begin
|
||||
// $display("UOP fire @ %t", $time);
|
||||
end
|
||||
|
||||
if (uop_finish) begin
|
||||
// $display("UOP finish @ %t", $time);
|
||||
end
|
||||
|
||||
if (reset) begin
|
||||
upc_r <= '0;
|
||||
use_uop_1d <= '0;
|
||||
end
|
||||
else begin
|
||||
upc_r <= upc_n;
|
||||
if (uop_finish) begin
|
||||
use_uop_1d <= 1'b0; // allow microcoded instructions to start immediately after eachother
|
||||
end
|
||||
else begin
|
||||
use_uop_1d <= use_uop;
|
||||
end
|
||||
end
|
||||
end
|
||||
`else
|
||||
`UNUSED_VAR(clk);
|
||||
`UNUSED_VAR(reset);
|
||||
assign ibuffer_if.valid = uop_sequencer_if.valid;
|
||||
assign uop_sequencer_if.ready = ibuffer_if.ready;
|
||||
assign ibuffer_if.data = uop_sequencer_if.data;
|
||||
`endif
|
||||
|
||||
|
||||
endmodule
|
||||
85
hw/rtl/core/generate_ucode.py
Normal file
85
hw/rtl/core/generate_ucode.py
Normal file
@@ -0,0 +1,85 @@
|
||||
num_sets = 4
|
||||
num_steps = 4
|
||||
num_substeps = 2
|
||||
|
||||
|
||||
def set_step_substep(sequence_number):
|
||||
set_num, step = divmod(sequence_number, num_steps * num_substeps)
|
||||
step //= num_substeps
|
||||
substep = sequence_number % 2
|
||||
|
||||
return set_num, step, substep
|
||||
|
||||
# set + substep -> rs1, rs2
|
||||
rs1 = {
|
||||
(0, 0): 0,
|
||||
(0, 1): 1,
|
||||
(1, 0): 2,
|
||||
(1, 1): 3,
|
||||
(2, 0): 4,
|
||||
(2, 1): 5,
|
||||
(3, 0): 6,
|
||||
(3, 1): 7
|
||||
}
|
||||
|
||||
rs2 = {
|
||||
(0, 0): 8,
|
||||
(0, 1): 9,
|
||||
(1, 0): 10,
|
||||
(1, 1): 11,
|
||||
(2, 0): 12,
|
||||
(2, 1): 13,
|
||||
(3, 0): 14,
|
||||
(3, 1): 15
|
||||
}
|
||||
|
||||
# step + substep -> rs3, rd
|
||||
rs3_rd = {
|
||||
(0, 0): 16,
|
||||
(0, 1): 17,
|
||||
(1, 0): 18,
|
||||
(1, 1): 19,
|
||||
(2, 0): 20,
|
||||
(2, 1): 21,
|
||||
(3, 0): 22,
|
||||
(3, 1): 23
|
||||
}
|
||||
|
||||
with open('VX_tensor_ucode.vh', 'w') as f:
|
||||
for sequence_number in range(num_sets * num_steps * num_substeps):
|
||||
set_num, step, substep = set_step_substep(sequence_number)
|
||||
|
||||
|
||||
next_sequence_num = (sequence_number + 1) % (num_sets * num_steps * num_substeps)
|
||||
next_set_num, next_step, next_substep = set_step_substep(next_sequence_num)
|
||||
finish = (next_sequence_num == 0)
|
||||
|
||||
name = "HMMA_SET{}_STEP{}_{}"
|
||||
ucode = "{}, HMMA_SET{}_STEP{}_{}, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'({}), `INST_MOD_BITS'({}), 1'b1, 1'b0, 1'b0, 32'b{}, 32'b{}, `FREG({}), `FREG({}), `FREG({}), `FREG({})"
|
||||
|
||||
name = name.format(
|
||||
set_num, step, substep,
|
||||
)
|
||||
|
||||
pc_imm = 1 if finish else 0
|
||||
|
||||
ucode = ucode.format(
|
||||
"FINISH" if finish else "NEXT",
|
||||
next_set_num, next_step, next_substep,
|
||||
step,
|
||||
substep,
|
||||
pc_imm,
|
||||
pc_imm,
|
||||
rs3_rd[(step, substep)],
|
||||
rs1[(set_num, substep)],
|
||||
rs2[(set_num, substep)],
|
||||
rs3_rd[(step, substep)],
|
||||
)
|
||||
|
||||
entry = f"{name}: begin \n"
|
||||
entry += "\tuop = {"
|
||||
entry += ucode
|
||||
entry += "}; \n"
|
||||
entry += "end \n"
|
||||
|
||||
f.write(entry)
|
||||
44
hw/rtl/fpu/VX_tensor_dpu.sv
Normal file
44
hw/rtl/fpu/VX_tensor_dpu.sv
Normal file
@@ -0,0 +1,44 @@
|
||||
`include "VX_fpu_define.vh"
|
||||
|
||||
module VX_tensor_dpu #(
|
||||
parameter ISW,
|
||||
parameter OCTET
|
||||
) (
|
||||
input clk,
|
||||
input reset,
|
||||
|
||||
input stall,
|
||||
|
||||
input valid_in,
|
||||
input [3:0][1:0][31:0] A_tile,
|
||||
input [1:0][3:0][31:0] B_tile,
|
||||
input [3:0][3:0][31:0] C_tile,
|
||||
|
||||
output valid_out,
|
||||
output [3:0][3:0][31:0] D_tile
|
||||
);
|
||||
logic [3:0][3:0][31:0] result_hmma;
|
||||
|
||||
always @(*) begin
|
||||
dpi_hmma(valid_in, A_tile, B_tile, C_tile, result_hmma);
|
||||
end
|
||||
|
||||
always @(posedge clk) begin
|
||||
if (~reset && valid_in) begin
|
||||
dpi_print_results(int'(ISW), int'(OCTET), A_tile, B_tile, C_tile, result_hmma);
|
||||
end
|
||||
end
|
||||
|
||||
|
||||
VX_shift_register #(
|
||||
.DATAW (1 + $bits(D_tile)),
|
||||
.DEPTH (`LATENCY_HMMA),
|
||||
.RESETW (1)
|
||||
) shift_reg (
|
||||
.clk (clk),
|
||||
.reset (reset),
|
||||
.enable (~stall),
|
||||
.data_in ({valid_in, result_hmma}),
|
||||
.data_out ({valid_out, D_tile})
|
||||
);
|
||||
endmodule
|
||||
30
hw/rtl/fpu/VX_tensor_tb.sv
Normal file
30
hw/rtl/fpu/VX_tensor_tb.sv
Normal file
@@ -0,0 +1,30 @@
|
||||
`include "VX_fpu_define.vh"
|
||||
|
||||
module VX_tensor_tb(
|
||||
input clk,
|
||||
input reset,
|
||||
|
||||
input valid_in,
|
||||
input [3:0][1:0][31:0] A_tile,
|
||||
input [1:0][3:0][31:0] B_tile,
|
||||
input [3:0][3:0][31:0] C_tile,
|
||||
|
||||
output valid_out,
|
||||
output [3:0][3:0][31:0] D_tile
|
||||
);
|
||||
|
||||
VX_tensor_dpu #() tensor_core (
|
||||
.clk(clk),
|
||||
.reset(reset),
|
||||
|
||||
.stall(1'b0),
|
||||
|
||||
.valid_in(valid_in),
|
||||
.A_tile(A_tile),
|
||||
.B_tile(B_tile),
|
||||
.C_tile(C_tile),
|
||||
|
||||
.valid_out(valid_out),
|
||||
.D_tile(D_tile)
|
||||
);
|
||||
endmodule
|
||||
89
hw/unittest/tensor/Makefile
Normal file
89
hw/unittest/tensor/Makefile
Normal file
@@ -0,0 +1,89 @@
|
||||
DESTDIR ?= .
|
||||
RTL_DIR = ../../rtl
|
||||
DPI_DIR = $(abspath ../../dpi)
|
||||
SIM_DIR = ../../../sim
|
||||
THIRD_PARTY_DIR = $(abspath ../../../third_party)
|
||||
|
||||
CONFIGS +=
|
||||
PARAMS +=
|
||||
|
||||
CXXFLAGS += -std=c++17 -Wall -Wextra -Wfatal-errors -Wno-array-bounds
|
||||
CXXFLAGS += -fPIC -Wno-maybe-uninitialized
|
||||
CXXFLAGS += -fcoroutines
|
||||
CXXFLAGS += -I../../.. -I../../common -I../../../../sim/common
|
||||
CXXFLAGS += -I/$(THIRD_PARTY_DIR)/softfloat/source/include
|
||||
CXXFLAGS += -I/$(DPI_DIR)
|
||||
CXXFLAGS += $(CONFIGS)
|
||||
|
||||
LDFLAGS += $(THIRD_PARTY_DIR)/softfloat/build/Linux-x86_64-GCC/softfloat.a
|
||||
|
||||
# control RTL debug tracing states
|
||||
DBG_TRACE_FLAGS += -DDBG_TRACE_CACHE_BANK
|
||||
DBG_TRACE_FLAGS += -DDBG_TRACE_CACHE_MSHR
|
||||
DBG_TRACE_FLAGS += -DDBG_TRACE_CACHE_TAG
|
||||
DBG_TRACE_FLAGS += -DDBG_TRACE_CACHE_DATA
|
||||
|
||||
DBG_FLAGS += -DDEBUG_LEVEL=$(DEBUG) -DVCD_OUTPUT $(DBG_TRACE_FLAGS)
|
||||
|
||||
RTL_PKGS = $(RTL_DIR)/VX_gpu_pkg.sv
|
||||
|
||||
RTL_INCLUDE = -I$(RTL_DIR) -I$(DPI_DIR) -I$(RTL_DIR)/libs -I$(RTL_DIR)/fpu
|
||||
|
||||
# SRCS = cachesim.cpp testbench.cpp
|
||||
SRCS += $(DPI_DIR)/util_dpi.cpp
|
||||
SRCS += $(DPI_DIR)/float_dpi.cpp
|
||||
SRCS += $(SIM_DIR)/common/rvfloats.cpp
|
||||
SRCS += ./main.cpp
|
||||
|
||||
RTL_SRCS += $(RTL_DIR)/fpu/VX_tensor_dpu.sv
|
||||
RTL_SRCS += $(RTL_DIR)/fpu/VX_tensor_tb.sv
|
||||
|
||||
TOP = VX_tensor_tb
|
||||
|
||||
VL_FLAGS = --exe
|
||||
VL_FLAGS += --language 1800-2009 # -Wall -Wpedantic # --assert
|
||||
VL_FLAGS += -Wno-DECLFILENAME -Wno-REDEFMACRO
|
||||
VL_FLAGS += --x-initial unique --x-assign unique
|
||||
VL_FLAGS += -DSIMULATION -DSV_DPI
|
||||
VL_FLAGS += $(CONFIGS)
|
||||
VL_FLAGS += $(PARAMS)
|
||||
VL_FLAGS += $(RTL_INCLUDE)
|
||||
VL_FLAGS += $(RTL_PKGS)
|
||||
VL_FLAGS += --cc $(TOP) --top-module $(TOP)
|
||||
VL_FLAGS += --timing
|
||||
|
||||
# Enable Verilator multithreaded simulation
|
||||
THREADS ?= $(shell python -c 'import multiprocessing as mp; print(mp.cpu_count())')
|
||||
VL_FLAGS += -j $(THREADS)
|
||||
#VL_FLAGS += --threads $(THREADS)
|
||||
|
||||
# Debugigng
|
||||
ifdef DEBUG
|
||||
VL_FLAGS += --trace --trace-structs $(DBG_FLAGS)
|
||||
CXXFLAGS += -g -O0 $(DBG_FLAGS)
|
||||
else
|
||||
VL_FLAGS += -DNDEBUG
|
||||
CXXFLAGS += -O2 -DNDEBUG
|
||||
endif
|
||||
|
||||
# Enable perf counters
|
||||
ifdef PERF
|
||||
VL_FLAGS += -DPERF_ENABLE
|
||||
CXXFLAGS += -DPERF_ENABLE
|
||||
endif
|
||||
|
||||
PROJECT = tensor
|
||||
|
||||
all: $(DESTDIR)/$(PROJECT)
|
||||
|
||||
$(DESTDIR)/$(PROJECT): $(SRCS) $(RTL_SRCS)
|
||||
verilator --build $(VL_FLAGS) $(SRCS) -CFLAGS '$(CXXFLAGS)' -LDFLAGS '$(LDFLAGS)' -o ../$@
|
||||
|
||||
run: $(DESTDIR)/$(PROJECT)
|
||||
$(DESTDIR)/$(PROJECT)
|
||||
|
||||
waves: trace.vcd
|
||||
gtkwave -o trace.vcd
|
||||
|
||||
clean:
|
||||
rm -rf obj_dir $(DESTDIR)/$(PROJECT)
|
||||
197
hw/unittest/tensor/main.cpp
Normal file
197
hw/unittest/tensor/main.cpp
Normal file
@@ -0,0 +1,197 @@
|
||||
// Copyright © 2019-2023
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "vl_simulator.h"
|
||||
#include "VVX_tensor_tb.h"
|
||||
#include <iostream>
|
||||
|
||||
#include <half.hpp>
|
||||
|
||||
#define MAX_TICKS 20
|
||||
|
||||
#ifndef TRACE_START_TIME
|
||||
#define TRACE_START_TIME 0ull
|
||||
#endif
|
||||
|
||||
#ifndef TRACE_STOP_TIME
|
||||
#define TRACE_STOP_TIME -1ull
|
||||
#endif
|
||||
|
||||
#define CHECK(x) \
|
||||
do { \
|
||||
if (x) \
|
||||
break; \
|
||||
std::cout << "FAILED: " << #x << std::endl; \
|
||||
std::abort(); \
|
||||
} while (false)
|
||||
|
||||
static uint64_t timestamp = 0;
|
||||
static bool trace_enabled = false;
|
||||
static uint64_t trace_start_time = TRACE_START_TIME;
|
||||
static uint64_t trace_stop_time = TRACE_STOP_TIME;
|
||||
|
||||
double sc_time_stamp() {
|
||||
return timestamp;
|
||||
}
|
||||
|
||||
bool sim_trace_enabled() {
|
||||
if (timestamp >= trace_start_time
|
||||
&& timestamp < trace_stop_time)
|
||||
return true;
|
||||
return trace_enabled;
|
||||
}
|
||||
|
||||
void sim_trace_enable(bool enable) {
|
||||
trace_enabled = enable;
|
||||
}
|
||||
|
||||
using Device = VVX_tensor_tb;
|
||||
using half_float::half;
|
||||
|
||||
static_assert(sizeof(half) == 2);
|
||||
uint32_t half2bits(half h) {
|
||||
uint16_t half_bits;
|
||||
memcpy(&half_bits, &h, sizeof(half));
|
||||
return half_bits;
|
||||
}
|
||||
|
||||
uint32_t float2bits(float f) {
|
||||
uint32_t float_bits;
|
||||
memcpy(&float_bits, &f, sizeof(f));
|
||||
return float_bits;
|
||||
}
|
||||
|
||||
float bits2float(uint32_t b) {
|
||||
float f;
|
||||
memcpy(&f, &b, sizeof(b));
|
||||
return f;
|
||||
}
|
||||
|
||||
// A is M * K, B is K * K * M, C is M * M, D is M * M
|
||||
#define M 4
|
||||
#define K 2
|
||||
|
||||
// row, column
|
||||
float A_tile[M][K];
|
||||
float B_tile[K][M];
|
||||
float C_tile[M][M];
|
||||
float D_tile[M][M];
|
||||
|
||||
void initialize_test_data() {
|
||||
for (int i = 0; i < M; i += 1) {
|
||||
for (int j = 0; j < K; j += 1) {
|
||||
A_tile[i][j] = (float) (i * K + j);
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < K; i += 1) {
|
||||
for (int j = 0; j < M; j += 1) {
|
||||
B_tile[i][j] = (float) (j * K + i);
|
||||
}
|
||||
}
|
||||
for (int i = 0; i < M; i += 1) {
|
||||
for (int j = 0; j < M; j += 1) {
|
||||
C_tile[i][j] = (float) (i * j);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void write_test_data(vl_simulator<Device>& sim) {
|
||||
for (int i = 0; i < M; i += 1) {
|
||||
for (int j = 0; j < K; j += 1) {
|
||||
int index = (i * K + j);
|
||||
uint32_t A_bits = float2bits(A_tile[i][j]);
|
||||
sim->A_tile[index] = A_bits;
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < K; i += 1) {
|
||||
for (int j = 0; j < M; j += 1) {
|
||||
int index = (i * M + j);
|
||||
uint32_t B_bits = float2bits(B_tile[i][j]);
|
||||
sim->B_tile[index] = B_bits;
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < M; i += 1) {
|
||||
for (int j = 0; j < M; j += 1) {
|
||||
int index = (i * M + j);
|
||||
uint32_t C_bits = float2bits(C_tile[i][j]);
|
||||
sim->C_tile[index] = C_bits;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void read_result(vl_simulator<Device>& sim) {
|
||||
for (int i = 0; i < M; i += 1) {
|
||||
for (int j = 0; j < M; j += 1) {
|
||||
int index = (i * M + j);
|
||||
|
||||
uint32_t D_bits = sim->D_tile[index];
|
||||
float f = bits2float(D_bits);
|
||||
D_tile[i][j] = f;
|
||||
std::cout << f << " ";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
void expected() {
|
||||
for (int i = 0; i < M; i += 1) {
|
||||
for (int j = 0; j < M; j += 1) {
|
||||
float accum = C_tile[i][j];
|
||||
for (int k = 0; k < K; k += 1) {
|
||||
accum += A_tile[i][k] * B_tile[k][j];
|
||||
}
|
||||
|
||||
std::cout << accum << " ";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
// Initialize Verilators variables
|
||||
Verilated::commandArgs(argc, argv);
|
||||
|
||||
vl_simulator<Device> sim;
|
||||
|
||||
initialize_test_data();
|
||||
// run test
|
||||
timestamp = sim.reset(0);
|
||||
|
||||
|
||||
// advance clock
|
||||
timestamp = sim.step(timestamp, 10);
|
||||
sim->valid_in = 1;
|
||||
write_test_data(sim);
|
||||
timestamp = sim.step(timestamp, 2);
|
||||
CHECK(sim->valid_out == 0);
|
||||
sim->valid_in = 0;
|
||||
timestamp = sim.step(timestamp, 2);
|
||||
CHECK(sim->valid_out == 0);
|
||||
timestamp = sim.step(timestamp, 2);
|
||||
CHECK(sim->valid_out == 0);
|
||||
timestamp = sim.step(timestamp, 2);
|
||||
CHECK(sim->valid_out == 1);
|
||||
read_result(sim);
|
||||
timestamp = sim.step(timestamp, 2);
|
||||
CHECK(sim->valid_out == 0);
|
||||
|
||||
expected();
|
||||
|
||||
std::cout << "PASSED!" << std::endl;
|
||||
std::cout << "Simulation time: " << std::dec << timestamp/2 << " cycles" << std::endl;
|
||||
|
||||
return 0;
|
||||
}
|
||||
@@ -48,6 +48,7 @@ void vx_wspawn_wait();
|
||||
void vx_spawn_kernel(context_t * ctx, vx_spawn_kernel_cb callback, void * arg);
|
||||
|
||||
void vx_spawn_tasks(int num_tasks, vx_spawn_tasks_cb callback, void * arg);
|
||||
void vx_spawn_tasks_contiguous(int num_tasks, vx_spawn_tasks_cb callback , void * arg);
|
||||
|
||||
void vx_serial(vx_serial_cb callback, void * arg);
|
||||
|
||||
|
||||
@@ -83,6 +83,38 @@ static void __attribute__ ((noinline)) spawn_tasks_rem_stub() {
|
||||
(p_wspawn_args->callback)(task_id, p_wspawn_args->arg);
|
||||
}
|
||||
|
||||
static void __attribute__ ((noinline)) spawn_tasks_contiguous_all_stub() {
|
||||
int NT = vx_num_threads();
|
||||
int NW = vx_num_warps();
|
||||
int cid = vx_core_id();
|
||||
int wid = vx_warp_id();
|
||||
int tid = vx_thread_id();
|
||||
|
||||
wspawn_tasks_args_t* p_wspawn_args = (wspawn_tasks_args_t*)g_wspawn_args[cid];
|
||||
|
||||
int waves = p_wspawn_args->NWs + (wid < p_wspawn_args->RWs);
|
||||
int offset = p_wspawn_args->offset + (NT * wid + tid);
|
||||
|
||||
vx_spawn_tasks_cb callback = p_wspawn_args->callback;
|
||||
void* arg = p_wspawn_args->arg;
|
||||
for (int wave_id = 0; wave_id < waves; ++wave_id) {
|
||||
int task_id = offset + (wave_id * NT * NW);
|
||||
callback(task_id, arg);
|
||||
}
|
||||
}
|
||||
|
||||
static void __attribute__ ((noinline)) spawn_tasks_contiguous_all_cb() {
|
||||
// activate all threads
|
||||
vx_tmc(-1);
|
||||
|
||||
// call stub routine
|
||||
spawn_tasks_contiguous_all_stub();
|
||||
|
||||
// disable warp
|
||||
// deadlock here on warps 1, 2, 3
|
||||
vx_tmc_zero();
|
||||
}
|
||||
|
||||
static void __attribute__ ((noinline)) spawn_tasks_all_cb() {
|
||||
// activate all threads
|
||||
vx_tmc(-1);
|
||||
@@ -94,6 +126,79 @@ static void __attribute__ ((noinline)) spawn_tasks_all_cb() {
|
||||
vx_tmc_zero();
|
||||
}
|
||||
|
||||
void vx_spawn_tasks_contiguous(int num_tasks, vx_spawn_tasks_cb callback , void * arg) {
|
||||
// device specs
|
||||
int NC = vx_num_cores();
|
||||
int NW = vx_num_warps();
|
||||
int NT = vx_num_threads();
|
||||
|
||||
// current core id
|
||||
int core_id = vx_core_id();
|
||||
if (core_id >= NUM_CORES_MAX)
|
||||
return;
|
||||
|
||||
// calculate necessary active cores
|
||||
int WT = NW * NT;
|
||||
int nC = (num_tasks > WT) ? (num_tasks / WT) : 1;
|
||||
int nc = MIN(nC, NC);
|
||||
if (core_id >= nc)
|
||||
return; // terminate extra cores
|
||||
|
||||
// number of tasks per core
|
||||
int tasks_per_core = num_tasks / nc;
|
||||
int tasks_per_core_n1 = tasks_per_core;
|
||||
if (core_id == (nc-1)) {
|
||||
int rem = num_tasks - (nc * tasks_per_core);
|
||||
tasks_per_core_n1 += rem; // last core also executes remaining tasks
|
||||
}
|
||||
|
||||
// number of tasks per warp
|
||||
int TW = tasks_per_core_n1 / NT; // occupied warps
|
||||
int rT = tasks_per_core_n1 - TW * NT; // remaining threads
|
||||
int fW = 1, rW = 0;
|
||||
if (TW >= NW) {
|
||||
fW = TW / NW; // full warps iterations
|
||||
rW = TW - fW * NW; // remaining warps
|
||||
}
|
||||
|
||||
wspawn_tasks_args_t wspawn_args = { callback, arg, core_id * tasks_per_core, fW, rW };
|
||||
g_wspawn_args[core_id] = &wspawn_args;
|
||||
|
||||
if (TW >= 1) {
|
||||
// execute callback on other warps
|
||||
int nw = MIN(TW, NW);
|
||||
vx_wspawn(nw, spawn_tasks_contiguous_all_cb);
|
||||
|
||||
// activate all threads
|
||||
vx_tmc(-1);
|
||||
|
||||
// call stub routine
|
||||
spawn_tasks_contiguous_all_stub();
|
||||
|
||||
// back to single-threaded
|
||||
vx_tmc_one();
|
||||
|
||||
// wait for spawn warps to terminate
|
||||
// deadlock here on warp 0!
|
||||
vx_wspawn_wait();
|
||||
}
|
||||
|
||||
if (rT != 0) {
|
||||
// adjust offset
|
||||
wspawn_args.offset += (tasks_per_core_n1 - rT);
|
||||
|
||||
// activate remaining threads
|
||||
int tmask = (1 << rT) - 1;
|
||||
vx_tmc(tmask);
|
||||
|
||||
// call stub routine
|
||||
spawn_tasks_rem_stub();
|
||||
|
||||
// back to single-threaded
|
||||
vx_tmc_one();
|
||||
}
|
||||
}
|
||||
|
||||
void vx_spawn_tasks(int num_tasks, vx_spawn_tasks_cb callback , void * arg) {
|
||||
// device specs
|
||||
int NC = vx_num_cores();
|
||||
|
||||
@@ -1,19 +1,23 @@
|
||||
all:
|
||||
$(MAKE) -C conform
|
||||
$(MAKE) -C hello
|
||||
$(MAKE) -C fibonacci
|
||||
$(MAKE) -C fibonacci
|
||||
$(MAKE) -C reductions
|
||||
|
||||
run-simx:
|
||||
$(MAKE) -C conform run-simx
|
||||
$(MAKE) -C hello run-simx
|
||||
$(MAKE) -C fibonacci run-simx
|
||||
$(MAKE) -C reductions run-simx
|
||||
|
||||
run-rtlsim:
|
||||
$(MAKE) -C conform run-rtlsim
|
||||
$(MAKE) -C hello run-rtlsim
|
||||
$(MAKE) -C fibonacci run-rtlsim
|
||||
$(MAKE) -C reductions run-rtlsim
|
||||
|
||||
clean:
|
||||
$(MAKE) -C conform clean
|
||||
$(MAKE) -C hello clean
|
||||
$(MAKE) -C fibonacci clean
|
||||
$(MAKE) -C reductions clean
|
||||
|
||||
@@ -33,7 +33,7 @@ $(PROJECT).dump: $(PROJECT).elf
|
||||
$(PROJECT).bin: $(PROJECT).elf
|
||||
$(CP) -O binary $(PROJECT).elf $(PROJECT).bin
|
||||
|
||||
$(PROJECT).elf: $(SRCS)
|
||||
$(PROJECT).elf: $(SRCS) $(DEPS)
|
||||
$(CC) $(CFLAGS) $(SRCS) $(LDFLAGS) -o $(PROJECT).elf
|
||||
|
||||
run-rtlsim: $(PROJECT).bin
|
||||
|
||||
5
tests/kernel/reductions/Makefile
Normal file
5
tests/kernel/reductions/Makefile
Normal file
@@ -0,0 +1,5 @@
|
||||
PROJECT = reductions
|
||||
|
||||
SRCS = main.cpp
|
||||
|
||||
include ../common.mk
|
||||
216
tests/kernel/reductions/main.cpp
Normal file
216
tests/kernel/reductions/main.cpp
Normal file
@@ -0,0 +1,216 @@
|
||||
#define RISCV_CUSTOM2 0x5B
|
||||
#define ADD_FUNC7 0b0000000
|
||||
#define ADDU_FUNC7 0b1000000
|
||||
#define MIN_FUNC7 0b0000001
|
||||
#define MINU_FUNC7 0b1000001
|
||||
#define MAX_FUNC7 0b0000010
|
||||
#define MAXU_FUNC7 0b1000010
|
||||
#define AND_FUNC7 0b0000011
|
||||
#define OR_FUNC7 0b0000100
|
||||
#define XOR_FUNC7 0b0000101
|
||||
|
||||
/*
|
||||
6'h0: begin
|
||||
op_type = func7[6] ? `INST_RED_ADDU : `INST_RED_ADD;
|
||||
end
|
||||
6'h1: begin
|
||||
op_type = func7[6] ? `INST_RED_MINU : `INST_RED_MIN;
|
||||
end
|
||||
6'h2: begin
|
||||
op_type = func7[6] ? `INST_RED_MAXU : `INST_RED_MAX;
|
||||
end
|
||||
6'h3: begin
|
||||
op_type = `INST_RED_AND;
|
||||
end
|
||||
6'h4: begin
|
||||
op_type = `INST_RED_OR;
|
||||
end
|
||||
6'h5: begin
|
||||
op_type = `INST_RED_XOR;
|
||||
end
|
||||
*/
|
||||
|
||||
#include <vx_intrinsics.h>
|
||||
#include <stdio.h>
|
||||
#include <vx_print.h>
|
||||
|
||||
int x[4] = {3, 7, 2, 5};
|
||||
int y = -1;
|
||||
|
||||
inline int vx_add_reduce(int v) {
|
||||
int ret;
|
||||
asm volatile (".insn r %2, 0, %3, %0, %1, x0" : "=r"(ret) : "r"(v), "i"(RISCV_CUSTOM2), "i"(ADD_FUNC7));
|
||||
return ret;
|
||||
}
|
||||
|
||||
inline int vx_min_reduce(int v) {
|
||||
int ret;
|
||||
asm volatile (".insn r %2, 0, %3, %0, %1, x0" : "=r"(ret) : "r"(v), "i"(RISCV_CUSTOM2), "i"(MIN_FUNC7));
|
||||
return ret;
|
||||
}
|
||||
|
||||
inline unsigned vx_minu_reduce(unsigned v) {
|
||||
unsigned ret;
|
||||
asm volatile (".insn r %2, 0, %3, %0, %1, x0" : "=r"(ret) : "r"(v), "i"(RISCV_CUSTOM2), "i"(MINU_FUNC7));
|
||||
return ret;
|
||||
}
|
||||
|
||||
inline int vx_max_reduce(int v) {
|
||||
int ret;
|
||||
asm volatile (".insn r %2, 0, %3, %0, %1, x0" : "=r"(ret) : "r"(v), "i"(RISCV_CUSTOM2), "i"(MAX_FUNC7));
|
||||
return ret;
|
||||
}
|
||||
|
||||
inline unsigned vx_maxu_reduce(unsigned v) {
|
||||
unsigned ret;
|
||||
asm volatile (".insn r %2, 0, %3, %0, %1, x0" : "=r"(ret) : "r"(v), "i"(RISCV_CUSTOM2), "i"(MAXU_FUNC7));
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
||||
inline unsigned vx_and_reduce(unsigned v) {
|
||||
unsigned ret;
|
||||
asm volatile (".insn r %2, 0, %3, %0, %1, x0" : "=r"(ret) : "r"(v), "i"(RISCV_CUSTOM2), "i"(AND_FUNC7));
|
||||
return ret;
|
||||
}
|
||||
|
||||
inline unsigned vx_or_reduce(unsigned v) {
|
||||
unsigned ret;
|
||||
asm volatile (".insn r %2, 0, %3, %0, %1, x0" : "=r"(ret) : "r"(v), "i"(RISCV_CUSTOM2), "i"(OR_FUNC7));
|
||||
return ret;
|
||||
}
|
||||
|
||||
inline unsigned vx_xor_reduce(unsigned v) {
|
||||
unsigned ret;
|
||||
asm volatile (".insn r %2, 0, %3, %0, %1, x0" : "=r"(ret) : "r"(v), "i"(RISCV_CUSTOM2), "i"(XOR_FUNC7));
|
||||
return ret;
|
||||
}
|
||||
|
||||
void test_add_reduce() {
|
||||
vx_tmc(-1);
|
||||
int tid = vx_thread_id();
|
||||
int v = x[tid];
|
||||
int reduced = vx_add_reduce(v);
|
||||
vx_tmc(1);
|
||||
|
||||
y = reduced;
|
||||
}
|
||||
|
||||
unsigned unsigned_vector[4] = {(unsigned)-1, 0, (unsigned)-2, 5};
|
||||
|
||||
void test_min_reduce() {
|
||||
vx_tmc(-1);
|
||||
int tid = vx_thread_id();
|
||||
int v = unsigned_vector[tid];
|
||||
int reduced = vx_min_reduce(v);
|
||||
vx_tmc(1);
|
||||
|
||||
y = reduced;
|
||||
}
|
||||
|
||||
void test_max_reduce() {
|
||||
vx_tmc(-1);
|
||||
int tid = vx_thread_id();
|
||||
int v = unsigned_vector[tid];
|
||||
int reduced = vx_max_reduce(v);
|
||||
vx_tmc(1);
|
||||
|
||||
y = reduced;
|
||||
}
|
||||
|
||||
void test_minu_reduce() {
|
||||
vx_tmc(-1);
|
||||
int tid = vx_thread_id();
|
||||
unsigned v = unsigned_vector[tid];
|
||||
unsigned reduced = vx_minu_reduce(v);
|
||||
vx_tmc(1);
|
||||
|
||||
y = reduced;
|
||||
}
|
||||
|
||||
void test_maxu_reduce() {
|
||||
vx_tmc(-1);
|
||||
int tid = vx_thread_id();
|
||||
unsigned v = unsigned_vector[tid];
|
||||
unsigned reduced = vx_maxu_reduce(v);
|
||||
vx_tmc(1);
|
||||
|
||||
y = reduced;
|
||||
}
|
||||
|
||||
unsigned bit_vectors[4] = {0b11010110000111001100010100100110, 0b10010100011010001010000000001110, 0b10001001010111110001110000000010, 0b00010011010100101101110111001111};
|
||||
|
||||
void test_and_reduce() {
|
||||
vx_tmc(-1);
|
||||
int tid = vx_thread_id();
|
||||
unsigned v = bit_vectors[tid];
|
||||
unsigned reduced = vx_and_reduce(v);
|
||||
vx_tmc(1);
|
||||
|
||||
y = reduced;
|
||||
}
|
||||
|
||||
void test_or_reduce() {
|
||||
vx_tmc(-1);
|
||||
int tid = vx_thread_id();
|
||||
unsigned v = bit_vectors[tid];
|
||||
unsigned reduced = vx_or_reduce(v);
|
||||
vx_tmc(1);
|
||||
|
||||
y = reduced;
|
||||
}
|
||||
|
||||
void test_xor_reduce() {
|
||||
vx_tmc(-1);
|
||||
int tid = vx_thread_id();
|
||||
unsigned v = bit_vectors[tid];
|
||||
unsigned reduced = vx_xor_reduce(v);
|
||||
vx_tmc(1);
|
||||
|
||||
y = reduced;
|
||||
}
|
||||
|
||||
int main()
|
||||
{
|
||||
int expected;
|
||||
|
||||
test_add_reduce();
|
||||
vx_printf("add reduce result: %d\n", y);
|
||||
vx_printf("expected: %d\n", x[0] + x[1] + x[2] + x[3]);
|
||||
|
||||
test_min_reduce();
|
||||
vx_printf("min reduce result: %d\n", y);
|
||||
expected = MIN((int)unsigned_vector[0], MIN((int)unsigned_vector[1], MIN((int)unsigned_vector[2], (int)unsigned_vector[3])));
|
||||
vx_printf("expected: %d\n", expected);
|
||||
|
||||
test_max_reduce();
|
||||
vx_printf("max reduce result: %d\n", y);
|
||||
expected = MAX((int)unsigned_vector[0], MAX((int)unsigned_vector[1], MAX((int)unsigned_vector[2], (int)unsigned_vector[3])));
|
||||
vx_printf("expected: %d\n", expected);
|
||||
|
||||
test_minu_reduce();
|
||||
vx_printf("minu reduce result: %d\n", y);
|
||||
expected = MIN(unsigned_vector[0], MIN(unsigned_vector[1], MIN(unsigned_vector[2], unsigned_vector[3])));
|
||||
vx_printf("expected: %d\n", expected);
|
||||
|
||||
test_maxu_reduce();
|
||||
vx_printf("maxu reduce result: %d\n", y);
|
||||
expected = MAX(unsigned_vector[0], MAX(unsigned_vector[1], MAX(unsigned_vector[2], unsigned_vector[3])));
|
||||
vx_printf("expected: %d\n", expected);
|
||||
|
||||
test_and_reduce();
|
||||
vx_printf("and reduce result: %d\n", y);
|
||||
vx_printf("expected: %d\n", bit_vectors[0] & bit_vectors[1] & bit_vectors[2] & bit_vectors[3]);
|
||||
|
||||
|
||||
test_or_reduce();
|
||||
vx_printf("or reduce result: %d\n", y);
|
||||
vx_printf("expected: %d\n", bit_vectors[0] | bit_vectors[1] | bit_vectors[2] | bit_vectors[3]);
|
||||
|
||||
test_xor_reduce();
|
||||
vx_printf("xor reduce result: %d\n", y);
|
||||
vx_printf("expected: %d\n", bit_vectors[0] ^ bit_vectors[1] ^ bit_vectors[2] ^ bit_vectors[3]);
|
||||
|
||||
|
||||
return 0;
|
||||
}
|
||||
8
tests/kernel/tensor/Makefile
Normal file
8
tests/kernel/tensor/Makefile
Normal file
@@ -0,0 +1,8 @@
|
||||
PROJECT = tensor
|
||||
|
||||
SRCS = main.cpp
|
||||
DEPS = a_matrix.h
|
||||
DEPS += b_matrix.h
|
||||
DEPS += c_matrix.h
|
||||
|
||||
include ../common.mk
|
||||
95
tests/kernel/tensor/check_correctness.py
Normal file
95
tests/kernel/tensor/check_correctness.py
Normal file
@@ -0,0 +1,95 @@
|
||||
import numpy as np
|
||||
import struct
|
||||
|
||||
A_array = np.zeros((16, 8))
|
||||
B_array = np.zeros((8, 16))
|
||||
C_array = np.zeros((16, 16))
|
||||
|
||||
file = input("simulator output filename: ")
|
||||
|
||||
def hex2float(float_hex_str):
|
||||
# print(float_hex_str.strip())
|
||||
return struct.unpack(">f",struct.pack(">i",int(float_hex_str,16)))[0]
|
||||
|
||||
def C_index(threadgroup, thread, register):
|
||||
"""
|
||||
col = ((tg % 4) / 2) * 8;
|
||||
row = (tg * 8) % 16;
|
||||
row += (tg / 4) * 4;
|
||||
|
||||
asm volatile ("flw f16, %0" :: "m"(C[row+0][col+0]));
|
||||
asm volatile ("flw f17, %0" :: "m"(C[row+0][col+1]));
|
||||
asm volatile ("flw f18, %0" :: "m"(C[row+2][col+0]));
|
||||
asm volatile ("flw f19, %0" :: "m"(C[row+2][col+1]));
|
||||
asm volatile ("flw f20, %0" :: "m"(C[row+0][col+4]));
|
||||
asm volatile ("flw f21, %0" :: "m"(C[row+0][col+5]));
|
||||
asm volatile ("flw f22, %0" :: "m"(C[row+2][col+4]));
|
||||
asm volatile ("flw f23, %0" :: "m"(C[row+2][col+5]));
|
||||
"""
|
||||
|
||||
col = ((threadgroup % 4) // 2) * 8
|
||||
row = (threadgroup * 8) % 16
|
||||
row += (threadgroup // 4) * 4
|
||||
offsets = [(0, 0), (0, 1), (2, 0), (2, 1), (0, 4), (0, 5), (2, 4), (2, 5)]
|
||||
offset = offsets[register-16]
|
||||
row += offset[0]
|
||||
col += offset[1]
|
||||
thread_offsets = [(0, 0), (1, 0), (0, 2), (1, 2)]
|
||||
thread_offset = thread_offsets[thread % 4]
|
||||
row += thread_offset[0]
|
||||
col += thread_offset[1]
|
||||
if C_array[row, col] != 0:
|
||||
print("bad")
|
||||
return (row, col)
|
||||
|
||||
|
||||
with open(file) as f:
|
||||
for line in f.readlines():
|
||||
line = line.strip()
|
||||
if "warp" in line:
|
||||
a, b, c = line.split(',')
|
||||
_, a = a.split(' ')
|
||||
_, b = b.strip().split(' ')
|
||||
c, d = c.strip().split(':')
|
||||
_, c = c.split(' ')
|
||||
warp = int(a)
|
||||
thread = int(b)
|
||||
register = int(c)
|
||||
value = d.strip()
|
||||
|
||||
if warp != 0:
|
||||
continue
|
||||
if not (32 <= register < 32+24):
|
||||
continue
|
||||
|
||||
register = register - 32
|
||||
|
||||
# threadgroups 0, 4, 1, 5 have all elements of A
|
||||
threadgroup = thread // 4
|
||||
if threadgroup in [0, 4, 1, 5]:
|
||||
row = [0, 4, 1, 5].index(threadgroup) * 4 + thread % 4
|
||||
if 0 <= register < 8:
|
||||
A_array[row, register] = hex2float(value)
|
||||
|
||||
if threadgroup in [0, 4, 2, 6]:
|
||||
col = [0, 4, 2, 6].index(threadgroup) * 4 + thread % 4
|
||||
if 8 <= register < 16:
|
||||
B_array[register-8, col] = hex2float(value)
|
||||
|
||||
if 16 <= register < 24:
|
||||
# print(value)
|
||||
C_array[C_index(threadgroup, thread, register)] = hex2float(value)
|
||||
|
||||
|
||||
expected = np.load("abc.npz")
|
||||
expected_A = expected['A_array']
|
||||
expected_B = expected['B_array']
|
||||
expected_C = expected['C_array']
|
||||
expected_C = expected_C + expected_A @ expected_B
|
||||
print(expected_C[0:8, 0:8])
|
||||
print(C_array[0:8, 0:8])
|
||||
print((expected_C - C_array)[0:8, 0:8])
|
||||
|
||||
assert np.allclose(expected_A, A_array)
|
||||
assert np.allclose(expected_B, B_array)
|
||||
assert np.allclose(expected_C, C_array)
|
||||
32
tests/kernel/tensor/create_test_case.py
Normal file
32
tests/kernel/tensor/create_test_case.py
Normal file
@@ -0,0 +1,32 @@
|
||||
import numpy as np
|
||||
A_array = np.random.rand(16, 8)
|
||||
B_array = np.random.rand(8, 16)
|
||||
C_array = np.random.rand(16, 16)
|
||||
# A_array = np.zeros((16, 8))
|
||||
# B_array = np.zeros((8, 16))
|
||||
# A_array[0,:] = 1.0
|
||||
# B_array[:,4] = 1.0
|
||||
# C_array = np.zeros((16, 16))
|
||||
# for i in range(16):
|
||||
# for j in range(16):
|
||||
# C_array[i,j] = i * 16 + j
|
||||
|
||||
with open('a_matrix.h', 'w') as f:
|
||||
for i in range(A_array.shape[0]):
|
||||
for j in range(A_array.shape[1]):
|
||||
f.write(f'{A_array[i,j]}f, ')
|
||||
f.write('\n')
|
||||
|
||||
with open('b_matrix.h', 'w') as f:
|
||||
for i in range(B_array.shape[0]):
|
||||
for j in range(B_array.shape[1]):
|
||||
f.write(f'{B_array[i,j]}f, ')
|
||||
f.write('\n')
|
||||
|
||||
with open('c_matrix.h', 'w') as f:
|
||||
for i in range(C_array.shape[0]):
|
||||
for j in range(C_array.shape[1]):
|
||||
f.write(f'{C_array[i,j]}f, ')
|
||||
f.write('\n')
|
||||
|
||||
np.savez("abc", A_array=A_array, B_array=B_array, C_array=C_array)
|
||||
96
tests/kernel/tensor/main.cpp
Normal file
96
tests/kernel/tensor/main.cpp
Normal file
@@ -0,0 +1,96 @@
|
||||
#define RISCV_CUSTOM3 0x7B
|
||||
|
||||
#include <vx_intrinsics.h>
|
||||
#include <stdio.h>
|
||||
#include <vx_print.h>
|
||||
|
||||
inline void vx_wmma() {
|
||||
asm volatile (".insn r %0, 0, 0, x0, x0, x0" :: "i"(RISCV_CUSTOM3));
|
||||
}
|
||||
|
||||
#include "test_data.h"
|
||||
|
||||
void vx_wmma_load() {
|
||||
int tid = vx_thread_id();
|
||||
int tg = tid / 4;
|
||||
|
||||
// load A
|
||||
int row = tid % 4;
|
||||
row += (tg * 8) % 16;
|
||||
row += (tg / 4) * 4;
|
||||
|
||||
asm volatile ("flw f0, %0" :: "m"(A[row][0]));
|
||||
asm volatile ("flw f1, %0" :: "m"(A[row][1]));
|
||||
asm volatile ("flw f2, %0" :: "m"(A[row][2]));
|
||||
asm volatile ("flw f3, %0" :: "m"(A[row][3]));
|
||||
asm volatile ("flw f4, %0" :: "m"(A[row][4]));
|
||||
asm volatile ("flw f5, %0" :: "m"(A[row][5]));
|
||||
asm volatile ("flw f6, %0" :: "m"(A[row][6]));
|
||||
asm volatile ("flw f7, %0" :: "m"(A[row][7]));
|
||||
|
||||
// load B
|
||||
int col = tid % 4;
|
||||
col += ((tg % 4) / 2) * 8;
|
||||
col += (tg / 4) * 4;
|
||||
|
||||
asm volatile ("flw f8 , %0" :: "m"(B[0][col]));
|
||||
asm volatile ("flw f9 , %0" :: "m"(B[1][col]));
|
||||
asm volatile ("flw f10, %0" :: "m"(B[2][col]));
|
||||
asm volatile ("flw f11, %0" :: "m"(B[3][col]));
|
||||
asm volatile ("flw f12, %0" :: "m"(B[4][col]));
|
||||
asm volatile ("flw f13, %0" :: "m"(B[5][col]));
|
||||
asm volatile ("flw f14, %0" :: "m"(B[6][col]));
|
||||
asm volatile ("flw f15, %0" :: "m"(B[7][col]));
|
||||
|
||||
// load C
|
||||
col = ((tg % 4) / 2) * 8;
|
||||
row = (tg * 8) % 16;
|
||||
row += (tg / 4) * 4;
|
||||
|
||||
row += (tid % 4) % 2;
|
||||
col += ((tid % 4) / 2) * 2;
|
||||
|
||||
asm volatile ("flw f16, %0" :: "m"(C[row+0][col+0]));
|
||||
asm volatile ("flw f17, %0" :: "m"(C[row+0][col+1]));
|
||||
asm volatile ("flw f18, %0" :: "m"(C[row+2][col+0]));
|
||||
asm volatile ("flw f19, %0" :: "m"(C[row+2][col+1]));
|
||||
asm volatile ("flw f20, %0" :: "m"(C[row+0][col+4]));
|
||||
asm volatile ("flw f21, %0" :: "m"(C[row+0][col+5]));
|
||||
asm volatile ("flw f22, %0" :: "m"(C[row+2][col+4]));
|
||||
asm volatile ("flw f23, %0" :: "m"(C[row+2][col+5]));
|
||||
}
|
||||
|
||||
float results[32*8];
|
||||
|
||||
void store_wmma_result() {
|
||||
int tid = vx_thread_id();
|
||||
|
||||
asm volatile ("fsw f16, %0" :: "m"(results[tid*8+0]));
|
||||
asm volatile ("fsw f17, %0" :: "m"(results[tid*8+1]));
|
||||
asm volatile ("fsw f18, %0" :: "m"(results[tid*8+2]));
|
||||
asm volatile ("fsw f19, %0" :: "m"(results[tid*8+3]));
|
||||
asm volatile ("fsw f20, %0" :: "m"(results[tid*8+4]));
|
||||
asm volatile ("fsw f21, %0" :: "m"(results[tid*8+5]));
|
||||
asm volatile ("fsw f22, %0" :: "m"(results[tid*8+6]));
|
||||
asm volatile ("fsw f23, %0" :: "m"(results[tid*8+7]));
|
||||
}
|
||||
|
||||
void print_wmma_result() {
|
||||
for (int tid = 0; tid < 32; tid += 1) {
|
||||
for (int reg = 0; reg < 8; reg += 1) {
|
||||
vx_printf("thread %d, f%d: %x\n", tid, 16+reg, *((int*) &results[tid*8+reg]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
int main()
|
||||
{
|
||||
vx_tmc(-1);
|
||||
vx_wmma_load();
|
||||
vx_wmma();
|
||||
store_wmma_result();
|
||||
vx_tmc(1);
|
||||
// print_wmma_result();
|
||||
|
||||
return 0;
|
||||
}
|
||||
11
tests/kernel/tensor/test_data.h
Normal file
11
tests/kernel/tensor/test_data.h
Normal file
@@ -0,0 +1,11 @@
|
||||
float A[16][8] = {
|
||||
#include "a_matrix.h"
|
||||
};
|
||||
|
||||
float B[8][16] = {
|
||||
#include "b_matrix.h"
|
||||
};
|
||||
|
||||
float C[16][16] = {
|
||||
#include "c_matrix.h"
|
||||
};
|
||||
9
tests/regression/sgemm_tcore/Makefile
Normal file
9
tests/regression/sgemm_tcore/Makefile
Normal file
@@ -0,0 +1,9 @@
|
||||
PROJECT = sgemm_tcore
|
||||
|
||||
SRCS = main.cpp common.h
|
||||
|
||||
VX_SRCS = kernel.cpp
|
||||
|
||||
OPTS ?= -n16
|
||||
|
||||
include ../common.mk
|
||||
18
tests/regression/sgemm_tcore/common.h
Normal file
18
tests/regression/sgemm_tcore/common.h
Normal file
@@ -0,0 +1,18 @@
|
||||
#ifndef _COMMON_H_
|
||||
#define _COMMON_H_
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
#define KERNEL_ARG_DEV_MEM_ADDR 0x7fff0000
|
||||
#define DEV_SMEM_START_ADDR 0xff000000
|
||||
|
||||
typedef struct {
|
||||
uint32_t dim_m;
|
||||
uint32_t dim_n;
|
||||
uint32_t dim_k;
|
||||
uint64_t addr_a;
|
||||
uint64_t addr_b;
|
||||
uint64_t addr_c;
|
||||
} kernel_arg_t;
|
||||
|
||||
#endif
|
||||
285
tests/regression/sgemm_tcore/kernel.cpp
Normal file
285
tests/regression/sgemm_tcore/kernel.cpp
Normal file
@@ -0,0 +1,285 @@
|
||||
#define RISCV_CUSTOM3 0x7B
|
||||
|
||||
#include <stdint.h>
|
||||
#include <vx_intrinsics.h>
|
||||
#include <vx_print.h>
|
||||
#include <vx_spawn.h>
|
||||
#include "common.h"
|
||||
|
||||
#define BM 16
|
||||
#define BN 16
|
||||
#define BK 8
|
||||
|
||||
inline void vx_wmma() {
|
||||
asm volatile (".insn r %0, 0, 0, x0, x0, x0" :: "i"(RISCV_CUSTOM3));
|
||||
}
|
||||
|
||||
void vx_wmma_load(volatile float *smem_A, volatile float *smem_B, int warp_x, int warp_y, int thread_in_warp) {
|
||||
int tid = thread_in_warp;
|
||||
int tg = tid / 4;
|
||||
|
||||
// load A
|
||||
int row = tid % 4;
|
||||
row += (tg * 8) % 16;
|
||||
row += (tg / 4) * 4;
|
||||
|
||||
int smem_A_m = 32;
|
||||
int smem_A_n = 8;
|
||||
int smem_B_m = 8;
|
||||
int smem_B_n = 32;
|
||||
|
||||
int A_offset = (row + BM * warp_y) * smem_A_n;
|
||||
|
||||
asm volatile ("flw f0, %0" :: "m"(smem_A[A_offset + 0]));
|
||||
asm volatile ("flw f1, %0" :: "m"(smem_A[A_offset + 1]));
|
||||
asm volatile ("flw f2, %0" :: "m"(smem_A[A_offset + 2]));
|
||||
asm volatile ("flw f3, %0" :: "m"(smem_A[A_offset + 3]));
|
||||
asm volatile ("flw f4, %0" :: "m"(smem_A[A_offset + 4]));
|
||||
asm volatile ("flw f5, %0" :: "m"(smem_A[A_offset + 5]));
|
||||
asm volatile ("flw f6, %0" :: "m"(smem_A[A_offset + 6]));
|
||||
asm volatile ("flw f7, %0" :: "m"(smem_A[A_offset + 7]));
|
||||
|
||||
// load B
|
||||
int col = tid % 4;
|
||||
col += ((tg % 4) / 2) * 8;
|
||||
col += (tg / 4) * 4;
|
||||
|
||||
asm volatile ("flw f8 , %0" :: "m"(smem_B[(0 * smem_B_n) + warp_x * BN + col]));
|
||||
asm volatile ("flw f9 , %0" :: "m"(smem_B[(1 * smem_B_n) + warp_x * BN + col]));
|
||||
asm volatile ("flw f10, %0" :: "m"(smem_B[(2 * smem_B_n) + warp_x * BN + col]));
|
||||
asm volatile ("flw f11, %0" :: "m"(smem_B[(3 * smem_B_n) + warp_x * BN + col]));
|
||||
asm volatile ("flw f12, %0" :: "m"(smem_B[(4 * smem_B_n) + warp_x * BN + col]));
|
||||
asm volatile ("flw f13, %0" :: "m"(smem_B[(5 * smem_B_n) + warp_x * BN + col]));
|
||||
asm volatile ("flw f14, %0" :: "m"(smem_B[(6 * smem_B_n) + warp_x * BN + col]));
|
||||
asm volatile ("flw f15, %0" :: "m"(smem_B[(7 * smem_B_n) + warp_x * BN + col]));
|
||||
}
|
||||
|
||||
inline void initialize_C() {
|
||||
// initialize C to zeros
|
||||
asm volatile ("fmv.w.x f16, x0");
|
||||
asm volatile ("fmv.w.x f17, x0");
|
||||
asm volatile ("fmv.w.x f18, x0");
|
||||
asm volatile ("fmv.w.x f19, x0");
|
||||
asm volatile ("fmv.w.x f20, x0");
|
||||
asm volatile ("fmv.w.x f21, x0");
|
||||
asm volatile ("fmv.w.x f22, x0");
|
||||
asm volatile ("fmv.w.x f23, x0");
|
||||
}
|
||||
|
||||
inline void write_results(
|
||||
volatile float *local_warp_results,
|
||||
int thread_in_warp,
|
||||
int warp_x,
|
||||
int warp_y,
|
||||
int dim_m,
|
||||
int dim_n,
|
||||
float *C,
|
||||
int threadblock_id_x,
|
||||
int threadblock_id_y
|
||||
) {
|
||||
int tid = thread_in_warp;
|
||||
int tg = tid / 4;
|
||||
|
||||
asm volatile ("fsw f16, %0" :: "m"(local_warp_results[tid*8+0]));
|
||||
asm volatile ("fsw f17, %0" :: "m"(local_warp_results[tid*8+1]));
|
||||
asm volatile ("fsw f18, %0" :: "m"(local_warp_results[tid*8+2]));
|
||||
asm volatile ("fsw f19, %0" :: "m"(local_warp_results[tid*8+3]));
|
||||
asm volatile ("fsw f20, %0" :: "m"(local_warp_results[tid*8+4]));
|
||||
asm volatile ("fsw f21, %0" :: "m"(local_warp_results[tid*8+5]));
|
||||
asm volatile ("fsw f22, %0" :: "m"(local_warp_results[tid*8+6]));
|
||||
asm volatile ("fsw f23, %0" :: "m"(local_warp_results[tid*8+7]));
|
||||
|
||||
/*
|
||||
col = ((threadgroup % 4) // 2) * 8
|
||||
row = (threadgroup * 8) % 16
|
||||
row += (threadgroup // 4) * 4
|
||||
offsets = [(0, 0), (0, 1), (2, 0), (2, 1), (0, 4), (0, 5), (2, 4), (2, 5)]
|
||||
offset = offsets[register-16]
|
||||
row += offset[0]
|
||||
col += offset[1]
|
||||
thread_offsets = [(0, 0), (1, 0), (0, 2), (1, 2)]
|
||||
thread_offset = thread_offsets[thread % 4]
|
||||
row += thread_offset[0]
|
||||
col += thread_offset[1]
|
||||
return (row, col)
|
||||
*/
|
||||
|
||||
int local_col = ((tg % 4) / 2) * 8;
|
||||
int local_row = (tg * 8) % 16;
|
||||
local_row += (tg / 4) * 4;
|
||||
|
||||
// int row_offsets[8] = {0, 0, 2, 2, 0, 0, 2, 2};
|
||||
// int col_offsets[8] = {0, 1, 0, 1, 4, 5, 4, 5};
|
||||
|
||||
// int thread_row_offsets[4] = {0, 1, 0, 1};
|
||||
// int thread_col_offsets[4] = {0, 0, 2, 2};
|
||||
int thread_row_offset = (tid % 4) % 2;
|
||||
int thread_col_offset = ((tid % 4) / 2) * 2;
|
||||
|
||||
float *global_offset_C = C + (threadblock_id_y * BM * 2 + warp_y * BM) * dim_n + threadblock_id_x * BN * 2 + warp_x * BM;
|
||||
for (int i = 0; i < 8; i += 1) {
|
||||
int row_offset = ((i / 2) % 2) * 2;
|
||||
int col_offset = (i / 4) * 4 + i % 2;
|
||||
|
||||
int adjusted_local_row = local_row + thread_row_offset + row_offset;
|
||||
int adjusted_local_col = local_col + thread_col_offset + col_offset;
|
||||
|
||||
float v = local_warp_results[tid*8+i];
|
||||
global_offset_C[adjusted_local_row * dim_n + adjusted_local_col] = v;
|
||||
}
|
||||
}
|
||||
|
||||
void threadblock_barrier(unsigned int tid_in_threadblock, unsigned int barrier_id, unsigned int count) {
|
||||
vx_fence();
|
||||
vx_barrier(barrier_id, count);
|
||||
}
|
||||
|
||||
void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
|
||||
const uint32_t tid_in_threadblock,
|
||||
const uint32_t threadblock_dim_x,
|
||||
const uint32_t threadblock_dim_y,
|
||||
const uint32_t threadblock_id_x,
|
||||
const uint32_t threadblock_id_y,
|
||||
const uint32_t threadblock_id,
|
||||
float *sharedmem_per_threadblock) {
|
||||
const float *A = (const float *)arg->addr_a;
|
||||
const float *B = (const float *)arg->addr_b;
|
||||
float *C = (float *)arg->addr_c;
|
||||
|
||||
const uint32_t dim_m = arg->dim_m;
|
||||
const uint32_t dim_n = arg->dim_n;
|
||||
const uint32_t dim_k = arg->dim_k;
|
||||
|
||||
// FIXME: Output block size is assumed to be square, i.e. BM == BN
|
||||
// const uint32_t BM = threadblock_dim_y;
|
||||
// const uint32_t BN = threadblock_dim_y;
|
||||
// const uint32_t BK = threadblock_dim_x;
|
||||
// constexpr uint32_t BM = 8;
|
||||
// constexpr uint32_t BN = 8;
|
||||
// constexpr uint32_t BK = 2;
|
||||
|
||||
const uint32_t warp_in_threadblock = tid_in_threadblock / 32;
|
||||
const uint32_t tid_in_warp = tid_in_threadblock % 32;
|
||||
const uint32_t warp_x = warp_in_threadblock % 2;
|
||||
const uint32_t warp_y = warp_in_threadblock / 2;
|
||||
|
||||
const uint32_t global_a_row = threadblock_dim_y * threadblock_id_y;
|
||||
|
||||
// 32 * 8 block of A, being loaded by 4 warps
|
||||
const uint32_t local_a_row = warp_y * BM + warp_x * (BM / 2) + (tid_in_warp / BK);
|
||||
const uint32_t local_a_col = tid_in_warp % BK;
|
||||
|
||||
// 8 * 32 block of B, being loaded by 4 warps
|
||||
// do a fat coalesced load
|
||||
const uint32_t global_b_col = threadblock_dim_x * threadblock_id_x;
|
||||
const uint32_t local_b_row = warp_in_threadblock;
|
||||
const uint32_t local_b_col = tid_in_warp;
|
||||
|
||||
|
||||
volatile float *local_a = sharedmem_per_threadblock;
|
||||
const size_t local_a_elems = (threadblock_dim_y * BK);
|
||||
volatile float *local_b = sharedmem_per_threadblock + local_a_elems;
|
||||
const size_t local_b_elems = (threadblock_dim_x * BK);
|
||||
volatile float *local_warp_results = local_b + local_b_elems + (warp_in_threadblock * BM * BN);
|
||||
|
||||
// clear out C
|
||||
initialize_C();
|
||||
|
||||
for (uint32_t k = 0; k < dim_k; k += BK) {
|
||||
// Data move from GMEM to SMEM
|
||||
//
|
||||
// Make sure global offset values for A and B are contiguous between
|
||||
// neighboring threads to ensure GMEM coalescing. (not possible for A here, but for B it's doable)
|
||||
|
||||
// coalesced load from global memory -> unit-stride store into shared memory
|
||||
uint32_t global_a_offset =
|
||||
dim_k * (global_a_row + local_a_row) + (k + local_a_col);
|
||||
uint32_t shared_a_offset =
|
||||
BK * local_a_row + local_a_col;
|
||||
|
||||
local_a[shared_a_offset] = A[global_a_offset];
|
||||
|
||||
global_a_offset += dim_k * (BM / 4);
|
||||
shared_a_offset += BK * (BM / 4);
|
||||
|
||||
local_a[shared_a_offset] = A[global_a_offset];
|
||||
|
||||
uint32_t global_b_offset =
|
||||
dim_n * (k + local_b_row) + (global_b_col + local_b_col);
|
||||
uint32_t shared_b_offset =
|
||||
(BN * 2) * (local_b_row) + local_b_col;
|
||||
|
||||
local_b[shared_b_offset] = B[global_b_offset];
|
||||
|
||||
global_b_offset += dim_n * (BK / 2);
|
||||
shared_b_offset += (BN * 2) * (BK / 2);
|
||||
|
||||
local_b[shared_b_offset] = B[global_b_offset];
|
||||
|
||||
// want all 4 warps to reach barrier before moving on (just use barrier 0 for now)
|
||||
threadblock_barrier(tid_in_threadblock, 0, 4);
|
||||
|
||||
// perform wmma
|
||||
vx_wmma_load(local_a, local_b, warp_x, warp_y, tid_in_warp);
|
||||
vx_wmma();
|
||||
|
||||
// same as above
|
||||
threadblock_barrier(tid_in_threadblock, 0, 4);
|
||||
}
|
||||
|
||||
write_results(
|
||||
local_warp_results,
|
||||
tid_in_warp,
|
||||
warp_x,
|
||||
warp_y,
|
||||
dim_m,
|
||||
dim_n,
|
||||
C,
|
||||
threadblock_id_x,
|
||||
threadblock_id_y
|
||||
);
|
||||
}
|
||||
|
||||
void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
|
||||
// @perf: All threads are running these compute whose result is mostly same
|
||||
// across the threadblock
|
||||
const int NT = 32; // vx_num_threads();
|
||||
const int NW = 4; // vx_num_warps();
|
||||
const uint32_t threads_per_threadblock = NT * NW;
|
||||
|
||||
// matches 4 warp capacity
|
||||
const uint32_t threadblock_dim_x = 2 * BN;
|
||||
const uint32_t threadblock_dim_y = 2 * BM;
|
||||
const int threadblock_id = task_id / threads_per_threadblock;
|
||||
const int tid_in_threadblock = task_id % threads_per_threadblock;
|
||||
|
||||
const uint32_t dim_m = arg->dim_m;
|
||||
const uint32_t dim_n = arg->dim_n;
|
||||
const uint32_t dim_n_in_blocks = dim_n / threadblock_dim_x;
|
||||
const int threadblock_id_x = threadblock_id % dim_n_in_blocks;
|
||||
const int threadblock_id_y = threadblock_id / dim_n_in_blocks;
|
||||
|
||||
// "static" shared memory allocation. This would determine threadblock
|
||||
// occupancy of a single cluster
|
||||
// only 1 threadblock running at a time, so this is ok
|
||||
float *sharedmem_per_threadblock =
|
||||
(float *)DEV_SMEM_START_ADDR; // + (2 * BM * BK) + (2 * BN * BK) * threadblock_id;
|
||||
|
||||
thread_block_gemm(arg, tid_in_threadblock, threadblock_dim_x,
|
||||
threadblock_dim_y, threadblock_id_x, threadblock_id_y,
|
||||
threadblock_id, sharedmem_per_threadblock);
|
||||
}
|
||||
|
||||
int main() {
|
||||
kernel_arg_t *arg = (kernel_arg_t *)KERNEL_ARG_DEV_MEM_ADDR;
|
||||
int NT = vx_num_threads();
|
||||
|
||||
// TODO: add support for edge-case (m, n not divisible by 16)
|
||||
const uint32_t grid_size = arg->dim_m * arg->dim_n * NT / (BM * BN);
|
||||
|
||||
// for now, simplifying assumption of just 1 core
|
||||
// vx_spawn_tasks_contiguous first runs warps 1 through NW, then NW+1 through 2*NW, etc.
|
||||
// we can thus treat 1 through NW as a single threadblock for the purposes of the optimization.
|
||||
vx_spawn_tasks_contiguous(grid_size, (vx_spawn_tasks_cb)kernel_body, arg);
|
||||
return 0;
|
||||
}
|
||||
270
tests/regression/sgemm_tcore/main.cpp
Normal file
270
tests/regression/sgemm_tcore/main.cpp
Normal file
@@ -0,0 +1,270 @@
|
||||
#include <iostream>
|
||||
#include <fstream>
|
||||
#include <unistd.h>
|
||||
#include <string.h>
|
||||
#include <vortex.h>
|
||||
#include <vector>
|
||||
#include "common.h"
|
||||
|
||||
#define RT_CHECK(_expr) \
|
||||
do { \
|
||||
int _ret = _expr; \
|
||||
if (0 == _ret) \
|
||||
break; \
|
||||
printf("Error: '%s' returned %d!\n", #_expr, (int)_ret); \
|
||||
cleanup(); \
|
||||
exit(-1); \
|
||||
} while (false)
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
const char* kernel_file = "kernel.bin";
|
||||
uint32_t count = 0;
|
||||
|
||||
std::vector<float> src_a_data;
|
||||
std::vector<float> src_b_data;
|
||||
std::vector<float> ref_data;
|
||||
|
||||
vx_device_h device = nullptr;
|
||||
std::vector<uint8_t> staging_buf;
|
||||
kernel_arg_t kernel_arg = {};
|
||||
|
||||
static void show_usage() {
|
||||
std::cout << "Vortex Test." << std::endl;
|
||||
std::cout << "Usage: [-k: kernel] [-n words] [-h: help]" << std::endl;
|
||||
}
|
||||
|
||||
static void parse_args(int argc, char **argv) {
|
||||
int c;
|
||||
while ((c = getopt(argc, argv, "n:k:h?")) != -1) {
|
||||
switch (c) {
|
||||
case 'n':
|
||||
count = atoi(optarg);
|
||||
break;
|
||||
case 'k':
|
||||
kernel_file = optarg;
|
||||
break;
|
||||
case 'h':
|
||||
case '?': {
|
||||
show_usage();
|
||||
exit(0);
|
||||
} break;
|
||||
default:
|
||||
show_usage();
|
||||
exit(-1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void cleanup() {
|
||||
if (device) {
|
||||
vx_mem_free(device, kernel_arg.addr_a);
|
||||
vx_mem_free(device, kernel_arg.addr_b);
|
||||
vx_mem_free(device, kernel_arg.addr_c);
|
||||
vx_dev_close(device);
|
||||
}
|
||||
}
|
||||
|
||||
void generate_source_matrix(uint32_t dim_m, uint32_t dim_n, uint32_t dim_k) {
|
||||
src_a_data.resize(dim_m * dim_k);
|
||||
src_b_data.resize(dim_k * dim_n);
|
||||
|
||||
for (uint32_t i = 0; i < src_a_data.size(); ++i) {
|
||||
src_a_data[i] = static_cast<float>(i);
|
||||
std::cout << "A: " << i << ": value=" << src_a_data[i] << std::endl;
|
||||
}
|
||||
for (uint32_t i = 0; i < src_b_data.size(); ++i) {
|
||||
src_b_data[i] = static_cast<float>(i);
|
||||
std::cout << "B: " << i << ": value=" << src_b_data[i] << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
void generate_reference_matmul(uint32_t dim_m, uint32_t dim_n, uint32_t dim_k) {
|
||||
ref_data.resize(dim_m * dim_n);
|
||||
|
||||
for (uint32_t i = 0; i < dim_m; ++i) {
|
||||
for (uint32_t j = 0; j < dim_n; ++j) {
|
||||
float ref = 0.0f;
|
||||
for (uint32_t k = 0; k < dim_k; ++k) {
|
||||
ref += src_a_data[dim_k * i + k] * src_b_data[dim_n * k + j];
|
||||
}
|
||||
ref_data.at(dim_n * i + j) = ref;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
int run_test(const kernel_arg_t& kernel_arg,
|
||||
uint32_t buf_size,
|
||||
uint32_t dim_m, uint32_t dim_n) {
|
||||
// start device
|
||||
std::cout << "start device" << std::endl;
|
||||
RT_CHECK(vx_start(device));
|
||||
|
||||
// wait for completion
|
||||
std::cout << "wait for completion" << std::endl;
|
||||
RT_CHECK(vx_ready_wait(device, VX_MAX_TIMEOUT));
|
||||
|
||||
// download destination buffer
|
||||
std::cout << "download destination buffer" << std::endl;
|
||||
RT_CHECK(vx_copy_from_dev(device, staging_buf.data(), kernel_arg.addr_c, buf_size));
|
||||
|
||||
// verify result
|
||||
std::cout << "verify result" << std::endl;
|
||||
{
|
||||
int errors = 0;
|
||||
auto buf_ptr = (float*)staging_buf.data();
|
||||
for (uint32_t i = 0; i < dim_m * dim_n; ++i) {
|
||||
float ref = ref_data.at(i);
|
||||
float cur = buf_ptr[i];
|
||||
if (std::abs((cur - ref) / ref) > 1e-6) {
|
||||
std::cout << "error at result #" << std::dec << i
|
||||
<< std::hex << ": actual=" << cur << ", expected=" << ref << std::endl;
|
||||
++errors;
|
||||
}
|
||||
}
|
||||
if (errors != 0) {
|
||||
std::cout << "Found " << std::dec << errors << " errors!" << std::endl;
|
||||
std::cout << "FAILED!" << std::endl;
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
int main(int argc, char *argv[]) {
|
||||
// parse command arguments
|
||||
parse_args(argc, argv);
|
||||
|
||||
if (count == 0) {
|
||||
count = 1;
|
||||
}
|
||||
|
||||
std::srand(50);
|
||||
|
||||
// open device connection
|
||||
std::cout << "open device connection" << std::endl;
|
||||
RT_CHECK(vx_dev_open(&device));
|
||||
|
||||
// FIXME: hardcoded
|
||||
uint32_t dim_m = 64;
|
||||
uint32_t dim_n = 64;
|
||||
uint32_t dim_k = 64;
|
||||
|
||||
generate_source_matrix(dim_m, dim_n, dim_k);
|
||||
generate_reference_matmul(dim_m, dim_n, dim_k);
|
||||
|
||||
uint32_t src_a_buf_size = src_a_data.size() * sizeof(src_a_data[0]);
|
||||
uint32_t src_b_buf_size = src_b_data.size() * sizeof(src_b_data[0]);
|
||||
uint32_t dst_buf_size = ref_data.size() * sizeof(src_a_data[0]);
|
||||
|
||||
std::cout << "buffer size: " << dst_buf_size << " bytes" << std::endl;
|
||||
|
||||
// upload program
|
||||
std::cout << "upload program" << std::endl;
|
||||
RT_CHECK(vx_upload_kernel_file(device, kernel_file));
|
||||
|
||||
// allocate device memory
|
||||
std::cout << "allocate device memory" << std::endl;
|
||||
RT_CHECK(vx_mem_alloc(device, src_a_buf_size, VX_MEM_TYPE_GLOBAL, &kernel_arg.addr_a));
|
||||
RT_CHECK(vx_mem_alloc(device, src_b_buf_size, VX_MEM_TYPE_GLOBAL, &kernel_arg.addr_b));
|
||||
RT_CHECK(vx_mem_alloc(device, dst_buf_size, VX_MEM_TYPE_GLOBAL, &kernel_arg.addr_c));
|
||||
|
||||
kernel_arg.dim_m = dim_m;
|
||||
kernel_arg.dim_n = dim_n;
|
||||
kernel_arg.dim_k = dim_k;
|
||||
|
||||
std::cout << "dev_addr_a=0x" << std::hex << kernel_arg.addr_a << std::endl;
|
||||
std::cout << "dev_addr_b=0x" << std::hex << kernel_arg.addr_b << std::endl;
|
||||
std::cout << "dev_addr_c=0x" << std::hex << kernel_arg.addr_c << std::endl;
|
||||
|
||||
// allocate staging buffer
|
||||
{
|
||||
std::cout << "allocate staging buffer" << std::endl;
|
||||
uint32_t staging_buf_size = std::max<uint32_t>(
|
||||
src_a_buf_size,
|
||||
std::max<uint32_t>(
|
||||
src_b_buf_size,
|
||||
std::max<uint32_t>(dst_buf_size, sizeof(kernel_arg_t))));
|
||||
staging_buf.resize(staging_buf_size);
|
||||
}
|
||||
|
||||
// upload kernel argument
|
||||
{
|
||||
std::cout << "upload kernel argument" << std::endl;
|
||||
auto buf_ptr = staging_buf.data();
|
||||
memcpy(buf_ptr, &kernel_arg, sizeof(kernel_arg_t));
|
||||
RT_CHECK(vx_copy_to_dev(device, KERNEL_ARG_DEV_MEM_ADDR, staging_buf.data(), sizeof(kernel_arg_t)));
|
||||
|
||||
std::cout << "uploading argument buffer to device, device mem address="
|
||||
<< std::hex << KERNEL_ARG_DEV_MEM_ADDR << ", size=" << std::dec
|
||||
<< sizeof(kernel_arg_t) << " bytes\n";
|
||||
std::ofstream file("args.bin", std::ios::binary | std::ios::out);
|
||||
if (!file) {
|
||||
std::cerr << "error: failed to open args.bin for writing\n";
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
file.write(reinterpret_cast<char *>(staging_buf.data()),
|
||||
sizeof(kernel_arg_t));
|
||||
file.close();
|
||||
}
|
||||
|
||||
// upload source buffer
|
||||
{
|
||||
{
|
||||
auto buf_ptr = staging_buf.data();
|
||||
memcpy(buf_ptr, src_a_data.data(), src_a_data.size() * sizeof(float));
|
||||
RT_CHECK(vx_copy_to_dev(device, kernel_arg.addr_a, staging_buf.data(),
|
||||
src_a_buf_size));
|
||||
|
||||
std::cout << "uploading source A matrix to device, device mem address="
|
||||
<< std::hex << kernel_arg.addr_a << ", size=" << std::dec
|
||||
<< src_a_buf_size << " bytes\n";
|
||||
std::ofstream file("input.a.bin", std::ios::binary | std::ios::out);
|
||||
if (!file) {
|
||||
std::cerr << "error: failed to open args.bin for writing\n";
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
file.write(reinterpret_cast<char *>(buf_ptr), src_a_buf_size);
|
||||
file.close();
|
||||
}
|
||||
{
|
||||
auto buf_ptr = staging_buf.data();
|
||||
memcpy(buf_ptr, src_b_data.data(), src_b_data.size() * sizeof(float));
|
||||
RT_CHECK(vx_copy_to_dev(device, kernel_arg.addr_b, staging_buf.data(),
|
||||
src_b_buf_size));
|
||||
|
||||
std::cout << "uploading source B matrix to device, device mem address="
|
||||
<< std::hex << kernel_arg.addr_b << ", size=" << std::dec
|
||||
<< src_b_buf_size << " bytes\n";
|
||||
std::ofstream file("input.b.bin", std::ios::binary | std::ios::out);
|
||||
if (!file) {
|
||||
std::cerr << "error: failed to open args.bin for writing\n";
|
||||
exit(EXIT_FAILURE);
|
||||
}
|
||||
file.write(reinterpret_cast<char *>(buf_ptr), src_b_buf_size);
|
||||
file.close();
|
||||
}
|
||||
}
|
||||
|
||||
// clear destination buffer
|
||||
{
|
||||
std::cout << "clear destination buffer" << std::endl;
|
||||
auto buf_ptr = (int32_t*)staging_buf.data();
|
||||
for (uint32_t i = 0; i < ref_data.size(); ++i) {
|
||||
buf_ptr[i] = 0xdeadbeef;
|
||||
}
|
||||
RT_CHECK(vx_copy_to_dev(device, kernel_arg.addr_c, staging_buf.data(), dst_buf_size));
|
||||
}
|
||||
|
||||
// run tests
|
||||
std::cout << "run tests" << std::endl;
|
||||
RT_CHECK(run_test(kernel_arg, dst_buf_size, kernel_arg.dim_m, kernel_arg.dim_n));
|
||||
std::cout << "PASSED!" << std::endl;
|
||||
|
||||
// cleanup
|
||||
std::cout << "cleanup" << std::endl;
|
||||
cleanup();
|
||||
|
||||
return 0;
|
||||
}
|
||||
Reference in New Issue
Block a user