This commit is contained in:
Richard Yan
2024-06-09 15:15:31 -07:00
13 changed files with 959 additions and 136 deletions

View File

@@ -347,7 +347,7 @@ void dpi_fmax(bool enable, int dst_fmt, int64_t a, int64_t b, int64_t* result, s
// A is M * K, B is K * M, C is M * M, D is M * M
#define M 4
#define K 2
#define K 2 // FIXME: 4x4x1 / cycle / octet!
// all row major
float c_A_tile[M][K];
@@ -358,6 +358,15 @@ float c_D_tile[M][M];
// code assumes that svBitVecVal is basically a uint32_t
static_assert(sizeof(svBitVecVal) == 4);
void clear_float_array(float* c_tile, int rows, int cols) {
for (int i = 0; i < rows; i += 1) {
for (int j = 0; j < cols; j += 1) {
int index = i * cols + j;
c_tile[index] = 0.0f;
}
}
}
void fill_float_array(const svBitVecVal* sv_tile, float* c_tile, int rows, int cols) {
for (int i = 0; i < rows; i += 1) {
@@ -396,6 +405,11 @@ void dpi_hmma(bool enable, const svBitVecVal* A_tile, const svBitVecVal* B_tile,
if (!enable) {
return;
}
clear_float_array(&c_A_tile[0][0], M, K);
clear_float_array(&c_B_tile[0][0], K, M);
clear_float_array(&c_C_tile[0][0], M, M);
clear_float_array(&c_D_tile[0][0], M, M);
// std::cout << "A: " << std::endl;
fill_float_array(A_tile, &c_A_tile[0][0], M, K);
// std::cout << "B: " << std::endl;
@@ -551,7 +565,7 @@ void dpi_print_results(int wid, int octet, const svBitVecVal* A_tile, const svBi
}
steps[wid] += 1;
if (steps[wid] % 64 == 0) {
if (steps[wid] % 32 == 0) {
steps[wid] = 0;
std::cout << "warp " << wid << " finished wmma\n";
std::cout << "A tile" << "\n";

View File

@@ -391,7 +391,7 @@
// Tensor Core Latency
`ifndef LATENCY_HMMA
`define LATENCY_HMMA 8
`define LATENCY_HMMA 4
`endif
// Icache Configurable Knobs //////////////////////////////////////////////////

View File

@@ -14,7 +14,7 @@
`ifndef VX_PLATFORM_VH
`define VX_PLATFORM_VH
// synthesis only
// enable synthesizable build if SIMULATION not explicitly defined
`ifndef SIMULATION
`define SYNTHESIS
`define NDEBUG

View File

@@ -42,7 +42,7 @@ module VX_alu_unit #(
`RESET_RELAY (dispatch_reset, reset);
VX_dispatch_unit #(
VX_dispatch_unit_sane #(
.BLOCK_SIZE (BLOCK_SIZE),
.NUM_LANES (NUM_LANES),
.OUT_REG (PARTIAL_BW ? 1 : 0)

View File

@@ -545,6 +545,12 @@ module VX_decode #(
`INST_EXT4: begin
ex_type = `EX_TENSOR;
op_type = `INST_TENSOR_HMMA;
// tensor core macroop is encoded as r-type
use_rd = 1;
`USED_IREG (rd);
`USED_IREG (rs1);
`USED_IREG (rs2);
`USED_IREG (rs3);
end
`endif
default:;

View File

@@ -0,0 +1,274 @@
// Copyright © 2019-2023
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
`include "VX_define.vh"
module VX_dispatch_unit_sane import VX_gpu_pkg::*; #(
parameter BLOCK_SIZE = 1,
parameter NUM_LANES = 1,
parameter OUT_REG = 0,
parameter MAX_FANOUT = `MAX_FANOUT
) (
input wire clk,
input wire reset,
// inputs
VX_dispatch_if.slave dispatch_if [`ISSUE_WIDTH],
// outputs
VX_execute_if.master execute_if [BLOCK_SIZE]
);
`STATIC_ASSERT ((`NUM_THREADS == NUM_LANES * (`NUM_THREADS / NUM_LANES)), ("invalid parameter"))
localparam BLOCK_SIZE_W = `LOG2UP(BLOCK_SIZE);
localparam NUM_PACKETS = `NUM_THREADS / NUM_LANES;
localparam PID_BITS = `CLOG2(NUM_PACKETS);
localparam PID_WIDTH = `UP(PID_BITS);
localparam BATCH_COUNT = `ISSUE_WIDTH / BLOCK_SIZE;
localparam BATCH_COUNT_W= `LOG2UP(BATCH_COUNT);
localparam ISSUE_W = `LOG2UP(`ISSUE_WIDTH);
localparam IN_DATAW = `UUID_WIDTH + ISSUE_WIS_W + `NUM_THREADS + `INST_OP_BITS + `INST_MOD_BITS + 1 + 1 + 1 + `XLEN + `XLEN + `NR_BITS + `NT_WIDTH + (3 * `NUM_THREADS * `XLEN);
localparam OUT_DATAW = `UUID_WIDTH + `NW_WIDTH + NUM_LANES + `INST_OP_BITS + `INST_MOD_BITS + 1 + 1 + 1 + `XLEN + `XLEN + `NR_BITS + `NT_WIDTH + (3 * NUM_LANES * `XLEN) + PID_WIDTH + 1 + 1;
localparam FANOUT_ENABLE= (`NUM_THREADS > (MAX_FANOUT + MAX_FANOUT/2));
localparam DATA_TMASK_OFF = IN_DATAW - (`UUID_WIDTH + ISSUE_WIS_W + `NUM_THREADS);
localparam DATA_REGS_OFF = 0;
wire [`ISSUE_WIDTH-1:0] dispatch_valid;
wire [`ISSUE_WIDTH-1:0][IN_DATAW-1:0] dispatch_data;
wire [`ISSUE_WIDTH-1:0] dispatch_ready;
for (genvar i = 0; i < `ISSUE_WIDTH; ++i) begin
assign dispatch_valid[i] = dispatch_if[i].valid;
assign dispatch_data[i] = dispatch_if[i].data;
assign dispatch_if[i].ready = dispatch_ready[i];
end
wire [BLOCK_SIZE-1:0][ISSUE_W-1:0] issue_indices;
wire [BLOCK_SIZE-1:0] block_ready;
wire [BLOCK_SIZE-1:0][NUM_LANES-1:0] block_tmask;
wire [BLOCK_SIZE-1:0][2:0][NUM_LANES-1:0][`XLEN-1:0] block_regs;
wire [BLOCK_SIZE-1:0][PID_WIDTH-1:0] block_pid;
wire [BLOCK_SIZE-1:0] block_sop;
wire [BLOCK_SIZE-1:0] block_eop;
wire [BLOCK_SIZE-1:0] block_done;
wire batch_done = (& block_done);
logic [BATCH_COUNT_W-1:0] batch_idx;
// if (BATCH_COUNT != 1) begin
// always @(posedge clk) begin
// if (reset) begin
// batch_idx <= '0;
// end else begin
// batch_idx <= batch_idx + BATCH_COUNT_W'(batch_done);
// end
// end
// end else begin
// assign batch_idx = 0;
// `UNUSED_VAR(batch_done)
// end
// group dispatch_valid by blocks
wire [BATCH_COUNT-1:0] batch_valids;
for (genvar i = 0; i < BATCH_COUNT; ++i) begin
assign batch_valids[i] = |(dispatch_valid[(BLOCK_SIZE * i) +: BLOCK_SIZE]);
end
// elect the leftmost-valid batch for the dispatch
wire dispatch_any_valid;
VX_lzc_rr #(
.N (BATCH_COUNT)
) batch_select (
.clk (clk),
.reset (reset),
.data_in (batch_valids),
.data_out (batch_idx),
.valid_out (dispatch_any_valid)
);
for (genvar block_idx = 0; block_idx < BLOCK_SIZE; ++block_idx) begin
wire [ISSUE_W-1:0] issue_idx = ISSUE_W'(batch_idx * BLOCK_SIZE) + ISSUE_W'(block_idx);
assign issue_indices[block_idx] = issue_idx;
wire valid_p, ready_p;
if (`NUM_THREADS != NUM_LANES) begin
reg [NUM_PACKETS-1:0] sent_mask_p;
wire [PID_WIDTH-1:0] start_p_n, start_p, end_p;
wire dispatch_valid_r;
reg is_first_p;
wire fire_p = valid_p && ready_p;
wire is_last_p = (start_p == end_p);
wire fire_eop = fire_p && is_last_p;
always @(posedge clk) begin
if (reset) begin
sent_mask_p <= '0;
is_first_p <= 1;
end else begin
if ((BATCH_COUNT != 1) ? batch_done : fire_eop) begin
sent_mask_p <= '0;
is_first_p <= 1;
end else if (fire_p) begin
sent_mask_p[start_p] <= 1;
is_first_p <= 0;
end
end
end
wire [NUM_PACKETS-1:0][NUM_LANES-1:0] per_packet_tmask;
wire [NUM_PACKETS-1:0][2:0][NUM_LANES-1:0][`XLEN-1:0] per_packet_regs;
wire [`NUM_THREADS-1:0] dispatch_tmask = dispatch_data[issue_idx][DATA_TMASK_OFF +: `NUM_THREADS];
wire [`NUM_THREADS-1:0][`XLEN-1:0] dispatch_rs1_data = dispatch_data[issue_idx][DATA_REGS_OFF + 2 * `NUM_THREADS * `XLEN +: `NUM_THREADS * `XLEN];
wire [`NUM_THREADS-1:0][`XLEN-1:0] dispatch_rs2_data = dispatch_data[issue_idx][DATA_REGS_OFF + 1 * `NUM_THREADS * `XLEN +: `NUM_THREADS * `XLEN];
wire [`NUM_THREADS-1:0][`XLEN-1:0] dispatch_rs3_data = dispatch_data[issue_idx][DATA_REGS_OFF + 0 * `NUM_THREADS * `XLEN +: `NUM_THREADS * `XLEN];
for (genvar i = 0; i < NUM_PACKETS; ++i) begin
for (genvar j = 0; j < NUM_LANES; ++j) begin
localparam k = i * NUM_LANES + j;
assign per_packet_tmask[i][j] = dispatch_tmask[k];
assign per_packet_regs[i][0][j] = dispatch_rs1_data[k];
assign per_packet_regs[i][1][j] = dispatch_rs2_data[k];
assign per_packet_regs[i][2][j] = dispatch_rs3_data[k];
end
end
wire [NUM_PACKETS-1:0] packet_valids;
wire [NUM_PACKETS-1:0][PID_WIDTH-1:0] packet_ids;
for (genvar i = 0; i < NUM_PACKETS; ++i) begin
assign packet_valids[i] = (| per_packet_tmask[i]);
assign packet_ids[i] = PID_WIDTH'(i);
end
VX_find_first #(
.N (NUM_PACKETS),
.DATAW (PID_WIDTH),
.REVERSE (0)
) find_first (
.valid_in (packet_valids & ~sent_mask_p),
.data_in (packet_ids),
.data_out (start_p_n),
`UNUSED_PIN (valid_out)
);
VX_find_first #(
.N (NUM_PACKETS),
.DATAW (PID_WIDTH),
.REVERSE (1)
) find_last (
.valid_in (packet_valids),
.data_in (packet_ids),
.data_out (end_p),
`UNUSED_PIN (valid_out)
);
VX_pipe_register #(
.DATAW (1 + PID_WIDTH),
.RESETW (1),
.DEPTH (FANOUT_ENABLE ? 1 : 0)
) pipe_reg (
.clk (clk),
.reset (reset || fire_p), // should flush on fire
.enable (1'b1),
.data_in ({dispatch_valid[issue_idx], start_p_n}),
.data_out ({dispatch_valid_r, start_p})
);
wire [NUM_LANES-1:0] tmask_p = per_packet_tmask[start_p];
wire [2:0][NUM_LANES-1:0][`XLEN-1:0] regs_p = per_packet_regs[start_p];
wire block_enable = (BATCH_COUNT == 1 || ~(& sent_mask_p));
assign valid_p = dispatch_valid_r && block_enable;
assign block_tmask[block_idx] = tmask_p;
assign block_regs[block_idx] = regs_p;
assign block_pid[block_idx] = start_p;
assign block_sop[block_idx] = is_first_p;
assign block_eop[block_idx] = is_last_p;
if (FANOUT_ENABLE) begin
assign block_ready[block_idx] = dispatch_valid_r && ready_p && block_enable;
end else begin
assign block_ready[block_idx] = ready_p && block_enable;
end
assign block_done[block_idx] = ~dispatch_valid[issue_idx] || fire_eop;
end else begin
assign valid_p = dispatch_valid[issue_idx];
assign block_tmask[block_idx] = dispatch_data[issue_idx][DATA_TMASK_OFF +: `NUM_THREADS];
assign block_regs[block_idx][0] = dispatch_data[issue_idx][DATA_REGS_OFF + 2 * `NUM_THREADS * `XLEN +: `NUM_THREADS * `XLEN];
assign block_regs[block_idx][1] = dispatch_data[issue_idx][DATA_REGS_OFF + 1 * `NUM_THREADS * `XLEN +: `NUM_THREADS * `XLEN];
assign block_regs[block_idx][2] = dispatch_data[issue_idx][DATA_REGS_OFF + 0 * `NUM_THREADS * `XLEN +: `NUM_THREADS * `XLEN];
assign block_pid[block_idx] = '0;
assign block_sop[block_idx] = 1'b1;
assign block_eop[block_idx] = 1'b1;
assign block_ready[block_idx] = ready_p;
assign block_done[block_idx] = ~valid_p || ready_p;
end
wire [ISSUE_ISW_W-1:0] isw;
if (BATCH_COUNT != 1) begin
if (BLOCK_SIZE != 1) begin
assign isw = {batch_idx, BLOCK_SIZE_W'(block_idx)};
end else begin
assign isw = batch_idx;
end
end else begin
assign isw = block_idx;
end
`RESET_RELAY(buf_out_reset, reset);
wire [`NW_WIDTH-1:0] block_wid = wis_to_wid(dispatch_data[issue_idx][DATA_TMASK_OFF+`NUM_THREADS +: ISSUE_WIS_W], isw);
VX_elastic_buffer #(
.DATAW (OUT_DATAW),
.SIZE (`OUT_REG_TO_EB_SIZE(OUT_REG)),
.OUT_REG (`OUT_REG_TO_EB_REG(OUT_REG))
) buf_out (
.clk (clk),
.reset (buf_out_reset),
.valid_in (valid_p),
.ready_in (ready_p),
.data_in ({
dispatch_data[issue_idx][IN_DATAW-1 : DATA_TMASK_OFF+`NUM_THREADS+ISSUE_WIS_W],
block_wid,
block_tmask[block_idx],
dispatch_data[issue_idx][DATA_TMASK_OFF-1 : DATA_REGS_OFF + 3 * `NUM_THREADS * `XLEN],
block_regs[block_idx][0],
block_regs[block_idx][1],
block_regs[block_idx][2],
block_pid[block_idx],
block_sop[block_idx],
block_eop[block_idx]}),
.data_out (execute_if[block_idx].data),
.valid_out (execute_if[block_idx].valid),
.ready_out (execute_if[block_idx].ready)
);
end
reg [`ISSUE_WIDTH-1:0] ready_in;
always @(*) begin
ready_in = 0;
for (integer i = 0; i < BLOCK_SIZE; ++i) begin
ready_in[issue_indices[i]] = block_ready[i] && block_eop[i];
end
end
assign dispatch_ready = ready_in;
endmodule

View File

@@ -39,7 +39,7 @@ module VX_fpu_unit import VX_fpu_pkg::*; #(
`RESET_RELAY (dispatch_reset, reset);
VX_dispatch_unit #(
VX_dispatch_unit_sane #(
.BLOCK_SIZE (BLOCK_SIZE),
.NUM_LANES (NUM_LANES),
.OUT_REG (PARTIAL_BW ? 1 : 0)

View File

@@ -49,7 +49,7 @@ module VX_lsu_unit import VX_gpu_pkg::*; #(
`RESET_RELAY (dispatch_reset, reset);
VX_dispatch_unit #(
VX_dispatch_unit_sane #(
.BLOCK_SIZE (BLOCK_SIZE),
.NUM_LANES (NUM_LANES),
.OUT_REG (1)
@@ -596,6 +596,31 @@ module VX_lsu_unit import VX_gpu_pkg::*; #(
.commit_out_if (commit_if)
);
`ifdef PERF_ENABLE
wire [`CLOG2(NUM_LANES+1)-1:0] perf_rsp_tmask_valids_per_cycle;
wire [`CLOG2(NUM_LANES+1)-1:0] perf_rsp_tmask_total_per_cycle;
reg [`PERF_CTR_BITS-1:0] perf_rsp_tmask_valids;
reg [`PERF_CTR_BITS-1:0] perf_rsp_tmask_total;
reg [`PERF_CTR_BITS-1:0] perf_rsp_fires;
`POP_COUNT(perf_rsp_tmask_valids_per_cycle, rsp_tmask);
assign perf_rsp_tmask_total_per_cycle = NUM_LANES;
always @(posedge clk) begin
if (reset) begin
perf_rsp_tmask_valids <= '0;
perf_rsp_tmask_total <= '0;
perf_rsp_fires <= '0;
end else begin
if (mem_rsp_fire) begin
perf_rsp_tmask_valids <= perf_rsp_tmask_valids + perf_rsp_tmask_valids_per_cycle;
perf_rsp_tmask_total <= perf_rsp_tmask_total + perf_rsp_tmask_total_per_cycle;
perf_rsp_fires <= perf_rsp_fires + 1'b1;
end
end
end
`endif
`ifdef DBG_SCOPE_LSU
if (CORE_ID == 0) begin
`ifdef SCOPE

View File

@@ -66,6 +66,7 @@ module VX_smem_unit import VX_gpu_pkg::*; #(
.req_valid (smem_req_valid),
.req_rw (smem_req_rw),
.req_byteen (smem_req_byteen),
// FIXME: synthesis complains undriven when USE_EXTERNAL_SMEM
.req_addr (smem_req_addr),
.req_data (smem_req_data),
.req_tag (smem_req_tag),

View File

@@ -1,7 +1,7 @@
`ifdef EXT_T_ENABLE
`include "VX_fpu_define.vh"
module VX_tensor_core #(
module VX_tensor_core import VX_gpu_pkg::*; #(
) (
input clk,
@@ -10,17 +10,54 @@ module VX_tensor_core #(
VX_dispatch_if.slave dispatch_if [`ISSUE_WIDTH],
VX_commit_if.master commit_if [`ISSUE_WIDTH]
);
`STATIC_ASSERT(`NUM_THREADS == 32, ("tensor core requires # of threads in a warp to be 32 (try running w/ CONFIGS=\"-DNUM_THREADS=32\")"));
localparam BLOCK_SIZE = 1;
localparam NUM_LANES = `NUM_THREADS;
// localparam PARTIAL_BW = (BLOCK_SIZE != `ISSUE_WIDTH) || (NUM_LANES != `NUM_THREADS);
localparam PARTIAL_BW = 1;
for (genvar i = 0; i < `ISSUE_WIDTH; ++i) begin
VX_execute_if #(
.NUM_LANES (NUM_LANES)
) execute_if[BLOCK_SIZE]();
`RESET_RELAY (dispatch_reset, reset);
VX_dispatch_unit_sane #(
.BLOCK_SIZE (BLOCK_SIZE),
.NUM_LANES (NUM_LANES),
.OUT_REG (PARTIAL_BW ? 1 : 0)
) dispatch_unit (
.clk (clk),
.reset (dispatch_reset),
.dispatch_if(dispatch_if),
.execute_if (execute_if)
);
VX_commit_if #(
.NUM_LANES (NUM_LANES)
) commit_block_if[BLOCK_SIZE]();
`RESET_RELAY (commit_reset, reset);
VX_gather_unit #(
.BLOCK_SIZE (BLOCK_SIZE),
.NUM_LANES (NUM_LANES),
.OUT_REG (PARTIAL_BW ? 3 : 0) // FIXME: why 3?
) gather_unit (
.clk (clk),
.reset (commit_reset),
.commit_in_if (commit_block_if),
.commit_out_if (commit_if)
);
for (genvar block_idx = 0; block_idx < BLOCK_SIZE; ++block_idx) begin
VX_tensor_core_warp #(
.ISW(i)
.ISW(1) // FIXME: not block_idx
) tensor_core (
.clk(clk),
.reset(reset),
.dispatch_if(dispatch_if[i]),
.commit_if(commit_if[i])
.execute_if(execute_if[block_idx]),
.commit_if(commit_block_if[block_idx])
);
end
@@ -32,37 +69,53 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #(
input clk,
input reset,
VX_dispatch_if.slave dispatch_if,
VX_execute_if.slave execute_if,
VX_commit_if.master commit_if
);
wire [1:0] step = 2'(dispatch_if.data.op_type);
logic [3:0] octet_results_valid;
logic [3:0] octet_results_ready;
logic [3:0] octet_operands_ready;
localparam NUM_OCTETS = (`NUM_THREADS / 8);
// offet in the lane numbers that get mapped to the two threadgroups in an
// octet. E.g. two tgs map lane 0-3 and lane 16-19 -> 16
// FIXME: not sure this is the right logic. just filling in what works
localparam LANE_OFFSET_THREADGROUP = (4 * NUM_OCTETS);
// this is only a rule of thumb
localparam METADATA_QUEUE_DEPTH = 2 * `LATENCY_HMMA;
wire [1:0] step = 2'(execute_if.data.op_type);
// op_mod is reused to indicate instruction's id in pair
wire last_in_pair = (execute_if.data.op_mod == `INST_MOD_BITS'(1));
logic [NUM_OCTETS-1:0] octet_results_valid;
logic [NUM_OCTETS-1:0] octet_results_ready;
logic [NUM_OCTETS-1:0] octet_operands_ready;
// FIXME: should be NUM_LANES?
logic [`NUM_THREADS-1:0][`XLEN-1:0] wb_data_0;
logic [`NUM_THREADS-1:0][`XLEN-1:0] wb_data_1;
wire [`NW_WIDTH-1:0] wb_wid;
assign dispatch_if.ready = &octet_operands_ready;
// valid signal synced between the functional units (octet) and the
// metadata queue
wire operands_valid_synced;
`ifdef EXT_T_ENABLE
for (genvar i = 0; i < 4/*octets*/; ++i) begin
for (genvar i = 0; i < NUM_OCTETS; ++i) begin
`else
for (genvar i = 0; i < 0; ++i) begin
`endif
// lane-to-octet mapping; see figure 13 of the paper
wire [7:0][31:0] octet_A = {
dispatch_if.data.rs1_data[16+4*i +: 4], dispatch_if.data.rs1_data[4*i +: 4]
execute_if.data.rs1_data[LANE_OFFSET_THREADGROUP + 4*i +: 4], execute_if.data.rs1_data[4*i +: 4]
};
wire [7:0][31:0] octet_B = {
dispatch_if.data.rs2_data[16+4*i +: 4], dispatch_if.data.rs2_data[4*i +: 4]
execute_if.data.rs2_data[LANE_OFFSET_THREADGROUP + 4*i +: 4], execute_if.data.rs2_data[4*i +: 4]
};
wire [7:0][31:0] octet_C = {
dispatch_if.data.rs3_data[16+4*i +: 4], dispatch_if.data.rs3_data[4*i +: 4]
execute_if.data.rs3_data[LANE_OFFSET_THREADGROUP + 4*i +: 4], execute_if.data.rs3_data[4*i +: 4]
};
logic [3:0][3:0][31:0] octet_D;
logic result_valid;
logic result_ready;
VX_tensor_octet #(
.ISW(ISW),
.OCTET(i)
@@ -73,12 +126,14 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #(
.A_in(octet_A),
.B_in(octet_B),
.C_in(octet_C),
.operands_valid(dispatch_if.valid),
.operands_valid(operands_valid_synced),
.operands_wid(execute_if.data.wid),
.operands_last_in_pair(last_in_pair),
.operands_step(step),
.operands_ready(octet_operands_ready[i]),
.step(step),
.D_out(octet_D),
.D_wid(wb_wid),
.result_valid(result_valid),
.result_ready(result_ready)
);
@@ -100,15 +155,15 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #(
assign wb_data_1[4*i+2] = octet_D[0][3];
assign wb_data_1[4*i+3] = octet_D[1][3];
assign wb_data_0[4*i+16+0] = octet_D[2][0];
assign wb_data_0[4*i+16+1] = octet_D[3][0];
assign wb_data_0[4*i+16+2] = octet_D[2][2];
assign wb_data_0[4*i+16+3] = octet_D[3][2];
assign wb_data_0[4*i+LANE_OFFSET_THREADGROUP+0] = octet_D[2][0];
assign wb_data_0[4*i+LANE_OFFSET_THREADGROUP+1] = octet_D[3][0];
assign wb_data_0[4*i+LANE_OFFSET_THREADGROUP+2] = octet_D[2][2];
assign wb_data_0[4*i+LANE_OFFSET_THREADGROUP+3] = octet_D[3][2];
assign wb_data_1[4*i+16+0] = octet_D[2][1];
assign wb_data_1[4*i+16+1] = octet_D[3][1];
assign wb_data_1[4*i+16+2] = octet_D[2][3];
assign wb_data_1[4*i+16+3] = octet_D[3][3];
assign wb_data_1[4*i+LANE_OFFSET_THREADGROUP+0] = octet_D[2][1];
assign wb_data_1[4*i+LANE_OFFSET_THREADGROUP+1] = octet_D[3][1];
assign wb_data_1[4*i+LANE_OFFSET_THREADGROUP+2] = octet_D[2][3];
assign wb_data_1[4*i+LANE_OFFSET_THREADGROUP+3] = octet_D[3][3];
end
/* commit_if.data_t parts that we need to keep around:
@@ -122,44 +177,95 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #(
localparam DATAW = `UUID_WIDTH + `NW_WIDTH + `NUM_THREADS + `XLEN + 1 + `NR_BITS;
wire dispatch_if_fire = dispatch_if.valid && dispatch_if.ready;
wire commit_if_fire = commit_if.valid && commit_if.ready;
wire [DATAW-1:0] dispatch_if_data_enq = {
dispatch_if.data.uuid,
wis_to_wid(dispatch_if.data.wis, ISW),
dispatch_if.data.tmask,
dispatch_if.data.PC,
dispatch_if.data.wb,
dispatch_if.data.rd
wire operand_enq_fire = operands_valid_synced && execute_if.ready;
wire commit_if_ready_override;
wire commit_if_fire = commit_if.valid && commit_if_ready_override;
wire [DATAW-1:0] execute_if_data_enq = {
execute_if.data.uuid,
execute_if.data.wid,
execute_if.data.tmask,
execute_if.data.PC,
execute_if.data.wb,
execute_if.data.rd
// pid/sop/eop set later
};
wire [DATAW-1:0] dispatch_if_data_deq;
wire [`NUM_WARPS-1:0][DATAW-1:0] execute_if_data_deq;
wire [`NUM_WARPS-1:0] metadata_queue_fulls;
// OR not AND, we don't want any warp full
wire metadata_queue_full = |(metadata_queue_fulls);
// need to make sure both metadata and octet issue queues are in sync
assign operands_valid_synced = execute_if.valid && !metadata_queue_full;
assign execute_if.ready = &(octet_operands_ready) && !metadata_queue_full;
for (genvar i = 0; i < `NUM_WARPS; i++) begin
// Metadata queue for commit_if. This simply copies execute_if's
// metadata and pops them in conjunction with commit fire.
//
// This has to be separated per-warp, as otherwise requests from
// multiple warps can be enqueued interleaved, which makes it hard to
// ensure two consecutive dequeues are associated with the same warp for
// commit. (FIXME: this is not strictly necessary though.)
wire enq = operand_enq_fire && (execute_if.data.wid == `NW_WIDTH'(i));
wire deq = commit_if_fire && ( wb_wid == `NW_WIDTH'(i));
// this is probably a little oversized
VX_fifo_queue #(
.DATAW(DATAW),
.DEPTH(16)
.DEPTH(METADATA_QUEUE_DEPTH)
) pending_uops (
.clk(clk),
.reset(reset),
.push(dispatch_if_fire),
.pop(commit_if_fire),
.data_in(dispatch_if_data_enq),
.data_out(dispatch_if_data_deq),
.push(enq),
.pop(deq),
.data_in(execute_if_data_enq),
.data_out(execute_if_data_deq[i]),
`UNUSED_PIN(empty),
`UNUSED_PIN(alm_empty),
`UNUSED_PIN(full), // should be impossible to overflow
.full(metadata_queue_fulls[i]),
`UNUSED_PIN(alm_full),
`UNUSED_PIN(size)
);
end
// this shouldn't really happen unless there's a big contention over
// the commit stage
`RUNTIME_ASSERT(!(!reset && metadata_queue_full), ("tensor core uop queue is full!"))
// unlike execute which can be interleaved between warps, commit is
// serialized and completed one-warp-by-warp, therefore we only need to
// keep one subcommit state bit unlike for `substeps`
logic subcommit, subcommit_n;
wire all_valid = (& octet_results_valid);
// define this to inject artificial commit backpressure for debugging
// `define TENSOR_INJECT_COMMIT_BACKPRESSURE
`ifndef TENSOR_INJECT_COMMIT_BACKPRESSURE
assign commit_if.valid = all_valid;
assign commit_if_ready_override = commit_if.ready;
`else
logic [1:0] counter;
always @(posedge clk) begin
if (reset) begin
counter <= '0;
end else begin
if (all_valid) begin
counter <= counter + 1'b1;
end
end
end
assign commit_if.valid = all_valid && (counter == 2'b0);
assign commit_if_ready_override = commit_if.ready && (counter == 2'b0);
`endif
localparam COMMIT_DATAW = `UUID_WIDTH + `NW_WIDTH + `NUM_THREADS + `XLEN + 1 + `NR_BITS + (`NUM_THREADS * `XLEN) + 1 + 1 + 1;
wire [COMMIT_DATAW-1:0] commit_if_data = {
dispatch_if_data_deq, /* uuid ~ rd */
execute_if_data_deq[wb_wid], /* uuid ~ rd */
// execute_if_data_deq, /* uuid ~ rd */
subcommit == 1'b0 ? wb_data_0 : wb_data_1, /* data */
1'b0, /* pid */
1'b1, /* sop */
@@ -199,19 +305,22 @@ module VX_tensor_octet #(
input [7:0][31:0] A_in,
input [7:0][31:0] B_in,
input [7:0][31:0] C_in,
input operands_valid, // we have to backpressure due to there potentially being contention over commit
input operands_valid,
input [`NW_WIDTH-1:0] operands_wid,
input operands_last_in_pair,
input [1:0] operands_step,
// we have to backpressure due to there potentially being contention over commit
output operands_ready,
input [1:0] step,
output [3:0][3:0][31:0] D_out,
output [`NW_WIDTH-1:0] D_wid,
output result_valid,
input result_ready
);
// 512 bits/octet * 4 octets per warp
logic [3:0][31:0] A_buffer, A_buffer_n;
logic [3:0][31:0] B_buffer, B_buffer_n;
logic [7:0][31:0] C_buffer, C_buffer_n;
logic [`NUM_WARPS-1:0][3:0][31:0] A_buffer, A_buffer_n;
logic [`NUM_WARPS-1:0][3:0][31:0] B_buffer, B_buffer_n;
logic [`NUM_WARPS-1:0][7:0][31:0] C_buffer, C_buffer_n;
// half the inputs are buffered, half are not (instead coming straight
// from operand bus) unlike the real tensor core.
@@ -219,41 +328,95 @@ module VX_tensor_octet #(
logic [3:0][31:0] A_half;
logic [3:0][31:0] B_half;
logic [7:0][31:0] C_half;
always @(*) begin
logic [3:0][31:0] A_half_buf;
logic [3:0][31:0] B_half_buf;
logic [7:0][31:0] C_half_buf;
logic [`NUM_WARPS-1:0] substeps;
logic [`NUM_WARPS-1:0] substeps_n;
wire [7:0][31:0] A_in_buf;
wire [7:0][31:0] B_in_buf;
wire [7:0][31:0] C_in_buf;
wire operands_valid_buf;
wire operands_ready_buf;
wire [`NW_WIDTH-1:0] operands_wid_buf;
wire operands_last_in_pair_buf;
wire [1:0] operands_step_buf;
assign A_in_buf = A_in;
assign B_in_buf = B_in;
assign C_in_buf = C_in;
assign operands_step_buf = operands_step;
assign operands_wid_buf = operands_wid;
assign operands_last_in_pair_buf = operands_last_in_pair;
assign operands_valid_buf = operands_valid;
assign operands_ready = operands_ready_buf;
typedef struct {
logic [3:0][31:0] A_half;
logic [3:0][31:0] B_half;
logic [7:0][31:0] C_half;
} half_t;
function half_t get_operand_half(
input logic [1:0] step,
input logic [7:0][31:0] A_in,
input logic [7:0][31:0] B_in,
input logic [7:0][31:0] C_in
);
half_t half;
// note that not all lanes participate at every step
case (step)
2'b00: begin
A_half = { A_in[5:4], A_in[1:0] };
B_half = B_in[3:0];
// Two A_in segments correspond to two 2x2 subtiles of A read
// by two threadgroups: [0:2,0:2] and [4:6,0:2] in Step 0 of
// Figure 10(b). B_in OTOH is shared by two threadgroups.
// Note k-dimension is shrunk from 4 to 2.
half.A_half = { A_in[5:4], A_in[1:0] };
half.B_half = B_in[3:0];
end
2'b01: begin
A_half = { A_in[7:6], A_in[3:2] };
B_half = B_in[3:0];
half.A_half = { A_in[7:6], A_in[3:2] };
half.B_half = B_in[3:0];
end
2'b10: begin
A_half = { A_in[5:4], A_in[1:0] };
B_half = B_in[7:4];
half.A_half = { A_in[5:4], A_in[1:0] };
half.B_half = B_in[7:4];
end
2'b11: begin
A_half = { A_in[7:6], A_in[3:2] };
B_half = B_in[7:4];
half.A_half = { A_in[7:6], A_in[3:2] };
half.B_half = B_in[7:4];
end
endcase
C_half = C_in;
end
half.C_half = C_in;
return half;
endfunction
logic substep;
wire substep_n = (operands_ready && operands_valid) ? ~substep : substep;
half_t halves;
half_t halves_buf;
assign halves = get_operand_half(operands_step, A_in, B_in, C_in);
assign halves_buf = get_operand_half(operands_step_buf, A_in_buf, B_in_buf, C_in_buf);
wire do_hmma = operands_ready_buf && operands_valid_buf && operands_last_in_pair_buf;
// wire operands_first_in_pair_fire = operands_ready && operands_valid && (!operands_last_in_pair);
wire operands_first_in_pair_fire = operands_ready_buf && operands_valid_buf && (!operands_last_in_pair_buf);
always @(*) begin
A_buffer_n = A_buffer;
B_buffer_n = B_buffer;
C_buffer_n = C_buffer;
substeps_n = substeps;
if (substep == 1'b0) begin
A_buffer_n = A_half;
B_buffer_n = B_half;
C_buffer_n = C_half;
if (operands_first_in_pair_fire) begin
substeps_n[operands_wid_buf] = 1'b1; // ready for hmma
A_buffer_n[operands_wid_buf] = halves_buf.A_half;
B_buffer_n[operands_wid_buf] = halves_buf.B_half;
C_buffer_n[operands_wid_buf] = halves_buf.C_half;
end
if (do_hmma) begin
substeps_n[operands_wid_buf] = 1'b0; // finished hmma, ready for next operand
end
end
@@ -262,61 +425,113 @@ module VX_tensor_octet #(
A_buffer <= '0;
B_buffer <= '0;
C_buffer <= '0;
substep <= '0;
substeps <= '0;
end
else begin
A_buffer <= A_buffer_n;
B_buffer <= B_buffer_n;
C_buffer <= C_buffer_n;
substep <= substep_n;
substeps <= substeps_n;
end
end
wire stall = result_valid && ~result_ready;
assign operands_ready = ~stall;
wire outbuf_ready_in;
wire hmma_ready;
assign operands_ready_buf = hmma_ready;
// A is 4x2 fp32 matrix
wire [3:0][1:0][31:0] A_tile = {
{ A_half[3], A_buffer[3] },
{ A_half[2], A_buffer[2] },
{ A_half[1], A_buffer[1] },
{ A_half[0], A_buffer[0] }
{ halves_buf.A_half[3], A_buffer[operands_wid_buf][3] },
{ halves_buf.A_half[2], A_buffer[operands_wid_buf][2] },
{ halves_buf.A_half[1], A_buffer[operands_wid_buf][1] },
{ halves_buf.A_half[0], A_buffer[operands_wid_buf][0] }
};
// B is 2x4 fp32 matrix
wire [1:0][3:0][31:0] B_tile = {
B_half, B_buffer
halves_buf.B_half, B_buffer[operands_wid_buf]
};
// C is 4x4 fp32 matrix
logic [3:0][3:0][31:0] C_tile;
logic [3:0][3:0][31:0] D_tile;
logic [`NW_WIDTH-1:0] D_wid_dpu;
always @(*) begin
C_tile = {
C_half[7], C_buffer[7], C_half[5], C_buffer[5],
C_half[6], C_buffer[6], C_half[4], C_buffer[4],
C_half[3], C_buffer[3], C_half[1], C_buffer[1],
C_half[2], C_buffer[2], C_half[0], C_buffer[0]
};
C_tile[3] = { halves_buf.C_half[7], C_buffer[operands_wid_buf][7], halves_buf.C_half[5], C_buffer[operands_wid_buf][5] };
C_tile[2] = { halves_buf.C_half[6], C_buffer[operands_wid_buf][6], halves_buf.C_half[4], C_buffer[operands_wid_buf][4] };
C_tile[1] = { halves_buf.C_half[3], C_buffer[operands_wid_buf][3], halves_buf.C_half[1], C_buffer[operands_wid_buf][1] };
C_tile[0] = { halves_buf.C_half[2], C_buffer[operands_wid_buf][2], halves_buf.C_half[0], C_buffer[operands_wid_buf][0] };
end
wire do_hmma = (substep == 1'b1 && operands_valid && operands_ready);
wire dpu_valid;
// this does (m,n,k)=(4,4,2) matmul, modeling compute of a single octet
VX_tensor_dpu #(
.ISW(ISW),
.OCTET(OCTET)
.OCTET(OCTET),
.ISSUE_QUEUE_DEPTH(4 /*@perf: arbtirary*/)
) dpu (
.clk(clk),
.reset(reset),
.stall(stall),
.valid_in(do_hmma),
.ready_in(hmma_ready),
.A_tile(A_tile),
.B_tile(B_tile),
.C_tile(C_tile),
.wid(operands_wid_buf),
.valid_out(result_valid),
.D_tile(D_out)
.valid_out(dpu_valid),
.ready_out(outbuf_ready_in),
.D_tile(D_tile),
.D_wid(D_wid_dpu)
);
wire outbuf_empty;
wire outbuf_full;
// backpressure from commit
assign outbuf_ready_in = ~outbuf_full;
assign result_valid = ~outbuf_empty;
wire outbuf_enq = outbuf_ready_in && dpu_valid;
wire outbuf_deq = result_valid && result_ready;
// buffer to stage the result D tile for 2 cycles until commit/writeback
// is complete. This decouples the irregular dpu output traffic from the
// regular, every-2-cycle commit traffic to ensure the commit pipeline is
// used more efficiently.
// FIXME: unnecessary?
VX_fifo_queue #(
.DATAW ($bits(D_wid) + $bits(D_out)),
.DEPTH (2 /* arbitrary */)
) output_buffer (
.clk (clk),
.reset (reset),
.push (outbuf_enq),
.pop (outbuf_deq),
.data_in ({D_wid_dpu, D_tile}),
.data_out ({D_wid, D_out}),
.empty (outbuf_empty),
`UNUSED_PIN(alm_empty),
.full (outbuf_full), // should be impossible to overflow
`UNUSED_PIN(alm_full),
`UNUSED_PIN(size)
);
// FIXME: this shouldn't be necessary
`RUNTIME_ASSERT(reset || !outbuf_full, ("dpu result queue is full!"))
`ifdef PERF_ENABLE
logic [`PERF_CTR_BITS-1:0] perf_tensor_dpu_total;
always @(posedge clk) begin
if (reset) begin
perf_tensor_dpu_total <= '0;
end else begin
if (do_hmma) begin
perf_tensor_dpu_total <= perf_tensor_dpu_total + 2'd2;
end
end
end
`endif
endmodule
`endif

View File

@@ -0,0 +1,49 @@
// uop metadata (sequencing, next state), execution metadata (EX_TYPE, OP_TYPE, OP_MOD), wb, use pc, use imm, pc, imm, rd, rs1, rs2, rs3
HMMA_SET0_STEP0_0: begin
uop = {NEXT, HMMA_SET0_STEP0_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(0), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(16), `FREG(0), `FREG(8), `FREG(16)};
end
HMMA_SET0_STEP0_1: begin
uop = {NEXT, HMMA_SET0_STEP1_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(0), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(17), `FREG(1), `FREG(9), `FREG(17)};
end
HMMA_SET0_STEP1_0: begin
uop = {NEXT, HMMA_SET0_STEP1_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(1), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(18), `FREG(0), `FREG(8), `FREG(18)};
end
HMMA_SET0_STEP1_1: begin
uop = {NEXT, HMMA_SET0_STEP2_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(1), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(19), `FREG(1), `FREG(9), `FREG(19)};
end
HMMA_SET0_STEP2_0: begin
uop = {NEXT, HMMA_SET0_STEP2_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(2), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(20), `FREG(0), `FREG(8), `FREG(20)};
end
HMMA_SET0_STEP2_1: begin
uop = {NEXT, HMMA_SET0_STEP3_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(2), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(21), `FREG(1), `FREG(9), `FREG(21)};
end
HMMA_SET0_STEP3_0: begin
uop = {NEXT, HMMA_SET0_STEP3_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(3), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(22), `FREG(0), `FREG(8), `FREG(22)};
end
HMMA_SET0_STEP3_1: begin
uop = {NEXT, HMMA_SET1_STEP0_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(3), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(23), `FREG(1), `FREG(9), `FREG(23)};
end
HMMA_SET1_STEP0_0: begin
uop = {NEXT, HMMA_SET1_STEP0_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(0), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(16), `FREG(2), `FREG(10), `FREG(16)};
end
HMMA_SET1_STEP0_1: begin
uop = {NEXT, HMMA_SET1_STEP1_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(0), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(17), `FREG(3), `FREG(11), `FREG(17)};
end
HMMA_SET1_STEP1_0: begin
uop = {NEXT, HMMA_SET1_STEP1_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(1), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(18), `FREG(2), `FREG(10), `FREG(18)};
end
HMMA_SET1_STEP1_1: begin
uop = {NEXT, HMMA_SET1_STEP2_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(1), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(19), `FREG(3), `FREG(11), `FREG(19)};
end
HMMA_SET1_STEP2_0: begin
uop = {NEXT, HMMA_SET1_STEP2_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(2), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(20), `FREG(2), `FREG(10), `FREG(20)};
end
HMMA_SET1_STEP2_1: begin
uop = {NEXT, HMMA_SET1_STEP3_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(2), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(21), `FREG(3), `FREG(11), `FREG(21)};
end
HMMA_SET1_STEP3_0: begin
uop = {NEXT, HMMA_SET1_STEP3_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(3), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(22), `FREG(2), `FREG(10), `FREG(22)};
end
HMMA_SET1_STEP3_1: begin
uop = {FINISH, HMMA_SET0_STEP0_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(3), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b1, 32'b1, `FREG(23), `FREG(3), `FREG(11), `FREG(23)};
end

View File

@@ -14,10 +14,9 @@ module VX_uop_sequencer import VX_gpu_pkg::*; (
localparam UOP_TABLE_SIZE = 64;
localparam UPC_BITS = `CLOG2(UOP_TABLE_SIZE);
localparam NEXT = 2'b00;
localparam FINISH = 2'b01;
localparam UBR_BITS = 2;
localparam NEXT = UBR_BITS'(2'b00);
localparam FINISH = UBR_BITS'(2'b01);
// uop metadata (sequencing, next state), execution metadata (EX_TYPE, OP_TYPE, OP_MOD), wb, use pc, use imm, pc, imm, rd, rs1, rs2, rs3
localparam UOP_TABLE_WIDTH = UBR_BITS + UPC_BITS + `EX_BITS + `INST_OP_BITS + `INST_MOD_BITS + 1 + 1 + 1 + `XLEN + `XLEN + (`NR_BITS * 4);
@@ -122,7 +121,17 @@ module VX_uop_sequencer import VX_gpu_pkg::*; (
// passthrough when !use_uop
assign ibuffer_if.valid = use_uop ? 1'b1 : uop_sequencer_if.valid;
assign uop_sequencer_if.ready = use_uop ? (uop_fire && ubr == FINISH) : ibuffer_if.ready;
assign ibuffer_if.data = use_uop ? ibuffer_output : uop_sequencer_if.data;
always @(*) begin
ibuffer_if.data = use_uop ? ibuffer_output : uop_sequencer_if.data;
if (uop_sequencer_if.valid && use_uop &&
uop_sequencer_if.data.rd == `NR_BITS'(1)) begin
// a little sketchy? but shouldn't create any loop
ibuffer_if.data.rd = ibuffer_if.data.rd + `NR_BITS'(8); // FIXME: 8 is hardcoded
ibuffer_if.data.rs3 = ibuffer_if.data.rs3 + `NR_BITS'(8);
end
end
always @(posedge clk) begin
if (uop_start) begin

View File

@@ -3,44 +3,274 @@
module VX_tensor_dpu #(
parameter ISW,
parameter OCTET
parameter OCTET,
// @perf: has big impact on throughput. A rule of thumb is to set it to
// the pipeline length of FEDPs in order to make sure there are enough
// entries to fully saturate the pipeline, but this is still rough
parameter ISSUE_QUEUE_DEPTH = `LATENCY_HMMA
) (
input clk,
input reset,
input stall,
input valid_in,
output ready_in,
input [3:0][1:0][31:0] A_tile,
input [1:0][3:0][31:0] B_tile,
input [3:0][3:0][31:0] C_tile,
input [`NW_WIDTH-1:0] wid,
output valid_out,
output [3:0][3:0][31:0] D_tile
input ready_out,
output [3:0][3:0][31:0] D_tile,
output [`NW_WIDTH-1:0] D_wid
);
logic [3:0][3:0][31:0] result_hmma;
// logic [3:0][3:0][31:0] result_hmma;
// always @(*) begin
// dpi_hmma(valid_in, A_tile, B_tile, C_tile, result_hmma);
// end
// logic ready_reg;
// always @(posedge clk) begin
// if (reset) begin
// ready_reg <= '1;
// end else if (valid_in && ready_in) begin
// ready_reg <= '0;
// dpi_print_results(int'(ISW), int'(OCTET), A_tile, B_tile, C_tile, result_hmma);
// end else if (valid_out && ready_out) begin
// ready_reg <= '1;
// end
// end
// // fixed-latency queue
// VX_shift_register #(
// .DATAW (1 + $bits(wid)/* + $bits(D_tile)*/),
// .DEPTH (`LATENCY_HMMA + 1),
// .RESETW (1)
// ) shift_reg (
// .clk (clk),
// .reset (reset),
// .enable (ready_out),
// .data_in ({valid_in && ready_in, wid /*, result_hmma*/}),
// .data_out ({valid_out, D_wid/*, D_tile */})
// );
// ready as soon as valid_out
// assign ready_in = ready_reg || valid_out;
// fully pipelined; ready_in is coupled to ready_out by immediately
// stalling
// assign ready_in = ready_out;
logic synced_fire;
assign synced_fire = valid_in && ready_in;
logic [1:0] threadgroup_valids;
logic [1:0] threadgroup_readys;
// B_tile is shared across the two threadgroups; see Figure 13
VX_tensor_threadgroup #(
.ISSUE_QUEUE_DEPTH(ISSUE_QUEUE_DEPTH)
) threadgroup_0 (
.clk (clk),
.reset (reset),
.valid_in (synced_fire),
.ready_in (threadgroup_readys[0]),
.stall (!ready_out),
.A_frag (A_tile[1:0]),
.B_frag (B_tile),
.C_frag (C_tile[1:0]),
.valid_out (threadgroup_valids[0]),
.D_frag (D_tile[1:0])
);
VX_tensor_threadgroup #(
.ISSUE_QUEUE_DEPTH(ISSUE_QUEUE_DEPTH)
) threadgroup_1 (
.clk (clk),
.reset (reset),
.valid_in (synced_fire),
.ready_in (threadgroup_readys[1]),
.stall (!ready_out),
.A_frag (A_tile[3:2]),
.B_frag (B_tile),
.C_frag (C_tile[3:2]),
.valid_out (threadgroup_valids[1]),
.D_frag (D_tile[3:2])
);
wire empty;
wire full;
wire enq = valid_in && ready_in;
wire deq = valid_out && ready_out;
assign ready_in = &(threadgroup_readys) && !full;
assign valid_out = &(threadgroup_valids);
// need to pass along warp id's to do multithreading
VX_fifo_queue #(
.DATAW ($bits(wid)),
// @perf: seems to require deeper depth than the FEDP issue queues to
// not cause stalls.
.DEPTH (2 * ISSUE_QUEUE_DEPTH)
) wid_queue (
.clk (clk),
.reset (reset),
.push (enq),
.pop (deq),
.data_in (wid),
.data_out (D_wid),
.empty (empty),
`UNUSED_PIN(alm_empty),
.full (full),
`UNUSED_PIN(alm_full),
`UNUSED_PIN(size)
);
`RUNTIME_ASSERT(reset || !(deq && empty),
("dequeueing from empty warp id queue!"))
endmodule
// does (m,n,k) = (2,4,2) matmul compute over 2 cycles.
// matches Figure 10(b) of the paper.
module VX_tensor_threadgroup #(
parameter ISSUE_QUEUE_DEPTH
) (
input clk,
input reset,
input valid_in,
output ready_in,
input stall,
input [1:0][1:0][31:0] A_frag,
input [1:0][3:0][31:0] B_frag,
input [1:0][3:0][31:0] C_frag,
output valid_out,
output [1:0][3:0][31:0] D_frag
);
wire [1:0][1:0][31:0] A_frag_buf;
wire [1:0][3:0][31:0] B_frag_buf;
wire [1:0][3:0][31:0] C_frag_buf;
wire valid_buf;
wire ready_buf;
wire enq = valid_in && ready_in;
wire deq = valid_buf && ready_buf;
wire empty;
wire full;
assign ready_in = !full;
assign valid_buf = !empty;
// 'Issue queue' for the FEDP units.
// This exists to decouple the execution of the dot-product unit from
// the operand arrival. Operands from execute_if can arrive
// intermittently according to the frontend's behavior, and since the dpu
// can also stall for a fixed initiation latency, we need to decouple the
// two to efficiently feed the dpu.
//
// TODO: better queue design possible; e.g. B_frag is shared by two
// threadgroups, so we need only 1 queue per octet for B
VX_fifo_queue #(
.DATAW ($bits(A_frag) + $bits(B_frag) + $bits(C_frag)),
.DEPTH (ISSUE_QUEUE_DEPTH)
) input_buffer (
.clk (clk),
.reset (reset),
.push (enq),
.pop (deq),
.data_in ({A_frag, B_frag, C_frag}),
.data_out ({A_frag_buf, B_frag_buf, C_frag_buf}),
.empty (empty),
`UNUSED_PIN(alm_empty),
.full (full),
`UNUSED_PIN(alm_full),
`UNUSED_PIN(size)
);
logic [3:0] fedp_valids;
wire fedp_valid_out = &(fedp_valids);
wire fedp_ready_out = !stall;
wire fedp_fire_out = fedp_valid_out && fedp_ready_out;
wire fedp_valid_in = valid_buf;
wire fedp_ready_in = fedp_ready_out; // coupled
wire fedp_fire_in = fedp_valid_in && fedp_ready_in;
// 0: FEDP uses first half from input_buffer
// 1: FEDP uses last half and pops input_buffer
logic step_in;
// 0: FEDP produces first half of D_frag
// 1: FEDP produces last half of D_frag and asserts valid_out
logic step_out;
assign ready_buf = fedp_fire_in && (step_in == 1'b1);
// latch the first-half result of D_frag
logic [3:0][31:0] D_reg, D_reg_n;
wire [3:0][31:0] D_half;
always @(*) begin
dpi_hmma(valid_in, A_tile, B_tile, C_tile, result_hmma);
D_reg_n = D_reg;
if (fedp_fire_out) begin
if (step_out == 1'b0) begin
D_reg_n = D_half;
end
end
end
always @(posedge clk) begin
if (~reset && valid_in) begin
dpi_print_results(int'(ISW), int'(OCTET), A_tile, B_tile, C_tile, result_hmma);
if (reset) begin
step_in <= '0;
step_out <= '0;
D_reg <= '0;
end else begin
if (fedp_fire_in) begin
step_in <= ~step_in;
end
if (fedp_fire_out) begin
step_out <= ~step_out;
end
D_reg <= D_reg_n;
end
end
// TODO: Instead of latching half-result and constructing a full D tile,
// we should be able to send these half fragments down to commit stage
// immediately, saving flop space
assign D_frag[0][0] = D_reg[0];
assign D_frag[0][2] = D_reg[1];
assign D_frag[1][0] = D_reg[2];
assign D_frag[1][2] = D_reg[3];
assign D_frag[0][1] = D_half[0];
assign D_frag[0][3] = D_half[1];
assign D_frag[1][1] = D_half[2];
assign D_frag[1][3] = D_half[3];
VX_shift_register #(
.DATAW (1 + $bits(D_tile)),
.DEPTH (`LATENCY_HMMA),
.RESETW (1)
) shift_reg (
.clk (clk),
// 4 FEDPs per threadgroup
for (genvar i = 0; i < 4; ++i) begin
localparam int d_row = i / 2;
localparam int d_col = (i % 2) * 2;
// four-element dot product (FEDP) unit
TensorDotProductUnit fedp (
.clock (clk),
.reset (reset),
.enable (~stall),
.data_in ({valid_in, result_hmma}),
.data_out ({valid_out, D_tile})
.io_in_valid (fedp_fire_in),
.io_in_bits_a_0 (A_frag_buf[d_row][0]),
.io_in_bits_a_1 (A_frag_buf[d_row][1]),
.io_in_bits_a_2 (32'h0),
.io_in_bits_a_3 (32'h0),
.io_in_bits_b_0 (step_in == 1'b0 ? B_frag_buf[0][d_col] : B_frag_buf[0][d_col + 1]),
.io_in_bits_b_1 (step_in == 1'b0 ? B_frag_buf[1][d_col] : B_frag_buf[1][d_col + 1]),
.io_in_bits_b_2 (32'h0),
.io_in_bits_b_3 (32'h0),
.io_in_bits_c (step_in == 1'b0 ? C_frag_buf[d_row][d_col] : C_frag_buf[d_row][d_col + 1]),
.io_stall (stall),
.io_out_valid (fedp_valids[i]),
.io_out_bits_data (D_half[i])
);
end
assign valid_out = fedp_valid_out && (step_out == 1'b1);
endmodule
`endif