tensor: Add dispatch unit to narrow to BLOCK_SIZE=1

This commit is contained in:
Hansung Kim
2024-05-15 15:34:26 -07:00
parent 9f9ec10960
commit 1a1094b2bb

View File

@@ -1,7 +1,7 @@
`ifdef EXT_T_ENABLE `ifdef EXT_T_ENABLE
`include "VX_fpu_define.vh" `include "VX_fpu_define.vh"
module VX_tensor_core #( module VX_tensor_core import VX_gpu_pkg::*; #(
) ( ) (
input clk, input clk,
@@ -10,15 +10,54 @@ module VX_tensor_core #(
VX_dispatch_if.slave dispatch_if [`ISSUE_WIDTH], VX_dispatch_if.slave dispatch_if [`ISSUE_WIDTH],
VX_commit_if.master commit_if [`ISSUE_WIDTH] VX_commit_if.master commit_if [`ISSUE_WIDTH]
); );
for (genvar i = 0; i < `ISSUE_WIDTH; ++i) begin localparam BLOCK_SIZE = 1;
localparam NUM_LANES = `NUM_THREADS;
// localparam PARTIAL_BW = (BLOCK_SIZE != `ISSUE_WIDTH) || (NUM_LANES != `NUM_THREADS);
localparam PARTIAL_BW = 1;
VX_execute_if #(
.NUM_LANES (NUM_LANES)
) execute_if[BLOCK_SIZE]();
`RESET_RELAY (dispatch_reset, reset);
VX_dispatch_unit #(
.BLOCK_SIZE (BLOCK_SIZE),
.NUM_LANES (NUM_LANES),
.OUT_REG (PARTIAL_BW ? 1 : 0)
) dispatch_unit (
.clk (clk),
.reset (dispatch_reset),
.dispatch_if(dispatch_if),
.execute_if (execute_if)
);
VX_commit_if #(
.NUM_LANES (NUM_LANES)
) commit_block_if[BLOCK_SIZE]();
`RESET_RELAY (commit_reset, reset);
VX_gather_unit #(
.BLOCK_SIZE (BLOCK_SIZE),
.NUM_LANES (NUM_LANES),
.OUT_REG (PARTIAL_BW ? 3 : 0) // FIXME: why 3?
) gather_unit (
.clk (clk),
.reset (commit_reset),
.commit_in_if (commit_block_if),
.commit_out_if (commit_if)
);
for (genvar block_idx = 0; block_idx < BLOCK_SIZE; ++block_idx) begin
VX_tensor_core_warp #( VX_tensor_core_warp #(
.ISW(i) .ISW(1) // FIXME: not block_idx
) tensor_core ( ) tensor_core (
.clk(clk), .clk(clk),
.reset(reset), .reset(reset),
.dispatch_if(dispatch_if[i]), .execute_if(execute_if[block_idx]),
.commit_if(commit_if[i]) .commit_if(commit_block_if[block_idx])
); );
end end
@@ -30,7 +69,7 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #(
input clk, input clk,
input reset, input reset,
VX_dispatch_if.slave dispatch_if, VX_execute_if.slave execute_if,
VX_commit_if.master commit_if VX_commit_if.master commit_if
); );
localparam NUM_OCTETS = (`NUM_THREADS / 8); localparam NUM_OCTETS = (`NUM_THREADS / 8);
@@ -39,14 +78,15 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #(
// FIXME: not sure this is the right logic. just filling in what works // FIXME: not sure this is the right logic. just filling in what works
localparam LANE_OFFSET_THREADGROUP = (4 * NUM_OCTETS); localparam LANE_OFFSET_THREADGROUP = (4 * NUM_OCTETS);
wire [1:0] step = 2'(dispatch_if.data.op_type); wire [1:0] step = 2'(execute_if.data.op_type);
logic [NUM_OCTETS-1:0] octet_results_valid; logic [NUM_OCTETS-1:0] octet_results_valid;
logic [NUM_OCTETS-1:0] octet_results_ready; logic [NUM_OCTETS-1:0] octet_results_ready;
logic [NUM_OCTETS-1:0] octet_operands_ready; logic [NUM_OCTETS-1:0] octet_operands_ready;
// FIXME: should be NUM_LANES?
logic [`NUM_THREADS-1:0][`XLEN-1:0] wb_data_0; logic [`NUM_THREADS-1:0][`XLEN-1:0] wb_data_0;
logic [`NUM_THREADS-1:0][`XLEN-1:0] wb_data_1; logic [`NUM_THREADS-1:0][`XLEN-1:0] wb_data_1;
assign dispatch_if.ready = &octet_operands_ready; assign execute_if.ready = &octet_operands_ready;
`ifdef EXT_T_ENABLE `ifdef EXT_T_ENABLE
for (genvar i = 0; i < NUM_OCTETS; ++i) begin for (genvar i = 0; i < NUM_OCTETS; ++i) begin
@@ -55,13 +95,13 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #(
`endif `endif
// lane-to-octet mapping; see figure 13 of the paper // lane-to-octet mapping; see figure 13 of the paper
wire [7:0][31:0] octet_A = { wire [7:0][31:0] octet_A = {
dispatch_if.data.rs1_data[LANE_OFFSET_THREADGROUP + 4*i +: 4], dispatch_if.data.rs1_data[4*i +: 4] execute_if.data.rs1_data[LANE_OFFSET_THREADGROUP + 4*i +: 4], execute_if.data.rs1_data[4*i +: 4]
}; };
wire [7:0][31:0] octet_B = { wire [7:0][31:0] octet_B = {
dispatch_if.data.rs2_data[LANE_OFFSET_THREADGROUP + 4*i +: 4], dispatch_if.data.rs2_data[4*i +: 4] execute_if.data.rs2_data[LANE_OFFSET_THREADGROUP + 4*i +: 4], execute_if.data.rs2_data[4*i +: 4]
}; };
wire [7:0][31:0] octet_C = { wire [7:0][31:0] octet_C = {
dispatch_if.data.rs3_data[LANE_OFFSET_THREADGROUP + 4*i +: 4], dispatch_if.data.rs3_data[4*i +: 4] execute_if.data.rs3_data[LANE_OFFSET_THREADGROUP + 4*i +: 4], execute_if.data.rs3_data[4*i +: 4]
}; };
logic [3:0][3:0][31:0] octet_D; logic [3:0][3:0][31:0] octet_D;
@@ -77,7 +117,7 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #(
.A_in(octet_A), .A_in(octet_A),
.B_in(octet_B), .B_in(octet_B),
.C_in(octet_C), .C_in(octet_C),
.operands_valid(dispatch_if.valid), .operands_valid(execute_if.valid),
.operands_ready(octet_operands_ready[i]), .operands_ready(octet_operands_ready[i]),
.step(step), .step(step),
@@ -126,18 +166,18 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #(
localparam DATAW = `UUID_WIDTH + `NW_WIDTH + `NUM_THREADS + `XLEN + 1 + `NR_BITS; localparam DATAW = `UUID_WIDTH + `NW_WIDTH + `NUM_THREADS + `XLEN + 1 + `NR_BITS;
wire dispatch_if_fire = dispatch_if.valid && dispatch_if.ready; wire execute_if_fire = execute_if.valid && execute_if.ready;
wire commit_if_fire = commit_if.valid && commit_if.ready; wire commit_if_fire = commit_if.valid && commit_if.ready;
wire [DATAW-1:0] dispatch_if_data_enq = { wire [DATAW-1:0] execute_if_data_enq = {
dispatch_if.data.uuid, execute_if.data.uuid,
wis_to_wid(dispatch_if.data.wis, ISW), execute_if.data.wid,
dispatch_if.data.tmask, execute_if.data.tmask,
dispatch_if.data.PC, execute_if.data.PC,
dispatch_if.data.wb, execute_if.data.wb,
dispatch_if.data.rd execute_if.data.rd
}; };
wire [DATAW-1:0] dispatch_if_data_deq; wire [DATAW-1:0] execute_if_data_deq;
// this is probably a little oversized // this is probably a little oversized
VX_fifo_queue #( VX_fifo_queue #(
@@ -146,10 +186,10 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #(
) pending_uops ( ) pending_uops (
.clk(clk), .clk(clk),
.reset(reset), .reset(reset),
.push(dispatch_if_fire), .push(execute_if_fire),
.pop(commit_if_fire), .pop(commit_if_fire),
.data_in(dispatch_if_data_enq), .data_in(execute_if_data_enq),
.data_out(dispatch_if_data_deq), .data_out(execute_if_data_deq),
`UNUSED_PIN(empty), `UNUSED_PIN(empty),
`UNUSED_PIN(alm_empty), `UNUSED_PIN(alm_empty),
`UNUSED_PIN(full), // should be impossible to overflow `UNUSED_PIN(full), // should be impossible to overflow
@@ -163,7 +203,7 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #(
localparam COMMIT_DATAW = `UUID_WIDTH + `NW_WIDTH + `NUM_THREADS + `XLEN + 1 + `NR_BITS + (`NUM_THREADS * `XLEN) + 1 + 1 + 1; localparam COMMIT_DATAW = `UUID_WIDTH + `NW_WIDTH + `NUM_THREADS + `XLEN + 1 + `NR_BITS + (`NUM_THREADS * `XLEN) + 1 + 1 + 1;
wire [COMMIT_DATAW-1:0] commit_if_data = { wire [COMMIT_DATAW-1:0] commit_if_data = {
dispatch_if_data_deq, /* uuid ~ rd */ execute_if_data_deq, /* uuid ~ rd */
subcommit == 1'b0 ? wb_data_0 : wb_data_1, /* data */ subcommit == 1'b0 ? wb_data_0 : wb_data_1, /* data */
1'b0, /* pid */ 1'b0, /* pid */
1'b1, /* sop */ 1'b1, /* sop */
@@ -227,6 +267,10 @@ module VX_tensor_octet #(
// note that not all lanes participate at every step // note that not all lanes participate at every step
case (step) case (step)
2'b00: begin 2'b00: begin
// Two A_in segments correspond to two 2x2 subtiles of A read
// by two threadgroups: [0:2,0:2] and [4:6,0:2] in Step 0 of
// Figure 10(b). B_in OTOH is shared by two threadgroups.
// Note k-dimension is shrunk from 4 to 2.
A_half = { A_in[5:4], A_in[1:0] }; A_half = { A_in[5:4], A_in[1:0] };
B_half = B_in[3:0]; B_half = B_in[3:0];
end end