tensor: Parameterize fedp for fp16/fp32

This commit is contained in:
Hansung Kim
2024-08-12 20:01:56 -07:00
parent 15e93e01d8
commit d39e24643d

View File

@@ -163,6 +163,7 @@ endmodule
// does (m,n,k) = (2,4,2) matmul compute over 2 cycles.
// see Figure 10(b) of the paper.
module VX_tensor_threadgroup #(
parameter HALF_PRECISION = 1
) (
input clk,
input reset,
@@ -280,23 +281,43 @@ module VX_tensor_threadgroup #(
wire [31:0] d_col_sel = (substep_in == 1'b0) ? d_col : (d_col + 1);
// Dot product (FEDP) unit generated from Chisel
TensorDotProductUnit fedp (
.clock (clk),
.reset (reset),
.io_in_valid (fedp_fire_in),
.io_in_bits_a_0 (A_frag[d_row][0][15: 0]),
.io_in_bits_a_1 (A_frag[d_row][0][31:16]),
.io_in_bits_a_2 (A_frag[d_row][1][15: 0]),
.io_in_bits_a_3 (A_frag[d_row][1][31:16]),
.io_in_bits_b_0 (B_frag[0][d_col_sel][15: 0]),
.io_in_bits_b_1 (B_frag[0][d_col_sel][31:16]),
.io_in_bits_b_2 (B_frag[1][d_col_sel][15: 0]),
.io_in_bits_b_3 (B_frag[1][d_col_sel][31:16]),
.io_in_bits_c (C_frag[d_row][d_col_sel]),
.io_stall (stall),
.io_out_valid (fedp_valids[i]),
.io_out_bits_data (D_half[i])
);
if (HALF_PRECISION != 0) begin
TensorDotProductUnit fedp (
.clock (clk),
.reset (reset),
.io_in_valid (fedp_fire_in),
.io_in_bits_a_0 (A_frag[d_row][0][15: 0]),
.io_in_bits_a_1 (A_frag[d_row][0][31:16]),
.io_in_bits_a_2 (A_frag[d_row][1][15: 0]),
.io_in_bits_a_3 (A_frag[d_row][1][31:16]),
.io_in_bits_b_0 (B_frag[0][d_col_sel][15: 0]),
.io_in_bits_b_1 (B_frag[0][d_col_sel][31:16]),
.io_in_bits_b_2 (B_frag[1][d_col_sel][15: 0]),
.io_in_bits_b_3 (B_frag[1][d_col_sel][31:16]),
.io_in_bits_c (C_frag[d_row][d_col_sel]),
.io_stall (stall),
.io_out_valid (fedp_valids[i]),
.io_out_bits_data (D_half[i])
);
end else begin
TensorDotProductUnit fedp (
.clock (clk),
.reset (reset),
.io_in_valid (fedp_fire_in),
.io_in_bits_a_0 (A_frag[d_row][0]),
.io_in_bits_a_1 (A_frag[d_row][1]),
.io_in_bits_a_2 (32'h0),
.io_in_bits_a_3 (32'h0),
.io_in_bits_b_0 (B_frag[0][d_col_sel]),
.io_in_bits_b_1 (B_frag[1][d_col_sel]),
.io_in_bits_b_2 (32'h0),
.io_in_bits_b_3 (32'h0),
.io_in_bits_c (C_frag[d_row][d_col_sel]),
.io_stall (stall),
.io_out_valid (fedp_valids[i]),
.io_out_bits_data (D_half[i])
);
end
end
assign valid_out = fedp_valid_out && (substep_out == 1'b1);