tensor: Parameterize fedp for fp16/fp32

This commit is contained in:
Hansung Kim
2024-08-12 20:01:56 -07:00
parent 15e93e01d8
commit d39e24643d

View File

@@ -163,6 +163,7 @@ endmodule
// does (m,n,k) = (2,4,2) matmul compute over 2 cycles. // does (m,n,k) = (2,4,2) matmul compute over 2 cycles.
// see Figure 10(b) of the paper. // see Figure 10(b) of the paper.
module VX_tensor_threadgroup #( module VX_tensor_threadgroup #(
parameter HALF_PRECISION = 1
) ( ) (
input clk, input clk,
input reset, input reset,
@@ -280,6 +281,7 @@ module VX_tensor_threadgroup #(
wire [31:0] d_col_sel = (substep_in == 1'b0) ? d_col : (d_col + 1); wire [31:0] d_col_sel = (substep_in == 1'b0) ? d_col : (d_col + 1);
// Dot product (FEDP) unit generated from Chisel // Dot product (FEDP) unit generated from Chisel
if (HALF_PRECISION != 0) begin
TensorDotProductUnit fedp ( TensorDotProductUnit fedp (
.clock (clk), .clock (clk),
.reset (reset), .reset (reset),
@@ -297,6 +299,25 @@ module VX_tensor_threadgroup #(
.io_out_valid (fedp_valids[i]), .io_out_valid (fedp_valids[i]),
.io_out_bits_data (D_half[i]) .io_out_bits_data (D_half[i])
); );
end else begin
TensorDotProductUnit fedp (
.clock (clk),
.reset (reset),
.io_in_valid (fedp_fire_in),
.io_in_bits_a_0 (A_frag[d_row][0]),
.io_in_bits_a_1 (A_frag[d_row][1]),
.io_in_bits_a_2 (32'h0),
.io_in_bits_a_3 (32'h0),
.io_in_bits_b_0 (B_frag[0][d_col_sel]),
.io_in_bits_b_1 (B_frag[1][d_col_sel]),
.io_in_bits_b_2 (32'h0),
.io_in_bits_b_3 (32'h0),
.io_in_bits_c (C_frag[d_row][d_col_sel]),
.io_stall (stall),
.io_out_valid (fedp_valids[i]),
.io_out_bits_data (D_half[i])
);
end
end end
assign valid_out = fedp_valid_out && (substep_out == 1'b1); assign valid_out = fedp_valid_out && (substep_out == 1'b1);