tensor: Issue queue for dpu to improve utilization
This commit is contained in:
@@ -125,10 +125,9 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #(
|
|||||||
.operands_valid(execute_if.valid),
|
.operands_valid(execute_if.valid),
|
||||||
.operands_wid(execute_if.data.wid),
|
.operands_wid(execute_if.data.wid),
|
||||||
.operands_last_in_pair(last_in_pair),
|
.operands_last_in_pair(last_in_pair),
|
||||||
|
.operands_step(step),
|
||||||
.operands_ready(octet_operands_ready[i]),
|
.operands_ready(octet_operands_ready[i]),
|
||||||
|
|
||||||
.step(step),
|
|
||||||
|
|
||||||
.D_out(octet_D),
|
.D_out(octet_D),
|
||||||
.D_wid(wb_wid),
|
.D_wid(wb_wid),
|
||||||
.result_valid(result_valid),
|
.result_valid(result_valid),
|
||||||
@@ -186,18 +185,38 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #(
|
|||||||
// pid/sop/eop set later
|
// pid/sop/eop set later
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// wire [DATAW-1:0] execute_if_data_deq;
|
||||||
|
|
||||||
|
// VX_fifo_queue #(
|
||||||
|
// .DATAW(DATAW),
|
||||||
|
// .DEPTH(4 /* FIXME: arbitrary */)
|
||||||
|
// ) pending_uops (
|
||||||
|
// .clk(clk),
|
||||||
|
// .reset(reset),
|
||||||
|
// .push(execute_if_fire),
|
||||||
|
// .pop(commit_if_fire),
|
||||||
|
// .data_in(execute_if_data_enq),
|
||||||
|
// .data_out(execute_if_data_deq),
|
||||||
|
// `UNUSED_PIN(empty),
|
||||||
|
// `UNUSED_PIN(alm_empty),
|
||||||
|
// `UNUSED_PIN(full), // should be impossible to overflow
|
||||||
|
// `UNUSED_PIN(alm_full),
|
||||||
|
// `UNUSED_PIN(size)
|
||||||
|
// );
|
||||||
|
|
||||||
wire [`NUM_WARPS-1:0][DATAW-1:0] execute_if_data_deq;
|
wire [`NUM_WARPS-1:0][DATAW-1:0] execute_if_data_deq;
|
||||||
|
|
||||||
for (genvar i = 0; i < `NUM_WARPS; i++) begin
|
for (genvar i = 0; i < `NUM_WARPS; i++) begin
|
||||||
wire enq = execute_if_fire && (execute_if.data.wid == `NW_WIDTH'(i));
|
|
||||||
wire deq = commit_if_fire && ( wb_wid == `NW_WIDTH'(i));
|
|
||||||
logic full;
|
|
||||||
|
|
||||||
// execute_if request queue.
|
// execute_if request queue.
|
||||||
// This has to be separated per-warp, as otherwise requests from
|
// This has to be separated per-warp, as otherwise requests from
|
||||||
// multiple warps can be enqueued interleaved, which makes it hard to
|
// multiple warps can be enqueued interleaved, which makes it hard to
|
||||||
// ensure two consecutive dequeues are associated to the same warp for
|
// ensure two consecutive dequeues are associated with the same warp for
|
||||||
// commit.
|
// commit.
|
||||||
|
|
||||||
|
wire enq = execute_if_fire && (execute_if.data.wid == `NW_WIDTH'(i));
|
||||||
|
wire deq = commit_if_fire && ( wb_wid == `NW_WIDTH'(i));
|
||||||
|
wire full;
|
||||||
|
|
||||||
VX_fifo_queue #(
|
VX_fifo_queue #(
|
||||||
.DATAW(DATAW),
|
.DATAW(DATAW),
|
||||||
.DEPTH(4 /* FIXME: arbitrary */)
|
.DEPTH(4 /* FIXME: arbitrary */)
|
||||||
@@ -215,7 +234,7 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #(
|
|||||||
`UNUSED_PIN(size)
|
`UNUSED_PIN(size)
|
||||||
);
|
);
|
||||||
|
|
||||||
`RUNTIME_ASSERT(!full, ("tensor core uop queue is full!"));
|
`RUNTIME_ASSERT(!(!reset && full), ("tensor core uop queue is full!"));
|
||||||
end
|
end
|
||||||
|
|
||||||
// unlike execute which can be interleaved between warps, commit is
|
// unlike execute which can be interleaved between warps, commit is
|
||||||
@@ -229,6 +248,7 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #(
|
|||||||
localparam COMMIT_DATAW = `UUID_WIDTH + `NW_WIDTH + `NUM_THREADS + `XLEN + 1 + `NR_BITS + (`NUM_THREADS * `XLEN) + 1 + 1 + 1;
|
localparam COMMIT_DATAW = `UUID_WIDTH + `NW_WIDTH + `NUM_THREADS + `XLEN + 1 + `NR_BITS + (`NUM_THREADS * `XLEN) + 1 + 1 + 1;
|
||||||
wire [COMMIT_DATAW-1:0] commit_if_data = {
|
wire [COMMIT_DATAW-1:0] commit_if_data = {
|
||||||
execute_if_data_deq[wb_wid], /* uuid ~ rd */
|
execute_if_data_deq[wb_wid], /* uuid ~ rd */
|
||||||
|
// execute_if_data_deq, /* uuid ~ rd */
|
||||||
subcommit == 1'b0 ? wb_data_0 : wb_data_1, /* data */
|
subcommit == 1'b0 ? wb_data_0 : wb_data_1, /* data */
|
||||||
1'b0, /* pid */
|
1'b0, /* pid */
|
||||||
1'b1, /* sop */
|
1'b1, /* sop */
|
||||||
@@ -271,11 +291,10 @@ module VX_tensor_octet #(
|
|||||||
input operands_valid,
|
input operands_valid,
|
||||||
input [`NW_WIDTH-1:0] operands_wid,
|
input [`NW_WIDTH-1:0] operands_wid,
|
||||||
input operands_last_in_pair,
|
input operands_last_in_pair,
|
||||||
|
input [1:0] operands_step,
|
||||||
// we have to backpressure due to there potentially being contention over commit
|
// we have to backpressure due to there potentially being contention over commit
|
||||||
output operands_ready,
|
output operands_ready,
|
||||||
|
|
||||||
input [1:0] step,
|
|
||||||
|
|
||||||
output [3:0][3:0][31:0] D_out,
|
output [3:0][3:0][31:0] D_out,
|
||||||
output [`NW_WIDTH-1:0] D_wid,
|
output [`NW_WIDTH-1:0] D_wid,
|
||||||
output result_valid,
|
output result_valid,
|
||||||
@@ -292,11 +311,73 @@ module VX_tensor_octet #(
|
|||||||
logic [3:0][31:0] A_half;
|
logic [3:0][31:0] A_half;
|
||||||
logic [3:0][31:0] B_half;
|
logic [3:0][31:0] B_half;
|
||||||
logic [7:0][31:0] C_half;
|
logic [7:0][31:0] C_half;
|
||||||
|
logic [3:0][31:0] A_half_buf;
|
||||||
|
logic [3:0][31:0] B_half_buf;
|
||||||
|
logic [7:0][31:0] C_half_buf;
|
||||||
|
|
||||||
|
|
||||||
logic [`NUM_WARPS-1:0] substeps;
|
logic [`NUM_WARPS-1:0] substeps;
|
||||||
logic [`NUM_WARPS-1:0] substeps_n;
|
logic [`NUM_WARPS-1:0] substeps_n;
|
||||||
|
|
||||||
always @(*) begin
|
wire [7:0][31:0] A_in_buf;
|
||||||
|
wire [7:0][31:0] B_in_buf;
|
||||||
|
wire [7:0][31:0] C_in_buf;
|
||||||
|
wire operands_valid_buf;
|
||||||
|
wire operands_ready_buf;
|
||||||
|
wire [`NW_WIDTH-1:0] operands_wid_buf;
|
||||||
|
wire operands_last_in_pair_buf;
|
||||||
|
wire [1:0] operands_step_buf;
|
||||||
|
|
||||||
|
wire inbuf_empty;
|
||||||
|
wire inbuf_full;
|
||||||
|
wire inbuf_ready_in;
|
||||||
|
assign inbuf_ready_in = !inbuf_full;
|
||||||
|
assign operands_ready = inbuf_ready_in;
|
||||||
|
assign operands_valid_buf = !inbuf_empty;
|
||||||
|
|
||||||
|
wire inbuf_enq = operands_ready && operands_valid && operands_last_in_pair;
|
||||||
|
wire inbuf_deq = operands_valid_buf && operands_ready_buf;
|
||||||
|
|
||||||
|
// the 'issue queue' for the dpu.
|
||||||
|
// This exists to decouple the input of the dot-product unit from
|
||||||
|
// execute_if.ready. execute_if can arrive intermittently according to
|
||||||
|
// the frontend's behavior, and since the dpu can also stall for a fixed
|
||||||
|
// initiation latency, we need to decouple the two to efficiently feed the
|
||||||
|
// dpu.
|
||||||
|
// This only applies to the last instruction in a pair, since the first
|
||||||
|
// instruction only acts to buffer the operands and can execute
|
||||||
|
// immediately without backpressure. So we don't enqueue them.
|
||||||
|
VX_fifo_queue #(
|
||||||
|
.DATAW ($bits(A_in) + $bits(B_in) + $bits(C_in) +
|
||||||
|
$bits(operands_wid) + $bits(operands_step) + $bits(operands_last_in_pair)),
|
||||||
|
.DEPTH (4 /* FIXME: arbitrary */)
|
||||||
|
) input_buffer (
|
||||||
|
.clk (clk),
|
||||||
|
.reset (reset),
|
||||||
|
.push (inbuf_enq),
|
||||||
|
.pop (inbuf_deq),
|
||||||
|
.data_in ({A_in, B_in, C_in, operands_wid, operands_step, operands_last_in_pair}),
|
||||||
|
.data_out ({A_in_buf, B_in_buf, C_in_buf, operands_wid_buf, operands_step_buf, operands_last_in_pair_buf}),
|
||||||
|
.empty (inbuf_empty),
|
||||||
|
`UNUSED_PIN(alm_empty),
|
||||||
|
.full (inbuf_full),
|
||||||
|
`UNUSED_PIN(alm_full),
|
||||||
|
`UNUSED_PIN(size)
|
||||||
|
);
|
||||||
|
|
||||||
|
typedef struct {
|
||||||
|
logic [3:0][31:0] A_half;
|
||||||
|
logic [3:0][31:0] B_half;
|
||||||
|
logic [7:0][31:0] C_half;
|
||||||
|
} half_t;
|
||||||
|
|
||||||
|
function half_t get_operand_half(
|
||||||
|
input logic [1:0] step,
|
||||||
|
input logic [7:0][31:0] A_in,
|
||||||
|
input logic [7:0][31:0] B_in,
|
||||||
|
input logic [7:0][31:0] C_in
|
||||||
|
);
|
||||||
|
half_t half;
|
||||||
// note that not all lanes participate at every step
|
// note that not all lanes participate at every step
|
||||||
case (step)
|
case (step)
|
||||||
2'b00: begin
|
2'b00: begin
|
||||||
@@ -304,28 +385,34 @@ module VX_tensor_octet #(
|
|||||||
// by two threadgroups: [0:2,0:2] and [4:6,0:2] in Step 0 of
|
// by two threadgroups: [0:2,0:2] and [4:6,0:2] in Step 0 of
|
||||||
// Figure 10(b). B_in OTOH is shared by two threadgroups.
|
// Figure 10(b). B_in OTOH is shared by two threadgroups.
|
||||||
// Note k-dimension is shrunk from 4 to 2.
|
// Note k-dimension is shrunk from 4 to 2.
|
||||||
A_half = { A_in[5:4], A_in[1:0] };
|
half.A_half = { A_in[5:4], A_in[1:0] };
|
||||||
B_half = B_in[3:0];
|
half.B_half = B_in[3:0];
|
||||||
end
|
end
|
||||||
2'b01: begin
|
2'b01: begin
|
||||||
A_half = { A_in[7:6], A_in[3:2] };
|
half.A_half = { A_in[7:6], A_in[3:2] };
|
||||||
B_half = B_in[3:0];
|
half.B_half = B_in[3:0];
|
||||||
end
|
end
|
||||||
2'b10: begin
|
2'b10: begin
|
||||||
A_half = { A_in[5:4], A_in[1:0] };
|
half.A_half = { A_in[5:4], A_in[1:0] };
|
||||||
B_half = B_in[7:4];
|
half.B_half = B_in[7:4];
|
||||||
end
|
end
|
||||||
2'b11: begin
|
2'b11: begin
|
||||||
A_half = { A_in[7:6], A_in[3:2] };
|
half.A_half = { A_in[7:6], A_in[3:2] };
|
||||||
B_half = B_in[7:4];
|
half.B_half = B_in[7:4];
|
||||||
end
|
end
|
||||||
endcase
|
endcase
|
||||||
C_half = C_in;
|
half.C_half = C_in;
|
||||||
end
|
return half;
|
||||||
|
endfunction
|
||||||
|
|
||||||
logic substep;
|
half_t halves;
|
||||||
wire operands_fire = operands_ready && operands_valid;
|
half_t halves_buf;
|
||||||
wire substep_n = operands_fire && operands_last_in_pair;
|
assign halves = get_operand_half(operands_step, A_in, B_in, C_in);
|
||||||
|
assign halves_buf = get_operand_half(operands_step_buf, A_in_buf, B_in_buf, C_in_buf);
|
||||||
|
|
||||||
|
wire do_hmma = operands_ready_buf && operands_valid_buf && operands_last_in_pair_buf;
|
||||||
|
wire operands_first_in_pair_fire = operands_ready && operands_valid && (!operands_last_in_pair);
|
||||||
|
// wire operands_first_in_pair_fire = operands_ready && operands_valid;
|
||||||
|
|
||||||
always @(*) begin
|
always @(*) begin
|
||||||
A_buffer_n = A_buffer;
|
A_buffer_n = A_buffer;
|
||||||
@@ -333,20 +420,15 @@ module VX_tensor_octet #(
|
|||||||
C_buffer_n = C_buffer;
|
C_buffer_n = C_buffer;
|
||||||
substeps_n = substeps;
|
substeps_n = substeps;
|
||||||
|
|
||||||
if (operands_fire) begin
|
if (operands_first_in_pair_fire) begin
|
||||||
substeps_n[operands_wid] = ~substeps[operands_wid];
|
substeps_n[operands_wid] = 1'b1; // ready for hmma
|
||||||
if (!operands_last_in_pair) begin
|
A_buffer_n[operands_wid] = halves.A_half;
|
||||||
A_buffer_n[operands_wid] = A_half;
|
B_buffer_n[operands_wid] = halves.B_half;
|
||||||
B_buffer_n[operands_wid] = B_half;
|
C_buffer_n[operands_wid] = halves.C_half;
|
||||||
C_buffer_n[operands_wid] = C_half;
|
end
|
||||||
end
|
if (do_hmma) begin
|
||||||
|
substeps_n[operands_wid_buf] = 1'b0; // finished hmma, ready for next operand
|
||||||
end
|
end
|
||||||
|
|
||||||
// if (operands_fire && (substep == 1'b0)) begin
|
|
||||||
// A_buffer_n[operands_wid] = A_half;
|
|
||||||
// B_buffer_n[operands_wid] = B_half;
|
|
||||||
// C_buffer_n[operands_wid] = C_half;
|
|
||||||
// end
|
|
||||||
end
|
end
|
||||||
|
|
||||||
always @(posedge clk) begin
|
always @(posedge clk) begin
|
||||||
@@ -354,43 +436,39 @@ module VX_tensor_octet #(
|
|||||||
A_buffer <= '0;
|
A_buffer <= '0;
|
||||||
B_buffer <= '0;
|
B_buffer <= '0;
|
||||||
C_buffer <= '0;
|
C_buffer <= '0;
|
||||||
|
|
||||||
substep <= '0;
|
|
||||||
substeps <= '0;
|
substeps <= '0;
|
||||||
end
|
end
|
||||||
else begin
|
else begin
|
||||||
A_buffer <= A_buffer_n;
|
A_buffer <= A_buffer_n;
|
||||||
B_buffer <= B_buffer_n;
|
B_buffer <= B_buffer_n;
|
||||||
C_buffer <= C_buffer_n;
|
C_buffer <= C_buffer_n;
|
||||||
|
|
||||||
substep <= substep_n;
|
|
||||||
substeps <= substeps_n;
|
substeps <= substeps_n;
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
wire hmma_ready;
|
|
||||||
wire outbuf_ready_in;
|
wire outbuf_ready_in;
|
||||||
// wire stall = result_valid && ~result_ready;
|
|
||||||
// backpressure from commit
|
// backpressure from commit
|
||||||
wire stall = ~outbuf_ready_in;
|
wire stall = ~outbuf_ready_in;
|
||||||
|
wire hmma_ready;
|
||||||
|
|
||||||
// assign operands_ready = ~stall;
|
// assign operands_ready = ~stall;
|
||||||
// TODO: Below line is to only allow 1 warp to occupy the octet at a time;
|
// TODO: Below line is to only allow 1 warp to occupy the octet at a time;
|
||||||
// currently, dpu is fully-pipelined and allows concurrency between
|
// currently, dpu is fully-pipelined and allows concurrency between
|
||||||
// multiple warps. This seems to be not a problem though given that the
|
// multiple warps. This seems to be not a problem though given that the
|
||||||
// RF operand read takes >=2 cycles, which should be the end-to-end
|
// RF operand read takes >=2 cycles, which should be the end-to-end
|
||||||
// latency of the DPU anyways
|
// latency of the DPU anyways
|
||||||
assign operands_ready = hmma_ready && ~stall;
|
assign operands_ready_buf = hmma_ready && ~stall;
|
||||||
|
|
||||||
// A is 4x2 fp32 matrix
|
// A is 4x2 fp32 matrix
|
||||||
wire [3:0][1:0][31:0] A_tile = {
|
wire [3:0][1:0][31:0] A_tile = {
|
||||||
{ A_half[3], A_buffer[operands_wid][3] },
|
{ halves_buf.A_half[3], A_buffer[operands_wid_buf][3] },
|
||||||
{ A_half[2], A_buffer[operands_wid][2] },
|
{ halves_buf.A_half[2], A_buffer[operands_wid_buf][2] },
|
||||||
{ A_half[1], A_buffer[operands_wid][1] },
|
{ halves_buf.A_half[1], A_buffer[operands_wid_buf][1] },
|
||||||
{ A_half[0], A_buffer[operands_wid][0] }
|
{ halves_buf.A_half[0], A_buffer[operands_wid_buf][0] }
|
||||||
};
|
};
|
||||||
// B is 2x4 fp32 matrix
|
// B is 2x4 fp32 matrix
|
||||||
wire [1:0][3:0][31:0] B_tile = {
|
wire [1:0][3:0][31:0] B_tile = {
|
||||||
B_half, B_buffer[operands_wid]
|
halves_buf.B_half, B_buffer[operands_wid_buf]
|
||||||
};
|
};
|
||||||
// C is 4x4 fp32 matrix
|
// C is 4x4 fp32 matrix
|
||||||
logic [3:0][3:0][31:0] C_tile;
|
logic [3:0][3:0][31:0] C_tile;
|
||||||
@@ -398,14 +476,12 @@ module VX_tensor_octet #(
|
|||||||
logic [`NW_WIDTH-1:0] D_wid_dpu;
|
logic [`NW_WIDTH-1:0] D_wid_dpu;
|
||||||
|
|
||||||
always @(*) begin
|
always @(*) begin
|
||||||
C_tile[3] = { C_half[7], C_buffer[operands_wid][7], C_half[5], C_buffer[operands_wid][5] };
|
C_tile[3] = { halves_buf.C_half[7], C_buffer[operands_wid_buf][7], halves_buf.C_half[5], C_buffer[operands_wid_buf][5] };
|
||||||
C_tile[2] = { C_half[6], C_buffer[operands_wid][6], C_half[4], C_buffer[operands_wid][4] };
|
C_tile[2] = { halves_buf.C_half[6], C_buffer[operands_wid_buf][6], halves_buf.C_half[4], C_buffer[operands_wid_buf][4] };
|
||||||
C_tile[1] = { C_half[3], C_buffer[operands_wid][3], C_half[1], C_buffer[operands_wid][1] };
|
C_tile[1] = { halves_buf.C_half[3], C_buffer[operands_wid_buf][3], halves_buf.C_half[1], C_buffer[operands_wid_buf][1] };
|
||||||
C_tile[0] = { C_half[2], C_buffer[operands_wid][2], C_half[0], C_buffer[operands_wid][0] };
|
C_tile[0] = { halves_buf.C_half[2], C_buffer[operands_wid_buf][2], halves_buf.C_half[0], C_buffer[operands_wid_buf][0] };
|
||||||
end
|
end
|
||||||
|
|
||||||
// wire do_hmma = operands_fire && (substeps[operands_wid] == 1'b1);
|
|
||||||
wire do_hmma = operands_fire && operands_last_in_pair;
|
|
||||||
wire dpu_valid;
|
wire dpu_valid;
|
||||||
|
|
||||||
// this does (m,n,k)=(4,4,2) matmul, modeling compute of a single octet
|
// this does (m,n,k)=(4,4,2) matmul, modeling compute of a single octet
|
||||||
@@ -423,7 +499,7 @@ module VX_tensor_octet #(
|
|||||||
.A_tile(A_tile),
|
.A_tile(A_tile),
|
||||||
.B_tile(B_tile),
|
.B_tile(B_tile),
|
||||||
.C_tile(C_tile),
|
.C_tile(C_tile),
|
||||||
.wid(operands_wid),
|
.wid(operands_wid_buf),
|
||||||
|
|
||||||
.valid_out(dpu_valid),
|
.valid_out(dpu_valid),
|
||||||
.D_tile(D_tile),
|
.D_tile(D_tile),
|
||||||
@@ -438,14 +514,14 @@ module VX_tensor_octet #(
|
|||||||
wire outbuf_enq = outbuf_ready_in && dpu_valid;
|
wire outbuf_enq = outbuf_ready_in && dpu_valid;
|
||||||
wire outbuf_deq = result_valid && result_ready;
|
wire outbuf_deq = result_valid && result_ready;
|
||||||
|
|
||||||
// buffer to stage the result tile for 2 cycles until commit/writeback is
|
// buffer to stage the result D tile for 2 cycles until commit/writeback
|
||||||
// complete. This decouples the irregular dpu output traffic from the
|
// is complete. This decouples the irregular dpu output traffic from the
|
||||||
// regular, every-2-cycle commit traffic and thereby ensures the commit
|
// regular, every-2-cycle commit traffic to ensure the commit pipeline is
|
||||||
// pipeline is used more efficiently.
|
// used more efficiently.
|
||||||
// TODO: This is probably oversized.
|
// TODO: This is probably oversized.
|
||||||
VX_fifo_queue #(
|
VX_fifo_queue #(
|
||||||
.DATAW ($bits(D_wid) + $bits(D_out)),
|
.DATAW ($bits(D_wid) + $bits(D_out)),
|
||||||
.DEPTH (8 /* FIXME: arbitrary */)
|
.DEPTH (4 /* FIXME: arbitrary */)
|
||||||
) output_buffer (
|
) output_buffer (
|
||||||
.clk (clk),
|
.clk (clk),
|
||||||
.reset (reset),
|
.reset (reset),
|
||||||
|
|||||||
@@ -51,7 +51,7 @@ module VX_tensor_dpu #(
|
|||||||
.clk (clk),
|
.clk (clk),
|
||||||
.reset (reset),
|
.reset (reset),
|
||||||
.enable (~stall),
|
.enable (~stall),
|
||||||
.data_in ({valid_in, wid, result_hmma}),
|
.data_in ({valid_in && ready_in, wid, result_hmma}),
|
||||||
.data_out ({valid_out, D_wid, D_tile})
|
.data_out ({valid_out, D_wid, D_tile})
|
||||||
);
|
);
|
||||||
endmodule
|
endmodule
|
||||||
|
|||||||
Reference in New Issue
Block a user