diff --git a/hw/rtl/core/VX_tensor_core.sv b/hw/rtl/core/VX_tensor_core.sv index ea97a361..cf5a0071 100644 --- a/hw/rtl/core/VX_tensor_core.sv +++ b/hw/rtl/core/VX_tensor_core.sv @@ -75,8 +75,9 @@ module VX_tensor_core_block import VX_gpu_pkg::*; #( ); localparam NUM_OCTETS = (`NUM_THREADS / 8); // offet in the lane numbers that get mapped to the two threadgroups in an - // octet. E.g. two tgs map lane 0-3 and lane 16-19 -> 16 - // FIXME: not sure this is the right logic. just filling in what works + // octet. E.g. two tgs map lane 0-3 and lane 16-19 -> + // LANE_OFFSET_THREADGROUP = 16 + // FIXME: check logic; only verified for single octet localparam LANE_OFFSET_THREADGROUP = (4 * NUM_OCTETS); // this is only a rule of thumb localparam METADATA_QUEUE_DEPTH = 2 * `LATENCY_HMMA; @@ -147,6 +148,10 @@ module VX_tensor_core_block import VX_gpu_pkg::*; #( // each octet produces 4x4 output partial sum, but the 8 lanes mapped // to the octet can only do 8 fp32 writeback at a time; so we need to // split writeback over two cycles + // + // octet_D matches the mathematical layout of the matrix (4x4 output + // per octet). The logic below replicates the jagged 1x2 mapping in + // Figure 7(b) to map values to the lanes. assign wb_data_0[4*i+0] = octet_D[0][0]; assign wb_data_0[4*i+1] = octet_D[1][0]; assign wb_data_0[4*i+2] = octet_D[0][2]; @@ -511,7 +516,7 @@ module VX_tensor_octet #( wire dpu_valid; // this does (m,n,k)=(4,4,2) matmul, modeling compute of a single octet - VX_tensor_dpu #( + VX_tensor_threadgroups #( .ISW(ISW), .OCTET(OCTET), .OPERAND_BUFFER_DEPTH(4 /*@perf: arbtirary*/) @@ -581,14 +586,14 @@ module VX_tensor_octet #( end `ifdef PERF_ENABLE - logic [`PERF_CTR_BITS-1:0] perf_tensor_dpu_total; + logic [`PERF_CTR_BITS-1:0] perf_tensor_ops_total; always @(posedge clk) begin if (reset) begin - perf_tensor_dpu_total <= '0; + perf_tensor_ops_total <= '0; end else begin if (do_hmma) begin - perf_tensor_dpu_total <= perf_tensor_dpu_total + 2'd2; + perf_tensor_ops_total <= perf_tensor_ops_total + 2'd2; end end end diff --git a/hw/rtl/fpu/VX_tensor_dpu.sv b/hw/rtl/fpu/VX_tensor_dpu.sv index 735aec87..ad3086d3 100644 --- a/hw/rtl/fpu/VX_tensor_dpu.sv +++ b/hw/rtl/fpu/VX_tensor_dpu.sv @@ -1,7 +1,8 @@ `ifdef EXT_T_ENABLE `include "VX_fpu_define.vh" -module VX_tensor_dpu #( +// Module that contains the threadgroups with DPUs + operand buffer. +module VX_tensor_threadgroups #( parameter ISW, parameter OCTET, // @perf: has big impact on throughput. A rule of thumb is to set it to @@ -15,6 +16,7 @@ module VX_tensor_dpu #( input valid_in, output ready_in, // [rows][cols][dtype] + // (m,n,k) = (4,4,2) input [3:0][1:0][31:0] A_tile, input [1:0][3:0][31:0] B_tile, input [3:0][3:0][31:0] C_tile, @@ -172,6 +174,7 @@ module VX_tensor_threadgroup #( output ready_in, input stall, // all *_frag are row-major + // (m,n,k) = (2,4,2) input [1:0][1:0][31:0] A_frag, input [1:0][3:0][31:0] B_frag, input [1:0][3:0][31:0] C_frag, @@ -269,8 +272,11 @@ module VX_tensor_threadgroup #( // 4 FEDPs per threadgroup for (genvar i = 0; i < 4; ++i) begin - // at substep == 0, the 0th and 2nd columns of D begins compute; - // at substep == 1, the 1st and 3rd columns of D begins compute. + // Determine which elements in the D matrix the dot-product units get + // mapped to. + // + // At substep == 0, the 0th and 2nd columns of D begins compute; + // At substep == 1, the 1st and 3rd columns of D begins compute. // There are two row elements for each column, rounding out to // 4 elements computed by 4 FEDPs at every cycle // (see Figure 10(b)).