tensor: Split flops into structural module
to get separate area/power numbers in hierarchical
This commit is contained in:
@@ -51,7 +51,7 @@ module VX_tensor_core import VX_gpu_pkg::*; #(
|
||||
);
|
||||
|
||||
for (genvar block_idx = 0; block_idx < BLOCK_SIZE; ++block_idx) begin
|
||||
VX_tensor_core_warp #(
|
||||
VX_tensor_core_block #(
|
||||
.ISW(1) // FIXME: not block_idx
|
||||
) tensor_core (
|
||||
.clk(clk),
|
||||
@@ -64,7 +64,7 @@ module VX_tensor_core import VX_gpu_pkg::*; #(
|
||||
|
||||
endmodule
|
||||
|
||||
module VX_tensor_core_warp import VX_gpu_pkg::*; #(
|
||||
module VX_tensor_core_block import VX_gpu_pkg::*; #(
|
||||
parameter ISW
|
||||
) (
|
||||
input clk,
|
||||
@@ -82,15 +82,16 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #(
|
||||
localparam METADATA_QUEUE_DEPTH = 2 * `LATENCY_HMMA;
|
||||
|
||||
wire [1:0] step = 2'(execute_if.data.op_type);
|
||||
// op_mod is reused to indicate instruction's id in pair
|
||||
// op_mod is reused to indicate if instruction is the last substep inside
|
||||
// a step (pair of substeps)
|
||||
wire last_in_pair = (execute_if.data.op_mod == `INST_MOD_BITS'(1));
|
||||
|
||||
logic [NUM_OCTETS-1:0] octet_results_valid;
|
||||
wire [NUM_OCTETS-1:0] octet_results_valid;
|
||||
logic [NUM_OCTETS-1:0] octet_results_ready;
|
||||
logic [NUM_OCTETS-1:0] octet_operands_ready;
|
||||
wire [NUM_OCTETS-1:0] octet_operands_ready;
|
||||
// FIXME: should be NUM_LANES?
|
||||
logic [`NUM_THREADS-1:0][`XLEN-1:0] wb_data_0;
|
||||
logic [`NUM_THREADS-1:0][`XLEN-1:0] wb_data_1;
|
||||
wire [`NUM_THREADS-1:0][`XLEN-1:0] wb_data_0;
|
||||
wire [`NUM_THREADS-1:0][`XLEN-1:0] wb_data_1;
|
||||
wire [`NW_WIDTH-1:0] wb_wid;
|
||||
|
||||
// valid signal synced between the functional units (octet) and the
|
||||
@@ -113,9 +114,9 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #(
|
||||
execute_if.data.rs3_data[LANE_OFFSET_THREADGROUP + 4*i +: 4], execute_if.data.rs3_data[4*i +: 4]
|
||||
};
|
||||
|
||||
logic [3:0][3:0][31:0] octet_D;
|
||||
logic result_valid;
|
||||
logic result_ready;
|
||||
wire [3:0][3:0][31:0] octet_D;
|
||||
wire result_valid;
|
||||
wire result_ready;
|
||||
|
||||
VX_tensor_octet #(
|
||||
.ISW(ISW),
|
||||
@@ -285,15 +286,39 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #(
|
||||
end
|
||||
end
|
||||
|
||||
VX_tensor_reg #(
|
||||
.N(1)
|
||||
) staging_subcommit (
|
||||
.clk(clk),
|
||||
.reset(reset),
|
||||
.d(subcommit_n),
|
||||
.en(1'b1),
|
||||
.q(subcommit)
|
||||
);
|
||||
endmodule
|
||||
|
||||
module VX_tensor_reg #(
|
||||
parameter N
|
||||
) (
|
||||
input clk,
|
||||
input reset,
|
||||
input [N-1:0] d,
|
||||
input en,
|
||||
output [N-1:0] q
|
||||
);
|
||||
logic [N-1:0] data;
|
||||
|
||||
always @(posedge clk) begin
|
||||
if (reset) begin
|
||||
subcommit <= '0;
|
||||
end
|
||||
else begin
|
||||
subcommit <= subcommit_n;
|
||||
data <= '0;
|
||||
end else begin
|
||||
if (en) begin
|
||||
data <= d;
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
|
||||
assign q = data;
|
||||
endmodule
|
||||
|
||||
module VX_tensor_octet #(
|
||||
@@ -337,7 +362,6 @@ module VX_tensor_octet #(
|
||||
logic [3:0][31:0] B_half_buf;
|
||||
logic [7:0][31:0] C_half_buf;
|
||||
|
||||
|
||||
logic [`NUM_WARPS-1:0] substeps;
|
||||
logic [`NUM_WARPS-1:0] substeps_n;
|
||||
|
||||
@@ -353,6 +377,7 @@ module VX_tensor_octet #(
|
||||
assign A_in_buf = A_in;
|
||||
assign B_in_buf = B_in;
|
||||
assign C_in_buf = C_in;
|
||||
// TODO: merge *_buf/*
|
||||
assign operands_step_buf = operands_step;
|
||||
assign operands_wid_buf = operands_wid;
|
||||
assign operands_last_in_pair_buf = operands_last_in_pair;
|
||||
@@ -408,6 +433,18 @@ module VX_tensor_octet #(
|
||||
// wire operands_first_in_pair_fire = operands_ready && operands_valid && (!operands_last_in_pair);
|
||||
wire operands_first_in_pair_fire = operands_ready_buf && operands_valid_buf && (!operands_last_in_pair_buf);
|
||||
|
||||
// Staging buffer for the A/B/C half-tiles that will later be assembled
|
||||
// with the other half tiles coming in on the input ports.
|
||||
VX_tensor_reg #(
|
||||
.N($bits(A_buffer) + $bits(B_buffer) + $bits(C_buffer))
|
||||
) staging_abc (
|
||||
.clk(clk),
|
||||
.reset(reset),
|
||||
.d({A_buffer_n, B_buffer_n, C_buffer_n}),
|
||||
.en(1'b1),
|
||||
.q({A_buffer, B_buffer, C_buffer})
|
||||
);
|
||||
|
||||
always @(*) begin
|
||||
A_buffer_n = A_buffer;
|
||||
B_buffer_n = B_buffer;
|
||||
@@ -426,20 +463,15 @@ module VX_tensor_octet #(
|
||||
end
|
||||
end
|
||||
|
||||
always @(posedge clk) begin
|
||||
if (reset) begin
|
||||
A_buffer <= '0;
|
||||
B_buffer <= '0;
|
||||
C_buffer <= '0;
|
||||
substeps <= '0;
|
||||
end
|
||||
else begin
|
||||
A_buffer <= A_buffer_n;
|
||||
B_buffer <= B_buffer_n;
|
||||
C_buffer <= C_buffer_n;
|
||||
substeps <= substeps_n;
|
||||
end
|
||||
end
|
||||
VX_tensor_reg #(
|
||||
.N($bits(substeps))
|
||||
) staging_substeps (
|
||||
.clk(clk),
|
||||
.reset(reset),
|
||||
.d(substeps_n),
|
||||
.en(1'b1),
|
||||
.q(substeps)
|
||||
);
|
||||
|
||||
wire outbuf_ready_in;
|
||||
wire hmma_ready;
|
||||
@@ -458,8 +490,8 @@ module VX_tensor_octet #(
|
||||
};
|
||||
// C is 4x4 fp32 matrix
|
||||
logic [3:0][3:0][31:0] C_tile;
|
||||
logic [3:0][3:0][31:0] D_tile;
|
||||
logic [`NW_WIDTH-1:0] D_wid_dpu;
|
||||
wire [3:0][3:0][31:0] D_tile;
|
||||
wire [`NW_WIDTH-1:0] D_wid_dpu;
|
||||
|
||||
always @(*) begin
|
||||
C_tile[3] = { halves_buf.C_half[7], C_buffer[operands_wid_buf][7], halves_buf.C_half[5], C_buffer[operands_wid_buf][5] };
|
||||
|
||||
@@ -62,11 +62,11 @@ module VX_tensor_dpu #(
|
||||
// stalling
|
||||
// assign ready_in = ready_out;
|
||||
|
||||
logic synced_fire;
|
||||
wire synced_fire;
|
||||
assign synced_fire = valid_in && ready_in;
|
||||
|
||||
logic [1:0] threadgroup_valids;
|
||||
logic [1:0] threadgroup_readys;
|
||||
wire [1:0] threadgroup_valids;
|
||||
wire [1:0] threadgroup_readys;
|
||||
// B_tile is shared across the two threadgroups; see Figure 13
|
||||
VX_tensor_threadgroup #(
|
||||
.ISSUE_QUEUE_DEPTH(ISSUE_QUEUE_DEPTH)
|
||||
@@ -187,7 +187,7 @@ module VX_tensor_threadgroup #(
|
||||
`UNUSED_PIN(size)
|
||||
);
|
||||
|
||||
logic [3:0] fedp_valids;
|
||||
wire [3:0] fedp_valids;
|
||||
wire fedp_valid_out = &(fedp_valids);
|
||||
wire fedp_ready_out = !stall;
|
||||
wire fedp_fire_out = fedp_valid_out && fedp_ready_out;
|
||||
@@ -198,14 +198,27 @@ module VX_tensor_threadgroup #(
|
||||
|
||||
// 0: FEDP uses first half from input_buffer
|
||||
// 1: FEDP uses last half and pops input_buffer
|
||||
logic step_in;
|
||||
wire step_in;
|
||||
// 0: FEDP produces first half of D_frag
|
||||
// 1: FEDP produces last half of D_frag and asserts valid_out
|
||||
logic step_out;
|
||||
wire step_out;
|
||||
assign ready_buf = fedp_fire_in && (step_in == 1'b1);
|
||||
|
||||
wire [3:0][31:0] D_reg;
|
||||
logic [3:0][31:0] D_reg_n;
|
||||
|
||||
// Staging buffer that latches the D half-tile.
|
||||
VX_tensor_reg #(
|
||||
.N($bits(D_reg))
|
||||
) staging_d (
|
||||
.clk(clk),
|
||||
.reset(reset),
|
||||
.d(D_reg_n),
|
||||
.en(1'b1),
|
||||
.q(D_reg)
|
||||
);
|
||||
|
||||
// latch the first-half result of D_frag
|
||||
logic [3:0][31:0] D_reg, D_reg_n;
|
||||
wire [3:0][31:0] D_half;
|
||||
always @(*) begin
|
||||
D_reg_n = D_reg;
|
||||
@@ -216,23 +229,25 @@ module VX_tensor_threadgroup #(
|
||||
end
|
||||
end
|
||||
|
||||
always @(posedge clk) begin
|
||||
if (reset) begin
|
||||
step_in <= '0;
|
||||
step_out <= '0;
|
||||
|
||||
D_reg <= '0;
|
||||
end else begin
|
||||
if (fedp_fire_in) begin
|
||||
step_in <= ~step_in;
|
||||
end
|
||||
if (fedp_fire_out) begin
|
||||
step_out <= ~step_out;
|
||||
end
|
||||
|
||||
D_reg <= D_reg_n;
|
||||
end
|
||||
end
|
||||
// flip step_in/step_out on FEDP in/out fire, respectively
|
||||
VX_tensor_reg #(
|
||||
.N(1)
|
||||
) staging_step_in (
|
||||
.clk(clk),
|
||||
.reset(reset),
|
||||
.d(~step_in),
|
||||
.en(fedp_fire_in),
|
||||
.q(step_in)
|
||||
);
|
||||
VX_tensor_reg #(
|
||||
.N(1)
|
||||
) staging_step_out (
|
||||
.clk(clk),
|
||||
.reset(reset),
|
||||
.d(~step_out),
|
||||
.en(fedp_fire_out),
|
||||
.q(step_out)
|
||||
);
|
||||
|
||||
// TODO: Instead of latching half-result and constructing a full D tile,
|
||||
// we should be able to send these half fragments down to commit stage
|
||||
|
||||
Reference in New Issue
Block a user