This commit is contained in:
Richard Yan
2024-05-07 14:00:31 -07:00
41 changed files with 7007 additions and 25 deletions

View File

@@ -23,6 +23,9 @@
// #include "verilated_vpi.h"
#include "VX_config.h"
#include <bit>
#include "half.h"
extern "C" {
void dpi_fadd(bool enable, int dst_fmt, int64_t a, int64_t b, const svBitVecVal* frm, int64_t* result, svBitVecVal* fflags);
void dpi_fsub(bool enable, int dst_fmt, int64_t a, int64_t b, const svBitVecVal* frm, int64_t* result, svBitVecVal* fflags);
@@ -51,6 +54,9 @@ extern "C" {
void dpi_feq(bool enable, int dst_fmt, int64_t a, int64_t b, int64_t* result, svBitVecVal* fflags);
void dpi_fmin(bool enable, int dst_fmt, int64_t a, int64_t b, int64_t* result, svBitVecVal* fflags);
void dpi_fmax(bool enable, int dst_fmt, int64_t a, int64_t b, int64_t* result, svBitVecVal* fflags);
void dpi_hmma(bool enable, const svBitVecVal* A_tile, const svBitVecVal* B_tile, const svBitVecVal* C_tile, svBitVecVal* D_tile);
void dpi_print_results(int wid, int octet, const svBitVecVal* A_tile, const svBitVecVal* B_tile, const svBitVecVal* C_tile, const svBitVecVal* D_tile);
}
inline uint64_t nan_box(uint32_t value) {
@@ -338,3 +344,223 @@ void dpi_fmax(bool enable, int dst_fmt, int64_t a, int64_t b, int64_t* result, s
*result = nan_box(rv_fmax_s(check_boxing(a), check_boxing(b), fflags));
}
}
// A is M * K, B is K * M, C is M * M, D is M * M
#define M 4
#define K 2
// all row major
float c_A_tile[M][K];
float c_B_tile[K][M];
float c_C_tile[M][M];
float c_D_tile[M][M];
// code assumes that svBitVecVal is basically a uint32_t
static_assert(sizeof(svBitVecVal) == 4);
void fill_float_array(const svBitVecVal* sv_tile, float* c_tile, int rows, int cols) {
for (int i = 0; i < rows; i += 1) {
for (int j = 0; j < cols; j += 1) {
int index = i * cols + j;
svBitVecVal sv_val = sv_tile[index];
uint32_t c_val = sv_val;
float c_float;
memcpy(&c_float, &c_val, sizeof(c_float));
c_tile[index] = c_float;
// std::cout << c_float << " ";
}
// std::cout << std::endl;
}
}
void write_float_array(svBitVecVal* sv_tile, float* c_tile, int rows, int cols) {
for (int i = 0; i < rows; i += 1) {
for (int j = 0; j < cols; j += 1) {
int index = i * cols + j;
svBitVecVal* sv_val = &sv_tile[index];
float c_float = c_tile[index];
memcpy(sv_val, &c_float, sizeof(c_float));
// std::cout << c_float << " ";
}
// std::cout << std::endl;
}
}
void dpi_hmma(bool enable, const svBitVecVal* A_tile, const svBitVecVal* B_tile, const svBitVecVal* C_tile, svBitVecVal* D_tile) {
if (!enable) {
return;
}
// std::cout << "A: " << std::endl;
fill_float_array(A_tile, &c_A_tile[0][0], M, K);
// std::cout << "B: " << std::endl;
fill_float_array(B_tile, &c_B_tile[0][0], K, M);
// std::cout << "C: " << std::endl;
fill_float_array(C_tile, &c_C_tile[0][0], M, M);
for (int i = 0; i < M; i += 1) {
for (int j = 0; j < M; j += 1) {
float accum = c_C_tile[i][j];
for (int k = 0; k < K; k += 1) {
accum += c_A_tile[i][k] * c_B_tile[k][j];
}
c_D_tile[i][j] = accum;
}
}
write_float_array(D_tile, &c_D_tile[0][0], M, M);
}
// 1 copy per warp
float A_tile_full[4][16][8];
float B_tile_full[4][8][16];
float C_tile_full[4][16][16];
float D_tile_full[4][16][16];
int steps[4];
void print_array(float* array, int rows, int cols) {
for (int i = 0; i < rows; i += 1) {
for (int j = 0; j < cols; j += 1) {
std::cout << array[i*cols+j] << " ";
}
std::cout << "\n";
}
std::cout << std::endl;
}
void dpi_print_results(int wid, int octet, const svBitVecVal* A_tile, const svBitVecVal* B_tile, const svBitVecVal* C_tile, const svBitVecVal* D_tile) {
// std::cout << "A: " << std::endl;
fill_float_array(A_tile, &c_A_tile[0][0], M, K);
// std::cout << "B: " << std::endl;
fill_float_array(B_tile, &c_B_tile[0][0], K, M);
// std::cout << "C: " << std::endl;
fill_float_array(C_tile, &c_C_tile[0][0], M, M);
// for some reason this still holds onto old value? very strange
// std::cout << "D: " << std::endl;
fill_float_array(D_tile, &c_D_tile[0][0], M, M);
int octet_row_offset;
int octet_col_offset;
switch(octet) {
case 0:
octet_row_offset = 0;
octet_col_offset = 0;
break;
case 1:
octet_row_offset = 8;
octet_col_offset = 0;
break;
case 2:
octet_row_offset = 0;
octet_col_offset = 8;
break;
case 3:
octet_row_offset = 8;
octet_col_offset = 8;
break;
}
int step_row_offset;
int step_col_offset;
int step = (steps[wid] % 16) / 4;
int set = (steps[wid] / 16);
switch(step) {
case 0:
step_row_offset = 0;
step_col_offset = 0;
break;
case 1:
step_row_offset = 2;
step_col_offset = 0;
break;
case 2:
step_row_offset = 0;
step_col_offset = 4;
break;
case 3:
step_row_offset = 2;
step_col_offset = 4;
break;
}
if (steps[0] >= 48) {
// std::cout << "octet " << octet << " step " << steps[0] << "\n";
// print_array(&c_D_tile[0][0], 4, 4);
}
D_tile_full[wid][octet_row_offset+step_row_offset+0][octet_col_offset+step_col_offset+0] = c_D_tile[0][0];
D_tile_full[wid][octet_row_offset+step_row_offset+0][octet_col_offset+step_col_offset+1] = c_D_tile[0][1];
D_tile_full[wid][octet_row_offset+step_row_offset+0][octet_col_offset+step_col_offset+2] = c_D_tile[0][2];
D_tile_full[wid][octet_row_offset+step_row_offset+0][octet_col_offset+step_col_offset+3] = c_D_tile[0][3];
D_tile_full[wid][octet_row_offset+step_row_offset+1][octet_col_offset+step_col_offset+0] = c_D_tile[1][0];
D_tile_full[wid][octet_row_offset+step_row_offset+1][octet_col_offset+step_col_offset+1] = c_D_tile[1][1];
D_tile_full[wid][octet_row_offset+step_row_offset+1][octet_col_offset+step_col_offset+2] = c_D_tile[1][2];
D_tile_full[wid][octet_row_offset+step_row_offset+1][octet_col_offset+step_col_offset+3] = c_D_tile[1][3];
D_tile_full[wid][octet_row_offset+step_row_offset+4][octet_col_offset+step_col_offset+0] = c_D_tile[2][0];
D_tile_full[wid][octet_row_offset+step_row_offset+4][octet_col_offset+step_col_offset+1] = c_D_tile[2][1];
D_tile_full[wid][octet_row_offset+step_row_offset+4][octet_col_offset+step_col_offset+2] = c_D_tile[2][2];
D_tile_full[wid][octet_row_offset+step_row_offset+4][octet_col_offset+step_col_offset+3] = c_D_tile[2][3];
D_tile_full[wid][octet_row_offset+step_row_offset+5][octet_col_offset+step_col_offset+0] = c_D_tile[3][0];
D_tile_full[wid][octet_row_offset+step_row_offset+5][octet_col_offset+step_col_offset+1] = c_D_tile[3][1];
D_tile_full[wid][octet_row_offset+step_row_offset+5][octet_col_offset+step_col_offset+2] = c_D_tile[3][2];
D_tile_full[wid][octet_row_offset+step_row_offset+5][octet_col_offset+step_col_offset+3] = c_D_tile[3][3];
if (octet == 0 || octet == 1) {
octet_row_offset = octet * 8;
if (step == 0) {
step_row_offset = 0;
}
if (step == 1) {
step_row_offset = 2;
}
if (step == 0 || step == 1) {
A_tile_full[wid][octet_row_offset+step_row_offset+0][set*2+0] = c_A_tile[0][0];
A_tile_full[wid][octet_row_offset+step_row_offset+0][set*2+1] = c_A_tile[0][1];
A_tile_full[wid][octet_row_offset+step_row_offset+1][set*2+0] = c_A_tile[1][0];
A_tile_full[wid][octet_row_offset+step_row_offset+1][set*2+1] = c_A_tile[1][1];
A_tile_full[wid][octet_row_offset+step_row_offset+4][set*2+0] = c_A_tile[2][0];
A_tile_full[wid][octet_row_offset+step_row_offset+4][set*2+1] = c_A_tile[2][1];
A_tile_full[wid][octet_row_offset+step_row_offset+5][set*2+0] = c_A_tile[3][0];
A_tile_full[wid][octet_row_offset+step_row_offset+5][set*2+1] = c_A_tile[3][1];
}
}
if (octet == 0 || octet == 2) {
octet_col_offset = octet * 4;
if (step == 0) {
step_col_offset = 0;
}
else if (step == 2) {
step_col_offset = 4;
}
if (step == 0 || step == 2) {
B_tile_full[wid][set*2+0][octet_col_offset+step_col_offset+0] = c_B_tile[0][0];
B_tile_full[wid][set*2+0][octet_col_offset+step_col_offset+1] = c_B_tile[0][1];
B_tile_full[wid][set*2+0][octet_col_offset+step_col_offset+2] = c_B_tile[0][2];
B_tile_full[wid][set*2+0][octet_col_offset+step_col_offset+3] = c_B_tile[0][3];
B_tile_full[wid][set*2+1][octet_col_offset+step_col_offset+0] = c_B_tile[1][0];
B_tile_full[wid][set*2+1][octet_col_offset+step_col_offset+1] = c_B_tile[1][1];
B_tile_full[wid][set*2+1][octet_col_offset+step_col_offset+2] = c_B_tile[1][2];
B_tile_full[wid][set*2+1][octet_col_offset+step_col_offset+3] = c_B_tile[1][3];
}
}
steps[wid] += 1;
if (steps[wid] % 64 == 0) {
steps[wid] = 0;
std::cout << "warp " << wid << " finished wmma\n";
std::cout << "A tile" << "\n";
print_array(&A_tile_full[wid][0][0], 16, 8);
std::cout << "B tile" << "\n";
print_array(&B_tile_full[wid][0][0], 8, 16);
// std::cout << "C tile" << "\n";
// print_array(&C_tile_full[wid][0][0], 16, 16);
std::cout << "D tile" << "\n";
print_array(&D_tile_full[wid][0][0], 16, 16);
}
}

View File

@@ -44,4 +44,7 @@ import "DPI-C" function void dpi_feq(input logic enable, input int dst_fmt, inpu
import "DPI-C" function void dpi_fmin(input logic enable, input int dst_fmt, input longint a, input longint b, output longint result, output bit[4:0] fflags);
import "DPI-C" function void dpi_fmax(input logic enable, input int dst_fmt, input longint a, input longint b, output longint result, output bit[4:0] fflags);
import "DPI-C" function void dpi_hmma(input logic enable, input bit[3:0][1:0][31:0] A_tile, input bit[1:0][3:0][31:0] B_tile, input bit[3:0][3:0][31:0] C_tile, output bit[3:0][3:0][31:0] D_tile);
import "DPI-C" function void dpi_print_results(input int wid, input int octet, input bit[3:0][1:0][31:0] A_tile, input bit[1:0][3:0][31:0] B_tile, input bit[3:0][3:0][31:0] C_tile, input bit[3:0][3:0][31:0] D_tile);
`endif

4018
hw/dpi/half.h Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -40,6 +40,10 @@
`define EXT_F_ENABLE
`endif
`ifndef EXT_T_DISABLE
`define EXT_T_ENABLE
`endif
`ifndef XLEN_32
`ifndef XLEN_64
`define XLEN_32
@@ -309,7 +313,7 @@
// Size of FPU Request Queue
`ifndef FPUQ_SIZE
`define FPUQ_SIZE (2 * (`NUM_THREADS / `NUM_FPU_LANES))
`define FPUQ_SIZE (8 * (`NUM_THREADS / `NUM_FPU_LANES))
`endif
// FNCP Latency
@@ -385,6 +389,11 @@
`define LATENCY_FCVT 5
`endif
// Tensor Core Latency
`ifndef LATENCY_HMMA
`define LATENCY_HMMA 8
`endif
// Icache Configurable Knobs //////////////////////////////////////////////////
// Cache Enable
@@ -613,6 +622,12 @@
`define EXT_F_ENABLED 0
`endif
`ifdef EXT_T_ENABLE
`define EXT_T_ENABLED 1
`else
`define EXT_T_ENABLED 0
`endif
`ifdef EXT_M_ENABLE
`define EXT_M_ENABLED 1
`else

View File

@@ -58,8 +58,9 @@
`define EX_LSU 1
`define EX_SFU 2
`define EX_FPU (`EX_SFU + `EXT_F_ENABLED)
`define EX_TENSOR (`EX_FPU + `EXT_T_ENABLED)
`define NUM_EX_UNITS (3 + `EXT_F_ENABLED)
`define NUM_EX_UNITS (3 + `EXT_F_ENABLED + `EXT_T_ENABLED)
`define EX_BITS `CLOG2(`NUM_EX_UNITS)
`define EX_WIDTH `UP(`EX_BITS)
@@ -115,7 +116,7 @@
///////////////////////////////////////////////////////////////////////////////
`define INST_OP_BITS 4
`define INST_MOD_BITS 3
`define INST_MOD_BITS 4
`define INST_FMT_BITS 2
///////////////////////////////////////////////////////////////////////////////
@@ -140,6 +141,7 @@
`define INST_ALU_IS_BR(mod) mod[0]
`define INST_ALU_IS_M(mod) mod[1]
`define INST_ALU_IS_W(mod) mod[2]
`define INST_ALU_IS_RED(mod) mod[3]
`define INST_BR_EQ 4'b0000
`define INST_BR_NE 4'b0010
@@ -176,6 +178,17 @@
`define INST_M_SIGNED_A(op) (op[1:0] != 1)
`define INST_M_IS_REM(op) op[1]
`define INST_RED_ADD 4'b0000
`define INST_RED_ADDU 4'b1000
`define INST_RED_MIN 4'b0001
`define INST_RED_MINU 4'b1001
`define INST_RED_MAX 4'b0010
`define INST_RED_MAXU 4'b1010
`define INST_RED_AND 4'b0011
`define INST_RED_OR 4'b0100
`define INST_RED_XOR 4'b0101
`define INST_RED_BITS 4
`define INST_FMT_B 3'b000
`define INST_FMT_H 3'b001
`define INST_FMT_W 3'b010
@@ -241,6 +254,8 @@
`define INST_SFU_IS_WCTL(op) (op <= 5)
`define INST_SFU_IS_CSR(op) (op >= 6 && op <= 8)
`define INST_TENSOR_HMMA 4'b0000
///////////////////////////////////////////////////////////////////////////////
// non-cacheable tag bits

View File

@@ -14,6 +14,16 @@
`ifndef VX_PLATFORM_VH
`define VX_PLATFORM_VH
// synthesis only
`ifndef SIMULATION
`define SYNTHESIS
`define NDEBUG
`define DPI_DISABLE
`else
`define SV_DPI
`endif
`ifdef SYNTHESIS
`define GPR_RESET
`define LSU_DUP_DISABLE
`define ICACHE_DISABLE

View File

@@ -33,7 +33,7 @@ module VX_alu_unit #(
localparam PID_BITS = `CLOG2(`NUM_THREADS / NUM_LANES);
localparam PID_WIDTH = `UP(PID_BITS);
localparam RSP_ARB_DATAW= `UUID_WIDTH + `NW_WIDTH + NUM_LANES + `XLEN + `NR_BITS + 1 + NUM_LANES * `XLEN + PID_WIDTH + 1 + 1;
localparam RSP_ARB_SIZE = 1 + `EXT_M_ENABLED;
localparam RSP_ARB_SIZE = 2 + `EXT_M_ENABLED;
localparam PARTIAL_BW = (BLOCK_SIZE != `ISSUE_WIDTH) || (NUM_LANES != `NUM_THREADS);
VX_execute_if #(
@@ -60,12 +60,13 @@ module VX_alu_unit #(
for (genvar block_idx = 0; block_idx < BLOCK_SIZE; ++block_idx) begin
wire is_muldiv_op;
wire is_reduce_op;
VX_execute_if #(
.NUM_LANES (NUM_LANES)
) int_execute_if();
assign int_execute_if.valid = execute_if[block_idx].valid && ~is_muldiv_op;
assign int_execute_if.valid = execute_if[block_idx].valid && ~is_muldiv_op && ~is_reduce_op;
assign int_execute_if.data = execute_if[block_idx].data;
VX_commit_if #(
@@ -86,6 +87,31 @@ module VX_alu_unit #(
.commit_if (int_commit_if)
);
assign is_reduce_op = `INST_ALU_IS_RED(execute_if[block_idx].data.op_mod);
VX_execute_if #(
.NUM_LANES (NUM_LANES)
) red_execute_if();
assign red_execute_if.valid = execute_if[block_idx].valid && is_reduce_op;
assign red_execute_if.data = execute_if[block_idx].data;
VX_commit_if #(
.NUM_LANES (NUM_LANES)
) red_commit_if();
`RESET_RELAY(red_reset, reset);
VX_reduce_unit #(
.CORE_ID(CORE_ID),
.NUM_LANES(NUM_LANES)
) reduce_unit (
.clk(clk),
.reset(red_reset),
.execute_if(red_execute_if),
.commit_if(red_commit_if)
);
`ifdef EXT_M_ENABLE
assign is_muldiv_op = `INST_ALU_IS_M(execute_if[block_idx].data.op_mod);
@@ -96,7 +122,7 @@ module VX_alu_unit #(
.NUM_LANES (NUM_LANES)
) mdv_execute_if();
assign mdv_execute_if.valid = execute_if[block_idx].valid && is_muldiv_op;
assign mdv_execute_if.valid = execute_if[block_idx].valid && is_muldiv_op && ~is_reduce_op;
assign mdv_execute_if.data = execute_if[block_idx].data;
VX_commit_if #(
@@ -113,12 +139,12 @@ module VX_alu_unit #(
.commit_if (mdv_commit_if)
);
assign execute_if[block_idx].ready = is_muldiv_op ? mdv_execute_if.ready : int_execute_if.ready;
assign execute_if[block_idx].ready = is_reduce_op ? red_execute_if.ready : (is_muldiv_op ? mdv_execute_if.ready : int_execute_if.ready);
`else
assign is_muldiv_op = 0;
assign execute_if[block_idx].ready = int_execute_if.ready;
assign execute_if[block_idx].ready = is_reduce_op ? red_execute_if.ready : int_execute_if.ready;
`endif
@@ -135,19 +161,22 @@ module VX_alu_unit #(
`ifdef EXT_M_ENABLE
mdv_commit_if.valid,
`endif
int_commit_if.valid
int_commit_if.valid,
red_commit_if.valid
}),
.ready_in ({
`ifdef EXT_M_ENABLE
mdv_commit_if.ready,
`endif
int_commit_if.ready
int_commit_if.ready,
red_commit_if.ready
}),
.data_in ({
`ifdef EXT_M_ENABLE
mdv_commit_if.data,
`endif
int_commit_if.data
int_commit_if.data,
red_commit_if.data
}),
.data_out (commit_block_if[block_idx].data),
.valid_out (commit_block_if[block_idx].valid),

View File

@@ -28,6 +28,10 @@ module VX_commit import VX_gpu_pkg::*; #(
`endif
VX_commit_if.slave sfu_commit_if [`ISSUE_WIDTH],
`ifdef EXT_T_ENABLE
VX_commit_if.slave tensor_commit_if [`ISSUE_WIDTH],
`endif
// outputs
VX_writeback_if.master writeback_if [`ISSUE_WIDTH],
VX_commit_csr_if.master commit_csr_if,
@@ -49,6 +53,8 @@ module VX_commit import VX_gpu_pkg::*; #(
wire [`ISSUE_WIDTH-1:0][`NW_WIDTH-1:0] commit_wid;
wire [`ISSUE_WIDTH-1:0][`NUM_THREADS-1:0] commit_tmask;
wire [`ISSUE_WIDTH-1:0] commit_eop;
wire [`ISSUE_WIDTH-1:0][`EX_BITS-1:0] commit_sel;
`UNUSED_VAR (commit_sel)
for (genvar i = 0; i < `ISSUE_WIDTH; ++i) begin
@@ -66,6 +72,9 @@ module VX_commit import VX_gpu_pkg::*; #(
sfu_commit_if[i].valid,
`ifdef EXT_F_ENABLE
fpu_commit_if[i].valid,
`endif
`ifdef EXT_T_ENABLE
tensor_commit_if[i].valid,
`endif
alu_commit_if[i].valid,
lsu_commit_if[i].valid
@@ -74,6 +83,9 @@ module VX_commit import VX_gpu_pkg::*; #(
sfu_commit_if[i].ready,
`ifdef EXT_F_ENABLE
fpu_commit_if[i].ready,
`endif
`ifdef EXT_T_ENABLE
tensor_commit_if[i].ready,
`endif
alu_commit_if[i].ready,
lsu_commit_if[i].ready
@@ -82,6 +94,9 @@ module VX_commit import VX_gpu_pkg::*; #(
sfu_commit_if[i].data,
`ifdef EXT_F_ENABLE
fpu_commit_if[i].data,
`endif
`ifdef EXT_T_ENABLE
tensor_commit_if[i].data,
`endif
alu_commit_if[i].data,
lsu_commit_if[i].data
@@ -89,7 +104,7 @@ module VX_commit import VX_gpu_pkg::*; #(
.data_out (commit_if[i].data),
.valid_out (commit_if[i].valid),
.ready_out (commit_if[i].ready),
`UNUSED_PIN (sel_out)
.sel_out (commit_sel[i])
);
assign commit_fire[i] = commit_if[i].valid && commit_if[i].ready;
@@ -158,7 +173,36 @@ module VX_commit import VX_gpu_pkg::*; #(
// Committed instructions
wire [`ISSUE_WIDTH-1:0] committed = commit_fire & commit_eop;
// temporary hack to not underflow the pending instructions buffer
// relies on 1 cycle delay of arbiter and continuous issuing of tensor instructions,
// so probably want to change this at some point
// (i.e. pass a "don't count this towards pending instructions" signal down the pipeline)
// logic [`ISSUE_WIDTH-1:0][4:0] hmma_ctr, hmma_ctr_n;
wire [`ISSUE_WIDTH-1:0] final_hmma;
`ifdef EXT_T_ENABLE
for (genvar i = 0; i < `ISSUE_WIDTH; ++i) begin
// assign hmma_ctr_n[i] = (tensor_commit_if[i].valid && tensor_commit_if[i].ready) ? hmma_ctr[i] + 5'b1 : hmma_ctr[i];
// assign final_hmma[i] = (commit_sel[i] != `EX_BITS'(2) || hmma_ctr == '0);
// i suppose this is now a feature and not a bug
// if PC is 0, this means it is not final step of a wmma, shouldn't be committed
assign final_hmma[i] = (commit_if[i].data.PC != 32'b0);
end
/*
always @(posedge clk) begin
if (reset) begin
hmma_ctr <= '0;
end
else begin
hmma_ctr <= hmma_ctr_n;
end
end
*/
`else
assign final_hmma = '1;
`endif
wire [`ISSUE_WIDTH-1:0] committed = (commit_fire & commit_eop) & final_hmma;
VX_pipe_register #(
.DATAW (`ISSUE_WIDTH * (1 + `NW_WIDTH)),

View File

@@ -71,6 +71,10 @@ module VX_core import VX_gpu_pkg::*; #(
`ifdef EXT_F_ENABLE
VX_dispatch_if fpu_dispatch_if[`ISSUE_WIDTH]();
VX_commit_if fpu_commit_if[`ISSUE_WIDTH]();
`endif
`ifdef EXT_T_ENABLE
VX_dispatch_if tensor_dispatch_if[`ISSUE_WIDTH]();
VX_commit_if tensor_commit_if[`ISSUE_WIDTH]();
`endif
VX_dispatch_if sfu_dispatch_if[`ISSUE_WIDTH]();
VX_commit_if sfu_commit_if[`ISSUE_WIDTH]();
@@ -178,6 +182,9 @@ module VX_core import VX_gpu_pkg::*; #(
.lsu_dispatch_if(lsu_dispatch_if),
`ifdef EXT_F_ENABLE
.fpu_dispatch_if(fpu_dispatch_if),
`endif
`ifdef EXT_T_ENABLE
.tensor_dispatch_if(tensor_dispatch_if),
`endif
.sfu_dispatch_if(sfu_dispatch_if)
);
@@ -203,6 +210,10 @@ module VX_core import VX_gpu_pkg::*; #(
.fpu_dispatch_if(fpu_dispatch_if),
.fpu_commit_if (fpu_commit_if),
`endif
`ifdef EXT_T_ENABLE
.tensor_dispatch_if (tensor_dispatch_if),
.tensor_commit_if (tensor_commit_if),
`endif
.commit_csr_if (commit_csr_if),
.sched_csr_if (sched_csr_if),
@@ -237,6 +248,9 @@ module VX_core import VX_gpu_pkg::*; #(
.fpu_commit_if (fpu_commit_if),
`endif
.sfu_commit_if (sfu_commit_if),
`ifdef EXT_T_ENABLE
.tensor_commit_if (tensor_commit_if),
`endif
.writeback_if (writeback_if),

View File

@@ -513,6 +513,40 @@ module VX_decode #(
default:;
endcase
end
`INST_EXT3: begin
ex_type = `EX_ALU;
op_mod[3] = 1;
`USED_IREG(rs1);
`USED_IREG(rd);
case (func7[5:0])
6'h0: begin
op_type = func7[6] ? `INST_RED_ADDU : `INST_RED_ADD;
end
6'h1: begin
op_type = func7[6] ? `INST_RED_MINU : `INST_RED_MIN;
end
6'h2: begin
op_type = func7[6] ? `INST_RED_MAXU : `INST_RED_MAX;
end
6'h3: begin
op_type = `INST_RED_AND;
end
6'h4: begin
op_type = `INST_RED_OR;
end
6'h5: begin
op_type = `INST_RED_XOR;
end
default:;
endcase
end
`ifdef EXT_T_ENABLE
`INST_EXT4: begin
ex_type = `EX_TENSOR;
op_type = `INST_TENSOR_HMMA;
end
`endif
default:;
endcase
end

View File

@@ -34,6 +34,9 @@ module VX_dispatch import VX_gpu_pkg::*; #(
VX_dispatch_if.master lsu_dispatch_if [`ISSUE_WIDTH],
`ifdef EXT_F_ENABLE
VX_dispatch_if.master fpu_dispatch_if [`ISSUE_WIDTH],
`endif
`ifdef EXT_T_ENABLE
VX_dispatch_if.master tensor_dispatch_if [`ISSUE_WIDTH],
`endif
VX_dispatch_if.master sfu_dispatch_if [`ISSUE_WIDTH]
);
@@ -142,6 +145,35 @@ module VX_dispatch import VX_gpu_pkg::*; #(
end
`endif
// Tensor Core dispatch
`ifdef EXT_T_ENABLE
VX_operands_if tensor_operands_if[`ISSUE_WIDTH]();
for (genvar i = 0; i < `ISSUE_WIDTH; ++i) begin
assign tensor_operands_if[i].valid = operands_if[i].valid && (operands_if[i].data.ex_type == `EX_TENSOR);
assign tensor_operands_if[i].data = operands_if[i].data;
`RESET_RELAY (tensor_reset, reset);
VX_elastic_buffer #(
.DATAW (DATAW),
.SIZE (2),
.OUT_REG (2)
) tensor_buffer (
.clk (clk),
.reset (tensor_reset),
.valid_in (tensor_operands_if[i].valid),
.ready_in (tensor_operands_if[i].ready),
.data_in (`TO_DISPATCH_DATA(tensor_operands_if[i].data, last_active_tid[i])),
.data_out (tensor_dispatch_if[i].data),
.valid_out (tensor_dispatch_if[i].valid),
.ready_out (tensor_dispatch_if[i].ready)
);
end
`endif
// SFU dispatch
VX_operands_if sfu_operands_if[`ISSUE_WIDTH]();
@@ -174,6 +206,9 @@ module VX_dispatch import VX_gpu_pkg::*; #(
|| (lsu_operands_if[i].ready && (operands_if[i].data.ex_type == `EX_LSU))
`ifdef EXT_F_ENABLE
|| (fpu_operands_if[i].ready && (operands_if[i].data.ex_type == `EX_FPU))
`endif
`ifdef EXT_T_ENABLE
|| (tensor_operands_if[i].ready && (operands_if[i].data.ex_type == `EX_TENSOR))
`endif
|| (sfu_operands_if[i].ready && (operands_if[i].data.ex_type == `EX_SFU));
end

View File

@@ -41,7 +41,7 @@ module VX_execute import VX_gpu_pkg::*; #(
VX_dispatch_if.slave fpu_dispatch_if [`ISSUE_WIDTH],
VX_commit_if.master fpu_commit_if [`ISSUE_WIDTH],
`endif
VX_dispatch_if.slave alu_dispatch_if [`ISSUE_WIDTH],
VX_commit_if.master alu_commit_if [`ISSUE_WIDTH],
VX_branch_ctl_if.master branch_ctl_if [`NUM_ALU_BLOCKS],
@@ -53,6 +53,11 @@ module VX_execute import VX_gpu_pkg::*; #(
VX_commit_if.master sfu_commit_if [`ISSUE_WIDTH],
VX_warp_ctl_if.master warp_ctl_if,
`ifdef EXT_T_ENABLE
VX_dispatch_if.slave tensor_dispatch_if [`ISSUE_WIDTH],
VX_commit_if.master tensor_commit_if [`ISSUE_WIDTH],
`endif
// simulation helper signals
output wire sim_ebreak,
@@ -135,6 +140,18 @@ module VX_execute import VX_gpu_pkg::*; #(
.acc_write_en (acc_write_en)
);
`ifdef EXT_T_ENABLE
VX_tensor_core #(
) tensor_core (
.clk(clk),
.reset(reset),
.dispatch_if(tensor_dispatch_if),
.commit_if(tensor_commit_if)
);
`endif
// simulation helper signal to get RISC-V tests Pass/Fail status
assign sim_ebreak = alu_dispatch_if[0].valid && alu_dispatch_if[0].ready
&& alu_dispatch_if[0].data.wis == 0

View File

@@ -85,6 +85,25 @@ module VX_fpu_unit import VX_fpu_pkg::*; #(
wire execute_fire = execute_if[block_idx].valid && execute_if[block_idx].ready;
wire fpu_rsp_fire = fpu_rsp_valid && fpu_rsp_ready;
reg [63:0] perf_execute_fires;
reg [63:0] perf_execute_valids;
reg [63:0] perf_fpu_req_valids;
reg [63:0] perf_fpu_req_readys;
always @(posedge clk) begin
if (reset) begin
perf_execute_fires <= '0;
perf_execute_valids <= '0;
perf_fpu_req_valids <= '0;
perf_fpu_req_readys <= '0;
end else begin
perf_execute_fires <= perf_execute_fires + 64'(execute_fire);
perf_execute_valids <= perf_execute_valids + 64'(execute_if[block_idx].valid);
perf_fpu_req_valids <= perf_fpu_req_valids + 64'(fpu_req_valid);
perf_fpu_req_readys <= perf_fpu_req_readys + 64'(fpu_req_ready);
end
end
VX_index_buffer #(
.DATAW (`UUID_WIDTH + `NW_WIDTH + NUM_LANES + `XLEN + `NR_BITS + PID_WIDTH + 1 + 1),
.SIZE (`FPUQ_SIZE)

View File

@@ -36,6 +36,8 @@ module VX_ibuffer import VX_gpu_pkg::*; #(
assign decode_if.ready = ibuf_ready_in[decode_isw];
VX_ibuffer_if uop_sequencer_if [`ISSUE_WIDTH]();
for (genvar i = 0; i < `ISSUE_WIDTH; ++i) begin
VX_elastic_buffer #(
.DATAW (DATAW),
@@ -62,13 +64,29 @@ module VX_ibuffer import VX_gpu_pkg::*; #(
decode_if.data.rs1,
decode_if.data.rs2,
decode_if.data.rs3}),
.data_out(ibuffer_if[i].data),
.valid_out (ibuffer_if[i].valid),
.ready_out(ibuffer_if[i].ready)
);
.data_out (uop_sequencer_if[i].data),
.valid_out (uop_sequencer_if[i].valid),
.ready_out (uop_sequencer_if[i].ready)
);
`ifndef L1_ENABLE
assign decode_if.ibuf_pop[i] = ibuffer_if[i].valid && ibuffer_if[i].ready;
assign decode_if.ibuf_pop[i] = uop_sequencer_if[i].valid && uop_sequencer_if[i].ready;
`endif
// tensor-core operation is controlled by a single macro-instruction at
// the ISA; internally, the uop_sequencer blitzs micro-ops (counterpart
// to Volta SASS set/step instructions) into the ibuffer upon encountering
// this macro-instruction. this becomes a pass-through for non-tensorcore
// instructions.
VX_uop_sequencer uop_sequencer (
.clk(clk),
.reset(reset),
.uop_sequencer_if(uop_sequencer_if[i]),
.ibuffer_if(ibuffer_if[i])
);
end
endmodule

View File

@@ -33,6 +33,9 @@ module VX_issue #(
VX_dispatch_if.master lsu_dispatch_if [`ISSUE_WIDTH],
`ifdef EXT_F_ENABLE
VX_dispatch_if.master fpu_dispatch_if [`ISSUE_WIDTH],
`endif
`ifdef EXT_T_ENABLE
VX_dispatch_if.master tensor_dispatch_if [`ISSUE_WIDTH],
`endif
VX_dispatch_if.master sfu_dispatch_if [`ISSUE_WIDTH]
);
@@ -93,7 +96,6 @@ module VX_issue #(
.clk (clk),
.reset (dispatch_reset),
`ifdef PERF_ENABLE
`UNUSED_PIN (perf_stalls),
.perf_stalls (perf_issue_if.dispatch_stalls),
.perf_valids (perf_issue_if.dispatch_valids),
.perf_fires (perf_issue_if.dispatch_fires),
@@ -104,6 +106,9 @@ module VX_issue #(
.lsu_dispatch_if(lsu_dispatch_if),
`ifdef EXT_F_ENABLE
.fpu_dispatch_if(fpu_dispatch_if),
`endif
`ifdef EXT_T_ENABLE
.tensor_dispatch_if(tensor_dispatch_if),
`endif
.sfu_dispatch_if(sfu_dispatch_if)
);

View File

@@ -294,6 +294,27 @@ module VX_operands import VX_gpu_pkg::*; #(
.raddr (gpr_rd_addr),
.rdata (gpr_rd_data[j])
);
// blast read register file because printf is slowge
logic [31:0] cycle, cycle_n;
assign cycle_n = cycle + 32'd1;
always @(posedge clk) begin
if (reset) begin
cycle <= '0;
end
else begin
cycle <= cycle_n;
end
// if (cycle == 32'd25000) begin
// for (integer k = 0; k < `NUM_REGS * ISSUE_RATIO; ++k) begin
// integer warp = i * ISSUE_RATIO + (k / `NUM_REGS);
// integer thread = j;
// integer register = k % `NUM_REGS;
// $display("warp %0d, thread %0d, register %0d: %0x", warp, thread, register, gpr_ram.ram[k]);
// end
// end
end
end
end

View File

@@ -31,9 +31,6 @@ module VX_operands_dup import VX_gpu_pkg::*; #(
localparam RAM_ADDRW = `LOG2UP(`NUM_REGS * ISSUE_RATIO);
for (genvar i = 0; i < `ISSUE_WIDTH; ++i) begin
// NOTE(hansung): toggle_buffer is 1-reg pipe without flow, halving
// throughput. Wouldn't this cap overall IPC? Or OK as long as
// ISSUE_WIDTH > 1?
VX_stream_buffer #(
.DATAW (DATAW)
) staging_buffer (

View File

@@ -0,0 +1,285 @@
`include "VX_define.vh"
`include "VX_platform.vh"
// Copyright © 2019-2023
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
`include "VX_platform.vh"
module VX_reduce_ext #(
parameter DATAW_IN = 1,
parameter DATAW_OUT = DATAW_IN,
parameter N = 1
) (
input wire [N-1:0][DATAW_IN-1:0] data_in,
input wire [N-1:0] mask,
input wire [`INST_RED_BITS-1:0] op_type,
output wire [DATAW_OUT-1:0] data_out
);
// recursive binary reduction
if (N == 1) begin
`UNUSED_VAR(op_type)
`UNUSED_VAR(mask)
assign data_out = DATAW_OUT'(data_in[0]);
end else begin
localparam int N_A = N / 2;
localparam int N_B = N - N_A;
wire [N_A-1:0][DATAW_IN-1:0] in_A;
wire [N_B-1:0][DATAW_IN-1:0] in_B;
wire [DATAW_OUT-1:0] out_A, out_B;
wire [N_A-1:0] mask_A;
wire [N_B-1:0] mask_B;
wire any_A, any_B;
for (genvar i = 0; i < N_A; i++) begin
assign in_A[i] = data_in[i];
end
for (genvar i = 0; i < N_B; i++) begin
assign in_B[i] = data_in[N_A + i];
end
assign mask_A = mask[N_A-1:0];
assign mask_B = mask[N-1:N_A];
assign any_A = |mask_A;
assign any_B = |mask_B;
VX_reduce_ext #(
.DATAW_IN (DATAW_IN),
.DATAW_OUT (DATAW_OUT),
.N (N_A)
) reduce_A (
.data_in (in_A),
.mask(mask_A),
.op_type(op_type),
.data_out (out_A)
);
VX_reduce_ext #(
.DATAW_IN (DATAW_IN),
.DATAW_OUT (DATAW_OUT),
.N (N_B)
) reduce_B (
.data_in (in_B),
.mask(mask_B),
.op_type(op_type),
.data_out (out_B)
);
logic [DATAW_OUT-1:0] _data_out;
always @(*) begin
case (op_type)
`INST_RED_ADD: _data_out = out_A + out_B;
`INST_RED_ADDU: _data_out = out_A + out_B;
`INST_RED_MIN: _data_out = ($signed(out_A) < $signed(out_B)) ? out_A : out_B;
`INST_RED_MINU: _data_out = (out_A < out_B) ? out_A : out_B;
`INST_RED_MAX: _data_out = ($signed(out_A) < $signed(out_B)) ? out_B : out_A;
`INST_RED_MAXU: _data_out = (out_A < out_B) ? out_B : out_A;
`INST_RED_AND: _data_out = out_A & out_B;
`INST_RED_OR: _data_out = out_A | out_B;
`INST_RED_XOR: _data_out = out_A ^ out_B;
default: _data_out = out_A;
endcase
end
// if both sides are masked out, then it doesn't matter what we output
assign data_out = (any_A && any_B) ? _data_out : (any_A ? out_A : out_B);
end
endmodule
module VX_reduce_unit #(
parameter CORE_ID = 0,
parameter NUM_LANES = 1
) (
input wire clk,
input wire reset,
VX_execute_if.slave execute_if,
VX_commit_if.master commit_if
);
`UNUSED_PARAM(CORE_ID)
localparam NUM_PACKETS = `NUM_THREADS / NUM_LANES;
localparam PID_BITS = `CLOG2(`NUM_THREADS / NUM_LANES);
localparam PID_WIDTH = `UP(PID_BITS);
logic [`XLEN-1:0] accumulator, accumulator_n, reduced_accumulator;
wire [(NUM_LANES * `XLEN)-1:0] broadcasted_accumulator;
assign broadcasted_accumulator = {NUM_LANES{accumulator}};
wire eop;
wire [NUM_LANES-1:0][`XLEN-1:0] data_in;
wire [`XLEN-1:0] data_out;
assign eop = execute_if.data.eop;
assign data_in = execute_if.data.rs1_data;
logic execute_if_valid;
logic execute_if_ready;
logic commit_if_valid;
logic commit_if_ready;
wire execute_if_fire;
wire commit_if_fire;
assign execute_if_valid = execute_if.valid;
assign execute_if.ready = execute_if_ready;
assign execute_if_fire = execute_if.ready && execute_if.valid;
assign commit_if_fire = commit_if_ready && commit_if_valid;
logic store_tmask_pid;
logic read_tmask_pid;
wire [PID_WIDTH-1:0] stored_pid;
wire [NUM_LANES-1:0] stored_tmask;
wire stored_sop;
wire stored_eop;
logic [PID_BITS:0] size;
logic [PID_BITS:0] size_n;
// 1. idle state - wait for execute_if to be valid
// 2. accumulate - continue accumulating until eop, store packet id + thread mask for broadcast phase
// 3. broadcast - broadcast to rds
localparam IDLE = 2'b00;
localparam ACCUMULATE = 2'b01;
localparam BROADCAST = 2'b10;
localparam FINISH = 2'b11;
logic [1:0] state, state_n;
always @(*) begin
state_n = state;
accumulator_n = accumulator;
execute_if_ready = '0;
commit_if_valid = '0;
store_tmask_pid = '0;
read_tmask_pid = '0;
size_n = store_tmask_pid ? size + 1 : (read_tmask_pid ? size - 1 : size);
case (state)
IDLE: begin
if (execute_if_valid) begin
accumulator_n = data_out;
store_tmask_pid = '1;
if (eop) begin
state_n = BROADCAST;
end
else begin
execute_if_ready = '1;
state_n = ACCUMULATE;
end
end
end
ACCUMULATE: begin
execute_if_ready = '1;
if (eop) begin
execute_if_ready = '0;
state_n = BROADCAST;
end
if (eop || execute_if_fire) begin
accumulator_n = reduced_accumulator;
store_tmask_pid = '1;
end
end
BROADCAST: begin
execute_if_ready = '0;
commit_if_valid = '1;
if (commit_if_fire) begin
read_tmask_pid = '1;
end
if (size_n == '0) begin
state_n = FINISH;
end
end
FINISH: begin
execute_if_ready = '1;
if (execute_if_fire) begin
state_n = IDLE;
end
end
endcase
end
always @(posedge clk) begin
if (reset) begin
accumulator <= '0;
state <= IDLE;
size <= '0;
end
else begin
accumulator <= accumulator_n;
state <= state_n;
size <= size_n;
end
end
VX_reduce_ext #(
.DATAW_IN(`XLEN),
.N(NUM_LANES)
) reducer (
.data_in(data_in),
.mask(execute_if.data.tmask),
.op_type(execute_if.data.op_type),
.data_out(data_out)
);
VX_reduce_ext #(
.DATAW_IN(`XLEN),
.N(2)
) accumulator_reducer (
.data_in({accumulator, data_out}),
.mask(2'b11),
.op_type(execute_if.data.op_type),
.data_out(reduced_accumulator)
);
VX_elastic_buffer #(
.DATAW(NUM_LANES + PID_WIDTH + 1 + 1),
.SIZE(NUM_PACKETS)
) tmask_pid_store (
.clk(clk),
.reset(reset),
.valid_in(store_tmask_pid),
`UNUSED_PIN(ready_in),
.data_in({execute_if.data.tmask, execute_if.data.pid, execute_if.data.sop, execute_if.data.eop}),
.data_out({stored_tmask, stored_pid, stored_sop, stored_eop}),
.ready_out(read_tmask_pid),
`UNUSED_PIN(valid_out)
);
VX_elastic_buffer #(
.DATAW(`UUID_WIDTH + `NW_WIDTH + NUM_LANES + `XLEN + 1 + `NR_BITS + (`XLEN * NUM_LANES) + PID_WIDTH + 1 + 1)
) output_buffer (
.clk(clk),
.reset(reset),
.valid_in(commit_if_valid),
.ready_in(commit_if_ready),
.data_in({execute_if.data.uuid, execute_if.data.wid, stored_tmask, execute_if.data.PC, execute_if.data.wb, execute_if.data.rd, broadcasted_accumulator, stored_pid, stored_sop, stored_eop}),
.data_out({commit_if.data.uuid, commit_if.data.wid, commit_if.data.tmask, commit_if.data.PC, commit_if.data.wb, commit_if.data.rd, commit_if.data.data, commit_if.data.pid, commit_if.data.sop, commit_if.data.eop}),
.ready_out(commit_if.ready),
.valid_out(commit_if.valid)
);
endmodule

View File

@@ -0,0 +1,316 @@
`include "VX_fpu_define.vh"
module VX_tensor_core #(
) (
input clk,
input reset,
VX_dispatch_if.slave dispatch_if [`ISSUE_WIDTH],
VX_commit_if.master commit_if [`ISSUE_WIDTH]
);
`STATIC_ASSERT(`NUM_THREADS == 32, ("tensor core requires # of threads in a warp to be 32 (try running w/ CONFIGS=\"-DNUM_THREADS=32\")"));
for (genvar i = 0; i < `ISSUE_WIDTH; ++i) begin
VX_tensor_core_warp #(
.ISW(i)
) tensor_core (
.clk(clk),
.reset(reset),
.dispatch_if(dispatch_if[i]),
.commit_if(commit_if[i])
);
end
endmodule
module VX_tensor_core_warp import VX_gpu_pkg::*; #(
parameter ISW
) (
input clk,
input reset,
VX_dispatch_if.slave dispatch_if,
VX_commit_if.master commit_if
);
wire [1:0] step = 2'(dispatch_if.data.op_type);
logic [3:0] octet_results_valid;
logic [3:0] octet_results_ready;
logic [3:0] octet_operands_ready;
logic [`NUM_THREADS-1:0][`XLEN-1:0] wb_data_0;
logic [`NUM_THREADS-1:0][`XLEN-1:0] wb_data_1;
assign dispatch_if.ready = &octet_operands_ready;
for (genvar i = 0; i < 4/*octets*/; ++i) begin
// lane-to-octet mapping; see figure 13 of the paper
wire [7:0][31:0] octet_A = {
dispatch_if.data.rs1_data[16+4*i +: 4], dispatch_if.data.rs1_data[4*i +: 4]
};
wire [7:0][31:0] octet_B = {
dispatch_if.data.rs2_data[16+4*i +: 4], dispatch_if.data.rs2_data[4*i +: 4]
};
wire [7:0][31:0] octet_C = {
dispatch_if.data.rs3_data[16+4*i +: 4], dispatch_if.data.rs3_data[4*i +: 4]
};
logic [3:0][3:0][31:0] octet_D;
logic result_valid;
logic result_ready;
VX_tensor_octet #(
.ISW(ISW),
.OCTET(i)
) octet (
.clk(clk),
.reset(reset),
.A_in(octet_A),
.B_in(octet_B),
.C_in(octet_C),
.operands_valid(dispatch_if.valid),
.operands_ready(octet_operands_ready[i]),
.step(step),
.D_out(octet_D),
.result_valid(result_valid),
.result_ready(result_ready)
);
// these should always be in lockstep
assign octet_results_valid[i] = result_valid;
assign result_ready = octet_results_ready[i];
// each octet produces 4x4 output partial sum, but the 8 lanes mapped
// to the octet can only do 8 fp32 writeback at a time; so we need to
// split writeback over two cycles
assign wb_data_0[4*i+0] = octet_D[0][0];
assign wb_data_0[4*i+1] = octet_D[1][0];
assign wb_data_0[4*i+2] = octet_D[0][2];
assign wb_data_0[4*i+3] = octet_D[1][2];
assign wb_data_1[4*i+0] = octet_D[0][1];
assign wb_data_1[4*i+1] = octet_D[1][1];
assign wb_data_1[4*i+2] = octet_D[0][3];
assign wb_data_1[4*i+3] = octet_D[1][3];
assign wb_data_0[4*i+16+0] = octet_D[2][0];
assign wb_data_0[4*i+16+1] = octet_D[3][0];
assign wb_data_0[4*i+16+2] = octet_D[2][2];
assign wb_data_0[4*i+16+3] = octet_D[3][2];
assign wb_data_1[4*i+16+0] = octet_D[2][1];
assign wb_data_1[4*i+16+1] = octet_D[3][1];
assign wb_data_1[4*i+16+2] = octet_D[2][3];
assign wb_data_1[4*i+16+3] = octet_D[3][3];
end
/* commit_if.data_t parts that we need to keep around:
- uuid
- wid
- tmask
- PC
- wb
- rd
*/
localparam DATAW = `UUID_WIDTH + `NW_WIDTH + `NUM_THREADS + `XLEN + 1 + `NR_BITS;
wire dispatch_if_fire = dispatch_if.valid && dispatch_if.ready;
wire commit_if_fire = commit_if.valid && commit_if.ready;
wire [DATAW-1:0] dispatch_if_data_enq = {
dispatch_if.data.uuid,
wis_to_wid(dispatch_if.data.wis, ISW),
dispatch_if.data.tmask,
dispatch_if.data.PC,
dispatch_if.data.wb,
dispatch_if.data.rd
};
wire [DATAW-1:0] dispatch_if_data_deq;
// this is probably a little oversized
VX_fifo_queue #(
.DATAW(DATAW),
.DEPTH(16)
) pending_uops (
.clk(clk),
.reset(reset),
.push(dispatch_if_fire),
.pop(commit_if_fire),
.data_in(dispatch_if_data_enq),
.data_out(dispatch_if_data_deq),
`UNUSED_PIN(empty),
`UNUSED_PIN(alm_empty),
`UNUSED_PIN(full), // should be impossible to overflow
`UNUSED_PIN(alm_full),
`UNUSED_PIN(size)
);
logic subcommit, subcommit_n;
wire all_valid = (& octet_results_valid);
assign commit_if.valid = all_valid;
localparam COMMIT_DATAW = `UUID_WIDTH + `NW_WIDTH + `NUM_THREADS + `XLEN + 1 + `NR_BITS + (`NUM_THREADS * `XLEN) + 1 + 1 + 1;
wire [COMMIT_DATAW-1:0] commit_if_data = {
dispatch_if_data_deq, /* uuid ~ rd */
subcommit == 1'b0 ? wb_data_0 : wb_data_1, /* data */
1'b0, /* pid */
1'b1, /* sop */
1'b1 /* eop */
};
assign commit_if.data = commit_if_data;
always @(*) begin
subcommit_n = commit_if_fire ? ~subcommit : subcommit;
if (commit_if_fire && subcommit == 1'b1) begin
octet_results_ready = '1;
end
else begin
octet_results_ready = '0;
end
end
always @(posedge clk) begin
if (reset) begin
subcommit <= '0;
end
else begin
subcommit <= subcommit_n;
end
end
endmodule
module VX_tensor_octet #(
parameter ISW,
parameter OCTET
) (
input clk,
input reset,
input [7:0][31:0] A_in,
input [7:0][31:0] B_in,
input [7:0][31:0] C_in,
input operands_valid, // we have to backpressure due to there potentially being contention over commit
output operands_ready,
input [1:0] step,
output [3:0][3:0][31:0] D_out,
output result_valid,
input result_ready
);
// 512 bits/octet * 4 octets per warp
logic [3:0][31:0] A_buffer, A_buffer_n;
logic [3:0][31:0] B_buffer, B_buffer_n;
logic [7:0][31:0] C_buffer, C_buffer_n;
// half the inputs are buffered, half are not (instead coming straight
// from operand bus) unlike the real tensor core.
// the banks are only 32 bit rather than 64 bit (a pair of fp32 regs).
logic [3:0][31:0] A_half;
logic [3:0][31:0] B_half;
logic [7:0][31:0] C_half;
always @(*) begin
// note that not all lanes participate at every step
case (step)
2'b00: begin
A_half = { A_in[5:4], A_in[1:0] };
B_half = B_in[3:0];
end
2'b01: begin
A_half = { A_in[7:6], A_in[3:2] };
B_half = B_in[3:0];
end
2'b10: begin
A_half = { A_in[5:4], A_in[1:0] };
B_half = B_in[7:4];
end
2'b11: begin
A_half = { A_in[7:6], A_in[3:2] };
B_half = B_in[7:4];
end
endcase
C_half = C_in;
end
logic substep;
wire substep_n = (operands_ready && operands_valid) ? ~substep : substep;
always @(*) begin
A_buffer_n = A_buffer;
B_buffer_n = B_buffer;
C_buffer_n = C_buffer;
if (substep == 1'b0) begin
A_buffer_n = A_half;
B_buffer_n = B_half;
C_buffer_n = C_half;
end
end
always @(posedge clk) begin
if (reset) begin
A_buffer <= '0;
B_buffer <= '0;
C_buffer <= '0;
substep <= '0;
end
else begin
A_buffer <= A_buffer_n;
B_buffer <= B_buffer_n;
C_buffer <= C_buffer_n;
substep <= substep_n;
end
end
wire stall = result_valid && ~result_ready;
assign operands_ready = ~stall;
// A is 4x2 fp32 matrix
wire [3:0][1:0][31:0] A_tile = {
{ A_half[3], A_buffer[3] },
{ A_half[2], A_buffer[2] },
{ A_half[1], A_buffer[1] },
{ A_half[0], A_buffer[0] }
};
// B is 2x4 fp32 matrix
wire [1:0][3:0][31:0] B_tile = {
B_half, B_buffer
};
// C is 4x4 fp32 matrix
logic [3:0][3:0][31:0] C_tile;
always @(*) begin
C_tile = {
C_half[7], C_buffer[7], C_half[5], C_buffer[5],
C_half[6], C_buffer[6], C_half[4], C_buffer[4],
C_half[3], C_buffer[3], C_half[1], C_buffer[1],
C_half[2], C_buffer[2], C_half[0], C_buffer[0]
};
end
wire do_hmma = (substep == 1'b1 && operands_valid && operands_ready);
// this does (m,n,k)=(4,4,2) matmul, modeling compute of a single octet
VX_tensor_dpu #(
.ISW(ISW),
.OCTET(OCTET)
) dpu (
.clk(clk),
.reset(reset),
.stall(stall),
.valid_in(do_hmma),
.A_tile(A_tile),
.B_tile(B_tile),
.C_tile(C_tile),
.valid_out(result_valid),
.D_tile(D_out)
);
endmodule

View File

@@ -0,0 +1,97 @@
// uop metadata (sequencing, next state), execution metadata (EX_TYPE, OP_TYPE, OP_MOD), wb, use pc, use imm, pc, imm, rd, rs1, rs2, rs3
HMMA_SET0_STEP0_0: begin
uop = {NEXT, HMMA_SET0_STEP0_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(0), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(16), `FREG(0), `FREG(8), `FREG(16)};
end
HMMA_SET0_STEP0_1: begin
uop = {NEXT, HMMA_SET0_STEP1_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(0), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(17), `FREG(1), `FREG(9), `FREG(17)};
end
HMMA_SET0_STEP1_0: begin
uop = {NEXT, HMMA_SET0_STEP1_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(1), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(18), `FREG(0), `FREG(8), `FREG(18)};
end
HMMA_SET0_STEP1_1: begin
uop = {NEXT, HMMA_SET0_STEP2_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(1), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(19), `FREG(1), `FREG(9), `FREG(19)};
end
HMMA_SET0_STEP2_0: begin
uop = {NEXT, HMMA_SET0_STEP2_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(2), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(20), `FREG(0), `FREG(8), `FREG(20)};
end
HMMA_SET0_STEP2_1: begin
uop = {NEXT, HMMA_SET0_STEP3_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(2), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(21), `FREG(1), `FREG(9), `FREG(21)};
end
HMMA_SET0_STEP3_0: begin
uop = {NEXT, HMMA_SET0_STEP3_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(3), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(22), `FREG(0), `FREG(8), `FREG(22)};
end
HMMA_SET0_STEP3_1: begin
uop = {NEXT, HMMA_SET1_STEP0_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(3), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(23), `FREG(1), `FREG(9), `FREG(23)};
end
HMMA_SET1_STEP0_0: begin
uop = {NEXT, HMMA_SET1_STEP0_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(0), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(16), `FREG(2), `FREG(10), `FREG(16)};
end
HMMA_SET1_STEP0_1: begin
uop = {NEXT, HMMA_SET1_STEP1_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(0), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(17), `FREG(3), `FREG(11), `FREG(17)};
end
HMMA_SET1_STEP1_0: begin
uop = {NEXT, HMMA_SET1_STEP1_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(1), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(18), `FREG(2), `FREG(10), `FREG(18)};
end
HMMA_SET1_STEP1_1: begin
uop = {NEXT, HMMA_SET1_STEP2_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(1), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(19), `FREG(3), `FREG(11), `FREG(19)};
end
HMMA_SET1_STEP2_0: begin
uop = {NEXT, HMMA_SET1_STEP2_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(2), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(20), `FREG(2), `FREG(10), `FREG(20)};
end
HMMA_SET1_STEP2_1: begin
uop = {NEXT, HMMA_SET1_STEP3_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(2), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(21), `FREG(3), `FREG(11), `FREG(21)};
end
HMMA_SET1_STEP3_0: begin
uop = {NEXT, HMMA_SET1_STEP3_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(3), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(22), `FREG(2), `FREG(10), `FREG(22)};
end
HMMA_SET1_STEP3_1: begin
uop = {NEXT, HMMA_SET2_STEP0_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(3), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(23), `FREG(3), `FREG(11), `FREG(23)};
end
HMMA_SET2_STEP0_0: begin
uop = {NEXT, HMMA_SET2_STEP0_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(0), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(16), `FREG(4), `FREG(12), `FREG(16)};
end
HMMA_SET2_STEP0_1: begin
uop = {NEXT, HMMA_SET2_STEP1_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(0), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(17), `FREG(5), `FREG(13), `FREG(17)};
end
HMMA_SET2_STEP1_0: begin
uop = {NEXT, HMMA_SET2_STEP1_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(1), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(18), `FREG(4), `FREG(12), `FREG(18)};
end
HMMA_SET2_STEP1_1: begin
uop = {NEXT, HMMA_SET2_STEP2_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(1), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(19), `FREG(5), `FREG(13), `FREG(19)};
end
HMMA_SET2_STEP2_0: begin
uop = {NEXT, HMMA_SET2_STEP2_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(2), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(20), `FREG(4), `FREG(12), `FREG(20)};
end
HMMA_SET2_STEP2_1: begin
uop = {NEXT, HMMA_SET2_STEP3_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(2), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(21), `FREG(5), `FREG(13), `FREG(21)};
end
HMMA_SET2_STEP3_0: begin
uop = {NEXT, HMMA_SET2_STEP3_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(3), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(22), `FREG(4), `FREG(12), `FREG(22)};
end
HMMA_SET2_STEP3_1: begin
uop = {NEXT, HMMA_SET3_STEP0_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(3), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(23), `FREG(5), `FREG(13), `FREG(23)};
end
HMMA_SET3_STEP0_0: begin
uop = {NEXT, HMMA_SET3_STEP0_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(0), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(16), `FREG(6), `FREG(14), `FREG(16)};
end
HMMA_SET3_STEP0_1: begin
uop = {NEXT, HMMA_SET3_STEP1_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(0), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(17), `FREG(7), `FREG(15), `FREG(17)};
end
HMMA_SET3_STEP1_0: begin
uop = {NEXT, HMMA_SET3_STEP1_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(1), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(18), `FREG(6), `FREG(14), `FREG(18)};
end
HMMA_SET3_STEP1_1: begin
uop = {NEXT, HMMA_SET3_STEP2_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(1), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(19), `FREG(7), `FREG(15), `FREG(19)};
end
HMMA_SET3_STEP2_0: begin
uop = {NEXT, HMMA_SET3_STEP2_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(2), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(20), `FREG(6), `FREG(14), `FREG(20)};
end
HMMA_SET3_STEP2_1: begin
uop = {NEXT, HMMA_SET3_STEP3_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(2), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(21), `FREG(7), `FREG(15), `FREG(21)};
end
HMMA_SET3_STEP3_0: begin
uop = {NEXT, HMMA_SET3_STEP3_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(3), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(22), `FREG(6), `FREG(14), `FREG(22)};
end
HMMA_SET3_STEP3_1: begin
uop = {FINISH, HMMA_SET0_STEP0_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(3), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b1, 32'b1, `FREG(23), `FREG(7), `FREG(15), `FREG(23)};
end

View File

@@ -0,0 +1,164 @@
`include "VX_define.vh"
`define FREG(x) {1'b1, `NRI_BITS'(x)}
module VX_uop_sequencer import VX_gpu_pkg::*; (
input clk,
input reset,
VX_ibuffer_if.slave uop_sequencer_if,
VX_ibuffer_if.master ibuffer_if
);
`ifdef EXT_T_ENABLE
localparam UOP_TABLE_SIZE = 64;
localparam UPC_BITS = `CLOG2(UOP_TABLE_SIZE);
localparam NEXT = 2'b00;
localparam FINISH = 2'b01;
localparam UBR_BITS = 2;
// uop metadata (sequencing, next state), execution metadata (EX_TYPE, OP_TYPE, OP_MOD), wb, use pc, use imm, pc, imm, rd, rs1, rs2, rs3
localparam UOP_TABLE_WIDTH = UBR_BITS + UPC_BITS + `EX_BITS + `INST_OP_BITS + `INST_MOD_BITS + 1 + 1 + 1 + `XLEN + `XLEN + (`NR_BITS * 4);
localparam IBUFFER_IF_DATAW = `UUID_WIDTH + ISSUE_WIS_W + `NUM_THREADS + `XLEN + 1 + `EX_BITS + `INST_OP_BITS + `INST_MOD_BITS + 1 + 1 + `XLEN + (`NR_BITS * 4);
logic [UOP_TABLE_WIDTH-1:0] uop;
// reserve space at start of table for more uop sequences
localparam HMMA_SET0_STEP0_0 = UPC_BITS'(0);
localparam HMMA_SET0_STEP0_1 = UPC_BITS'(8);
localparam HMMA_SET0_STEP1_0 = UPC_BITS'(9);
localparam HMMA_SET0_STEP1_1 = UPC_BITS'(10);
localparam HMMA_SET0_STEP2_0 = UPC_BITS'(11);
localparam HMMA_SET0_STEP2_1 = UPC_BITS'(12);
localparam HMMA_SET0_STEP3_0 = UPC_BITS'(13);
localparam HMMA_SET0_STEP3_1 = UPC_BITS'(14);
localparam HMMA_SET1_STEP0_0 = UPC_BITS'(15);
localparam HMMA_SET1_STEP0_1 = UPC_BITS'(16);
localparam HMMA_SET1_STEP1_0 = UPC_BITS'(17);
localparam HMMA_SET1_STEP1_1 = UPC_BITS'(18);
localparam HMMA_SET1_STEP2_0 = UPC_BITS'(19);
localparam HMMA_SET1_STEP2_1 = UPC_BITS'(20);
localparam HMMA_SET1_STEP3_0 = UPC_BITS'(21);
localparam HMMA_SET1_STEP3_1 = UPC_BITS'(22);
localparam HMMA_SET2_STEP0_0 = UPC_BITS'(23);
localparam HMMA_SET2_STEP0_1 = UPC_BITS'(24);
localparam HMMA_SET2_STEP1_0 = UPC_BITS'(25);
localparam HMMA_SET2_STEP1_1 = UPC_BITS'(26);
localparam HMMA_SET2_STEP2_0 = UPC_BITS'(27);
localparam HMMA_SET2_STEP2_1 = UPC_BITS'(28);
localparam HMMA_SET2_STEP3_0 = UPC_BITS'(29);
localparam HMMA_SET2_STEP3_1 = UPC_BITS'(30);
localparam HMMA_SET3_STEP0_0 = UPC_BITS'(31);
localparam HMMA_SET3_STEP0_1 = UPC_BITS'(32);
localparam HMMA_SET3_STEP1_0 = UPC_BITS'(33);
localparam HMMA_SET3_STEP1_1 = UPC_BITS'(34);
localparam HMMA_SET3_STEP2_0 = UPC_BITS'(35);
localparam HMMA_SET3_STEP2_1 = UPC_BITS'(36);
localparam HMMA_SET3_STEP3_0 = UPC_BITS'(37);
localparam HMMA_SET3_STEP3_1 = UPC_BITS'(38);
// register layout: f0-f7 used for A, f8-f15 used for B, f16-f23 used for C
logic [UPC_BITS-1:0] upc, upc_r, upc_n;
always @(*) begin
case (upc)
`include "VX_tensor_ucode.vh"
default: begin
uop = '0;
end
endcase
end
wire [UBR_BITS-1:0] ubr = uop[UOP_TABLE_WIDTH-1:UOP_TABLE_WIDTH-UBR_BITS];
wire [UPC_BITS-1:0] next_upc = uop[UOP_TABLE_WIDTH-UBR_BITS-1:UOP_TABLE_WIDTH-UBR_BITS-UPC_BITS];
logic use_uop, use_uop_1d;
wire uop_fire = use_uop && ibuffer_if.valid && ibuffer_if.ready;
wire uop_start = ~use_uop_1d && use_uop;
wire uop_finish = use_uop && uop_sequencer_if.valid && uop_sequencer_if.ready;
// merging the 2 always blocks leads to spurious UNOPTFLAT verilator lint,
// but conceptually they should be linked
always @(*) begin
use_uop = uop_sequencer_if.valid && uop_sequencer_if.data.ex_type == `EX_BITS'(`EX_TENSOR);
if (uop_start) begin
// 1st cycle of microcoded operation, use op_type to determine entry point into microcode table
upc_n = UPC_BITS'(uop_sequencer_if.data.op_type);
end
else begin
upc_n = upc;
end
if (uop_fire) begin
upc_n = next_upc;
end
end
always @(*) begin
if (uop_start) begin
// 1st cycle of microcoded operation, use op_type to determine entry point into microcode table
upc = UPC_BITS'(uop_sequencer_if.data.op_type);
end
else begin
upc = upc_r;
end
end
// copy UUID, wis, tmask from microcoded instruction
wire [IBUFFER_IF_DATAW-1:0] ibuffer_output = {
uop_sequencer_if.data.uuid,
uop_sequencer_if.data.wis,
uop_sequencer_if.data.tmask,
uop[UOP_TABLE_WIDTH-UBR_BITS-UPC_BITS-1:0]
};
// passthrough when !use_uop
assign ibuffer_if.valid = use_uop ? 1'b1 : uop_sequencer_if.valid;
assign uop_sequencer_if.ready = use_uop ? (uop_fire && ubr == FINISH) : ibuffer_if.ready;
assign ibuffer_if.data = use_uop ? ibuffer_output : uop_sequencer_if.data;
always @(posedge clk) begin
if (uop_start) begin
// $display("UOP start @ %t", $time);
// $display("use_uop=%0d, use_uop_1d=%0d, uop_start=%0d, ibuffer_if.valid=%0d, ibuffer_if.ready=%0d", use_uop, use_uop_1d, uop_start, ibuffer_if.valid, ibuffer_if.ready);
end
if (uop_fire) begin
// $display("UOP fire @ %t", $time);
end
if (uop_finish) begin
// $display("UOP finish @ %t", $time);
end
if (reset) begin
upc_r <= '0;
use_uop_1d <= '0;
end
else begin
upc_r <= upc_n;
if (uop_finish) begin
use_uop_1d <= 1'b0; // allow microcoded instructions to start immediately after eachother
end
else begin
use_uop_1d <= use_uop;
end
end
end
`else
`UNUSED_VAR(clk);
`UNUSED_VAR(reset);
assign ibuffer_if.valid = uop_sequencer_if.valid;
assign uop_sequencer_if.ready = ibuffer_if.ready;
assign ibuffer_if.data = uop_sequencer_if.data;
`endif
endmodule

View File

@@ -0,0 +1,85 @@
num_sets = 4
num_steps = 4
num_substeps = 2
def set_step_substep(sequence_number):
set_num, step = divmod(sequence_number, num_steps * num_substeps)
step //= num_substeps
substep = sequence_number % 2
return set_num, step, substep
# set + substep -> rs1, rs2
rs1 = {
(0, 0): 0,
(0, 1): 1,
(1, 0): 2,
(1, 1): 3,
(2, 0): 4,
(2, 1): 5,
(3, 0): 6,
(3, 1): 7
}
rs2 = {
(0, 0): 8,
(0, 1): 9,
(1, 0): 10,
(1, 1): 11,
(2, 0): 12,
(2, 1): 13,
(3, 0): 14,
(3, 1): 15
}
# step + substep -> rs3, rd
rs3_rd = {
(0, 0): 16,
(0, 1): 17,
(1, 0): 18,
(1, 1): 19,
(2, 0): 20,
(2, 1): 21,
(3, 0): 22,
(3, 1): 23
}
with open('VX_tensor_ucode.vh', 'w') as f:
for sequence_number in range(num_sets * num_steps * num_substeps):
set_num, step, substep = set_step_substep(sequence_number)
next_sequence_num = (sequence_number + 1) % (num_sets * num_steps * num_substeps)
next_set_num, next_step, next_substep = set_step_substep(next_sequence_num)
finish = (next_sequence_num == 0)
name = "HMMA_SET{}_STEP{}_{}"
ucode = "{}, HMMA_SET{}_STEP{}_{}, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'({}), `INST_MOD_BITS'({}), 1'b1, 1'b0, 1'b0, 32'b{}, 32'b{}, `FREG({}), `FREG({}), `FREG({}), `FREG({})"
name = name.format(
set_num, step, substep,
)
pc_imm = 1 if finish else 0
ucode = ucode.format(
"FINISH" if finish else "NEXT",
next_set_num, next_step, next_substep,
step,
substep,
pc_imm,
pc_imm,
rs3_rd[(step, substep)],
rs1[(set_num, substep)],
rs2[(set_num, substep)],
rs3_rd[(step, substep)],
)
entry = f"{name}: begin \n"
entry += "\tuop = {"
entry += ucode
entry += "}; \n"
entry += "end \n"
f.write(entry)

View File

@@ -0,0 +1,44 @@
`include "VX_fpu_define.vh"
module VX_tensor_dpu #(
parameter ISW,
parameter OCTET
) (
input clk,
input reset,
input stall,
input valid_in,
input [3:0][1:0][31:0] A_tile,
input [1:0][3:0][31:0] B_tile,
input [3:0][3:0][31:0] C_tile,
output valid_out,
output [3:0][3:0][31:0] D_tile
);
logic [3:0][3:0][31:0] result_hmma;
always @(*) begin
dpi_hmma(valid_in, A_tile, B_tile, C_tile, result_hmma);
end
always @(posedge clk) begin
if (~reset && valid_in) begin
dpi_print_results(int'(ISW), int'(OCTET), A_tile, B_tile, C_tile, result_hmma);
end
end
VX_shift_register #(
.DATAW (1 + $bits(D_tile)),
.DEPTH (`LATENCY_HMMA),
.RESETW (1)
) shift_reg (
.clk (clk),
.reset (reset),
.enable (~stall),
.data_in ({valid_in, result_hmma}),
.data_out ({valid_out, D_tile})
);
endmodule

View File

@@ -0,0 +1,30 @@
`include "VX_fpu_define.vh"
module VX_tensor_tb(
input clk,
input reset,
input valid_in,
input [3:0][1:0][31:0] A_tile,
input [1:0][3:0][31:0] B_tile,
input [3:0][3:0][31:0] C_tile,
output valid_out,
output [3:0][3:0][31:0] D_tile
);
VX_tensor_dpu #() tensor_core (
.clk(clk),
.reset(reset),
.stall(1'b0),
.valid_in(valid_in),
.A_tile(A_tile),
.B_tile(B_tile),
.C_tile(C_tile),
.valid_out(valid_out),
.D_tile(D_tile)
);
endmodule

View File

@@ -0,0 +1,89 @@
DESTDIR ?= .
RTL_DIR = ../../rtl
DPI_DIR = $(abspath ../../dpi)
SIM_DIR = ../../../sim
THIRD_PARTY_DIR = $(abspath ../../../third_party)
CONFIGS +=
PARAMS +=
CXXFLAGS += -std=c++17 -Wall -Wextra -Wfatal-errors -Wno-array-bounds
CXXFLAGS += -fPIC -Wno-maybe-uninitialized
CXXFLAGS += -fcoroutines
CXXFLAGS += -I../../.. -I../../common -I../../../../sim/common
CXXFLAGS += -I/$(THIRD_PARTY_DIR)/softfloat/source/include
CXXFLAGS += -I/$(DPI_DIR)
CXXFLAGS += $(CONFIGS)
LDFLAGS += $(THIRD_PARTY_DIR)/softfloat/build/Linux-x86_64-GCC/softfloat.a
# control RTL debug tracing states
DBG_TRACE_FLAGS += -DDBG_TRACE_CACHE_BANK
DBG_TRACE_FLAGS += -DDBG_TRACE_CACHE_MSHR
DBG_TRACE_FLAGS += -DDBG_TRACE_CACHE_TAG
DBG_TRACE_FLAGS += -DDBG_TRACE_CACHE_DATA
DBG_FLAGS += -DDEBUG_LEVEL=$(DEBUG) -DVCD_OUTPUT $(DBG_TRACE_FLAGS)
RTL_PKGS = $(RTL_DIR)/VX_gpu_pkg.sv
RTL_INCLUDE = -I$(RTL_DIR) -I$(DPI_DIR) -I$(RTL_DIR)/libs -I$(RTL_DIR)/fpu
# SRCS = cachesim.cpp testbench.cpp
SRCS += $(DPI_DIR)/util_dpi.cpp
SRCS += $(DPI_DIR)/float_dpi.cpp
SRCS += $(SIM_DIR)/common/rvfloats.cpp
SRCS += ./main.cpp
RTL_SRCS += $(RTL_DIR)/fpu/VX_tensor_dpu.sv
RTL_SRCS += $(RTL_DIR)/fpu/VX_tensor_tb.sv
TOP = VX_tensor_tb
VL_FLAGS = --exe
VL_FLAGS += --language 1800-2009 # -Wall -Wpedantic # --assert
VL_FLAGS += -Wno-DECLFILENAME -Wno-REDEFMACRO
VL_FLAGS += --x-initial unique --x-assign unique
VL_FLAGS += -DSIMULATION -DSV_DPI
VL_FLAGS += $(CONFIGS)
VL_FLAGS += $(PARAMS)
VL_FLAGS += $(RTL_INCLUDE)
VL_FLAGS += $(RTL_PKGS)
VL_FLAGS += --cc $(TOP) --top-module $(TOP)
VL_FLAGS += --timing
# Enable Verilator multithreaded simulation
THREADS ?= $(shell python -c 'import multiprocessing as mp; print(mp.cpu_count())')
VL_FLAGS += -j $(THREADS)
#VL_FLAGS += --threads $(THREADS)
# Debugigng
ifdef DEBUG
VL_FLAGS += --trace --trace-structs $(DBG_FLAGS)
CXXFLAGS += -g -O0 $(DBG_FLAGS)
else
VL_FLAGS += -DNDEBUG
CXXFLAGS += -O2 -DNDEBUG
endif
# Enable perf counters
ifdef PERF
VL_FLAGS += -DPERF_ENABLE
CXXFLAGS += -DPERF_ENABLE
endif
PROJECT = tensor
all: $(DESTDIR)/$(PROJECT)
$(DESTDIR)/$(PROJECT): $(SRCS) $(RTL_SRCS)
verilator --build $(VL_FLAGS) $(SRCS) -CFLAGS '$(CXXFLAGS)' -LDFLAGS '$(LDFLAGS)' -o ../$@
run: $(DESTDIR)/$(PROJECT)
$(DESTDIR)/$(PROJECT)
waves: trace.vcd
gtkwave -o trace.vcd
clean:
rm -rf obj_dir $(DESTDIR)/$(PROJECT)

197
hw/unittest/tensor/main.cpp Normal file
View File

@@ -0,0 +1,197 @@
// Copyright © 2019-2023
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "vl_simulator.h"
#include "VVX_tensor_tb.h"
#include <iostream>
#include <half.hpp>
#define MAX_TICKS 20
#ifndef TRACE_START_TIME
#define TRACE_START_TIME 0ull
#endif
#ifndef TRACE_STOP_TIME
#define TRACE_STOP_TIME -1ull
#endif
#define CHECK(x) \
do { \
if (x) \
break; \
std::cout << "FAILED: " << #x << std::endl; \
std::abort(); \
} while (false)
static uint64_t timestamp = 0;
static bool trace_enabled = false;
static uint64_t trace_start_time = TRACE_START_TIME;
static uint64_t trace_stop_time = TRACE_STOP_TIME;
double sc_time_stamp() {
return timestamp;
}
bool sim_trace_enabled() {
if (timestamp >= trace_start_time
&& timestamp < trace_stop_time)
return true;
return trace_enabled;
}
void sim_trace_enable(bool enable) {
trace_enabled = enable;
}
using Device = VVX_tensor_tb;
using half_float::half;
static_assert(sizeof(half) == 2);
uint32_t half2bits(half h) {
uint16_t half_bits;
memcpy(&half_bits, &h, sizeof(half));
return half_bits;
}
uint32_t float2bits(float f) {
uint32_t float_bits;
memcpy(&float_bits, &f, sizeof(f));
return float_bits;
}
float bits2float(uint32_t b) {
float f;
memcpy(&f, &b, sizeof(b));
return f;
}
// A is M * K, B is K * K * M, C is M * M, D is M * M
#define M 4
#define K 2
// row, column
float A_tile[M][K];
float B_tile[K][M];
float C_tile[M][M];
float D_tile[M][M];
void initialize_test_data() {
for (int i = 0; i < M; i += 1) {
for (int j = 0; j < K; j += 1) {
A_tile[i][j] = (float) (i * K + j);
}
}
for (int i = 0; i < K; i += 1) {
for (int j = 0; j < M; j += 1) {
B_tile[i][j] = (float) (j * K + i);
}
}
for (int i = 0; i < M; i += 1) {
for (int j = 0; j < M; j += 1) {
C_tile[i][j] = (float) (i * j);
}
}
}
void write_test_data(vl_simulator<Device>& sim) {
for (int i = 0; i < M; i += 1) {
for (int j = 0; j < K; j += 1) {
int index = (i * K + j);
uint32_t A_bits = float2bits(A_tile[i][j]);
sim->A_tile[index] = A_bits;
}
}
for (int i = 0; i < K; i += 1) {
for (int j = 0; j < M; j += 1) {
int index = (i * M + j);
uint32_t B_bits = float2bits(B_tile[i][j]);
sim->B_tile[index] = B_bits;
}
}
for (int i = 0; i < M; i += 1) {
for (int j = 0; j < M; j += 1) {
int index = (i * M + j);
uint32_t C_bits = float2bits(C_tile[i][j]);
sim->C_tile[index] = C_bits;
}
}
}
void read_result(vl_simulator<Device>& sim) {
for (int i = 0; i < M; i += 1) {
for (int j = 0; j < M; j += 1) {
int index = (i * M + j);
uint32_t D_bits = sim->D_tile[index];
float f = bits2float(D_bits);
D_tile[i][j] = f;
std::cout << f << " ";
}
std::cout << std::endl;
}
}
void expected() {
for (int i = 0; i < M; i += 1) {
for (int j = 0; j < M; j += 1) {
float accum = C_tile[i][j];
for (int k = 0; k < K; k += 1) {
accum += A_tile[i][k] * B_tile[k][j];
}
std::cout << accum << " ";
}
std::cout << std::endl;
}
}
int main(int argc, char **argv) {
// Initialize Verilators variables
Verilated::commandArgs(argc, argv);
vl_simulator<Device> sim;
initialize_test_data();
// run test
timestamp = sim.reset(0);
// advance clock
timestamp = sim.step(timestamp, 10);
sim->valid_in = 1;
write_test_data(sim);
timestamp = sim.step(timestamp, 2);
CHECK(sim->valid_out == 0);
sim->valid_in = 0;
timestamp = sim.step(timestamp, 2);
CHECK(sim->valid_out == 0);
timestamp = sim.step(timestamp, 2);
CHECK(sim->valid_out == 0);
timestamp = sim.step(timestamp, 2);
CHECK(sim->valid_out == 1);
read_result(sim);
timestamp = sim.step(timestamp, 2);
CHECK(sim->valid_out == 0);
expected();
std::cout << "PASSED!" << std::endl;
std::cout << "Simulation time: " << std::dec << timestamp/2 << " cycles" << std::endl;
return 0;
}

View File

@@ -48,6 +48,7 @@ void vx_wspawn_wait();
void vx_spawn_kernel(context_t * ctx, vx_spawn_kernel_cb callback, void * arg);
void vx_spawn_tasks(int num_tasks, vx_spawn_tasks_cb callback, void * arg);
void vx_spawn_tasks_contiguous(int num_tasks, vx_spawn_tasks_cb callback , void * arg);
void vx_serial(vx_serial_cb callback, void * arg);

View File

@@ -83,6 +83,38 @@ static void __attribute__ ((noinline)) spawn_tasks_rem_stub() {
(p_wspawn_args->callback)(task_id, p_wspawn_args->arg);
}
static void __attribute__ ((noinline)) spawn_tasks_contiguous_all_stub() {
int NT = vx_num_threads();
int NW = vx_num_warps();
int cid = vx_core_id();
int wid = vx_warp_id();
int tid = vx_thread_id();
wspawn_tasks_args_t* p_wspawn_args = (wspawn_tasks_args_t*)g_wspawn_args[cid];
int waves = p_wspawn_args->NWs + (wid < p_wspawn_args->RWs);
int offset = p_wspawn_args->offset + (NT * wid + tid);
vx_spawn_tasks_cb callback = p_wspawn_args->callback;
void* arg = p_wspawn_args->arg;
for (int wave_id = 0; wave_id < waves; ++wave_id) {
int task_id = offset + (wave_id * NT * NW);
callback(task_id, arg);
}
}
static void __attribute__ ((noinline)) spawn_tasks_contiguous_all_cb() {
// activate all threads
vx_tmc(-1);
// call stub routine
spawn_tasks_contiguous_all_stub();
// disable warp
// deadlock here on warps 1, 2, 3
vx_tmc_zero();
}
static void __attribute__ ((noinline)) spawn_tasks_all_cb() {
// activate all threads
vx_tmc(-1);
@@ -94,6 +126,79 @@ static void __attribute__ ((noinline)) spawn_tasks_all_cb() {
vx_tmc_zero();
}
void vx_spawn_tasks_contiguous(int num_tasks, vx_spawn_tasks_cb callback , void * arg) {
// device specs
int NC = vx_num_cores();
int NW = vx_num_warps();
int NT = vx_num_threads();
// current core id
int core_id = vx_core_id();
if (core_id >= NUM_CORES_MAX)
return;
// calculate necessary active cores
int WT = NW * NT;
int nC = (num_tasks > WT) ? (num_tasks / WT) : 1;
int nc = MIN(nC, NC);
if (core_id >= nc)
return; // terminate extra cores
// number of tasks per core
int tasks_per_core = num_tasks / nc;
int tasks_per_core_n1 = tasks_per_core;
if (core_id == (nc-1)) {
int rem = num_tasks - (nc * tasks_per_core);
tasks_per_core_n1 += rem; // last core also executes remaining tasks
}
// number of tasks per warp
int TW = tasks_per_core_n1 / NT; // occupied warps
int rT = tasks_per_core_n1 - TW * NT; // remaining threads
int fW = 1, rW = 0;
if (TW >= NW) {
fW = TW / NW; // full warps iterations
rW = TW - fW * NW; // remaining warps
}
wspawn_tasks_args_t wspawn_args = { callback, arg, core_id * tasks_per_core, fW, rW };
g_wspawn_args[core_id] = &wspawn_args;
if (TW >= 1) {
// execute callback on other warps
int nw = MIN(TW, NW);
vx_wspawn(nw, spawn_tasks_contiguous_all_cb);
// activate all threads
vx_tmc(-1);
// call stub routine
spawn_tasks_contiguous_all_stub();
// back to single-threaded
vx_tmc_one();
// wait for spawn warps to terminate
// deadlock here on warp 0!
vx_wspawn_wait();
}
if (rT != 0) {
// adjust offset
wspawn_args.offset += (tasks_per_core_n1 - rT);
// activate remaining threads
int tmask = (1 << rT) - 1;
vx_tmc(tmask);
// call stub routine
spawn_tasks_rem_stub();
// back to single-threaded
vx_tmc_one();
}
}
void vx_spawn_tasks(int num_tasks, vx_spawn_tasks_cb callback , void * arg) {
// device specs
int NC = vx_num_cores();

View File

@@ -1,19 +1,23 @@
all:
$(MAKE) -C conform
$(MAKE) -C hello
$(MAKE) -C fibonacci
$(MAKE) -C fibonacci
$(MAKE) -C reductions
run-simx:
$(MAKE) -C conform run-simx
$(MAKE) -C hello run-simx
$(MAKE) -C fibonacci run-simx
$(MAKE) -C reductions run-simx
run-rtlsim:
$(MAKE) -C conform run-rtlsim
$(MAKE) -C hello run-rtlsim
$(MAKE) -C fibonacci run-rtlsim
$(MAKE) -C reductions run-rtlsim
clean:
$(MAKE) -C conform clean
$(MAKE) -C hello clean
$(MAKE) -C fibonacci clean
$(MAKE) -C reductions clean

View File

@@ -33,7 +33,7 @@ $(PROJECT).dump: $(PROJECT).elf
$(PROJECT).bin: $(PROJECT).elf
$(CP) -O binary $(PROJECT).elf $(PROJECT).bin
$(PROJECT).elf: $(SRCS)
$(PROJECT).elf: $(SRCS) $(DEPS)
$(CC) $(CFLAGS) $(SRCS) $(LDFLAGS) -o $(PROJECT).elf
run-rtlsim: $(PROJECT).bin

View File

@@ -0,0 +1,5 @@
PROJECT = reductions
SRCS = main.cpp
include ../common.mk

View File

@@ -0,0 +1,216 @@
#define RISCV_CUSTOM2 0x5B
#define ADD_FUNC7 0b0000000
#define ADDU_FUNC7 0b1000000
#define MIN_FUNC7 0b0000001
#define MINU_FUNC7 0b1000001
#define MAX_FUNC7 0b0000010
#define MAXU_FUNC7 0b1000010
#define AND_FUNC7 0b0000011
#define OR_FUNC7 0b0000100
#define XOR_FUNC7 0b0000101
/*
6'h0: begin
op_type = func7[6] ? `INST_RED_ADDU : `INST_RED_ADD;
end
6'h1: begin
op_type = func7[6] ? `INST_RED_MINU : `INST_RED_MIN;
end
6'h2: begin
op_type = func7[6] ? `INST_RED_MAXU : `INST_RED_MAX;
end
6'h3: begin
op_type = `INST_RED_AND;
end
6'h4: begin
op_type = `INST_RED_OR;
end
6'h5: begin
op_type = `INST_RED_XOR;
end
*/
#include <vx_intrinsics.h>
#include <stdio.h>
#include <vx_print.h>
int x[4] = {3, 7, 2, 5};
int y = -1;
inline int vx_add_reduce(int v) {
int ret;
asm volatile (".insn r %2, 0, %3, %0, %1, x0" : "=r"(ret) : "r"(v), "i"(RISCV_CUSTOM2), "i"(ADD_FUNC7));
return ret;
}
inline int vx_min_reduce(int v) {
int ret;
asm volatile (".insn r %2, 0, %3, %0, %1, x0" : "=r"(ret) : "r"(v), "i"(RISCV_CUSTOM2), "i"(MIN_FUNC7));
return ret;
}
inline unsigned vx_minu_reduce(unsigned v) {
unsigned ret;
asm volatile (".insn r %2, 0, %3, %0, %1, x0" : "=r"(ret) : "r"(v), "i"(RISCV_CUSTOM2), "i"(MINU_FUNC7));
return ret;
}
inline int vx_max_reduce(int v) {
int ret;
asm volatile (".insn r %2, 0, %3, %0, %1, x0" : "=r"(ret) : "r"(v), "i"(RISCV_CUSTOM2), "i"(MAX_FUNC7));
return ret;
}
inline unsigned vx_maxu_reduce(unsigned v) {
unsigned ret;
asm volatile (".insn r %2, 0, %3, %0, %1, x0" : "=r"(ret) : "r"(v), "i"(RISCV_CUSTOM2), "i"(MAXU_FUNC7));
return ret;
}
inline unsigned vx_and_reduce(unsigned v) {
unsigned ret;
asm volatile (".insn r %2, 0, %3, %0, %1, x0" : "=r"(ret) : "r"(v), "i"(RISCV_CUSTOM2), "i"(AND_FUNC7));
return ret;
}
inline unsigned vx_or_reduce(unsigned v) {
unsigned ret;
asm volatile (".insn r %2, 0, %3, %0, %1, x0" : "=r"(ret) : "r"(v), "i"(RISCV_CUSTOM2), "i"(OR_FUNC7));
return ret;
}
inline unsigned vx_xor_reduce(unsigned v) {
unsigned ret;
asm volatile (".insn r %2, 0, %3, %0, %1, x0" : "=r"(ret) : "r"(v), "i"(RISCV_CUSTOM2), "i"(XOR_FUNC7));
return ret;
}
void test_add_reduce() {
vx_tmc(-1);
int tid = vx_thread_id();
int v = x[tid];
int reduced = vx_add_reduce(v);
vx_tmc(1);
y = reduced;
}
unsigned unsigned_vector[4] = {(unsigned)-1, 0, (unsigned)-2, 5};
void test_min_reduce() {
vx_tmc(-1);
int tid = vx_thread_id();
int v = unsigned_vector[tid];
int reduced = vx_min_reduce(v);
vx_tmc(1);
y = reduced;
}
void test_max_reduce() {
vx_tmc(-1);
int tid = vx_thread_id();
int v = unsigned_vector[tid];
int reduced = vx_max_reduce(v);
vx_tmc(1);
y = reduced;
}
void test_minu_reduce() {
vx_tmc(-1);
int tid = vx_thread_id();
unsigned v = unsigned_vector[tid];
unsigned reduced = vx_minu_reduce(v);
vx_tmc(1);
y = reduced;
}
void test_maxu_reduce() {
vx_tmc(-1);
int tid = vx_thread_id();
unsigned v = unsigned_vector[tid];
unsigned reduced = vx_maxu_reduce(v);
vx_tmc(1);
y = reduced;
}
unsigned bit_vectors[4] = {0b11010110000111001100010100100110, 0b10010100011010001010000000001110, 0b10001001010111110001110000000010, 0b00010011010100101101110111001111};
void test_and_reduce() {
vx_tmc(-1);
int tid = vx_thread_id();
unsigned v = bit_vectors[tid];
unsigned reduced = vx_and_reduce(v);
vx_tmc(1);
y = reduced;
}
void test_or_reduce() {
vx_tmc(-1);
int tid = vx_thread_id();
unsigned v = bit_vectors[tid];
unsigned reduced = vx_or_reduce(v);
vx_tmc(1);
y = reduced;
}
void test_xor_reduce() {
vx_tmc(-1);
int tid = vx_thread_id();
unsigned v = bit_vectors[tid];
unsigned reduced = vx_xor_reduce(v);
vx_tmc(1);
y = reduced;
}
int main()
{
int expected;
test_add_reduce();
vx_printf("add reduce result: %d\n", y);
vx_printf("expected: %d\n", x[0] + x[1] + x[2] + x[3]);
test_min_reduce();
vx_printf("min reduce result: %d\n", y);
expected = MIN((int)unsigned_vector[0], MIN((int)unsigned_vector[1], MIN((int)unsigned_vector[2], (int)unsigned_vector[3])));
vx_printf("expected: %d\n", expected);
test_max_reduce();
vx_printf("max reduce result: %d\n", y);
expected = MAX((int)unsigned_vector[0], MAX((int)unsigned_vector[1], MAX((int)unsigned_vector[2], (int)unsigned_vector[3])));
vx_printf("expected: %d\n", expected);
test_minu_reduce();
vx_printf("minu reduce result: %d\n", y);
expected = MIN(unsigned_vector[0], MIN(unsigned_vector[1], MIN(unsigned_vector[2], unsigned_vector[3])));
vx_printf("expected: %d\n", expected);
test_maxu_reduce();
vx_printf("maxu reduce result: %d\n", y);
expected = MAX(unsigned_vector[0], MAX(unsigned_vector[1], MAX(unsigned_vector[2], unsigned_vector[3])));
vx_printf("expected: %d\n", expected);
test_and_reduce();
vx_printf("and reduce result: %d\n", y);
vx_printf("expected: %d\n", bit_vectors[0] & bit_vectors[1] & bit_vectors[2] & bit_vectors[3]);
test_or_reduce();
vx_printf("or reduce result: %d\n", y);
vx_printf("expected: %d\n", bit_vectors[0] | bit_vectors[1] | bit_vectors[2] | bit_vectors[3]);
test_xor_reduce();
vx_printf("xor reduce result: %d\n", y);
vx_printf("expected: %d\n", bit_vectors[0] ^ bit_vectors[1] ^ bit_vectors[2] ^ bit_vectors[3]);
return 0;
}

View File

@@ -0,0 +1,8 @@
PROJECT = tensor
SRCS = main.cpp
DEPS = a_matrix.h
DEPS += b_matrix.h
DEPS += c_matrix.h
include ../common.mk

View File

@@ -0,0 +1,95 @@
import numpy as np
import struct
A_array = np.zeros((16, 8))
B_array = np.zeros((8, 16))
C_array = np.zeros((16, 16))
file = input("simulator output filename: ")
def hex2float(float_hex_str):
# print(float_hex_str.strip())
return struct.unpack(">f",struct.pack(">i",int(float_hex_str,16)))[0]
def C_index(threadgroup, thread, register):
"""
col = ((tg % 4) / 2) * 8;
row = (tg * 8) % 16;
row += (tg / 4) * 4;
asm volatile ("flw f16, %0" :: "m"(C[row+0][col+0]));
asm volatile ("flw f17, %0" :: "m"(C[row+0][col+1]));
asm volatile ("flw f18, %0" :: "m"(C[row+2][col+0]));
asm volatile ("flw f19, %0" :: "m"(C[row+2][col+1]));
asm volatile ("flw f20, %0" :: "m"(C[row+0][col+4]));
asm volatile ("flw f21, %0" :: "m"(C[row+0][col+5]));
asm volatile ("flw f22, %0" :: "m"(C[row+2][col+4]));
asm volatile ("flw f23, %0" :: "m"(C[row+2][col+5]));
"""
col = ((threadgroup % 4) // 2) * 8
row = (threadgroup * 8) % 16
row += (threadgroup // 4) * 4
offsets = [(0, 0), (0, 1), (2, 0), (2, 1), (0, 4), (0, 5), (2, 4), (2, 5)]
offset = offsets[register-16]
row += offset[0]
col += offset[1]
thread_offsets = [(0, 0), (1, 0), (0, 2), (1, 2)]
thread_offset = thread_offsets[thread % 4]
row += thread_offset[0]
col += thread_offset[1]
if C_array[row, col] != 0:
print("bad")
return (row, col)
with open(file) as f:
for line in f.readlines():
line = line.strip()
if "warp" in line:
a, b, c = line.split(',')
_, a = a.split(' ')
_, b = b.strip().split(' ')
c, d = c.strip().split(':')
_, c = c.split(' ')
warp = int(a)
thread = int(b)
register = int(c)
value = d.strip()
if warp != 0:
continue
if not (32 <= register < 32+24):
continue
register = register - 32
# threadgroups 0, 4, 1, 5 have all elements of A
threadgroup = thread // 4
if threadgroup in [0, 4, 1, 5]:
row = [0, 4, 1, 5].index(threadgroup) * 4 + thread % 4
if 0 <= register < 8:
A_array[row, register] = hex2float(value)
if threadgroup in [0, 4, 2, 6]:
col = [0, 4, 2, 6].index(threadgroup) * 4 + thread % 4
if 8 <= register < 16:
B_array[register-8, col] = hex2float(value)
if 16 <= register < 24:
# print(value)
C_array[C_index(threadgroup, thread, register)] = hex2float(value)
expected = np.load("abc.npz")
expected_A = expected['A_array']
expected_B = expected['B_array']
expected_C = expected['C_array']
expected_C = expected_C + expected_A @ expected_B
print(expected_C[0:8, 0:8])
print(C_array[0:8, 0:8])
print((expected_C - C_array)[0:8, 0:8])
assert np.allclose(expected_A, A_array)
assert np.allclose(expected_B, B_array)
assert np.allclose(expected_C, C_array)

View File

@@ -0,0 +1,32 @@
import numpy as np
A_array = np.random.rand(16, 8)
B_array = np.random.rand(8, 16)
C_array = np.random.rand(16, 16)
# A_array = np.zeros((16, 8))
# B_array = np.zeros((8, 16))
# A_array[0,:] = 1.0
# B_array[:,4] = 1.0
# C_array = np.zeros((16, 16))
# for i in range(16):
# for j in range(16):
# C_array[i,j] = i * 16 + j
with open('a_matrix.h', 'w') as f:
for i in range(A_array.shape[0]):
for j in range(A_array.shape[1]):
f.write(f'{A_array[i,j]}f, ')
f.write('\n')
with open('b_matrix.h', 'w') as f:
for i in range(B_array.shape[0]):
for j in range(B_array.shape[1]):
f.write(f'{B_array[i,j]}f, ')
f.write('\n')
with open('c_matrix.h', 'w') as f:
for i in range(C_array.shape[0]):
for j in range(C_array.shape[1]):
f.write(f'{C_array[i,j]}f, ')
f.write('\n')
np.savez("abc", A_array=A_array, B_array=B_array, C_array=C_array)

View File

@@ -0,0 +1,96 @@
#define RISCV_CUSTOM3 0x7B
#include <vx_intrinsics.h>
#include <stdio.h>
#include <vx_print.h>
inline void vx_wmma() {
asm volatile (".insn r %0, 0, 0, x0, x0, x0" :: "i"(RISCV_CUSTOM3));
}
#include "test_data.h"
void vx_wmma_load() {
int tid = vx_thread_id();
int tg = tid / 4;
// load A
int row = tid % 4;
row += (tg * 8) % 16;
row += (tg / 4) * 4;
asm volatile ("flw f0, %0" :: "m"(A[row][0]));
asm volatile ("flw f1, %0" :: "m"(A[row][1]));
asm volatile ("flw f2, %0" :: "m"(A[row][2]));
asm volatile ("flw f3, %0" :: "m"(A[row][3]));
asm volatile ("flw f4, %0" :: "m"(A[row][4]));
asm volatile ("flw f5, %0" :: "m"(A[row][5]));
asm volatile ("flw f6, %0" :: "m"(A[row][6]));
asm volatile ("flw f7, %0" :: "m"(A[row][7]));
// load B
int col = tid % 4;
col += ((tg % 4) / 2) * 8;
col += (tg / 4) * 4;
asm volatile ("flw f8 , %0" :: "m"(B[0][col]));
asm volatile ("flw f9 , %0" :: "m"(B[1][col]));
asm volatile ("flw f10, %0" :: "m"(B[2][col]));
asm volatile ("flw f11, %0" :: "m"(B[3][col]));
asm volatile ("flw f12, %0" :: "m"(B[4][col]));
asm volatile ("flw f13, %0" :: "m"(B[5][col]));
asm volatile ("flw f14, %0" :: "m"(B[6][col]));
asm volatile ("flw f15, %0" :: "m"(B[7][col]));
// load C
col = ((tg % 4) / 2) * 8;
row = (tg * 8) % 16;
row += (tg / 4) * 4;
row += (tid % 4) % 2;
col += ((tid % 4) / 2) * 2;
asm volatile ("flw f16, %0" :: "m"(C[row+0][col+0]));
asm volatile ("flw f17, %0" :: "m"(C[row+0][col+1]));
asm volatile ("flw f18, %0" :: "m"(C[row+2][col+0]));
asm volatile ("flw f19, %0" :: "m"(C[row+2][col+1]));
asm volatile ("flw f20, %0" :: "m"(C[row+0][col+4]));
asm volatile ("flw f21, %0" :: "m"(C[row+0][col+5]));
asm volatile ("flw f22, %0" :: "m"(C[row+2][col+4]));
asm volatile ("flw f23, %0" :: "m"(C[row+2][col+5]));
}
float results[32*8];
void store_wmma_result() {
int tid = vx_thread_id();
asm volatile ("fsw f16, %0" :: "m"(results[tid*8+0]));
asm volatile ("fsw f17, %0" :: "m"(results[tid*8+1]));
asm volatile ("fsw f18, %0" :: "m"(results[tid*8+2]));
asm volatile ("fsw f19, %0" :: "m"(results[tid*8+3]));
asm volatile ("fsw f20, %0" :: "m"(results[tid*8+4]));
asm volatile ("fsw f21, %0" :: "m"(results[tid*8+5]));
asm volatile ("fsw f22, %0" :: "m"(results[tid*8+6]));
asm volatile ("fsw f23, %0" :: "m"(results[tid*8+7]));
}
void print_wmma_result() {
for (int tid = 0; tid < 32; tid += 1) {
for (int reg = 0; reg < 8; reg += 1) {
vx_printf("thread %d, f%d: %x\n", tid, 16+reg, *((int*) &results[tid*8+reg]));
}
}
}
int main()
{
vx_tmc(-1);
vx_wmma_load();
vx_wmma();
store_wmma_result();
vx_tmc(1);
// print_wmma_result();
return 0;
}

View File

@@ -0,0 +1,11 @@
float A[16][8] = {
#include "a_matrix.h"
};
float B[8][16] = {
#include "b_matrix.h"
};
float C[16][16] = {
#include "c_matrix.h"
};

View File

@@ -0,0 +1,9 @@
PROJECT = sgemm_tcore
SRCS = main.cpp common.h
VX_SRCS = kernel.cpp
OPTS ?= -n16
include ../common.mk

View File

@@ -0,0 +1,18 @@
#ifndef _COMMON_H_
#define _COMMON_H_
#include <cstdint>
#define KERNEL_ARG_DEV_MEM_ADDR 0x7fff0000
#define DEV_SMEM_START_ADDR 0xff000000
typedef struct {
uint32_t dim_m;
uint32_t dim_n;
uint32_t dim_k;
uint64_t addr_a;
uint64_t addr_b;
uint64_t addr_c;
} kernel_arg_t;
#endif

View File

@@ -0,0 +1,285 @@
#define RISCV_CUSTOM3 0x7B
#include <stdint.h>
#include <vx_intrinsics.h>
#include <vx_print.h>
#include <vx_spawn.h>
#include "common.h"
#define BM 16
#define BN 16
#define BK 8
inline void vx_wmma() {
asm volatile (".insn r %0, 0, 0, x0, x0, x0" :: "i"(RISCV_CUSTOM3));
}
void vx_wmma_load(volatile float *smem_A, volatile float *smem_B, int warp_x, int warp_y, int thread_in_warp) {
int tid = thread_in_warp;
int tg = tid / 4;
// load A
int row = tid % 4;
row += (tg * 8) % 16;
row += (tg / 4) * 4;
int smem_A_m = 32;
int smem_A_n = 8;
int smem_B_m = 8;
int smem_B_n = 32;
int A_offset = (row + BM * warp_y) * smem_A_n;
asm volatile ("flw f0, %0" :: "m"(smem_A[A_offset + 0]));
asm volatile ("flw f1, %0" :: "m"(smem_A[A_offset + 1]));
asm volatile ("flw f2, %0" :: "m"(smem_A[A_offset + 2]));
asm volatile ("flw f3, %0" :: "m"(smem_A[A_offset + 3]));
asm volatile ("flw f4, %0" :: "m"(smem_A[A_offset + 4]));
asm volatile ("flw f5, %0" :: "m"(smem_A[A_offset + 5]));
asm volatile ("flw f6, %0" :: "m"(smem_A[A_offset + 6]));
asm volatile ("flw f7, %0" :: "m"(smem_A[A_offset + 7]));
// load B
int col = tid % 4;
col += ((tg % 4) / 2) * 8;
col += (tg / 4) * 4;
asm volatile ("flw f8 , %0" :: "m"(smem_B[(0 * smem_B_n) + warp_x * BN + col]));
asm volatile ("flw f9 , %0" :: "m"(smem_B[(1 * smem_B_n) + warp_x * BN + col]));
asm volatile ("flw f10, %0" :: "m"(smem_B[(2 * smem_B_n) + warp_x * BN + col]));
asm volatile ("flw f11, %0" :: "m"(smem_B[(3 * smem_B_n) + warp_x * BN + col]));
asm volatile ("flw f12, %0" :: "m"(smem_B[(4 * smem_B_n) + warp_x * BN + col]));
asm volatile ("flw f13, %0" :: "m"(smem_B[(5 * smem_B_n) + warp_x * BN + col]));
asm volatile ("flw f14, %0" :: "m"(smem_B[(6 * smem_B_n) + warp_x * BN + col]));
asm volatile ("flw f15, %0" :: "m"(smem_B[(7 * smem_B_n) + warp_x * BN + col]));
}
inline void initialize_C() {
// initialize C to zeros
asm volatile ("fmv.w.x f16, x0");
asm volatile ("fmv.w.x f17, x0");
asm volatile ("fmv.w.x f18, x0");
asm volatile ("fmv.w.x f19, x0");
asm volatile ("fmv.w.x f20, x0");
asm volatile ("fmv.w.x f21, x0");
asm volatile ("fmv.w.x f22, x0");
asm volatile ("fmv.w.x f23, x0");
}
inline void write_results(
volatile float *local_warp_results,
int thread_in_warp,
int warp_x,
int warp_y,
int dim_m,
int dim_n,
float *C,
int threadblock_id_x,
int threadblock_id_y
) {
int tid = thread_in_warp;
int tg = tid / 4;
asm volatile ("fsw f16, %0" :: "m"(local_warp_results[tid*8+0]));
asm volatile ("fsw f17, %0" :: "m"(local_warp_results[tid*8+1]));
asm volatile ("fsw f18, %0" :: "m"(local_warp_results[tid*8+2]));
asm volatile ("fsw f19, %0" :: "m"(local_warp_results[tid*8+3]));
asm volatile ("fsw f20, %0" :: "m"(local_warp_results[tid*8+4]));
asm volatile ("fsw f21, %0" :: "m"(local_warp_results[tid*8+5]));
asm volatile ("fsw f22, %0" :: "m"(local_warp_results[tid*8+6]));
asm volatile ("fsw f23, %0" :: "m"(local_warp_results[tid*8+7]));
/*
col = ((threadgroup % 4) // 2) * 8
row = (threadgroup * 8) % 16
row += (threadgroup // 4) * 4
offsets = [(0, 0), (0, 1), (2, 0), (2, 1), (0, 4), (0, 5), (2, 4), (2, 5)]
offset = offsets[register-16]
row += offset[0]
col += offset[1]
thread_offsets = [(0, 0), (1, 0), (0, 2), (1, 2)]
thread_offset = thread_offsets[thread % 4]
row += thread_offset[0]
col += thread_offset[1]
return (row, col)
*/
int local_col = ((tg % 4) / 2) * 8;
int local_row = (tg * 8) % 16;
local_row += (tg / 4) * 4;
// int row_offsets[8] = {0, 0, 2, 2, 0, 0, 2, 2};
// int col_offsets[8] = {0, 1, 0, 1, 4, 5, 4, 5};
// int thread_row_offsets[4] = {0, 1, 0, 1};
// int thread_col_offsets[4] = {0, 0, 2, 2};
int thread_row_offset = (tid % 4) % 2;
int thread_col_offset = ((tid % 4) / 2) * 2;
float *global_offset_C = C + (threadblock_id_y * BM * 2 + warp_y * BM) * dim_n + threadblock_id_x * BN * 2 + warp_x * BM;
for (int i = 0; i < 8; i += 1) {
int row_offset = ((i / 2) % 2) * 2;
int col_offset = (i / 4) * 4 + i % 2;
int adjusted_local_row = local_row + thread_row_offset + row_offset;
int adjusted_local_col = local_col + thread_col_offset + col_offset;
float v = local_warp_results[tid*8+i];
global_offset_C[adjusted_local_row * dim_n + adjusted_local_col] = v;
}
}
void threadblock_barrier(unsigned int tid_in_threadblock, unsigned int barrier_id, unsigned int count) {
vx_fence();
vx_barrier(barrier_id, count);
}
void thread_block_gemm(kernel_arg_t *__UNIFORM__ arg,
const uint32_t tid_in_threadblock,
const uint32_t threadblock_dim_x,
const uint32_t threadblock_dim_y,
const uint32_t threadblock_id_x,
const uint32_t threadblock_id_y,
const uint32_t threadblock_id,
float *sharedmem_per_threadblock) {
const float *A = (const float *)arg->addr_a;
const float *B = (const float *)arg->addr_b;
float *C = (float *)arg->addr_c;
const uint32_t dim_m = arg->dim_m;
const uint32_t dim_n = arg->dim_n;
const uint32_t dim_k = arg->dim_k;
// FIXME: Output block size is assumed to be square, i.e. BM == BN
// const uint32_t BM = threadblock_dim_y;
// const uint32_t BN = threadblock_dim_y;
// const uint32_t BK = threadblock_dim_x;
// constexpr uint32_t BM = 8;
// constexpr uint32_t BN = 8;
// constexpr uint32_t BK = 2;
const uint32_t warp_in_threadblock = tid_in_threadblock / 32;
const uint32_t tid_in_warp = tid_in_threadblock % 32;
const uint32_t warp_x = warp_in_threadblock % 2;
const uint32_t warp_y = warp_in_threadblock / 2;
const uint32_t global_a_row = threadblock_dim_y * threadblock_id_y;
// 32 * 8 block of A, being loaded by 4 warps
const uint32_t local_a_row = warp_y * BM + warp_x * (BM / 2) + (tid_in_warp / BK);
const uint32_t local_a_col = tid_in_warp % BK;
// 8 * 32 block of B, being loaded by 4 warps
// do a fat coalesced load
const uint32_t global_b_col = threadblock_dim_x * threadblock_id_x;
const uint32_t local_b_row = warp_in_threadblock;
const uint32_t local_b_col = tid_in_warp;
volatile float *local_a = sharedmem_per_threadblock;
const size_t local_a_elems = (threadblock_dim_y * BK);
volatile float *local_b = sharedmem_per_threadblock + local_a_elems;
const size_t local_b_elems = (threadblock_dim_x * BK);
volatile float *local_warp_results = local_b + local_b_elems + (warp_in_threadblock * BM * BN);
// clear out C
initialize_C();
for (uint32_t k = 0; k < dim_k; k += BK) {
// Data move from GMEM to SMEM
//
// Make sure global offset values for A and B are contiguous between
// neighboring threads to ensure GMEM coalescing. (not possible for A here, but for B it's doable)
// coalesced load from global memory -> unit-stride store into shared memory
uint32_t global_a_offset =
dim_k * (global_a_row + local_a_row) + (k + local_a_col);
uint32_t shared_a_offset =
BK * local_a_row + local_a_col;
local_a[shared_a_offset] = A[global_a_offset];
global_a_offset += dim_k * (BM / 4);
shared_a_offset += BK * (BM / 4);
local_a[shared_a_offset] = A[global_a_offset];
uint32_t global_b_offset =
dim_n * (k + local_b_row) + (global_b_col + local_b_col);
uint32_t shared_b_offset =
(BN * 2) * (local_b_row) + local_b_col;
local_b[shared_b_offset] = B[global_b_offset];
global_b_offset += dim_n * (BK / 2);
shared_b_offset += (BN * 2) * (BK / 2);
local_b[shared_b_offset] = B[global_b_offset];
// want all 4 warps to reach barrier before moving on (just use barrier 0 for now)
threadblock_barrier(tid_in_threadblock, 0, 4);
// perform wmma
vx_wmma_load(local_a, local_b, warp_x, warp_y, tid_in_warp);
vx_wmma();
// same as above
threadblock_barrier(tid_in_threadblock, 0, 4);
}
write_results(
local_warp_results,
tid_in_warp,
warp_x,
warp_y,
dim_m,
dim_n,
C,
threadblock_id_x,
threadblock_id_y
);
}
void kernel_body(int task_id, kernel_arg_t *__UNIFORM__ arg) {
// @perf: All threads are running these compute whose result is mostly same
// across the threadblock
const int NT = 32; // vx_num_threads();
const int NW = 4; // vx_num_warps();
const uint32_t threads_per_threadblock = NT * NW;
// matches 4 warp capacity
const uint32_t threadblock_dim_x = 2 * BN;
const uint32_t threadblock_dim_y = 2 * BM;
const int threadblock_id = task_id / threads_per_threadblock;
const int tid_in_threadblock = task_id % threads_per_threadblock;
const uint32_t dim_m = arg->dim_m;
const uint32_t dim_n = arg->dim_n;
const uint32_t dim_n_in_blocks = dim_n / threadblock_dim_x;
const int threadblock_id_x = threadblock_id % dim_n_in_blocks;
const int threadblock_id_y = threadblock_id / dim_n_in_blocks;
// "static" shared memory allocation. This would determine threadblock
// occupancy of a single cluster
// only 1 threadblock running at a time, so this is ok
float *sharedmem_per_threadblock =
(float *)DEV_SMEM_START_ADDR; // + (2 * BM * BK) + (2 * BN * BK) * threadblock_id;
thread_block_gemm(arg, tid_in_threadblock, threadblock_dim_x,
threadblock_dim_y, threadblock_id_x, threadblock_id_y,
threadblock_id, sharedmem_per_threadblock);
}
int main() {
kernel_arg_t *arg = (kernel_arg_t *)KERNEL_ARG_DEV_MEM_ADDR;
int NT = vx_num_threads();
// TODO: add support for edge-case (m, n not divisible by 16)
const uint32_t grid_size = arg->dim_m * arg->dim_n * NT / (BM * BN);
// for now, simplifying assumption of just 1 core
// vx_spawn_tasks_contiguous first runs warps 1 through NW, then NW+1 through 2*NW, etc.
// we can thus treat 1 through NW as a single threadblock for the purposes of the optimization.
vx_spawn_tasks_contiguous(grid_size, (vx_spawn_tasks_cb)kernel_body, arg);
return 0;
}

View File

@@ -0,0 +1,270 @@
#include <iostream>
#include <fstream>
#include <unistd.h>
#include <string.h>
#include <vortex.h>
#include <vector>
#include "common.h"
#define RT_CHECK(_expr) \
do { \
int _ret = _expr; \
if (0 == _ret) \
break; \
printf("Error: '%s' returned %d!\n", #_expr, (int)_ret); \
cleanup(); \
exit(-1); \
} while (false)
///////////////////////////////////////////////////////////////////////////////
const char* kernel_file = "kernel.bin";
uint32_t count = 0;
std::vector<float> src_a_data;
std::vector<float> src_b_data;
std::vector<float> ref_data;
vx_device_h device = nullptr;
std::vector<uint8_t> staging_buf;
kernel_arg_t kernel_arg = {};
static void show_usage() {
std::cout << "Vortex Test." << std::endl;
std::cout << "Usage: [-k: kernel] [-n words] [-h: help]" << std::endl;
}
static void parse_args(int argc, char **argv) {
int c;
while ((c = getopt(argc, argv, "n:k:h?")) != -1) {
switch (c) {
case 'n':
count = atoi(optarg);
break;
case 'k':
kernel_file = optarg;
break;
case 'h':
case '?': {
show_usage();
exit(0);
} break;
default:
show_usage();
exit(-1);
}
}
}
void cleanup() {
if (device) {
vx_mem_free(device, kernel_arg.addr_a);
vx_mem_free(device, kernel_arg.addr_b);
vx_mem_free(device, kernel_arg.addr_c);
vx_dev_close(device);
}
}
void generate_source_matrix(uint32_t dim_m, uint32_t dim_n, uint32_t dim_k) {
src_a_data.resize(dim_m * dim_k);
src_b_data.resize(dim_k * dim_n);
for (uint32_t i = 0; i < src_a_data.size(); ++i) {
src_a_data[i] = static_cast<float>(i);
std::cout << "A: " << i << ": value=" << src_a_data[i] << std::endl;
}
for (uint32_t i = 0; i < src_b_data.size(); ++i) {
src_b_data[i] = static_cast<float>(i);
std::cout << "B: " << i << ": value=" << src_b_data[i] << std::endl;
}
}
void generate_reference_matmul(uint32_t dim_m, uint32_t dim_n, uint32_t dim_k) {
ref_data.resize(dim_m * dim_n);
for (uint32_t i = 0; i < dim_m; ++i) {
for (uint32_t j = 0; j < dim_n; ++j) {
float ref = 0.0f;
for (uint32_t k = 0; k < dim_k; ++k) {
ref += src_a_data[dim_k * i + k] * src_b_data[dim_n * k + j];
}
ref_data.at(dim_n * i + j) = ref;
}
}
}
int run_test(const kernel_arg_t& kernel_arg,
uint32_t buf_size,
uint32_t dim_m, uint32_t dim_n) {
// start device
std::cout << "start device" << std::endl;
RT_CHECK(vx_start(device));
// wait for completion
std::cout << "wait for completion" << std::endl;
RT_CHECK(vx_ready_wait(device, VX_MAX_TIMEOUT));
// download destination buffer
std::cout << "download destination buffer" << std::endl;
RT_CHECK(vx_copy_from_dev(device, staging_buf.data(), kernel_arg.addr_c, buf_size));
// verify result
std::cout << "verify result" << std::endl;
{
int errors = 0;
auto buf_ptr = (float*)staging_buf.data();
for (uint32_t i = 0; i < dim_m * dim_n; ++i) {
float ref = ref_data.at(i);
float cur = buf_ptr[i];
if (std::abs((cur - ref) / ref) > 1e-6) {
std::cout << "error at result #" << std::dec << i
<< std::hex << ": actual=" << cur << ", expected=" << ref << std::endl;
++errors;
}
}
if (errors != 0) {
std::cout << "Found " << std::dec << errors << " errors!" << std::endl;
std::cout << "FAILED!" << std::endl;
return 1;
}
}
return 0;
}
int main(int argc, char *argv[]) {
// parse command arguments
parse_args(argc, argv);
if (count == 0) {
count = 1;
}
std::srand(50);
// open device connection
std::cout << "open device connection" << std::endl;
RT_CHECK(vx_dev_open(&device));
// FIXME: hardcoded
uint32_t dim_m = 64;
uint32_t dim_n = 64;
uint32_t dim_k = 64;
generate_source_matrix(dim_m, dim_n, dim_k);
generate_reference_matmul(dim_m, dim_n, dim_k);
uint32_t src_a_buf_size = src_a_data.size() * sizeof(src_a_data[0]);
uint32_t src_b_buf_size = src_b_data.size() * sizeof(src_b_data[0]);
uint32_t dst_buf_size = ref_data.size() * sizeof(src_a_data[0]);
std::cout << "buffer size: " << dst_buf_size << " bytes" << std::endl;
// upload program
std::cout << "upload program" << std::endl;
RT_CHECK(vx_upload_kernel_file(device, kernel_file));
// allocate device memory
std::cout << "allocate device memory" << std::endl;
RT_CHECK(vx_mem_alloc(device, src_a_buf_size, VX_MEM_TYPE_GLOBAL, &kernel_arg.addr_a));
RT_CHECK(vx_mem_alloc(device, src_b_buf_size, VX_MEM_TYPE_GLOBAL, &kernel_arg.addr_b));
RT_CHECK(vx_mem_alloc(device, dst_buf_size, VX_MEM_TYPE_GLOBAL, &kernel_arg.addr_c));
kernel_arg.dim_m = dim_m;
kernel_arg.dim_n = dim_n;
kernel_arg.dim_k = dim_k;
std::cout << "dev_addr_a=0x" << std::hex << kernel_arg.addr_a << std::endl;
std::cout << "dev_addr_b=0x" << std::hex << kernel_arg.addr_b << std::endl;
std::cout << "dev_addr_c=0x" << std::hex << kernel_arg.addr_c << std::endl;
// allocate staging buffer
{
std::cout << "allocate staging buffer" << std::endl;
uint32_t staging_buf_size = std::max<uint32_t>(
src_a_buf_size,
std::max<uint32_t>(
src_b_buf_size,
std::max<uint32_t>(dst_buf_size, sizeof(kernel_arg_t))));
staging_buf.resize(staging_buf_size);
}
// upload kernel argument
{
std::cout << "upload kernel argument" << std::endl;
auto buf_ptr = staging_buf.data();
memcpy(buf_ptr, &kernel_arg, sizeof(kernel_arg_t));
RT_CHECK(vx_copy_to_dev(device, KERNEL_ARG_DEV_MEM_ADDR, staging_buf.data(), sizeof(kernel_arg_t)));
std::cout << "uploading argument buffer to device, device mem address="
<< std::hex << KERNEL_ARG_DEV_MEM_ADDR << ", size=" << std::dec
<< sizeof(kernel_arg_t) << " bytes\n";
std::ofstream file("args.bin", std::ios::binary | std::ios::out);
if (!file) {
std::cerr << "error: failed to open args.bin for writing\n";
exit(EXIT_FAILURE);
}
file.write(reinterpret_cast<char *>(staging_buf.data()),
sizeof(kernel_arg_t));
file.close();
}
// upload source buffer
{
{
auto buf_ptr = staging_buf.data();
memcpy(buf_ptr, src_a_data.data(), src_a_data.size() * sizeof(float));
RT_CHECK(vx_copy_to_dev(device, kernel_arg.addr_a, staging_buf.data(),
src_a_buf_size));
std::cout << "uploading source A matrix to device, device mem address="
<< std::hex << kernel_arg.addr_a << ", size=" << std::dec
<< src_a_buf_size << " bytes\n";
std::ofstream file("input.a.bin", std::ios::binary | std::ios::out);
if (!file) {
std::cerr << "error: failed to open args.bin for writing\n";
exit(EXIT_FAILURE);
}
file.write(reinterpret_cast<char *>(buf_ptr), src_a_buf_size);
file.close();
}
{
auto buf_ptr = staging_buf.data();
memcpy(buf_ptr, src_b_data.data(), src_b_data.size() * sizeof(float));
RT_CHECK(vx_copy_to_dev(device, kernel_arg.addr_b, staging_buf.data(),
src_b_buf_size));
std::cout << "uploading source B matrix to device, device mem address="
<< std::hex << kernel_arg.addr_b << ", size=" << std::dec
<< src_b_buf_size << " bytes\n";
std::ofstream file("input.b.bin", std::ios::binary | std::ios::out);
if (!file) {
std::cerr << "error: failed to open args.bin for writing\n";
exit(EXIT_FAILURE);
}
file.write(reinterpret_cast<char *>(buf_ptr), src_b_buf_size);
file.close();
}
}
// clear destination buffer
{
std::cout << "clear destination buffer" << std::endl;
auto buf_ptr = (int32_t*)staging_buf.data();
for (uint32_t i = 0; i < ref_data.size(); ++i) {
buf_ptr[i] = 0xdeadbeef;
}
RT_CHECK(vx_copy_to_dev(device, kernel_arg.addr_c, staging_buf.data(), dst_buf_size));
}
// run tests
std::cout << "run tests" << std::endl;
RT_CHECK(run_test(kernel_arg, dst_buf_size, kernel_arg.dim_m, kernel_arg.dim_n));
std::cout << "PASSED!" << std::endl;
// cleanup
std::cout << "cleanup" << std::endl;
cleanup();
return 0;
}