diff --git a/hw/rtl/core/VX_tensor_hopper_core.sv b/hw/rtl/core/VX_tensor_hopper_core.sv index c6f9d4dd..e50c0bc9 100644 --- a/hw/rtl/core/VX_tensor_hopper_core.sv +++ b/hw/rtl/core/VX_tensor_hopper_core.sv @@ -25,13 +25,15 @@ module VX_tensor_hopper_core_block import VX_gpu_pkg::*; #( - wb - rd */ - wire [`UUID_WIDTH-1:0] execute_if_data_uuid; - wire [`NW_WIDTH-1:0] execute_if_data_wid; - wire [NUM_LANES-1:0] execute_if_data_tmask; - wire [`INST_ALU_BITS-1:0] execute_if_data_op_type; - wire [`XLEN-1:0] execute_if_data_PC; - wire execute_if_data_wb; - wire [`NR_BITS-1:0] execute_if_data_rd; + wire [`UUID_WIDTH-1:0] execute_if_data_uuid; + wire [`NW_WIDTH-1:0] execute_if_data_wid; + wire [NUM_LANES-1:0] execute_if_data_tmask; + wire [`INST_ALU_BITS-1:0] execute_if_data_op_type; + wire [`XLEN-1:0] execute_if_data_PC; + wire execute_if_data_wb; + wire [`NR_BITS-1:0] execute_if_data_rd; + wire [NUM_LANES-1:0][`XLEN-1:0] execute_if_data_rs1; + wire [NUM_LANES-1:0][`XLEN-1:0] execute_if_data_rs2; wire metadata_queue_full; wire metadata_queue_empty; @@ -52,7 +54,8 @@ module VX_tensor_hopper_core_block import VX_gpu_pkg::*; #( wire enq = operand_enq_fire; wire deq = metadata_deq; - localparam DATAW = `UUID_WIDTH + `NW_WIDTH + `NUM_THREADS + `INST_ALU_BITS + `XLEN + 1 + `NR_BITS; + localparam DATAW = `UUID_WIDTH + `NW_WIDTH + `NUM_THREADS + `INST_ALU_BITS + `XLEN + 1 + + `NR_BITS + (NUM_LANES * `XLEN) + (NUM_LANES * `XLEN); VX_fifo_queue #( .DATAW(DATAW), .DEPTH(METADATA_QUEUE_DEPTH) @@ -63,10 +66,12 @@ module VX_tensor_hopper_core_block import VX_gpu_pkg::*; #( .pop(deq), .data_in({execute_if.data.uuid, execute_if.data.wid, execute_if.data.tmask, execute_if.data.op_type, execute_if.data.PC, - execute_if.data.wb, execute_if.data.rd}), + execute_if.data.wb, execute_if.data.rd, + execute_if.data.rs1_data, execute_if.data.rs2_data}), .data_out({execute_if_data_uuid, execute_if_data_wid, execute_if_data_tmask, execute_if_data_op_type, execute_if_data_PC, - execute_if_data_wb, execute_if_data_rd}), + execute_if_data_wb, execute_if_data_rd, + execute_if_data_rs1, execute_if_data_rs2}), .empty(metadata_queue_empty), `UNUSED_PIN(alm_empty), .full(metadata_queue_full), @@ -94,6 +99,10 @@ module VX_tensor_hopper_core_block import VX_gpu_pkg::*; #( // commit wire initiate_valid = metadata_valid && commit_if.ready && !hmma_wait; wire [`NW_WIDTH-1:0] initiate_wid = execute_if_data_wid; + wire [`XLEN-1:0] initiate_addr_a = execute_if_data_rs1[0]; + wire [`XLEN-1:0] initiate_addr_b = execute_if_data_rs2[0]; + `RUNTIME_ASSERT(!metadata_valid || execute_if_data_tmask[0], + ("tmask for HGMMA instruction is invalid")) // we're recycling execute_if.op_type as operands_if.op_type which might // have a different width; let's be safe @@ -107,17 +116,17 @@ module VX_tensor_hopper_core_block import VX_gpu_pkg::*; #( // /* // fake fsm driving tc rf port - reg [11:0] counter; - always @(posedge clk) begin - if (reset) begin - counter <= 12'd1; - end else begin - counter <= counter + 12'd1; - end - end - assign regfile_if.req_valid = (counter[3:0] != 4'd0); - assign regfile_if.req_data.wis = '0; - assign regfile_if.req_data.rs = counter[11:7]; + // reg [11:0] counter; + // always @(posedge clk) begin + // if (reset) begin + // counter <= 12'd1; + // end else begin + // counter <= counter + 12'd1; + // end + // end + // assign regfile_if.req_valid = (counter[6:0] == 7'd0); + // assign regfile_if.req_data.wis = '0; + // assign regfile_if.req_data.rs = counter[11:7]; // */ TensorCoreDecoupled tensor_hopper_core ( @@ -127,6 +136,8 @@ module VX_tensor_hopper_core_block import VX_gpu_pkg::*; #( .io_initiate_ready(initiate_ready), .io_initiate_valid(initiate_valid), .io_initiate_bits_wid(initiate_wid), + .io_initiate_bits_addressA(initiate_addr_a), + .io_initiate_bits_addressB(initiate_addr_b), .io_writeback_ready(writeback_ready), .io_writeback_valid(writeback_valid), @@ -150,6 +161,7 @@ module VX_tensor_hopper_core_block import VX_gpu_pkg::*; #( .io_respB_valid(smem_B_if.rsp_valid), .io_respB_bits_source(smem_B_if.rsp_data.tag), .io_respB_bits_data(smem_B_if.rsp_data.data), + .io_respC(regfile_if.rsp_data.data), .io_reqA_ready(smem_A_if.req_ready), .io_reqA_valid(smem_A_if.req_valid), @@ -158,8 +170,15 @@ module VX_tensor_hopper_core_block import VX_gpu_pkg::*; #( .io_reqB_ready(smem_B_if.req_ready), .io_reqB_valid(smem_B_if.req_valid), .io_reqB_bits_source(smem_B_if.req_data.tag), - .io_reqB_bits_address(smem_B_if.req_data.addr) + .io_reqB_bits_address(smem_B_if.req_data.addr), + .io_reqC_valid(regfile_if.req_valid), + .io_reqC_bits(regfile_if.req_data.rs[4:0]) ); + // add offset of 32 for fp regs + assign regfile_if.req_data.rs[5] = 1'b1; + assign regfile_if.req_data.wis = '0; + `STATIC_ASSERT((`ISSUE_WIDTH == `NUM_WARPS), + ("static assertion failed: tensor_hopper_core assumes ISSUE_WIDTH == NUM_WARPS")) // VX_tensor_hopper_core #( // ) tensor_hopper_core (