diff --git a/hw/rtl/core/VX_tensor_core.sv b/hw/rtl/core/VX_tensor_core.sv index b9cc89b1..e1781e4c 100644 --- a/hw/rtl/core/VX_tensor_core.sv +++ b/hw/rtl/core/VX_tensor_core.sv @@ -51,7 +51,7 @@ module VX_tensor_core import VX_gpu_pkg::*; #( ); for (genvar block_idx = 0; block_idx < BLOCK_SIZE; ++block_idx) begin - VX_tensor_core_warp #( + VX_tensor_core_block #( .ISW(1) // FIXME: not block_idx ) tensor_core ( .clk(clk), @@ -64,7 +64,7 @@ module VX_tensor_core import VX_gpu_pkg::*; #( endmodule -module VX_tensor_core_warp import VX_gpu_pkg::*; #( +module VX_tensor_core_block import VX_gpu_pkg::*; #( parameter ISW ) ( input clk, @@ -82,15 +82,16 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #( localparam METADATA_QUEUE_DEPTH = 2 * `LATENCY_HMMA; wire [1:0] step = 2'(execute_if.data.op_type); - // op_mod is reused to indicate instruction's id in pair + // op_mod is reused to indicate if instruction is the last substep inside + // a step (pair of substeps) wire last_in_pair = (execute_if.data.op_mod == `INST_MOD_BITS'(1)); - logic [NUM_OCTETS-1:0] octet_results_valid; + wire [NUM_OCTETS-1:0] octet_results_valid; logic [NUM_OCTETS-1:0] octet_results_ready; - logic [NUM_OCTETS-1:0] octet_operands_ready; + wire [NUM_OCTETS-1:0] octet_operands_ready; // FIXME: should be NUM_LANES? - logic [`NUM_THREADS-1:0][`XLEN-1:0] wb_data_0; - logic [`NUM_THREADS-1:0][`XLEN-1:0] wb_data_1; + wire [`NUM_THREADS-1:0][`XLEN-1:0] wb_data_0; + wire [`NUM_THREADS-1:0][`XLEN-1:0] wb_data_1; wire [`NW_WIDTH-1:0] wb_wid; // valid signal synced between the functional units (octet) and the @@ -113,9 +114,9 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #( execute_if.data.rs3_data[LANE_OFFSET_THREADGROUP + 4*i +: 4], execute_if.data.rs3_data[4*i +: 4] }; - logic [3:0][3:0][31:0] octet_D; - logic result_valid; - logic result_ready; + wire [3:0][3:0][31:0] octet_D; + wire result_valid; + wire result_ready; VX_tensor_octet #( .ISW(ISW), @@ -285,15 +286,39 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #( end end + VX_tensor_reg #( + .N(1) + ) staging_subcommit ( + .clk(clk), + .reset(reset), + .d(subcommit_n), + .en(1'b1), + .q(subcommit) + ); +endmodule + +module VX_tensor_reg #( + parameter N +) ( + input clk, + input reset, + input [N-1:0] d, + input en, + output [N-1:0] q +); + logic [N-1:0] data; + always @(posedge clk) begin if (reset) begin - subcommit <= '0; - end - else begin - subcommit <= subcommit_n; + data <= '0; + end else begin + if (en) begin + data <= d; + end end end - + + assign q = data; endmodule module VX_tensor_octet #( @@ -337,7 +362,6 @@ module VX_tensor_octet #( logic [3:0][31:0] B_half_buf; logic [7:0][31:0] C_half_buf; - logic [`NUM_WARPS-1:0] substeps; logic [`NUM_WARPS-1:0] substeps_n; @@ -353,6 +377,7 @@ module VX_tensor_octet #( assign A_in_buf = A_in; assign B_in_buf = B_in; assign C_in_buf = C_in; + // TODO: merge *_buf/* assign operands_step_buf = operands_step; assign operands_wid_buf = operands_wid; assign operands_last_in_pair_buf = operands_last_in_pair; @@ -408,6 +433,18 @@ module VX_tensor_octet #( // wire operands_first_in_pair_fire = operands_ready && operands_valid && (!operands_last_in_pair); wire operands_first_in_pair_fire = operands_ready_buf && operands_valid_buf && (!operands_last_in_pair_buf); + // Staging buffer for the A/B/C half-tiles that will later be assembled + // with the other half tiles coming in on the input ports. + VX_tensor_reg #( + .N($bits(A_buffer) + $bits(B_buffer) + $bits(C_buffer)) + ) staging_abc ( + .clk(clk), + .reset(reset), + .d({A_buffer_n, B_buffer_n, C_buffer_n}), + .en(1'b1), + .q({A_buffer, B_buffer, C_buffer}) + ); + always @(*) begin A_buffer_n = A_buffer; B_buffer_n = B_buffer; @@ -426,20 +463,15 @@ module VX_tensor_octet #( end end - always @(posedge clk) begin - if (reset) begin - A_buffer <= '0; - B_buffer <= '0; - C_buffer <= '0; - substeps <= '0; - end - else begin - A_buffer <= A_buffer_n; - B_buffer <= B_buffer_n; - C_buffer <= C_buffer_n; - substeps <= substeps_n; - end - end + VX_tensor_reg #( + .N($bits(substeps)) + ) staging_substeps ( + .clk(clk), + .reset(reset), + .d(substeps_n), + .en(1'b1), + .q(substeps) + ); wire outbuf_ready_in; wire hmma_ready; @@ -458,8 +490,8 @@ module VX_tensor_octet #( }; // C is 4x4 fp32 matrix logic [3:0][3:0][31:0] C_tile; - logic [3:0][3:0][31:0] D_tile; - logic [`NW_WIDTH-1:0] D_wid_dpu; + wire [3:0][3:0][31:0] D_tile; + wire [`NW_WIDTH-1:0] D_wid_dpu; always @(*) begin C_tile[3] = { halves_buf.C_half[7], C_buffer[operands_wid_buf][7], halves_buf.C_half[5], C_buffer[operands_wid_buf][5] }; diff --git a/hw/rtl/fpu/VX_tensor_dpu.sv b/hw/rtl/fpu/VX_tensor_dpu.sv index ebe752b5..88f3cbc1 100644 --- a/hw/rtl/fpu/VX_tensor_dpu.sv +++ b/hw/rtl/fpu/VX_tensor_dpu.sv @@ -62,11 +62,11 @@ module VX_tensor_dpu #( // stalling // assign ready_in = ready_out; - logic synced_fire; + wire synced_fire; assign synced_fire = valid_in && ready_in; - logic [1:0] threadgroup_valids; - logic [1:0] threadgroup_readys; + wire [1:0] threadgroup_valids; + wire [1:0] threadgroup_readys; // B_tile is shared across the two threadgroups; see Figure 13 VX_tensor_threadgroup #( .ISSUE_QUEUE_DEPTH(ISSUE_QUEUE_DEPTH) @@ -187,7 +187,7 @@ module VX_tensor_threadgroup #( `UNUSED_PIN(size) ); - logic [3:0] fedp_valids; + wire [3:0] fedp_valids; wire fedp_valid_out = &(fedp_valids); wire fedp_ready_out = !stall; wire fedp_fire_out = fedp_valid_out && fedp_ready_out; @@ -198,14 +198,27 @@ module VX_tensor_threadgroup #( // 0: FEDP uses first half from input_buffer // 1: FEDP uses last half and pops input_buffer - logic step_in; + wire step_in; // 0: FEDP produces first half of D_frag // 1: FEDP produces last half of D_frag and asserts valid_out - logic step_out; + wire step_out; assign ready_buf = fedp_fire_in && (step_in == 1'b1); + wire [3:0][31:0] D_reg; + logic [3:0][31:0] D_reg_n; + + // Staging buffer that latches the D half-tile. + VX_tensor_reg #( + .N($bits(D_reg)) + ) staging_d ( + .clk(clk), + .reset(reset), + .d(D_reg_n), + .en(1'b1), + .q(D_reg) + ); + // latch the first-half result of D_frag - logic [3:0][31:0] D_reg, D_reg_n; wire [3:0][31:0] D_half; always @(*) begin D_reg_n = D_reg; @@ -216,23 +229,25 @@ module VX_tensor_threadgroup #( end end - always @(posedge clk) begin - if (reset) begin - step_in <= '0; - step_out <= '0; - - D_reg <= '0; - end else begin - if (fedp_fire_in) begin - step_in <= ~step_in; - end - if (fedp_fire_out) begin - step_out <= ~step_out; - end - - D_reg <= D_reg_n; - end - end + // flip step_in/step_out on FEDP in/out fire, respectively + VX_tensor_reg #( + .N(1) + ) staging_step_in ( + .clk(clk), + .reset(reset), + .d(~step_in), + .en(fedp_fire_in), + .q(step_in) + ); + VX_tensor_reg #( + .N(1) + ) staging_step_out ( + .clk(clk), + .reset(reset), + .d(~step_out), + .en(fedp_fire_out), + .q(step_out) + ); // TODO: Instead of latching half-result and constructing a full D tile, // we should be able to send these half fragments down to commit stage