Implement WU architecture support

This commit is contained in:
2026-05-25 19:25:05 +08:00
parent 323ed7d7e9
commit 0ad87bde81
35 changed files with 3303 additions and 472 deletions

View File

@@ -7,6 +7,7 @@ module Vortex import VX_gpu_pkg::*; #(
parameter TENSOR_FP16 = 0,
parameter logic [63:0] STARTUP_ADDR = 64'h0000_0000_0001_0100,
parameter NUM_THREADS = 0,
parameter NUM_TENSOR_CORES = 1,
parameter TC_DATA_WIDTH = 256,
parameter TC_TAG_WIDTH = 4
) (
@@ -77,26 +78,32 @@ module Vortex import VX_gpu_pkg::*; #(
output [(DCACHE_NUM_REQS * 32) - 1:0] smem_a_bits_data,
// tc --------------------------------------------------
input [2:0] tc_a_ready,
output [2:0] tc_a_valid,
output [2:0] tc_a_bits_write,
output [95:0] tc_a_bits_address,
output [3 * TC_TAG_WIDTH - 1:0] tc_a_bits_tag,
output [3 * 32 - 1:0] tc_a_bits_mask,
output [3 * TC_DATA_WIDTH - 1:0] tc_a_bits_data,
output [2:0] tc_d_ready,
input [2:0] tc_d_valid,
input [3 * TC_DATA_WIDTH - 1:0] tc_d_bits_data,
input [3 * TC_TAG_WIDTH - 1:0] tc_d_bits_tag,
input [NUM_TENSOR_CORES * 3 - 1:0] tc_a_ready,
output [NUM_TENSOR_CORES * 3 - 1:0] tc_a_valid,
output [NUM_TENSOR_CORES * 3 - 1:0] tc_a_bits_write,
output [NUM_TENSOR_CORES * 3 * 32 - 1:0] tc_a_bits_address,
output [NUM_TENSOR_CORES * 3 * TC_TAG_WIDTH - 1:0] tc_a_bits_tag,
output [NUM_TENSOR_CORES * 3 * 32 - 1:0] tc_a_bits_mask,
output [NUM_TENSOR_CORES * 3 * TC_DATA_WIDTH - 1:0] tc_a_bits_data,
output [NUM_TENSOR_CORES * 3 - 1:0] tc_d_ready,
input [NUM_TENSOR_CORES * 3 - 1:0] tc_d_valid,
input [NUM_TENSOR_CORES * 3 * TC_DATA_WIDTH - 1:0] tc_d_bits_data,
input [NUM_TENSOR_CORES * 3 * TC_TAG_WIDTH - 1:0] tc_d_bits_tag,
// tmem_C direct SRAM port
output tc_tmem_C_wen,
output tc_tmem_C_ren,
output [8:0] tc_tmem_C_waddr,
output [8:0] tc_tmem_C_raddr,
output [`NUM_THREADS*`XLEN-1:0] tc_tmem_C_wdata,
output [`NUM_THREADS*`XLEN/8-1:0] tc_tmem_C_mask,
input [`NUM_THREADS*`XLEN-1:0] tc_tmem_C_rdata,
// shared tmem direct SRAM ports
output [NUM_TENSOR_CORES-1:0] tc_tmem_A_ren,
input [NUM_TENSOR_CORES-1:0] tc_tmem_A_rready,
output [NUM_TENSOR_CORES*9-1:0] tc_tmem_A_raddr,
input [NUM_TENSOR_CORES*`NUM_THREADS*`XLEN-1:0] tc_tmem_A_rdata,
output [NUM_TENSOR_CORES-1:0] tc_tmem_C_ren,
input [NUM_TENSOR_CORES-1:0] tc_tmem_C_rready,
output [NUM_TENSOR_CORES*9-1:0] tc_tmem_C_raddr,
input [NUM_TENSOR_CORES*`NUM_THREADS*`XLEN-1:0] tc_tmem_C_rdata,
output [NUM_TENSOR_CORES-1:0] tc_tmem_C_wen,
input [NUM_TENSOR_CORES-1:0] tc_tmem_C_wready,
output [NUM_TENSOR_CORES*9-1:0] tc_tmem_C_waddr,
output [NUM_TENSOR_CORES*`NUM_THREADS*`XLEN-1:0] tc_tmem_C_wdata,
output [NUM_TENSOR_CORES*`NUM_THREADS*`XLEN/8-1:0] tc_tmem_C_mask,
// gbar ------------------------------------------------
@@ -314,24 +321,52 @@ module Vortex import VX_gpu_pkg::*; #(
endgenerate
// tc ---------------------------------------------------------------------
VX_tc_bus_if #(.TAG_WIDTH(TC_TAG_WIDTH)) tc_p0_bus_if();
VX_tc_bus_if #(.TAG_WIDTH(TC_TAG_WIDTH)) tc_p2_bus_if();
// tc_p1 (tmem_C) is now a direct SRAM port exposed as top-level ports tc_tmem_C_*
assign tc_a_valid = {tc_p2_bus_if.req_valid, 1'b0, tc_p0_bus_if.req_valid};
assign tc_a_bits_write = {tc_p2_bus_if.req_data.rw, 1'b0, tc_p0_bus_if.req_data.rw};
assign tc_a_bits_address = {tc_p2_bus_if.req_data.addr, 32'b0, tc_p0_bus_if.req_data.addr};
assign tc_a_bits_tag = {tc_p2_bus_if.req_data.tag, {TC_TAG_WIDTH{1'b0}}, tc_p0_bus_if.req_data.tag};
assign tc_a_bits_mask = {tc_p2_bus_if.req_data.byteen, {(TC_DATA_WIDTH/8){1'b0}},tc_p0_bus_if.req_data.byteen};
assign tc_a_bits_data = {tc_p2_bus_if.req_data.data, {TC_DATA_WIDTH{1'b0}}, tc_p0_bus_if.req_data.data};
assign tc_p0_bus_if.req_ready = tc_a_ready[0];
assign tc_p0_bus_if.rsp_valid = tc_d_valid[0];
assign tc_p0_bus_if.rsp_data.data = tc_d_bits_data[0 * TC_DATA_WIDTH +: TC_DATA_WIDTH];
assign tc_p0_bus_if.rsp_data.tag = tc_d_bits_tag[0 * TC_TAG_WIDTH +: TC_TAG_WIDTH];
assign tc_p2_bus_if.req_ready = tc_a_ready[2];
assign tc_p2_bus_if.rsp_valid = tc_d_valid[2];
assign tc_p2_bus_if.rsp_data.data = tc_d_bits_data[2 * TC_DATA_WIDTH +: TC_DATA_WIDTH];
assign tc_p2_bus_if.rsp_data.tag = tc_d_bits_tag[2 * TC_TAG_WIDTH +: TC_TAG_WIDTH];
assign tc_d_ready = {tc_p2_bus_if.rsp_ready, 1'b0, tc_p0_bus_if.rsp_ready};
VX_tc_bus_if #(.TAG_WIDTH(TC_TAG_WIDTH)) tc_p0_bus_if[NUM_TENSOR_CORES]();
VX_tc_bus_if #(.TAG_WIDTH(TC_TAG_WIDTH)) tc_p2_bus_if[NUM_TENSOR_CORES]();
for (genvar tc = 0; tc < NUM_TENSOR_CORES; ++tc) begin : g_tc_ports
localparam P0 = tc * 3;
localparam P1 = tc * 3 + 1;
localparam P2 = tc * 3 + 2;
assign tc_a_valid[P0] = tc_p0_bus_if[tc].req_valid;
assign tc_a_valid[P1] = 1'b0;
assign tc_a_valid[P2] = tc_p2_bus_if[tc].req_valid;
assign tc_a_bits_write[P0] = tc_p0_bus_if[tc].req_data.rw;
assign tc_a_bits_write[P1] = 1'b0;
assign tc_a_bits_write[P2] = tc_p2_bus_if[tc].req_data.rw;
assign tc_a_bits_address[P0 * 32 +: 32] = tc_p0_bus_if[tc].req_data.addr;
assign tc_a_bits_address[P1 * 32 +: 32] = 32'b0;
assign tc_a_bits_address[P2 * 32 +: 32] = tc_p2_bus_if[tc].req_data.addr;
assign tc_a_bits_tag[P0 * TC_TAG_WIDTH +: TC_TAG_WIDTH] = tc_p0_bus_if[tc].req_data.tag;
assign tc_a_bits_tag[P1 * TC_TAG_WIDTH +: TC_TAG_WIDTH] = '0;
assign tc_a_bits_tag[P2 * TC_TAG_WIDTH +: TC_TAG_WIDTH] = tc_p2_bus_if[tc].req_data.tag;
assign tc_a_bits_mask[P0 * 32 +: 32] = tc_p0_bus_if[tc].req_data.byteen;
assign tc_a_bits_mask[P1 * 32 +: 32] = '0;
assign tc_a_bits_mask[P2 * 32 +: 32] = tc_p2_bus_if[tc].req_data.byteen;
assign tc_a_bits_data[P0 * TC_DATA_WIDTH +: TC_DATA_WIDTH] = tc_p0_bus_if[tc].req_data.data;
assign tc_a_bits_data[P1 * TC_DATA_WIDTH +: TC_DATA_WIDTH] = '0;
assign tc_a_bits_data[P2 * TC_DATA_WIDTH +: TC_DATA_WIDTH] = tc_p2_bus_if[tc].req_data.data;
assign tc_p0_bus_if[tc].req_ready = tc_a_ready[P0];
assign tc_p0_bus_if[tc].rsp_valid = tc_d_valid[P0];
assign tc_p0_bus_if[tc].rsp_data.data = tc_d_bits_data[P0 * TC_DATA_WIDTH +: TC_DATA_WIDTH];
assign tc_p0_bus_if[tc].rsp_data.tag = tc_d_bits_tag[P0 * TC_TAG_WIDTH +: TC_TAG_WIDTH];
assign tc_p2_bus_if[tc].req_ready = tc_a_ready[P2];
assign tc_p2_bus_if[tc].rsp_valid = tc_d_valid[P2];
assign tc_p2_bus_if[tc].rsp_data.data = tc_d_bits_data[P2 * TC_DATA_WIDTH +: TC_DATA_WIDTH];
assign tc_p2_bus_if[tc].rsp_data.tag = tc_d_bits_tag[P2 * TC_TAG_WIDTH +: TC_TAG_WIDTH];
assign tc_d_ready[P0] = tc_p0_bus_if[tc].rsp_ready;
assign tc_d_ready[P1] = 1'b0;
assign tc_d_ready[P2] = tc_p2_bus_if[tc].rsp_ready;
end
// gbar -------------------------------------------------------------------
`ifdef GBAR_ENABLE
@@ -439,7 +474,8 @@ module Vortex import VX_gpu_pkg::*; #(
// TODO: SCOPE_IO_BIND should be socket id
VX_core #(
.CORE_ID (CORE_ID),
.TENSOR_FP16 (TENSOR_FP16)
.TENSOR_FP16 (TENSOR_FP16),
.NUM_TENSOR_CORES (NUM_TENSOR_CORES)
) core (
`SCOPE_IO_BIND (0)
@@ -465,22 +501,34 @@ module Vortex import VX_gpu_pkg::*; #(
.tensor_smem_A_if (tc_p0_bus_if),
`ifdef EXT_T_BLACKWELL
.tensor_tmem_C_wen(tc_tmem_C_wen),
.tensor_tmem_A_ren(tc_tmem_A_ren),
.tensor_tmem_A_rready(tc_tmem_A_rready),
.tensor_tmem_A_raddr(tc_tmem_A_raddr),
.tensor_tmem_A_rdata(tc_tmem_A_rdata),
.tensor_tmem_C_ren(tc_tmem_C_ren),
.tensor_tmem_C_waddr(tc_tmem_C_waddr),
.tensor_tmem_C_rready(tc_tmem_C_rready),
.tensor_tmem_C_raddr(tc_tmem_C_raddr),
.tensor_tmem_C_rdata(tc_tmem_C_rdata),
.tensor_tmem_C_wen(tc_tmem_C_wen),
.tensor_tmem_C_wready(tc_tmem_C_wready),
.tensor_tmem_C_waddr(tc_tmem_C_waddr),
.tensor_tmem_C_wdata(tc_tmem_C_wdata),
.tensor_tmem_C_mask(tc_tmem_C_mask),
.tensor_tmem_C_rdata(tc_tmem_C_rdata),
.tensor_smem_B_if (tc_p2_bus_if),
`else
.tensor_tmem_C_wen(tc_tmem_C_wen),
.tensor_tmem_A_ren(tc_tmem_A_ren),
.tensor_tmem_A_rready(tc_tmem_A_rready),
.tensor_tmem_A_raddr(tc_tmem_A_raddr),
.tensor_tmem_A_rdata(tc_tmem_A_rdata),
.tensor_tmem_C_ren(tc_tmem_C_ren),
.tensor_tmem_C_waddr(tc_tmem_C_waddr),
.tensor_tmem_C_rready(tc_tmem_C_rready),
.tensor_tmem_C_raddr(tc_tmem_C_raddr),
.tensor_tmem_C_rdata(tc_tmem_C_rdata),
.tensor_tmem_C_wen(tc_tmem_C_wen),
.tensor_tmem_C_wready(tc_tmem_C_wready),
.tensor_tmem_C_waddr(tc_tmem_C_waddr),
.tensor_tmem_C_wdata(tc_tmem_C_wdata),
.tensor_tmem_C_mask(tc_tmem_C_mask),
.tensor_tmem_C_rdata(tc_tmem_C_rdata),
.tensor_smem_B_if (tc_p2_bus_if),
`endif
@@ -583,7 +631,7 @@ module Vortex import VX_gpu_pkg::*; #(
$display("simulation has probably ended. exiting");
$finish();
end
if (finished) begin
if (busy_prev && !busy) begin
$display("---------------- core%2d has no more active warps ----------------", CORE_ID);
// TODO: lane assumed to be 4
// `ifndef SYNTHESIS