tensor: Split packed fp16 and wire correctly to DPU

This commit is contained in:
Hansung Kim
2024-08-07 11:16:38 -07:00
parent d4d18c2823
commit 15e93e01d8

View File

@@ -170,6 +170,7 @@ module VX_tensor_threadgroup #(
input valid_in,
output ready_in,
input stall,
// all *_frag are row-major
input [1:0][1:0][31:0] A_frag,
input [1:0][3:0][31:0] B_frag,
input [1:0][3:0][31:0] C_frag,
@@ -191,16 +192,18 @@ module VX_tensor_threadgroup #(
assign fedp_ready_in = fedp_ready_out;
// The dot product units take 2 cycles to finish computing A_frag * B_frag
// + C_frag. step_in and step_out keeps track of which cycle they're at
// + C_frag. substep_in and substep_out keeps track of which cycle they're at
// & when they have to pop from input queue and push to result queue.
// Note that substep is different from the "step" defined in the HMMA
// instruction set; it is a purely microarchitectural construct.
//
// step_in == 0: FEDP uses first half from operand buffer
// step_in == 1: FEDP uses last half and pops from operand buffer
wire step_in;
// step_out == 0: FEDP produces first half of D_frag
// step_out == 1: FEDP produces last half of D_frag and asserts valid_out
wire step_out;
assign ready_in = fedp_fire_in && (step_in == 1'b1);
// substep_in == 0: FEDP uses first half from operand buffer
// substep_in == 1: FEDP uses last half and pops from operand buffer
wire substep_in;
// substep_out == 0: FEDP produces first half of D_frag
// substep_out == 1: FEDP produces last half of D_frag and asserts valid_out
wire substep_out;
assign ready_in = fedp_fire_in && (substep_in == 1'b1);
wire [3:0][31:0] D_reg;
logic [3:0][31:0] D_reg_n;
@@ -221,30 +224,30 @@ module VX_tensor_threadgroup #(
always @(*) begin
D_reg_n = D_reg;
if (fedp_fire_out) begin
if (step_out == 1'b0) begin
if (substep_out == 1'b0) begin
D_reg_n = D_half;
end
end
end
// flip step_in/step_out on FEDP in/out fire, respectively
// flip substep_in/substep_out on FEDP in/out fire, respectively
VX_tensor_reg #(
.DATAW(1)
) staging_step_in (
) staging_substep_in (
.clk(clk),
.reset(reset),
.d(~step_in),
.d(~substep_in),
.en(fedp_fire_in),
.q(step_in)
.q(substep_in)
);
VX_tensor_reg #(
.DATAW(1)
) staging_step_out (
) staging_substep_out (
.clk(clk),
.reset(reset),
.d(~step_out),
.d(~substep_out),
.en(fedp_fire_out),
.q(step_out)
.q(substep_out)
);
// TODO: Instead of latching half-result and constructing a full D tile,
@@ -259,31 +262,44 @@ module VX_tensor_threadgroup #(
assign D_frag[1][1] = D_half[2];
assign D_frag[1][3] = D_half[3];
wire [1:0][3:0][15:0] A_frag_fp16 = A_frag;
wire [3:0][3:0][15:0] B_frag_fp16 = B_frag;
// 4 FEDPs per threadgroup
for (genvar i = 0; i < 4; ++i) begin
// at substep == 0, the 0th and 2nd columns of D begins compute;
// at substep == 1, the 1st and 3rd columns of D begins compute.
// there are two row elements for each column, rounding out
// 4 elements being computed by 4 FEDPs at every cycle
// (see Figure 10(b)).
// d_row: 0, 0, 1, 1
// d_col: 0, 2, 0, 2
localparam int d_row = i / 2;
localparam int d_col = (i % 2) * 2;
wire [31:0] d_col_sel = (substep_in == 1'b0) ? d_col : (d_col + 1);
// Dot product (FEDP) unit generated from Chisel
TensorDotProductUnit fedp (
.clock (clk),
.reset (reset),
.io_in_valid (fedp_fire_in),
.io_in_bits_a_0 (A_frag[d_row][0]),
.io_in_bits_a_1 (A_frag[d_row][1]),
.io_in_bits_a_2 (32'h0),
.io_in_bits_a_3 (32'h0),
.io_in_bits_b_0 (step_in == 1'b0 ? B_frag[0][d_col] : B_frag[0][d_col + 1]),
.io_in_bits_b_1 (step_in == 1'b0 ? B_frag[1][d_col] : B_frag[1][d_col + 1]),
.io_in_bits_b_2 (32'h0),
.io_in_bits_b_3 (32'h0),
.io_in_bits_c (step_in == 1'b0 ? C_frag[d_row][d_col] : C_frag[d_row][d_col + 1]),
.io_in_bits_a_0 (A_frag[d_row][0][15: 0]),
.io_in_bits_a_1 (A_frag[d_row][0][31:16]),
.io_in_bits_a_2 (A_frag[d_row][1][15: 0]),
.io_in_bits_a_3 (A_frag[d_row][1][31:16]),
.io_in_bits_b_0 (B_frag[0][d_col_sel][15: 0]),
.io_in_bits_b_1 (B_frag[0][d_col_sel][31:16]),
.io_in_bits_b_2 (B_frag[1][d_col_sel][15: 0]),
.io_in_bits_b_3 (B_frag[1][d_col_sel][31:16]),
.io_in_bits_c (C_frag[d_row][d_col_sel]),
.io_stall (stall),
.io_out_valid (fedp_valids[i]),
.io_out_bits_data (D_half[i])
);
end
assign valid_out = fedp_valid_out && (step_out == 1'b1);
assign valid_out = fedp_valid_out && (substep_out == 1'b1);
endmodule
`endif