diff --git a/hw/rtl/core/VX_tensor_core.sv b/hw/rtl/core/VX_tensor_core.sv index 2fc54fc5..71b17e08 100644 --- a/hw/rtl/core/VX_tensor_core.sv +++ b/hw/rtl/core/VX_tensor_core.sv @@ -77,6 +77,7 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #( // octet. E.g. two tgs map lane 0-3 and lane 16-19 -> 16 // FIXME: not sure this is the right logic. just filling in what works localparam LANE_OFFSET_THREADGROUP = (4 * NUM_OCTETS); + localparam REQ_QUEUE_DEPTH = 4; wire [1:0] step = 2'(execute_if.data.op_type); wire last_in_pair = (execute_if.data.op_mod == `INST_MOD_BITS'(1)); @@ -219,7 +220,7 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #( VX_fifo_queue #( .DATAW(DATAW), - .DEPTH(8 /* FIXME: arbitrary */) + .DEPTH(REQ_QUEUE_DEPTH) ) pending_uops ( .clk(clk), .reset(reset), @@ -234,6 +235,8 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #( `UNUSED_PIN(size) ); + // this shouldn't really happen unless there's a big contention over + // the commit stage `RUNTIME_ASSERT(!(!reset && full), ("tensor core uop queue is full!")); end @@ -300,6 +303,8 @@ module VX_tensor_octet #( output result_valid, input result_ready ); + localparam ISSUE_QUEUE_DEPTH = 4; + // 512 bits/octet * 4 octets per warp logic [`NUM_WARPS-1:0][3:0][31:0] A_buffer, A_buffer_n; logic [`NUM_WARPS-1:0][3:0][31:0] B_buffer, B_buffer_n; @@ -351,7 +356,7 @@ module VX_tensor_octet #( VX_fifo_queue #( .DATAW ($bits(A_in) + $bits(B_in) + $bits(C_in) + $bits(operands_wid) + $bits(operands_step) + $bits(operands_last_in_pair)), - .DEPTH (8 /* FIXME: arbitrary */) + .DEPTH (ISSUE_QUEUE_DEPTH) ) input_buffer ( .clk (clk), .reset (reset), @@ -451,17 +456,8 @@ module VX_tensor_octet #( end wire outbuf_ready_in; - // backpressure from commit - wire stall = ~outbuf_ready_in; wire hmma_ready; - - // assign operands_ready = ~stall; - // TODO: Below line is to only allow 1 warp to occupy the octet at a time; - // currently, dpu is fully-pipelined and allows concurrency between - // multiple warps. This seems to be not a problem though given that the - // RF operand read takes >=2 cycles, which should be the end-to-end - // latency of the DPU anyways - assign operands_ready_buf = hmma_ready && ~stall; + assign operands_ready_buf = hmma_ready; // A is 4x2 fp32 matrix wire [3:0][1:0][31:0] A_tile = { @@ -496,8 +492,6 @@ module VX_tensor_octet #( .clk(clk), .reset(reset), - .stall(stall), - .valid_in(do_hmma), .ready_in(hmma_ready), .A_tile(A_tile), @@ -506,12 +500,14 @@ module VX_tensor_octet #( .wid(operands_wid_buf), .valid_out(dpu_valid), + .ready_out(outbuf_ready_in), .D_tile(D_tile), .D_wid(D_wid_dpu) ); wire outbuf_empty; wire outbuf_full; + // backpressure from commit assign outbuf_ready_in = ~outbuf_full; assign result_valid = ~outbuf_empty; @@ -525,7 +521,10 @@ module VX_tensor_octet #( // TODO: This is probably oversized. VX_fifo_queue #( .DATAW ($bits(D_wid) + $bits(D_out)), - .DEPTH (8 /* FIXME: arbitrary */) + // depth of this queue should ideally be deeper than the dpu pipeline + // latency, since the dpu is fully-pipelined and it can output the + // latency-number of outputs in a burst-y way. + .DEPTH (`LATENCY_HMMA + `LATENCY_HMMA) ) output_buffer ( .clk (clk), .reset (reset), diff --git a/hw/rtl/fpu/VX_tensor_dpu.sv b/hw/rtl/fpu/VX_tensor_dpu.sv index 90c2c7ed..51112c96 100644 --- a/hw/rtl/fpu/VX_tensor_dpu.sv +++ b/hw/rtl/fpu/VX_tensor_dpu.sv @@ -8,8 +8,6 @@ module VX_tensor_dpu #( input clk, input reset, - input stall, - input valid_in, output ready_in, input [3:0][1:0][31:0] A_tile, @@ -18,6 +16,7 @@ module VX_tensor_dpu #( input [`NW_WIDTH-1:0] wid, output valid_out, + input ready_out, output [3:0][3:0][31:0] D_tile, output [`NW_WIDTH-1:0] D_wid ); @@ -40,10 +39,11 @@ module VX_tensor_dpu #( end // ready as soon as valid_out - assign ready_in = ready_reg || valid_out; + // assign ready_in = ready_reg || valid_out; - // fully pipelined; always ready - // assign ready_in = 1'b1; + // fully pipelined; ready_in is coupled to ready_out by immediately + // stalling + assign ready_in = ready_out; // wire dpu_valid; // wire [31:0] dpu_data; @@ -70,8 +70,8 @@ module VX_tensor_dpu #( ) threadgroup_0 ( .clk (clk), .reset (reset), - .valid_in (valid_in && ready_in), - .stall (stall), + .valid_in (valid_in), + .stall (!ready_out), .A_frag (A_tile[1:0]), .B_frag (B_tile), .C_frag (C_tile[1:0]), @@ -82,8 +82,8 @@ module VX_tensor_dpu #( ) threadgroup_1 ( .clk (clk), .reset (reset), - .valid_in (valid_in && ready_in), - .stall (stall), + .valid_in (valid_in), + .stall (!ready_out), .A_frag (A_tile[3:2]), .B_frag (B_tile), .C_frag (C_tile[3:2]), @@ -94,18 +94,16 @@ module VX_tensor_dpu #( // fixed-latency queue VX_shift_register #( .DATAW (1 + $bits(wid)/* + $bits(D_tile)*/), - // .DEPTH (`LATENCY_HMMA), - .DEPTH (4), + .DEPTH (`LATENCY_HMMA), .RESETW (1) ) shift_reg ( .clk (clk), .reset (reset), - .enable (~stall), + .enable (ready_out), .data_in ({valid_in && ready_in, wid /*, result_hmma*/}), .data_out ({valid_out, D_wid/*, D_tile */}) ); - // FIXME: breaks when stall is on! `RUNTIME_ASSERT(reset || (&(threadgroup_valids) == valid_out), ("FEDP and metadata queue went out of sync!")) endmodule @@ -146,7 +144,7 @@ module VX_tensor_threadgroup #( .io_in_bits_b_2 (32'h0), .io_in_bits_b_3 (32'h0), .io_in_bits_c (C_frag[D_row][D_col]), - .io_stall (1'b0), // FIXME + .io_stall (stall), .io_out_valid (valids[D_row][D_col]), .io_out_bits_data (D_frag[D_row][D_col]) ); @@ -154,8 +152,6 @@ module VX_tensor_threadgroup #( end assign valid_out = (&(valids[0])) && (&(valids[1])); - - `RUNTIME_ASSERT(reset || !stall, ("stall not supported yet in tensor dpu!")) endmodule `endif