Merge branch 'tensor_core' into rtl

This commit is contained in:
Hansung Kim
2024-05-01 16:18:14 -07:00
32 changed files with 6097 additions and 20 deletions

View File

@@ -23,6 +23,9 @@
// #include "verilated_vpi.h"
#include "VX_config.h"
#include <bit>
#include "half.hpp"
extern "C" {
void dpi_fadd(bool enable, int dst_fmt, int64_t a, int64_t b, const svBitVecVal* frm, int64_t* result, svBitVecVal* fflags);
void dpi_fsub(bool enable, int dst_fmt, int64_t a, int64_t b, const svBitVecVal* frm, int64_t* result, svBitVecVal* fflags);
@@ -51,6 +54,8 @@ extern "C" {
void dpi_feq(bool enable, int dst_fmt, int64_t a, int64_t b, int64_t* result, svBitVecVal* fflags);
void dpi_fmin(bool enable, int dst_fmt, int64_t a, int64_t b, int64_t* result, svBitVecVal* fflags);
void dpi_fmax(bool enable, int dst_fmt, int64_t a, int64_t b, int64_t* result, svBitVecVal* fflags);
void dpi_hmma(bool enable, const svBitVecVal* A_tile, const svBitVecVal* B_tile, const svBitVecVal* C_tile, svBitVecVal* D_tile);
}
inline uint64_t nan_box(uint32_t value) {
@@ -338,3 +343,74 @@ void dpi_fmax(bool enable, int dst_fmt, int64_t a, int64_t b, int64_t* result, s
*result = nan_box(rv_fmax_s(check_boxing(a), check_boxing(b), fflags));
}
}
// A is M * K, B is K * K * M, C is M * M, D is M * M
#define M 4
#define K 2
// all row major
float c_A_tile[M][K];
float c_B_tile[K][M];
float c_C_tile[M][M];
float c_D_tile[M][M];
// code assumes that svBitVecVal is basically a uint32_t
static_assert(sizeof(svBitVecVal) == 4);
void fill_float_array(const svBitVecVal* sv_tile, float* c_tile, int rows, int cols) {
for (int i = 0; i < rows; i += 1) {
for (int j = 0; j < cols; j += 1) {
int index = i * cols + j;
svBitVecVal sv_val = sv_tile[index];
uint32_t c_val = sv_val;
float c_float;
memcpy(&c_float, &c_val, sizeof(c_float));
c_tile[index] = c_float;
// std::cout << c_float << " ";
}
// std::cout << std::endl;
}
}
void write_float_array(svBitVecVal* sv_tile, float* c_tile, int rows, int cols) {
for (int i = 0; i < rows; i += 1) {
for (int j = 0; j < cols; j += 1) {
int index = i * cols + j;
svBitVecVal* sv_val = &sv_tile[index];
float c_float = c_tile[index];
memcpy(sv_val, &c_float, sizeof(c_float));
// std::cout << c_float << " ";
}
// std::cout << std::endl;
}
}
void dpi_hmma(bool enable, const svBitVecVal* A_tile, const svBitVecVal* B_tile, const svBitVecVal* C_tile, svBitVecVal* D_tile) {
if (!enable) {
return;
}
// std::cout << "A: " << std::endl;
fill_float_array(A_tile, &c_A_tile[0][0], M, K);
// std::cout << "B: " << std::endl;
fill_float_array(B_tile, &c_B_tile[0][0], K, M);
// std::cout << "C: " << std::endl;
fill_float_array(C_tile, &c_C_tile[0][0], M, M);
for (int i = 0; i < M; i += 1) {
for (int j = 0; j < M; j += 1) {
float accum = c_C_tile[i][j];
for (int k = 0; k < K; k += 1) {
accum += c_A_tile[i][k] * c_B_tile[k][j];
}
c_D_tile[i][j] = accum;
}
}
write_float_array(D_tile, &c_D_tile[0][0], M, M);
}

View File

@@ -44,4 +44,6 @@ import "DPI-C" function void dpi_feq(input logic enable, input int dst_fmt, inpu
import "DPI-C" function void dpi_fmin(input logic enable, input int dst_fmt, input longint a, input longint b, output longint result, output bit[4:0] fflags);
import "DPI-C" function void dpi_fmax(input logic enable, input int dst_fmt, input longint a, input longint b, output longint result, output bit[4:0] fflags);
import "DPI-C" function void dpi_hmma(input logic enable, input bit[3:0][1:0][31:0] A_tile, input bit[1:0][3:0][31:0] B_tile, input bit[3:0][3:0][31:0] C_tile, output bit[3:0][3:0][31:0] D_tile);
`endif

4018
hw/dpi/half.hpp Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -40,6 +40,10 @@
`define EXT_F_ENABLE
`endif
`ifndef EXT_T_DISABLE
`define EXT_T_ENABLE
`endif
`ifndef XLEN_32
`ifndef XLEN_64
`define XLEN_32
@@ -385,6 +389,11 @@
`define LATENCY_FCVT 5
`endif
// Tensor Core Latency
`ifndef LATENCY_HMMA
`define LATENCY_HMMA 8
`endif
// Icache Configurable Knobs //////////////////////////////////////////////////
// Cache Enable
@@ -613,6 +622,12 @@
`define EXT_F_ENABLED 0
`endif
`ifdef EXT_T_ENABLE
`define EXT_T_ENABLED 1
`else
`define EXT_T_ENABLED 0
`endif
`ifdef EXT_M_ENABLE
`define EXT_M_ENABLED 1
`else

View File

@@ -58,8 +58,9 @@
`define EX_LSU 1
`define EX_SFU 2
`define EX_FPU (`EX_SFU + `EXT_F_ENABLED)
`define EX_TENSOR (`EX_FPU + `EXT_T_ENABLED)
`define NUM_EX_UNITS (3 + `EXT_F_ENABLED)
`define NUM_EX_UNITS (3 + `EXT_F_ENABLED + `EXT_T_ENABLED)
`define EX_BITS `CLOG2(`NUM_EX_UNITS)
`define EX_WIDTH `UP(`EX_BITS)
@@ -115,7 +116,7 @@
///////////////////////////////////////////////////////////////////////////////
`define INST_OP_BITS 4
`define INST_MOD_BITS 3
`define INST_MOD_BITS 4
`define INST_FMT_BITS 2
///////////////////////////////////////////////////////////////////////////////
@@ -140,6 +141,7 @@
`define INST_ALU_IS_BR(mod) mod[0]
`define INST_ALU_IS_M(mod) mod[1]
`define INST_ALU_IS_W(mod) mod[2]
`define INST_ALU_IS_RED(mod) mod[3]
`define INST_BR_EQ 4'b0000
`define INST_BR_NE 4'b0010
@@ -176,6 +178,17 @@
`define INST_M_SIGNED_A(op) (op[1:0] != 1)
`define INST_M_IS_REM(op) op[1]
`define INST_RED_ADD 4'b0000
`define INST_RED_ADDU 4'b1000
`define INST_RED_MIN 4'b0001
`define INST_RED_MINU 4'b1001
`define INST_RED_MAX 4'b0010
`define INST_RED_MAXU 4'b1010
`define INST_RED_AND 4'b0011
`define INST_RED_OR 4'b0100
`define INST_RED_XOR 4'b0101
`define INST_RED_BITS 4
`define INST_FMT_B 3'b000
`define INST_FMT_H 3'b001
`define INST_FMT_W 3'b010
@@ -241,6 +254,8 @@
`define INST_SFU_IS_WCTL(op) (op <= 5)
`define INST_SFU_IS_CSR(op) (op >= 6 && op <= 8)
`define INST_TENSOR_HMMA 4'b0000
///////////////////////////////////////////////////////////////////////////////
// non-cacheable tag bits

View File

@@ -33,7 +33,7 @@ module VX_alu_unit #(
localparam PID_BITS = `CLOG2(`NUM_THREADS / NUM_LANES);
localparam PID_WIDTH = `UP(PID_BITS);
localparam RSP_ARB_DATAW= `UUID_WIDTH + `NW_WIDTH + NUM_LANES + `XLEN + `NR_BITS + 1 + NUM_LANES * `XLEN + PID_WIDTH + 1 + 1;
localparam RSP_ARB_SIZE = 1 + `EXT_M_ENABLED;
localparam RSP_ARB_SIZE = 2 + `EXT_M_ENABLED;
localparam PARTIAL_BW = (BLOCK_SIZE != `ISSUE_WIDTH) || (NUM_LANES != `NUM_THREADS);
VX_execute_if #(
@@ -60,12 +60,13 @@ module VX_alu_unit #(
for (genvar block_idx = 0; block_idx < BLOCK_SIZE; ++block_idx) begin
wire is_muldiv_op;
wire is_reduce_op;
VX_execute_if #(
.NUM_LANES (NUM_LANES)
) int_execute_if();
assign int_execute_if.valid = execute_if[block_idx].valid && ~is_muldiv_op;
assign int_execute_if.valid = execute_if[block_idx].valid && ~is_muldiv_op && ~is_reduce_op;
assign int_execute_if.data = execute_if[block_idx].data;
VX_commit_if #(
@@ -86,6 +87,31 @@ module VX_alu_unit #(
.commit_if (int_commit_if)
);
assign is_reduce_op = `INST_ALU_IS_RED(execute_if[block_idx].data.op_mod);
VX_execute_if #(
.NUM_LANES (NUM_LANES)
) red_execute_if();
assign red_execute_if.valid = execute_if[block_idx].valid && is_reduce_op;
assign red_execute_if.data = execute_if[block_idx].data;
VX_commit_if #(
.NUM_LANES (NUM_LANES)
) red_commit_if();
`RESET_RELAY(red_reset, reset);
VX_reduce_unit #(
.CORE_ID(CORE_ID),
.NUM_LANES(NUM_LANES)
) reduce_unit (
.clk(clk),
.reset(red_reset),
.execute_if(red_execute_if),
.commit_if(red_commit_if)
);
`ifdef EXT_M_ENABLE
assign is_muldiv_op = `INST_ALU_IS_M(execute_if[block_idx].data.op_mod);
@@ -96,7 +122,7 @@ module VX_alu_unit #(
.NUM_LANES (NUM_LANES)
) mdv_execute_if();
assign mdv_execute_if.valid = execute_if[block_idx].valid && is_muldiv_op;
assign mdv_execute_if.valid = execute_if[block_idx].valid && is_muldiv_op && ~is_reduce_op;
assign mdv_execute_if.data = execute_if[block_idx].data;
VX_commit_if #(
@@ -113,12 +139,12 @@ module VX_alu_unit #(
.commit_if (mdv_commit_if)
);
assign execute_if[block_idx].ready = is_muldiv_op ? mdv_execute_if.ready : int_execute_if.ready;
assign execute_if[block_idx].ready = is_reduce_op ? red_execute_if.ready : (is_muldiv_op ? mdv_execute_if.ready : int_execute_if.ready);
`else
assign is_muldiv_op = 0;
assign execute_if[block_idx].ready = int_execute_if.ready;
assign execute_if[block_idx].ready = is_reduce_op ? red_execute_if.ready : int_execute_if.ready;
`endif
@@ -135,19 +161,22 @@ module VX_alu_unit #(
`ifdef EXT_M_ENABLE
mdv_commit_if.valid,
`endif
int_commit_if.valid
int_commit_if.valid,
red_commit_if.valid
}),
.ready_in ({
`ifdef EXT_M_ENABLE
mdv_commit_if.ready,
`endif
int_commit_if.ready
int_commit_if.ready,
red_commit_if.ready
}),
.data_in ({
`ifdef EXT_M_ENABLE
mdv_commit_if.data,
`endif
int_commit_if.data
int_commit_if.data,
red_commit_if.data
}),
.data_out (commit_block_if[block_idx].data),
.valid_out (commit_block_if[block_idx].valid),

View File

@@ -28,6 +28,10 @@ module VX_commit import VX_gpu_pkg::*; #(
`endif
VX_commit_if.slave sfu_commit_if [`ISSUE_WIDTH],
`ifdef EXT_T_ENABLE
VX_commit_if.slave tensor_commit_if [`ISSUE_WIDTH],
`endif
// outputs
VX_writeback_if.master writeback_if [`ISSUE_WIDTH],
VX_commit_csr_if.master commit_csr_if,
@@ -49,6 +53,7 @@ module VX_commit import VX_gpu_pkg::*; #(
wire [`ISSUE_WIDTH-1:0][`NW_WIDTH-1:0] commit_wid;
wire [`ISSUE_WIDTH-1:0][`NUM_THREADS-1:0] commit_tmask;
wire [`ISSUE_WIDTH-1:0] commit_eop;
wire [`ISSUE_WIDTH-1:0][`EX_BITS-1:0] commit_sel;
for (genvar i = 0; i < `ISSUE_WIDTH; ++i) begin
@@ -66,6 +71,9 @@ module VX_commit import VX_gpu_pkg::*; #(
sfu_commit_if[i].valid,
`ifdef EXT_F_ENABLE
fpu_commit_if[i].valid,
`endif
`ifdef EXT_T_ENABLE
tensor_commit_if[i].valid,
`endif
alu_commit_if[i].valid,
lsu_commit_if[i].valid
@@ -74,6 +82,9 @@ module VX_commit import VX_gpu_pkg::*; #(
sfu_commit_if[i].ready,
`ifdef EXT_F_ENABLE
fpu_commit_if[i].ready,
`endif
`ifdef EXT_T_ENABLE
tensor_commit_if[i].ready,
`endif
alu_commit_if[i].ready,
lsu_commit_if[i].ready
@@ -82,6 +93,9 @@ module VX_commit import VX_gpu_pkg::*; #(
sfu_commit_if[i].data,
`ifdef EXT_F_ENABLE
fpu_commit_if[i].data,
`endif
`ifdef EXT_T_ENABLE
tensor_commit_if[i].data,
`endif
alu_commit_if[i].data,
lsu_commit_if[i].data
@@ -89,7 +103,7 @@ module VX_commit import VX_gpu_pkg::*; #(
.data_out (commit_if[i].data),
.valid_out (commit_if[i].valid),
.ready_out (commit_if[i].ready),
`UNUSED_PIN (sel_out)
.sel_out (commit_sel[i])
);
assign commit_fire[i] = commit_if[i].valid && commit_if[i].ready;
@@ -158,7 +172,32 @@ module VX_commit import VX_gpu_pkg::*; #(
// Committed instructions
wire [`ISSUE_WIDTH-1:0] committed = commit_fire & commit_eop;
// temporary hack to not underflow the pending instructions buffer
// relies on 1 cycle delay of arbiter and continuous issuing of tensor instructions,
// so probably want to change this at some point
// (i.e. pass a "don't count this towards pending instructions" signal down the pipeline)
logic [`ISSUE_WIDTH-1:0][4:0] hmma_ctr, hmma_ctr_n;
wire [`ISSUE_WIDTH-1:0] final_hmma;
`ifdef EXT_T_ENABLE
for (genvar i = 0; i < `ISSUE_WIDTH; ++i) begin
assign hmma_ctr_n[i] = (tensor_commit_if[i].valid && tensor_commit_if[i].ready) ? hmma_ctr[i] + 5'b1 : hmma_ctr[i];
assign final_hmma[i] = (commit_sel[i] != `EX_BITS'(2) || hmma_ctr == '0);
end
always @(posedge clk) begin
if (reset) begin
hmma_ctr <= '0;
end
else begin
hmma_ctr <= hmma_ctr_n;
end
end
`else
assign final_hmma = '1;
`endif
wire [`ISSUE_WIDTH-1:0] committed = (commit_fire & commit_eop) & final_hmma;
VX_pipe_register #(
.DATAW (`ISSUE_WIDTH * (1 + `NW_WIDTH)),

View File

@@ -67,6 +67,10 @@ module VX_core import VX_gpu_pkg::*; #(
`ifdef EXT_F_ENABLE
VX_dispatch_if fpu_dispatch_if[`ISSUE_WIDTH]();
VX_commit_if fpu_commit_if[`ISSUE_WIDTH]();
`endif
`ifdef EXT_T_ENABLE
VX_dispatch_if tensor_dispatch_if[`ISSUE_WIDTH]();
VX_commit_if tensor_commit_if[`ISSUE_WIDTH]();
`endif
VX_dispatch_if sfu_dispatch_if[`ISSUE_WIDTH]();
VX_commit_if sfu_commit_if[`ISSUE_WIDTH]();
@@ -174,6 +178,9 @@ module VX_core import VX_gpu_pkg::*; #(
.lsu_dispatch_if(lsu_dispatch_if),
`ifdef EXT_F_ENABLE
.fpu_dispatch_if(fpu_dispatch_if),
`endif
`ifdef EXT_T_ENABLE
.tensor_dispatch_if(tensor_dispatch_if),
`endif
.sfu_dispatch_if(sfu_dispatch_if)
);
@@ -199,6 +206,10 @@ module VX_core import VX_gpu_pkg::*; #(
.fpu_dispatch_if(fpu_dispatch_if),
.fpu_commit_if (fpu_commit_if),
`endif
`ifdef EXT_T_ENABLE
.tensor_dispatch_if (tensor_dispatch_if),
.tensor_commit_if (tensor_commit_if),
`endif
.commit_csr_if (commit_csr_if),
.sched_csr_if (sched_csr_if),
@@ -229,6 +240,9 @@ module VX_core import VX_gpu_pkg::*; #(
.fpu_commit_if (fpu_commit_if),
`endif
.sfu_commit_if (sfu_commit_if),
`ifdef EXT_T_ENABLE
.tensor_commit_if (tensor_commit_if),
`endif
.writeback_if (writeback_if),

View File

@@ -513,6 +513,40 @@ module VX_decode #(
default:;
endcase
end
`INST_EXT3: begin
ex_type = `EX_ALU;
op_mod[3] = 1;
`USED_IREG(rs1);
`USED_IREG(rd);
case (func7[5:0])
6'h0: begin
op_type = func7[6] ? `INST_RED_ADDU : `INST_RED_ADD;
end
6'h1: begin
op_type = func7[6] ? `INST_RED_MINU : `INST_RED_MIN;
end
6'h2: begin
op_type = func7[6] ? `INST_RED_MAXU : `INST_RED_MAX;
end
6'h3: begin
op_type = `INST_RED_AND;
end
6'h4: begin
op_type = `INST_RED_OR;
end
6'h5: begin
op_type = `INST_RED_XOR;
end
default:;
endcase
end
`ifdef EXT_T_ENABLE
`INST_EXT4: begin
ex_type = `EX_TENSOR;
op_type = `INST_TENSOR_HMMA;
end
`endif
default:;
endcase
end

View File

@@ -34,6 +34,9 @@ module VX_dispatch import VX_gpu_pkg::*; #(
VX_dispatch_if.master lsu_dispatch_if [`ISSUE_WIDTH],
`ifdef EXT_F_ENABLE
VX_dispatch_if.master fpu_dispatch_if [`ISSUE_WIDTH],
`endif
`ifdef EXT_T_ENABLE
VX_dispatch_if.master tensor_dispatch_if [`ISSUE_WIDTH],
`endif
VX_dispatch_if.master sfu_dispatch_if [`ISSUE_WIDTH]
);
@@ -142,6 +145,35 @@ module VX_dispatch import VX_gpu_pkg::*; #(
end
`endif
// Tensor Core dispatch
`ifdef EXT_T_ENABLE
VX_operands_if tensor_operands_if[`ISSUE_WIDTH]();
for (genvar i = 0; i < `ISSUE_WIDTH; ++i) begin
assign tensor_operands_if[i].valid = operands_if[i].valid && (operands_if[i].data.ex_type == `EX_TENSOR);
assign tensor_operands_if[i].data = operands_if[i].data;
`RESET_RELAY (tensor_reset, reset);
VX_elastic_buffer #(
.DATAW (DATAW),
.SIZE (2),
.OUT_REG (2)
) tensor_buffer (
.clk (clk),
.reset (tensor_reset),
.valid_in (tensor_operands_if[i].valid),
.ready_in (tensor_operands_if[i].ready),
.data_in (`TO_DISPATCH_DATA(tensor_operands_if[i].data, last_active_tid[i])),
.data_out (tensor_dispatch_if[i].data),
.valid_out (tensor_dispatch_if[i].valid),
.ready_out (tensor_dispatch_if[i].ready)
);
end
`endif
// SFU dispatch
VX_operands_if sfu_operands_if[`ISSUE_WIDTH]();
@@ -174,6 +206,9 @@ module VX_dispatch import VX_gpu_pkg::*; #(
|| (lsu_operands_if[i].ready && (operands_if[i].data.ex_type == `EX_LSU))
`ifdef EXT_F_ENABLE
|| (fpu_operands_if[i].ready && (operands_if[i].data.ex_type == `EX_FPU))
`endif
`ifdef EXT_T_ENABLE
|| (tensor_operands_if[i].ready && (operands_if[i].data.ex_type == `EX_TENSOR))
`endif
|| (sfu_operands_if[i].ready && (operands_if[i].data.ex_type == `EX_SFU));
end

View File

@@ -41,7 +41,7 @@ module VX_execute import VX_gpu_pkg::*; #(
VX_dispatch_if.slave fpu_dispatch_if [`ISSUE_WIDTH],
VX_commit_if.master fpu_commit_if [`ISSUE_WIDTH],
`endif
VX_dispatch_if.slave alu_dispatch_if [`ISSUE_WIDTH],
VX_commit_if.master alu_commit_if [`ISSUE_WIDTH],
VX_branch_ctl_if.master branch_ctl_if [`NUM_ALU_BLOCKS],
@@ -53,6 +53,11 @@ module VX_execute import VX_gpu_pkg::*; #(
VX_commit_if.master sfu_commit_if [`ISSUE_WIDTH],
VX_warp_ctl_if.master warp_ctl_if,
`ifdef EXT_T_ENABLE
VX_dispatch_if.slave tensor_dispatch_if [`ISSUE_WIDTH],
VX_commit_if.master tensor_commit_if [`ISSUE_WIDTH],
`endif
// simulation helper signals
output wire sim_ebreak
);
@@ -127,6 +132,18 @@ module VX_execute import VX_gpu_pkg::*; #(
.commit_if (sfu_commit_if)
);
`ifdef EXT_T_ENABLE
VX_tensor_core #(
) tensor_core (
.clk(clk),
.reset(reset),
.dispatch_if(tensor_dispatch_if),
.commit_if(tensor_commit_if)
);
`endif
// simulation helper signal to get RISC-V tests Pass/Fail status
assign sim_ebreak = alu_dispatch_if[0].valid && alu_dispatch_if[0].ready
&& alu_dispatch_if[0].data.wis == 0

View File

@@ -36,6 +36,8 @@ module VX_ibuffer import VX_gpu_pkg::*; #(
assign decode_if.ready = ibuf_ready_in[decode_isw];
VX_ibuffer_if uop_sequencer_if [`ISSUE_WIDTH]();
for (genvar i = 0; i < `ISSUE_WIDTH; ++i) begin
VX_elastic_buffer #(
.DATAW (DATAW),
@@ -62,13 +64,24 @@ module VX_ibuffer import VX_gpu_pkg::*; #(
decode_if.data.rs1,
decode_if.data.rs2,
decode_if.data.rs3}),
.data_out(ibuffer_if[i].data),
.valid_out (ibuffer_if[i].valid),
.ready_out(ibuffer_if[i].ready)
);
.data_out (uop_sequencer_if[i].data),
.valid_out (uop_sequencer_if[i].valid),
.ready_out (uop_sequencer_if[i].ready)
);
`ifndef L1_ENABLE
assign decode_if.ibuf_pop[i] = ibuffer_if[i].valid && ibuffer_if[i].ready;
assign decode_if.ibuf_pop[i] = uop_sequencer_if[i].valid && uop_sequencer_if[i].ready;
`endif
VX_uop_sequencer uop_sequencer (
.clk(clk),
.reset(reset),
.uop_sequencer_if(uop_sequencer_if[i]),
.ibuffer_if(ibuffer_if[i])
);
end
endmodule

View File

@@ -33,6 +33,9 @@ module VX_issue #(
VX_dispatch_if.master lsu_dispatch_if [`ISSUE_WIDTH],
`ifdef EXT_F_ENABLE
VX_dispatch_if.master fpu_dispatch_if [`ISSUE_WIDTH],
`endif
`ifdef EXT_T_ENABLE
VX_dispatch_if.master tensor_dispatch_if [`ISSUE_WIDTH],
`endif
VX_dispatch_if.master sfu_dispatch_if [`ISSUE_WIDTH]
);
@@ -104,6 +107,9 @@ module VX_issue #(
.lsu_dispatch_if(lsu_dispatch_if),
`ifdef EXT_F_ENABLE
.fpu_dispatch_if(fpu_dispatch_if),
`endif
`ifdef EXT_T_ENABLE
.tensor_dispatch_if(tensor_dispatch_if),
`endif
.sfu_dispatch_if(sfu_dispatch_if)
);

View File

@@ -294,6 +294,27 @@ module VX_operands import VX_gpu_pkg::*; #(
.raddr (gpr_rd_addr),
.rdata (gpr_rd_data[j])
);
// blast read register file because printf is slowge
logic [31:0] cycle, cycle_n;
assign cycle_n = cycle + 32'd1;
always @(posedge clk) begin
if (reset) begin
cycle <= '0;
end
else begin
cycle <= cycle_n;
end
if (cycle == 32'd25000) begin
for (integer k = 0; k < `NUM_REGS * ISSUE_RATIO; ++k) begin
integer warp = i * ISSUE_RATIO + (k / `NUM_REGS);
integer thread = j;
integer register = k % `NUM_REGS;
$display("warp %0d, thread %0d, register %0d: %0x", warp, thread, register, gpr_ram.ram[k]);
end
end
end
end
end

View File

@@ -0,0 +1,283 @@
`include "VX_define.vh"
`include "VX_platform.vh"
// Copyright © 2019-2023
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
`include "VX_platform.vh"
module VX_reduce_ext #(
parameter DATAW_IN = 1,
parameter DATAW_OUT = DATAW_IN,
parameter N = 1
) (
input wire [N-1:0][DATAW_IN-1:0] data_in,
input wire [N-1:0] mask,
input wire [`INST_RED_BITS-1:0] op_type,
output wire [DATAW_OUT-1:0] data_out
);
if (N == 1) begin
`UNUSED_VAR(op_type)
`UNUSED_VAR(mask)
assign data_out = DATAW_OUT'(data_in[0]);
end else begin
localparam int N_A = N / 2;
localparam int N_B = N - N_A;
wire [N_A-1:0][DATAW_IN-1:0] in_A;
wire [N_B-1:0][DATAW_IN-1:0] in_B;
wire [DATAW_OUT-1:0] out_A, out_B;
wire [N_A-1:0] mask_A;
wire [N_B-1:0] mask_B;
wire any_A, any_B;
for (genvar i = 0; i < N_A; i++) begin
assign in_A[i] = data_in[i];
end
for (genvar i = 0; i < N_B; i++) begin
assign in_B[i] = data_in[N_A + i];
end
assign mask_A = mask[N_A-1:0];
assign mask_B = mask[N-1:N_A];
assign any_A = |mask_A;
assign any_B = |mask_B;
VX_reduce_ext #(
.DATAW_IN (DATAW_IN),
.DATAW_OUT (DATAW_OUT),
.N (N_A)
) reduce_A (
.data_in (in_A),
.mask(mask_A),
.op_type(op_type),
.data_out (out_A)
);
VX_reduce_ext #(
.DATAW_IN (DATAW_IN),
.DATAW_OUT (DATAW_OUT),
.N (N_B)
) reduce_B (
.data_in (in_B),
.mask(mask_B),
.op_type(op_type),
.data_out (out_B)
);
logic [DATAW_OUT-1:0] _data_out;
always @(*) begin
case (op_type)
`INST_RED_ADD: _data_out = out_A + out_B;
`INST_RED_ADDU: _data_out = out_A + out_B;
`INST_RED_MIN: _data_out = ($signed(out_A) < $signed(out_B)) ? out_A : out_B;
`INST_RED_MINU: _data_out = (out_A < out_B) ? out_A : out_B;
`INST_RED_MAX: _data_out = ($signed(out_A) < $signed(out_B)) ? out_B : out_A;
`INST_RED_MAXU: _data_out = (out_A < out_B) ? out_B : out_A;
`INST_RED_AND: _data_out = out_A & out_B;
`INST_RED_OR: _data_out = out_A | out_B;
`INST_RED_XOR: _data_out = out_A ^ out_B;
default: _data_out = out_A;
endcase
end
// if both sides are masked out, then it doesn't matter what we output
assign data_out = (any_A && any_B) ? _data_out : (any_A ? out_A : out_B);
end
endmodule
module VX_reduce_unit #(
parameter CORE_ID = 0,
parameter NUM_LANES = 1
) (
input wire clk,
input wire reset,
VX_execute_if.slave execute_if,
VX_commit_if.master commit_if
);
`UNUSED_PARAM(CORE_ID)
localparam NUM_PACKETS = `NUM_THREADS / NUM_LANES;
localparam PID_BITS = `CLOG2(`NUM_THREADS / NUM_LANES);
localparam PID_WIDTH = `UP(PID_BITS);
logic [`XLEN-1:0] accumulator, accumulator_n, reduced_accumulator;
wire [(NUM_LANES * `XLEN)-1:0] broadcasted_accumulator;
assign broadcasted_accumulator = {NUM_LANES{accumulator}};
wire eop;
wire [NUM_LANES-1:0][`XLEN-1:0] data_in;
wire [`XLEN-1:0] data_out;
assign eop = execute_if.data.eop;
assign data_in = execute_if.data.rs1_data;
logic execute_if_valid;
logic execute_if_ready;
logic commit_if_valid;
logic commit_if_ready;
wire execute_if_fire;
wire commit_if_fire;
assign execute_if_valid = execute_if.valid;
assign execute_if.ready = execute_if_ready;
assign execute_if_fire = execute_if.ready && execute_if.valid;
assign commit_if_fire = commit_if_ready && commit_if_valid;
logic store_tmask_pid;
logic read_tmask_pid;
wire [PID_WIDTH-1:0] stored_pid;
wire [NUM_LANES-1:0] stored_tmask;
wire stored_sop;
wire stored_eop;
logic [PID_BITS:0] size, size_n;
// 1. idle state - wait for execute_if to be valid
// 2. accumulate - continue accumulating until eop, store packet id + thread mask for broadcast phase
// 3. broadcast - broadcast to rds
localparam IDLE = 2'b00;
localparam ACCUMULATE = 2'b01;
localparam BROADCAST = 2'b10;
localparam FINISH = 2'b11;
logic [1:0] state, state_n;
always @(*) begin
state_n = state;
accumulator_n = accumulator;
execute_if_ready = '0;
commit_if_valid = '0;
store_tmask_pid = '0;
read_tmask_pid = '0;
size_n = store_tmask_pid ? size + 1 : (read_tmask_pid ? size - 1 : size);
case (state)
IDLE: begin
if (execute_if_valid) begin
accumulator_n = data_out;
store_tmask_pid = '1;
if (eop) begin
state_n = BROADCAST;
end
else begin
execute_if_ready = '1;
state_n = ACCUMULATE;
end
end
end
ACCUMULATE: begin
execute_if_ready = '1;
if (eop) begin
execute_if_ready = '0;
state_n = BROADCAST;
end
if (eop || execute_if_fire) begin
accumulator_n = reduced_accumulator;
store_tmask_pid = '1;
end
end
BROADCAST: begin
execute_if_ready = '0;
commit_if_valid = '1;
if (commit_if_fire) begin
read_tmask_pid = '1;
end
if (size_n == '0) begin
state_n = FINISH;
end
end
FINISH: begin
execute_if_ready = '1;
if (execute_if_fire) begin
state_n = IDLE;
end
end
endcase
end
always @(posedge clk) begin
if (reset) begin
accumulator <= '0;
state <= IDLE;
size <= '0;
end
else begin
accumulator <= accumulator_n;
state <= state_n;
size <= size_n;
end
end
VX_reduce_ext #(
.DATAW_IN(`XLEN),
.N(NUM_LANES)
) reducer (
.data_in(data_in),
.mask(execute_if.data.tmask),
.op_type(execute_if.data.op_type),
.data_out(data_out)
);
VX_reduce_ext #(
.DATAW_IN(`XLEN),
.N(2)
) accumulator_reducer (
.data_in({accumulator, data_out}),
.mask(2'b11),
.op_type(execute_if.data.op_type),
.data_out(reduced_accumulator)
);
VX_elastic_buffer #(
.DATAW(NUM_LANES + PID_WIDTH + 1 + 1),
.SIZE(NUM_PACKETS),
) tmask_pid_store (
.clk(clk),
.reset(reset),
.valid_in(store_tmask_pid),
`UNUSED_PIN(ready_in),
.data_in({execute_if.data.tmask, execute_if.data.pid, execute_if.data.sop, execute_if.data.eop}),
.data_out({stored_tmask, stored_pid, stored_sop, stored_eop}),
.ready_out(read_tmask_pid),
`UNUSED_PIN(valid_out)
);
VX_elastic_buffer #(
.DATAW(`UUID_WIDTH + `NW_WIDTH + NUM_LANES + `XLEN + 1 + `NR_BITS + (`XLEN * NUM_LANES) + PID_WIDTH + 1 + 1)
) output_buffer (
.clk(clk),
.reset(reset),
.valid_in(commit_if_valid),
.ready_in(commit_if_ready),
.data_in({execute_if.data.uuid, execute_if.data.wid, stored_tmask, execute_if.data.PC, execute_if.data.wb, execute_if.data.rd, broadcasted_accumulator, stored_pid, stored_sop, stored_eop}),
.data_out({commit_if.data.uuid, commit_if.data.wid, commit_if.data.tmask, commit_if.data.PC, commit_if.data.wb, commit_if.data.rd, commit_if.data.data, commit_if.data.pid, commit_if.data.sop, commit_if.data.eop}),
.ready_out(commit_if.ready),
.valid_out(commit_if.valid)
);
endmodule

View File

@@ -0,0 +1,300 @@
`include "VX_fpu_define.vh"
module VX_tensor_core #(
) (
input clk,
input reset,
VX_dispatch_if.slave dispatch_if [`ISSUE_WIDTH],
VX_commit_if.master commit_if [`ISSUE_WIDTH]
);
`STATIC_ASSERT(`NUM_THREADS == 32, ("tensor core requires # of threads in a warp to be 32 (try running w/ CONFIGS=\"-DNUM_THREADS=32\")"));
for (genvar i = 0; i < `ISSUE_WIDTH; ++i) begin
VX_tensor_core_warp #(
.ISW(i)
) tensor_core (
.clk(clk),
.reset(reset),
.dispatch_if(dispatch_if[i]),
.commit_if(commit_if[i])
);
end
endmodule
module VX_tensor_core_warp import VX_gpu_pkg::*; #(
parameter ISW
) (
input clk,
input reset,
VX_dispatch_if.slave dispatch_if,
VX_commit_if.master commit_if
);
wire [1:0] step = 2'(dispatch_if.data.op_type);
logic [3:0] octet_results_valid;
logic [3:0] octet_results_ready;
logic [`NUM_THREADS-1:0][`XLEN-1:0] wb_data_0;
logic [`NUM_THREADS-1:0][`XLEN-1:0] wb_data_1;
for (genvar i = 0; i < 4; ++i) begin
wire [7:0][31:0] octet_A = {
dispatch_if.data.rs1_data[16+4*i +: 4], dispatch_if.data.rs1_data[4*i +: 4]
};
wire [7:0][31:0] octet_B = {
dispatch_if.data.rs2_data[16+4*i +: 4], dispatch_if.data.rs2_data[4*i +: 4]
};
wire [7:0][31:0] octet_C = {
dispatch_if.data.rs3_data[16+4*i +: 4], dispatch_if.data.rs3_data[4*i +: 4]
};
logic [3:0][3:0][31:0] octet_D;
logic result_valid;
logic result_ready;
VX_tensor_octet #(
) octet (
.clk(clk),
.reset(reset),
.A_in(octet_A),
.B_in(octet_B),
.C_in(octet_C),
.operands_valid(dispatch_if.valid),
.operands_ready(dispatch_if.ready),
.step(step),
.D_out(octet_D),
.result_valid(result_valid),
.result_ready(result_ready)
);
// these should always be in lockstep
assign octet_results_valid[i] = result_valid;
assign result_ready = octet_results_ready[i];
assign wb_data_0[4*i+0] = octet_D[0][0];
assign wb_data_0[4*i+1] = octet_D[1][0];
assign wb_data_0[4*i+2] = octet_D[0][2];
assign wb_data_0[4*i+3] = octet_D[1][2];
assign wb_data_1[4*i+0] = octet_D[0][1];
assign wb_data_1[4*i+1] = octet_D[1][1];
assign wb_data_1[4*i+2] = octet_D[0][3];
assign wb_data_1[4*i+3] = octet_D[1][3];
assign wb_data_0[4*i+16+0] = octet_D[2][0];
assign wb_data_0[4*i+16+1] = octet_D[3][0];
assign wb_data_0[4*i+16+2] = octet_D[2][2];
assign wb_data_0[4*i+16+3] = octet_D[3][2];
assign wb_data_1[4*i+16+0] = octet_D[2][1];
assign wb_data_1[4*i+16+1] = octet_D[3][1];
assign wb_data_1[4*i+16+2] = octet_D[2][3];
assign wb_data_1[4*i+16+3] = octet_D[3][3];
end
/* commit_if.data_t parts that we need to keep around:
- uuid
- wid
- tmask
- PC
- wb
- rd
*/
localparam DATAW = `UUID_WIDTH + `NW_WIDTH + `NUM_THREADS + `XLEN + 1 + `NR_BITS;
wire dispatch_if_fire = dispatch_if.valid && dispatch_if.ready;
wire commit_if_fire = commit_if.valid && commit_if.ready;
wire [DATAW-1:0] dispatch_if_data_enq = {
dispatch_if.data.uuid,
wis_to_wid(dispatch_if.data.wis, ISW),
dispatch_if.data.tmask,
dispatch_if.data.PC,
dispatch_if.data.wb,
dispatch_if.data.rd
};
wire [DATAW-1:0] dispatch_if_data_deq;
// this is probably a little oversized
VX_fifo_queue #(
.DATAW(DATAW),
.DEPTH(16)
) pending_uops (
.clk(clk),
.reset(reset),
.push(dispatch_if_fire),
.pop(commit_if_fire),
.data_in(dispatch_if_data_enq),
.data_out(dispatch_if_data_deq),
`UNUSED_PIN(empty),
`UNUSED_PIN(alm_empty),
`UNUSED_PIN(full), // should be impossible to overflow
`UNUSED_PIN(alm_full),
`UNUSED_PIN(size)
);
logic subcommit, subcommit_n;
wire all_valid = (& octet_results_valid);
assign commit_if.valid = all_valid;
localparam COMMIT_DATAW = `UUID_WIDTH + `NW_WIDTH + `NUM_THREADS + `XLEN + 1 + `NR_BITS + (`NUM_THREADS * `XLEN) + 1 + 1 + 1;
wire [COMMIT_DATAW-1:0] commit_if_data = {
dispatch_if_data_deq,
subcommit == 1'b0 ? wb_data_0 : wb_data_1,
1'b0,
1'b1,
1'b1
};
assign commit_if.data = commit_if_data;
always @(*) begin
subcommit_n = commit_if_fire ? ~subcommit : subcommit;
if (commit_if_fire && subcommit == 1'b1) begin
octet_results_ready = '1;
end
else begin
octet_results_ready = '0;
end
end
always @(posedge clk) begin
if (reset) begin
subcommit <= '0;
end
else begin
subcommit <= subcommit_n;
end
end
endmodule
module VX_tensor_octet #(
) (
input clk,
input reset,
input [7:0][31:0] A_in,
input [7:0][31:0] B_in,
input [7:0][31:0] C_in,
input operands_valid, // we have to backpressure due to there potentially being contention over commit
output operands_ready,
input [1:0] step,
output [3:0][3:0][31:0] D_out,
output result_valid,
input result_ready
);
// 512 bits/octet * 4 octets per warp
logic [3:0][31:0] A_buffer, A_buffer_n;
logic [3:0][31:0] B_buffer, B_buffer_n;
logic [7:0][31:0] C_buffer, C_buffer_n;
// half the inputs are buffered, half are not (instead coming straight from operand bus)
// unlike the real tensor core, the banks are only 32 bit rather than 64 bit
logic [3:0][31:0] A_half;
logic [3:0][31:0] B_half;
logic [7:0][31:0] C_half;
always @(*) begin
case (step)
2'b00: begin
A_half = { A_in[5:4], A_in[1:0] };
B_half = B_in[3:0];
end
2'b01: begin
A_half = { A_in[7:6], A_in[3:2] };
B_half = B_in[3:0];
end
2'b10: begin
A_half = { A_in[5:4], A_in[1:0] };
B_half = B_in[7:4];
end
2'b11: begin
A_half = { A_in[7:6], A_in[3:2] };
B_half = B_in[7:4];
end
endcase
C_half = C_in;
end
logic substep;
wire substep_n = (operands_ready && operands_valid) ? ~substep : substep;
always @(*) begin
A_buffer_n = A_buffer;
B_buffer_n = B_buffer;
C_buffer_n = C_buffer;
if (substep == 1'b0) begin
A_buffer_n = A_half;
B_buffer_n = B_half;
C_buffer_n = C_half;
end
end
always @(posedge clk) begin
if (reset) begin
A_buffer <= '0;
B_buffer <= '0;
C_buffer <= '0;
substep <= '0;
end
else begin
A_buffer <= A_buffer_n;
B_buffer <= B_buffer_n;
C_buffer <= C_buffer_n;
substep <= substep_n;
end
end
wire stall = result_valid && ~result_ready;
assign operands_ready = ~stall;
wire [3:0][1:0][31:0] A_tile = {
{ A_half[3], A_buffer[3] },
{ A_half[2], A_buffer[2] },
{ A_half[1], A_buffer[1] },
{ A_half[0], A_buffer[0] }
};
wire [1:0][3:0][31:0] B_tile = {
B_half, B_buffer
};
logic [3:0][3:0][31:0] C_tile;
always @(*) begin
C_tile = {
C_half[7], C_buffer[7], C_half[5], C_buffer[5],
C_half[6], C_buffer[6], C_half[4], C_buffer[4],
C_half[3], C_buffer[3], C_half[1], C_buffer[1],
C_half[2], C_buffer[2], C_half[0], C_buffer[0]
};
end
wire do_hmma = (substep == 1'b1 && operands_valid && operands_ready);
VX_tensor_dpu #(
) dpu (
.clk(clk),
.reset(reset),
.stall(stall),
.valid_in(do_hmma),
.A_tile(A_tile),
.B_tile(B_tile),
.C_tile(C_tile),
.valid_out(result_valid),
.D_tile(D_out)
);
endmodule

View File

@@ -0,0 +1,96 @@
HMMA_SET0_STEP0_0: begin
uop = {NEXT, HMMA_SET0_STEP0_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(0), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(16), `FREG(0), `FREG(8), `FREG(16)};
end
HMMA_SET0_STEP0_1: begin
uop = {NEXT, HMMA_SET0_STEP1_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(0), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(17), `FREG(1), `FREG(9), `FREG(17)};
end
HMMA_SET0_STEP1_0: begin
uop = {NEXT, HMMA_SET0_STEP1_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(1), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(18), `FREG(0), `FREG(8), `FREG(18)};
end
HMMA_SET0_STEP1_1: begin
uop = {NEXT, HMMA_SET0_STEP2_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(1), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(19), `FREG(1), `FREG(9), `FREG(19)};
end
HMMA_SET0_STEP2_0: begin
uop = {NEXT, HMMA_SET0_STEP2_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(2), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(20), `FREG(0), `FREG(8), `FREG(20)};
end
HMMA_SET0_STEP2_1: begin
uop = {NEXT, HMMA_SET0_STEP3_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(2), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(21), `FREG(1), `FREG(9), `FREG(21)};
end
HMMA_SET0_STEP3_0: begin
uop = {NEXT, HMMA_SET0_STEP3_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(3), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(22), `FREG(0), `FREG(8), `FREG(22)};
end
HMMA_SET0_STEP3_1: begin
uop = {NEXT, HMMA_SET1_STEP0_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(3), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(23), `FREG(1), `FREG(9), `FREG(23)};
end
HMMA_SET1_STEP0_0: begin
uop = {NEXT, HMMA_SET1_STEP0_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(0), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(16), `FREG(2), `FREG(10), `FREG(16)};
end
HMMA_SET1_STEP0_1: begin
uop = {NEXT, HMMA_SET1_STEP1_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(0), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(17), `FREG(3), `FREG(11), `FREG(17)};
end
HMMA_SET1_STEP1_0: begin
uop = {NEXT, HMMA_SET1_STEP1_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(1), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(18), `FREG(2), `FREG(10), `FREG(18)};
end
HMMA_SET1_STEP1_1: begin
uop = {NEXT, HMMA_SET1_STEP2_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(1), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(19), `FREG(3), `FREG(11), `FREG(19)};
end
HMMA_SET1_STEP2_0: begin
uop = {NEXT, HMMA_SET1_STEP2_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(2), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(20), `FREG(2), `FREG(10), `FREG(20)};
end
HMMA_SET1_STEP2_1: begin
uop = {NEXT, HMMA_SET1_STEP3_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(2), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(21), `FREG(3), `FREG(11), `FREG(21)};
end
HMMA_SET1_STEP3_0: begin
uop = {NEXT, HMMA_SET1_STEP3_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(3), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(22), `FREG(2), `FREG(10), `FREG(22)};
end
HMMA_SET1_STEP3_1: begin
uop = {NEXT, HMMA_SET2_STEP0_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(3), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(23), `FREG(3), `FREG(11), `FREG(23)};
end
HMMA_SET2_STEP0_0: begin
uop = {NEXT, HMMA_SET2_STEP0_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(0), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(16), `FREG(4), `FREG(12), `FREG(16)};
end
HMMA_SET2_STEP0_1: begin
uop = {NEXT, HMMA_SET2_STEP1_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(0), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(17), `FREG(5), `FREG(13), `FREG(17)};
end
HMMA_SET2_STEP1_0: begin
uop = {NEXT, HMMA_SET2_STEP1_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(1), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(18), `FREG(4), `FREG(12), `FREG(18)};
end
HMMA_SET2_STEP1_1: begin
uop = {NEXT, HMMA_SET2_STEP2_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(1), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(19), `FREG(5), `FREG(13), `FREG(19)};
end
HMMA_SET2_STEP2_0: begin
uop = {NEXT, HMMA_SET2_STEP2_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(2), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(20), `FREG(4), `FREG(12), `FREG(20)};
end
HMMA_SET2_STEP2_1: begin
uop = {NEXT, HMMA_SET2_STEP3_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(2), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(21), `FREG(5), `FREG(13), `FREG(21)};
end
HMMA_SET2_STEP3_0: begin
uop = {NEXT, HMMA_SET2_STEP3_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(3), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(22), `FREG(4), `FREG(12), `FREG(22)};
end
HMMA_SET2_STEP3_1: begin
uop = {NEXT, HMMA_SET3_STEP0_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(3), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(23), `FREG(5), `FREG(13), `FREG(23)};
end
HMMA_SET3_STEP0_0: begin
uop = {NEXT, HMMA_SET3_STEP0_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(0), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(16), `FREG(6), `FREG(14), `FREG(16)};
end
HMMA_SET3_STEP0_1: begin
uop = {NEXT, HMMA_SET3_STEP1_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(0), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(17), `FREG(7), `FREG(15), `FREG(17)};
end
HMMA_SET3_STEP1_0: begin
uop = {NEXT, HMMA_SET3_STEP1_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(1), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(18), `FREG(6), `FREG(14), `FREG(18)};
end
HMMA_SET3_STEP1_1: begin
uop = {NEXT, HMMA_SET3_STEP2_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(1), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(19), `FREG(7), `FREG(15), `FREG(19)};
end
HMMA_SET3_STEP2_0: begin
uop = {NEXT, HMMA_SET3_STEP2_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(2), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(20), `FREG(6), `FREG(14), `FREG(20)};
end
HMMA_SET3_STEP2_1: begin
uop = {NEXT, HMMA_SET3_STEP3_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(2), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(21), `FREG(7), `FREG(15), `FREG(21)};
end
HMMA_SET3_STEP3_0: begin
uop = {NEXT, HMMA_SET3_STEP3_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(3), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(22), `FREG(6), `FREG(14), `FREG(22)};
end
HMMA_SET3_STEP3_1: begin
uop = {FINISH, HMMA_SET0_STEP0_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(3), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(23), `FREG(7), `FREG(15), `FREG(23)};
end

View File

@@ -0,0 +1,163 @@
`include "VX_define.vh"
`define FREG(x) {1'b1, `NRI_BITS'(x)}
module VX_uop_sequencer import VX_gpu_pkg::*; (
input clk,
input reset,
VX_ibuffer_if.slave uop_sequencer_if,
VX_ibuffer_if.master ibuffer_if
);
`ifdef EXT_T_ENABLE
localparam UOP_TABLE_SIZE = 64;
localparam UPC_BITS = `CLOG2(UOP_TABLE_SIZE);
localparam NEXT = 2'b00;
localparam FINISH = 2'b01;
localparam UBR_BITS = 2;
// uop metadata (sequencing, next state), execution metadata (EX_TYPE, OP_TYPE, OP_MOD), wb, use pc, use imm, pc, imm, rd, rs1, rs2, rs3
localparam UOP_TABLE_WIDTH = UBR_BITS + UPC_BITS + `EX_BITS + `INST_OP_BITS + `INST_MOD_BITS + 1 + 1 + 1 + `XLEN + `XLEN + (`NR_BITS * 4);
localparam IBUFFER_IF_DATAW = `UUID_WIDTH + ISSUE_WIS_W + `NUM_THREADS + `XLEN + 1 + `EX_BITS + `INST_OP_BITS + `INST_MOD_BITS + 1 + 1 + `XLEN + (`NR_BITS * 4);
logic [UOP_TABLE_WIDTH-1:0] uop;
// reserve space at start of table for more uop sequences
localparam HMMA_SET0_STEP0_0 = UPC_BITS'(0);
localparam HMMA_SET0_STEP0_1 = UPC_BITS'(8);
localparam HMMA_SET0_STEP1_0 = UPC_BITS'(9);
localparam HMMA_SET0_STEP1_1 = UPC_BITS'(10);
localparam HMMA_SET0_STEP2_0 = UPC_BITS'(11);
localparam HMMA_SET0_STEP2_1 = UPC_BITS'(12);
localparam HMMA_SET0_STEP3_0 = UPC_BITS'(13);
localparam HMMA_SET0_STEP3_1 = UPC_BITS'(14);
localparam HMMA_SET1_STEP0_0 = UPC_BITS'(15);
localparam HMMA_SET1_STEP0_1 = UPC_BITS'(16);
localparam HMMA_SET1_STEP1_0 = UPC_BITS'(17);
localparam HMMA_SET1_STEP1_1 = UPC_BITS'(18);
localparam HMMA_SET1_STEP2_0 = UPC_BITS'(19);
localparam HMMA_SET1_STEP2_1 = UPC_BITS'(20);
localparam HMMA_SET1_STEP3_0 = UPC_BITS'(21);
localparam HMMA_SET1_STEP3_1 = UPC_BITS'(22);
localparam HMMA_SET2_STEP0_0 = UPC_BITS'(23);
localparam HMMA_SET2_STEP0_1 = UPC_BITS'(24);
localparam HMMA_SET2_STEP1_0 = UPC_BITS'(25);
localparam HMMA_SET2_STEP1_1 = UPC_BITS'(26);
localparam HMMA_SET2_STEP2_0 = UPC_BITS'(27);
localparam HMMA_SET2_STEP2_1 = UPC_BITS'(28);
localparam HMMA_SET2_STEP3_0 = UPC_BITS'(29);
localparam HMMA_SET2_STEP3_1 = UPC_BITS'(30);
localparam HMMA_SET3_STEP0_0 = UPC_BITS'(31);
localparam HMMA_SET3_STEP0_1 = UPC_BITS'(32);
localparam HMMA_SET3_STEP1_0 = UPC_BITS'(33);
localparam HMMA_SET3_STEP1_1 = UPC_BITS'(34);
localparam HMMA_SET3_STEP2_0 = UPC_BITS'(35);
localparam HMMA_SET3_STEP2_1 = UPC_BITS'(36);
localparam HMMA_SET3_STEP3_0 = UPC_BITS'(37);
localparam HMMA_SET3_STEP3_1 = UPC_BITS'(38);
// register layout: f0-f7 used for A, f8-f15 used for B, f16-f23 used for C
always @(*) begin
case (upc)
`include "VX_tensor_ucode.vh"
default: begin
uop = '0;
end
endcase
end
logic [UPC_BITS-1:0] upc, upc_r, upc_n;
wire [UBR_BITS-1:0] ubr = uop[UOP_TABLE_WIDTH-1:UOP_TABLE_WIDTH-UBR_BITS];
wire [UPC_BITS-1:0] next_upc = uop[UOP_TABLE_WIDTH-UBR_BITS-1:UOP_TABLE_WIDTH-UBR_BITS-UPC_BITS];
logic use_uop, use_uop_1d;
wire uop_fire = use_uop && ibuffer_if.valid && ibuffer_if.ready;
wire uop_start = ~use_uop_1d && use_uop;
wire uop_finish = use_uop && uop_sequencer_if.valid && uop_sequencer_if.ready;
// merging the 2 always blocks leads to spurious UNOPTFLAT verilator lint, but conceptually they should be linked
always @(*) begin
use_uop = uop_sequencer_if.valid && uop_sequencer_if.data.ex_type == `EX_TENSOR;
if (uop_start) begin
// 1st cycle of microcoded operation, use op_type to determine entry point into microcode table
upc_n = UPC_BITS'(uop_sequencer_if.data.op_type);
end
else begin
upc_n = upc;
end
if (uop_fire) begin
upc_n = next_upc;
end
end
always @(*) begin
if (uop_start) begin
// 1st cycle of microcoded operation, use op_type to determine entry point into microcode table
upc = UPC_BITS'(uop_sequencer_if.data.op_type);
end
else begin
upc = upc_r;
end
end
// copy UUID, wis, tmask from microcoded instruction
wire [IBUFFER_IF_DATAW-1:0] ibuffer_output = {
uop_sequencer_if.data.uuid,
uop_sequencer_if.data.wis,
uop_sequencer_if.data.tmask,
uop[UOP_TABLE_WIDTH-UBR_BITS-UPC_BITS-1:0]
};
assign ibuffer_if.valid = use_uop ? 1'b1 : uop_sequencer_if.valid;
assign uop_sequencer_if.ready = use_uop ? (uop_fire && ubr == FINISH) : ibuffer_if.ready;
assign ibuffer_if.data = use_uop ? ibuffer_output : uop_sequencer_if.data;
always @(posedge clk) begin
if (uop_start) begin
$display("UOP start @ %t", $time);
$display("use_uop=%0d, use_uop_1d=%0d, uop_start=%0d, ibuffer_if.valid=%0d, ibuffer_if.ready=%0d", use_uop, use_uop_1d, uop_start, ibuffer_if.valid, ibuffer_if.ready);
end
if (uop_fire) begin
$display("UOP fire @ %t", $time);
end
if (uop_finish) begin
$display("UOP finish @ %t", $time);
end
if (reset) begin
upc_r <= '0;
use_uop_1d <= '0;
end
else begin
upc_r <= upc_n;
if (uop_finish) begin
use_uop_1d <= 1'b0; // allow microcoded instructions to start immediately after eachother
end
else begin
use_uop_1d <= use_uop;
end
end
end
`else
`UNUSED_VAR(clk);
`UNUSED_VAR(reset);
assign ibuffer_if.valid = uop_sequencer_if.valid;
assign uop_sequencer_if.ready = ibuffer_if.ready;
assign ibuffer_if.data = uop_sequencer_if.data;
`endif
endmodule

View File

@@ -0,0 +1,81 @@
num_sets = 4
num_steps = 4
num_substeps = 2
def set_step_substep(sequence_number):
set_num, step = divmod(sequence_number, num_steps * num_substeps)
step //= num_substeps
substep = sequence_number % 2
return set_num, step, substep
# set + substep -> rs1, rs2
rs1 = {
(0, 0): 0,
(0, 1): 1,
(1, 0): 2,
(1, 1): 3,
(2, 0): 4,
(2, 1): 5,
(3, 0): 6,
(3, 1): 7
}
rs2 = {
(0, 0): 8,
(0, 1): 9,
(1, 0): 10,
(1, 1): 11,
(2, 0): 12,
(2, 1): 13,
(3, 0): 14,
(3, 1): 15
}
# step + substep -> rs3, rd
rs3_rd = {
(0, 0): 16,
(0, 1): 17,
(1, 0): 18,
(1, 1): 19,
(2, 0): 20,
(2, 1): 21,
(3, 0): 22,
(3, 1): 23
}
with open('VX_tensor_ucode.vh', 'w') as f:
for sequence_number in range(num_sets * num_steps * num_substeps):
set_num, step, substep = set_step_substep(sequence_number)
next_sequence_num = (sequence_number + 1) % (num_sets * num_steps * num_substeps)
next_set_num, next_step, next_substep = set_step_substep(next_sequence_num)
finish = (next_sequence_num == 0)
name = "HMMA_SET{}_STEP{}_{}"
ucode = "{}, HMMA_SET{}_STEP{}_{}, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'({}), `INST_MOD_BITS'({}), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG({}), `FREG({}), `FREG({}), `FREG({})"
name = name.format(
set_num, step, substep,
)
ucode = ucode.format(
"FINISH" if finish else "NEXT",
next_set_num, next_step, next_substep,
step,
substep,
rs3_rd[(step, substep)],
rs1[(set_num, substep)],
rs2[(set_num, substep)],
rs3_rd[(step, substep)],
)
entry = f"{name}: begin \n"
entry += "\tuop = {"
entry += ucode
entry += "}; \n"
entry += "end \n"
f.write(entry)

View File

@@ -0,0 +1,37 @@
`include "VX_fpu_define.vh"
module VX_tensor_dpu #(
) (
input clk,
input reset,
input stall,
input valid_in,
input [3:0][1:0][31:0] A_tile,
input [1:0][3:0][31:0] B_tile,
input [3:0][3:0][31:0] C_tile,
output valid_out,
output [3:0][3:0][31:0] D_tile
);
logic [3:0][3:0][31:0] result_hmma;
always @(*) begin
dpi_hmma(valid_in, A_tile, B_tile, C_tile, result_hmma);
end
VX_shift_register #(
.DATAW (1 + $bits(D_tile)),
.DEPTH (`LATENCY_HMMA),
.RESETW (1)
) shift_reg (
.clk (clk),
.reset (reset),
.enable (~stall),
.data_in ({valid_in, result_hmma}),
.data_out ({valid_out, D_tile})
);
endmodule

View File

@@ -0,0 +1,30 @@
`include "VX_fpu_define.vh"
module VX_tensor_tb(
input clk,
input reset,
input valid_in,
input [3:0][1:0][31:0] A_tile,
input [1:0][3:0][31:0] B_tile,
input [3:0][3:0][31:0] C_tile,
output valid_out,
output [3:0][3:0][31:0] D_tile
);
VX_tensor_dpu #() tensor_core (
.clk(clk),
.reset(reset),
.stall(1'b0),
.valid_in(valid_in),
.A_tile(A_tile),
.B_tile(B_tile),
.C_tile(C_tile),
.valid_out(valid_out),
.D_tile(D_tile)
);
endmodule

View File

@@ -0,0 +1,89 @@
DESTDIR ?= .
RTL_DIR = ../../rtl
DPI_DIR = $(abspath ../../dpi)
SIM_DIR = ../../../sim
THIRD_PARTY_DIR = $(abspath ../../../third_party)
CONFIGS +=
PARAMS +=
CXXFLAGS += -std=c++17 -Wall -Wextra -Wfatal-errors -Wno-array-bounds
CXXFLAGS += -fPIC -Wno-maybe-uninitialized
CXXFLAGS += -fcoroutines
CXXFLAGS += -I../../.. -I../../common -I../../../../sim/common
CXXFLAGS += -I/$(THIRD_PARTY_DIR)/softfloat/source/include
CXXFLAGS += -I/$(DPI_DIR)
CXXFLAGS += $(CONFIGS)
LDFLAGS += $(THIRD_PARTY_DIR)/softfloat/build/Linux-x86_64-GCC/softfloat.a
# control RTL debug tracing states
DBG_TRACE_FLAGS += -DDBG_TRACE_CACHE_BANK
DBG_TRACE_FLAGS += -DDBG_TRACE_CACHE_MSHR
DBG_TRACE_FLAGS += -DDBG_TRACE_CACHE_TAG
DBG_TRACE_FLAGS += -DDBG_TRACE_CACHE_DATA
DBG_FLAGS += -DDEBUG_LEVEL=$(DEBUG) -DVCD_OUTPUT $(DBG_TRACE_FLAGS)
RTL_PKGS = $(RTL_DIR)/VX_gpu_pkg.sv
RTL_INCLUDE = -I$(RTL_DIR) -I$(DPI_DIR) -I$(RTL_DIR)/libs -I$(RTL_DIR)/fpu
# SRCS = cachesim.cpp testbench.cpp
SRCS += $(DPI_DIR)/util_dpi.cpp
SRCS += $(DPI_DIR)/float_dpi.cpp
SRCS += $(SIM_DIR)/common/rvfloats.cpp
SRCS += ./main.cpp
RTL_SRCS += $(RTL_DIR)/fpu/VX_tensor_dpu.sv
RTL_SRCS += $(RTL_DIR)/fpu/VX_tensor_tb.sv
TOP = VX_tensor_tb
VL_FLAGS = --exe
VL_FLAGS += --language 1800-2009 # -Wall -Wpedantic # --assert
VL_FLAGS += -Wno-DECLFILENAME -Wno-REDEFMACRO
VL_FLAGS += --x-initial unique --x-assign unique
VL_FLAGS += -DSIMULATION -DSV_DPI
VL_FLAGS += $(CONFIGS)
VL_FLAGS += $(PARAMS)
VL_FLAGS += $(RTL_INCLUDE)
VL_FLAGS += $(RTL_PKGS)
VL_FLAGS += --cc $(TOP) --top-module $(TOP)
VL_FLAGS += --timing
# Enable Verilator multithreaded simulation
THREADS ?= $(shell python -c 'import multiprocessing as mp; print(mp.cpu_count())')
VL_FLAGS += -j $(THREADS)
#VL_FLAGS += --threads $(THREADS)
# Debugigng
ifdef DEBUG
VL_FLAGS += --trace --trace-structs $(DBG_FLAGS)
CXXFLAGS += -g -O0 $(DBG_FLAGS)
else
VL_FLAGS += -DNDEBUG
CXXFLAGS += -O2 -DNDEBUG
endif
# Enable perf counters
ifdef PERF
VL_FLAGS += -DPERF_ENABLE
CXXFLAGS += -DPERF_ENABLE
endif
PROJECT = tensor
all: $(DESTDIR)/$(PROJECT)
$(DESTDIR)/$(PROJECT): $(SRCS) $(RTL_SRCS)
verilator --build $(VL_FLAGS) $(SRCS) -CFLAGS '$(CXXFLAGS)' -LDFLAGS '$(LDFLAGS)' -o ../$@
run: $(DESTDIR)/$(PROJECT)
$(DESTDIR)/$(PROJECT)
waves: trace.vcd
gtkwave -o trace.vcd
clean:
rm -rf obj_dir $(DESTDIR)/$(PROJECT)

197
hw/unittest/tensor/main.cpp Normal file
View File

@@ -0,0 +1,197 @@
// Copyright © 2019-2023
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "vl_simulator.h"
#include "VVX_tensor_tb.h"
#include <iostream>
#include <half.hpp>
#define MAX_TICKS 20
#ifndef TRACE_START_TIME
#define TRACE_START_TIME 0ull
#endif
#ifndef TRACE_STOP_TIME
#define TRACE_STOP_TIME -1ull
#endif
#define CHECK(x) \
do { \
if (x) \
break; \
std::cout << "FAILED: " << #x << std::endl; \
std::abort(); \
} while (false)
static uint64_t timestamp = 0;
static bool trace_enabled = false;
static uint64_t trace_start_time = TRACE_START_TIME;
static uint64_t trace_stop_time = TRACE_STOP_TIME;
double sc_time_stamp() {
return timestamp;
}
bool sim_trace_enabled() {
if (timestamp >= trace_start_time
&& timestamp < trace_stop_time)
return true;
return trace_enabled;
}
void sim_trace_enable(bool enable) {
trace_enabled = enable;
}
using Device = VVX_tensor_tb;
using half_float::half;
static_assert(sizeof(half) == 2);
uint32_t half2bits(half h) {
uint16_t half_bits;
memcpy(&half_bits, &h, sizeof(half));
return half_bits;
}
uint32_t float2bits(float f) {
uint32_t float_bits;
memcpy(&float_bits, &f, sizeof(f));
return float_bits;
}
float bits2float(uint32_t b) {
float f;
memcpy(&f, &b, sizeof(b));
return f;
}
// A is M * K, B is K * K * M, C is M * M, D is M * M
#define M 4
#define K 2
// row, column
float A_tile[M][K];
float B_tile[K][M];
float C_tile[M][M];
float D_tile[M][M];
void initialize_test_data() {
for (int i = 0; i < M; i += 1) {
for (int j = 0; j < K; j += 1) {
A_tile[i][j] = (float) (i * K + j);
}
}
for (int i = 0; i < K; i += 1) {
for (int j = 0; j < M; j += 1) {
B_tile[i][j] = (float) (j * K + i);
}
}
for (int i = 0; i < M; i += 1) {
for (int j = 0; j < M; j += 1) {
C_tile[i][j] = (float) (i * j);
}
}
}
void write_test_data(vl_simulator<Device>& sim) {
for (int i = 0; i < M; i += 1) {
for (int j = 0; j < K; j += 1) {
int index = (i * K + j);
uint32_t A_bits = float2bits(A_tile[i][j]);
sim->A_tile[index] = A_bits;
}
}
for (int i = 0; i < K; i += 1) {
for (int j = 0; j < M; j += 1) {
int index = (i * M + j);
uint32_t B_bits = float2bits(B_tile[i][j]);
sim->B_tile[index] = B_bits;
}
}
for (int i = 0; i < M; i += 1) {
for (int j = 0; j < M; j += 1) {
int index = (i * M + j);
uint32_t C_bits = float2bits(C_tile[i][j]);
sim->C_tile[index] = C_bits;
}
}
}
void read_result(vl_simulator<Device>& sim) {
for (int i = 0; i < M; i += 1) {
for (int j = 0; j < M; j += 1) {
int index = (i * M + j);
uint32_t D_bits = sim->D_tile[index];
float f = bits2float(D_bits);
D_tile[i][j] = f;
std::cout << f << " ";
}
std::cout << std::endl;
}
}
void expected() {
for (int i = 0; i < M; i += 1) {
for (int j = 0; j < M; j += 1) {
float accum = C_tile[i][j];
for (int k = 0; k < K; k += 1) {
accum += A_tile[i][k] * B_tile[k][j];
}
std::cout << accum << " ";
}
std::cout << std::endl;
}
}
int main(int argc, char **argv) {
// Initialize Verilators variables
Verilated::commandArgs(argc, argv);
vl_simulator<Device> sim;
initialize_test_data();
// run test
timestamp = sim.reset(0);
// advance clock
timestamp = sim.step(timestamp, 10);
sim->valid_in = 1;
write_test_data(sim);
timestamp = sim.step(timestamp, 2);
CHECK(sim->valid_out == 0);
sim->valid_in = 0;
timestamp = sim.step(timestamp, 2);
CHECK(sim->valid_out == 0);
timestamp = sim.step(timestamp, 2);
CHECK(sim->valid_out == 0);
timestamp = sim.step(timestamp, 2);
CHECK(sim->valid_out == 1);
read_result(sim);
timestamp = sim.step(timestamp, 2);
CHECK(sim->valid_out == 0);
expected();
std::cout << "PASSED!" << std::endl;
std::cout << "Simulation time: " << std::dec << timestamp/2 << " cycles" << std::endl;
return 0;
}

View File

@@ -1,19 +1,23 @@
all:
$(MAKE) -C conform
$(MAKE) -C hello
$(MAKE) -C fibonacci
$(MAKE) -C fibonacci
$(MAKE) -C reductions
run-simx:
$(MAKE) -C conform run-simx
$(MAKE) -C hello run-simx
$(MAKE) -C fibonacci run-simx
$(MAKE) -C reductions run-simx
run-rtlsim:
$(MAKE) -C conform run-rtlsim
$(MAKE) -C hello run-rtlsim
$(MAKE) -C fibonacci run-rtlsim
$(MAKE) -C reductions run-rtlsim
clean:
$(MAKE) -C conform clean
$(MAKE) -C hello clean
$(MAKE) -C fibonacci clean
$(MAKE) -C reductions clean

View File

@@ -33,7 +33,7 @@ $(PROJECT).dump: $(PROJECT).elf
$(PROJECT).bin: $(PROJECT).elf
$(CP) -O binary $(PROJECT).elf $(PROJECT).bin
$(PROJECT).elf: $(SRCS)
$(PROJECT).elf: $(SRCS) $(DEPS)
$(CC) $(CFLAGS) $(SRCS) $(LDFLAGS) -o $(PROJECT).elf
run-rtlsim: $(PROJECT).bin

View File

@@ -0,0 +1,5 @@
PROJECT = reductions
SRCS = main.cpp
include ../common.mk

View File

@@ -0,0 +1,216 @@
#define RISCV_CUSTOM2 0x5B
#define ADD_FUNC7 0b0000000
#define ADDU_FUNC7 0b1000000
#define MIN_FUNC7 0b0000001
#define MINU_FUNC7 0b1000001
#define MAX_FUNC7 0b0000010
#define MAXU_FUNC7 0b1000010
#define AND_FUNC7 0b0000011
#define OR_FUNC7 0b0000100
#define XOR_FUNC7 0b0000101
/*
6'h0: begin
op_type = func7[6] ? `INST_RED_ADDU : `INST_RED_ADD;
end
6'h1: begin
op_type = func7[6] ? `INST_RED_MINU : `INST_RED_MIN;
end
6'h2: begin
op_type = func7[6] ? `INST_RED_MAXU : `INST_RED_MAX;
end
6'h3: begin
op_type = `INST_RED_AND;
end
6'h4: begin
op_type = `INST_RED_OR;
end
6'h5: begin
op_type = `INST_RED_XOR;
end
*/
#include <vx_intrinsics.h>
#include <stdio.h>
#include <vx_print.h>
int x[4] = {3, 7, 2, 5};
int y = -1;
inline int vx_add_reduce(int v) {
int ret;
asm volatile (".insn r %2, 0, %3, %0, %1, x0" : "=r"(ret) : "r"(v), "i"(RISCV_CUSTOM2), "i"(ADD_FUNC7));
return ret;
}
inline int vx_min_reduce(int v) {
int ret;
asm volatile (".insn r %2, 0, %3, %0, %1, x0" : "=r"(ret) : "r"(v), "i"(RISCV_CUSTOM2), "i"(MIN_FUNC7));
return ret;
}
inline unsigned vx_minu_reduce(unsigned v) {
unsigned ret;
asm volatile (".insn r %2, 0, %3, %0, %1, x0" : "=r"(ret) : "r"(v), "i"(RISCV_CUSTOM2), "i"(MINU_FUNC7));
return ret;
}
inline int vx_max_reduce(int v) {
int ret;
asm volatile (".insn r %2, 0, %3, %0, %1, x0" : "=r"(ret) : "r"(v), "i"(RISCV_CUSTOM2), "i"(MAX_FUNC7));
return ret;
}
inline unsigned vx_maxu_reduce(unsigned v) {
unsigned ret;
asm volatile (".insn r %2, 0, %3, %0, %1, x0" : "=r"(ret) : "r"(v), "i"(RISCV_CUSTOM2), "i"(MAXU_FUNC7));
return ret;
}
inline unsigned vx_and_reduce(unsigned v) {
unsigned ret;
asm volatile (".insn r %2, 0, %3, %0, %1, x0" : "=r"(ret) : "r"(v), "i"(RISCV_CUSTOM2), "i"(AND_FUNC7));
return ret;
}
inline unsigned vx_or_reduce(unsigned v) {
unsigned ret;
asm volatile (".insn r %2, 0, %3, %0, %1, x0" : "=r"(ret) : "r"(v), "i"(RISCV_CUSTOM2), "i"(OR_FUNC7));
return ret;
}
inline unsigned vx_xor_reduce(unsigned v) {
unsigned ret;
asm volatile (".insn r %2, 0, %3, %0, %1, x0" : "=r"(ret) : "r"(v), "i"(RISCV_CUSTOM2), "i"(XOR_FUNC7));
return ret;
}
void test_add_reduce() {
vx_tmc(-1);
int tid = vx_thread_id();
int v = x[tid];
int reduced = vx_add_reduce(v);
vx_tmc(1);
y = reduced;
}
unsigned unsigned_vector[4] = {(unsigned)-1, 0, (unsigned)-2, 5};
void test_min_reduce() {
vx_tmc(-1);
int tid = vx_thread_id();
int v = unsigned_vector[tid];
int reduced = vx_min_reduce(v);
vx_tmc(1);
y = reduced;
}
void test_max_reduce() {
vx_tmc(-1);
int tid = vx_thread_id();
int v = unsigned_vector[tid];
int reduced = vx_max_reduce(v);
vx_tmc(1);
y = reduced;
}
void test_minu_reduce() {
vx_tmc(-1);
int tid = vx_thread_id();
unsigned v = unsigned_vector[tid];
unsigned reduced = vx_minu_reduce(v);
vx_tmc(1);
y = reduced;
}
void test_maxu_reduce() {
vx_tmc(-1);
int tid = vx_thread_id();
unsigned v = unsigned_vector[tid];
unsigned reduced = vx_maxu_reduce(v);
vx_tmc(1);
y = reduced;
}
unsigned bit_vectors[4] = {0b11010110000111001100010100100110, 0b10010100011010001010000000001110, 0b10001001010111110001110000000010, 0b00010011010100101101110111001111};
void test_and_reduce() {
vx_tmc(-1);
int tid = vx_thread_id();
unsigned v = bit_vectors[tid];
unsigned reduced = vx_and_reduce(v);
vx_tmc(1);
y = reduced;
}
void test_or_reduce() {
vx_tmc(-1);
int tid = vx_thread_id();
unsigned v = bit_vectors[tid];
unsigned reduced = vx_or_reduce(v);
vx_tmc(1);
y = reduced;
}
void test_xor_reduce() {
vx_tmc(-1);
int tid = vx_thread_id();
unsigned v = bit_vectors[tid];
unsigned reduced = vx_xor_reduce(v);
vx_tmc(1);
y = reduced;
}
int main()
{
int expected;
test_add_reduce();
vx_printf("add reduce result: %d\n", y);
vx_printf("expected: %d\n", x[0] + x[1] + x[2] + x[3]);
test_min_reduce();
vx_printf("min reduce result: %d\n", y);
expected = MIN((int)unsigned_vector[0], MIN((int)unsigned_vector[1], MIN((int)unsigned_vector[2], (int)unsigned_vector[3])));
vx_printf("expected: %d\n", expected);
test_max_reduce();
vx_printf("max reduce result: %d\n", y);
expected = MAX((int)unsigned_vector[0], MAX((int)unsigned_vector[1], MAX((int)unsigned_vector[2], (int)unsigned_vector[3])));
vx_printf("expected: %d\n", expected);
test_minu_reduce();
vx_printf("minu reduce result: %d\n", y);
expected = MIN(unsigned_vector[0], MIN(unsigned_vector[1], MIN(unsigned_vector[2], unsigned_vector[3])));
vx_printf("expected: %d\n", expected);
test_maxu_reduce();
vx_printf("maxu reduce result: %d\n", y);
expected = MAX(unsigned_vector[0], MAX(unsigned_vector[1], MAX(unsigned_vector[2], unsigned_vector[3])));
vx_printf("expected: %d\n", expected);
test_and_reduce();
vx_printf("and reduce result: %d\n", y);
vx_printf("expected: %d\n", bit_vectors[0] & bit_vectors[1] & bit_vectors[2] & bit_vectors[3]);
test_or_reduce();
vx_printf("or reduce result: %d\n", y);
vx_printf("expected: %d\n", bit_vectors[0] | bit_vectors[1] | bit_vectors[2] | bit_vectors[3]);
test_xor_reduce();
vx_printf("xor reduce result: %d\n", y);
vx_printf("expected: %d\n", bit_vectors[0] ^ bit_vectors[1] ^ bit_vectors[2] ^ bit_vectors[3]);
return 0;
}

View File

@@ -0,0 +1,8 @@
PROJECT = tensor
SRCS = main.cpp
DEPS = a_matrix.h
DEPS += b_matrix.h
DEPS += c_matrix.h
include ../common.mk

View File

@@ -0,0 +1,95 @@
import numpy as np
import struct
A_array = np.zeros((16, 8))
B_array = np.zeros((8, 16))
C_array = np.zeros((16, 16))
file = input("simulator output filename: ")
def hex2float(float_hex_str):
# print(float_hex_str.strip())
return struct.unpack(">f",struct.pack(">i",int(float_hex_str,16)))[0]
def C_index(threadgroup, thread, register):
"""
col = ((tg % 4) / 2) * 8;
row = (tg * 8) % 16;
row += (tg / 4) * 4;
asm volatile ("flw f16, %0" :: "m"(C[row+0][col+0]));
asm volatile ("flw f17, %0" :: "m"(C[row+0][col+1]));
asm volatile ("flw f18, %0" :: "m"(C[row+2][col+0]));
asm volatile ("flw f19, %0" :: "m"(C[row+2][col+1]));
asm volatile ("flw f20, %0" :: "m"(C[row+0][col+4]));
asm volatile ("flw f21, %0" :: "m"(C[row+0][col+5]));
asm volatile ("flw f22, %0" :: "m"(C[row+2][col+4]));
asm volatile ("flw f23, %0" :: "m"(C[row+2][col+5]));
"""
col = ((threadgroup % 4) // 2) * 8
row = (threadgroup * 8) % 16
row += (threadgroup // 4) * 4
offsets = [(0, 0), (0, 1), (2, 0), (2, 1), (0, 4), (0, 5), (2, 4), (2, 5)]
offset = offsets[register-16]
row += offset[0]
col += offset[1]
thread_offsets = [(0, 0), (1, 0), (0, 2), (1, 2)]
thread_offset = thread_offsets[thread % 4]
row += thread_offset[0]
col += thread_offset[1]
if C_array[row, col] != 0:
print("bad")
return (row, col)
with open(file) as f:
for line in f.readlines():
line = line.strip()
if "warp" in line:
a, b, c = line.split(',')
_, a = a.split(' ')
_, b = b.strip().split(' ')
c, d = c.strip().split(':')
_, c = c.split(' ')
warp = int(a)
thread = int(b)
register = int(c)
value = d.strip()
if warp != 0:
continue
if not (32 <= register < 32+24):
continue
register = register - 32
# threadgroups 0, 4, 1, 5 have all elements of A
threadgroup = thread // 4
if threadgroup in [0, 4, 1, 5]:
row = [0, 4, 1, 5].index(threadgroup) * 4 + thread % 4
if 0 <= register < 8:
A_array[row, register] = hex2float(value)
if threadgroup in [0, 4, 2, 6]:
col = [0, 4, 2, 6].index(threadgroup) * 4 + thread % 4
if 8 <= register < 16:
B_array[register-8, col] = hex2float(value)
if 16 <= register < 24:
# print(value)
C_array[C_index(threadgroup, thread, register)] = hex2float(value)
expected = np.load("abc.npz")
expected_A = expected['A_array']
expected_B = expected['B_array']
expected_C = expected['C_array']
expected_C = expected_C + expected_A @ expected_B
print(expected_C[0:8, 0:8])
print(C_array[0:8, 0:8])
print((expected_C - C_array)[0:8, 0:8])
assert np.allclose(expected_A, A_array)
assert np.allclose(expected_B, B_array)
assert np.allclose(expected_C, C_array)

View File

@@ -0,0 +1,32 @@
import numpy as np
A_array = np.random.rand(16, 8)
B_array = np.random.rand(8, 16)
C_array = np.random.rand(16, 16)
# A_array = np.zeros((16, 8))
# B_array = np.zeros((8, 16))
# A_array[0,:] = 1.0
# B_array[:,4] = 1.0
# C_array = np.zeros((16, 16))
# for i in range(16):
# for j in range(16):
# C_array[i,j] = i * 16 + j
with open('a_matrix.h', 'w') as f:
for i in range(A_array.shape[0]):
for j in range(A_array.shape[1]):
f.write(f'{A_array[i,j]}f, ')
f.write('\n')
with open('b_matrix.h', 'w') as f:
for i in range(B_array.shape[0]):
for j in range(B_array.shape[1]):
f.write(f'{B_array[i,j]}f, ')
f.write('\n')
with open('c_matrix.h', 'w') as f:
for i in range(C_array.shape[0]):
for j in range(C_array.shape[1]):
f.write(f'{C_array[i,j]}f, ')
f.write('\n')
np.savez("abc", A_array=A_array, B_array=B_array, C_array=C_array)

View File

@@ -0,0 +1,96 @@
#define RISCV_CUSTOM3 0x7B
#include <vx_intrinsics.h>
#include <stdio.h>
#include <vx_print.h>
inline void vx_wmma() {
asm volatile (".insn r %0, 0, 0, x0, x0, x0" :: "i"(RISCV_CUSTOM3));
}
#include "test_data.h"
void vx_wmma_load() {
int tid = vx_thread_id();
int tg = tid / 4;
// load A
int row = tid % 4;
row += (tg * 8) % 16;
row += (tg / 4) * 4;
asm volatile ("flw f0, %0" :: "m"(A[row][0]));
asm volatile ("flw f1, %0" :: "m"(A[row][1]));
asm volatile ("flw f2, %0" :: "m"(A[row][2]));
asm volatile ("flw f3, %0" :: "m"(A[row][3]));
asm volatile ("flw f4, %0" :: "m"(A[row][4]));
asm volatile ("flw f5, %0" :: "m"(A[row][5]));
asm volatile ("flw f6, %0" :: "m"(A[row][6]));
asm volatile ("flw f7, %0" :: "m"(A[row][7]));
// load B
int col = tid % 4;
col += ((tg % 4) / 2) * 8;
col += (tg / 4) * 4;
asm volatile ("flw f8 , %0" :: "m"(B[0][col]));
asm volatile ("flw f9 , %0" :: "m"(B[1][col]));
asm volatile ("flw f10, %0" :: "m"(B[2][col]));
asm volatile ("flw f11, %0" :: "m"(B[3][col]));
asm volatile ("flw f12, %0" :: "m"(B[4][col]));
asm volatile ("flw f13, %0" :: "m"(B[5][col]));
asm volatile ("flw f14, %0" :: "m"(B[6][col]));
asm volatile ("flw f15, %0" :: "m"(B[7][col]));
// load C
col = ((tg % 4) / 2) * 8;
row = (tg * 8) % 16;
row += (tg / 4) * 4;
row += (tid % 4) % 2;
col += ((tid % 4) / 2) * 2;
asm volatile ("flw f16, %0" :: "m"(C[row+0][col+0]));
asm volatile ("flw f17, %0" :: "m"(C[row+0][col+1]));
asm volatile ("flw f18, %0" :: "m"(C[row+2][col+0]));
asm volatile ("flw f19, %0" :: "m"(C[row+2][col+1]));
asm volatile ("flw f20, %0" :: "m"(C[row+0][col+4]));
asm volatile ("flw f21, %0" :: "m"(C[row+0][col+5]));
asm volatile ("flw f22, %0" :: "m"(C[row+2][col+4]));
asm volatile ("flw f23, %0" :: "m"(C[row+2][col+5]));
}
float results[32*8];
void store_wmma_result() {
int tid = vx_thread_id();
asm volatile ("fsw f16, %0" :: "m"(results[tid*8+0]));
asm volatile ("fsw f17, %0" :: "m"(results[tid*8+1]));
asm volatile ("fsw f18, %0" :: "m"(results[tid*8+2]));
asm volatile ("fsw f19, %0" :: "m"(results[tid*8+3]));
asm volatile ("fsw f20, %0" :: "m"(results[tid*8+4]));
asm volatile ("fsw f21, %0" :: "m"(results[tid*8+5]));
asm volatile ("fsw f22, %0" :: "m"(results[tid*8+6]));
asm volatile ("fsw f23, %0" :: "m"(results[tid*8+7]));
}
void print_wmma_result() {
for (int tid = 0; tid < 32; tid += 1) {
for (int reg = 0; reg < 8; reg += 1) {
vx_printf("thread %d, f%d: %x\n", tid, 16+reg, *((int*) &results[tid*8+reg]));
}
}
}
int main()
{
vx_tmc(-1);
vx_wmma_load();
vx_wmma();
store_wmma_result();
vx_tmc(1);
print_wmma_result();
return 0;
}

View File

@@ -0,0 +1,11 @@
float A[16][8] = {
#include "a_matrix.h"
};
float B[8][16] = {
#include "b_matrix.h"
};
float C[16][16] = {
#include "c_matrix.h"
};