tensor: doc
This commit is contained in:
@@ -43,7 +43,8 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #(
|
|||||||
|
|
||||||
assign dispatch_if.ready = &octet_operands_ready;
|
assign dispatch_if.ready = &octet_operands_ready;
|
||||||
|
|
||||||
for (genvar i = 0; i < 4; ++i) begin
|
for (genvar i = 0; i < 4/*octets*/; ++i) begin
|
||||||
|
// lane-to-octet mapping; see figure 13 of the paper
|
||||||
wire [7:0][31:0] octet_A = {
|
wire [7:0][31:0] octet_A = {
|
||||||
dispatch_if.data.rs1_data[16+4*i +: 4], dispatch_if.data.rs1_data[4*i +: 4]
|
dispatch_if.data.rs1_data[16+4*i +: 4], dispatch_if.data.rs1_data[4*i +: 4]
|
||||||
};
|
};
|
||||||
@@ -81,6 +82,9 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #(
|
|||||||
assign octet_results_valid[i] = result_valid;
|
assign octet_results_valid[i] = result_valid;
|
||||||
assign result_ready = octet_results_ready[i];
|
assign result_ready = octet_results_ready[i];
|
||||||
|
|
||||||
|
// each octet produces 4x4 output partial sum, but the 8 lanes mapped
|
||||||
|
// to the octet can only do 8 fp32 writeback at a time; so we need to
|
||||||
|
// split writeback over two cycles
|
||||||
assign wb_data_0[4*i+0] = octet_D[0][0];
|
assign wb_data_0[4*i+0] = octet_D[0][0];
|
||||||
assign wb_data_0[4*i+1] = octet_D[1][0];
|
assign wb_data_0[4*i+1] = octet_D[1][0];
|
||||||
assign wb_data_0[4*i+2] = octet_D[0][2];
|
assign wb_data_0[4*i+2] = octet_D[0][2];
|
||||||
@@ -150,11 +154,11 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #(
|
|||||||
|
|
||||||
localparam COMMIT_DATAW = `UUID_WIDTH + `NW_WIDTH + `NUM_THREADS + `XLEN + 1 + `NR_BITS + (`NUM_THREADS * `XLEN) + 1 + 1 + 1;
|
localparam COMMIT_DATAW = `UUID_WIDTH + `NW_WIDTH + `NUM_THREADS + `XLEN + 1 + `NR_BITS + (`NUM_THREADS * `XLEN) + 1 + 1 + 1;
|
||||||
wire [COMMIT_DATAW-1:0] commit_if_data = {
|
wire [COMMIT_DATAW-1:0] commit_if_data = {
|
||||||
dispatch_if_data_deq,
|
dispatch_if_data_deq, /* uuid ~ rd */
|
||||||
subcommit == 1'b0 ? wb_data_0 : wb_data_1,
|
subcommit == 1'b0 ? wb_data_0 : wb_data_1, /* data */
|
||||||
1'b0,
|
1'b0, /* pid */
|
||||||
1'b1,
|
1'b1, /* sop */
|
||||||
1'b1
|
1'b1 /* eop */
|
||||||
};
|
};
|
||||||
|
|
||||||
assign commit_if.data = commit_if_data;
|
assign commit_if.data = commit_if_data;
|
||||||
@@ -204,8 +208,11 @@ module VX_tensor_octet #(
|
|||||||
logic [3:0][31:0] B_buffer, B_buffer_n;
|
logic [3:0][31:0] B_buffer, B_buffer_n;
|
||||||
logic [7:0][31:0] C_buffer, C_buffer_n;
|
logic [7:0][31:0] C_buffer, C_buffer_n;
|
||||||
|
|
||||||
// half the inputs are buffered, half are not (instead coming straight from operand bus)
|
// half the inputs are buffered, half are not (instead coming straight
|
||||||
// unlike the real tensor core, the banks are only 32 bit rather than 64 bit
|
// from operand bus) unlike the real tensor core.
|
||||||
|
// the banks are only 32 bit rather than 64 bit (a pair of fp32 regs).
|
||||||
|
// since A and B are supplied by 4 lanes each, we get 4 fp32's at a time
|
||||||
|
// (8 for C).
|
||||||
logic [3:0][31:0] A_half;
|
logic [3:0][31:0] A_half;
|
||||||
logic [3:0][31:0] B_half;
|
logic [3:0][31:0] B_half;
|
||||||
logic [7:0][31:0] C_half;
|
logic [7:0][31:0] C_half;
|
||||||
@@ -265,15 +272,18 @@ module VX_tensor_octet #(
|
|||||||
wire stall = result_valid && ~result_ready;
|
wire stall = result_valid && ~result_ready;
|
||||||
assign operands_ready = ~stall;
|
assign operands_ready = ~stall;
|
||||||
|
|
||||||
|
// A is 4x2 fp32 matrix
|
||||||
wire [3:0][1:0][31:0] A_tile = {
|
wire [3:0][1:0][31:0] A_tile = {
|
||||||
{ A_half[3], A_buffer[3] },
|
{ A_half[3], A_buffer[3] },
|
||||||
{ A_half[2], A_buffer[2] },
|
{ A_half[2], A_buffer[2] },
|
||||||
{ A_half[1], A_buffer[1] },
|
{ A_half[1], A_buffer[1] },
|
||||||
{ A_half[0], A_buffer[0] }
|
{ A_half[0], A_buffer[0] }
|
||||||
};
|
};
|
||||||
|
// B is 2x4 fp32 matrix
|
||||||
wire [1:0][3:0][31:0] B_tile = {
|
wire [1:0][3:0][31:0] B_tile = {
|
||||||
B_half, B_buffer
|
B_half, B_buffer
|
||||||
};
|
};
|
||||||
|
// C is 4x4 fp32 matrix
|
||||||
logic [3:0][3:0][31:0] C_tile;
|
logic [3:0][3:0][31:0] C_tile;
|
||||||
|
|
||||||
always @(*) begin
|
always @(*) begin
|
||||||
@@ -286,6 +296,8 @@ module VX_tensor_octet #(
|
|||||||
end
|
end
|
||||||
|
|
||||||
wire do_hmma = (substep == 1'b1 && operands_valid && operands_ready);
|
wire do_hmma = (substep == 1'b1 && operands_valid && operands_ready);
|
||||||
|
|
||||||
|
// this does (m,n,k)=(4,4,2) matmul, modeling compute of a single octet
|
||||||
VX_tensor_dpu #(
|
VX_tensor_dpu #(
|
||||||
.ISW(ISW),
|
.ISW(ISW),
|
||||||
.OCTET(OCTET)
|
.OCTET(OCTET)
|
||||||
|
|||||||
Reference in New Issue
Block a user