diff --git a/hw/rtl/core/VX_tensor_core.sv b/hw/rtl/core/VX_tensor_core.sv index e9976085..ea97a361 100644 --- a/hw/rtl/core/VX_tensor_core.sv +++ b/hw/rtl/core/VX_tensor_core.sv @@ -486,7 +486,8 @@ module VX_tensor_octet #( { halves.A_half[1], A_buffer[operands_wid_buf][1] }, { halves.A_half[0], A_buffer[operands_wid_buf][0] } }; - // B is a 2x4 fp32 matrix, shared between the two threadgroups + // B is a 2x4 fp32 matrix, shared between the two threadgroups. + // The two rows (along k) are combined between buffered and current data. wire [1:0][3:0][31:0] B_tile = { halves.B_half, B_buffer[operands_wid_buf] @@ -497,6 +498,9 @@ module VX_tensor_octet #( wire [3:0][3:0][31:0] D_tile; wire [`NW_WIDTH-1:0] D_wid_dpu; + // C follows the 1x2 "jagged" mapping in Figure 7(b). + // Buffered data are combined with the current data along the rows, + // forming an 1x2 block for each lane. always @(*) begin C_tile[3] = { halves.C_half[7], C_buffer[operands_wid_buf][7], halves.C_half[5], C_buffer[operands_wid_buf][5] }; C_tile[2] = { halves.C_half[6], C_buffer[operands_wid_buf][6], halves.C_half[4], C_buffer[operands_wid_buf][4] }; diff --git a/hw/rtl/fpu/VX_tensor_dpu.sv b/hw/rtl/fpu/VX_tensor_dpu.sv index 0d8059ba..735aec87 100644 --- a/hw/rtl/fpu/VX_tensor_dpu.sv +++ b/hw/rtl/fpu/VX_tensor_dpu.sv @@ -97,8 +97,8 @@ module VX_tensor_dpu #( ); // Split A_tile and C_tile by rows (0-1, 2-3) and parallelize in two - // threadgroups; B_tile is shared across the two threadgroups. See Figure - // 13 in paper + // threadgroup DPUs; B_tile is shared across the two threadgroups. See + // Figure 13 in paper VX_tensor_threadgroup #( ) threadgroup_0 ( .clk (clk), @@ -196,7 +196,8 @@ module VX_tensor_threadgroup #( // + C_frag. substep_in and substep_out keeps track of which cycle they're at // & when they have to pop from input queue and push to result queue. // Note that substep is different from the "step" defined in the HMMA - // instruction set; it is a purely microarchitectural construct. + // instruction set; it is similar in meaning to the substeps in + // VX_tensor_octet. // // substep_in == 0: FEDP uses first half from operand buffer // substep_in == 1: FEDP uses last half and pops from operand buffer @@ -270,12 +271,21 @@ module VX_tensor_threadgroup #( for (genvar i = 0; i < 4; ++i) begin // at substep == 0, the 0th and 2nd columns of D begins compute; // at substep == 1, the 1st and 3rd columns of D begins compute. - // there are two row elements for each column, rounding out - // 4 elements being computed by 4 FEDPs at every cycle + // There are two row elements for each column, rounding out to + // 4 elements computed by 4 FEDPs at every cycle // (see Figure 10(b)). - // d_row: 0, 0, 1, 1 - // d_col: 0, 2, 0, 2 + // i : 0, 1, 2, 3 + // d_row : 0, 0, 1, 1 + // d_col : 0, 2, 0, 2 + // d_col_sel: 1, 3, 1, 3 + // + // substep 0: + // [ 0 x 2 x ] + // [ 1 x 3 x ] + // substep 1: + // [ x 0 x 2 ] + // [ x 1 x 3 ] localparam int d_row = i / 2; localparam int d_col = (i % 2) * 2; wire [31:0] d_col_sel = (substep_in == 1'b0) ? d_col : (d_col + 1);