tensor: Properly stall dpu upon commit backpressure
& better-reasoned queue depths
This commit is contained in:
@@ -77,6 +77,7 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #(
|
|||||||
// octet. E.g. two tgs map lane 0-3 and lane 16-19 -> 16
|
// octet. E.g. two tgs map lane 0-3 and lane 16-19 -> 16
|
||||||
// FIXME: not sure this is the right logic. just filling in what works
|
// FIXME: not sure this is the right logic. just filling in what works
|
||||||
localparam LANE_OFFSET_THREADGROUP = (4 * NUM_OCTETS);
|
localparam LANE_OFFSET_THREADGROUP = (4 * NUM_OCTETS);
|
||||||
|
localparam REQ_QUEUE_DEPTH = 4;
|
||||||
|
|
||||||
wire [1:0] step = 2'(execute_if.data.op_type);
|
wire [1:0] step = 2'(execute_if.data.op_type);
|
||||||
wire last_in_pair = (execute_if.data.op_mod == `INST_MOD_BITS'(1));
|
wire last_in_pair = (execute_if.data.op_mod == `INST_MOD_BITS'(1));
|
||||||
@@ -219,7 +220,7 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #(
|
|||||||
|
|
||||||
VX_fifo_queue #(
|
VX_fifo_queue #(
|
||||||
.DATAW(DATAW),
|
.DATAW(DATAW),
|
||||||
.DEPTH(8 /* FIXME: arbitrary */)
|
.DEPTH(REQ_QUEUE_DEPTH)
|
||||||
) pending_uops (
|
) pending_uops (
|
||||||
.clk(clk),
|
.clk(clk),
|
||||||
.reset(reset),
|
.reset(reset),
|
||||||
@@ -234,6 +235,8 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #(
|
|||||||
`UNUSED_PIN(size)
|
`UNUSED_PIN(size)
|
||||||
);
|
);
|
||||||
|
|
||||||
|
// this shouldn't really happen unless there's a big contention over
|
||||||
|
// the commit stage
|
||||||
`RUNTIME_ASSERT(!(!reset && full), ("tensor core uop queue is full!"));
|
`RUNTIME_ASSERT(!(!reset && full), ("tensor core uop queue is full!"));
|
||||||
end
|
end
|
||||||
|
|
||||||
@@ -300,6 +303,8 @@ module VX_tensor_octet #(
|
|||||||
output result_valid,
|
output result_valid,
|
||||||
input result_ready
|
input result_ready
|
||||||
);
|
);
|
||||||
|
localparam ISSUE_QUEUE_DEPTH = 4;
|
||||||
|
|
||||||
// 512 bits/octet * 4 octets per warp
|
// 512 bits/octet * 4 octets per warp
|
||||||
logic [`NUM_WARPS-1:0][3:0][31:0] A_buffer, A_buffer_n;
|
logic [`NUM_WARPS-1:0][3:0][31:0] A_buffer, A_buffer_n;
|
||||||
logic [`NUM_WARPS-1:0][3:0][31:0] B_buffer, B_buffer_n;
|
logic [`NUM_WARPS-1:0][3:0][31:0] B_buffer, B_buffer_n;
|
||||||
@@ -351,7 +356,7 @@ module VX_tensor_octet #(
|
|||||||
VX_fifo_queue #(
|
VX_fifo_queue #(
|
||||||
.DATAW ($bits(A_in) + $bits(B_in) + $bits(C_in) +
|
.DATAW ($bits(A_in) + $bits(B_in) + $bits(C_in) +
|
||||||
$bits(operands_wid) + $bits(operands_step) + $bits(operands_last_in_pair)),
|
$bits(operands_wid) + $bits(operands_step) + $bits(operands_last_in_pair)),
|
||||||
.DEPTH (8 /* FIXME: arbitrary */)
|
.DEPTH (ISSUE_QUEUE_DEPTH)
|
||||||
) input_buffer (
|
) input_buffer (
|
||||||
.clk (clk),
|
.clk (clk),
|
||||||
.reset (reset),
|
.reset (reset),
|
||||||
@@ -451,17 +456,8 @@ module VX_tensor_octet #(
|
|||||||
end
|
end
|
||||||
|
|
||||||
wire outbuf_ready_in;
|
wire outbuf_ready_in;
|
||||||
// backpressure from commit
|
|
||||||
wire stall = ~outbuf_ready_in;
|
|
||||||
wire hmma_ready;
|
wire hmma_ready;
|
||||||
|
assign operands_ready_buf = hmma_ready;
|
||||||
// assign operands_ready = ~stall;
|
|
||||||
// TODO: Below line is to only allow 1 warp to occupy the octet at a time;
|
|
||||||
// currently, dpu is fully-pipelined and allows concurrency between
|
|
||||||
// multiple warps. This seems to be not a problem though given that the
|
|
||||||
// RF operand read takes >=2 cycles, which should be the end-to-end
|
|
||||||
// latency of the DPU anyways
|
|
||||||
assign operands_ready_buf = hmma_ready && ~stall;
|
|
||||||
|
|
||||||
// A is 4x2 fp32 matrix
|
// A is 4x2 fp32 matrix
|
||||||
wire [3:0][1:0][31:0] A_tile = {
|
wire [3:0][1:0][31:0] A_tile = {
|
||||||
@@ -496,8 +492,6 @@ module VX_tensor_octet #(
|
|||||||
.clk(clk),
|
.clk(clk),
|
||||||
.reset(reset),
|
.reset(reset),
|
||||||
|
|
||||||
.stall(stall),
|
|
||||||
|
|
||||||
.valid_in(do_hmma),
|
.valid_in(do_hmma),
|
||||||
.ready_in(hmma_ready),
|
.ready_in(hmma_ready),
|
||||||
.A_tile(A_tile),
|
.A_tile(A_tile),
|
||||||
@@ -506,12 +500,14 @@ module VX_tensor_octet #(
|
|||||||
.wid(operands_wid_buf),
|
.wid(operands_wid_buf),
|
||||||
|
|
||||||
.valid_out(dpu_valid),
|
.valid_out(dpu_valid),
|
||||||
|
.ready_out(outbuf_ready_in),
|
||||||
.D_tile(D_tile),
|
.D_tile(D_tile),
|
||||||
.D_wid(D_wid_dpu)
|
.D_wid(D_wid_dpu)
|
||||||
);
|
);
|
||||||
|
|
||||||
wire outbuf_empty;
|
wire outbuf_empty;
|
||||||
wire outbuf_full;
|
wire outbuf_full;
|
||||||
|
// backpressure from commit
|
||||||
assign outbuf_ready_in = ~outbuf_full;
|
assign outbuf_ready_in = ~outbuf_full;
|
||||||
assign result_valid = ~outbuf_empty;
|
assign result_valid = ~outbuf_empty;
|
||||||
|
|
||||||
@@ -525,7 +521,10 @@ module VX_tensor_octet #(
|
|||||||
// TODO: This is probably oversized.
|
// TODO: This is probably oversized.
|
||||||
VX_fifo_queue #(
|
VX_fifo_queue #(
|
||||||
.DATAW ($bits(D_wid) + $bits(D_out)),
|
.DATAW ($bits(D_wid) + $bits(D_out)),
|
||||||
.DEPTH (8 /* FIXME: arbitrary */)
|
// depth of this queue should ideally be deeper than the dpu pipeline
|
||||||
|
// latency, since the dpu is fully-pipelined and it can output the
|
||||||
|
// latency-number of outputs in a burst-y way.
|
||||||
|
.DEPTH (`LATENCY_HMMA + `LATENCY_HMMA)
|
||||||
) output_buffer (
|
) output_buffer (
|
||||||
.clk (clk),
|
.clk (clk),
|
||||||
.reset (reset),
|
.reset (reset),
|
||||||
|
|||||||
@@ -8,8 +8,6 @@ module VX_tensor_dpu #(
|
|||||||
input clk,
|
input clk,
|
||||||
input reset,
|
input reset,
|
||||||
|
|
||||||
input stall,
|
|
||||||
|
|
||||||
input valid_in,
|
input valid_in,
|
||||||
output ready_in,
|
output ready_in,
|
||||||
input [3:0][1:0][31:0] A_tile,
|
input [3:0][1:0][31:0] A_tile,
|
||||||
@@ -18,6 +16,7 @@ module VX_tensor_dpu #(
|
|||||||
input [`NW_WIDTH-1:0] wid,
|
input [`NW_WIDTH-1:0] wid,
|
||||||
|
|
||||||
output valid_out,
|
output valid_out,
|
||||||
|
input ready_out,
|
||||||
output [3:0][3:0][31:0] D_tile,
|
output [3:0][3:0][31:0] D_tile,
|
||||||
output [`NW_WIDTH-1:0] D_wid
|
output [`NW_WIDTH-1:0] D_wid
|
||||||
);
|
);
|
||||||
@@ -40,10 +39,11 @@ module VX_tensor_dpu #(
|
|||||||
end
|
end
|
||||||
|
|
||||||
// ready as soon as valid_out
|
// ready as soon as valid_out
|
||||||
assign ready_in = ready_reg || valid_out;
|
// assign ready_in = ready_reg || valid_out;
|
||||||
|
|
||||||
// fully pipelined; always ready
|
// fully pipelined; ready_in is coupled to ready_out by immediately
|
||||||
// assign ready_in = 1'b1;
|
// stalling
|
||||||
|
assign ready_in = ready_out;
|
||||||
|
|
||||||
// wire dpu_valid;
|
// wire dpu_valid;
|
||||||
// wire [31:0] dpu_data;
|
// wire [31:0] dpu_data;
|
||||||
@@ -70,8 +70,8 @@ module VX_tensor_dpu #(
|
|||||||
) threadgroup_0 (
|
) threadgroup_0 (
|
||||||
.clk (clk),
|
.clk (clk),
|
||||||
.reset (reset),
|
.reset (reset),
|
||||||
.valid_in (valid_in && ready_in),
|
.valid_in (valid_in),
|
||||||
.stall (stall),
|
.stall (!ready_out),
|
||||||
.A_frag (A_tile[1:0]),
|
.A_frag (A_tile[1:0]),
|
||||||
.B_frag (B_tile),
|
.B_frag (B_tile),
|
||||||
.C_frag (C_tile[1:0]),
|
.C_frag (C_tile[1:0]),
|
||||||
@@ -82,8 +82,8 @@ module VX_tensor_dpu #(
|
|||||||
) threadgroup_1 (
|
) threadgroup_1 (
|
||||||
.clk (clk),
|
.clk (clk),
|
||||||
.reset (reset),
|
.reset (reset),
|
||||||
.valid_in (valid_in && ready_in),
|
.valid_in (valid_in),
|
||||||
.stall (stall),
|
.stall (!ready_out),
|
||||||
.A_frag (A_tile[3:2]),
|
.A_frag (A_tile[3:2]),
|
||||||
.B_frag (B_tile),
|
.B_frag (B_tile),
|
||||||
.C_frag (C_tile[3:2]),
|
.C_frag (C_tile[3:2]),
|
||||||
@@ -94,18 +94,16 @@ module VX_tensor_dpu #(
|
|||||||
// fixed-latency queue
|
// fixed-latency queue
|
||||||
VX_shift_register #(
|
VX_shift_register #(
|
||||||
.DATAW (1 + $bits(wid)/* + $bits(D_tile)*/),
|
.DATAW (1 + $bits(wid)/* + $bits(D_tile)*/),
|
||||||
// .DEPTH (`LATENCY_HMMA),
|
.DEPTH (`LATENCY_HMMA),
|
||||||
.DEPTH (4),
|
|
||||||
.RESETW (1)
|
.RESETW (1)
|
||||||
) shift_reg (
|
) shift_reg (
|
||||||
.clk (clk),
|
.clk (clk),
|
||||||
.reset (reset),
|
.reset (reset),
|
||||||
.enable (~stall),
|
.enable (ready_out),
|
||||||
.data_in ({valid_in && ready_in, wid /*, result_hmma*/}),
|
.data_in ({valid_in && ready_in, wid /*, result_hmma*/}),
|
||||||
.data_out ({valid_out, D_wid/*, D_tile */})
|
.data_out ({valid_out, D_wid/*, D_tile */})
|
||||||
);
|
);
|
||||||
|
|
||||||
// FIXME: breaks when stall is on!
|
|
||||||
`RUNTIME_ASSERT(reset || (&(threadgroup_valids) == valid_out),
|
`RUNTIME_ASSERT(reset || (&(threadgroup_valids) == valid_out),
|
||||||
("FEDP and metadata queue went out of sync!"))
|
("FEDP and metadata queue went out of sync!"))
|
||||||
endmodule
|
endmodule
|
||||||
@@ -146,7 +144,7 @@ module VX_tensor_threadgroup #(
|
|||||||
.io_in_bits_b_2 (32'h0),
|
.io_in_bits_b_2 (32'h0),
|
||||||
.io_in_bits_b_3 (32'h0),
|
.io_in_bits_b_3 (32'h0),
|
||||||
.io_in_bits_c (C_frag[D_row][D_col]),
|
.io_in_bits_c (C_frag[D_row][D_col]),
|
||||||
.io_stall (1'b0), // FIXME
|
.io_stall (stall),
|
||||||
.io_out_valid (valids[D_row][D_col]),
|
.io_out_valid (valids[D_row][D_col]),
|
||||||
.io_out_bits_data (D_frag[D_row][D_col])
|
.io_out_bits_data (D_frag[D_row][D_col])
|
||||||
);
|
);
|
||||||
@@ -154,8 +152,6 @@ module VX_tensor_threadgroup #(
|
|||||||
end
|
end
|
||||||
|
|
||||||
assign valid_out = (&(valids[0])) && (&(valids[1]));
|
assign valid_out = (&(valids[0])) && (&(valids[1]));
|
||||||
|
|
||||||
`RUNTIME_ASSERT(reset || !stall, ("stall not supported yet in tensor dpu!"))
|
|
||||||
endmodule
|
endmodule
|
||||||
|
|
||||||
`endif
|
`endif
|
||||||
|
|||||||
Reference in New Issue
Block a user