tensor: Properly stall dpu upon commit backpressure

& better-reasoned queue depths
This commit is contained in:
Hansung Kim
2024-05-29 17:05:12 -07:00
parent f5a9ca5bf3
commit 5ed6041e33
2 changed files with 26 additions and 31 deletions

View File

@@ -77,6 +77,7 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #(
// octet. E.g. two tgs map lane 0-3 and lane 16-19 -> 16 // octet. E.g. two tgs map lane 0-3 and lane 16-19 -> 16
// FIXME: not sure this is the right logic. just filling in what works // FIXME: not sure this is the right logic. just filling in what works
localparam LANE_OFFSET_THREADGROUP = (4 * NUM_OCTETS); localparam LANE_OFFSET_THREADGROUP = (4 * NUM_OCTETS);
localparam REQ_QUEUE_DEPTH = 4;
wire [1:0] step = 2'(execute_if.data.op_type); wire [1:0] step = 2'(execute_if.data.op_type);
wire last_in_pair = (execute_if.data.op_mod == `INST_MOD_BITS'(1)); wire last_in_pair = (execute_if.data.op_mod == `INST_MOD_BITS'(1));
@@ -219,7 +220,7 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #(
VX_fifo_queue #( VX_fifo_queue #(
.DATAW(DATAW), .DATAW(DATAW),
.DEPTH(8 /* FIXME: arbitrary */) .DEPTH(REQ_QUEUE_DEPTH)
) pending_uops ( ) pending_uops (
.clk(clk), .clk(clk),
.reset(reset), .reset(reset),
@@ -234,6 +235,8 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #(
`UNUSED_PIN(size) `UNUSED_PIN(size)
); );
// this shouldn't really happen unless there's a big contention over
// the commit stage
`RUNTIME_ASSERT(!(!reset && full), ("tensor core uop queue is full!")); `RUNTIME_ASSERT(!(!reset && full), ("tensor core uop queue is full!"));
end end
@@ -300,6 +303,8 @@ module VX_tensor_octet #(
output result_valid, output result_valid,
input result_ready input result_ready
); );
localparam ISSUE_QUEUE_DEPTH = 4;
// 512 bits/octet * 4 octets per warp // 512 bits/octet * 4 octets per warp
logic [`NUM_WARPS-1:0][3:0][31:0] A_buffer, A_buffer_n; logic [`NUM_WARPS-1:0][3:0][31:0] A_buffer, A_buffer_n;
logic [`NUM_WARPS-1:0][3:0][31:0] B_buffer, B_buffer_n; logic [`NUM_WARPS-1:0][3:0][31:0] B_buffer, B_buffer_n;
@@ -351,7 +356,7 @@ module VX_tensor_octet #(
VX_fifo_queue #( VX_fifo_queue #(
.DATAW ($bits(A_in) + $bits(B_in) + $bits(C_in) + .DATAW ($bits(A_in) + $bits(B_in) + $bits(C_in) +
$bits(operands_wid) + $bits(operands_step) + $bits(operands_last_in_pair)), $bits(operands_wid) + $bits(operands_step) + $bits(operands_last_in_pair)),
.DEPTH (8 /* FIXME: arbitrary */) .DEPTH (ISSUE_QUEUE_DEPTH)
) input_buffer ( ) input_buffer (
.clk (clk), .clk (clk),
.reset (reset), .reset (reset),
@@ -451,17 +456,8 @@ module VX_tensor_octet #(
end end
wire outbuf_ready_in; wire outbuf_ready_in;
// backpressure from commit
wire stall = ~outbuf_ready_in;
wire hmma_ready; wire hmma_ready;
assign operands_ready_buf = hmma_ready;
// assign operands_ready = ~stall;
// TODO: Below line is to only allow 1 warp to occupy the octet at a time;
// currently, dpu is fully-pipelined and allows concurrency between
// multiple warps. This seems to be not a problem though given that the
// RF operand read takes >=2 cycles, which should be the end-to-end
// latency of the DPU anyways
assign operands_ready_buf = hmma_ready && ~stall;
// A is 4x2 fp32 matrix // A is 4x2 fp32 matrix
wire [3:0][1:0][31:0] A_tile = { wire [3:0][1:0][31:0] A_tile = {
@@ -496,8 +492,6 @@ module VX_tensor_octet #(
.clk(clk), .clk(clk),
.reset(reset), .reset(reset),
.stall(stall),
.valid_in(do_hmma), .valid_in(do_hmma),
.ready_in(hmma_ready), .ready_in(hmma_ready),
.A_tile(A_tile), .A_tile(A_tile),
@@ -506,12 +500,14 @@ module VX_tensor_octet #(
.wid(operands_wid_buf), .wid(operands_wid_buf),
.valid_out(dpu_valid), .valid_out(dpu_valid),
.ready_out(outbuf_ready_in),
.D_tile(D_tile), .D_tile(D_tile),
.D_wid(D_wid_dpu) .D_wid(D_wid_dpu)
); );
wire outbuf_empty; wire outbuf_empty;
wire outbuf_full; wire outbuf_full;
// backpressure from commit
assign outbuf_ready_in = ~outbuf_full; assign outbuf_ready_in = ~outbuf_full;
assign result_valid = ~outbuf_empty; assign result_valid = ~outbuf_empty;
@@ -525,7 +521,10 @@ module VX_tensor_octet #(
// TODO: This is probably oversized. // TODO: This is probably oversized.
VX_fifo_queue #( VX_fifo_queue #(
.DATAW ($bits(D_wid) + $bits(D_out)), .DATAW ($bits(D_wid) + $bits(D_out)),
.DEPTH (8 /* FIXME: arbitrary */) // depth of this queue should ideally be deeper than the dpu pipeline
// latency, since the dpu is fully-pipelined and it can output the
// latency-number of outputs in a burst-y way.
.DEPTH (`LATENCY_HMMA + `LATENCY_HMMA)
) output_buffer ( ) output_buffer (
.clk (clk), .clk (clk),
.reset (reset), .reset (reset),

View File

@@ -8,8 +8,6 @@ module VX_tensor_dpu #(
input clk, input clk,
input reset, input reset,
input stall,
input valid_in, input valid_in,
output ready_in, output ready_in,
input [3:0][1:0][31:0] A_tile, input [3:0][1:0][31:0] A_tile,
@@ -18,6 +16,7 @@ module VX_tensor_dpu #(
input [`NW_WIDTH-1:0] wid, input [`NW_WIDTH-1:0] wid,
output valid_out, output valid_out,
input ready_out,
output [3:0][3:0][31:0] D_tile, output [3:0][3:0][31:0] D_tile,
output [`NW_WIDTH-1:0] D_wid output [`NW_WIDTH-1:0] D_wid
); );
@@ -40,10 +39,11 @@ module VX_tensor_dpu #(
end end
// ready as soon as valid_out // ready as soon as valid_out
assign ready_in = ready_reg || valid_out; // assign ready_in = ready_reg || valid_out;
// fully pipelined; always ready // fully pipelined; ready_in is coupled to ready_out by immediately
// assign ready_in = 1'b1; // stalling
assign ready_in = ready_out;
// wire dpu_valid; // wire dpu_valid;
// wire [31:0] dpu_data; // wire [31:0] dpu_data;
@@ -70,8 +70,8 @@ module VX_tensor_dpu #(
) threadgroup_0 ( ) threadgroup_0 (
.clk (clk), .clk (clk),
.reset (reset), .reset (reset),
.valid_in (valid_in && ready_in), .valid_in (valid_in),
.stall (stall), .stall (!ready_out),
.A_frag (A_tile[1:0]), .A_frag (A_tile[1:0]),
.B_frag (B_tile), .B_frag (B_tile),
.C_frag (C_tile[1:0]), .C_frag (C_tile[1:0]),
@@ -82,8 +82,8 @@ module VX_tensor_dpu #(
) threadgroup_1 ( ) threadgroup_1 (
.clk (clk), .clk (clk),
.reset (reset), .reset (reset),
.valid_in (valid_in && ready_in), .valid_in (valid_in),
.stall (stall), .stall (!ready_out),
.A_frag (A_tile[3:2]), .A_frag (A_tile[3:2]),
.B_frag (B_tile), .B_frag (B_tile),
.C_frag (C_tile[3:2]), .C_frag (C_tile[3:2]),
@@ -94,18 +94,16 @@ module VX_tensor_dpu #(
// fixed-latency queue // fixed-latency queue
VX_shift_register #( VX_shift_register #(
.DATAW (1 + $bits(wid)/* + $bits(D_tile)*/), .DATAW (1 + $bits(wid)/* + $bits(D_tile)*/),
// .DEPTH (`LATENCY_HMMA), .DEPTH (`LATENCY_HMMA),
.DEPTH (4),
.RESETW (1) .RESETW (1)
) shift_reg ( ) shift_reg (
.clk (clk), .clk (clk),
.reset (reset), .reset (reset),
.enable (~stall), .enable (ready_out),
.data_in ({valid_in && ready_in, wid /*, result_hmma*/}), .data_in ({valid_in && ready_in, wid /*, result_hmma*/}),
.data_out ({valid_out, D_wid/*, D_tile */}) .data_out ({valid_out, D_wid/*, D_tile */})
); );
// FIXME: breaks when stall is on!
`RUNTIME_ASSERT(reset || (&(threadgroup_valids) == valid_out), `RUNTIME_ASSERT(reset || (&(threadgroup_valids) == valid_out),
("FEDP and metadata queue went out of sync!")) ("FEDP and metadata queue went out of sync!"))
endmodule endmodule
@@ -146,7 +144,7 @@ module VX_tensor_threadgroup #(
.io_in_bits_b_2 (32'h0), .io_in_bits_b_2 (32'h0),
.io_in_bits_b_3 (32'h0), .io_in_bits_b_3 (32'h0),
.io_in_bits_c (C_frag[D_row][D_col]), .io_in_bits_c (C_frag[D_row][D_col]),
.io_stall (1'b0), // FIXME .io_stall (stall),
.io_out_valid (valids[D_row][D_col]), .io_out_valid (valids[D_row][D_col]),
.io_out_bits_data (D_frag[D_row][D_col]) .io_out_bits_data (D_frag[D_row][D_col])
); );
@@ -154,8 +152,6 @@ module VX_tensor_threadgroup #(
end end
assign valid_out = (&(valids[0])) && (&(valids[1])); assign valid_out = (&(valids[0])) && (&(valids[1]));
`RUNTIME_ASSERT(reset || !stall, ("stall not supported yet in tensor dpu!"))
endmodule endmodule
`endif `endif