Stage half-operands per warp
An easy solution to handle multiple concurrent warp operations by staging half-operands in their own per-warp register. This might increase area requirement by quite a bit. TODO: Commit is not being handled correctly yet
This commit is contained in:
@@ -83,6 +83,8 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #(
|
||||
localparam LANE_OFFSET_THREADGROUP = (4 * NUM_OCTETS);
|
||||
|
||||
wire [1:0] step = 2'(execute_if.data.op_type);
|
||||
wire operands_last_in_pair = (execute_if.data.op_mod == `INST_MOD_BITS'(1));
|
||||
|
||||
logic [NUM_OCTETS-1:0] octet_results_valid;
|
||||
logic [NUM_OCTETS-1:0] octet_results_ready;
|
||||
logic [NUM_OCTETS-1:0] octet_operands_ready;
|
||||
@@ -111,6 +113,8 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #(
|
||||
logic [3:0][3:0][31:0] octet_D;
|
||||
logic result_valid;
|
||||
logic result_ready;
|
||||
|
||||
// op_mod is reused to indicate instruction's id in pair
|
||||
VX_tensor_octet #(
|
||||
.ISW(ISW),
|
||||
.OCTET(i)
|
||||
@@ -122,6 +126,8 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #(
|
||||
.B_in(octet_B),
|
||||
.C_in(octet_C),
|
||||
.operands_valid(execute_if.valid),
|
||||
.operands_wid(execute_if.data.wid),
|
||||
.operands_last_in_pair(operands_last_in_pair),
|
||||
.operands_ready(octet_operands_ready[i]),
|
||||
|
||||
.step(step),
|
||||
@@ -245,11 +251,14 @@ module VX_tensor_octet #(
|
||||
input clk,
|
||||
input reset,
|
||||
|
||||
input [7:0][31:0] A_in,
|
||||
input [7:0][31:0] B_in,
|
||||
input [7:0][31:0] C_in,
|
||||
input operands_valid, // we have to backpressure due to there potentially being contention over commit
|
||||
output operands_ready,
|
||||
input [7:0][31:0] A_in,
|
||||
input [7:0][31:0] B_in,
|
||||
input [7:0][31:0] C_in,
|
||||
input operands_valid,
|
||||
input [`NW_WIDTH-1:0] operands_wid,
|
||||
input operands_last_in_pair,
|
||||
// we have to backpressure due to there potentially being contention over commit
|
||||
output operands_ready,
|
||||
|
||||
input [1:0] step,
|
||||
|
||||
@@ -258,9 +267,9 @@ module VX_tensor_octet #(
|
||||
input result_ready
|
||||
);
|
||||
// 512 bits/octet * 4 octets per warp
|
||||
logic [3:0][31:0] A_buffer, A_buffer_n;
|
||||
logic [3:0][31:0] B_buffer, B_buffer_n;
|
||||
logic [7:0][31:0] C_buffer, C_buffer_n;
|
||||
logic [`NUM_WARPS-1:0][3:0][31:0] A_buffer, A_buffer_n;
|
||||
logic [`NUM_WARPS-1:0][3:0][31:0] B_buffer, B_buffer_n;
|
||||
logic [`NUM_WARPS-1:0][7:0][31:0] C_buffer, C_buffer_n;
|
||||
|
||||
// half the inputs are buffered, half are not (instead coming straight
|
||||
// from operand bus) unlike the real tensor core.
|
||||
@@ -268,6 +277,10 @@ module VX_tensor_octet #(
|
||||
logic [3:0][31:0] A_half;
|
||||
logic [3:0][31:0] B_half;
|
||||
logic [7:0][31:0] C_half;
|
||||
|
||||
logic [`NUM_WARPS-1:0] substeps;
|
||||
logic [`NUM_WARPS-1:0] substeps_n;
|
||||
|
||||
always @(*) begin
|
||||
// note that not all lanes participate at every step
|
||||
case (step)
|
||||
@@ -296,18 +309,29 @@ module VX_tensor_octet #(
|
||||
end
|
||||
|
||||
logic substep;
|
||||
wire substep_n = (operands_ready && operands_valid) ? ~substep : substep;
|
||||
wire operands_fire = operands_ready && operands_valid;
|
||||
wire substep_n = operands_fire && operands_last_in_pair;
|
||||
|
||||
always @(*) begin
|
||||
A_buffer_n = A_buffer;
|
||||
B_buffer_n = B_buffer;
|
||||
C_buffer_n = C_buffer;
|
||||
substeps_n = substeps;
|
||||
|
||||
if (substep == 1'b0) begin
|
||||
A_buffer_n = A_half;
|
||||
B_buffer_n = B_half;
|
||||
C_buffer_n = C_half;
|
||||
if (operands_fire) begin
|
||||
substeps_n[operands_wid] = ~substeps[operands_wid];
|
||||
if (!operands_last_in_pair) begin
|
||||
A_buffer_n[operands_wid] = A_half;
|
||||
B_buffer_n[operands_wid] = B_half;
|
||||
C_buffer_n[operands_wid] = C_half;
|
||||
end
|
||||
end
|
||||
|
||||
// if (operands_fire && (substep == 1'b0)) begin
|
||||
// A_buffer_n[operands_wid] = A_half;
|
||||
// B_buffer_n[operands_wid] = B_half;
|
||||
// C_buffer_n[operands_wid] = C_half;
|
||||
// end
|
||||
end
|
||||
|
||||
always @(posedge clk) begin
|
||||
@@ -315,13 +339,17 @@ module VX_tensor_octet #(
|
||||
A_buffer <= '0;
|
||||
B_buffer <= '0;
|
||||
C_buffer <= '0;
|
||||
|
||||
substep <= '0;
|
||||
substeps <= '0;
|
||||
end
|
||||
else begin
|
||||
A_buffer <= A_buffer_n;
|
||||
B_buffer <= B_buffer_n;
|
||||
C_buffer <= C_buffer_n;
|
||||
|
||||
substep <= substep_n;
|
||||
substeps <= substeps_n;
|
||||
end
|
||||
end
|
||||
|
||||
@@ -330,39 +358,38 @@ module VX_tensor_octet #(
|
||||
// wire stall = result_valid && ~result_ready;
|
||||
// backpressure from commit
|
||||
wire stall = ~outbuf_ready_in;
|
||||
assign operands_ready = ~stall;
|
||||
// assign operands_ready = ~stall;
|
||||
// TODO: Below line is to only allow 1 warp to occupy the octet at a time;
|
||||
// currently, dpu is fully-pipelined and allows concurrency between
|
||||
// multiple warps. This seems to be not a problem though given that the
|
||||
// RF operand read takes >=2 cycles, which should be the end-to-end
|
||||
// latency of the DPU anyways
|
||||
// assign operands_ready = hmma_ready && ~stall;
|
||||
assign operands_ready = hmma_ready && ~stall;
|
||||
|
||||
// A is 4x2 fp32 matrix
|
||||
wire [3:0][1:0][31:0] A_tile = {
|
||||
{ A_half[3], A_buffer[3] },
|
||||
{ A_half[2], A_buffer[2] },
|
||||
{ A_half[1], A_buffer[1] },
|
||||
{ A_half[0], A_buffer[0] }
|
||||
{ A_half[3], A_buffer[operands_wid][3] },
|
||||
{ A_half[2], A_buffer[operands_wid][2] },
|
||||
{ A_half[1], A_buffer[operands_wid][1] },
|
||||
{ A_half[0], A_buffer[operands_wid][0] }
|
||||
};
|
||||
// B is 2x4 fp32 matrix
|
||||
wire [1:0][3:0][31:0] B_tile = {
|
||||
B_half, B_buffer
|
||||
B_half, B_buffer[operands_wid]
|
||||
};
|
||||
// C is 4x4 fp32 matrix
|
||||
logic [3:0][3:0][31:0] C_tile;
|
||||
logic [3:0][3:0][31:0] D_tile;
|
||||
|
||||
always @(*) begin
|
||||
C_tile = {
|
||||
C_half[7], C_buffer[7], C_half[5], C_buffer[5],
|
||||
C_half[6], C_buffer[6], C_half[4], C_buffer[4],
|
||||
C_half[3], C_buffer[3], C_half[1], C_buffer[1],
|
||||
C_half[2], C_buffer[2], C_half[0], C_buffer[0]
|
||||
};
|
||||
C_tile[3] = { C_half[7], C_buffer[operands_wid][7], C_half[5], C_buffer[operands_wid][5] };
|
||||
C_tile[2] = { C_half[6], C_buffer[operands_wid][6], C_half[4], C_buffer[operands_wid][4] };
|
||||
C_tile[1] = { C_half[3], C_buffer[operands_wid][3], C_half[1], C_buffer[operands_wid][1] };
|
||||
C_tile[0] = { C_half[2], C_buffer[operands_wid][2], C_half[0], C_buffer[operands_wid][0] };
|
||||
end
|
||||
|
||||
wire do_hmma = (substep == 1'b1 && operands_valid && operands_ready);
|
||||
// wire do_hmma = operands_fire && (substeps[operands_wid] == 1'b1);
|
||||
wire do_hmma = operands_fire && operands_last_in_pair;
|
||||
wire dpu_valid;
|
||||
|
||||
// this does (m,n,k)=(4,4,2) matmul, modeling compute of a single octet
|
||||
|
||||
Reference in New Issue
Block a user