tensor: Split flops into structural module

to get separate area/power numbers in hierarchical
This commit is contained in:
Hansung Kim
2024-07-26 16:26:48 -07:00
parent 7f43bab0aa
commit 01f6024a76
2 changed files with 103 additions and 56 deletions

View File

@@ -62,11 +62,11 @@ module VX_tensor_dpu #(
// stalling
// assign ready_in = ready_out;
logic synced_fire;
wire synced_fire;
assign synced_fire = valid_in && ready_in;
logic [1:0] threadgroup_valids;
logic [1:0] threadgroup_readys;
wire [1:0] threadgroup_valids;
wire [1:0] threadgroup_readys;
// B_tile is shared across the two threadgroups; see Figure 13
VX_tensor_threadgroup #(
.ISSUE_QUEUE_DEPTH(ISSUE_QUEUE_DEPTH)
@@ -187,7 +187,7 @@ module VX_tensor_threadgroup #(
`UNUSED_PIN(size)
);
logic [3:0] fedp_valids;
wire [3:0] fedp_valids;
wire fedp_valid_out = &(fedp_valids);
wire fedp_ready_out = !stall;
wire fedp_fire_out = fedp_valid_out && fedp_ready_out;
@@ -198,14 +198,27 @@ module VX_tensor_threadgroup #(
// 0: FEDP uses first half from input_buffer
// 1: FEDP uses last half and pops input_buffer
logic step_in;
wire step_in;
// 0: FEDP produces first half of D_frag
// 1: FEDP produces last half of D_frag and asserts valid_out
logic step_out;
wire step_out;
assign ready_buf = fedp_fire_in && (step_in == 1'b1);
wire [3:0][31:0] D_reg;
logic [3:0][31:0] D_reg_n;
// Staging buffer that latches the D half-tile.
VX_tensor_reg #(
.N($bits(D_reg))
) staging_d (
.clk(clk),
.reset(reset),
.d(D_reg_n),
.en(1'b1),
.q(D_reg)
);
// latch the first-half result of D_frag
logic [3:0][31:0] D_reg, D_reg_n;
wire [3:0][31:0] D_half;
always @(*) begin
D_reg_n = D_reg;
@@ -216,23 +229,25 @@ module VX_tensor_threadgroup #(
end
end
always @(posedge clk) begin
if (reset) begin
step_in <= '0;
step_out <= '0;
D_reg <= '0;
end else begin
if (fedp_fire_in) begin
step_in <= ~step_in;
end
if (fedp_fire_out) begin
step_out <= ~step_out;
end
D_reg <= D_reg_n;
end
end
// flip step_in/step_out on FEDP in/out fire, respectively
VX_tensor_reg #(
.N(1)
) staging_step_in (
.clk(clk),
.reset(reset),
.d(~step_in),
.en(fedp_fire_in),
.q(step_in)
);
VX_tensor_reg #(
.N(1)
) staging_step_out (
.clk(clk),
.reset(reset),
.d(~step_out),
.en(fedp_fire_out),
.q(step_out)
);
// TODO: Instead of latching half-result and constructing a full D tile,
// we should be able to send these half fragments down to commit stage