From d39e24643d807f0b106136ca91c35eda987d1ba9 Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Mon, 12 Aug 2024 20:01:56 -0700 Subject: [PATCH] tensor: Parameterize fedp for fp16/fp32 --- hw/rtl/fpu/VX_tensor_dpu.sv | 55 +++++++++++++++++++++++++------------ 1 file changed, 38 insertions(+), 17 deletions(-) diff --git a/hw/rtl/fpu/VX_tensor_dpu.sv b/hw/rtl/fpu/VX_tensor_dpu.sv index eb974633..0504f457 100644 --- a/hw/rtl/fpu/VX_tensor_dpu.sv +++ b/hw/rtl/fpu/VX_tensor_dpu.sv @@ -163,6 +163,7 @@ endmodule // does (m,n,k) = (2,4,2) matmul compute over 2 cycles. // see Figure 10(b) of the paper. module VX_tensor_threadgroup #( + parameter HALF_PRECISION = 1 ) ( input clk, input reset, @@ -280,23 +281,43 @@ module VX_tensor_threadgroup #( wire [31:0] d_col_sel = (substep_in == 1'b0) ? d_col : (d_col + 1); // Dot product (FEDP) unit generated from Chisel - TensorDotProductUnit fedp ( - .clock (clk), - .reset (reset), - .io_in_valid (fedp_fire_in), - .io_in_bits_a_0 (A_frag[d_row][0][15: 0]), - .io_in_bits_a_1 (A_frag[d_row][0][31:16]), - .io_in_bits_a_2 (A_frag[d_row][1][15: 0]), - .io_in_bits_a_3 (A_frag[d_row][1][31:16]), - .io_in_bits_b_0 (B_frag[0][d_col_sel][15: 0]), - .io_in_bits_b_1 (B_frag[0][d_col_sel][31:16]), - .io_in_bits_b_2 (B_frag[1][d_col_sel][15: 0]), - .io_in_bits_b_3 (B_frag[1][d_col_sel][31:16]), - .io_in_bits_c (C_frag[d_row][d_col_sel]), - .io_stall (stall), - .io_out_valid (fedp_valids[i]), - .io_out_bits_data (D_half[i]) - ); + if (HALF_PRECISION != 0) begin + TensorDotProductUnit fedp ( + .clock (clk), + .reset (reset), + .io_in_valid (fedp_fire_in), + .io_in_bits_a_0 (A_frag[d_row][0][15: 0]), + .io_in_bits_a_1 (A_frag[d_row][0][31:16]), + .io_in_bits_a_2 (A_frag[d_row][1][15: 0]), + .io_in_bits_a_3 (A_frag[d_row][1][31:16]), + .io_in_bits_b_0 (B_frag[0][d_col_sel][15: 0]), + .io_in_bits_b_1 (B_frag[0][d_col_sel][31:16]), + .io_in_bits_b_2 (B_frag[1][d_col_sel][15: 0]), + .io_in_bits_b_3 (B_frag[1][d_col_sel][31:16]), + .io_in_bits_c (C_frag[d_row][d_col_sel]), + .io_stall (stall), + .io_out_valid (fedp_valids[i]), + .io_out_bits_data (D_half[i]) + ); + end else begin + TensorDotProductUnit fedp ( + .clock (clk), + .reset (reset), + .io_in_valid (fedp_fire_in), + .io_in_bits_a_0 (A_frag[d_row][0]), + .io_in_bits_a_1 (A_frag[d_row][1]), + .io_in_bits_a_2 (32'h0), + .io_in_bits_a_3 (32'h0), + .io_in_bits_b_0 (B_frag[0][d_col_sel]), + .io_in_bits_b_1 (B_frag[1][d_col_sel]), + .io_in_bits_b_2 (32'h0), + .io_in_bits_b_3 (32'h0), + .io_in_bits_c (C_frag[d_row][d_col_sel]), + .io_stall (stall), + .io_out_valid (fedp_valids[i]), + .io_out_bits_data (D_half[i]) + ); + end end assign valid_out = fedp_valid_out && (substep_out == 1'b1);