tensor: Split flops into structural module

to get separate area/power numbers in hierarchical
This commit is contained in:
Hansung Kim
2024-07-26 16:26:48 -07:00
parent 7f43bab0aa
commit 01f6024a76
2 changed files with 103 additions and 56 deletions

View File

@@ -51,7 +51,7 @@ module VX_tensor_core import VX_gpu_pkg::*; #(
);
for (genvar block_idx = 0; block_idx < BLOCK_SIZE; ++block_idx) begin
VX_tensor_core_warp #(
VX_tensor_core_block #(
.ISW(1) // FIXME: not block_idx
) tensor_core (
.clk(clk),
@@ -64,7 +64,7 @@ module VX_tensor_core import VX_gpu_pkg::*; #(
endmodule
module VX_tensor_core_warp import VX_gpu_pkg::*; #(
module VX_tensor_core_block import VX_gpu_pkg::*; #(
parameter ISW
) (
input clk,
@@ -82,15 +82,16 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #(
localparam METADATA_QUEUE_DEPTH = 2 * `LATENCY_HMMA;
wire [1:0] step = 2'(execute_if.data.op_type);
// op_mod is reused to indicate instruction's id in pair
// op_mod is reused to indicate if instruction is the last substep inside
// a step (pair of substeps)
wire last_in_pair = (execute_if.data.op_mod == `INST_MOD_BITS'(1));
logic [NUM_OCTETS-1:0] octet_results_valid;
wire [NUM_OCTETS-1:0] octet_results_valid;
logic [NUM_OCTETS-1:0] octet_results_ready;
logic [NUM_OCTETS-1:0] octet_operands_ready;
wire [NUM_OCTETS-1:0] octet_operands_ready;
// FIXME: should be NUM_LANES?
logic [`NUM_THREADS-1:0][`XLEN-1:0] wb_data_0;
logic [`NUM_THREADS-1:0][`XLEN-1:0] wb_data_1;
wire [`NUM_THREADS-1:0][`XLEN-1:0] wb_data_0;
wire [`NUM_THREADS-1:0][`XLEN-1:0] wb_data_1;
wire [`NW_WIDTH-1:0] wb_wid;
// valid signal synced between the functional units (octet) and the
@@ -113,9 +114,9 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #(
execute_if.data.rs3_data[LANE_OFFSET_THREADGROUP + 4*i +: 4], execute_if.data.rs3_data[4*i +: 4]
};
logic [3:0][3:0][31:0] octet_D;
logic result_valid;
logic result_ready;
wire [3:0][3:0][31:0] octet_D;
wire result_valid;
wire result_ready;
VX_tensor_octet #(
.ISW(ISW),
@@ -285,15 +286,39 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #(
end
end
VX_tensor_reg #(
.N(1)
) staging_subcommit (
.clk(clk),
.reset(reset),
.d(subcommit_n),
.en(1'b1),
.q(subcommit)
);
endmodule
module VX_tensor_reg #(
parameter N
) (
input clk,
input reset,
input [N-1:0] d,
input en,
output [N-1:0] q
);
logic [N-1:0] data;
always @(posedge clk) begin
if (reset) begin
subcommit <= '0;
end
else begin
subcommit <= subcommit_n;
data <= '0;
end else begin
if (en) begin
data <= d;
end
end
end
assign q = data;
endmodule
module VX_tensor_octet #(
@@ -337,7 +362,6 @@ module VX_tensor_octet #(
logic [3:0][31:0] B_half_buf;
logic [7:0][31:0] C_half_buf;
logic [`NUM_WARPS-1:0] substeps;
logic [`NUM_WARPS-1:0] substeps_n;
@@ -353,6 +377,7 @@ module VX_tensor_octet #(
assign A_in_buf = A_in;
assign B_in_buf = B_in;
assign C_in_buf = C_in;
// TODO: merge *_buf/*
assign operands_step_buf = operands_step;
assign operands_wid_buf = operands_wid;
assign operands_last_in_pair_buf = operands_last_in_pair;
@@ -408,6 +433,18 @@ module VX_tensor_octet #(
// wire operands_first_in_pair_fire = operands_ready && operands_valid && (!operands_last_in_pair);
wire operands_first_in_pair_fire = operands_ready_buf && operands_valid_buf && (!operands_last_in_pair_buf);
// Staging buffer for the A/B/C half-tiles that will later be assembled
// with the other half tiles coming in on the input ports.
VX_tensor_reg #(
.N($bits(A_buffer) + $bits(B_buffer) + $bits(C_buffer))
) staging_abc (
.clk(clk),
.reset(reset),
.d({A_buffer_n, B_buffer_n, C_buffer_n}),
.en(1'b1),
.q({A_buffer, B_buffer, C_buffer})
);
always @(*) begin
A_buffer_n = A_buffer;
B_buffer_n = B_buffer;
@@ -426,20 +463,15 @@ module VX_tensor_octet #(
end
end
always @(posedge clk) begin
if (reset) begin
A_buffer <= '0;
B_buffer <= '0;
C_buffer <= '0;
substeps <= '0;
end
else begin
A_buffer <= A_buffer_n;
B_buffer <= B_buffer_n;
C_buffer <= C_buffer_n;
substeps <= substeps_n;
end
end
VX_tensor_reg #(
.N($bits(substeps))
) staging_substeps (
.clk(clk),
.reset(reset),
.d(substeps_n),
.en(1'b1),
.q(substeps)
);
wire outbuf_ready_in;
wire hmma_ready;
@@ -458,8 +490,8 @@ module VX_tensor_octet #(
};
// C is 4x4 fp32 matrix
logic [3:0][3:0][31:0] C_tile;
logic [3:0][3:0][31:0] D_tile;
logic [`NW_WIDTH-1:0] D_wid_dpu;
wire [3:0][3:0][31:0] D_tile;
wire [`NW_WIDTH-1:0] D_wid_dpu;
always @(*) begin
C_tile[3] = { halves_buf.C_half[7], C_buffer[operands_wid_buf][7], halves_buf.C_half[5], C_buffer[operands_wid_buf][5] };

View File

@@ -62,11 +62,11 @@ module VX_tensor_dpu #(
// stalling
// assign ready_in = ready_out;
logic synced_fire;
wire synced_fire;
assign synced_fire = valid_in && ready_in;
logic [1:0] threadgroup_valids;
logic [1:0] threadgroup_readys;
wire [1:0] threadgroup_valids;
wire [1:0] threadgroup_readys;
// B_tile is shared across the two threadgroups; see Figure 13
VX_tensor_threadgroup #(
.ISSUE_QUEUE_DEPTH(ISSUE_QUEUE_DEPTH)
@@ -187,7 +187,7 @@ module VX_tensor_threadgroup #(
`UNUSED_PIN(size)
);
logic [3:0] fedp_valids;
wire [3:0] fedp_valids;
wire fedp_valid_out = &(fedp_valids);
wire fedp_ready_out = !stall;
wire fedp_fire_out = fedp_valid_out && fedp_ready_out;
@@ -198,14 +198,27 @@ module VX_tensor_threadgroup #(
// 0: FEDP uses first half from input_buffer
// 1: FEDP uses last half and pops input_buffer
logic step_in;
wire step_in;
// 0: FEDP produces first half of D_frag
// 1: FEDP produces last half of D_frag and asserts valid_out
logic step_out;
wire step_out;
assign ready_buf = fedp_fire_in && (step_in == 1'b1);
wire [3:0][31:0] D_reg;
logic [3:0][31:0] D_reg_n;
// Staging buffer that latches the D half-tile.
VX_tensor_reg #(
.N($bits(D_reg))
) staging_d (
.clk(clk),
.reset(reset),
.d(D_reg_n),
.en(1'b1),
.q(D_reg)
);
// latch the first-half result of D_frag
logic [3:0][31:0] D_reg, D_reg_n;
wire [3:0][31:0] D_half;
always @(*) begin
D_reg_n = D_reg;
@@ -216,23 +229,25 @@ module VX_tensor_threadgroup #(
end
end
always @(posedge clk) begin
if (reset) begin
step_in <= '0;
step_out <= '0;
D_reg <= '0;
end else begin
if (fedp_fire_in) begin
step_in <= ~step_in;
end
if (fedp_fire_out) begin
step_out <= ~step_out;
end
D_reg <= D_reg_n;
end
end
// flip step_in/step_out on FEDP in/out fire, respectively
VX_tensor_reg #(
.N(1)
) staging_step_in (
.clk(clk),
.reset(reset),
.d(~step_in),
.en(fedp_fire_in),
.q(step_in)
);
VX_tensor_reg #(
.N(1)
) staging_step_out (
.clk(clk),
.reset(reset),
.d(~step_out),
.en(fedp_fire_out),
.q(step_out)
);
// TODO: Instead of latching half-result and constructing a full D tile,
// we should be able to send these half fragments down to commit stage