From 864265bda5ee5115d0de15939ea59ba92145295b Mon Sep 17 00:00:00 2001 From: Hansung Kim Date: Sat, 25 May 2024 20:04:31 -0700 Subject: [PATCH] tensor: Fix consecutive commits to write to same warp ... by splitting the pending_uops queue across warps. --- hw/rtl/core/VX_tensor_core.sv | 76 ++++++++++++++++++++++------------- hw/rtl/fpu/VX_tensor_dpu.sv | 10 +++-- 2 files changed, 53 insertions(+), 33 deletions(-) diff --git a/hw/rtl/core/VX_tensor_core.sv b/hw/rtl/core/VX_tensor_core.sv index b6b11754..d1ee3b38 100644 --- a/hw/rtl/core/VX_tensor_core.sv +++ b/hw/rtl/core/VX_tensor_core.sv @@ -32,10 +32,6 @@ module VX_tensor_core import VX_gpu_pkg::*; #( .execute_if (execute_if) ); - // FIXME: when multiple warps are running, step0_0 from multiple warps can - // get interleaved before the first warp advances to step0_1, fucking - // everything up - VX_commit_if #( .NUM_LANES (NUM_LANES) ) commit_block_if[BLOCK_SIZE](); @@ -83,7 +79,7 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #( localparam LANE_OFFSET_THREADGROUP = (4 * NUM_OCTETS); wire [1:0] step = 2'(execute_if.data.op_type); - wire operands_last_in_pair = (execute_if.data.op_mod == `INST_MOD_BITS'(1)); + wire last_in_pair = (execute_if.data.op_mod == `INST_MOD_BITS'(1)); logic [NUM_OCTETS-1:0] octet_results_valid; logic [NUM_OCTETS-1:0] octet_results_ready; @@ -91,6 +87,7 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #( // FIXME: should be NUM_LANES? logic [`NUM_THREADS-1:0][`XLEN-1:0] wb_data_0; logic [`NUM_THREADS-1:0][`XLEN-1:0] wb_data_1; + wire [`NW_WIDTH-1:0] wb_wid; assign execute_if.ready = &octet_operands_ready; @@ -127,12 +124,13 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #( .C_in(octet_C), .operands_valid(execute_if.valid), .operands_wid(execute_if.data.wid), - .operands_last_in_pair(operands_last_in_pair), + .operands_last_in_pair(last_in_pair), .operands_ready(octet_operands_ready[i]), .step(step), .D_out(octet_D), + .D_wid(wb_wid), .result_valid(result_valid), .result_ready(result_ready) ); @@ -188,33 +186,49 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #( // pid/sop/eop set later }; - wire [DATAW-1:0] execute_if_data_deq; + wire [`NUM_WARPS-1:0][DATAW-1:0] execute_if_data_deq; - // this is probably a little oversized - VX_fifo_queue #( - .DATAW(DATAW), - .DEPTH(16) - ) pending_uops ( - .clk(clk), - .reset(reset), - .push(execute_if_fire), - .pop(commit_if_fire), - .data_in(execute_if_data_enq), - .data_out(execute_if_data_deq), - `UNUSED_PIN(empty), - `UNUSED_PIN(alm_empty), - `UNUSED_PIN(full), // should be impossible to overflow - `UNUSED_PIN(alm_full), - `UNUSED_PIN(size) - ); + for (genvar i = 0; i < `NUM_WARPS; i++) begin + wire enq = execute_if_fire && (execute_if.data.wid == i); + wire deq = commit_if_fire && (wb_wid == i); + logic full; + // execute_if request queue. + // This has to be separated per-warp, as otherwise requests from + // multiple warps can be enqueued interleaved, which makes it hard to + // ensure two consecutive dequeues are associated to the same warp for + // commit. + VX_fifo_queue #( + .DATAW(DATAW), + .DEPTH(4 /* FIXME: arbitrary */) + ) pending_uops ( + .clk(clk), + .reset(reset), + .push(enq), + .pop(deq), + .data_in(execute_if_data_enq), + .data_out(execute_if_data_deq[i]), + `UNUSED_PIN(empty), + `UNUSED_PIN(alm_empty), + .full(full), // should be impossible to overflow + `UNUSED_PIN(alm_full), + `UNUSED_PIN(size) + ); + + `RUNTIME_ASSERT(!full, ("tensor core uop queue is full!")); + end + + // unlike execute which can be interleaved between warps, commit is + // serialized and completed one-warp-by-warp, therefore we only need to + // keep one subcommit state bit unlike for `substeps` logic subcommit, subcommit_n; + wire all_valid = (& octet_results_valid); assign commit_if.valid = all_valid; localparam COMMIT_DATAW = `UUID_WIDTH + `NW_WIDTH + `NUM_THREADS + `XLEN + 1 + `NR_BITS + (`NUM_THREADS * `XLEN) + 1 + 1 + 1; wire [COMMIT_DATAW-1:0] commit_if_data = { - execute_if_data_deq, /* uuid ~ rd */ + execute_if_data_deq[wb_wid], /* uuid ~ rd */ subcommit == 1'b0 ? wb_data_0 : wb_data_1, /* data */ 1'b0, /* pid */ 1'b1, /* sop */ @@ -263,6 +277,7 @@ module VX_tensor_octet #( input [1:0] step, output [3:0][3:0][31:0] D_out, + output [`NW_WIDTH-1:0] D_wid, output result_valid, input result_ready ); @@ -380,6 +395,7 @@ module VX_tensor_octet #( // C is 4x4 fp32 matrix logic [3:0][3:0][31:0] C_tile; logic [3:0][3:0][31:0] D_tile; + logic [`NW_WIDTH-1:0] D_warp_id; always @(*) begin C_tile[3] = { C_half[7], C_buffer[operands_wid][7], C_half[5], C_buffer[operands_wid][5] }; @@ -407,23 +423,25 @@ module VX_tensor_octet #( .A_tile(A_tile), .B_tile(B_tile), .C_tile(C_tile), + .warp_id(operands_wid), .valid_out(dpu_valid), - .D_tile(D_tile) + .D_tile(D_tile), + .D_warp_id(D_warp_id) ); // buffer to stage the result tile for 2 cycles until commit/writeback is // complete VX_stream_buffer #( - .DATAW ($bits(D_out)), + .DATAW ($bits(D_wid) + $bits(D_out)), .OUT_REG (1) // not sure this is necessary ) output_buffer ( .clk (clk), .reset (reset), .valid_in (dpu_valid), .ready_in (outbuf_ready_in), - .data_in (D_tile), - .data_out (D_out), + .data_in ({D_warp_id, D_tile}), + .data_out ({D_wid, D_out}), .ready_out (result_ready), .valid_out (result_valid) ); diff --git a/hw/rtl/fpu/VX_tensor_dpu.sv b/hw/rtl/fpu/VX_tensor_dpu.sv index 4130fb98..1ffbb6d3 100644 --- a/hw/rtl/fpu/VX_tensor_dpu.sv +++ b/hw/rtl/fpu/VX_tensor_dpu.sv @@ -15,9 +15,11 @@ module VX_tensor_dpu #( input [3:0][1:0][31:0] A_tile, input [1:0][3:0][31:0] B_tile, input [3:0][3:0][31:0] C_tile, + input [`NW_WIDTH-1:0] warp_id, output valid_out, - output [3:0][3:0][31:0] D_tile + output [3:0][3:0][31:0] D_tile, + output [`NW_WIDTH-1:0] D_warp_id ); logic [3:0][3:0][31:0] result_hmma; @@ -42,15 +44,15 @@ module VX_tensor_dpu #( // fixed-latency model VX_shift_register #( - .DATAW (1 + $bits(D_tile)), + .DATAW (1 + $bits(warp_id) + $bits(D_tile)), .DEPTH (`LATENCY_HMMA), .RESETW (1) ) shift_reg ( .clk (clk), .reset (reset), .enable (~stall), - .data_in ({valid_in, result_hmma}), - .data_out ({valid_out, D_tile}) + .data_in ({valid_in, warp_id, result_hmma}), + .data_out ({valid_out, D_warp_id, D_tile}) ); endmodule `endif