tensor: Split packed fp16 and wire correctly to DPU
This commit is contained in:
@@ -170,6 +170,7 @@ module VX_tensor_threadgroup #(
|
|||||||
input valid_in,
|
input valid_in,
|
||||||
output ready_in,
|
output ready_in,
|
||||||
input stall,
|
input stall,
|
||||||
|
// all *_frag are row-major
|
||||||
input [1:0][1:0][31:0] A_frag,
|
input [1:0][1:0][31:0] A_frag,
|
||||||
input [1:0][3:0][31:0] B_frag,
|
input [1:0][3:0][31:0] B_frag,
|
||||||
input [1:0][3:0][31:0] C_frag,
|
input [1:0][3:0][31:0] C_frag,
|
||||||
@@ -191,16 +192,18 @@ module VX_tensor_threadgroup #(
|
|||||||
assign fedp_ready_in = fedp_ready_out;
|
assign fedp_ready_in = fedp_ready_out;
|
||||||
|
|
||||||
// The dot product units take 2 cycles to finish computing A_frag * B_frag
|
// The dot product units take 2 cycles to finish computing A_frag * B_frag
|
||||||
// + C_frag. step_in and step_out keeps track of which cycle they're at
|
// + C_frag. substep_in and substep_out keeps track of which cycle they're at
|
||||||
// & when they have to pop from input queue and push to result queue.
|
// & when they have to pop from input queue and push to result queue.
|
||||||
|
// Note that substep is different from the "step" defined in the HMMA
|
||||||
|
// instruction set; it is a purely microarchitectural construct.
|
||||||
//
|
//
|
||||||
// step_in == 0: FEDP uses first half from operand buffer
|
// substep_in == 0: FEDP uses first half from operand buffer
|
||||||
// step_in == 1: FEDP uses last half and pops from operand buffer
|
// substep_in == 1: FEDP uses last half and pops from operand buffer
|
||||||
wire step_in;
|
wire substep_in;
|
||||||
// step_out == 0: FEDP produces first half of D_frag
|
// substep_out == 0: FEDP produces first half of D_frag
|
||||||
// step_out == 1: FEDP produces last half of D_frag and asserts valid_out
|
// substep_out == 1: FEDP produces last half of D_frag and asserts valid_out
|
||||||
wire step_out;
|
wire substep_out;
|
||||||
assign ready_in = fedp_fire_in && (step_in == 1'b1);
|
assign ready_in = fedp_fire_in && (substep_in == 1'b1);
|
||||||
|
|
||||||
wire [3:0][31:0] D_reg;
|
wire [3:0][31:0] D_reg;
|
||||||
logic [3:0][31:0] D_reg_n;
|
logic [3:0][31:0] D_reg_n;
|
||||||
@@ -221,30 +224,30 @@ module VX_tensor_threadgroup #(
|
|||||||
always @(*) begin
|
always @(*) begin
|
||||||
D_reg_n = D_reg;
|
D_reg_n = D_reg;
|
||||||
if (fedp_fire_out) begin
|
if (fedp_fire_out) begin
|
||||||
if (step_out == 1'b0) begin
|
if (substep_out == 1'b0) begin
|
||||||
D_reg_n = D_half;
|
D_reg_n = D_half;
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
// flip step_in/step_out on FEDP in/out fire, respectively
|
// flip substep_in/substep_out on FEDP in/out fire, respectively
|
||||||
VX_tensor_reg #(
|
VX_tensor_reg #(
|
||||||
.DATAW(1)
|
.DATAW(1)
|
||||||
) staging_step_in (
|
) staging_substep_in (
|
||||||
.clk(clk),
|
.clk(clk),
|
||||||
.reset(reset),
|
.reset(reset),
|
||||||
.d(~step_in),
|
.d(~substep_in),
|
||||||
.en(fedp_fire_in),
|
.en(fedp_fire_in),
|
||||||
.q(step_in)
|
.q(substep_in)
|
||||||
);
|
);
|
||||||
VX_tensor_reg #(
|
VX_tensor_reg #(
|
||||||
.DATAW(1)
|
.DATAW(1)
|
||||||
) staging_step_out (
|
) staging_substep_out (
|
||||||
.clk(clk),
|
.clk(clk),
|
||||||
.reset(reset),
|
.reset(reset),
|
||||||
.d(~step_out),
|
.d(~substep_out),
|
||||||
.en(fedp_fire_out),
|
.en(fedp_fire_out),
|
||||||
.q(step_out)
|
.q(substep_out)
|
||||||
);
|
);
|
||||||
|
|
||||||
// TODO: Instead of latching half-result and constructing a full D tile,
|
// TODO: Instead of latching half-result and constructing a full D tile,
|
||||||
@@ -259,31 +262,44 @@ module VX_tensor_threadgroup #(
|
|||||||
assign D_frag[1][1] = D_half[2];
|
assign D_frag[1][1] = D_half[2];
|
||||||
assign D_frag[1][3] = D_half[3];
|
assign D_frag[1][3] = D_half[3];
|
||||||
|
|
||||||
|
wire [1:0][3:0][15:0] A_frag_fp16 = A_frag;
|
||||||
|
wire [3:0][3:0][15:0] B_frag_fp16 = B_frag;
|
||||||
|
|
||||||
// 4 FEDPs per threadgroup
|
// 4 FEDPs per threadgroup
|
||||||
for (genvar i = 0; i < 4; ++i) begin
|
for (genvar i = 0; i < 4; ++i) begin
|
||||||
|
// at substep == 0, the 0th and 2nd columns of D begins compute;
|
||||||
|
// at substep == 1, the 1st and 3rd columns of D begins compute.
|
||||||
|
// there are two row elements for each column, rounding out
|
||||||
|
// 4 elements being computed by 4 FEDPs at every cycle
|
||||||
|
// (see Figure 10(b)).
|
||||||
|
|
||||||
|
// d_row: 0, 0, 1, 1
|
||||||
|
// d_col: 0, 2, 0, 2
|
||||||
localparam int d_row = i / 2;
|
localparam int d_row = i / 2;
|
||||||
localparam int d_col = (i % 2) * 2;
|
localparam int d_col = (i % 2) * 2;
|
||||||
|
wire [31:0] d_col_sel = (substep_in == 1'b0) ? d_col : (d_col + 1);
|
||||||
|
|
||||||
// Dot product (FEDP) unit generated from Chisel
|
// Dot product (FEDP) unit generated from Chisel
|
||||||
TensorDotProductUnit fedp (
|
TensorDotProductUnit fedp (
|
||||||
.clock (clk),
|
.clock (clk),
|
||||||
.reset (reset),
|
.reset (reset),
|
||||||
.io_in_valid (fedp_fire_in),
|
.io_in_valid (fedp_fire_in),
|
||||||
.io_in_bits_a_0 (A_frag[d_row][0]),
|
.io_in_bits_a_0 (A_frag[d_row][0][15: 0]),
|
||||||
.io_in_bits_a_1 (A_frag[d_row][1]),
|
.io_in_bits_a_1 (A_frag[d_row][0][31:16]),
|
||||||
.io_in_bits_a_2 (32'h0),
|
.io_in_bits_a_2 (A_frag[d_row][1][15: 0]),
|
||||||
.io_in_bits_a_3 (32'h0),
|
.io_in_bits_a_3 (A_frag[d_row][1][31:16]),
|
||||||
.io_in_bits_b_0 (step_in == 1'b0 ? B_frag[0][d_col] : B_frag[0][d_col + 1]),
|
.io_in_bits_b_0 (B_frag[0][d_col_sel][15: 0]),
|
||||||
.io_in_bits_b_1 (step_in == 1'b0 ? B_frag[1][d_col] : B_frag[1][d_col + 1]),
|
.io_in_bits_b_1 (B_frag[0][d_col_sel][31:16]),
|
||||||
.io_in_bits_b_2 (32'h0),
|
.io_in_bits_b_2 (B_frag[1][d_col_sel][15: 0]),
|
||||||
.io_in_bits_b_3 (32'h0),
|
.io_in_bits_b_3 (B_frag[1][d_col_sel][31:16]),
|
||||||
.io_in_bits_c (step_in == 1'b0 ? C_frag[d_row][d_col] : C_frag[d_row][d_col + 1]),
|
.io_in_bits_c (C_frag[d_row][d_col_sel]),
|
||||||
.io_stall (stall),
|
.io_stall (stall),
|
||||||
.io_out_valid (fedp_valids[i]),
|
.io_out_valid (fedp_valids[i]),
|
||||||
.io_out_bits_data (D_half[i])
|
.io_out_bits_data (D_half[i])
|
||||||
);
|
);
|
||||||
end
|
end
|
||||||
|
|
||||||
assign valid_out = fedp_valid_out && (step_out == 1'b1);
|
assign valid_out = fedp_valid_out && (substep_out == 1'b1);
|
||||||
endmodule
|
endmodule
|
||||||
|
|
||||||
`endif
|
`endif
|
||||||
|
|||||||
Reference in New Issue
Block a user