Implement WU architecture support

This commit is contained in:
2026-05-25 19:25:05 +08:00
parent 323ed7d7e9
commit 0ad87bde81
35 changed files with 3303 additions and 472 deletions

View File

@@ -39,14 +39,15 @@ module VX_decode #(
// outputs
VX_decode_if.master decode_if,
`ifdef PERF_ENABLE
VX_pipeline_perf_if.decode perf_decode_if,
`endif
VX_decode_sched_if.master decode_sched_if
);
localparam DATAW = `UUID_WIDTH + `NW_WIDTH + `NUM_THREADS + `XLEN + `EX_BITS + `INST_OP_BITS + `INST_MOD_BITS + 1 + (`NR_BITS * 4) + `XLEN + 1 + 1;
`UNUSED_PARAM (CORE_ID)
`UNUSED_VAR (clk)
`UNUSED_VAR (reset)
reg [`EX_BITS-1:0] ex_type;
reg [`INST_OP_BITS-1:0] op_type;
@@ -488,6 +489,17 @@ module VX_decode #(
`USED_IREG (rs1);
`USED_IREG (rs2);
end
3'h6: begin // WSPAWN_MASK
op_type = `INST_OP_BITS'(`INST_SFU_WSPAWN);
op_mod[0] = 1;
`USED_IREG (rs1);
`USED_IREG (rs2);
end
3'h7: begin // BAR_MASK
op_type = `INST_OP_BITS'(`INST_SFU_BAR_MASK);
`USED_IREG (rs1);
`USED_IREG (rs2);
end
default:;
endcase
end
@@ -620,8 +632,93 @@ module VX_decode #(
endcase
end
wire fetch_is_tensor_warp = fetch_if.data.wid >= `NW_WIDTH'(`NUM_SCALAR_WARPS);
wire fetch_is_scalar_warp = fetch_if.data.wid < `NW_WIDTH'(`NUM_SCALAR_WARPS);
wire fetch_fire = fetch_if.valid && fetch_if.ready;
wire decoded_tensor_inst = (ex_type == `EX_BITS'(`EX_TENSOR));
wire t_reg_hi_rd = use_rd && (rd_r[`NRI_BITS-1:3] != '0);
wire t_reg_hi_rs1 = use_rs1 && (rs1_r[`NRI_BITS-1:3] != '0);
wire t_reg_hi_rs2 = use_rs2 && (rs2_r[`NRI_BITS-1:3] != '0);
wire t_reg_hi_rs3 = use_rs3 && (rs3_r[`NRI_BITS-1:3] != '0);
wire tensor_reg_illegal = fetch_is_tensor_warp && (t_reg_hi_rd || t_reg_hi_rs1 || t_reg_hi_rs2 || t_reg_hi_rs3);
wire scalar_tensor_illegal = fetch_is_scalar_warp && decoded_tensor_inst;
wire tensor_fpu_illegal = fetch_is_tensor_warp && (ex_type == `EX_BITS'(`EX_FPU));
wire tensor_read_csr_allowed = (op_type == `INST_OP_BITS'(`INST_SFU_CSRRS))
&& (rs1_r == `NR_BITS'(0))
&& ((u_12 == `VX_CSR_THREAD_ID)
|| (u_12 == `VX_CSR_WARP_ID)
|| (u_12 == `VX_CSR_CORE_ID)
|| (u_12 == `VX_CSR_MHARTID)
|| (u_12 == `VX_CSR_NUM_THREADS)
|| (u_12 == `VX_CSR_NUM_WARPS)
|| (u_12 == `VX_CSR_NUM_CORES));
wire tensor_sfu_barrier_allowed = (op_type == `INST_OP_BITS'(`INST_SFU_BAR))
|| (op_type == `INST_OP_BITS'(`INST_SFU_BAR_MASK));
wire tensor_sfu_allowed = (op_type == `INST_OP_BITS'(`INST_SFU_TMC))
|| tensor_read_csr_allowed
|| tensor_sfu_barrier_allowed;
wire tensor_sfu_illegal = fetch_is_tensor_warp
&& (ex_type == `EX_BITS'(`EX_SFU))
&& !tensor_sfu_allowed;
wire tensor_complex_alu_illegal = fetch_is_tensor_warp
&& (ex_type == `EX_BITS'(`EX_ALU))
&& (`INST_ALU_IS_M(op_mod) || `INST_ALU_IS_RED(op_mod));
wire tensor_scalar_illegal = tensor_fpu_illegal || tensor_sfu_illegal || tensor_complex_alu_illegal;
wire decode_illegal = tensor_reg_illegal || scalar_tensor_illegal || tensor_scalar_illegal;
wire [`EX_BITS-1:0] emit_ex_type = decode_illegal ? `EX_BITS'(`EX_ALU) : ex_type;
wire [`INST_OP_BITS-1:0] emit_op_type = decode_illegal ? `INST_OP_BITS'(`INST_BR_EBREAK) : op_type;
wire [`INST_MOD_BITS-1:0] emit_op_mod = decode_illegal ? `INST_MOD_BITS'(1) : op_mod;
wire emit_use_PC = decode_illegal ? 1'b1 : use_PC;
wire emit_use_imm = decode_illegal ? 1'b1 : use_imm;
wire [`XLEN-1:0] emit_imm = decode_illegal ? `XLEN'(0) : imm;
wire [`NR_BITS-1:0] emit_rd = decode_illegal ? `NR_BITS'(0) : rd_r;
wire [`NR_BITS-1:0] emit_rs1 = decode_illegal ? `NR_BITS'(0) : rs1_r;
wire [`NR_BITS-1:0] emit_rs2 = decode_illegal ? `NR_BITS'(0) : rs2_r;
wire [`NR_BITS-1:0] emit_rs3 = decode_illegal ? `NR_BITS'(0) : rs3_r;
`RUNTIME_ASSERT(
!fetch_if.valid || !tensor_reg_illegal,
("%t: *** core%0d-decode-illegal-tensor-reg: wid=%0d PC=0x%0h instr=0x%0h ex=%0d op=%0d rd=%0d rs1=%0d rs2=%0d rs3=%0d",
$time, CORE_ID, fetch_if.data.wid, fetch_if.data.PC, fetch_if.data.instr, ex_type, op_type, rd, rs1, rs2, rs3)
)
`RUNTIME_ASSERT(
!fetch_if.valid || !scalar_tensor_illegal,
("%t: *** core%0d-decode-illegal-scalar-tensor-op: wid=%0d PC=0x%0h instr=0x%0h ex=%0d op=%0d",
$time, CORE_ID, fetch_if.data.wid, fetch_if.data.PC, fetch_if.data.instr, ex_type, op_type)
)
`RUNTIME_ASSERT(
!fetch_if.valid || !tensor_scalar_illegal,
("%t: *** core%0d-decode-illegal-tensor-scalar-op: wid=%0d PC=0x%0h instr=0x%0h ex=%0d op=%0d mod=%0d",
$time, CORE_ID, fetch_if.data.wid, fetch_if.data.PC, fetch_if.data.instr, ex_type, op_type, op_mod)
)
`ifdef PERF_ENABLE
reg [`PERF_CTR_BITS-1:0] perf_illegal_tensor_reg_access;
reg [`PERF_CTR_BITS-1:0] perf_illegal_tensor_scalar_op;
reg [`PERF_CTR_BITS-1:0] perf_illegal_scalar_tensor_op;
always @(posedge clk) begin
if (reset) begin
perf_illegal_tensor_reg_access <= '0;
perf_illegal_tensor_scalar_op <= '0;
perf_illegal_scalar_tensor_op <= '0;
end else if (fetch_fire) begin
perf_illegal_tensor_reg_access <= perf_illegal_tensor_reg_access + `PERF_CTR_BITS'(tensor_reg_illegal);
perf_illegal_tensor_scalar_op <= perf_illegal_tensor_scalar_op + `PERF_CTR_BITS'(tensor_scalar_illegal);
perf_illegal_scalar_tensor_op <= perf_illegal_scalar_tensor_op + `PERF_CTR_BITS'(scalar_tensor_illegal);
end
end
assign perf_decode_if.illegal_tensor_reg_access = perf_illegal_tensor_reg_access;
assign perf_decode_if.illegal_tensor_scalar_op = perf_illegal_tensor_scalar_op;
assign perf_decode_if.illegal_scalar_tensor_op = perf_illegal_scalar_tensor_op;
`endif
// disable write to integer register r0
wire wb = use_rd && (rd_r != 0);
wire wb = !decode_illegal && use_rd && (rd_r != 0);
VX_elastic_buffer #(
.DATAW (DATAW),
@@ -631,7 +728,7 @@ module VX_decode #(
.reset (reset),
.valid_in (fetch_if.valid),
.ready_in (fetch_if.ready),
.data_in ({fetch_if.data.uuid, fetch_if.data.wid, fetch_if.data.tmask, fetch_if.data.PC, ex_type, op_type, op_mod, use_PC, imm, use_imm, wb, rd_r, rs1_r, rs2_r, rs3_r}),
.data_in ({fetch_if.data.uuid, fetch_if.data.wid, fetch_if.data.tmask, fetch_if.data.PC, emit_ex_type, emit_op_type, emit_op_mod, emit_use_PC, emit_imm, emit_use_imm, wb, emit_rd, emit_rs1, emit_rs2, emit_rs3}),
.data_out ({decode_if.data.uuid, decode_if.data.wid, decode_if.data.tmask, decode_if.data.PC, decode_if.data.ex_type, decode_if.data.op_type, decode_if.data.op_mod, decode_if.data.use_PC, decode_if.data.imm, decode_if.data.use_imm, decode_if.data.wb, decode_if.data.rd, decode_if.data.rs1, decode_if.data.rs2, decode_if.data.rs3}),
.valid_out (decode_if.valid),
.ready_out (decode_if.ready)
@@ -639,11 +736,9 @@ module VX_decode #(
///////////////////////////////////////////////////////////////////////////
wire fetch_fire = fetch_if.valid && fetch_if.ready;
assign decode_sched_if.valid = fetch_fire;
assign decode_sched_if.wid = fetch_if.data.wid;
assign decode_sched_if.is_wstall = is_wstall;
assign decode_sched_if.is_wstall = is_wstall || decode_illegal;
`ifndef L1_ENABLE
assign fetch_if.ibuf_pop = decode_if.ibuf_pop;
`endif