tensor: Parameterize fedp for fp16/fp32

This commit is contained in:
Hansung Kim
2024-08-12 20:01:56 -07:00
parent 15e93e01d8
commit d39e24643d

View File

@@ -163,6 +163,7 @@ endmodule
// does (m,n,k) = (2,4,2) matmul compute over 2 cycles. // does (m,n,k) = (2,4,2) matmul compute over 2 cycles.
// see Figure 10(b) of the paper. // see Figure 10(b) of the paper.
module VX_tensor_threadgroup #( module VX_tensor_threadgroup #(
parameter HALF_PRECISION = 1
) ( ) (
input clk, input clk,
input reset, input reset,
@@ -280,23 +281,43 @@ module VX_tensor_threadgroup #(
wire [31:0] d_col_sel = (substep_in == 1'b0) ? d_col : (d_col + 1); wire [31:0] d_col_sel = (substep_in == 1'b0) ? d_col : (d_col + 1);
// Dot product (FEDP) unit generated from Chisel // Dot product (FEDP) unit generated from Chisel
TensorDotProductUnit fedp ( if (HALF_PRECISION != 0) begin
.clock (clk), TensorDotProductUnit fedp (
.reset (reset), .clock (clk),
.io_in_valid (fedp_fire_in), .reset (reset),
.io_in_bits_a_0 (A_frag[d_row][0][15: 0]), .io_in_valid (fedp_fire_in),
.io_in_bits_a_1 (A_frag[d_row][0][31:16]), .io_in_bits_a_0 (A_frag[d_row][0][15: 0]),
.io_in_bits_a_2 (A_frag[d_row][1][15: 0]), .io_in_bits_a_1 (A_frag[d_row][0][31:16]),
.io_in_bits_a_3 (A_frag[d_row][1][31:16]), .io_in_bits_a_2 (A_frag[d_row][1][15: 0]),
.io_in_bits_b_0 (B_frag[0][d_col_sel][15: 0]), .io_in_bits_a_3 (A_frag[d_row][1][31:16]),
.io_in_bits_b_1 (B_frag[0][d_col_sel][31:16]), .io_in_bits_b_0 (B_frag[0][d_col_sel][15: 0]),
.io_in_bits_b_2 (B_frag[1][d_col_sel][15: 0]), .io_in_bits_b_1 (B_frag[0][d_col_sel][31:16]),
.io_in_bits_b_3 (B_frag[1][d_col_sel][31:16]), .io_in_bits_b_2 (B_frag[1][d_col_sel][15: 0]),
.io_in_bits_c (C_frag[d_row][d_col_sel]), .io_in_bits_b_3 (B_frag[1][d_col_sel][31:16]),
.io_stall (stall), .io_in_bits_c (C_frag[d_row][d_col_sel]),
.io_out_valid (fedp_valids[i]), .io_stall (stall),
.io_out_bits_data (D_half[i]) .io_out_valid (fedp_valids[i]),
); .io_out_bits_data (D_half[i])
);
end else begin
TensorDotProductUnit fedp (
.clock (clk),
.reset (reset),
.io_in_valid (fedp_fire_in),
.io_in_bits_a_0 (A_frag[d_row][0]),
.io_in_bits_a_1 (A_frag[d_row][1]),
.io_in_bits_a_2 (32'h0),
.io_in_bits_a_3 (32'h0),
.io_in_bits_b_0 (B_frag[0][d_col_sel]),
.io_in_bits_b_1 (B_frag[1][d_col_sel]),
.io_in_bits_b_2 (32'h0),
.io_in_bits_b_3 (32'h0),
.io_in_bits_c (C_frag[d_row][d_col_sel]),
.io_stall (stall),
.io_out_valid (fedp_valids[i]),
.io_out_bits_data (D_half[i])
);
end
end end
assign valid_out = fedp_valid_out && (substep_out == 1'b1); assign valid_out = fedp_valid_out && (substep_out == 1'b1);