From 15e93e01d82a7c4f5a5889818327f85d9b39a5cf Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Wed, 7 Aug 2024 11:16:38 -0700 Subject: [PATCH] tensor: Split packed fp16 and wire correctly to DPU --- hw/rtl/fpu/VX_tensor_dpu.sv | 68 +++++++++++++++++++++++-------------- 1 file changed, 42 insertions(+), 26 deletions(-) diff --git a/hw/rtl/fpu/VX_tensor_dpu.sv b/hw/rtl/fpu/VX_tensor_dpu.sv index c196ffec..eb974633 100644 --- a/hw/rtl/fpu/VX_tensor_dpu.sv +++ b/hw/rtl/fpu/VX_tensor_dpu.sv @@ -170,6 +170,7 @@ module VX_tensor_threadgroup #( input valid_in, output ready_in, input stall, + // all *_frag are row-major input [1:0][1:0][31:0] A_frag, input [1:0][3:0][31:0] B_frag, input [1:0][3:0][31:0] C_frag, @@ -191,16 +192,18 @@ module VX_tensor_threadgroup #( assign fedp_ready_in = fedp_ready_out; // The dot product units take 2 cycles to finish computing A_frag * B_frag - // + C_frag. step_in and step_out keeps track of which cycle they're at + // + C_frag. substep_in and substep_out keeps track of which cycle they're at // & when they have to pop from input queue and push to result queue. + // Note that substep is different from the "step" defined in the HMMA + // instruction set; it is a purely microarchitectural construct. // - // step_in == 0: FEDP uses first half from operand buffer - // step_in == 1: FEDP uses last half and pops from operand buffer - wire step_in; - // step_out == 0: FEDP produces first half of D_frag - // step_out == 1: FEDP produces last half of D_frag and asserts valid_out - wire step_out; - assign ready_in = fedp_fire_in && (step_in == 1'b1); + // substep_in == 0: FEDP uses first half from operand buffer + // substep_in == 1: FEDP uses last half and pops from operand buffer + wire substep_in; + // substep_out == 0: FEDP produces first half of D_frag + // substep_out == 1: FEDP produces last half of D_frag and asserts valid_out + wire substep_out; + assign ready_in = fedp_fire_in && (substep_in == 1'b1); wire [3:0][31:0] D_reg; logic [3:0][31:0] D_reg_n; @@ -221,30 +224,30 @@ module VX_tensor_threadgroup #( always @(*) begin D_reg_n = D_reg; if (fedp_fire_out) begin - if (step_out == 1'b0) begin + if (substep_out == 1'b0) begin D_reg_n = D_half; end end end - // flip step_in/step_out on FEDP in/out fire, respectively + // flip substep_in/substep_out on FEDP in/out fire, respectively VX_tensor_reg #( .DATAW(1) - ) staging_step_in ( + ) staging_substep_in ( .clk(clk), .reset(reset), - .d(~step_in), + .d(~substep_in), .en(fedp_fire_in), - .q(step_in) + .q(substep_in) ); VX_tensor_reg #( .DATAW(1) - ) staging_step_out ( + ) staging_substep_out ( .clk(clk), .reset(reset), - .d(~step_out), + .d(~substep_out), .en(fedp_fire_out), - .q(step_out) + .q(substep_out) ); // TODO: Instead of latching half-result and constructing a full D tile, @@ -259,31 +262,44 @@ module VX_tensor_threadgroup #( assign D_frag[1][1] = D_half[2]; assign D_frag[1][3] = D_half[3]; + wire [1:0][3:0][15:0] A_frag_fp16 = A_frag; + wire [3:0][3:0][15:0] B_frag_fp16 = B_frag; + // 4 FEDPs per threadgroup for (genvar i = 0; i < 4; ++i) begin + // at substep == 0, the 0th and 2nd columns of D begins compute; + // at substep == 1, the 1st and 3rd columns of D begins compute. + // there are two row elements for each column, rounding out + // 4 elements being computed by 4 FEDPs at every cycle + // (see Figure 10(b)). + + // d_row: 0, 0, 1, 1 + // d_col: 0, 2, 0, 2 localparam int d_row = i / 2; localparam int d_col = (i % 2) * 2; + wire [31:0] d_col_sel = (substep_in == 1'b0) ? d_col : (d_col + 1); + // Dot product (FEDP) unit generated from Chisel TensorDotProductUnit fedp ( .clock (clk), .reset (reset), .io_in_valid (fedp_fire_in), - .io_in_bits_a_0 (A_frag[d_row][0]), - .io_in_bits_a_1 (A_frag[d_row][1]), - .io_in_bits_a_2 (32'h0), - .io_in_bits_a_3 (32'h0), - .io_in_bits_b_0 (step_in == 1'b0 ? B_frag[0][d_col] : B_frag[0][d_col + 1]), - .io_in_bits_b_1 (step_in == 1'b0 ? B_frag[1][d_col] : B_frag[1][d_col + 1]), - .io_in_bits_b_2 (32'h0), - .io_in_bits_b_3 (32'h0), - .io_in_bits_c (step_in == 1'b0 ? C_frag[d_row][d_col] : C_frag[d_row][d_col + 1]), + .io_in_bits_a_0 (A_frag[d_row][0][15: 0]), + .io_in_bits_a_1 (A_frag[d_row][0][31:16]), + .io_in_bits_a_2 (A_frag[d_row][1][15: 0]), + .io_in_bits_a_3 (A_frag[d_row][1][31:16]), + .io_in_bits_b_0 (B_frag[0][d_col_sel][15: 0]), + .io_in_bits_b_1 (B_frag[0][d_col_sel][31:16]), + .io_in_bits_b_2 (B_frag[1][d_col_sel][15: 0]), + .io_in_bits_b_3 (B_frag[1][d_col_sel][31:16]), + .io_in_bits_c (C_frag[d_row][d_col_sel]), .io_stall (stall), .io_out_valid (fedp_valids[i]), .io_out_bits_data (D_half[i]) ); end - assign valid_out = fedp_valid_out && (step_out == 1'b1); + assign valid_out = fedp_valid_out && (substep_out == 1'b1); endmodule `endif