tensor: Doc comments
This commit is contained in:
@@ -486,7 +486,8 @@ module VX_tensor_octet #(
|
|||||||
{ halves.A_half[1], A_buffer[operands_wid_buf][1] },
|
{ halves.A_half[1], A_buffer[operands_wid_buf][1] },
|
||||||
{ halves.A_half[0], A_buffer[operands_wid_buf][0] }
|
{ halves.A_half[0], A_buffer[operands_wid_buf][0] }
|
||||||
};
|
};
|
||||||
// B is a 2x4 fp32 matrix, shared between the two threadgroups
|
// B is a 2x4 fp32 matrix, shared between the two threadgroups.
|
||||||
|
// The two rows (along k) are combined between buffered and current data.
|
||||||
wire [1:0][3:0][31:0] B_tile = {
|
wire [1:0][3:0][31:0] B_tile = {
|
||||||
halves.B_half,
|
halves.B_half,
|
||||||
B_buffer[operands_wid_buf]
|
B_buffer[operands_wid_buf]
|
||||||
@@ -497,6 +498,9 @@ module VX_tensor_octet #(
|
|||||||
wire [3:0][3:0][31:0] D_tile;
|
wire [3:0][3:0][31:0] D_tile;
|
||||||
wire [`NW_WIDTH-1:0] D_wid_dpu;
|
wire [`NW_WIDTH-1:0] D_wid_dpu;
|
||||||
|
|
||||||
|
// C follows the 1x2 "jagged" mapping in Figure 7(b).
|
||||||
|
// Buffered data are combined with the current data along the rows,
|
||||||
|
// forming an 1x2 block for each lane.
|
||||||
always @(*) begin
|
always @(*) begin
|
||||||
C_tile[3] = { halves.C_half[7], C_buffer[operands_wid_buf][7], halves.C_half[5], C_buffer[operands_wid_buf][5] };
|
C_tile[3] = { halves.C_half[7], C_buffer[operands_wid_buf][7], halves.C_half[5], C_buffer[operands_wid_buf][5] };
|
||||||
C_tile[2] = { halves.C_half[6], C_buffer[operands_wid_buf][6], halves.C_half[4], C_buffer[operands_wid_buf][4] };
|
C_tile[2] = { halves.C_half[6], C_buffer[operands_wid_buf][6], halves.C_half[4], C_buffer[operands_wid_buf][4] };
|
||||||
|
|||||||
@@ -97,8 +97,8 @@ module VX_tensor_dpu #(
|
|||||||
);
|
);
|
||||||
|
|
||||||
// Split A_tile and C_tile by rows (0-1, 2-3) and parallelize in two
|
// Split A_tile and C_tile by rows (0-1, 2-3) and parallelize in two
|
||||||
// threadgroups; B_tile is shared across the two threadgroups. See Figure
|
// threadgroup DPUs; B_tile is shared across the two threadgroups. See
|
||||||
// 13 in paper
|
// Figure 13 in paper
|
||||||
VX_tensor_threadgroup #(
|
VX_tensor_threadgroup #(
|
||||||
) threadgroup_0 (
|
) threadgroup_0 (
|
||||||
.clk (clk),
|
.clk (clk),
|
||||||
@@ -196,7 +196,8 @@ module VX_tensor_threadgroup #(
|
|||||||
// + C_frag. substep_in and substep_out keeps track of which cycle they're at
|
// + C_frag. substep_in and substep_out keeps track of which cycle they're at
|
||||||
// & when they have to pop from input queue and push to result queue.
|
// & when they have to pop from input queue and push to result queue.
|
||||||
// Note that substep is different from the "step" defined in the HMMA
|
// Note that substep is different from the "step" defined in the HMMA
|
||||||
// instruction set; it is a purely microarchitectural construct.
|
// instruction set; it is similar in meaning to the substeps in
|
||||||
|
// VX_tensor_octet.
|
||||||
//
|
//
|
||||||
// substep_in == 0: FEDP uses first half from operand buffer
|
// substep_in == 0: FEDP uses first half from operand buffer
|
||||||
// substep_in == 1: FEDP uses last half and pops from operand buffer
|
// substep_in == 1: FEDP uses last half and pops from operand buffer
|
||||||
@@ -270,12 +271,21 @@ module VX_tensor_threadgroup #(
|
|||||||
for (genvar i = 0; i < 4; ++i) begin
|
for (genvar i = 0; i < 4; ++i) begin
|
||||||
// at substep == 0, the 0th and 2nd columns of D begins compute;
|
// at substep == 0, the 0th and 2nd columns of D begins compute;
|
||||||
// at substep == 1, the 1st and 3rd columns of D begins compute.
|
// at substep == 1, the 1st and 3rd columns of D begins compute.
|
||||||
// there are two row elements for each column, rounding out
|
// There are two row elements for each column, rounding out to
|
||||||
// 4 elements being computed by 4 FEDPs at every cycle
|
// 4 elements computed by 4 FEDPs at every cycle
|
||||||
// (see Figure 10(b)).
|
// (see Figure 10(b)).
|
||||||
|
|
||||||
// d_row: 0, 0, 1, 1
|
// i : 0, 1, 2, 3
|
||||||
// d_col: 0, 2, 0, 2
|
// d_row : 0, 0, 1, 1
|
||||||
|
// d_col : 0, 2, 0, 2
|
||||||
|
// d_col_sel: 1, 3, 1, 3
|
||||||
|
//
|
||||||
|
// substep 0:
|
||||||
|
// [ 0 x 2 x ]
|
||||||
|
// [ 1 x 3 x ]
|
||||||
|
// substep 1:
|
||||||
|
// [ x 0 x 2 ]
|
||||||
|
// [ x 1 x 3 ]
|
||||||
localparam int d_row = i / 2;
|
localparam int d_row = i / 2;
|
||||||
localparam int d_col = (i % 2) * 2;
|
localparam int d_col = (i % 2) * 2;
|
||||||
wire [31:0] d_col_sel = (substep_in == 1'b0) ? d_col : (d_col + 1);
|
wire [31:0] d_col_sel = (substep_in == 1'b0) ? d_col : (d_col + 1);
|
||||||
|
|||||||
Reference in New Issue
Block a user