tensor: Rename halves_buf to reduce confusion
This commit is contained in:
@@ -425,9 +425,7 @@ module VX_tensor_octet #(
|
|||||||
endfunction
|
endfunction
|
||||||
|
|
||||||
half_t halves;
|
half_t halves;
|
||||||
half_t halves_buf;
|
assign halves = get_operand_half(operands_step_buf, A_in_buf, B_in_buf, C_in_buf);
|
||||||
assign halves = get_operand_half(operands_step, A_in, B_in, C_in);
|
|
||||||
assign halves_buf = get_operand_half(operands_step_buf, A_in_buf, B_in_buf, C_in_buf);
|
|
||||||
|
|
||||||
wire do_hmma = operands_ready_buf && operands_valid_buf && operands_last_in_pair_buf;
|
wire do_hmma = operands_ready_buf && operands_valid_buf && operands_last_in_pair_buf;
|
||||||
// wire operands_first_in_pair_fire = operands_ready && operands_valid && (!operands_last_in_pair);
|
// wire operands_first_in_pair_fire = operands_ready && operands_valid && (!operands_last_in_pair);
|
||||||
@@ -454,9 +452,9 @@ module VX_tensor_octet #(
|
|||||||
if (operands_first_in_pair_fire) begin
|
if (operands_first_in_pair_fire) begin
|
||||||
// NOTE: substeps is only used for debugging
|
// NOTE: substeps is only used for debugging
|
||||||
substeps_n[operands_wid_buf] = 1'b1; // ready for hmma
|
substeps_n[operands_wid_buf] = 1'b1; // ready for hmma
|
||||||
A_buffer_n[operands_wid_buf] = halves_buf.A_half;
|
A_buffer_n[operands_wid_buf] = halves.A_half;
|
||||||
B_buffer_n[operands_wid_buf] = halves_buf.B_half;
|
B_buffer_n[operands_wid_buf] = halves.B_half;
|
||||||
C_buffer_n[operands_wid_buf] = halves_buf.C_half;
|
C_buffer_n[operands_wid_buf] = halves.C_half;
|
||||||
end
|
end
|
||||||
if (do_hmma) begin
|
if (do_hmma) begin
|
||||||
substeps_n[operands_wid_buf] = 1'b0; // finished hmma, ready for next operand
|
substeps_n[operands_wid_buf] = 1'b0; // finished hmma, ready for next operand
|
||||||
@@ -478,28 +476,32 @@ module VX_tensor_octet #(
|
|||||||
assign operands_ready_buf = hmma_ready;
|
assign operands_ready_buf = hmma_ready;
|
||||||
|
|
||||||
// all *_tiles below are row-major
|
// all *_tiles below are row-major
|
||||||
// A is a 4x2 fp32 matrix
|
// A is a 4x2 fp32 matrix; row 0-2 for one threadgroup, row 4-6 for the
|
||||||
|
// other. The two columns (along k) are shared between the threadgroups.
|
||||||
|
// Buffered data are combined with the current data along the K dimension.
|
||||||
|
// See figure 10(b).
|
||||||
wire [3:0][1:0][31:0] A_tile = {
|
wire [3:0][1:0][31:0] A_tile = {
|
||||||
{ halves_buf.A_half[3], A_buffer[operands_wid_buf][3] },
|
{ halves.A_half[3], A_buffer[operands_wid_buf][3] },
|
||||||
{ halves_buf.A_half[2], A_buffer[operands_wid_buf][2] },
|
{ halves.A_half[2], A_buffer[operands_wid_buf][2] },
|
||||||
{ halves_buf.A_half[1], A_buffer[operands_wid_buf][1] },
|
{ halves.A_half[1], A_buffer[operands_wid_buf][1] },
|
||||||
{ halves_buf.A_half[0], A_buffer[operands_wid_buf][0] }
|
{ halves.A_half[0], A_buffer[operands_wid_buf][0] }
|
||||||
};
|
};
|
||||||
// B is a 2x4 fp32 matrix
|
// B is a 2x4 fp32 matrix, shared between the two threadgroups
|
||||||
wire [1:0][3:0][31:0] B_tile = {
|
wire [1:0][3:0][31:0] B_tile = {
|
||||||
halves_buf.B_half,
|
halves.B_half,
|
||||||
B_buffer[operands_wid_buf]
|
B_buffer[operands_wid_buf]
|
||||||
};
|
};
|
||||||
// C is a 4x4 fp32 matrix
|
// C is a 4x4 fp32 matrix; row 0-2 for one threadgroup, row 4-6 for the
|
||||||
|
// other
|
||||||
logic [3:0][3:0][31:0] C_tile;
|
logic [3:0][3:0][31:0] C_tile;
|
||||||
wire [3:0][3:0][31:0] D_tile;
|
wire [3:0][3:0][31:0] D_tile;
|
||||||
wire [`NW_WIDTH-1:0] D_wid_dpu;
|
wire [`NW_WIDTH-1:0] D_wid_dpu;
|
||||||
|
|
||||||
always @(*) begin
|
always @(*) begin
|
||||||
C_tile[3] = { halves_buf.C_half[7], C_buffer[operands_wid_buf][7], halves_buf.C_half[5], C_buffer[operands_wid_buf][5] };
|
C_tile[3] = { halves.C_half[7], C_buffer[operands_wid_buf][7], halves.C_half[5], C_buffer[operands_wid_buf][5] };
|
||||||
C_tile[2] = { halves_buf.C_half[6], C_buffer[operands_wid_buf][6], halves_buf.C_half[4], C_buffer[operands_wid_buf][4] };
|
C_tile[2] = { halves.C_half[6], C_buffer[operands_wid_buf][6], halves.C_half[4], C_buffer[operands_wid_buf][4] };
|
||||||
C_tile[1] = { halves_buf.C_half[3], C_buffer[operands_wid_buf][3], halves_buf.C_half[1], C_buffer[operands_wid_buf][1] };
|
C_tile[1] = { halves.C_half[3], C_buffer[operands_wid_buf][3], halves.C_half[1], C_buffer[operands_wid_buf][1] };
|
||||||
C_tile[0] = { halves_buf.C_half[2], C_buffer[operands_wid_buf][2], halves_buf.C_half[0], C_buffer[operands_wid_buf][0] };
|
C_tile[0] = { halves.C_half[2], C_buffer[operands_wid_buf][2], halves.C_half[0], C_buffer[operands_wid_buf][0] };
|
||||||
end
|
end
|
||||||
|
|
||||||
wire dpu_valid;
|
wire dpu_valid;
|
||||||
|
|||||||
@@ -163,7 +163,7 @@ endmodule
|
|||||||
// does (m,n,k) = (2,4,2) matmul compute over 2 cycles.
|
// does (m,n,k) = (2,4,2) matmul compute over 2 cycles.
|
||||||
// see Figure 10(b) of the paper.
|
// see Figure 10(b) of the paper.
|
||||||
module VX_tensor_threadgroup #(
|
module VX_tensor_threadgroup #(
|
||||||
parameter HALF_PRECISION = 1
|
parameter HALF_PRECISION = 0
|
||||||
) (
|
) (
|
||||||
input clk,
|
input clk,
|
||||||
input reset,
|
input reset,
|
||||||
|
|||||||
Reference in New Issue
Block a user