tensor: Split packed fp16 and wire correctly to DPU
This commit is contained in:
@@ -170,6 +170,7 @@ module VX_tensor_threadgroup #(
|
||||
input valid_in,
|
||||
output ready_in,
|
||||
input stall,
|
||||
// all *_frag are row-major
|
||||
input [1:0][1:0][31:0] A_frag,
|
||||
input [1:0][3:0][31:0] B_frag,
|
||||
input [1:0][3:0][31:0] C_frag,
|
||||
@@ -191,16 +192,18 @@ module VX_tensor_threadgroup #(
|
||||
assign fedp_ready_in = fedp_ready_out;
|
||||
|
||||
// The dot product units take 2 cycles to finish computing A_frag * B_frag
|
||||
// + C_frag. step_in and step_out keeps track of which cycle they're at
|
||||
// + C_frag. substep_in and substep_out keeps track of which cycle they're at
|
||||
// & when they have to pop from input queue and push to result queue.
|
||||
// Note that substep is different from the "step" defined in the HMMA
|
||||
// instruction set; it is a purely microarchitectural construct.
|
||||
//
|
||||
// step_in == 0: FEDP uses first half from operand buffer
|
||||
// step_in == 1: FEDP uses last half and pops from operand buffer
|
||||
wire step_in;
|
||||
// step_out == 0: FEDP produces first half of D_frag
|
||||
// step_out == 1: FEDP produces last half of D_frag and asserts valid_out
|
||||
wire step_out;
|
||||
assign ready_in = fedp_fire_in && (step_in == 1'b1);
|
||||
// substep_in == 0: FEDP uses first half from operand buffer
|
||||
// substep_in == 1: FEDP uses last half and pops from operand buffer
|
||||
wire substep_in;
|
||||
// substep_out == 0: FEDP produces first half of D_frag
|
||||
// substep_out == 1: FEDP produces last half of D_frag and asserts valid_out
|
||||
wire substep_out;
|
||||
assign ready_in = fedp_fire_in && (substep_in == 1'b1);
|
||||
|
||||
wire [3:0][31:0] D_reg;
|
||||
logic [3:0][31:0] D_reg_n;
|
||||
@@ -221,30 +224,30 @@ module VX_tensor_threadgroup #(
|
||||
always @(*) begin
|
||||
D_reg_n = D_reg;
|
||||
if (fedp_fire_out) begin
|
||||
if (step_out == 1'b0) begin
|
||||
if (substep_out == 1'b0) begin
|
||||
D_reg_n = D_half;
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
// flip step_in/step_out on FEDP in/out fire, respectively
|
||||
// flip substep_in/substep_out on FEDP in/out fire, respectively
|
||||
VX_tensor_reg #(
|
||||
.DATAW(1)
|
||||
) staging_step_in (
|
||||
) staging_substep_in (
|
||||
.clk(clk),
|
||||
.reset(reset),
|
||||
.d(~step_in),
|
||||
.d(~substep_in),
|
||||
.en(fedp_fire_in),
|
||||
.q(step_in)
|
||||
.q(substep_in)
|
||||
);
|
||||
VX_tensor_reg #(
|
||||
.DATAW(1)
|
||||
) staging_step_out (
|
||||
) staging_substep_out (
|
||||
.clk(clk),
|
||||
.reset(reset),
|
||||
.d(~step_out),
|
||||
.d(~substep_out),
|
||||
.en(fedp_fire_out),
|
||||
.q(step_out)
|
||||
.q(substep_out)
|
||||
);
|
||||
|
||||
// TODO: Instead of latching half-result and constructing a full D tile,
|
||||
@@ -259,31 +262,44 @@ module VX_tensor_threadgroup #(
|
||||
assign D_frag[1][1] = D_half[2];
|
||||
assign D_frag[1][3] = D_half[3];
|
||||
|
||||
wire [1:0][3:0][15:0] A_frag_fp16 = A_frag;
|
||||
wire [3:0][3:0][15:0] B_frag_fp16 = B_frag;
|
||||
|
||||
// 4 FEDPs per threadgroup
|
||||
for (genvar i = 0; i < 4; ++i) begin
|
||||
// at substep == 0, the 0th and 2nd columns of D begins compute;
|
||||
// at substep == 1, the 1st and 3rd columns of D begins compute.
|
||||
// there are two row elements for each column, rounding out
|
||||
// 4 elements being computed by 4 FEDPs at every cycle
|
||||
// (see Figure 10(b)).
|
||||
|
||||
// d_row: 0, 0, 1, 1
|
||||
// d_col: 0, 2, 0, 2
|
||||
localparam int d_row = i / 2;
|
||||
localparam int d_col = (i % 2) * 2;
|
||||
wire [31:0] d_col_sel = (substep_in == 1'b0) ? d_col : (d_col + 1);
|
||||
|
||||
// Dot product (FEDP) unit generated from Chisel
|
||||
TensorDotProductUnit fedp (
|
||||
.clock (clk),
|
||||
.reset (reset),
|
||||
.io_in_valid (fedp_fire_in),
|
||||
.io_in_bits_a_0 (A_frag[d_row][0]),
|
||||
.io_in_bits_a_1 (A_frag[d_row][1]),
|
||||
.io_in_bits_a_2 (32'h0),
|
||||
.io_in_bits_a_3 (32'h0),
|
||||
.io_in_bits_b_0 (step_in == 1'b0 ? B_frag[0][d_col] : B_frag[0][d_col + 1]),
|
||||
.io_in_bits_b_1 (step_in == 1'b0 ? B_frag[1][d_col] : B_frag[1][d_col + 1]),
|
||||
.io_in_bits_b_2 (32'h0),
|
||||
.io_in_bits_b_3 (32'h0),
|
||||
.io_in_bits_c (step_in == 1'b0 ? C_frag[d_row][d_col] : C_frag[d_row][d_col + 1]),
|
||||
.io_in_bits_a_0 (A_frag[d_row][0][15: 0]),
|
||||
.io_in_bits_a_1 (A_frag[d_row][0][31:16]),
|
||||
.io_in_bits_a_2 (A_frag[d_row][1][15: 0]),
|
||||
.io_in_bits_a_3 (A_frag[d_row][1][31:16]),
|
||||
.io_in_bits_b_0 (B_frag[0][d_col_sel][15: 0]),
|
||||
.io_in_bits_b_1 (B_frag[0][d_col_sel][31:16]),
|
||||
.io_in_bits_b_2 (B_frag[1][d_col_sel][15: 0]),
|
||||
.io_in_bits_b_3 (B_frag[1][d_col_sel][31:16]),
|
||||
.io_in_bits_c (C_frag[d_row][d_col_sel]),
|
||||
.io_stall (stall),
|
||||
.io_out_valid (fedp_valids[i]),
|
||||
.io_out_bits_data (D_half[i])
|
||||
);
|
||||
end
|
||||
|
||||
assign valid_out = fedp_valid_out && (step_out == 1'b1);
|
||||
assign valid_out = fedp_valid_out && (substep_out == 1'b1);
|
||||
endmodule
|
||||
|
||||
`endif
|
||||
|
||||
Reference in New Issue
Block a user