tensor: Fix consecutive commits to write to same warp

... by splitting the pending_uops queue across warps.
This commit is contained in:
Hansung Kim
2024-05-25 20:04:31 -07:00
parent 5a95eba1f5
commit 864265bda5
2 changed files with 53 additions and 33 deletions

View File

@@ -32,10 +32,6 @@ module VX_tensor_core import VX_gpu_pkg::*; #(
.execute_if (execute_if) .execute_if (execute_if)
); );
// FIXME: when multiple warps are running, step0_0 from multiple warps can
// get interleaved before the first warp advances to step0_1, fucking
// everything up
VX_commit_if #( VX_commit_if #(
.NUM_LANES (NUM_LANES) .NUM_LANES (NUM_LANES)
) commit_block_if[BLOCK_SIZE](); ) commit_block_if[BLOCK_SIZE]();
@@ -83,7 +79,7 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #(
localparam LANE_OFFSET_THREADGROUP = (4 * NUM_OCTETS); localparam LANE_OFFSET_THREADGROUP = (4 * NUM_OCTETS);
wire [1:0] step = 2'(execute_if.data.op_type); wire [1:0] step = 2'(execute_if.data.op_type);
wire operands_last_in_pair = (execute_if.data.op_mod == `INST_MOD_BITS'(1)); wire last_in_pair = (execute_if.data.op_mod == `INST_MOD_BITS'(1));
logic [NUM_OCTETS-1:0] octet_results_valid; logic [NUM_OCTETS-1:0] octet_results_valid;
logic [NUM_OCTETS-1:0] octet_results_ready; logic [NUM_OCTETS-1:0] octet_results_ready;
@@ -91,6 +87,7 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #(
// FIXME: should be NUM_LANES? // FIXME: should be NUM_LANES?
logic [`NUM_THREADS-1:0][`XLEN-1:0] wb_data_0; logic [`NUM_THREADS-1:0][`XLEN-1:0] wb_data_0;
logic [`NUM_THREADS-1:0][`XLEN-1:0] wb_data_1; logic [`NUM_THREADS-1:0][`XLEN-1:0] wb_data_1;
wire [`NW_WIDTH-1:0] wb_wid;
assign execute_if.ready = &octet_operands_ready; assign execute_if.ready = &octet_operands_ready;
@@ -127,12 +124,13 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #(
.C_in(octet_C), .C_in(octet_C),
.operands_valid(execute_if.valid), .operands_valid(execute_if.valid),
.operands_wid(execute_if.data.wid), .operands_wid(execute_if.data.wid),
.operands_last_in_pair(operands_last_in_pair), .operands_last_in_pair(last_in_pair),
.operands_ready(octet_operands_ready[i]), .operands_ready(octet_operands_ready[i]),
.step(step), .step(step),
.D_out(octet_D), .D_out(octet_D),
.D_wid(wb_wid),
.result_valid(result_valid), .result_valid(result_valid),
.result_ready(result_ready) .result_ready(result_ready)
); );
@@ -188,33 +186,49 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #(
// pid/sop/eop set later // pid/sop/eop set later
}; };
wire [DATAW-1:0] execute_if_data_deq; wire [`NUM_WARPS-1:0][DATAW-1:0] execute_if_data_deq;
// this is probably a little oversized for (genvar i = 0; i < `NUM_WARPS; i++) begin
VX_fifo_queue #( wire enq = execute_if_fire && (execute_if.data.wid == i);
.DATAW(DATAW), wire deq = commit_if_fire && (wb_wid == i);
.DEPTH(16) logic full;
) pending_uops (
.clk(clk),
.reset(reset),
.push(execute_if_fire),
.pop(commit_if_fire),
.data_in(execute_if_data_enq),
.data_out(execute_if_data_deq),
`UNUSED_PIN(empty),
`UNUSED_PIN(alm_empty),
`UNUSED_PIN(full), // should be impossible to overflow
`UNUSED_PIN(alm_full),
`UNUSED_PIN(size)
);
// execute_if request queue.
// This has to be separated per-warp, as otherwise requests from
// multiple warps can be enqueued interleaved, which makes it hard to
// ensure two consecutive dequeues are associated to the same warp for
// commit.
VX_fifo_queue #(
.DATAW(DATAW),
.DEPTH(4 /* FIXME: arbitrary */)
) pending_uops (
.clk(clk),
.reset(reset),
.push(enq),
.pop(deq),
.data_in(execute_if_data_enq),
.data_out(execute_if_data_deq[i]),
`UNUSED_PIN(empty),
`UNUSED_PIN(alm_empty),
.full(full), // should be impossible to overflow
`UNUSED_PIN(alm_full),
`UNUSED_PIN(size)
);
`RUNTIME_ASSERT(!full, ("tensor core uop queue is full!"));
end
// unlike execute which can be interleaved between warps, commit is
// serialized and completed one-warp-by-warp, therefore we only need to
// keep one subcommit state bit unlike for `substeps`
logic subcommit, subcommit_n; logic subcommit, subcommit_n;
wire all_valid = (& octet_results_valid); wire all_valid = (& octet_results_valid);
assign commit_if.valid = all_valid; assign commit_if.valid = all_valid;
localparam COMMIT_DATAW = `UUID_WIDTH + `NW_WIDTH + `NUM_THREADS + `XLEN + 1 + `NR_BITS + (`NUM_THREADS * `XLEN) + 1 + 1 + 1; localparam COMMIT_DATAW = `UUID_WIDTH + `NW_WIDTH + `NUM_THREADS + `XLEN + 1 + `NR_BITS + (`NUM_THREADS * `XLEN) + 1 + 1 + 1;
wire [COMMIT_DATAW-1:0] commit_if_data = { wire [COMMIT_DATAW-1:0] commit_if_data = {
execute_if_data_deq, /* uuid ~ rd */ execute_if_data_deq[wb_wid], /* uuid ~ rd */
subcommit == 1'b0 ? wb_data_0 : wb_data_1, /* data */ subcommit == 1'b0 ? wb_data_0 : wb_data_1, /* data */
1'b0, /* pid */ 1'b0, /* pid */
1'b1, /* sop */ 1'b1, /* sop */
@@ -263,6 +277,7 @@ module VX_tensor_octet #(
input [1:0] step, input [1:0] step,
output [3:0][3:0][31:0] D_out, output [3:0][3:0][31:0] D_out,
output [`NW_WIDTH-1:0] D_wid,
output result_valid, output result_valid,
input result_ready input result_ready
); );
@@ -380,6 +395,7 @@ module VX_tensor_octet #(
// C is 4x4 fp32 matrix // C is 4x4 fp32 matrix
logic [3:0][3:0][31:0] C_tile; logic [3:0][3:0][31:0] C_tile;
logic [3:0][3:0][31:0] D_tile; logic [3:0][3:0][31:0] D_tile;
logic [`NW_WIDTH-1:0] D_warp_id;
always @(*) begin always @(*) begin
C_tile[3] = { C_half[7], C_buffer[operands_wid][7], C_half[5], C_buffer[operands_wid][5] }; C_tile[3] = { C_half[7], C_buffer[operands_wid][7], C_half[5], C_buffer[operands_wid][5] };
@@ -407,23 +423,25 @@ module VX_tensor_octet #(
.A_tile(A_tile), .A_tile(A_tile),
.B_tile(B_tile), .B_tile(B_tile),
.C_tile(C_tile), .C_tile(C_tile),
.warp_id(operands_wid),
.valid_out(dpu_valid), .valid_out(dpu_valid),
.D_tile(D_tile) .D_tile(D_tile),
.D_warp_id(D_warp_id)
); );
// buffer to stage the result tile for 2 cycles until commit/writeback is // buffer to stage the result tile for 2 cycles until commit/writeback is
// complete // complete
VX_stream_buffer #( VX_stream_buffer #(
.DATAW ($bits(D_out)), .DATAW ($bits(D_wid) + $bits(D_out)),
.OUT_REG (1) // not sure this is necessary .OUT_REG (1) // not sure this is necessary
) output_buffer ( ) output_buffer (
.clk (clk), .clk (clk),
.reset (reset), .reset (reset),
.valid_in (dpu_valid), .valid_in (dpu_valid),
.ready_in (outbuf_ready_in), .ready_in (outbuf_ready_in),
.data_in (D_tile), .data_in ({D_warp_id, D_tile}),
.data_out (D_out), .data_out ({D_wid, D_out}),
.ready_out (result_ready), .ready_out (result_ready),
.valid_out (result_valid) .valid_out (result_valid)
); );

View File

@@ -15,9 +15,11 @@ module VX_tensor_dpu #(
input [3:0][1:0][31:0] A_tile, input [3:0][1:0][31:0] A_tile,
input [1:0][3:0][31:0] B_tile, input [1:0][3:0][31:0] B_tile,
input [3:0][3:0][31:0] C_tile, input [3:0][3:0][31:0] C_tile,
input [`NW_WIDTH-1:0] warp_id,
output valid_out, output valid_out,
output [3:0][3:0][31:0] D_tile output [3:0][3:0][31:0] D_tile,
output [`NW_WIDTH-1:0] D_warp_id
); );
logic [3:0][3:0][31:0] result_hmma; logic [3:0][3:0][31:0] result_hmma;
@@ -42,15 +44,15 @@ module VX_tensor_dpu #(
// fixed-latency model // fixed-latency model
VX_shift_register #( VX_shift_register #(
.DATAW (1 + $bits(D_tile)), .DATAW (1 + $bits(warp_id) + $bits(D_tile)),
.DEPTH (`LATENCY_HMMA), .DEPTH (`LATENCY_HMMA),
.RESETW (1) .RESETW (1)
) shift_reg ( ) shift_reg (
.clk (clk), .clk (clk),
.reset (reset), .reset (reset),
.enable (~stall), .enable (~stall),
.data_in ({valid_in, result_hmma}), .data_in ({valid_in, warp_id, result_hmma}),
.data_out ({valid_out, D_tile}) .data_out ({valid_out, D_warp_id, D_tile})
); );
endmodule endmodule
`endif `endif