Stage half-operands per warp

An easy solution to handle multiple concurrent warp operations by
staging half-operands in their own per-warp register.  This might
increase area requirement by quite a bit.

TODO: Commit is not being handled correctly yet
This commit is contained in:
Hansung Kim
2024-05-25 19:08:17 -07:00
parent 45d86b26a2
commit 8775458a8f

View File

@@ -83,6 +83,8 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #(
localparam LANE_OFFSET_THREADGROUP = (4 * NUM_OCTETS); localparam LANE_OFFSET_THREADGROUP = (4 * NUM_OCTETS);
wire [1:0] step = 2'(execute_if.data.op_type); wire [1:0] step = 2'(execute_if.data.op_type);
wire operands_last_in_pair = (execute_if.data.op_mod == `INST_MOD_BITS'(1));
logic [NUM_OCTETS-1:0] octet_results_valid; logic [NUM_OCTETS-1:0] octet_results_valid;
logic [NUM_OCTETS-1:0] octet_results_ready; logic [NUM_OCTETS-1:0] octet_results_ready;
logic [NUM_OCTETS-1:0] octet_operands_ready; logic [NUM_OCTETS-1:0] octet_operands_ready;
@@ -111,6 +113,8 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #(
logic [3:0][3:0][31:0] octet_D; logic [3:0][3:0][31:0] octet_D;
logic result_valid; logic result_valid;
logic result_ready; logic result_ready;
// op_mod is reused to indicate instruction's id in pair
VX_tensor_octet #( VX_tensor_octet #(
.ISW(ISW), .ISW(ISW),
.OCTET(i) .OCTET(i)
@@ -122,6 +126,8 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #(
.B_in(octet_B), .B_in(octet_B),
.C_in(octet_C), .C_in(octet_C),
.operands_valid(execute_if.valid), .operands_valid(execute_if.valid),
.operands_wid(execute_if.data.wid),
.operands_last_in_pair(operands_last_in_pair),
.operands_ready(octet_operands_ready[i]), .operands_ready(octet_operands_ready[i]),
.step(step), .step(step),
@@ -245,11 +251,14 @@ module VX_tensor_octet #(
input clk, input clk,
input reset, input reset,
input [7:0][31:0] A_in, input [7:0][31:0] A_in,
input [7:0][31:0] B_in, input [7:0][31:0] B_in,
input [7:0][31:0] C_in, input [7:0][31:0] C_in,
input operands_valid, // we have to backpressure due to there potentially being contention over commit input operands_valid,
output operands_ready, input [`NW_WIDTH-1:0] operands_wid,
input operands_last_in_pair,
// we have to backpressure due to there potentially being contention over commit
output operands_ready,
input [1:0] step, input [1:0] step,
@@ -258,9 +267,9 @@ module VX_tensor_octet #(
input result_ready input result_ready
); );
// 512 bits/octet * 4 octets per warp // 512 bits/octet * 4 octets per warp
logic [3:0][31:0] A_buffer, A_buffer_n; logic [`NUM_WARPS-1:0][3:0][31:0] A_buffer, A_buffer_n;
logic [3:0][31:0] B_buffer, B_buffer_n; logic [`NUM_WARPS-1:0][3:0][31:0] B_buffer, B_buffer_n;
logic [7:0][31:0] C_buffer, C_buffer_n; logic [`NUM_WARPS-1:0][7:0][31:0] C_buffer, C_buffer_n;
// half the inputs are buffered, half are not (instead coming straight // half the inputs are buffered, half are not (instead coming straight
// from operand bus) unlike the real tensor core. // from operand bus) unlike the real tensor core.
@@ -268,6 +277,10 @@ module VX_tensor_octet #(
logic [3:0][31:0] A_half; logic [3:0][31:0] A_half;
logic [3:0][31:0] B_half; logic [3:0][31:0] B_half;
logic [7:0][31:0] C_half; logic [7:0][31:0] C_half;
logic [`NUM_WARPS-1:0] substeps;
logic [`NUM_WARPS-1:0] substeps_n;
always @(*) begin always @(*) begin
// note that not all lanes participate at every step // note that not all lanes participate at every step
case (step) case (step)
@@ -296,18 +309,29 @@ module VX_tensor_octet #(
end end
logic substep; logic substep;
wire substep_n = (operands_ready && operands_valid) ? ~substep : substep; wire operands_fire = operands_ready && operands_valid;
wire substep_n = operands_fire && operands_last_in_pair;
always @(*) begin always @(*) begin
A_buffer_n = A_buffer; A_buffer_n = A_buffer;
B_buffer_n = B_buffer; B_buffer_n = B_buffer;
C_buffer_n = C_buffer; C_buffer_n = C_buffer;
substeps_n = substeps;
if (substep == 1'b0) begin if (operands_fire) begin
A_buffer_n = A_half; substeps_n[operands_wid] = ~substeps[operands_wid];
B_buffer_n = B_half; if (!operands_last_in_pair) begin
C_buffer_n = C_half; A_buffer_n[operands_wid] = A_half;
B_buffer_n[operands_wid] = B_half;
C_buffer_n[operands_wid] = C_half;
end
end end
// if (operands_fire && (substep == 1'b0)) begin
// A_buffer_n[operands_wid] = A_half;
// B_buffer_n[operands_wid] = B_half;
// C_buffer_n[operands_wid] = C_half;
// end
end end
always @(posedge clk) begin always @(posedge clk) begin
@@ -315,13 +339,17 @@ module VX_tensor_octet #(
A_buffer <= '0; A_buffer <= '0;
B_buffer <= '0; B_buffer <= '0;
C_buffer <= '0; C_buffer <= '0;
substep <= '0; substep <= '0;
substeps <= '0;
end end
else begin else begin
A_buffer <= A_buffer_n; A_buffer <= A_buffer_n;
B_buffer <= B_buffer_n; B_buffer <= B_buffer_n;
C_buffer <= C_buffer_n; C_buffer <= C_buffer_n;
substep <= substep_n; substep <= substep_n;
substeps <= substeps_n;
end end
end end
@@ -330,39 +358,38 @@ module VX_tensor_octet #(
// wire stall = result_valid && ~result_ready; // wire stall = result_valid && ~result_ready;
// backpressure from commit // backpressure from commit
wire stall = ~outbuf_ready_in; wire stall = ~outbuf_ready_in;
assign operands_ready = ~stall; // assign operands_ready = ~stall;
// TODO: Below line is to only allow 1 warp to occupy the octet at a time; // TODO: Below line is to only allow 1 warp to occupy the octet at a time;
// currently, dpu is fully-pipelined and allows concurrency between // currently, dpu is fully-pipelined and allows concurrency between
// multiple warps. This seems to be not a problem though given that the // multiple warps. This seems to be not a problem though given that the
// RF operand read takes >=2 cycles, which should be the end-to-end // RF operand read takes >=2 cycles, which should be the end-to-end
// latency of the DPU anyways // latency of the DPU anyways
// assign operands_ready = hmma_ready && ~stall; assign operands_ready = hmma_ready && ~stall;
// A is 4x2 fp32 matrix // A is 4x2 fp32 matrix
wire [3:0][1:0][31:0] A_tile = { wire [3:0][1:0][31:0] A_tile = {
{ A_half[3], A_buffer[3] }, { A_half[3], A_buffer[operands_wid][3] },
{ A_half[2], A_buffer[2] }, { A_half[2], A_buffer[operands_wid][2] },
{ A_half[1], A_buffer[1] }, { A_half[1], A_buffer[operands_wid][1] },
{ A_half[0], A_buffer[0] } { A_half[0], A_buffer[operands_wid][0] }
}; };
// B is 2x4 fp32 matrix // B is 2x4 fp32 matrix
wire [1:0][3:0][31:0] B_tile = { wire [1:0][3:0][31:0] B_tile = {
B_half, B_buffer B_half, B_buffer[operands_wid]
}; };
// C is 4x4 fp32 matrix // C is 4x4 fp32 matrix
logic [3:0][3:0][31:0] C_tile; logic [3:0][3:0][31:0] C_tile;
logic [3:0][3:0][31:0] D_tile; logic [3:0][3:0][31:0] D_tile;
always @(*) begin always @(*) begin
C_tile = { C_tile[3] = { C_half[7], C_buffer[operands_wid][7], C_half[5], C_buffer[operands_wid][5] };
C_half[7], C_buffer[7], C_half[5], C_buffer[5], C_tile[2] = { C_half[6], C_buffer[operands_wid][6], C_half[4], C_buffer[operands_wid][4] };
C_half[6], C_buffer[6], C_half[4], C_buffer[4], C_tile[1] = { C_half[3], C_buffer[operands_wid][3], C_half[1], C_buffer[operands_wid][1] };
C_half[3], C_buffer[3], C_half[1], C_buffer[1], C_tile[0] = { C_half[2], C_buffer[operands_wid][2], C_half[0], C_buffer[operands_wid][0] };
C_half[2], C_buffer[2], C_half[0], C_buffer[0]
};
end end
wire do_hmma = (substep == 1'b1 && operands_valid && operands_ready); // wire do_hmma = operands_fire && (substeps[operands_wid] == 1'b1);
wire do_hmma = operands_fire && operands_last_in_pair;
wire dpu_valid; wire dpu_valid;
// this does (m,n,k)=(4,4,2) matmul, modeling compute of a single octet // this does (m,n,k)=(4,4,2) matmul, modeling compute of a single octet