tensor: Rename & docs
This commit is contained in:
@@ -75,8 +75,9 @@ module VX_tensor_core_block import VX_gpu_pkg::*; #(
|
|||||||
);
|
);
|
||||||
localparam NUM_OCTETS = (`NUM_THREADS / 8);
|
localparam NUM_OCTETS = (`NUM_THREADS / 8);
|
||||||
// offet in the lane numbers that get mapped to the two threadgroups in an
|
// offet in the lane numbers that get mapped to the two threadgroups in an
|
||||||
// octet. E.g. two tgs map lane 0-3 and lane 16-19 -> 16
|
// octet. E.g. two tgs map lane 0-3 and lane 16-19 ->
|
||||||
// FIXME: not sure this is the right logic. just filling in what works
|
// LANE_OFFSET_THREADGROUP = 16
|
||||||
|
// FIXME: check logic; only verified for single octet
|
||||||
localparam LANE_OFFSET_THREADGROUP = (4 * NUM_OCTETS);
|
localparam LANE_OFFSET_THREADGROUP = (4 * NUM_OCTETS);
|
||||||
// this is only a rule of thumb
|
// this is only a rule of thumb
|
||||||
localparam METADATA_QUEUE_DEPTH = 2 * `LATENCY_HMMA;
|
localparam METADATA_QUEUE_DEPTH = 2 * `LATENCY_HMMA;
|
||||||
@@ -147,6 +148,10 @@ module VX_tensor_core_block import VX_gpu_pkg::*; #(
|
|||||||
// each octet produces 4x4 output partial sum, but the 8 lanes mapped
|
// each octet produces 4x4 output partial sum, but the 8 lanes mapped
|
||||||
// to the octet can only do 8 fp32 writeback at a time; so we need to
|
// to the octet can only do 8 fp32 writeback at a time; so we need to
|
||||||
// split writeback over two cycles
|
// split writeback over two cycles
|
||||||
|
//
|
||||||
|
// octet_D matches the mathematical layout of the matrix (4x4 output
|
||||||
|
// per octet). The logic below replicates the jagged 1x2 mapping in
|
||||||
|
// Figure 7(b) to map values to the lanes.
|
||||||
assign wb_data_0[4*i+0] = octet_D[0][0];
|
assign wb_data_0[4*i+0] = octet_D[0][0];
|
||||||
assign wb_data_0[4*i+1] = octet_D[1][0];
|
assign wb_data_0[4*i+1] = octet_D[1][0];
|
||||||
assign wb_data_0[4*i+2] = octet_D[0][2];
|
assign wb_data_0[4*i+2] = octet_D[0][2];
|
||||||
@@ -511,7 +516,7 @@ module VX_tensor_octet #(
|
|||||||
wire dpu_valid;
|
wire dpu_valid;
|
||||||
|
|
||||||
// this does (m,n,k)=(4,4,2) matmul, modeling compute of a single octet
|
// this does (m,n,k)=(4,4,2) matmul, modeling compute of a single octet
|
||||||
VX_tensor_dpu #(
|
VX_tensor_threadgroups #(
|
||||||
.ISW(ISW),
|
.ISW(ISW),
|
||||||
.OCTET(OCTET),
|
.OCTET(OCTET),
|
||||||
.OPERAND_BUFFER_DEPTH(4 /*@perf: arbtirary*/)
|
.OPERAND_BUFFER_DEPTH(4 /*@perf: arbtirary*/)
|
||||||
@@ -581,14 +586,14 @@ module VX_tensor_octet #(
|
|||||||
end
|
end
|
||||||
|
|
||||||
`ifdef PERF_ENABLE
|
`ifdef PERF_ENABLE
|
||||||
logic [`PERF_CTR_BITS-1:0] perf_tensor_dpu_total;
|
logic [`PERF_CTR_BITS-1:0] perf_tensor_ops_total;
|
||||||
|
|
||||||
always @(posedge clk) begin
|
always @(posedge clk) begin
|
||||||
if (reset) begin
|
if (reset) begin
|
||||||
perf_tensor_dpu_total <= '0;
|
perf_tensor_ops_total <= '0;
|
||||||
end else begin
|
end else begin
|
||||||
if (do_hmma) begin
|
if (do_hmma) begin
|
||||||
perf_tensor_dpu_total <= perf_tensor_dpu_total + 2'd2;
|
perf_tensor_ops_total <= perf_tensor_ops_total + 2'd2;
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|||||||
@@ -1,7 +1,8 @@
|
|||||||
`ifdef EXT_T_ENABLE
|
`ifdef EXT_T_ENABLE
|
||||||
`include "VX_fpu_define.vh"
|
`include "VX_fpu_define.vh"
|
||||||
|
|
||||||
module VX_tensor_dpu #(
|
// Module that contains the threadgroups with DPUs + operand buffer.
|
||||||
|
module VX_tensor_threadgroups #(
|
||||||
parameter ISW,
|
parameter ISW,
|
||||||
parameter OCTET,
|
parameter OCTET,
|
||||||
// @perf: has big impact on throughput. A rule of thumb is to set it to
|
// @perf: has big impact on throughput. A rule of thumb is to set it to
|
||||||
@@ -15,6 +16,7 @@ module VX_tensor_dpu #(
|
|||||||
input valid_in,
|
input valid_in,
|
||||||
output ready_in,
|
output ready_in,
|
||||||
// [rows][cols][dtype]
|
// [rows][cols][dtype]
|
||||||
|
// (m,n,k) = (4,4,2)
|
||||||
input [3:0][1:0][31:0] A_tile,
|
input [3:0][1:0][31:0] A_tile,
|
||||||
input [1:0][3:0][31:0] B_tile,
|
input [1:0][3:0][31:0] B_tile,
|
||||||
input [3:0][3:0][31:0] C_tile,
|
input [3:0][3:0][31:0] C_tile,
|
||||||
@@ -172,6 +174,7 @@ module VX_tensor_threadgroup #(
|
|||||||
output ready_in,
|
output ready_in,
|
||||||
input stall,
|
input stall,
|
||||||
// all *_frag are row-major
|
// all *_frag are row-major
|
||||||
|
// (m,n,k) = (2,4,2)
|
||||||
input [1:0][1:0][31:0] A_frag,
|
input [1:0][1:0][31:0] A_frag,
|
||||||
input [1:0][3:0][31:0] B_frag,
|
input [1:0][3:0][31:0] B_frag,
|
||||||
input [1:0][3:0][31:0] C_frag,
|
input [1:0][3:0][31:0] C_frag,
|
||||||
@@ -269,8 +272,11 @@ module VX_tensor_threadgroup #(
|
|||||||
|
|
||||||
// 4 FEDPs per threadgroup
|
// 4 FEDPs per threadgroup
|
||||||
for (genvar i = 0; i < 4; ++i) begin
|
for (genvar i = 0; i < 4; ++i) begin
|
||||||
// at substep == 0, the 0th and 2nd columns of D begins compute;
|
// Determine which elements in the D matrix the dot-product units get
|
||||||
// at substep == 1, the 1st and 3rd columns of D begins compute.
|
// mapped to.
|
||||||
|
//
|
||||||
|
// At substep == 0, the 0th and 2nd columns of D begins compute;
|
||||||
|
// At substep == 1, the 1st and 3rd columns of D begins compute.
|
||||||
// There are two row elements for each column, rounding out to
|
// There are two row elements for each column, rounding out to
|
||||||
// 4 elements computed by 4 FEDPs at every cycle
|
// 4 elements computed by 4 FEDPs at every cycle
|
||||||
// (see Figure 10(b)).
|
// (see Figure 10(b)).
|
||||||
|
|||||||
Reference in New Issue
Block a user