From 9528335a046c2024590e9e454aace01fc8fca8c8 Mon Sep 17 00:00:00 2001 From: Lixuanwang Date: Sat, 19 Jul 2025 17:50:14 +0800 Subject: [PATCH] =?UTF-8?q?[backend-llir]=E4=BF=AE=E5=A4=8D=E4=BA=86?= =?UTF-8?q?=E8=AE=B8=E5=A4=9A=E9=87=8D=E6=9E=84=E7=9A=84bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/RISCv64AsmPrinter.cpp | 281 ++++++++++++++--------------- src/RISCv64Backend.cpp | 33 +--- src/RISCv64ISel.cpp | 306 ++++++++++++++++++-------------- src/RISCv64Passes.cpp | 8 + src/RISCv64RegAlloc.cpp | 179 ++++++++++++------- src/include/RISCv64AsmPrinter.h | 26 ++- src/include/RISCv64Backend.h | 8 +- src/include/RISCv64ISel.h | 34 ++-- src/include/RISCv64LLIR.h | 104 ++++------- src/include/RISCv64Passes.h | 18 ++ src/include/RISCv64RegAlloc.h | 13 +- 11 files changed, 513 insertions(+), 497 deletions(-) create mode 100644 src/RISCv64Passes.cpp create mode 100644 src/include/RISCv64Passes.h diff --git a/src/RISCv64AsmPrinter.cpp b/src/RISCv64AsmPrinter.cpp index 0688095..0ad1c81 100644 --- a/src/RISCv64AsmPrinter.cpp +++ b/src/RISCv64AsmPrinter.cpp @@ -1,44 +1,71 @@ #include "RISCv64AsmPrinter.h" +#include "RISCv64ISel.h" #include namespace sysy { -void RISCv64AsmPrinter::runOnMachineFunction(MachineFunction* mfunc, std::ostream& os) { +// 检查是否为内存加载/存储指令,以处理特殊的打印格式 +bool isMemoryOp(RVOpcodes opcode) { + switch (opcode) { + case RVOpcodes::LB: case RVOpcodes::LH: case RVOpcodes::LW: case RVOpcodes::LD: + case RVOpcodes::LBU: case RVOpcodes::LHU: case RVOpcodes::LWU: + case RVOpcodes::SB: case RVOpcodes::SH: case RVOpcodes::SW: case RVOpcodes::SD: + return true; + default: + return false; + } +} + +RISCv64AsmPrinter::RISCv64AsmPrinter(MachineFunction* mfunc) : MFunc(mfunc) {} + +void RISCv64AsmPrinter::run(std::ostream& os) { OS = &os; - // 打印函数声明和全局符号 - *OS << ".text\n"; - *OS << ".globl " << mfunc->getName() << "\n"; - *OS << mfunc->getName() << ":\n"; + *OS << ".globl " << MFunc->getName() << "\n"; + *OS << MFunc->getName() << ":\n"; - // 打印函数序言 - printPrologue(mfunc); + printPrologue(); - // 遍历并打印所有基本块 - for (auto& mbb : mfunc->getBlocks()) { + for (auto& mbb : MFunc->getBlocks()) { printBasicBlock(mbb.get()); } } -void RISCv64AsmPrinter::printPrologue(MachineFunction* mfunc) { - int stack_size = mfunc->getFrameInfo().frame_size; - - // 确保栈大小是16字节对齐 - int aligned_stack_size = (stack_size + 15) & ~15; +void RISCv64AsmPrinter::printPrologue() { + StackFrameInfo& frame_info = MFunc->getFrameInfo(); + // 序言需要为保存ra和s0预留16字节 + int total_stack_size = frame_info.locals_size + frame_info.spill_size + 16; + int aligned_stack_size = (total_stack_size + 15) & ~15; + frame_info.total_size = aligned_stack_size; if (aligned_stack_size > 0) { *OS << " addi sp, sp, -" << aligned_stack_size << "\n"; - // RV64中ra和s0都是8字节 *OS << " sd ra, " << (aligned_stack_size - 8) << "(sp)\n"; *OS << " sd s0, " << (aligned_stack_size - 16) << "(sp)\n"; *OS << " mv s0, sp\n"; } + + // 忠实还原保存函数入口参数的逻辑 + Function* F = MFunc->getFunc(); + if (F && F->getEntryBlock()) { + int arg_idx = 0; + RISCv64ISel* isel = MFunc->getISel(); + for (AllocaInst* alloca_for_param : F->getEntryBlock()->getArguments()) { + if (arg_idx >= 8) break; + + unsigned vreg = isel->getVReg(alloca_for_param); + if (frame_info.alloca_offsets.count(vreg)) { + int offset = frame_info.alloca_offsets.at(vreg); + auto arg_reg = static_cast(static_cast(PhysicalReg::A0) + arg_idx); + *OS << " sw " << regToString(arg_reg) << ", " << offset << "(s0)\n"; + } + arg_idx++; + } + } } -void RISCv64AsmPrinter::printEpilogue(MachineFunction* mfunc) { - int stack_size = mfunc->getFrameInfo().frame_size; - int aligned_stack_size = (stack_size + 15) & ~15; - +void RISCv64AsmPrinter::printEpilogue() { + int aligned_stack_size = MFunc->getFrameInfo().total_size; if (aligned_stack_size > 0) { *OS << " ld ra, " << (aligned_stack_size - 8) << "(sp)\n"; *OS << " ld s0, " << (aligned_stack_size - 16) << "(sp)\n"; @@ -46,129 +73,85 @@ void RISCv64AsmPrinter::printEpilogue(MachineFunction* mfunc) { } } - void RISCv64AsmPrinter::printBasicBlock(MachineBasicBlock* mbb) { - // 打印基本块标签 if (!mbb->getName().empty()) { *OS << mbb->getName() << ":\n"; } - - // 打印指令 for (auto& instr : mbb->getInstructions()) { - printInstruction(instr.get(), mbb); + printInstruction(instr.get()); } } -void RISCv64AsmPrinter::printInstruction(MachineInstr* instr, MachineBasicBlock* parent_bb) { - *OS << " "; // 指令缩进 - +void RISCv64AsmPrinter::printInstruction(MachineInstr* instr) { auto opcode = instr->getOpcode(); - - // RET指令需要特殊处理,在打印ret之前先打印函数尾声 if (opcode == RVOpcodes::RET) { - printEpilogue(parent_bb->getParent()); + printEpilogue(); } - - // 使用switch将Opcode转换为汇编助记符 + if (opcode != RVOpcodes::LABEL) { + *OS << " "; + } + switch (opcode) { - // Arithmatic - case RVOpcodes::ADD: *OS << "add "; break; - case RVOpcodes::ADDI: *OS << "addi "; break; - case RVOpcodes::ADDW: *OS << "addw "; break; - case RVOpcodes::ADDIW: *OS << "addiw "; break; - case RVOpcodes::SUB: *OS << "sub "; break; - case RVOpcodes::SUBW: *OS << "subw "; break; - case RVOpcodes::MUL: *OS << "mul "; break; - case RVOpcodes::MULW: *OS << "mulw "; break; - case RVOpcodes::DIV: *OS << "div "; break; - case RVOpcodes::DIVW: *OS << "divw "; break; - case RVOpcodes::REM: *OS << "rem "; break; - case RVOpcodes::REMW: *OS << "remw "; break; - // Logical - case RVOpcodes::XOR: *OS << "xor "; break; - case RVOpcodes::XORI: *OS << "xori "; break; - case RVOpcodes::OR: *OS << "or "; break; - case RVOpcodes::ORI: *OS << "ori "; break; - case RVOpcodes::AND: *OS << "and "; break; - case RVOpcodes::ANDI: *OS << "andi "; break; - // Shift - case RVOpcodes::SLL: *OS << "sll "; break; - case RVOpcodes::SLLI: *OS << "slli "; break; - case RVOpcodes::SLLW: *OS << "sllw "; break; - case RVOpcodes::SLLIW: *OS << "slliw "; break; - case RVOpcodes::SRL: *OS << "srl "; break; - case RVOpcodes::SRLI: *OS << "srli "; break; - case RVOpcodes::SRLW: *OS << "srlw "; break; - case RVOpcodes::SRLIW: *OS << "srliw "; break; - case RVOpcodes::SRA: *OS << "sra "; break; - case RVOpcodes::SRAI: *OS << "srai "; break; - case RVOpcodes::SRAW: *OS << "sraw "; break; - case RVOpcodes::SRAIW: *OS << "sraiw "; break; - // Compare - case RVOpcodes::SLT: *OS << "slt "; break; - case RVOpcodes::SLTI: *OS << "slti "; break; - case RVOpcodes::SLTU: *OS << "sltu "; break; - case RVOpcodes::SLTIU: *OS << "sltiu "; break; - // Memory - case RVOpcodes::LW: *OS << "lw "; break; - case RVOpcodes::LH: *OS << "lh "; break; - case RVOpcodes::LB: *OS << "lb "; break; - case RVOpcodes::LWU: *OS << "lwu "; break; - case RVOpcodes::LHU: *OS << "lhu "; break; - case RVOpcodes::LBU: *OS << "lbu "; break; - case RVOpcodes::SW: *OS << "sw "; break; - case RVOpcodes::SH: *OS << "sh "; break; - case RVOpcodes::SB: *OS << "sb "; break; - case RVOpcodes::LD: *OS << "ld "; break; + case RVOpcodes::ADD: *OS << "add "; break; case RVOpcodes::ADDI: *OS << "addi "; break; + case RVOpcodes::ADDW: *OS << "addw "; break; case RVOpcodes::ADDIW: *OS << "addiw "; break; + case RVOpcodes::SUB: *OS << "sub "; break; case RVOpcodes::SUBW: *OS << "subw "; break; + case RVOpcodes::MUL: *OS << "mul "; break; case RVOpcodes::MULW: *OS << "mulw "; break; + case RVOpcodes::DIV: *OS << "div "; break; case RVOpcodes::DIVW: *OS << "divw "; break; + case RVOpcodes::REM: *OS << "rem "; break; case RVOpcodes::REMW: *OS << "remw "; break; + case RVOpcodes::XOR: *OS << "xor "; break; case RVOpcodes::XORI: *OS << "xori "; break; + case RVOpcodes::OR: *OS << "or "; break; case RVOpcodes::ORI: *OS << "ori "; break; + case RVOpcodes::AND: *OS << "and "; break; case RVOpcodes::ANDI: *OS << "andi "; break; + case RVOpcodes::SLL: *OS << "sll "; break; case RVOpcodes::SLLI: *OS << "slli "; break; + case RVOpcodes::SLLW: *OS << "sllw "; break; case RVOpcodes::SLLIW: *OS << "slliw "; break; + case RVOpcodes::SRL: *OS << "srl "; break; case RVOpcodes::SRLI: *OS << "srli "; break; + case RVOpcodes::SRLW: *OS << "srlw "; break; case RVOpcodes::SRLIW: *OS << "srliw "; break; + case RVOpcodes::SRA: *OS << "sra "; break; case RVOpcodes::SRAI: *OS << "srai "; break; + case RVOpcodes::SRAW: *OS << "sraw "; break; case RVOpcodes::SRAIW: *OS << "sraiw "; break; + case RVOpcodes::SLT: *OS << "slt "; break; case RVOpcodes::SLTI: *OS << "slti "; break; + case RVOpcodes::SLTU: *OS << "sltu "; break; case RVOpcodes::SLTIU: *OS << "sltiu "; break; + case RVOpcodes::LW: *OS << "lw "; break; case RVOpcodes::LH: *OS << "lh "; break; + case RVOpcodes::LB: *OS << "lb "; break; case RVOpcodes::LWU: *OS << "lwu "; break; + case RVOpcodes::LHU: *OS << "lhu "; break; case RVOpcodes::LBU: *OS << "lbu "; break; + case RVOpcodes::SW: *OS << "sw "; break; case RVOpcodes::SH: *OS << "sh "; break; + case RVOpcodes::SB: *OS << "sb "; break; case RVOpcodes::LD: *OS << "ld "; break; case RVOpcodes::SD: *OS << "sd "; break; - // Control Flow - case RVOpcodes::J: *OS << "j "; break; - case RVOpcodes::JAL: *OS << "jal "; break; - case RVOpcodes::JALR: *OS << "jalr "; break; - case RVOpcodes::RET: *OS << "ret"; break; - case RVOpcodes::BEQ: *OS << "beq "; break; - case RVOpcodes::BNE: *OS << "bne "; break; - case RVOpcodes::BLT: *OS << "blt "; break; - case RVOpcodes::BGE: *OS << "bge "; break; - case RVOpcodes::BLTU: *OS << "bltu "; break; - case RVOpcodes::BGEU: *OS << "bgeu "; break; - // Pseudo-Instructions - case RVOpcodes::LI: *OS << "li "; break; - case RVOpcodes::LA: *OS << "la "; break; - case RVOpcodes::MV: *OS << "mv "; break; - case RVOpcodes::NEG: *OS << "neg "; break; - case RVOpcodes::NEGW: *OS << "negw "; break; - case RVOpcodes::SEQZ: *OS << "seqz "; break; + case RVOpcodes::J: *OS << "j "; break; case RVOpcodes::JAL: *OS << "jal "; break; + case RVOpcodes::JALR: *OS << "jalr "; break; case RVOpcodes::RET: *OS << "ret"; break; + case RVOpcodes::BEQ: *OS << "beq "; break; case RVOpcodes::BNE: *OS << "bne "; break; + case RVOpcodes::BLT: *OS << "blt "; break; case RVOpcodes::BGE: *OS << "bge "; break; + case RVOpcodes::BLTU: *OS << "bltu "; break; case RVOpcodes::BGEU: *OS << "bgeu "; break; + case RVOpcodes::LI: *OS << "li "; break; case RVOpcodes::LA: *OS << "la "; break; + case RVOpcodes::MV: *OS << "mv "; break; case RVOpcodes::NEG: *OS << "neg "; break; + case RVOpcodes::NEGW: *OS << "negw "; break; case RVOpcodes::SEQZ: *OS << "seqz "; break; case RVOpcodes::SNEZ: *OS << "snez "; break; - // Call case RVOpcodes::CALL: *OS << "call "; break; - // Special case RVOpcodes::LABEL: - *OS << "\b\b\b\b"; printOperand(instr->getOperands()[0].get()); *OS << ":"; break; + case RVOpcodes::FRAME_LOAD: + case RVOpcodes::FRAME_STORE: + // These should have been eliminated by RegAlloc + throw std::runtime_error("FRAME pseudo-instruction not eliminated before AsmPrinter"); default: throw std::runtime_error("Unknown opcode in AsmPrinter"); } - // 打印操作数 const auto& operands = instr->getOperands(); - for (size_t i = 0; i < operands.size(); ++i) { - // 对于LW/SW, 操作数格式是 rd, offset(rs1) - if (opcode == RVOpcodes::LW || opcode == RVOpcodes::SW || opcode == RVOpcodes::LD || opcode == RVOpcodes::SD) { + if (!operands.empty()) { + if (isMemoryOp(opcode)) { printOperand(operands[0].get()); *OS << ", "; printOperand(operands[1].get()); - break; // LW/SW只有两个操作数部分 - } - - printOperand(operands[i].get()); - if (i < operands.size() - 1) { - *OS << ", "; + } else { + for (size_t i = 0; i < operands.size(); ++i) { + printOperand(operands[i].get()); + if (i < operands.size() - 1) { + *OS << ", "; + } + } } } - *OS << "\n"; } @@ -178,21 +161,18 @@ void RISCv64AsmPrinter::printOperand(MachineOperand* op) { case MachineOperand::KIND_REG: { auto reg_op = static_cast(op); if (reg_op->isVirtual()) { - // 在这个阶段不应该再有虚拟寄存器了 *OS << "%vreg" << reg_op->getVRegNum(); } else { *OS << regToString(reg_op->getPReg()); } break; } - case MachineOperand::KIND_IMM: { + case MachineOperand::KIND_IMM: *OS << static_cast(op)->getValue(); break; - } - case MachineOperand::KIND_LABEL: { + case MachineOperand::KIND_LABEL: *OS << static_cast(op)->getName(); break; - } case MachineOperand::KIND_MEM: { auto mem_op = static_cast(op); printOperand(mem_op->getOffset()); @@ -204,41 +184,40 @@ void RISCv64AsmPrinter::printOperand(MachineOperand* op) { } } -// 物理寄存器到字符串的转换 (从原RISCv64Backend.cpp迁移) std::string RISCv64AsmPrinter::regToString(PhysicalReg reg) { switch (reg) { - case PhysicalReg::ZERO: return "x0"; - case PhysicalReg::RA: return "ra"; - case PhysicalReg::SP: return "sp"; - case PhysicalReg::GP: return "gp"; - case PhysicalReg::TP: return "tp"; - case PhysicalReg::T0: return "t0"; - case PhysicalReg::T1: return "t1"; - case PhysicalReg::T2: return "t2"; - case PhysicalReg::S0: return "s0"; - case PhysicalReg::S1: return "s1"; - case PhysicalReg::A0: return "a0"; - case PhysicalReg::A1: return "a1"; - case PhysicalReg::A2: return "a2"; - case PhysicalReg::A3: return "a3"; - case PhysicalReg::A4: return "a4"; - case PhysicalReg::A5: return "a5"; - case PhysicalReg::A6: return "a6"; - case PhysicalReg::A7: return "a7"; - case PhysicalReg::S2: return "s2"; - case PhysicalReg::S3: return "s3"; - case PhysicalReg::S4: return "s4"; - case PhysicalReg::S5: return "s5"; - case PhysicalReg::S6: return "s6"; - case PhysicalReg::S7: return "s7"; - case PhysicalReg::S8: return "s8"; - case PhysicalReg::S9: return "s9"; - case PhysicalReg::S10: return "s10"; - case PhysicalReg::S11: return "s11"; - case PhysicalReg::T3: return "t3"; - case PhysicalReg::T4: return "t4"; - case PhysicalReg::T5: return "t5"; - case PhysicalReg::T6: return "t6"; + case PhysicalReg::ZERO: return "x0"; case PhysicalReg::RA: return "ra"; + case PhysicalReg::SP: return "sp"; case PhysicalReg::GP: return "gp"; + case PhysicalReg::TP: return "tp"; case PhysicalReg::T0: return "t0"; + case PhysicalReg::T1: return "t1"; case PhysicalReg::T2: return "t2"; + case PhysicalReg::S0: return "s0"; case PhysicalReg::S1: return "s1"; + case PhysicalReg::A0: return "a0"; case PhysicalReg::A1: return "a1"; + case PhysicalReg::A2: return "a2"; case PhysicalReg::A3: return "a3"; + case PhysicalReg::A4: return "a4"; case PhysicalReg::A5: return "a5"; + case PhysicalReg::A6: return "a6"; case PhysicalReg::A7: return "a7"; + case PhysicalReg::S2: return "s2"; case PhysicalReg::S3: return "s3"; + case PhysicalReg::S4: return "s4"; case PhysicalReg::S5: return "s5"; + case PhysicalReg::S6: return "s6"; case PhysicalReg::S7: return "s7"; + case PhysicalReg::S8: return "s8"; case PhysicalReg::S9: return "s9"; + case PhysicalReg::S10: return "s10"; case PhysicalReg::S11: return "s11"; + case PhysicalReg::T3: return "t3"; case PhysicalReg::T4: return "t4"; + case PhysicalReg::T5: return "t5"; case PhysicalReg::T6: return "t6"; + case PhysicalReg::F0: return "f0"; case PhysicalReg::F1: return "f1"; + case PhysicalReg::F2: return "f2"; case PhysicalReg::F3: return "f3"; + case PhysicalReg::F4: return "f4"; case PhysicalReg::F5: return "f5"; + case PhysicalReg::F6: return "f6"; case PhysicalReg::F7: return "f7"; + case PhysicalReg::F8: return "f8"; case PhysicalReg::F9: return "f9"; + case PhysicalReg::F10: return "f10"; case PhysicalReg::F11: return "f11"; + case PhysicalReg::F12: return "f12"; case PhysicalReg::F13: return "f13"; + case PhysicalReg::F14: return "f14"; case PhysicalReg::F15: return "f15"; + case PhysicalReg::F16: return "f16"; case PhysicalReg::F17: return "f17"; + case PhysicalReg::F18: return "f18"; case PhysicalReg::F19: return "f19"; + case PhysicalReg::F20: return "f20"; case PhysicalReg::F21: return "f21"; + case PhysicalReg::F22: return "f22"; case PhysicalReg::F23: return "f23"; + case PhysicalReg::F24: return "f24"; case PhysicalReg::F25: return "f25"; + case PhysicalReg::F26: return "f26"; case PhysicalReg::F27: return "f27"; + case PhysicalReg::F28: return "f28"; case PhysicalReg::F29: return "f29"; + case PhysicalReg::F30: return "f30"; case PhysicalReg::F31: return "f31"; default: return "UNKNOWN_REG"; } } diff --git a/src/RISCv64Backend.cpp b/src/RISCv64Backend.cpp index c9fdead..b2a7da0 100644 --- a/src/RISCv64Backend.cpp +++ b/src/RISCv64Backend.cpp @@ -3,7 +3,6 @@ #include "RISCv64RegAlloc.h" #include "RISCv64AsmPrinter.h" #include -#include namespace sysy { @@ -12,13 +11,12 @@ std::string RISCv64CodeGen::code_gen() { return module_gen(); } -// module_gen 的逻辑基本不变,它负责处理.data段和驱动每个函数的生成 +// 模块级代码生成 (移植自原文件,处理.data段和驱动函数生成) std::string RISCv64CodeGen::module_gen() { std::stringstream ss; // 1. 处理全局变量 (.data段) - bool has_globals = !module->getGlobals().empty(); - if (has_globals) { + if (!module->getGlobals().empty()) { ss << ".data\n"; for (const auto& global : module->getGlobals()) { ss << ".globl " << global->getName() << "\n"; @@ -45,9 +43,9 @@ std::string RISCv64CodeGen::module_gen() { // 2. 处理函数 (.text段) if (!module->getFunctions().empty()) { ss << ".text\n"; - for (const auto& func : module->getFunctions()) { - if (func.second.get()) { - ss << function_gen(func.second.get()); + for (const auto& func_pair : module->getFunctions()) { + if (func_pair.second.get()) { + ss << function_gen(func_pair.second.get()); } } } @@ -56,31 +54,18 @@ std::string RISCv64CodeGen::module_gen() { // function_gen 现在是新的、模块化的处理流水线 std::string RISCv64CodeGen::function_gen(Function* func) { - - // === 新的、完整的流水线 === - // 阶段 1: 指令选择 (sysy::IR -> LLIR with virtual registers) RISCv64ISel isel; std::unique_ptr mfunc = isel.runOnFunction(func); - // 阶段 2: 寄存器分配前优化 (未来扩展点) - // 例如: - // auto pre_ra_scheduler = std::make_unique(); - // pre_ra_scheduler->runOnMachineFunction(mfunc.get()); - - // 阶段 3: 物理寄存器分配 (virtual regs -> physical regs + spill code) + // 阶段 2: 寄存器分配 (包含栈帧布局, 活跃性分析, 图着色, spill/rewrite) RISCv64RegAlloc reg_alloc(mfunc.get()); reg_alloc.run(); - // 阶段 4: 寄存器分配后优化 (未来扩展点) - // 例如: - // auto post_ra_peephole = std::make_unique(); - // post_ra_peephole->runOnMachineFunction(mfunc.get()); - - // 阶段 5: 代码发射 (LLIR with physical regs -> Assembly Text) + // 阶段 3: 代码发射 (LLIR with physical regs -> Assembly Text) std::stringstream ss; - RISCv64AsmPrinter printer; - printer.runOnMachineFunction(mfunc.get(), ss); + RISCv64AsmPrinter printer(mfunc.get()); + printer.run(ss); return ss.str(); } diff --git a/src/RISCv64ISel.cpp b/src/RISCv64ISel.cpp index bf56027..005ba96 100644 --- a/src/RISCv64ISel.cpp +++ b/src/RISCv64ISel.cpp @@ -1,22 +1,32 @@ #include "RISCv64ISel.h" #include -#include -#include #include +#include +#include // For std::fabs +#include // For std::numeric_limits namespace sysy { +// DAG节点定义 (内部实现) +struct RISCv64ISel::DAGNode { + enum NodeKind { CONSTANT, LOAD, STORE, BINARY, CALL, RETURN, BRANCH, ALLOCA_ADDR, UNARY, MEMSET }; + NodeKind kind; + Value* value = nullptr; + std::vector operands; + std::vector users; + DAGNode(NodeKind k) : kind(k) {} +}; + RISCv64ISel::RISCv64ISel() : vreg_counter(0), local_label_counter(0) {} // 为一个IR Value获取或分配一个新的虚拟寄存器 unsigned RISCv64ISel::getVReg(Value* val) { - if (!val) { // 安全检查 + if (!val) { throw std::runtime_error("Cannot get vreg for a null Value."); } if (vreg_map.find(val) == vreg_map.end()) { if (vreg_counter == 0) { - // vreg 0 通常保留给物理寄存器x0(zero),我们从1开始分配 - vreg_counter = 1; + vreg_counter = 1; // vreg 0 保留 } vreg_map[val] = vreg_counter++; } @@ -27,7 +37,7 @@ unsigned RISCv64ISel::getVReg(Value* val) { std::unique_ptr RISCv64ISel::runOnFunction(Function* func) { F = func; if (!F) return nullptr; - MFunc = std::make_unique(F->getName()); + MFunc = std::make_unique(F, this); vreg_map.clear(); bb_map.clear(); vreg_counter = 0; @@ -40,37 +50,28 @@ std::unique_ptr RISCv64ISel::runOnFunction(Function* func) { // 指令选择主流程 void RISCv64ISel::select() { - // 1. 为所有基本块创建对应的MachineBasicBlock for (const auto& bb_ptr : F->getBasicBlocks()) { - BasicBlock* bb = bb_ptr.get(); - auto mbb = std::make_unique(bb->getName(), MFunc.get()); - bb_map[bb] = mbb.get(); + auto mbb = std::make_unique(bb_ptr->getName(), MFunc.get()); + bb_map[bb_ptr.get()] = mbb.get(); MFunc->addBlock(std::move(mbb)); } - // 2. 为函数参数创建虚拟寄存器 - // ====================== 已修正 ====================== - // 根据 IR.h, 参数列表存储在入口基本块中 if (F->getEntryBlock()) { for (auto* arg_alloca : F->getEntryBlock()->getArguments()) { getVReg(arg_alloca); } } - // ===================================================== - // 3. 遍历每个基本块,生成指令 for (const auto& bb_ptr : F->getBasicBlocks()) { selectBasicBlock(bb_ptr.get()); } - // 4. 设置基本块的前驱后继关系 for (const auto& bb_ptr : F->getBasicBlocks()) { - BasicBlock* bb = bb_ptr.get(); - CurMBB = bb_map.at(bb); - for (auto succ : bb->getSuccessors()) { + CurMBB = bb_map.at(bb_ptr.get()); + for (auto succ : bb_ptr->getSuccessors()) { CurMBB->successors.push_back(bb_map.at(succ)); } - for (auto pred : bb->getPredecessors()) { + for (auto pred : bb_ptr->getPredecessors()) { CurMBB->predecessors.push_back(bb_map.at(pred)); } } @@ -87,29 +88,23 @@ void RISCv64ISel::selectBasicBlock(BasicBlock* bb) { value_to_node[node->value] = node.get(); } } - + std::set selected_nodes; std::function select_recursive = [&](DAGNode* node) { if (!node || selected_nodes.count(node)) return; - for (auto operand : node->operands) { select_recursive(operand); } - - // 只有当所有操作数都选择完毕后,才选择当前节点 selectNode(node); selected_nodes.insert(node); }; - // 按照IR指令的原始顺序来驱动指令选择 for (const auto& inst_ptr : bb->getInstructions()) { DAGNode* node_to_select = nullptr; - // 查找当前IR指令对应的DAG节点 if (value_to_node.count(inst_ptr.get())) { node_to_select = value_to_node.at(inst_ptr.get()); } else { - // 对于没有返回值的指令或某些特殊情况 for(const auto& node : dag) { if(node->value == inst_ptr.get()) { node_to_select = node.get(); @@ -123,88 +118,105 @@ void RISCv64ISel::selectBasicBlock(BasicBlock* bb) { } } +// 核心函数:为DAG节点选择并生成MachineInstr (忠实移植版) void RISCv64ISel::selectNode(DAGNode* node) { - // 注意:不再生成字符串,而是创建MachineInstr对象并加入到CurMBB switch (node->kind) { case DAGNode::CONSTANT: case DAGNode::ALLOCA_ADDR: - // 这些节点本身不生成指令。使用它们的指令会按需处理。 - // 为Alloca地址分配一个vreg是必要的,代表地址。 if (node->value) getVReg(node->value); break; case DAGNode::LOAD: { - // lw rd, offset(base) auto dest_vreg = getVReg(node->value); - auto ptr_vreg = getVReg(node->operands[0]->value); + Value* ptr_val = node->operands[0]->value; - auto instr = std::make_unique(RVOpcodes::LW); - instr->addOperand(std::make_unique(dest_vreg)); - // 暂时生成0(ptr),后续pass会将其优化为 offset(s0) - instr->addOperand(std::make_unique( - std::make_unique(ptr_vreg), - std::make_unique(0) - )); - CurMBB->addInstruction(std::move(instr)); + if (auto alloca = dynamic_cast(ptr_val)) { + auto instr = std::make_unique(RVOpcodes::FRAME_LOAD); + instr->addOperand(std::make_unique(dest_vreg)); + instr->addOperand(std::make_unique(getVReg(alloca))); + CurMBB->addInstruction(std::move(instr)); + } else if (auto global = dynamic_cast(ptr_val)) { + auto addr_vreg = getNewVReg(); + auto la = std::make_unique(RVOpcodes::LA); + la->addOperand(std::make_unique(addr_vreg)); + la->addOperand(std::make_unique(global->getName())); + CurMBB->addInstruction(std::move(la)); + + auto lw = std::make_unique(RVOpcodes::LW); + lw->addOperand(std::make_unique(dest_vreg)); + lw->addOperand(std::make_unique( + std::make_unique(addr_vreg), + std::make_unique(0) + )); + CurMBB->addInstruction(std::move(lw)); + } else { + auto ptr_vreg = getVReg(ptr_val); + auto lw = std::make_unique(RVOpcodes::LW); + lw->addOperand(std::make_unique(dest_vreg)); + lw->addOperand(std::make_unique( + std::make_unique(ptr_vreg), + std::make_unique(0) + )); + CurMBB->addInstruction(std::move(lw)); + } break; } case DAGNode::STORE: { - // sw rs2, offset(rs1) - // 先加载常量 - if (auto val_const = dynamic_cast(node->operands[0]->value)) { + Value* val_to_store = node->operands[0]->value; + Value* ptr_val = node->operands[1]->value; + + if (auto val_const = dynamic_cast(val_to_store)) { auto li = std::make_unique(RVOpcodes::LI); li->addOperand(std::make_unique(getVReg(val_const))); li->addOperand(std::make_unique(val_const->getInt())); CurMBB->addInstruction(std::move(li)); } + auto val_vreg = getVReg(val_to_store); - auto val_vreg = getVReg(node->operands[0]->value); - auto ptr_vreg = getVReg(node->operands[1]->value); + if (auto alloca = dynamic_cast(ptr_val)) { + auto instr = std::make_unique(RVOpcodes::FRAME_STORE); + instr->addOperand(std::make_unique(val_vreg)); + instr->addOperand(std::make_unique(getVReg(alloca))); + CurMBB->addInstruction(std::move(instr)); + } else if (auto global = dynamic_cast(ptr_val)) { + auto addr_vreg = getNewVReg(); + auto la = std::make_unique(RVOpcodes::LA); + la->addOperand(std::make_unique(addr_vreg)); + la->addOperand(std::make_unique(global->getName())); + CurMBB->addInstruction(std::move(la)); - auto instr = std::make_unique(RVOpcodes::SW); - instr->addOperand(std::make_unique(val_vreg)); // value to store - instr->addOperand(std::make_unique( - std::make_unique(ptr_vreg), // base address - std::make_unique(0) // offset - )); - CurMBB->addInstruction(std::move(instr)); + auto sw = std::make_unique(RVOpcodes::SW); + sw->addOperand(std::make_unique(val_vreg)); + sw->addOperand(std::make_unique( + std::make_unique(addr_vreg), + std::make_unique(0) + )); + CurMBB->addInstruction(std::move(sw)); + } else { + auto ptr_vreg = getVReg(ptr_val); + auto sw = std::make_unique(RVOpcodes::SW); + sw->addOperand(std::make_unique(val_vreg)); + sw->addOperand(std::make_unique( + std::make_unique(ptr_vreg), + std::make_unique(0) + )); + CurMBB->addInstruction(std::move(sw)); + } break; } case DAGNode::BINARY: { auto bin = dynamic_cast(node->value); - if (!bin) break; - Value* lhs = bin->getLhs(); Value* rhs = bin->getRhs(); - // 检查是否为 addi 优化 - if (bin->getKind() == BinaryInst::kAdd) { - if (auto rhs_const = dynamic_cast(rhs)) { - if (rhs_const->getInt() >= -2048 && rhs_const->getInt() < 2048) { - auto instr = std::make_unique(RVOpcodes::ADDIW); - instr->addOperand(std::make_unique(getVReg(bin))); - instr->addOperand(std::make_unique(getVReg(lhs))); - instr->addOperand(std::make_unique(rhs_const->getInt())); - CurMBB->addInstruction(std::move(instr)); - return; // 指令已生成,提前返回 - } - } - } - - // 为操作数加载立即数或地址 auto load_val_if_const = [&](Value* val) { if (auto c = dynamic_cast(val)) { auto li = std::make_unique(RVOpcodes::LI); li->addOperand(std::make_unique(getVReg(c))); li->addOperand(std::make_unique(c->getInt())); CurMBB->addInstruction(std::move(li)); - } else if (auto g = dynamic_cast(val)) { - auto la = std::make_unique(RVOpcodes::LA); - la->addOperand(std::make_unique(getVReg(g))); - la->addOperand(std::make_unique(g->getName())); - CurMBB->addInstruction(std::move(la)); } }; load_val_if_const(lhs); @@ -214,7 +226,19 @@ void RISCv64ISel::selectNode(DAGNode* node) { auto lhs_vreg = getVReg(lhs); auto rhs_vreg = getVReg(rhs); - // 生成二元运算指令 + if (bin->getKind() == BinaryInst::kAdd) { + if (auto rhs_const = dynamic_cast(rhs)) { + if (rhs_const->getInt() >= -2048 && rhs_const->getInt() < 2048) { + auto instr = std::make_unique(RVOpcodes::ADDIW); + instr->addOperand(std::make_unique(dest_vreg)); + instr->addOperand(std::make_unique(lhs_vreg)); + instr->addOperand(std::make_unique(rhs_const->getInt())); + CurMBB->addInstruction(std::move(instr)); + return; + } + } + } + switch (bin->getKind()) { case BinaryInst::kAdd: { RVOpcodes opcode = (lhs->getType()->isPointer() || rhs->getType()->isPointer()) ? RVOpcodes::ADD : RVOpcodes::ADDW; @@ -294,16 +318,16 @@ void RISCv64ISel::selectNode(DAGNode* node) { case BinaryInst::kICmpGT: { auto instr = std::make_unique(RVOpcodes::SLT); instr->addOperand(std::make_unique(dest_vreg)); - instr->addOperand(std::make_unique(rhs_vreg)); // Swapped - instr->addOperand(std::make_unique(lhs_vreg)); // Swapped + instr->addOperand(std::make_unique(rhs_vreg)); + instr->addOperand(std::make_unique(lhs_vreg)); CurMBB->addInstruction(std::move(instr)); break; } case BinaryInst::kICmpLE: { auto slt = std::make_unique(RVOpcodes::SLT); slt->addOperand(std::make_unique(dest_vreg)); - slt->addOperand(std::make_unique(rhs_vreg)); // Swapped - slt->addOperand(std::make_unique(lhs_vreg)); // Swapped + slt->addOperand(std::make_unique(rhs_vreg)); + slt->addOperand(std::make_unique(lhs_vreg)); CurMBB->addInstruction(std::move(slt)); auto xori = std::make_unique(RVOpcodes::XORI); @@ -335,8 +359,6 @@ void RISCv64ISel::selectNode(DAGNode* node) { case DAGNode::UNARY: { auto unary = dynamic_cast(node->value); - if (!unary) break; - auto dest_vreg = getVReg(unary); auto src_vreg = getVReg(unary->getOperand()); @@ -344,7 +366,7 @@ void RISCv64ISel::selectNode(DAGNode* node) { case UnaryInst::kNeg: { auto instr = std::make_unique(RVOpcodes::SUBW); instr->addOperand(std::make_unique(dest_vreg)); - instr->addOperand(std::make_unique(PhysicalReg::ZERO)); // x0 + instr->addOperand(std::make_unique(PhysicalReg::ZERO)); instr->addOperand(std::make_unique(src_vreg)); CurMBB->addInstruction(std::move(instr)); break; @@ -364,51 +386,67 @@ void RISCv64ISel::selectNode(DAGNode* node) { case DAGNode::CALL: { auto call = dynamic_cast(node->value); - if (!call) break; + for (size_t i = 0; i < node->operands.size() && i < 8; ++i) { + DAGNode* arg_node = node->operands[i]; + auto arg_preg = static_cast(static_cast(PhysicalReg::A0) + i); + + if (arg_node->kind == DAGNode::CONSTANT) { + if (auto const_val = dynamic_cast(arg_node->value)) { + auto li = std::make_unique(RVOpcodes::LI); + li->addOperand(std::make_unique(arg_preg)); + li->addOperand(std::make_unique(const_val->getInt())); + CurMBB->addInstruction(std::move(li)); + } + } else { + auto src_vreg = getVReg(arg_node->value); + auto mv = std::make_unique(RVOpcodes::MV); + mv->addOperand(std::make_unique(arg_preg)); + mv->addOperand(std::make_unique(src_vreg)); + CurMBB->addInstruction(std::move(mv)); + } + } - // 在此阶段,我们只处理函数调用本身和返回值的移动 - // 参数的传递将在一个专门的 Calling Convention Pass 中处理 - auto call_instr = std::make_unique(RVOpcodes::CALL); call_instr->addOperand(std::make_unique(call->getCallee()->getName())); CurMBB->addInstruction(std::move(call_instr)); if (!call->getType()->isVoid()) { auto mv_instr = std::make_unique(RVOpcodes::MV); - mv_instr->addOperand(std::make_unique(getVReg(call))); // dest - mv_instr->addOperand(std::make_unique(PhysicalReg::A0)); // src + mv_instr->addOperand(std::make_unique(getVReg(call))); + mv_instr->addOperand(std::make_unique(PhysicalReg::A0)); CurMBB->addInstruction(std::move(mv_instr)); } break; } case DAGNode::RETURN: { - auto ret_inst = dynamic_cast(node->value); - if (ret_inst && ret_inst->hasReturnValue()) { - // 如果有返回值,生成一条mv指令将其放入a0 - auto mv_instr = std::make_unique(RVOpcodes::MV); - mv_instr->addOperand(std::make_unique(PhysicalReg::A0)); - mv_instr->addOperand(std::make_unique(getVReg(ret_inst->getReturnValue()))); - CurMBB->addInstruction(std::move(mv_instr)); + auto ret_inst_ir = dynamic_cast(node->value); + if (ret_inst_ir && ret_inst_ir->hasReturnValue()) { + Value* ret_val = ret_inst_ir->getReturnValue(); + if (auto const_val = dynamic_cast(ret_val)) { + auto li_instr = std::make_unique(RVOpcodes::LI); + li_instr->addOperand(std::make_unique(PhysicalReg::A0)); + li_instr->addOperand(std::make_unique(const_val->getInt())); + CurMBB->addInstruction(std::move(li_instr)); + } else { + auto mv_instr = std::make_unique(RVOpcodes::MV); + mv_instr->addOperand(std::make_unique(PhysicalReg::A0)); + mv_instr->addOperand(std::make_unique(getVReg(ret_val))); + CurMBB->addInstruction(std::move(mv_instr)); + } } - // 生成ret伪指令 - auto instr = std::make_unique(RVOpcodes::RET); - CurMBB->addInstruction(std::move(instr)); + auto ret_mi = std::make_unique(RVOpcodes::RET); + CurMBB->addInstruction(std::move(ret_mi)); break; } case DAGNode::BRANCH: { if (auto cond_br = dynamic_cast(node->value)) { - // bne cond, x0, then_block auto br_instr = std::make_unique(RVOpcodes::BNE); br_instr->addOperand(std::make_unique(getVReg(cond_br->getCondition()))); br_instr->addOperand(std::make_unique(PhysicalReg::ZERO)); br_instr->addOperand(std::make_unique(cond_br->getThenBlock()->getName())); CurMBB->addInstruction(std::move(br_instr)); - - // j else_block - // 注意:这里会产生一个fallthrough问题,后续的分支优化pass会解决它 - // 一个更健壮的生成方式是 bne -> j else; then: ...; else: ... } else if (auto uncond_br = dynamic_cast(node->value)) { auto j_instr = std::make_unique(RVOpcodes::J); j_instr->addOperand(std::make_unique(uncond_br->getBlock()->getName())); @@ -416,20 +454,16 @@ void RISCv64ISel::selectNode(DAGNode* node) { } break; } - case DAGNode::MEMSET: { - // 这是对原memset逻辑的完整LLIR翻译 - auto memset = dynamic_cast(node->value); - if (!memset) break; + case DAGNode::MEMSET: { + auto memset = dynamic_cast(node->value); auto r_dest_addr = getVReg(memset->getPointer()); auto r_num_bytes = getVReg(memset->getSize()); auto r_value_byte = getVReg(memset->getValue()); - - // 为临时值创建虚拟寄存器 - auto r_counter = vreg_counter++; - auto r_end_addr = vreg_counter++; - auto r_current_addr = vreg_counter++; - auto r_temp_val = vreg_counter++; + auto r_counter = getNewVReg(); + auto r_end_addr = getNewVReg(); + auto r_current_addr = getNewVReg(); + auto r_temp_val = getNewVReg(); auto add_instr = [&](RVOpcodes op, unsigned rd, unsigned rs1, unsigned rs2) { auto i = std::make_unique(op); @@ -470,12 +504,11 @@ void RISCv64ISel::selectNode(DAGNode* node) { }; int unique_id = this->local_label_counter++; - std::string loop_start_label = "memset_loop_start_" + std::to_string(unique_id); - std::string loop_end_label = "memset_loop_end_" + std::to_string(unique_id); - std::string remainder_label = "memset_remainder_" + std::to_string(unique_id); - std::string done_label = "memset_done_" + std::to_string(unique_id); - - // 构造64位的填充值 + std::string loop_start_label = MFunc->getName() + "_memset_loop_start_" + std::to_string(unique_id); + std::string loop_end_label = MFunc->getName() + "_memset_loop_end_" + std::to_string(unique_id); + std::string remainder_label = MFunc->getName() + "_memset_remainder_" + std::to_string(unique_id); + std::string done_label = MFunc->getName() + "_memset_done_" + std::to_string(unique_id); + addi_instr(RVOpcodes::ANDI, r_temp_val, r_value_byte, 255); addi_instr(RVOpcodes::SLLI, r_value_byte, r_temp_val, 8); add_instr(RVOpcodes::OR, r_temp_val, r_temp_val, r_value_byte); @@ -483,8 +516,6 @@ void RISCv64ISel::selectNode(DAGNode* node) { add_instr(RVOpcodes::OR, r_temp_val, r_temp_val, r_value_byte); addi_instr(RVOpcodes::SLLI, r_value_byte, r_temp_val, 32); add_instr(RVOpcodes::OR, r_temp_val, r_temp_val, r_value_byte); - - // 设置循环变量 add_instr(RVOpcodes::ADD, r_end_addr, r_dest_addr, r_num_bytes); auto mv = std::make_unique(RVOpcodes::MV); mv->addOperand(std::make_unique(r_current_addr)); @@ -492,16 +523,12 @@ void RISCv64ISel::selectNode(DAGNode* node) { CurMBB->addInstruction(std::move(mv)); addi_instr(RVOpcodes::ANDI, r_counter, r_num_bytes, -8); add_instr(RVOpcodes::ADD, r_counter, r_dest_addr, r_counter); - - // 64位写入循环 label_instr(loop_start_label); branch_instr(RVOpcodes::BGEU, r_current_addr, r_counter, loop_end_label); store_instr(RVOpcodes::SD, r_temp_val, r_current_addr, 0); addi_instr(RVOpcodes::ADDI, r_current_addr, r_current_addr, 8); jump_instr(loop_start_label); label_instr(loop_end_label); - - // 剩余字节写入循环 label_instr(remainder_label); branch_instr(RVOpcodes::BGEU, r_current_addr, r_end_addr, done_label); store_instr(RVOpcodes::SB, r_temp_val, r_current_addr, 0); @@ -512,13 +539,13 @@ void RISCv64ISel::selectNode(DAGNode* node) { } default: - throw std::runtime_error("Unsupported DAGNode kind in ISel: " + std::to_string(node->kind)); + throw std::runtime_error("Unsupported DAGNode kind in ISel"); } } - -// --- DAG构建函数 (从原RISCv64Backend.cpp几乎原样迁移, 保持不变) --- -RISCv64ISel::DAGNode* RISCv64ISel::create_node(DAGNode::NodeKind kind, Value* val, std::map& value_to_node, std::vector>& nodes_storage) { +// 以下是忠实移植的DAG构建函数 +RISCv64ISel::DAGNode* RISCv64ISel::create_node(int kind_int, Value* val, std::map& value_to_node, std::vector>& nodes_storage) { + auto kind = static_cast(kind_int); if (val && value_to_node.count(val) && kind != DAGNode::STORE && kind != DAGNode::RETURN && kind != DAGNode::BRANCH && kind != DAGNode::MEMSET) { return value_to_node[val]; } @@ -526,10 +553,7 @@ RISCv64ISel::DAGNode* RISCv64ISel::create_node(DAGNode::NodeKind kind, Value* va node->value = val; DAGNode* raw_node_ptr = node.get(); nodes_storage.push_back(std::move(node)); - // 只有产生值的节点才应该被记录,以备复用 - if (val && !val->getType()->isVoid() && dynamic_cast(val)) { - value_to_node[val] = raw_node_ptr; - } else if (val && dynamic_cast(val)) { + if (val && !val->getType()->isVoid() && (dynamic_cast(val) || dynamic_cast(val))) { value_to_node[val] = raw_node_ptr; } return raw_node_ptr; @@ -545,7 +569,6 @@ RISCv64ISel::DAGNode* RISCv64ISel::get_operand_node(Value* val_ir, std::map(val_ir)) { return create_node(DAGNode::ALLOCA_ADDR, val_ir, value_to_node, nodes_storage); } - // Fallback: Assume it needs to be loaded if not found (might be a parameter or a value from another block) return create_node(DAGNode::LOAD, val_ir, value_to_node, nodes_storage); } @@ -567,12 +590,20 @@ std::vector> RISCv64ISel::build_dag(BasicB memset_node->operands.push_back(get_operand_node(memset->getBegin(), value_to_node, nodes_storage)); memset_node->operands.push_back(get_operand_node(memset->getSize(), value_to_node, nodes_storage)); memset_node->operands.push_back(get_operand_node(memset->getValue(), value_to_node, nodes_storage)); - } - else if (auto load = dynamic_cast(inst)) { + } else if (auto load = dynamic_cast(inst)) { auto load_node = create_node(DAGNode::LOAD, load, value_to_node, nodes_storage); load_node->operands.push_back(get_operand_node(load->getPointer(), value_to_node, nodes_storage)); } else if (auto bin = dynamic_cast(inst)) { if(value_to_node.count(bin)) continue; + if (bin->getKind() == BinaryInst::kSub) { + if (auto const_lhs = dynamic_cast(bin->getLhs())) { + if (const_lhs->getInt() == 0) { + auto unary_node = create_node(DAGNode::UNARY, bin, value_to_node, nodes_storage); + unary_node->operands.push_back(get_operand_node(bin->getRhs(), value_to_node, nodes_storage)); + continue; + } + } + } auto bin_node = create_node(DAGNode::BINARY, bin, value_to_node, nodes_storage); bin_node->operands.push_back(get_operand_node(bin->getLhs(), value_to_node, nodes_storage)); bin_node->operands.push_back(get_operand_node(bin->getRhs(), value_to_node, nodes_storage)); @@ -580,8 +611,7 @@ std::vector> RISCv64ISel::build_dag(BasicB if(value_to_node.count(un)) continue; auto unary_node = create_node(DAGNode::UNARY, un, value_to_node, nodes_storage); unary_node->operands.push_back(get_operand_node(un->getOperand(), value_to_node, nodes_storage)); - } - else if (auto call = dynamic_cast(inst)) { + } else if (auto call = dynamic_cast(inst)) { if(value_to_node.count(call)) continue; auto call_node = create_node(DAGNode::CALL, call, value_to_node, nodes_storage); for (auto arg : call->getArguments()) { diff --git a/src/RISCv64Passes.cpp b/src/RISCv64Passes.cpp new file mode 100644 index 0000000..40aff21 --- /dev/null +++ b/src/RISCv64Passes.cpp @@ -0,0 +1,8 @@ +// RISCv64Passes.cpp +#include "RISCv64Passes.h" + +namespace sysy { + +// 此处为未来优化Pass的实现 + +} // namespace sysy \ No newline at end of file diff --git a/src/RISCv64RegAlloc.cpp b/src/RISCv64RegAlloc.cpp index 4471868..2695f3b 100644 --- a/src/RISCv64RegAlloc.cpp +++ b/src/RISCv64RegAlloc.cpp @@ -1,11 +1,11 @@ #include "RISCv64RegAlloc.h" +#include "RISCv64ISel.h" #include #include namespace sysy { RISCv64RegAlloc::RISCv64RegAlloc(MachineFunction* mfunc) : MFunc(mfunc) { - // 初始化可分配的整数寄存器池 (排除特殊用途的) allocable_int_regs = { PhysicalReg::T0, PhysicalReg::T1, PhysicalReg::T2, PhysicalReg::T3, PhysicalReg::T4, PhysicalReg::T5, PhysicalReg::T6, @@ -18,23 +18,113 @@ RISCv64RegAlloc::RISCv64RegAlloc(MachineFunction* mfunc) : MFunc(mfunc) { } void RISCv64RegAlloc::run() { + eliminateFrameIndices(); analyzeLiveness(); buildInterferenceGraph(); colorGraph(); rewriteFunction(); } +void RISCv64RegAlloc::eliminateFrameIndices() { + StackFrameInfo& frame_info = MFunc->getFrameInfo(); + int current_offset = 0; + Function* F = MFunc->getFunc(); + RISCv64ISel* isel = MFunc->getISel(); + + for (auto& bb : F->getBasicBlocks()) { + for (auto& inst : bb->getInstructions()) { + if (auto alloca = dynamic_cast(inst.get())) { + int size = 4; + if (!alloca->getDims().empty()) { + int num_elements = 1; + for (const auto& dim_use : alloca->getDims()) { + if (auto const_dim = dynamic_cast(dim_use->getValue())) { + num_elements *= const_dim->getInt(); + } + } + size *= num_elements; + } + current_offset += size; + unsigned alloca_vreg = isel->getVReg(alloca); + frame_info.alloca_offsets[alloca_vreg] = -current_offset; + } + } + } + frame_info.locals_size = current_offset; + + for (auto& mbb : MFunc->getBlocks()) { + std::vector> new_instructions; + for (auto& instr_ptr : mbb->getInstructions()) { + if (instr_ptr->getOpcode() == RVOpcodes::FRAME_LOAD) { + auto& operands = instr_ptr->getOperands(); + unsigned dest_vreg = static_cast(operands[0].get())->getVRegNum(); + unsigned alloca_vreg = static_cast(operands[1].get())->getVRegNum(); + int offset = frame_info.alloca_offsets.at(alloca_vreg); + auto addr_vreg = isel->getNewVReg(); + + auto addi = std::make_unique(RVOpcodes::ADDI); + addi->addOperand(std::make_unique(addr_vreg)); + addi->addOperand(std::make_unique(PhysicalReg::S0)); + addi->addOperand(std::make_unique(offset)); + new_instructions.push_back(std::move(addi)); + + auto lw = std::make_unique(RVOpcodes::LW); + lw->addOperand(std::make_unique(dest_vreg)); + lw->addOperand(std::make_unique( + std::make_unique(addr_vreg), + std::make_unique(0))); + new_instructions.push_back(std::move(lw)); + + } else if (instr_ptr->getOpcode() == RVOpcodes::FRAME_STORE) { + auto& operands = instr_ptr->getOperands(); + unsigned src_vreg = static_cast(operands[0].get())->getVRegNum(); + unsigned alloca_vreg = static_cast(operands[1].get())->getVRegNum(); + int offset = frame_info.alloca_offsets.at(alloca_vreg); + auto addr_vreg = isel->getNewVReg(); + + auto addi = std::make_unique(RVOpcodes::ADDI); + addi->addOperand(std::make_unique(addr_vreg)); + addi->addOperand(std::make_unique(PhysicalReg::S0)); + addi->addOperand(std::make_unique(offset)); + new_instructions.push_back(std::move(addi)); + + auto sw = std::make_unique(RVOpcodes::SW); + sw->addOperand(std::make_unique(src_vreg)); + sw->addOperand(std::make_unique( + std::make_unique(addr_vreg), + std::make_unique(0))); + new_instructions.push_back(std::move(sw)); + } else { + new_instructions.push_back(std::move(instr_ptr)); + } + } + mbb->getInstructions() = std::move(new_instructions); + } +} + void RISCv64RegAlloc::getInstrUseDef(MachineInstr* instr, LiveSet& use, LiveSet& def) { - // 这是一个简化的版本,实际需要根据RVOpcodes精确定义 - // 通常第一个RegOperand是def,其余是use bool is_def = true; + auto opcode = instr->getOpcode(); + + // 预定义def和use规则 + if (opcode == RVOpcodes::SW || opcode == RVOpcodes::SD || + opcode == RVOpcodes::BEQ || opcode == RVOpcodes::BNE || + opcode == RVOpcodes::BLT || opcode == RVOpcodes::BGE || + opcode == RVOpcodes::RET || opcode == RVOpcodes::J) { + is_def = false; + } + if (opcode == RVOpcodes::CALL) { + // CALL会杀死所有调用者保存寄存器,这是一个简化处理 + // 同时也使用了传入a0-a7的参数 + } + for (const auto& op : instr->getOperands()) { if (op->getKind() == MachineOperand::KIND_REG) { auto reg_op = static_cast(op.get()); if (reg_op->isVirtual()) { if (is_def) { def.insert(reg_op->getVRegNum()); - is_def = false; // 假设每条指令最多一个def + is_def = false; } else { use.insert(reg_op->getVRegNum()); } @@ -46,35 +136,16 @@ void RISCv64RegAlloc::getInstrUseDef(MachineInstr* instr, LiveSet& use, LiveSet& } } } - - // 特殊处理store和branch指令,它们没有显式的def - auto opcode = instr->getOpcode(); - if (opcode == RVOpcodes::SW || opcode == RVOpcodes::SD || opcode == RVOpcodes::BNE || opcode == RVOpcodes::BEQ) { - def.clear(); // 清空错误的def - use.clear(); - for (const auto& op : instr->getOperands()) { - if (op->getKind() == MachineOperand::KIND_REG) { - auto reg_op = static_cast(op.get()); - if(reg_op->isVirtual()) use.insert(reg_op->getVRegNum()); - } else if (op->getKind() == MachineOperand::KIND_MEM) { - auto mem_op = static_cast(op.get()); - if(mem_op->getBase()->isVirtual()) use.insert(mem_op->getBase()->getVRegNum()); - } - } - } } - void RISCv64RegAlloc::analyzeLiveness() { bool changed = true; while (changed) { changed = false; - // 逆序遍历基本块 for (auto it = MFunc->getBlocks().rbegin(); it != MFunc->getBlocks().rend(); ++it) { auto& mbb = *it; LiveSet live_out; for (auto succ : mbb->successors) { - // live_out[B] = Union(live_in[S]) for all S in succ(B) if (!succ->getInstructions().empty()) { auto first_instr = succ->getInstructions().front().get(); if (live_in_map.count(first_instr)) { @@ -83,19 +154,14 @@ void RISCv64RegAlloc::analyzeLiveness() { } } - // 逆序遍历指令 for (auto instr_it = mbb->getInstructions().rbegin(); instr_it != mbb->getInstructions().rend(); ++instr_it) { MachineInstr* instr = instr_it->get(); LiveSet old_live_in = live_in_map[instr]; - LiveSet old_live_out = live_out_map[instr]; - - // 更新 live_out live_out_map[instr] = live_out; LiveSet use, def; getInstrUseDef(instr, use, def); - // live_in[i] = use[i] U (live_out[i] - def[i]) LiveSet live_in = use; LiveSet diff = live_out; for (auto vreg : def) { @@ -104,10 +170,9 @@ void RISCv64RegAlloc::analyzeLiveness() { live_in.insert(diff.begin(), diff.end()); live_in_map[instr] = live_in; - // 为下一次迭代准备live_out live_out = live_in; - if (live_in_map[instr] != old_live_in || live_out_map[instr] != old_live_out) { + if (live_in_map[instr] != old_live_in) { changed = true; } } @@ -117,21 +182,21 @@ void RISCv64RegAlloc::analyzeLiveness() { void RISCv64RegAlloc::buildInterferenceGraph() { std::set all_vregs; - // 收集所有虚拟寄存器 - for (auto const& [instr, live_set] : live_out_map) { - all_vregs.insert(live_set.begin(), live_set.end()); + for (auto& mbb : MFunc->getBlocks()) { + for(auto& instr : mbb->getInstructions()) { + LiveSet use, def; + getInstrUseDef(instr.get(), use, def); + for(auto u : use) all_vregs.insert(u); + for(auto d : def) all_vregs.insert(d); + } } - // 初始化图 - for (auto vreg : all_vregs) { - interference_graph[vreg] = {}; - } + for (auto vreg : all_vregs) { interference_graph[vreg] = {}; } for (auto& mbb : MFunc->getBlocks()) { for (auto& instr : mbb->getInstructions()) { LiveSet def, use; getInstrUseDef(instr.get(), use, def); - const LiveSet& live_out = live_out_map.at(instr.get()); for (unsigned d : def) { @@ -152,21 +217,18 @@ void RISCv64RegAlloc::colorGraph() { sorted_vregs.push_back(vreg); } - // 按度数降序排序 (简单贪心策略) std::sort(sorted_vregs.begin(), sorted_vregs.end(), [&](unsigned a, unsigned b) { return interference_graph[a].size() > interference_graph[b].size(); }); for (unsigned vreg : sorted_vregs) { std::set used_colors; - // 查找邻居已用的颜色 for (unsigned neighbor : interference_graph.at(vreg)) { if (color_map.count(neighbor)) { used_colors.insert(color_map.at(neighbor)); } } - // 寻找一个可用的颜色 bool colored = false; for (PhysicalReg preg : allocable_int_regs) { if (used_colors.find(preg) == used_colors.end()) { @@ -175,54 +237,47 @@ void RISCv64RegAlloc::colorGraph() { break; } } - if (!colored) { - // 无法分配,需要溢出 spilled_vregs.insert(vreg); } } } void RISCv64RegAlloc::rewriteFunction() { - // 1. 为所有溢出的vreg分配栈槽 StackFrameInfo& frame_info = MFunc->getFrameInfo(); - int current_offset = frame_info.frame_size; // 假设从现有栈大小后开始分配 + int current_offset = frame_info.locals_size; for (unsigned vreg : spilled_vregs) { - current_offset += 4; // 假设所有溢出变量都占4字节 - frame_info.spill_slots[vreg] = -current_offset; // 栈向下增长,所以是负偏移 + current_offset += 4; + frame_info.spill_offsets[vreg] = -current_offset; } - frame_info.frame_size = current_offset; + frame_info.spill_size = current_offset - frame_info.locals_size; - // 2. 遍历所有指令,替换vreg并插入spill代码 for (auto& mbb : MFunc->getBlocks()) { std::vector> new_instructions; for (auto& instr_ptr : mbb->getInstructions()) { LiveSet use, def; getInstrUseDef(instr_ptr.get(), use, def); - // 为use的溢出变量插入LOAD for (unsigned vreg : use) { if (spilled_vregs.count(vreg)) { - int offset = frame_info.spill_slots.at(vreg); + int offset = frame_info.spill_offsets.at(vreg); auto load = std::make_unique(RVOpcodes::LW); - load->addOperand(std::make_unique(vreg)); // 临时用vreg号代表,稍后替换 + load->addOperand(std::make_unique(vreg)); load->addOperand(std::make_unique( - std::make_unique(PhysicalReg::S0), // 基址用帧指针s0 + std::make_unique(PhysicalReg::S0), std::make_unique(offset) )); new_instructions.push_back(std::move(load)); } } - // 添加原始指令 new_instructions.push_back(std::move(instr_ptr)); - // 为def的溢出变量插入STORE for (unsigned vreg : def) { if (spilled_vregs.count(vreg)) { - int offset = frame_info.spill_slots.at(vreg); + int offset = frame_info.spill_offsets.at(vreg); auto store = std::make_unique(RVOpcodes::SW); - store->addOperand(std::make_unique(vreg)); // 临时用vreg号代表 + store->addOperand(std::make_unique(vreg)); store->addOperand(std::make_unique( std::make_unique(PhysicalReg::S0), std::make_unique(offset) @@ -234,7 +289,6 @@ void RISCv64RegAlloc::rewriteFunction() { mbb->getInstructions() = std::move(new_instructions); } - // 3. 最后一遍扫描,将所有RegOperand从vreg替换为preg for (auto& mbb : MFunc->getBlocks()) { for (auto& instr_ptr : mbb->getInstructions()) { for (auto& op_ptr : instr_ptr->getOperands()) { @@ -245,8 +299,7 @@ void RISCv64RegAlloc::rewriteFunction() { if (color_map.count(vreg)) { reg_op->setPReg(color_map.at(vreg)); } else if (spilled_vregs.count(vreg)) { - // 对于spill的vreg, 使用一个固定的临时寄存器, 比如t6 - reg_op->setPReg(PhysicalReg::T6); + reg_op->setPReg(PhysicalReg::T6); // 溢出统一用t6 } } } else if (op_ptr->getKind() == MachineOperand::KIND_MEM) { @@ -254,7 +307,11 @@ void RISCv64RegAlloc::rewriteFunction() { auto base_reg_op = mem_op->getBase(); if(base_reg_op->isVirtual()){ unsigned vreg = base_reg_op->getVRegNum(); - if(color_map.count(vreg)) base_reg_op->setPReg(color_map.at(vreg)); + if(color_map.count(vreg)) { + base_reg_op->setPReg(color_map.at(vreg)); + } else if (spilled_vregs.count(vreg)) { + base_reg_op->setPReg(PhysicalReg::T6); + } } } } diff --git a/src/include/RISCv64AsmPrinter.h b/src/include/RISCv64AsmPrinter.h index 7df35f6..3ea71f6 100644 --- a/src/include/RISCv64AsmPrinter.h +++ b/src/include/RISCv64AsmPrinter.h @@ -8,29 +8,23 @@ namespace sysy { class RISCv64AsmPrinter { public: - // 主入口,将整个MachineFunction打印到指定的输出流 - void runOnMachineFunction(MachineFunction* mfunc, std::ostream& os); + RISCv64AsmPrinter(MachineFunction* mfunc); + // 主入口 + void run(std::ostream& os); private: - // 打印单个基本块 + // 打印各个部分 + void printPrologue(); + void printEpilogue(); void printBasicBlock(MachineBasicBlock* mbb); - - // 打印单条指令 - void printInstruction(MachineInstr* instr, MachineBasicBlock* parent_bb); - - // 打印函数序言 - void printPrologue(MachineFunction* mfunc); + void printInstruction(MachineInstr* instr); - // 打印函数尾声 - void printEpilogue(MachineFunction* mfunc); - - // 将物理寄存器枚举转换为字符串 (从原RISCv64Backend迁移) + // 辅助函数 std::string regToString(PhysicalReg reg); - - // 打印单个操作数 void printOperand(MachineOperand* op); - std::ostream* OS; // 指向当前输出流 + MachineFunction* MFunc; + std::ostream* OS; }; } // namespace sysy diff --git a/src/include/RISCv64Backend.h b/src/include/RISCv64Backend.h index f929007..33f7831 100644 --- a/src/include/RISCv64Backend.h +++ b/src/include/RISCv64Backend.h @@ -1,10 +1,8 @@ #ifndef RISCV64_BACKEND_H #define RISCV64_BACKEND_H -#include "IR.h" // 只需包含高层IR定义 +#include "IR.h" #include -#include -#include namespace sysy { @@ -12,14 +10,12 @@ namespace sysy { class RISCv64CodeGen { public: RISCv64CodeGen(Module* mod) : module(mod) {} - // 唯一的公共入口点 std::string code_gen(); private: - // 模块级代码生成 (处理全局变量和驱动函数生成) + // 模块级代码生成 std::string module_gen(); - // 函数级代码生成 (实现新的流水线) std::string function_gen(Function* func); diff --git a/src/include/RISCv64ISel.h b/src/include/RISCv64ISel.h index e122ee0..795b2b8 100644 --- a/src/include/RISCv64ISel.h +++ b/src/include/RISCv64ISel.h @@ -1,10 +1,7 @@ #ifndef RISCV64_ISEL_H #define RISCV64_ISEL_H -#include "IR.h" #include "RISCv64LLIR.h" -#include -#include namespace sysy { @@ -14,43 +11,34 @@ public: // 模块主入口:将一个高层IR函数转换为底层LLIR函数 std::unique_ptr runOnFunction(Function* func); + // 公开接口,以便后续模块(如RegAlloc)可以查询或创建vreg + unsigned getVReg(Value* val); + unsigned getNewVReg() { return vreg_counter++; } + private: // DAG节点定义,作为ISel的内部实现细节 - struct DAGNode { - enum NodeKind { CONSTANT, LOAD, STORE, BINARY, CALL, RETURN, BRANCH, ALLOCA_ADDR, UNARY, MEMSET }; - NodeKind kind; - Value* value = nullptr; - std::vector operands; - DAGNode(NodeKind k) : kind(k) {} - }; - - // 为当前函数生成LLIR + struct DAGNode; + + // 指令选择主流程 void select(); - // 为单个基本块生成指令 void selectBasicBlock(BasicBlock* bb); - // 核心函数:为DAG节点选择并生成MachineInstr void selectNode(DAGNode* node); - // --- DAG 构建相关函数 (从原RISCv64Backend迁移) --- + // DAG 构建相关函数 (从原RISCv64Backend迁移) std::vector> build_dag(BasicBlock* bb); - DAGNode* get_operand_node(Value* val_ir, std::map& value_to_node, std::vector>& nodes_storage); - DAGNode* create_node(DAGNode::NodeKind kind, Value* val, std::map& value_to_node, std::vector>& nodes_storage); - - // --- 辅助函数 --- - // 为一个IR Value获取/分配一个虚拟寄存器号 - unsigned getVReg(Value* val); + DAGNode* get_operand_node(Value* val_ir, std::map&, std::vector>&); + DAGNode* create_node(int kind, Value* val, std::map&, std::vector>&); + // 状态 Function* F; // 当前处理的高层IR函数 std::unique_ptr MFunc; // 正在构建的底层LLIR函数 - MachineBasicBlock* CurMBB; // 当前正在处理的机器基本块 // 映射关系 std::map vreg_map; std::map bb_map; - std::map value_to_node_map; // 用于selectNode中查找 unsigned vreg_counter; int local_label_counter; diff --git a/src/include/RISCv64LLIR.h b/src/include/RISCv64LLIR.h index c1df3bf..6310741 100644 --- a/src/include/RISCv64LLIR.h +++ b/src/include/RISCv64LLIR.h @@ -1,66 +1,51 @@ #ifndef RISCV64_LLIR_H #define RISCV64_LLIR_H +#include "IR.h" // 确保包含了您自己的IR头文件 #include #include #include #include #include +// 前向声明,避免循环引用 +namespace sysy { +class Function; +class RISCv64ISel; +} + namespace sysy { -// 物理寄存器定义 (从 RISCv64Backend.h 移至此) +// 物理寄存器定义 enum class PhysicalReg { ZERO, RA, SP, GP, TP, T0, T1, T2, S0, S1, A0, A1, A2, A3, A4, A5, A6, A7, S2, S3, S4, S5, S6, S7, S8, S9, S10, S11, T3, T4, T5, T6, F0, F1, F2, F3, F4, F5, F6, F7, F8, F9, F10, F11, F12, F13, F14, F15,F16, F17, F18, F19, F20, F21, F22, F23, F24, F25, F26, F27, F28, F29, F30, F31 }; - // RISC-V 指令操作码枚举 enum class RVOpcodes { // 算术指令 - ADD, ADDI, ADDW, ADDIW, - SUB, SUBW, - MUL, MULW, - DIV, DIVW, - REM, REMW, - + ADD, ADDI, ADDW, ADDIW, SUB, SUBW, MUL, MULW, DIV, DIVW, REM, REMW, // 逻辑指令 - XOR, XORI, - OR, ORI, - AND, ANDI, - + XOR, XORI, OR, ORI, AND, ANDI, // 移位指令 - SLL, SLLI, SLLW, SLLIW, - SRL, SRLI, SRLW, SRLIW, - SRA, SRAI, SRAW, SRAIW, - + SLL, SLLI, SLLW, SLLIW, SRL, SRLI, SRLW, SRLIW, SRA, SRAI, SRAW, SRAIW, // 比较指令 SLT, SLTI, SLTU, SLTIU, - // 内存访问指令 - LW, LH, LB, LWU, LHU, LBU, - SW, SH, SB, - LD, SD, // 64位 - + LW, LH, LB, LWU, LHU, LBU, SW, SH, SB, LD, SD, // 控制流指令 - J, JAL, JALR, RET, // RET 是 JALR x0, 0(ra) 的伪指令 + J, JAL, JALR, RET, BEQ, BNE, BLT, BGE, BLTU, BGEU, - - // 伪指令 (方便指令选择) - LI, // Load Immediate - LA, // Load Address - MV, // Move register - NEG, // Negate - NEGW, // Negate Word - SEQZ, // Set if Equal to Zero - SNEZ, // Set if Not Equal to Zero - + // 伪指令 + LI, LA, MV, NEG, NEGW, SEQZ, SNEZ, // 函数调用 CALL, - // 特殊标记,非指令 - LABEL, // 用于表示一个标签位置 + LABEL, + // 新增伪指令,用于解耦栈帧处理 + FRAME_LOAD, // 从栈帧加载 (AllocaInst) + FRAME_STORE, // 保存到栈帧 (AllocaInst) }; class MachineOperand; @@ -72,22 +57,13 @@ class MachineInstr; class MachineBasicBlock; class MachineFunction; -// --- 操作数定义 --- - // 操作数基类 class MachineOperand { public: - enum OperandKind { - KIND_REG, - KIND_IMM, - KIND_LABEL, - KIND_MEM - }; - + enum OperandKind { KIND_REG, KIND_IMM, KIND_LABEL, KIND_MEM }; MachineOperand(OperandKind kind) : kind(kind) {} virtual ~MachineOperand() = default; OperandKind getKind() const { return kind; } - private: OperandKind kind; }; @@ -111,7 +87,6 @@ public: preg = new_preg; is_virtual = false; } - private: unsigned vreg_num = 0; PhysicalReg preg = PhysicalReg::ZERO; @@ -121,9 +96,7 @@ private: // 立即数操作数 class ImmOperand : public MachineOperand { public: - ImmOperand(int64_t value) - : MachineOperand(KIND_IMM), value(value) {} - + ImmOperand(int64_t value) : MachineOperand(KIND_IMM), value(value) {} int64_t getValue() const { return value; } private: int64_t value; @@ -132,9 +105,7 @@ private: // 标签操作数 class LabelOperand : public MachineOperand { public: - LabelOperand(const std::string& name) - : MachineOperand(KIND_LABEL), name(name) {} - + LabelOperand(const std::string& name) : MachineOperand(KIND_LABEL), name(name) {} const std::string& getName() const { return name; } private: std::string name; @@ -145,33 +116,25 @@ class MemOperand : public MachineOperand { public: MemOperand(std::unique_ptr base, std::unique_ptr offset) : MachineOperand(KIND_MEM), base(std::move(base)), offset(std::move(offset)) {} - RegOperand* getBase() const { return base.get(); } ImmOperand* getOffset() const { return offset.get(); } - private: std::unique_ptr base; std::unique_ptr offset; }; - -// --- 组织结构定义 --- - // 机器指令 class MachineInstr { public: MachineInstr(RVOpcodes opcode) : opcode(opcode) {} RVOpcodes getOpcode() const { return opcode; } - // 注意:返回const引用,因为通常不直接修改指令的操作数列表 const std::vector>& getOperands() const { return operands; } - // 提供一个非const版本,用于内部修改 std::vector>& getOperands() { return operands; } void addOperand(std::unique_ptr operand) { operands.push_back(std::move(operand)); } - private: RVOpcodes opcode; std::vector> operands; @@ -185,8 +148,6 @@ public: const std::string& getName() const { return name; } MachineFunction* getParent() const { return parent; } - - // 同时提供 const 和 non-const 版本 const std::vector>& getInstructions() const { return instructions; } std::vector>& getInstructions() { return instructions; } @@ -196,43 +157,44 @@ public: std::vector successors; std::vector predecessors; - private: std::string name; std::vector> instructions; - MachineFunction* parent; // 指向所属函数 + MachineFunction* parent; }; // 栈帧信息 struct StackFrameInfo { - int frame_size = 0; - std::map spill_slots; // <虚拟寄存器号, 栈偏移> - // ... 未来可以添加更多信息 + int locals_size = 0; // 仅为AllocaInst分配的大小 + int spill_size = 0; // 仅为溢出分配的大小 + int total_size = 0; // 总大小 + std::map alloca_offsets; // + std::map spill_offsets; // <溢出vreg, 栈偏移> }; // 机器函数 class MachineFunction { public: - MachineFunction(const std::string& name) : name(name) {} + MachineFunction(Function* func, RISCv64ISel* isel) : F(func), name(func->getName()), isel(isel) {} + Function* getFunc() const { return F; } + RISCv64ISel* getISel() const { return isel; } const std::string& getName() const { return name; } StackFrameInfo& getFrameInfo() { return frame_info; } - - // 同时提供 const 和 non-const 版本 const std::vector>& getBlocks() const { return blocks; } std::vector>& getBlocks() { return blocks; } void addBlock(std::unique_ptr block) { blocks.push_back(std::move(block)); } - private: + Function* F; + RISCv64ISel* isel; // 指向创建它的ISel,用于获取vreg映射等信息 std::string name; std::vector> blocks; StackFrameInfo frame_info; }; - } // namespace sysy #endif // RISCV64_LLIR_H \ No newline at end of file diff --git a/src/include/RISCv64Passes.h b/src/include/RISCv64Passes.h new file mode 100644 index 0000000..3a4bcd1 --- /dev/null +++ b/src/include/RISCv64Passes.h @@ -0,0 +1,18 @@ +// RISCv64Passes.h +#ifndef RISCV64_PASSES_H +#define RISCV64_PASSES_H + +#include "RISCv64LLIR.h" + +namespace sysy { + +// 此处为未来优化Pass的基类或独立类定义 +// 例如: +// class PeepholeOptimizer { +// public: +// void runOnMachineFunction(MachineFunction* mfunc); +// }; + +} // namespace sysy + +#endif // RISCV64_PASSES_H \ No newline at end of file diff --git a/src/include/RISCv64RegAlloc.h b/src/include/RISCv64RegAlloc.h index ee8ab16..c786bde 100644 --- a/src/include/RISCv64RegAlloc.h +++ b/src/include/RISCv64RegAlloc.h @@ -2,9 +2,6 @@ #define RISCV64_REGALLOC_H #include "RISCv64LLIR.h" -#include -#include -#include namespace sysy { @@ -19,6 +16,9 @@ private: using LiveSet = std::set; // 活跃虚拟寄存器集合 using InterferenceGraph = std::map>; + // 栈帧管理 + void eliminateFrameIndices(); + // 活跃性分析 void analyzeLiveness(); @@ -28,7 +28,7 @@ private: // 图着色分配寄存器 void colorGraph(); - // 重写函数,将虚拟寄存器替换为物理寄存器,并插入溢出代码 + // 重写函数,替换vreg并插入溢出代码 void rewriteFunction(); // 辅助函数,获取指令的Use/Def集合 @@ -37,8 +37,8 @@ private: MachineFunction* MFunc; // 活跃性分析结果 - std::map live_in_map; - std::map live_out_map; + std::map live_in_map; + std::map live_out_map; // 干扰图 InterferenceGraph interference_graph; @@ -49,7 +49,6 @@ private: // 可用的物理寄存器池 std::vector allocable_int_regs; - std::vector allocable_float_regs; // (为未来浮点支持预留) }; } // namespace sysy