tensor: Connect SMEM addr/rf IO

This commit is contained in:
Hansung Kim
2024-10-28 19:42:02 -07:00
parent 4376bd33a2
commit 8a66b5ed89

View File

@@ -25,13 +25,15 @@ module VX_tensor_hopper_core_block import VX_gpu_pkg::*; #(
- wb - wb
- rd - rd
*/ */
wire [`UUID_WIDTH-1:0] execute_if_data_uuid; wire [`UUID_WIDTH-1:0] execute_if_data_uuid;
wire [`NW_WIDTH-1:0] execute_if_data_wid; wire [`NW_WIDTH-1:0] execute_if_data_wid;
wire [NUM_LANES-1:0] execute_if_data_tmask; wire [NUM_LANES-1:0] execute_if_data_tmask;
wire [`INST_ALU_BITS-1:0] execute_if_data_op_type; wire [`INST_ALU_BITS-1:0] execute_if_data_op_type;
wire [`XLEN-1:0] execute_if_data_PC; wire [`XLEN-1:0] execute_if_data_PC;
wire execute_if_data_wb; wire execute_if_data_wb;
wire [`NR_BITS-1:0] execute_if_data_rd; wire [`NR_BITS-1:0] execute_if_data_rd;
wire [NUM_LANES-1:0][`XLEN-1:0] execute_if_data_rs1;
wire [NUM_LANES-1:0][`XLEN-1:0] execute_if_data_rs2;
wire metadata_queue_full; wire metadata_queue_full;
wire metadata_queue_empty; wire metadata_queue_empty;
@@ -52,7 +54,8 @@ module VX_tensor_hopper_core_block import VX_gpu_pkg::*; #(
wire enq = operand_enq_fire; wire enq = operand_enq_fire;
wire deq = metadata_deq; wire deq = metadata_deq;
localparam DATAW = `UUID_WIDTH + `NW_WIDTH + `NUM_THREADS + `INST_ALU_BITS + `XLEN + 1 + `NR_BITS; localparam DATAW = `UUID_WIDTH + `NW_WIDTH + `NUM_THREADS + `INST_ALU_BITS + `XLEN + 1 +
`NR_BITS + (NUM_LANES * `XLEN) + (NUM_LANES * `XLEN);
VX_fifo_queue #( VX_fifo_queue #(
.DATAW(DATAW), .DATAW(DATAW),
.DEPTH(METADATA_QUEUE_DEPTH) .DEPTH(METADATA_QUEUE_DEPTH)
@@ -63,10 +66,12 @@ module VX_tensor_hopper_core_block import VX_gpu_pkg::*; #(
.pop(deq), .pop(deq),
.data_in({execute_if.data.uuid, execute_if.data.wid, .data_in({execute_if.data.uuid, execute_if.data.wid,
execute_if.data.tmask, execute_if.data.op_type, execute_if.data.PC, execute_if.data.tmask, execute_if.data.op_type, execute_if.data.PC,
execute_if.data.wb, execute_if.data.rd}), execute_if.data.wb, execute_if.data.rd,
execute_if.data.rs1_data, execute_if.data.rs2_data}),
.data_out({execute_if_data_uuid, execute_if_data_wid, .data_out({execute_if_data_uuid, execute_if_data_wid,
execute_if_data_tmask, execute_if_data_op_type, execute_if_data_PC, execute_if_data_tmask, execute_if_data_op_type, execute_if_data_PC,
execute_if_data_wb, execute_if_data_rd}), execute_if_data_wb, execute_if_data_rd,
execute_if_data_rs1, execute_if_data_rs2}),
.empty(metadata_queue_empty), .empty(metadata_queue_empty),
`UNUSED_PIN(alm_empty), `UNUSED_PIN(alm_empty),
.full(metadata_queue_full), .full(metadata_queue_full),
@@ -94,6 +99,10 @@ module VX_tensor_hopper_core_block import VX_gpu_pkg::*; #(
// commit // commit
wire initiate_valid = metadata_valid && commit_if.ready && !hmma_wait; wire initiate_valid = metadata_valid && commit_if.ready && !hmma_wait;
wire [`NW_WIDTH-1:0] initiate_wid = execute_if_data_wid; wire [`NW_WIDTH-1:0] initiate_wid = execute_if_data_wid;
wire [`XLEN-1:0] initiate_addr_a = execute_if_data_rs1[0];
wire [`XLEN-1:0] initiate_addr_b = execute_if_data_rs2[0];
`RUNTIME_ASSERT(!metadata_valid || execute_if_data_tmask[0],
("tmask for HGMMA instruction is invalid"))
// we're recycling execute_if.op_type as operands_if.op_type which might // we're recycling execute_if.op_type as operands_if.op_type which might
// have a different width; let's be safe // have a different width; let's be safe
@@ -107,17 +116,17 @@ module VX_tensor_hopper_core_block import VX_gpu_pkg::*; #(
// /* // /*
// fake fsm driving tc rf port // fake fsm driving tc rf port
reg [11:0] counter; // reg [11:0] counter;
always @(posedge clk) begin // always @(posedge clk) begin
if (reset) begin // if (reset) begin
counter <= 12'd1; // counter <= 12'd1;
end else begin // end else begin
counter <= counter + 12'd1; // counter <= counter + 12'd1;
end // end
end // end
assign regfile_if.req_valid = (counter[3:0] != 4'd0); // assign regfile_if.req_valid = (counter[6:0] == 7'd0);
assign regfile_if.req_data.wis = '0; // assign regfile_if.req_data.wis = '0;
assign regfile_if.req_data.rs = counter[11:7]; // assign regfile_if.req_data.rs = counter[11:7];
// */ // */
TensorCoreDecoupled tensor_hopper_core ( TensorCoreDecoupled tensor_hopper_core (
@@ -127,6 +136,8 @@ module VX_tensor_hopper_core_block import VX_gpu_pkg::*; #(
.io_initiate_ready(initiate_ready), .io_initiate_ready(initiate_ready),
.io_initiate_valid(initiate_valid), .io_initiate_valid(initiate_valid),
.io_initiate_bits_wid(initiate_wid), .io_initiate_bits_wid(initiate_wid),
.io_initiate_bits_addressA(initiate_addr_a),
.io_initiate_bits_addressB(initiate_addr_b),
.io_writeback_ready(writeback_ready), .io_writeback_ready(writeback_ready),
.io_writeback_valid(writeback_valid), .io_writeback_valid(writeback_valid),
@@ -150,6 +161,7 @@ module VX_tensor_hopper_core_block import VX_gpu_pkg::*; #(
.io_respB_valid(smem_B_if.rsp_valid), .io_respB_valid(smem_B_if.rsp_valid),
.io_respB_bits_source(smem_B_if.rsp_data.tag), .io_respB_bits_source(smem_B_if.rsp_data.tag),
.io_respB_bits_data(smem_B_if.rsp_data.data), .io_respB_bits_data(smem_B_if.rsp_data.data),
.io_respC(regfile_if.rsp_data.data),
.io_reqA_ready(smem_A_if.req_ready), .io_reqA_ready(smem_A_if.req_ready),
.io_reqA_valid(smem_A_if.req_valid), .io_reqA_valid(smem_A_if.req_valid),
@@ -158,8 +170,15 @@ module VX_tensor_hopper_core_block import VX_gpu_pkg::*; #(
.io_reqB_ready(smem_B_if.req_ready), .io_reqB_ready(smem_B_if.req_ready),
.io_reqB_valid(smem_B_if.req_valid), .io_reqB_valid(smem_B_if.req_valid),
.io_reqB_bits_source(smem_B_if.req_data.tag), .io_reqB_bits_source(smem_B_if.req_data.tag),
.io_reqB_bits_address(smem_B_if.req_data.addr) .io_reqB_bits_address(smem_B_if.req_data.addr),
.io_reqC_valid(regfile_if.req_valid),
.io_reqC_bits(regfile_if.req_data.rs[4:0])
); );
// add offset of 32 for fp regs
assign regfile_if.req_data.rs[5] = 1'b1;
assign regfile_if.req_data.wis = '0;
`STATIC_ASSERT((`ISSUE_WIDTH == `NUM_WARPS),
("static assertion failed: tensor_hopper_core assumes ISSUE_WIDTH == NUM_WARPS"))
// VX_tensor_hopper_core #( // VX_tensor_hopper_core #(
// ) tensor_hopper_core ( // ) tensor_hopper_core (