diff --git a/hw/dpi/float_dpi.cpp b/hw/dpi/float_dpi.cpp index 29ca22df..570d6bf2 100644 --- a/hw/dpi/float_dpi.cpp +++ b/hw/dpi/float_dpi.cpp @@ -347,7 +347,7 @@ void dpi_fmax(bool enable, int dst_fmt, int64_t a, int64_t b, int64_t* result, s // A is M * K, B is K * M, C is M * M, D is M * M #define M 4 -#define K 2 +#define K 2 // FIXME: 4x4x1 / cycle / octet! // all row major float c_A_tile[M][K]; @@ -358,6 +358,15 @@ float c_D_tile[M][M]; // code assumes that svBitVecVal is basically a uint32_t static_assert(sizeof(svBitVecVal) == 4); +void clear_float_array(float* c_tile, int rows, int cols) { + for (int i = 0; i < rows; i += 1) { + for (int j = 0; j < cols; j += 1) { + int index = i * cols + j; + c_tile[index] = 0.0f; + } + } +} + void fill_float_array(const svBitVecVal* sv_tile, float* c_tile, int rows, int cols) { for (int i = 0; i < rows; i += 1) { @@ -396,6 +405,11 @@ void dpi_hmma(bool enable, const svBitVecVal* A_tile, const svBitVecVal* B_tile, if (!enable) { return; } + clear_float_array(&c_A_tile[0][0], M, K); + clear_float_array(&c_B_tile[0][0], K, M); + clear_float_array(&c_C_tile[0][0], M, M); + clear_float_array(&c_D_tile[0][0], M, M); + // std::cout << "A: " << std::endl; fill_float_array(A_tile, &c_A_tile[0][0], M, K); // std::cout << "B: " << std::endl; @@ -551,7 +565,7 @@ void dpi_print_results(int wid, int octet, const svBitVecVal* A_tile, const svBi } steps[wid] += 1; - if (steps[wid] % 64 == 0) { + if (steps[wid] % 32 == 0) { steps[wid] = 0; std::cout << "warp " << wid << " finished wmma\n"; std::cout << "A tile" << "\n"; diff --git a/hw/rtl/VX_config.vh b/hw/rtl/VX_config.vh index 65d56e8a..8905bd3d 100644 --- a/hw/rtl/VX_config.vh +++ b/hw/rtl/VX_config.vh @@ -391,7 +391,7 @@ // Tensor Core Latency `ifndef LATENCY_HMMA -`define LATENCY_HMMA 8 +`define LATENCY_HMMA 4 `endif // Icache Configurable Knobs ////////////////////////////////////////////////// diff --git a/hw/rtl/VX_platform.vh b/hw/rtl/VX_platform.vh index 9cfb76fb..0f6c0917 100644 --- a/hw/rtl/VX_platform.vh +++ b/hw/rtl/VX_platform.vh @@ -14,7 +14,7 @@ `ifndef VX_PLATFORM_VH `define VX_PLATFORM_VH -// synthesis only +// enable synthesizable build if SIMULATION not explicitly defined `ifndef SIMULATION `define SYNTHESIS `define NDEBUG diff --git a/hw/rtl/core/VX_alu_unit.sv b/hw/rtl/core/VX_alu_unit.sv index 7546f4b3..c1724360 100644 --- a/hw/rtl/core/VX_alu_unit.sv +++ b/hw/rtl/core/VX_alu_unit.sv @@ -42,7 +42,7 @@ module VX_alu_unit #( `RESET_RELAY (dispatch_reset, reset); - VX_dispatch_unit #( + VX_dispatch_unit_sane #( .BLOCK_SIZE (BLOCK_SIZE), .NUM_LANES (NUM_LANES), .OUT_REG (PARTIAL_BW ? 1 : 0) diff --git a/hw/rtl/core/VX_decode.sv b/hw/rtl/core/VX_decode.sv index 6f4539e7..2ca414cd 100644 --- a/hw/rtl/core/VX_decode.sv +++ b/hw/rtl/core/VX_decode.sv @@ -545,6 +545,12 @@ module VX_decode #( `INST_EXT4: begin ex_type = `EX_TENSOR; op_type = `INST_TENSOR_HMMA; + // tensor core macroop is encoded as r-type + use_rd = 1; + `USED_IREG (rd); + `USED_IREG (rs1); + `USED_IREG (rs2); + `USED_IREG (rs3); end `endif default:; diff --git a/hw/rtl/core/VX_dispatch_unit_sane.sv b/hw/rtl/core/VX_dispatch_unit_sane.sv new file mode 100644 index 00000000..3e31ced2 --- /dev/null +++ b/hw/rtl/core/VX_dispatch_unit_sane.sv @@ -0,0 +1,274 @@ +// Copyright © 2019-2023 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +`include "VX_define.vh" + +module VX_dispatch_unit_sane import VX_gpu_pkg::*; #( + parameter BLOCK_SIZE = 1, + parameter NUM_LANES = 1, + parameter OUT_REG = 0, + parameter MAX_FANOUT = `MAX_FANOUT +) ( + input wire clk, + input wire reset, + + // inputs + VX_dispatch_if.slave dispatch_if [`ISSUE_WIDTH], + + // outputs + VX_execute_if.master execute_if [BLOCK_SIZE] + +); + `STATIC_ASSERT ((`NUM_THREADS == NUM_LANES * (`NUM_THREADS / NUM_LANES)), ("invalid parameter")) + localparam BLOCK_SIZE_W = `LOG2UP(BLOCK_SIZE); + localparam NUM_PACKETS = `NUM_THREADS / NUM_LANES; + localparam PID_BITS = `CLOG2(NUM_PACKETS); + localparam PID_WIDTH = `UP(PID_BITS); + localparam BATCH_COUNT = `ISSUE_WIDTH / BLOCK_SIZE; + localparam BATCH_COUNT_W= `LOG2UP(BATCH_COUNT); + localparam ISSUE_W = `LOG2UP(`ISSUE_WIDTH); + localparam IN_DATAW = `UUID_WIDTH + ISSUE_WIS_W + `NUM_THREADS + `INST_OP_BITS + `INST_MOD_BITS + 1 + 1 + 1 + `XLEN + `XLEN + `NR_BITS + `NT_WIDTH + (3 * `NUM_THREADS * `XLEN); + localparam OUT_DATAW = `UUID_WIDTH + `NW_WIDTH + NUM_LANES + `INST_OP_BITS + `INST_MOD_BITS + 1 + 1 + 1 + `XLEN + `XLEN + `NR_BITS + `NT_WIDTH + (3 * NUM_LANES * `XLEN) + PID_WIDTH + 1 + 1; + localparam FANOUT_ENABLE= (`NUM_THREADS > (MAX_FANOUT + MAX_FANOUT/2)); + + localparam DATA_TMASK_OFF = IN_DATAW - (`UUID_WIDTH + ISSUE_WIS_W + `NUM_THREADS); + localparam DATA_REGS_OFF = 0; + + wire [`ISSUE_WIDTH-1:0] dispatch_valid; + wire [`ISSUE_WIDTH-1:0][IN_DATAW-1:0] dispatch_data; + wire [`ISSUE_WIDTH-1:0] dispatch_ready; + + for (genvar i = 0; i < `ISSUE_WIDTH; ++i) begin + assign dispatch_valid[i] = dispatch_if[i].valid; + assign dispatch_data[i] = dispatch_if[i].data; + assign dispatch_if[i].ready = dispatch_ready[i]; + end + + wire [BLOCK_SIZE-1:0][ISSUE_W-1:0] issue_indices; + wire [BLOCK_SIZE-1:0] block_ready; + wire [BLOCK_SIZE-1:0][NUM_LANES-1:0] block_tmask; + wire [BLOCK_SIZE-1:0][2:0][NUM_LANES-1:0][`XLEN-1:0] block_regs; + wire [BLOCK_SIZE-1:0][PID_WIDTH-1:0] block_pid; + wire [BLOCK_SIZE-1:0] block_sop; + wire [BLOCK_SIZE-1:0] block_eop; + wire [BLOCK_SIZE-1:0] block_done; + + wire batch_done = (& block_done); + + logic [BATCH_COUNT_W-1:0] batch_idx; + // if (BATCH_COUNT != 1) begin + // always @(posedge clk) begin + // if (reset) begin + // batch_idx <= '0; + // end else begin + // batch_idx <= batch_idx + BATCH_COUNT_W'(batch_done); + // end + // end + // end else begin + // assign batch_idx = 0; + // `UNUSED_VAR(batch_done) + // end + + // group dispatch_valid by blocks + wire [BATCH_COUNT-1:0] batch_valids; + for (genvar i = 0; i < BATCH_COUNT; ++i) begin + assign batch_valids[i] = |(dispatch_valid[(BLOCK_SIZE * i) +: BLOCK_SIZE]); + end + + // elect the leftmost-valid batch for the dispatch + wire dispatch_any_valid; + VX_lzc_rr #( + .N (BATCH_COUNT) + ) batch_select ( + .clk (clk), + .reset (reset), + .data_in (batch_valids), + .data_out (batch_idx), + .valid_out (dispatch_any_valid) + ); + + for (genvar block_idx = 0; block_idx < BLOCK_SIZE; ++block_idx) begin + + wire [ISSUE_W-1:0] issue_idx = ISSUE_W'(batch_idx * BLOCK_SIZE) + ISSUE_W'(block_idx); + assign issue_indices[block_idx] = issue_idx; + + wire valid_p, ready_p; + + if (`NUM_THREADS != NUM_LANES) begin + reg [NUM_PACKETS-1:0] sent_mask_p; + wire [PID_WIDTH-1:0] start_p_n, start_p, end_p; + wire dispatch_valid_r; + reg is_first_p; + + wire fire_p = valid_p && ready_p; + + wire is_last_p = (start_p == end_p); + + wire fire_eop = fire_p && is_last_p; + + always @(posedge clk) begin + if (reset) begin + sent_mask_p <= '0; + is_first_p <= 1; + end else begin + if ((BATCH_COUNT != 1) ? batch_done : fire_eop) begin + sent_mask_p <= '0; + is_first_p <= 1; + end else if (fire_p) begin + sent_mask_p[start_p] <= 1; + is_first_p <= 0; + end + end + end + + wire [NUM_PACKETS-1:0][NUM_LANES-1:0] per_packet_tmask; + wire [NUM_PACKETS-1:0][2:0][NUM_LANES-1:0][`XLEN-1:0] per_packet_regs; + + wire [`NUM_THREADS-1:0] dispatch_tmask = dispatch_data[issue_idx][DATA_TMASK_OFF +: `NUM_THREADS]; + wire [`NUM_THREADS-1:0][`XLEN-1:0] dispatch_rs1_data = dispatch_data[issue_idx][DATA_REGS_OFF + 2 * `NUM_THREADS * `XLEN +: `NUM_THREADS * `XLEN]; + wire [`NUM_THREADS-1:0][`XLEN-1:0] dispatch_rs2_data = dispatch_data[issue_idx][DATA_REGS_OFF + 1 * `NUM_THREADS * `XLEN +: `NUM_THREADS * `XLEN]; + wire [`NUM_THREADS-1:0][`XLEN-1:0] dispatch_rs3_data = dispatch_data[issue_idx][DATA_REGS_OFF + 0 * `NUM_THREADS * `XLEN +: `NUM_THREADS * `XLEN]; + + for (genvar i = 0; i < NUM_PACKETS; ++i) begin + for (genvar j = 0; j < NUM_LANES; ++j) begin + localparam k = i * NUM_LANES + j; + assign per_packet_tmask[i][j] = dispatch_tmask[k]; + assign per_packet_regs[i][0][j] = dispatch_rs1_data[k]; + assign per_packet_regs[i][1][j] = dispatch_rs2_data[k]; + assign per_packet_regs[i][2][j] = dispatch_rs3_data[k]; + end + end + + wire [NUM_PACKETS-1:0] packet_valids; + wire [NUM_PACKETS-1:0][PID_WIDTH-1:0] packet_ids; + + for (genvar i = 0; i < NUM_PACKETS; ++i) begin + assign packet_valids[i] = (| per_packet_tmask[i]); + assign packet_ids[i] = PID_WIDTH'(i); + end + + VX_find_first #( + .N (NUM_PACKETS), + .DATAW (PID_WIDTH), + .REVERSE (0) + ) find_first ( + .valid_in (packet_valids & ~sent_mask_p), + .data_in (packet_ids), + .data_out (start_p_n), + `UNUSED_PIN (valid_out) + ); + + VX_find_first #( + .N (NUM_PACKETS), + .DATAW (PID_WIDTH), + .REVERSE (1) + ) find_last ( + .valid_in (packet_valids), + .data_in (packet_ids), + .data_out (end_p), + `UNUSED_PIN (valid_out) + ); + + VX_pipe_register #( + .DATAW (1 + PID_WIDTH), + .RESETW (1), + .DEPTH (FANOUT_ENABLE ? 1 : 0) + ) pipe_reg ( + .clk (clk), + .reset (reset || fire_p), // should flush on fire + .enable (1'b1), + .data_in ({dispatch_valid[issue_idx], start_p_n}), + .data_out ({dispatch_valid_r, start_p}) + ); + + wire [NUM_LANES-1:0] tmask_p = per_packet_tmask[start_p]; + wire [2:0][NUM_LANES-1:0][`XLEN-1:0] regs_p = per_packet_regs[start_p]; + + wire block_enable = (BATCH_COUNT == 1 || ~(& sent_mask_p)); + + assign valid_p = dispatch_valid_r && block_enable; + assign block_tmask[block_idx] = tmask_p; + assign block_regs[block_idx] = regs_p; + assign block_pid[block_idx] = start_p; + assign block_sop[block_idx] = is_first_p; + assign block_eop[block_idx] = is_last_p; + if (FANOUT_ENABLE) begin + assign block_ready[block_idx] = dispatch_valid_r && ready_p && block_enable; + end else begin + assign block_ready[block_idx] = ready_p && block_enable; + end + assign block_done[block_idx] = ~dispatch_valid[issue_idx] || fire_eop; + end else begin + assign valid_p = dispatch_valid[issue_idx]; + assign block_tmask[block_idx] = dispatch_data[issue_idx][DATA_TMASK_OFF +: `NUM_THREADS]; + assign block_regs[block_idx][0] = dispatch_data[issue_idx][DATA_REGS_OFF + 2 * `NUM_THREADS * `XLEN +: `NUM_THREADS * `XLEN]; + assign block_regs[block_idx][1] = dispatch_data[issue_idx][DATA_REGS_OFF + 1 * `NUM_THREADS * `XLEN +: `NUM_THREADS * `XLEN]; + assign block_regs[block_idx][2] = dispatch_data[issue_idx][DATA_REGS_OFF + 0 * `NUM_THREADS * `XLEN +: `NUM_THREADS * `XLEN]; + assign block_pid[block_idx] = '0; + assign block_sop[block_idx] = 1'b1; + assign block_eop[block_idx] = 1'b1; + assign block_ready[block_idx] = ready_p; + assign block_done[block_idx] = ~valid_p || ready_p; + end + + wire [ISSUE_ISW_W-1:0] isw; + if (BATCH_COUNT != 1) begin + if (BLOCK_SIZE != 1) begin + assign isw = {batch_idx, BLOCK_SIZE_W'(block_idx)}; + end else begin + assign isw = batch_idx; + end + end else begin + assign isw = block_idx; + end + + `RESET_RELAY(buf_out_reset, reset); + + wire [`NW_WIDTH-1:0] block_wid = wis_to_wid(dispatch_data[issue_idx][DATA_TMASK_OFF+`NUM_THREADS +: ISSUE_WIS_W], isw); + + VX_elastic_buffer #( + .DATAW (OUT_DATAW), + .SIZE (`OUT_REG_TO_EB_SIZE(OUT_REG)), + .OUT_REG (`OUT_REG_TO_EB_REG(OUT_REG)) + ) buf_out ( + .clk (clk), + .reset (buf_out_reset), + .valid_in (valid_p), + .ready_in (ready_p), + .data_in ({ + dispatch_data[issue_idx][IN_DATAW-1 : DATA_TMASK_OFF+`NUM_THREADS+ISSUE_WIS_W], + block_wid, + block_tmask[block_idx], + dispatch_data[issue_idx][DATA_TMASK_OFF-1 : DATA_REGS_OFF + 3 * `NUM_THREADS * `XLEN], + block_regs[block_idx][0], + block_regs[block_idx][1], + block_regs[block_idx][2], + block_pid[block_idx], + block_sop[block_idx], + block_eop[block_idx]}), + .data_out (execute_if[block_idx].data), + .valid_out (execute_if[block_idx].valid), + .ready_out (execute_if[block_idx].ready) + ); + end + + reg [`ISSUE_WIDTH-1:0] ready_in; + always @(*) begin + ready_in = 0; + for (integer i = 0; i < BLOCK_SIZE; ++i) begin + ready_in[issue_indices[i]] = block_ready[i] && block_eop[i]; + end + end + assign dispatch_ready = ready_in; + +endmodule diff --git a/hw/rtl/core/VX_fpu_unit.sv b/hw/rtl/core/VX_fpu_unit.sv index 26956213..7e0875ba 100644 --- a/hw/rtl/core/VX_fpu_unit.sv +++ b/hw/rtl/core/VX_fpu_unit.sv @@ -39,7 +39,7 @@ module VX_fpu_unit import VX_fpu_pkg::*; #( `RESET_RELAY (dispatch_reset, reset); - VX_dispatch_unit #( + VX_dispatch_unit_sane #( .BLOCK_SIZE (BLOCK_SIZE), .NUM_LANES (NUM_LANES), .OUT_REG (PARTIAL_BW ? 1 : 0) diff --git a/hw/rtl/core/VX_lsu_unit.sv b/hw/rtl/core/VX_lsu_unit.sv index b4fd6ee1..20fac1d1 100644 --- a/hw/rtl/core/VX_lsu_unit.sv +++ b/hw/rtl/core/VX_lsu_unit.sv @@ -49,7 +49,7 @@ module VX_lsu_unit import VX_gpu_pkg::*; #( `RESET_RELAY (dispatch_reset, reset); - VX_dispatch_unit #( + VX_dispatch_unit_sane #( .BLOCK_SIZE (BLOCK_SIZE), .NUM_LANES (NUM_LANES), .OUT_REG (1) @@ -596,6 +596,31 @@ module VX_lsu_unit import VX_gpu_pkg::*; #( .commit_out_if (commit_if) ); +`ifdef PERF_ENABLE + wire [`CLOG2(NUM_LANES+1)-1:0] perf_rsp_tmask_valids_per_cycle; + wire [`CLOG2(NUM_LANES+1)-1:0] perf_rsp_tmask_total_per_cycle; + reg [`PERF_CTR_BITS-1:0] perf_rsp_tmask_valids; + reg [`PERF_CTR_BITS-1:0] perf_rsp_tmask_total; + reg [`PERF_CTR_BITS-1:0] perf_rsp_fires; + + `POP_COUNT(perf_rsp_tmask_valids_per_cycle, rsp_tmask); + assign perf_rsp_tmask_total_per_cycle = NUM_LANES; + + always @(posedge clk) begin + if (reset) begin + perf_rsp_tmask_valids <= '0; + perf_rsp_tmask_total <= '0; + perf_rsp_fires <= '0; + end else begin + if (mem_rsp_fire) begin + perf_rsp_tmask_valids <= perf_rsp_tmask_valids + perf_rsp_tmask_valids_per_cycle; + perf_rsp_tmask_total <= perf_rsp_tmask_total + perf_rsp_tmask_total_per_cycle; + perf_rsp_fires <= perf_rsp_fires + 1'b1; + end + end + end +`endif + `ifdef DBG_SCOPE_LSU if (CORE_ID == 0) begin `ifdef SCOPE diff --git a/hw/rtl/core/VX_smem_unit.sv b/hw/rtl/core/VX_smem_unit.sv index 91587b2f..532dba55 100644 --- a/hw/rtl/core/VX_smem_unit.sv +++ b/hw/rtl/core/VX_smem_unit.sv @@ -66,6 +66,7 @@ module VX_smem_unit import VX_gpu_pkg::*; #( .req_valid (smem_req_valid), .req_rw (smem_req_rw), .req_byteen (smem_req_byteen), + // FIXME: synthesis complains undriven when USE_EXTERNAL_SMEM .req_addr (smem_req_addr), .req_data (smem_req_data), .req_tag (smem_req_tag), diff --git a/hw/rtl/core/VX_tensor_core.sv b/hw/rtl/core/VX_tensor_core.sv index 9971d619..efa74afd 100644 --- a/hw/rtl/core/VX_tensor_core.sv +++ b/hw/rtl/core/VX_tensor_core.sv @@ -1,7 +1,7 @@ `ifdef EXT_T_ENABLE `include "VX_fpu_define.vh" -module VX_tensor_core #( +module VX_tensor_core import VX_gpu_pkg::*; #( ) ( input clk, @@ -10,17 +10,54 @@ module VX_tensor_core #( VX_dispatch_if.slave dispatch_if [`ISSUE_WIDTH], VX_commit_if.master commit_if [`ISSUE_WIDTH] ); - `STATIC_ASSERT(`NUM_THREADS == 32, ("tensor core requires # of threads in a warp to be 32 (try running w/ CONFIGS=\"-DNUM_THREADS=32\")")); - - for (genvar i = 0; i < `ISSUE_WIDTH; ++i) begin + localparam BLOCK_SIZE = 1; + localparam NUM_LANES = `NUM_THREADS; + // localparam PARTIAL_BW = (BLOCK_SIZE != `ISSUE_WIDTH) || (NUM_LANES != `NUM_THREADS); + localparam PARTIAL_BW = 1; + + VX_execute_if #( + .NUM_LANES (NUM_LANES) + ) execute_if[BLOCK_SIZE](); + + `RESET_RELAY (dispatch_reset, reset); + + VX_dispatch_unit_sane #( + .BLOCK_SIZE (BLOCK_SIZE), + .NUM_LANES (NUM_LANES), + .OUT_REG (PARTIAL_BW ? 1 : 0) + ) dispatch_unit ( + .clk (clk), + .reset (dispatch_reset), + .dispatch_if(dispatch_if), + .execute_if (execute_if) + ); + + VX_commit_if #( + .NUM_LANES (NUM_LANES) + ) commit_block_if[BLOCK_SIZE](); + + `RESET_RELAY (commit_reset, reset); + + VX_gather_unit #( + .BLOCK_SIZE (BLOCK_SIZE), + .NUM_LANES (NUM_LANES), + .OUT_REG (PARTIAL_BW ? 3 : 0) // FIXME: why 3? + ) gather_unit ( + .clk (clk), + .reset (commit_reset), + .commit_in_if (commit_block_if), + .commit_out_if (commit_if) + ); + + for (genvar block_idx = 0; block_idx < BLOCK_SIZE; ++block_idx) begin VX_tensor_core_warp #( - .ISW(i) + .ISW(1) // FIXME: not block_idx ) tensor_core ( .clk(clk), .reset(reset), - .dispatch_if(dispatch_if[i]), - .commit_if(commit_if[i]) + .execute_if(execute_if[block_idx]), + .commit_if(commit_block_if[block_idx]) ); end @@ -32,37 +69,53 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #( input clk, input reset, - VX_dispatch_if.slave dispatch_if, + VX_execute_if.slave execute_if, VX_commit_if.master commit_if ); - wire [1:0] step = 2'(dispatch_if.data.op_type); - logic [3:0] octet_results_valid; - logic [3:0] octet_results_ready; - logic [3:0] octet_operands_ready; + localparam NUM_OCTETS = (`NUM_THREADS / 8); + // offet in the lane numbers that get mapped to the two threadgroups in an + // octet. E.g. two tgs map lane 0-3 and lane 16-19 -> 16 + // FIXME: not sure this is the right logic. just filling in what works + localparam LANE_OFFSET_THREADGROUP = (4 * NUM_OCTETS); + // this is only a rule of thumb + localparam METADATA_QUEUE_DEPTH = 2 * `LATENCY_HMMA; + + wire [1:0] step = 2'(execute_if.data.op_type); + // op_mod is reused to indicate instruction's id in pair + wire last_in_pair = (execute_if.data.op_mod == `INST_MOD_BITS'(1)); + + logic [NUM_OCTETS-1:0] octet_results_valid; + logic [NUM_OCTETS-1:0] octet_results_ready; + logic [NUM_OCTETS-1:0] octet_operands_ready; + // FIXME: should be NUM_LANES? logic [`NUM_THREADS-1:0][`XLEN-1:0] wb_data_0; logic [`NUM_THREADS-1:0][`XLEN-1:0] wb_data_1; - - assign dispatch_if.ready = &octet_operands_ready; + wire [`NW_WIDTH-1:0] wb_wid; + + // valid signal synced between the functional units (octet) and the + // metadata queue + wire operands_valid_synced; `ifdef EXT_T_ENABLE - for (genvar i = 0; i < 4/*octets*/; ++i) begin + for (genvar i = 0; i < NUM_OCTETS; ++i) begin `else for (genvar i = 0; i < 0; ++i) begin `endif // lane-to-octet mapping; see figure 13 of the paper wire [7:0][31:0] octet_A = { - dispatch_if.data.rs1_data[16+4*i +: 4], dispatch_if.data.rs1_data[4*i +: 4] + execute_if.data.rs1_data[LANE_OFFSET_THREADGROUP + 4*i +: 4], execute_if.data.rs1_data[4*i +: 4] }; wire [7:0][31:0] octet_B = { - dispatch_if.data.rs2_data[16+4*i +: 4], dispatch_if.data.rs2_data[4*i +: 4] + execute_if.data.rs2_data[LANE_OFFSET_THREADGROUP + 4*i +: 4], execute_if.data.rs2_data[4*i +: 4] }; wire [7:0][31:0] octet_C = { - dispatch_if.data.rs3_data[16+4*i +: 4], dispatch_if.data.rs3_data[4*i +: 4] + execute_if.data.rs3_data[LANE_OFFSET_THREADGROUP + 4*i +: 4], execute_if.data.rs3_data[4*i +: 4] }; logic [3:0][3:0][31:0] octet_D; logic result_valid; logic result_ready; + VX_tensor_octet #( .ISW(ISW), .OCTET(i) @@ -73,12 +126,14 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #( .A_in(octet_A), .B_in(octet_B), .C_in(octet_C), - .operands_valid(dispatch_if.valid), + .operands_valid(operands_valid_synced), + .operands_wid(execute_if.data.wid), + .operands_last_in_pair(last_in_pair), + .operands_step(step), .operands_ready(octet_operands_ready[i]), - .step(step), - .D_out(octet_D), + .D_wid(wb_wid), .result_valid(result_valid), .result_ready(result_ready) ); @@ -100,15 +155,15 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #( assign wb_data_1[4*i+2] = octet_D[0][3]; assign wb_data_1[4*i+3] = octet_D[1][3]; - assign wb_data_0[4*i+16+0] = octet_D[2][0]; - assign wb_data_0[4*i+16+1] = octet_D[3][0]; - assign wb_data_0[4*i+16+2] = octet_D[2][2]; - assign wb_data_0[4*i+16+3] = octet_D[3][2]; + assign wb_data_0[4*i+LANE_OFFSET_THREADGROUP+0] = octet_D[2][0]; + assign wb_data_0[4*i+LANE_OFFSET_THREADGROUP+1] = octet_D[3][0]; + assign wb_data_0[4*i+LANE_OFFSET_THREADGROUP+2] = octet_D[2][2]; + assign wb_data_0[4*i+LANE_OFFSET_THREADGROUP+3] = octet_D[3][2]; - assign wb_data_1[4*i+16+0] = octet_D[2][1]; - assign wb_data_1[4*i+16+1] = octet_D[3][1]; - assign wb_data_1[4*i+16+2] = octet_D[2][3]; - assign wb_data_1[4*i+16+3] = octet_D[3][3]; + assign wb_data_1[4*i+LANE_OFFSET_THREADGROUP+0] = octet_D[2][1]; + assign wb_data_1[4*i+LANE_OFFSET_THREADGROUP+1] = octet_D[3][1]; + assign wb_data_1[4*i+LANE_OFFSET_THREADGROUP+2] = octet_D[2][3]; + assign wb_data_1[4*i+LANE_OFFSET_THREADGROUP+3] = octet_D[3][3]; end /* commit_if.data_t parts that we need to keep around: @@ -122,44 +177,95 @@ module VX_tensor_core_warp import VX_gpu_pkg::*; #( localparam DATAW = `UUID_WIDTH + `NW_WIDTH + `NUM_THREADS + `XLEN + 1 + `NR_BITS; - wire dispatch_if_fire = dispatch_if.valid && dispatch_if.ready; - wire commit_if_fire = commit_if.valid && commit_if.ready; - wire [DATAW-1:0] dispatch_if_data_enq = { - dispatch_if.data.uuid, - wis_to_wid(dispatch_if.data.wis, ISW), - dispatch_if.data.tmask, - dispatch_if.data.PC, - dispatch_if.data.wb, - dispatch_if.data.rd + wire operand_enq_fire = operands_valid_synced && execute_if.ready; + wire commit_if_ready_override; + wire commit_if_fire = commit_if.valid && commit_if_ready_override; + wire [DATAW-1:0] execute_if_data_enq = { + execute_if.data.uuid, + execute_if.data.wid, + execute_if.data.tmask, + execute_if.data.PC, + execute_if.data.wb, + execute_if.data.rd + // pid/sop/eop set later }; - wire [DATAW-1:0] dispatch_if_data_deq; + wire [`NUM_WARPS-1:0][DATAW-1:0] execute_if_data_deq; - // this is probably a little oversized - VX_fifo_queue #( - .DATAW(DATAW), - .DEPTH(16) - ) pending_uops ( - .clk(clk), - .reset(reset), - .push(dispatch_if_fire), - .pop(commit_if_fire), - .data_in(dispatch_if_data_enq), - .data_out(dispatch_if_data_deq), - `UNUSED_PIN(empty), - `UNUSED_PIN(alm_empty), - `UNUSED_PIN(full), // should be impossible to overflow - `UNUSED_PIN(alm_full), - `UNUSED_PIN(size) - ); + wire [`NUM_WARPS-1:0] metadata_queue_fulls; + // OR not AND, we don't want any warp full + wire metadata_queue_full = |(metadata_queue_fulls); + // need to make sure both metadata and octet issue queues are in sync + assign operands_valid_synced = execute_if.valid && !metadata_queue_full; + assign execute_if.ready = &(octet_operands_ready) && !metadata_queue_full; + + for (genvar i = 0; i < `NUM_WARPS; i++) begin + // Metadata queue for commit_if. This simply copies execute_if's + // metadata and pops them in conjunction with commit fire. + // + // This has to be separated per-warp, as otherwise requests from + // multiple warps can be enqueued interleaved, which makes it hard to + // ensure two consecutive dequeues are associated with the same warp for + // commit. (FIXME: this is not strictly necessary though.) + + wire enq = operand_enq_fire && (execute_if.data.wid == `NW_WIDTH'(i)); + wire deq = commit_if_fire && ( wb_wid == `NW_WIDTH'(i)); + + VX_fifo_queue #( + .DATAW(DATAW), + .DEPTH(METADATA_QUEUE_DEPTH) + ) pending_uops ( + .clk(clk), + .reset(reset), + .push(enq), + .pop(deq), + .data_in(execute_if_data_enq), + .data_out(execute_if_data_deq[i]), + `UNUSED_PIN(empty), + `UNUSED_PIN(alm_empty), + .full(metadata_queue_fulls[i]), + `UNUSED_PIN(alm_full), + `UNUSED_PIN(size) + ); + end + + // this shouldn't really happen unless there's a big contention over + // the commit stage + `RUNTIME_ASSERT(!(!reset && metadata_queue_full), ("tensor core uop queue is full!")) + + // unlike execute which can be interleaved between warps, commit is + // serialized and completed one-warp-by-warp, therefore we only need to + // keep one subcommit state bit unlike for `substeps` logic subcommit, subcommit_n; + wire all_valid = (& octet_results_valid); + +// define this to inject artificial commit backpressure for debugging +// `define TENSOR_INJECT_COMMIT_BACKPRESSURE +`ifndef TENSOR_INJECT_COMMIT_BACKPRESSURE assign commit_if.valid = all_valid; + assign commit_if_ready_override = commit_if.ready; +`else + logic [1:0] counter; + always @(posedge clk) begin + if (reset) begin + counter <= '0; + end else begin + if (all_valid) begin + counter <= counter + 1'b1; + end + end + end + + assign commit_if.valid = all_valid && (counter == 2'b0); + assign commit_if_ready_override = commit_if.ready && (counter == 2'b0); +`endif localparam COMMIT_DATAW = `UUID_WIDTH + `NW_WIDTH + `NUM_THREADS + `XLEN + 1 + `NR_BITS + (`NUM_THREADS * `XLEN) + 1 + 1 + 1; wire [COMMIT_DATAW-1:0] commit_if_data = { - dispatch_if_data_deq, /* uuid ~ rd */ + execute_if_data_deq[wb_wid], /* uuid ~ rd */ + // execute_if_data_deq, /* uuid ~ rd */ subcommit == 1'b0 ? wb_data_0 : wb_data_1, /* data */ 1'b0, /* pid */ 1'b1, /* sop */ @@ -196,22 +302,25 @@ module VX_tensor_octet #( input clk, input reset, - input [7:0][31:0] A_in, - input [7:0][31:0] B_in, - input [7:0][31:0] C_in, - input operands_valid, // we have to backpressure due to there potentially being contention over commit - output operands_ready, - - input [1:0] step, + input [7:0][31:0] A_in, + input [7:0][31:0] B_in, + input [7:0][31:0] C_in, + input operands_valid, + input [`NW_WIDTH-1:0] operands_wid, + input operands_last_in_pair, + input [1:0] operands_step, + // we have to backpressure due to there potentially being contention over commit + output operands_ready, output [3:0][3:0][31:0] D_out, + output [`NW_WIDTH-1:0] D_wid, output result_valid, input result_ready ); // 512 bits/octet * 4 octets per warp - logic [3:0][31:0] A_buffer, A_buffer_n; - logic [3:0][31:0] B_buffer, B_buffer_n; - logic [7:0][31:0] C_buffer, C_buffer_n; + logic [`NUM_WARPS-1:0][3:0][31:0] A_buffer, A_buffer_n; + logic [`NUM_WARPS-1:0][3:0][31:0] B_buffer, B_buffer_n; + logic [`NUM_WARPS-1:0][7:0][31:0] C_buffer, C_buffer_n; // half the inputs are buffered, half are not (instead coming straight // from operand bus) unlike the real tensor core. @@ -219,41 +328,95 @@ module VX_tensor_octet #( logic [3:0][31:0] A_half; logic [3:0][31:0] B_half; logic [7:0][31:0] C_half; - always @(*) begin + logic [3:0][31:0] A_half_buf; + logic [3:0][31:0] B_half_buf; + logic [7:0][31:0] C_half_buf; + + + logic [`NUM_WARPS-1:0] substeps; + logic [`NUM_WARPS-1:0] substeps_n; + + wire [7:0][31:0] A_in_buf; + wire [7:0][31:0] B_in_buf; + wire [7:0][31:0] C_in_buf; + wire operands_valid_buf; + wire operands_ready_buf; + wire [`NW_WIDTH-1:0] operands_wid_buf; + wire operands_last_in_pair_buf; + wire [1:0] operands_step_buf; + + assign A_in_buf = A_in; + assign B_in_buf = B_in; + assign C_in_buf = C_in; + assign operands_step_buf = operands_step; + assign operands_wid_buf = operands_wid; + assign operands_last_in_pair_buf = operands_last_in_pair; + assign operands_valid_buf = operands_valid; + assign operands_ready = operands_ready_buf; + + typedef struct { + logic [3:0][31:0] A_half; + logic [3:0][31:0] B_half; + logic [7:0][31:0] C_half; + } half_t; + + function half_t get_operand_half( + input logic [1:0] step, + input logic [7:0][31:0] A_in, + input logic [7:0][31:0] B_in, + input logic [7:0][31:0] C_in + ); + half_t half; // note that not all lanes participate at every step case (step) 2'b00: begin - A_half = { A_in[5:4], A_in[1:0] }; - B_half = B_in[3:0]; + // Two A_in segments correspond to two 2x2 subtiles of A read + // by two threadgroups: [0:2,0:2] and [4:6,0:2] in Step 0 of + // Figure 10(b). B_in OTOH is shared by two threadgroups. + // Note k-dimension is shrunk from 4 to 2. + half.A_half = { A_in[5:4], A_in[1:0] }; + half.B_half = B_in[3:0]; end 2'b01: begin - A_half = { A_in[7:6], A_in[3:2] }; - B_half = B_in[3:0]; + half.A_half = { A_in[7:6], A_in[3:2] }; + half.B_half = B_in[3:0]; end 2'b10: begin - A_half = { A_in[5:4], A_in[1:0] }; - B_half = B_in[7:4]; + half.A_half = { A_in[5:4], A_in[1:0] }; + half.B_half = B_in[7:4]; end 2'b11: begin - A_half = { A_in[7:6], A_in[3:2] }; - B_half = B_in[7:4]; + half.A_half = { A_in[7:6], A_in[3:2] }; + half.B_half = B_in[7:4]; end endcase - C_half = C_in; - end + half.C_half = C_in; + return half; + endfunction - logic substep; - wire substep_n = (operands_ready && operands_valid) ? ~substep : substep; + half_t halves; + half_t halves_buf; + assign halves = get_operand_half(operands_step, A_in, B_in, C_in); + assign halves_buf = get_operand_half(operands_step_buf, A_in_buf, B_in_buf, C_in_buf); + + wire do_hmma = operands_ready_buf && operands_valid_buf && operands_last_in_pair_buf; + // wire operands_first_in_pair_fire = operands_ready && operands_valid && (!operands_last_in_pair); + wire operands_first_in_pair_fire = operands_ready_buf && operands_valid_buf && (!operands_last_in_pair_buf); always @(*) begin A_buffer_n = A_buffer; B_buffer_n = B_buffer; C_buffer_n = C_buffer; + substeps_n = substeps; - if (substep == 1'b0) begin - A_buffer_n = A_half; - B_buffer_n = B_half; - C_buffer_n = C_half; + if (operands_first_in_pair_fire) begin + substeps_n[operands_wid_buf] = 1'b1; // ready for hmma + A_buffer_n[operands_wid_buf] = halves_buf.A_half; + B_buffer_n[operands_wid_buf] = halves_buf.B_half; + C_buffer_n[operands_wid_buf] = halves_buf.C_half; + end + if (do_hmma) begin + substeps_n[operands_wid_buf] = 1'b0; // finished hmma, ready for next operand end end @@ -262,61 +425,113 @@ module VX_tensor_octet #( A_buffer <= '0; B_buffer <= '0; C_buffer <= '0; - substep <= '0; + substeps <= '0; end else begin A_buffer <= A_buffer_n; B_buffer <= B_buffer_n; C_buffer <= C_buffer_n; - substep <= substep_n; + substeps <= substeps_n; end end - wire stall = result_valid && ~result_ready; - assign operands_ready = ~stall; + wire outbuf_ready_in; + wire hmma_ready; + assign operands_ready_buf = hmma_ready; // A is 4x2 fp32 matrix wire [3:0][1:0][31:0] A_tile = { - { A_half[3], A_buffer[3] }, - { A_half[2], A_buffer[2] }, - { A_half[1], A_buffer[1] }, - { A_half[0], A_buffer[0] } + { halves_buf.A_half[3], A_buffer[operands_wid_buf][3] }, + { halves_buf.A_half[2], A_buffer[operands_wid_buf][2] }, + { halves_buf.A_half[1], A_buffer[operands_wid_buf][1] }, + { halves_buf.A_half[0], A_buffer[operands_wid_buf][0] } }; // B is 2x4 fp32 matrix wire [1:0][3:0][31:0] B_tile = { - B_half, B_buffer + halves_buf.B_half, B_buffer[operands_wid_buf] }; // C is 4x4 fp32 matrix logic [3:0][3:0][31:0] C_tile; + logic [3:0][3:0][31:0] D_tile; + logic [`NW_WIDTH-1:0] D_wid_dpu; always @(*) begin - C_tile = { - C_half[7], C_buffer[7], C_half[5], C_buffer[5], - C_half[6], C_buffer[6], C_half[4], C_buffer[4], - C_half[3], C_buffer[3], C_half[1], C_buffer[1], - C_half[2], C_buffer[2], C_half[0], C_buffer[0] - }; + C_tile[3] = { halves_buf.C_half[7], C_buffer[operands_wid_buf][7], halves_buf.C_half[5], C_buffer[operands_wid_buf][5] }; + C_tile[2] = { halves_buf.C_half[6], C_buffer[operands_wid_buf][6], halves_buf.C_half[4], C_buffer[operands_wid_buf][4] }; + C_tile[1] = { halves_buf.C_half[3], C_buffer[operands_wid_buf][3], halves_buf.C_half[1], C_buffer[operands_wid_buf][1] }; + C_tile[0] = { halves_buf.C_half[2], C_buffer[operands_wid_buf][2], halves_buf.C_half[0], C_buffer[operands_wid_buf][0] }; end - wire do_hmma = (substep == 1'b1 && operands_valid && operands_ready); + wire dpu_valid; // this does (m,n,k)=(4,4,2) matmul, modeling compute of a single octet VX_tensor_dpu #( .ISW(ISW), - .OCTET(OCTET) + .OCTET(OCTET), + .ISSUE_QUEUE_DEPTH(4 /*@perf: arbtirary*/) ) dpu ( .clk(clk), .reset(reset), - .stall(stall), - .valid_in(do_hmma), + .ready_in(hmma_ready), .A_tile(A_tile), .B_tile(B_tile), .C_tile(C_tile), + .wid(operands_wid_buf), - .valid_out(result_valid), - .D_tile(D_out) + .valid_out(dpu_valid), + .ready_out(outbuf_ready_in), + .D_tile(D_tile), + .D_wid(D_wid_dpu) ); + + wire outbuf_empty; + wire outbuf_full; + // backpressure from commit + assign outbuf_ready_in = ~outbuf_full; + assign result_valid = ~outbuf_empty; + + wire outbuf_enq = outbuf_ready_in && dpu_valid; + wire outbuf_deq = result_valid && result_ready; + + // buffer to stage the result D tile for 2 cycles until commit/writeback + // is complete. This decouples the irregular dpu output traffic from the + // regular, every-2-cycle commit traffic to ensure the commit pipeline is + // used more efficiently. + // FIXME: unnecessary? + VX_fifo_queue #( + .DATAW ($bits(D_wid) + $bits(D_out)), + .DEPTH (2 /* arbitrary */) + ) output_buffer ( + .clk (clk), + .reset (reset), + .push (outbuf_enq), + .pop (outbuf_deq), + .data_in ({D_wid_dpu, D_tile}), + .data_out ({D_wid, D_out}), + .empty (outbuf_empty), + `UNUSED_PIN(alm_empty), + .full (outbuf_full), // should be impossible to overflow + `UNUSED_PIN(alm_full), + `UNUSED_PIN(size) + ); + + // FIXME: this shouldn't be necessary + `RUNTIME_ASSERT(reset || !outbuf_full, ("dpu result queue is full!")) + +`ifdef PERF_ENABLE + logic [`PERF_CTR_BITS-1:0] perf_tensor_dpu_total; + + always @(posedge clk) begin + if (reset) begin + perf_tensor_dpu_total <= '0; + end else begin + if (do_hmma) begin + perf_tensor_dpu_total <= perf_tensor_dpu_total + 2'd2; + end + end + end +`endif endmodule `endif diff --git a/hw/rtl/core/VX_tensor_ucode_8lanes.vh b/hw/rtl/core/VX_tensor_ucode_8lanes.vh new file mode 100644 index 00000000..41ec857e --- /dev/null +++ b/hw/rtl/core/VX_tensor_ucode_8lanes.vh @@ -0,0 +1,49 @@ +// uop metadata (sequencing, next state), execution metadata (EX_TYPE, OP_TYPE, OP_MOD), wb, use pc, use imm, pc, imm, rd, rs1, rs2, rs3 +HMMA_SET0_STEP0_0: begin + uop = {NEXT, HMMA_SET0_STEP0_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(0), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(16), `FREG(0), `FREG(8), `FREG(16)}; +end +HMMA_SET0_STEP0_1: begin + uop = {NEXT, HMMA_SET0_STEP1_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(0), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(17), `FREG(1), `FREG(9), `FREG(17)}; +end +HMMA_SET0_STEP1_0: begin + uop = {NEXT, HMMA_SET0_STEP1_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(1), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(18), `FREG(0), `FREG(8), `FREG(18)}; +end +HMMA_SET0_STEP1_1: begin + uop = {NEXT, HMMA_SET0_STEP2_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(1), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(19), `FREG(1), `FREG(9), `FREG(19)}; +end +HMMA_SET0_STEP2_0: begin + uop = {NEXT, HMMA_SET0_STEP2_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(2), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(20), `FREG(0), `FREG(8), `FREG(20)}; +end +HMMA_SET0_STEP2_1: begin + uop = {NEXT, HMMA_SET0_STEP3_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(2), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(21), `FREG(1), `FREG(9), `FREG(21)}; +end +HMMA_SET0_STEP3_0: begin + uop = {NEXT, HMMA_SET0_STEP3_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(3), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(22), `FREG(0), `FREG(8), `FREG(22)}; +end +HMMA_SET0_STEP3_1: begin + uop = {NEXT, HMMA_SET1_STEP0_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(3), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(23), `FREG(1), `FREG(9), `FREG(23)}; +end +HMMA_SET1_STEP0_0: begin + uop = {NEXT, HMMA_SET1_STEP0_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(0), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(16), `FREG(2), `FREG(10), `FREG(16)}; +end +HMMA_SET1_STEP0_1: begin + uop = {NEXT, HMMA_SET1_STEP1_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(0), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(17), `FREG(3), `FREG(11), `FREG(17)}; +end +HMMA_SET1_STEP1_0: begin + uop = {NEXT, HMMA_SET1_STEP1_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(1), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(18), `FREG(2), `FREG(10), `FREG(18)}; +end +HMMA_SET1_STEP1_1: begin + uop = {NEXT, HMMA_SET1_STEP2_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(1), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(19), `FREG(3), `FREG(11), `FREG(19)}; +end +HMMA_SET1_STEP2_0: begin + uop = {NEXT, HMMA_SET1_STEP2_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(2), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(20), `FREG(2), `FREG(10), `FREG(20)}; +end +HMMA_SET1_STEP2_1: begin + uop = {NEXT, HMMA_SET1_STEP3_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(2), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(21), `FREG(3), `FREG(11), `FREG(21)}; +end +HMMA_SET1_STEP3_0: begin + uop = {NEXT, HMMA_SET1_STEP3_1, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(3), `INST_MOD_BITS'(0), 1'b1, 1'b0, 1'b0, 32'b0, 32'b0, `FREG(22), `FREG(2), `FREG(10), `FREG(22)}; +end +HMMA_SET1_STEP3_1: begin + uop = {FINISH, HMMA_SET0_STEP0_0, `EX_BITS'(`EX_TENSOR), `INST_OP_BITS'(3), `INST_MOD_BITS'(1), 1'b1, 1'b0, 1'b0, 32'b1, 32'b1, `FREG(23), `FREG(3), `FREG(11), `FREG(23)}; +end diff --git a/hw/rtl/core/VX_uop_sequencer.sv b/hw/rtl/core/VX_uop_sequencer.sv index 24b5af3c..26817b8d 100644 --- a/hw/rtl/core/VX_uop_sequencer.sv +++ b/hw/rtl/core/VX_uop_sequencer.sv @@ -14,10 +14,9 @@ module VX_uop_sequencer import VX_gpu_pkg::*; ( localparam UOP_TABLE_SIZE = 64; localparam UPC_BITS = `CLOG2(UOP_TABLE_SIZE); - localparam NEXT = 2'b00; - localparam FINISH = 2'b01; - localparam UBR_BITS = 2; + localparam NEXT = UBR_BITS'(2'b00); + localparam FINISH = UBR_BITS'(2'b01); // uop metadata (sequencing, next state), execution metadata (EX_TYPE, OP_TYPE, OP_MOD), wb, use pc, use imm, pc, imm, rd, rs1, rs2, rs3 localparam UOP_TABLE_WIDTH = UBR_BITS + UPC_BITS + `EX_BITS + `INST_OP_BITS + `INST_MOD_BITS + 1 + 1 + 1 + `XLEN + `XLEN + (`NR_BITS * 4); @@ -122,7 +121,17 @@ module VX_uop_sequencer import VX_gpu_pkg::*; ( // passthrough when !use_uop assign ibuffer_if.valid = use_uop ? 1'b1 : uop_sequencer_if.valid; assign uop_sequencer_if.ready = use_uop ? (uop_fire && ubr == FINISH) : ibuffer_if.ready; - assign ibuffer_if.data = use_uop ? ibuffer_output : uop_sequencer_if.data; + + always @(*) begin + ibuffer_if.data = use_uop ? ibuffer_output : uop_sequencer_if.data; + + if (uop_sequencer_if.valid && use_uop && + uop_sequencer_if.data.rd == `NR_BITS'(1)) begin + // a little sketchy? but shouldn't create any loop + ibuffer_if.data.rd = ibuffer_if.data.rd + `NR_BITS'(8); // FIXME: 8 is hardcoded + ibuffer_if.data.rs3 = ibuffer_if.data.rs3 + `NR_BITS'(8); + end + end always @(posedge clk) begin if (uop_start) begin diff --git a/hw/rtl/fpu/VX_tensor_dpu.sv b/hw/rtl/fpu/VX_tensor_dpu.sv index cfc5f507..8b7a1c26 100644 --- a/hw/rtl/fpu/VX_tensor_dpu.sv +++ b/hw/rtl/fpu/VX_tensor_dpu.sv @@ -3,44 +3,274 @@ module VX_tensor_dpu #( parameter ISW, - parameter OCTET + parameter OCTET, + // @perf: has big impact on throughput. A rule of thumb is to set it to + // the pipeline length of FEDPs in order to make sure there are enough + // entries to fully saturate the pipeline, but this is still rough + parameter ISSUE_QUEUE_DEPTH = `LATENCY_HMMA ) ( input clk, input reset, - input stall, - input valid_in, + output ready_in, input [3:0][1:0][31:0] A_tile, input [1:0][3:0][31:0] B_tile, input [3:0][3:0][31:0] C_tile, + input [`NW_WIDTH-1:0] wid, output valid_out, - output [3:0][3:0][31:0] D_tile + input ready_out, + output [3:0][3:0][31:0] D_tile, + output [`NW_WIDTH-1:0] D_wid ); - logic [3:0][3:0][31:0] result_hmma; + // logic [3:0][3:0][31:0] result_hmma; + // always @(*) begin + // dpi_hmma(valid_in, A_tile, B_tile, C_tile, result_hmma); + // end + + // logic ready_reg; + // always @(posedge clk) begin + // if (reset) begin + // ready_reg <= '1; + // end else if (valid_in && ready_in) begin + // ready_reg <= '0; + // dpi_print_results(int'(ISW), int'(OCTET), A_tile, B_tile, C_tile, result_hmma); + // end else if (valid_out && ready_out) begin + // ready_reg <= '1; + // end + // end + + // // fixed-latency queue + // VX_shift_register #( + // .DATAW (1 + $bits(wid)/* + $bits(D_tile)*/), + // .DEPTH (`LATENCY_HMMA + 1), + // .RESETW (1) + // ) shift_reg ( + // .clk (clk), + // .reset (reset), + // .enable (ready_out), + // .data_in ({valid_in && ready_in, wid /*, result_hmma*/}), + // .data_out ({valid_out, D_wid/*, D_tile */}) + // ); + + // ready as soon as valid_out + // assign ready_in = ready_reg || valid_out; + + // fully pipelined; ready_in is coupled to ready_out by immediately + // stalling + // assign ready_in = ready_out; + + logic synced_fire; + assign synced_fire = valid_in && ready_in; + + logic [1:0] threadgroup_valids; + logic [1:0] threadgroup_readys; + // B_tile is shared across the two threadgroups; see Figure 13 + VX_tensor_threadgroup #( + .ISSUE_QUEUE_DEPTH(ISSUE_QUEUE_DEPTH) + ) threadgroup_0 ( + .clk (clk), + .reset (reset), + .valid_in (synced_fire), + .ready_in (threadgroup_readys[0]), + .stall (!ready_out), + .A_frag (A_tile[1:0]), + .B_frag (B_tile), + .C_frag (C_tile[1:0]), + .valid_out (threadgroup_valids[0]), + .D_frag (D_tile[1:0]) + ); + VX_tensor_threadgroup #( + .ISSUE_QUEUE_DEPTH(ISSUE_QUEUE_DEPTH) + ) threadgroup_1 ( + .clk (clk), + .reset (reset), + .valid_in (synced_fire), + .ready_in (threadgroup_readys[1]), + .stall (!ready_out), + .A_frag (A_tile[3:2]), + .B_frag (B_tile), + .C_frag (C_tile[3:2]), + .valid_out (threadgroup_valids[1]), + .D_frag (D_tile[3:2]) + ); + + wire empty; + wire full; + wire enq = valid_in && ready_in; + wire deq = valid_out && ready_out; + + assign ready_in = &(threadgroup_readys) && !full; + assign valid_out = &(threadgroup_valids); + + // need to pass along warp id's to do multithreading + VX_fifo_queue #( + .DATAW ($bits(wid)), + // @perf: seems to require deeper depth than the FEDP issue queues to + // not cause stalls. + .DEPTH (2 * ISSUE_QUEUE_DEPTH) + ) wid_queue ( + .clk (clk), + .reset (reset), + .push (enq), + .pop (deq), + .data_in (wid), + .data_out (D_wid), + .empty (empty), + `UNUSED_PIN(alm_empty), + .full (full), + `UNUSED_PIN(alm_full), + `UNUSED_PIN(size) + ); + + `RUNTIME_ASSERT(reset || !(deq && empty), + ("dequeueing from empty warp id queue!")) +endmodule + +// does (m,n,k) = (2,4,2) matmul compute over 2 cycles. +// matches Figure 10(b) of the paper. +module VX_tensor_threadgroup #( + parameter ISSUE_QUEUE_DEPTH +) ( + input clk, + input reset, + + input valid_in, + output ready_in, + input stall, + input [1:0][1:0][31:0] A_frag, + input [1:0][3:0][31:0] B_frag, + input [1:0][3:0][31:0] C_frag, + + output valid_out, + output [1:0][3:0][31:0] D_frag +); + wire [1:0][1:0][31:0] A_frag_buf; + wire [1:0][3:0][31:0] B_frag_buf; + wire [1:0][3:0][31:0] C_frag_buf; + + wire valid_buf; + wire ready_buf; + + wire enq = valid_in && ready_in; + wire deq = valid_buf && ready_buf; + wire empty; + wire full; + assign ready_in = !full; + assign valid_buf = !empty; + + // 'Issue queue' for the FEDP units. + // This exists to decouple the execution of the dot-product unit from + // the operand arrival. Operands from execute_if can arrive + // intermittently according to the frontend's behavior, and since the dpu + // can also stall for a fixed initiation latency, we need to decouple the + // two to efficiently feed the dpu. + // + // TODO: better queue design possible; e.g. B_frag is shared by two + // threadgroups, so we need only 1 queue per octet for B + VX_fifo_queue #( + .DATAW ($bits(A_frag) + $bits(B_frag) + $bits(C_frag)), + .DEPTH (ISSUE_QUEUE_DEPTH) + ) input_buffer ( + .clk (clk), + .reset (reset), + .push (enq), + .pop (deq), + .data_in ({A_frag, B_frag, C_frag}), + .data_out ({A_frag_buf, B_frag_buf, C_frag_buf}), + .empty (empty), + `UNUSED_PIN(alm_empty), + .full (full), + `UNUSED_PIN(alm_full), + `UNUSED_PIN(size) + ); + + logic [3:0] fedp_valids; + wire fedp_valid_out = &(fedp_valids); + wire fedp_ready_out = !stall; + wire fedp_fire_out = fedp_valid_out && fedp_ready_out; + + wire fedp_valid_in = valid_buf; + wire fedp_ready_in = fedp_ready_out; // coupled + wire fedp_fire_in = fedp_valid_in && fedp_ready_in; + + // 0: FEDP uses first half from input_buffer + // 1: FEDP uses last half and pops input_buffer + logic step_in; + // 0: FEDP produces first half of D_frag + // 1: FEDP produces last half of D_frag and asserts valid_out + logic step_out; + assign ready_buf = fedp_fire_in && (step_in == 1'b1); + + // latch the first-half result of D_frag + logic [3:0][31:0] D_reg, D_reg_n; + wire [3:0][31:0] D_half; always @(*) begin - dpi_hmma(valid_in, A_tile, B_tile, C_tile, result_hmma); + D_reg_n = D_reg; + if (fedp_fire_out) begin + if (step_out == 1'b0) begin + D_reg_n = D_half; + end + end end always @(posedge clk) begin - if (~reset && valid_in) begin - dpi_print_results(int'(ISW), int'(OCTET), A_tile, B_tile, C_tile, result_hmma); + if (reset) begin + step_in <= '0; + step_out <= '0; + + D_reg <= '0; + end else begin + if (fedp_fire_in) begin + step_in <= ~step_in; + end + if (fedp_fire_out) begin + step_out <= ~step_out; + end + + D_reg <= D_reg_n; end end - - VX_shift_register #( - .DATAW (1 + $bits(D_tile)), - .DEPTH (`LATENCY_HMMA), - .RESETW (1) - ) shift_reg ( - .clk (clk), - .reset (reset), - .enable (~stall), - .data_in ({valid_in, result_hmma}), - .data_out ({valid_out, D_tile}) - ); + // TODO: Instead of latching half-result and constructing a full D tile, + // we should be able to send these half fragments down to commit stage + // immediately, saving flop space + assign D_frag[0][0] = D_reg[0]; + assign D_frag[0][2] = D_reg[1]; + assign D_frag[1][0] = D_reg[2]; + assign D_frag[1][2] = D_reg[3]; + assign D_frag[0][1] = D_half[0]; + assign D_frag[0][3] = D_half[1]; + assign D_frag[1][1] = D_half[2]; + assign D_frag[1][3] = D_half[3]; + + // 4 FEDPs per threadgroup + for (genvar i = 0; i < 4; ++i) begin + localparam int d_row = i / 2; + localparam int d_col = (i % 2) * 2; + // four-element dot product (FEDP) unit + TensorDotProductUnit fedp ( + .clock (clk), + .reset (reset), + .io_in_valid (fedp_fire_in), + .io_in_bits_a_0 (A_frag_buf[d_row][0]), + .io_in_bits_a_1 (A_frag_buf[d_row][1]), + .io_in_bits_a_2 (32'h0), + .io_in_bits_a_3 (32'h0), + .io_in_bits_b_0 (step_in == 1'b0 ? B_frag_buf[0][d_col] : B_frag_buf[0][d_col + 1]), + .io_in_bits_b_1 (step_in == 1'b0 ? B_frag_buf[1][d_col] : B_frag_buf[1][d_col + 1]), + .io_in_bits_b_2 (32'h0), + .io_in_bits_b_3 (32'h0), + .io_in_bits_c (step_in == 1'b0 ? C_frag_buf[d_row][d_col] : C_frag_buf[d_row][d_col + 1]), + .io_stall (stall), + .io_out_valid (fedp_valids[i]), + .io_out_bits_data (D_half[i]) + ); + end + + assign valid_out = fedp_valid_out && (step_out == 1'b1); endmodule + `endif