diff --git a/hw/rtl/core/VX_tensor_core.sv b/hw/rtl/core/VX_tensor_core.sv index e1781e4c..a1f4c937 100644 --- a/hw/rtl/core/VX_tensor_core.sv +++ b/hw/rtl/core/VX_tensor_core.sv @@ -287,7 +287,7 @@ module VX_tensor_core_block import VX_gpu_pkg::*; #( end VX_tensor_reg #( - .N(1) + .DATAW(1) ) staging_subcommit ( .clk(clk), .reset(reset), @@ -298,15 +298,15 @@ module VX_tensor_core_block import VX_gpu_pkg::*; #( endmodule module VX_tensor_reg #( - parameter N + parameter DATAW ) ( input clk, input reset, - input [N-1:0] d, + input [DATAW-1:0] d, input en, - output [N-1:0] q + output [DATAW-1:0] q ); - logic [N-1:0] data; + logic [DATAW-1:0] data; always @(posedge clk) begin if (reset) begin @@ -436,7 +436,7 @@ module VX_tensor_octet #( // Staging buffer for the A/B/C half-tiles that will later be assembled // with the other half tiles coming in on the input ports. VX_tensor_reg #( - .N($bits(A_buffer) + $bits(B_buffer) + $bits(C_buffer)) + .DATAW($bits(A_buffer) + $bits(B_buffer) + $bits(C_buffer)) ) staging_abc ( .clk(clk), .reset(reset), @@ -464,7 +464,7 @@ module VX_tensor_octet #( end VX_tensor_reg #( - .N($bits(substeps)) + .DATAW($bits(substeps)) ) staging_substeps ( .clk(clk), .reset(reset), @@ -506,7 +506,7 @@ module VX_tensor_octet #( VX_tensor_dpu #( .ISW(ISW), .OCTET(OCTET), - .ISSUE_QUEUE_DEPTH(4 /*@perf: arbtirary*/) + .OPERAND_BUFFER_DEPTH(4 /*@perf: arbtirary*/) ) dpu ( .clk(clk), .reset(reset), diff --git a/hw/rtl/fpu/VX_tensor_dpu.sv b/hw/rtl/fpu/VX_tensor_dpu.sv index 417da2ab..aabe0105 100644 --- a/hw/rtl/fpu/VX_tensor_dpu.sv +++ b/hw/rtl/fpu/VX_tensor_dpu.sv @@ -7,7 +7,7 @@ module VX_tensor_dpu #( // @perf: has big impact on throughput. A rule of thumb is to set it to // the pipeline length of FEDPs in order to make sure there are enough // entries to fully saturate the pipeline, but this is still rough - parameter ISSUE_QUEUE_DEPTH = `LATENCY_HMMA + parameter OPERAND_BUFFER_DEPTH = `LATENCY_HMMA ) ( input clk, input reset, @@ -51,65 +51,111 @@ module VX_tensor_dpu #( // stalling // assign ready_in = ready_out; - wire synced_fire; - assign synced_fire = valid_in && ready_in; + wire [3:0][1:0][31:0] A_tile_buf; + wire [1:0][3:0][31:0] B_tile_buf; + wire [3:0][3:0][31:0] C_tile_buf; - wire [1:0] threadgroup_valids; - wire [1:0] threadgroup_readys; - // B_tile is shared across the two threadgroups; see Figure 13 - VX_tensor_threadgroup #( - .ISSUE_QUEUE_DEPTH(ISSUE_QUEUE_DEPTH) - ) threadgroup_0 ( - .clk (clk), - .reset (reset), - .valid_in (synced_fire), - .ready_in (threadgroup_readys[0]), - .stall (!ready_out), - .A_frag (A_tile[1:0]), - .B_frag (B_tile), - .C_frag (C_tile[1:0]), - .valid_out (threadgroup_valids[0]), - .D_frag (D_tile[1:0]) - ); - VX_tensor_threadgroup #( - .ISSUE_QUEUE_DEPTH(ISSUE_QUEUE_DEPTH) - ) threadgroup_1 ( - .clk (clk), - .reset (reset), - .valid_in (synced_fire), - .ready_in (threadgroup_readys[1]), - .stall (!ready_out), - .A_frag (A_tile[3:2]), - .B_frag (B_tile), - .C_frag (C_tile[3:2]), - .valid_out (threadgroup_valids[1]), - .D_frag (D_tile[3:2]) - ); + wire wid_empty; + wire wid_full; wire empty; wire full; - wire enq = valid_in && ready_in; - wire deq = valid_out && ready_out; + // sync between operand buffer and wid buffer + assign ready_in = !full && !wid_full; - assign ready_in = &(threadgroup_readys) && !full; - assign valid_out = &(threadgroup_valids); + wire [1:0] threadgroup_valids_out; + wire [1:0] threadgroup_readys_in; + // sync operand queue and wid queue + wire threadgroup_valid_in = !empty; + wire threadgroup_fire_in = threadgroup_valid_in && &(threadgroup_readys_in); + + wire enq = valid_in && ready_in; + wire deq = threadgroup_fire_in; + + // Operand buffer for the dot product units. + // + // This exists to decouple the execution of the dot-product unit from + // the operand arrival. Operands from the upstream execute_if can arrive + // intermittently depending on the frontend's behavior, whereas downstream + // writeback happens at a regular cadence. Therefore to achieve full + // throughput of the dpu, we need to decouple the operand arrival from the + // direct input to the dpu. + VX_fifo_queue #( + .DATAW ($bits(A_tile) + $bits(B_tile) + $bits(C_tile)), + .DEPTH (OPERAND_BUFFER_DEPTH) + ) operand_buffer ( + .clk (clk), + .reset (reset), + .push (enq), + .pop (deq), + .data_in ({A_tile, B_tile, C_tile}), + .data_out ({A_tile_buf, B_tile_buf, C_tile_buf}), + .empty (empty), + `UNUSED_PIN(alm_empty), + .full (full), + `UNUSED_PIN(alm_full), + `UNUSED_PIN(size) + ); + + // Split A_tile and C_tile by rows (0-1, 2-3) and parallelize in two + // threadgroups + // + // B_tile is shared across the two threadgroups; see Figure 13 + VX_tensor_threadgroup #( + .OPERAND_BUFFER_DEPTH(OPERAND_BUFFER_DEPTH) + ) threadgroup_0 ( + .clk (clk), + .reset (reset), + .valid_in (threadgroup_valid_in), + .ready_in (threadgroup_readys_in[0]), + .stall (!ready_out), + .A_frag (A_tile_buf[1:0]), + .B_frag (B_tile_buf), + .C_frag (C_tile_buf[1:0]), + .valid_out (threadgroup_valids_out[0]), + .D_frag (D_tile[1:0]) + ); + VX_tensor_threadgroup #( + .OPERAND_BUFFER_DEPTH(OPERAND_BUFFER_DEPTH) + ) threadgroup_1 ( + .clk (clk), + .reset (reset), + .valid_in (threadgroup_valid_in), + .ready_in (threadgroup_readys_in[1]), + .stall (!ready_out), + .A_frag (A_tile_buf[3:2]), + .B_frag (B_tile_buf), + .C_frag (C_tile_buf[3:2]), + .valid_out (threadgroup_valids_out[1]), + .D_frag (D_tile[3:2]) + ); + + `RUNTIME_ASSERT(&(threadgroup_valids_out) == |(threadgroup_valids_out), + ("threadgroups went out of sync!")) + `RUNTIME_ASSERT(&(threadgroup_readys_in) == |(threadgroup_readys_in), + ("threadgroups went out of sync!")) + + wire wid_enq = valid_in && ready_in; + wire wid_deq = valid_out && ready_out; + + assign valid_out = &(threadgroup_valids_out); // need to pass along warp id's to do multithreading VX_fifo_queue #( .DATAW ($bits(wid)), // @perf: seems to require deeper depth than the FEDP issue queues to // not cause stalls. - .DEPTH (2 * ISSUE_QUEUE_DEPTH) + .DEPTH (2 * OPERAND_BUFFER_DEPTH) ) wid_queue ( .clk (clk), .reset (reset), - .push (enq), - .pop (deq), + .push (wid_enq), + .pop (wid_deq), .data_in (wid), .data_out (D_wid), - .empty (empty), + .empty (wid_empty), `UNUSED_PIN(alm_empty), - .full (full), + .full (wid_full), `UNUSED_PIN(alm_full), `UNUSED_PIN(size) ); @@ -119,9 +165,9 @@ module VX_tensor_dpu #( endmodule // does (m,n,k) = (2,4,2) matmul compute over 2 cycles. -// matches Figure 10(b) of the paper. +// see Figure 10(b) of the paper. module VX_tensor_threadgroup #( - parameter ISSUE_QUEUE_DEPTH + parameter OPERAND_BUFFER_DEPTH ) ( input clk, input reset, @@ -136,69 +182,37 @@ module VX_tensor_threadgroup #( output valid_out, output [1:0][3:0][31:0] D_frag ); - wire [1:0][1:0][31:0] A_frag_buf; - wire [1:0][3:0][31:0] B_frag_buf; - wire [1:0][3:0][31:0] C_frag_buf; - - wire valid_buf; - wire ready_buf; - - wire enq = valid_in && ready_in; - wire deq = valid_buf && ready_buf; - wire empty; - wire full; - assign ready_in = !full; - assign valid_buf = !empty; - - // 'Issue queue' for the FEDP units. - // This exists to decouple the execution of the dot-product unit from - // the operand arrival. Operands from execute_if can arrive - // intermittently according to the frontend's behavior, and since the dpu - // can also stall for a fixed initiation latency, we need to decouple the - // two to efficiently feed the dpu. - // - // TODO: better queue design possible; e.g. B_frag is shared by two - // threadgroups, so we need only 1 queue per octet for B - VX_fifo_queue #( - .DATAW ($bits(A_frag) + $bits(B_frag) + $bits(C_frag)), - .DEPTH (ISSUE_QUEUE_DEPTH) - ) input_buffer ( - .clk (clk), - .reset (reset), - .push (enq), - .pop (deq), - .data_in ({A_frag, B_frag, C_frag}), - .data_out ({A_frag_buf, B_frag_buf, C_frag_buf}), - .empty (empty), - `UNUSED_PIN(alm_empty), - .full (full), - `UNUSED_PIN(alm_full), - `UNUSED_PIN(size) - ); + wire fedp_valid_in; + wire fedp_ready_in; + wire fedp_fire_in = fedp_valid_in && fedp_ready_in; wire [3:0] fedp_valids; wire fedp_valid_out = &(fedp_valids); wire fedp_ready_out = !stall; wire fedp_fire_out = fedp_valid_out && fedp_ready_out; - wire fedp_valid_in = valid_buf; - wire fedp_ready_in = fedp_ready_out; // coupled - wire fedp_fire_in = fedp_valid_in && fedp_ready_in; + assign fedp_valid_in = valid_in; + // coupled ready; backpressure immediately reaches input from output + assign fedp_ready_in = fedp_ready_out; - // 0: FEDP uses first half from input_buffer - // 1: FEDP uses last half and pops input_buffer + // The dot product units take 2 cycles to finish computing A_frag * B_frag + // + C_frag. step_in and step_out keeps track of which cycle they're at + // & when they have to pop from input queue and push to result queue. + // + // step_in == 0: FEDP uses first half from operand buffer + // step_in == 1: FEDP uses last half and pops from operand buffer wire step_in; - // 0: FEDP produces first half of D_frag - // 1: FEDP produces last half of D_frag and asserts valid_out + // step_out == 0: FEDP produces first half of D_frag + // step_out == 1: FEDP produces last half of D_frag and asserts valid_out wire step_out; - assign ready_buf = fedp_fire_in && (step_in == 1'b1); + assign ready_in = fedp_fire_in && (step_in == 1'b1); wire [3:0][31:0] D_reg; logic [3:0][31:0] D_reg_n; - // Staging buffer that latches the D half-tile. + // staging buffer that latches the D half-tile VX_tensor_reg #( - .N($bits(D_reg)) + .DATAW($bits(D_reg)) ) staging_d ( .clk(clk), .reset(reset), @@ -220,7 +234,7 @@ module VX_tensor_threadgroup #( // flip step_in/step_out on FEDP in/out fire, respectively VX_tensor_reg #( - .N(1) + .DATAW(1) ) staging_step_in ( .clk(clk), .reset(reset), @@ -229,7 +243,7 @@ module VX_tensor_threadgroup #( .q(step_in) ); VX_tensor_reg #( - .N(1) + .DATAW(1) ) staging_step_out ( .clk(clk), .reset(reset), @@ -259,15 +273,15 @@ module VX_tensor_threadgroup #( .clock (clk), .reset (reset), .io_in_valid (fedp_fire_in), - .io_in_bits_a_0 (A_frag_buf[d_row][0]), - .io_in_bits_a_1 (A_frag_buf[d_row][1]), + .io_in_bits_a_0 (A_frag[d_row][0]), + .io_in_bits_a_1 (A_frag[d_row][1]), .io_in_bits_a_2 (32'h0), .io_in_bits_a_3 (32'h0), - .io_in_bits_b_0 (step_in == 1'b0 ? B_frag_buf[0][d_col] : B_frag_buf[0][d_col + 1]), - .io_in_bits_b_1 (step_in == 1'b0 ? B_frag_buf[1][d_col] : B_frag_buf[1][d_col + 1]), + .io_in_bits_b_0 (step_in == 1'b0 ? B_frag[0][d_col] : B_frag[0][d_col + 1]), + .io_in_bits_b_1 (step_in == 1'b0 ? B_frag[1][d_col] : B_frag[1][d_col + 1]), .io_in_bits_b_2 (32'h0), .io_in_bits_b_3 (32'h0), - .io_in_bits_c (step_in == 1'b0 ? C_frag_buf[d_row][d_col] : C_frag_buf[d_row][d_col + 1]), + .io_in_bits_c (step_in == 1'b0 ? C_frag[d_row][d_col] : C_frag[d_row][d_col + 1]), .io_stall (stall), .io_out_valid (fedp_valids[i]), .io_out_bits_data (D_half[i])