From 75e61bf27478d3778d7eda8f1d012aa6d7f1bcb3 Mon Sep 17 00:00:00 2001 From: Lixuanwang Date: Sat, 19 Jul 2025 14:29:57 +0800 Subject: [PATCH 1/3] =?UTF-8?q?[backend-llir]=E5=BC=95=E5=85=A5=E4=BA=86LL?= =?UTF-8?q?IR=E5=AE=9A=E4=B9=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/RISCv64Backend.cpp | 1 + src/include/RISCv64LLIR.h | 229 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 230 insertions(+) create mode 100644 src/include/RISCv64LLIR.h diff --git a/src/RISCv64Backend.cpp b/src/RISCv64Backend.cpp index 29b0bfd..a726bf0 100644 --- a/src/RISCv64Backend.cpp +++ b/src/RISCv64Backend.cpp @@ -1,4 +1,5 @@ #include "RISCv64Backend.h" +#include "RISCv64LLIR.h" #include #include #include diff --git a/src/include/RISCv64LLIR.h b/src/include/RISCv64LLIR.h new file mode 100644 index 0000000..acbeec5 --- /dev/null +++ b/src/include/RISCv64LLIR.h @@ -0,0 +1,229 @@ +#ifndef RISCV64_LLIR_H +#define RISCV64_LLIR_H + +#include +#include +#include +#include +#include + +namespace sysy { + +// 物理寄存器定义 (从 RISCv64Backend.h 移至此) +enum class PhysicalReg { + ZERO, RA, SP, GP, TP, T0, T1, T2, S0, S1, A0, A1, A2, A3, A4, A5, A6, A7, S2, S3, S4, S5, S6, S7, S8, S9, S10, S11, T3, T4, T5, T6, + F0, F1, F2, F3, F4, F5, F6, F7, F8, F9, F10, F11, F12, F13, F14, F15,F16, F17, F18, F19, F20, F21, F22, F23, F24, F25, F26, F27, F28, F29, F30, F31 +}; + + +// RISC-V 指令操作码枚举 +enum class RVOpcodes { + // 算术指令 + ADD, ADDI, ADDW, ADDIW, + SUB, SUBW, + MUL, MULW, + DIV, DIVW, + REM, REMW, + + // 逻辑指令 + XOR, XORI, + OR, ORI, + AND, ANDI, + + // 移位指令 + SLL, SLLI, SLLW, SLLIW, + SRL, SRLI, SRLW, SRLIW, + SRA, SRAI, SRAW, SRAIW, + + // 比较指令 + SLT, SLTI, SLTU, SLTIU, + + // 内存访问指令 + LW, LH, LB, LWU, LHU, LBU, + SW, SH, SB, + LD, SD, // 64位 + + // 控制流指令 + J, JAL, JALR, RET, // RET 是 JALR x0, 0(ra) 的伪指令 + BEQ, BNE, BLT, BGE, BLTU, BGEU, + + // 伪指令 (方便指令选择) + LI, // Load Immediate + LA, // Load Address + MV, // Move register + NEG, // Negate + NEGW, // Negate Word + SEQZ, // Set if Equal to Zero + SNEZ, // Set if Not Equal to Zero + + // 函数调用 + CALL, + + // 特殊标记,非指令 + LABEL, // 用于表示一个标签位置 +}; + +class MachineOperand; +class RegOperand; +class ImmOperand; +class LabelOperand; +class MemOperand; +class MachineInstr; +class MachineBasicBlock; +class MachineFunction; + +// --- 操作数定义 --- + +// 操作数基类 +class MachineOperand { +public: + enum OperandKind { + KIND_REG, + KIND_IMM, + KIND_LABEL, + KIND_MEM + }; + + MachineOperand(OperandKind kind) : kind(kind) {} + virtual ~MachineOperand() = default; + OperandKind getKind() const { return kind; } + +private: + OperandKind kind; +}; + +// 寄存器操作数 +class RegOperand : public MachineOperand { +public: + // 构造虚拟寄存器 + RegOperand(unsigned vreg_num) + : MachineOperand(KIND_REG), vreg_num(vreg_num), is_virtual(true) {} + + // 构造物理寄存器 + RegOperand(PhysicalReg preg) + : MachineOperand(KIND_REG), preg(preg), is_virtual(false) {} + + bool isVirtual() const { return is_virtual; } + unsigned getVRegNum() const { return vreg_num; } + PhysicalReg getPReg() const { return preg; } + + void setPReg(PhysicalReg new_preg) { + preg = new_preg; + is_virtual = false; + } + +private: + unsigned vreg_num = 0; + PhysicalReg preg = PhysicalReg::ZERO; + bool is_virtual; +}; + +// 立即数操作数 +class ImmOperand : public MachineOperand { +public: + ImmOperand(int64_t value) + : MachineOperand(KIND_IMM), value(value) {} + + int64_t getValue() const { return value; } +private: + int64_t value; +}; + +// 标签操作数 +class LabelOperand : public MachineOperand { +public: + LabelOperand(const std::string& name) + : MachineOperand(KIND_LABEL), name(name) {} + + const std::string& getName() const { return name; } +private: + std::string name; +}; + +// 内存操作数, 表示 offset(base_reg) +class MemOperand : public MachineOperand { +public: + MemOperand(std::unique_ptr base, std::unique_ptr offset) + : MachineOperand(KIND_MEM), base(std::move(base)), offset(std::move(offset)) {} + + RegOperand* getBase() const { return base.get(); } + ImmOperand* getOffset() const { return offset.get(); } + +private: + std::unique_ptr base; + std::unique_ptr offset; +}; + + +// --- 组织结构定义 --- + +// 机器指令 +class MachineInstr { +public: + MachineInstr(RVOpcodes opcode) : opcode(opcode) {} + + RVOpcodes getOpcode() const { return opcode; } + const std::vector>& getOperands() const { return operands; } + + void addOperand(std::unique_ptr operand) { + operands.push_back(std::move(operand)); + } + +private: + RVOpcodes opcode; + std::vector> operands; +}; + +// 机器基本块 +class MachineBasicBlock { +public: + MachineBasicBlock(const std::string& name, MachineFunction* parent) + : name(name), parent(parent) {} + + const std::string& getName() const { return name; } + const std::vector>& getInstructions() const { return instructions; } + MachineFunction* getParent() const { return parent; } + + void addInstruction(std::unique_ptr instr) { + instructions.push_back(std::move(instr)); + } + + std::vector successors; + std::vector predecessors; + +private: + std::string name; + std::vector> instructions; + MachineFunction* parent; // 指向所属函数 +}; + +// 栈帧信息 +struct StackFrameInfo { + int frame_size = 0; + std::map spill_slots; // <虚拟寄存器号, 栈偏移> + // ... 未来可以添加更多信息 +}; + +// 机器函数 +class MachineFunction { +public: + MachineFunction(const std::string& name) : name(name) {} + + const std::string& getName() const { return name; } + const std::vector>& getBlocks() const { return blocks; } + StackFrameInfo& getFrameInfo() { return frame_info; } + + void addBlock(std::unique_ptr block) { + blocks.push_back(std::move(block)); + } + +private: + std::string name; + std::vector> blocks; + StackFrameInfo frame_info; +}; + + +} // namespace sysy + +#endif // RISCV64_LLIR_H \ No newline at end of file From d4a6996d74930bb5c0c12b41d6db33367b01bf83 Mon Sep 17 00:00:00 2001 From: Lixuanwang Date: Sat, 19 Jul 2025 16:06:35 +0800 Subject: [PATCH 2/3] =?UTF-8?q?[backend]=E9=87=8D=E6=9E=84=E4=BA=86?= =?UTF-8?q?=E5=90=8E=E7=AB=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/CMakeLists.txt | 3 + src/RISCv64AsmPrinter.cpp | 246 +++++ src/RISCv64Backend.cpp | 1491 +------------------------------ src/RISCv64ISel.cpp | 605 +++++++++++++ src/RISCv64RegAlloc.cpp | 265 ++++++ src/include/RISCv64AsmPrinter.h | 38 + src/include/RISCv64Backend.h | 119 +-- src/include/RISCv64ISel.h | 61 ++ src/include/RISCv64LLIR.h | 15 +- src/include/RISCv64RegAlloc.h | 57 ++ 10 files changed, 1336 insertions(+), 1564 deletions(-) create mode 100644 src/RISCv64AsmPrinter.cpp create mode 100644 src/RISCv64ISel.cpp create mode 100644 src/RISCv64RegAlloc.cpp create mode 100644 src/include/RISCv64AsmPrinter.h create mode 100644 src/include/RISCv64ISel.h create mode 100644 src/include/RISCv64RegAlloc.h diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 050b63c..a8a4249 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -28,6 +28,9 @@ add_executable(sysyc Mem2Reg.cpp Reg2Mem.cpp RISCv64Backend.cpp + RISCv64ISel.cpp + RISCv64RegAlloc.cpp + RISCv64AsmPrinter.cpp ) # 设置 include 路径,包含 ANTLR 运行时库和项目头文件 diff --git a/src/RISCv64AsmPrinter.cpp b/src/RISCv64AsmPrinter.cpp new file mode 100644 index 0000000..0688095 --- /dev/null +++ b/src/RISCv64AsmPrinter.cpp @@ -0,0 +1,246 @@ +#include "RISCv64AsmPrinter.h" +#include + +namespace sysy { + +void RISCv64AsmPrinter::runOnMachineFunction(MachineFunction* mfunc, std::ostream& os) { + OS = &os; + + // 打印函数声明和全局符号 + *OS << ".text\n"; + *OS << ".globl " << mfunc->getName() << "\n"; + *OS << mfunc->getName() << ":\n"; + + // 打印函数序言 + printPrologue(mfunc); + + // 遍历并打印所有基本块 + for (auto& mbb : mfunc->getBlocks()) { + printBasicBlock(mbb.get()); + } +} + +void RISCv64AsmPrinter::printPrologue(MachineFunction* mfunc) { + int stack_size = mfunc->getFrameInfo().frame_size; + + // 确保栈大小是16字节对齐 + int aligned_stack_size = (stack_size + 15) & ~15; + + if (aligned_stack_size > 0) { + *OS << " addi sp, sp, -" << aligned_stack_size << "\n"; + // RV64中ra和s0都是8字节 + *OS << " sd ra, " << (aligned_stack_size - 8) << "(sp)\n"; + *OS << " sd s0, " << (aligned_stack_size - 16) << "(sp)\n"; + *OS << " mv s0, sp\n"; + } +} + +void RISCv64AsmPrinter::printEpilogue(MachineFunction* mfunc) { + int stack_size = mfunc->getFrameInfo().frame_size; + int aligned_stack_size = (stack_size + 15) & ~15; + + if (aligned_stack_size > 0) { + *OS << " ld ra, " << (aligned_stack_size - 8) << "(sp)\n"; + *OS << " ld s0, " << (aligned_stack_size - 16) << "(sp)\n"; + *OS << " addi sp, sp, " << aligned_stack_size << "\n"; + } +} + + +void RISCv64AsmPrinter::printBasicBlock(MachineBasicBlock* mbb) { + // 打印基本块标签 + if (!mbb->getName().empty()) { + *OS << mbb->getName() << ":\n"; + } + + // 打印指令 + for (auto& instr : mbb->getInstructions()) { + printInstruction(instr.get(), mbb); + } +} + +void RISCv64AsmPrinter::printInstruction(MachineInstr* instr, MachineBasicBlock* parent_bb) { + *OS << " "; // 指令缩进 + + auto opcode = instr->getOpcode(); + + // RET指令需要特殊处理,在打印ret之前先打印函数尾声 + if (opcode == RVOpcodes::RET) { + printEpilogue(parent_bb->getParent()); + } + + // 使用switch将Opcode转换为汇编助记符 + switch (opcode) { + // Arithmatic + case RVOpcodes::ADD: *OS << "add "; break; + case RVOpcodes::ADDI: *OS << "addi "; break; + case RVOpcodes::ADDW: *OS << "addw "; break; + case RVOpcodes::ADDIW: *OS << "addiw "; break; + case RVOpcodes::SUB: *OS << "sub "; break; + case RVOpcodes::SUBW: *OS << "subw "; break; + case RVOpcodes::MUL: *OS << "mul "; break; + case RVOpcodes::MULW: *OS << "mulw "; break; + case RVOpcodes::DIV: *OS << "div "; break; + case RVOpcodes::DIVW: *OS << "divw "; break; + case RVOpcodes::REM: *OS << "rem "; break; + case RVOpcodes::REMW: *OS << "remw "; break; + // Logical + case RVOpcodes::XOR: *OS << "xor "; break; + case RVOpcodes::XORI: *OS << "xori "; break; + case RVOpcodes::OR: *OS << "or "; break; + case RVOpcodes::ORI: *OS << "ori "; break; + case RVOpcodes::AND: *OS << "and "; break; + case RVOpcodes::ANDI: *OS << "andi "; break; + // Shift + case RVOpcodes::SLL: *OS << "sll "; break; + case RVOpcodes::SLLI: *OS << "slli "; break; + case RVOpcodes::SLLW: *OS << "sllw "; break; + case RVOpcodes::SLLIW: *OS << "slliw "; break; + case RVOpcodes::SRL: *OS << "srl "; break; + case RVOpcodes::SRLI: *OS << "srli "; break; + case RVOpcodes::SRLW: *OS << "srlw "; break; + case RVOpcodes::SRLIW: *OS << "srliw "; break; + case RVOpcodes::SRA: *OS << "sra "; break; + case RVOpcodes::SRAI: *OS << "srai "; break; + case RVOpcodes::SRAW: *OS << "sraw "; break; + case RVOpcodes::SRAIW: *OS << "sraiw "; break; + // Compare + case RVOpcodes::SLT: *OS << "slt "; break; + case RVOpcodes::SLTI: *OS << "slti "; break; + case RVOpcodes::SLTU: *OS << "sltu "; break; + case RVOpcodes::SLTIU: *OS << "sltiu "; break; + // Memory + case RVOpcodes::LW: *OS << "lw "; break; + case RVOpcodes::LH: *OS << "lh "; break; + case RVOpcodes::LB: *OS << "lb "; break; + case RVOpcodes::LWU: *OS << "lwu "; break; + case RVOpcodes::LHU: *OS << "lhu "; break; + case RVOpcodes::LBU: *OS << "lbu "; break; + case RVOpcodes::SW: *OS << "sw "; break; + case RVOpcodes::SH: *OS << "sh "; break; + case RVOpcodes::SB: *OS << "sb "; break; + case RVOpcodes::LD: *OS << "ld "; break; + case RVOpcodes::SD: *OS << "sd "; break; + // Control Flow + case RVOpcodes::J: *OS << "j "; break; + case RVOpcodes::JAL: *OS << "jal "; break; + case RVOpcodes::JALR: *OS << "jalr "; break; + case RVOpcodes::RET: *OS << "ret"; break; + case RVOpcodes::BEQ: *OS << "beq "; break; + case RVOpcodes::BNE: *OS << "bne "; break; + case RVOpcodes::BLT: *OS << "blt "; break; + case RVOpcodes::BGE: *OS << "bge "; break; + case RVOpcodes::BLTU: *OS << "bltu "; break; + case RVOpcodes::BGEU: *OS << "bgeu "; break; + // Pseudo-Instructions + case RVOpcodes::LI: *OS << "li "; break; + case RVOpcodes::LA: *OS << "la "; break; + case RVOpcodes::MV: *OS << "mv "; break; + case RVOpcodes::NEG: *OS << "neg "; break; + case RVOpcodes::NEGW: *OS << "negw "; break; + case RVOpcodes::SEQZ: *OS << "seqz "; break; + case RVOpcodes::SNEZ: *OS << "snez "; break; + // Call + case RVOpcodes::CALL: *OS << "call "; break; + // Special + case RVOpcodes::LABEL: + *OS << "\b\b\b\b"; + printOperand(instr->getOperands()[0].get()); + *OS << ":"; + break; + default: + throw std::runtime_error("Unknown opcode in AsmPrinter"); + } + + // 打印操作数 + const auto& operands = instr->getOperands(); + for (size_t i = 0; i < operands.size(); ++i) { + // 对于LW/SW, 操作数格式是 rd, offset(rs1) + if (opcode == RVOpcodes::LW || opcode == RVOpcodes::SW || opcode == RVOpcodes::LD || opcode == RVOpcodes::SD) { + printOperand(operands[0].get()); + *OS << ", "; + printOperand(operands[1].get()); + break; // LW/SW只有两个操作数部分 + } + + printOperand(operands[i].get()); + if (i < operands.size() - 1) { + *OS << ", "; + } + } + + *OS << "\n"; +} + +void RISCv64AsmPrinter::printOperand(MachineOperand* op) { + if (!op) return; + switch(op->getKind()) { + case MachineOperand::KIND_REG: { + auto reg_op = static_cast(op); + if (reg_op->isVirtual()) { + // 在这个阶段不应该再有虚拟寄存器了 + *OS << "%vreg" << reg_op->getVRegNum(); + } else { + *OS << regToString(reg_op->getPReg()); + } + break; + } + case MachineOperand::KIND_IMM: { + *OS << static_cast(op)->getValue(); + break; + } + case MachineOperand::KIND_LABEL: { + *OS << static_cast(op)->getName(); + break; + } + case MachineOperand::KIND_MEM: { + auto mem_op = static_cast(op); + printOperand(mem_op->getOffset()); + *OS << "("; + printOperand(mem_op->getBase()); + *OS << ")"; + break; + } + } +} + +// 物理寄存器到字符串的转换 (从原RISCv64Backend.cpp迁移) +std::string RISCv64AsmPrinter::regToString(PhysicalReg reg) { + switch (reg) { + case PhysicalReg::ZERO: return "x0"; + case PhysicalReg::RA: return "ra"; + case PhysicalReg::SP: return "sp"; + case PhysicalReg::GP: return "gp"; + case PhysicalReg::TP: return "tp"; + case PhysicalReg::T0: return "t0"; + case PhysicalReg::T1: return "t1"; + case PhysicalReg::T2: return "t2"; + case PhysicalReg::S0: return "s0"; + case PhysicalReg::S1: return "s1"; + case PhysicalReg::A0: return "a0"; + case PhysicalReg::A1: return "a1"; + case PhysicalReg::A2: return "a2"; + case PhysicalReg::A3: return "a3"; + case PhysicalReg::A4: return "a4"; + case PhysicalReg::A5: return "a5"; + case PhysicalReg::A6: return "a6"; + case PhysicalReg::A7: return "a7"; + case PhysicalReg::S2: return "s2"; + case PhysicalReg::S3: return "s3"; + case PhysicalReg::S4: return "s4"; + case PhysicalReg::S5: return "s5"; + case PhysicalReg::S6: return "s6"; + case PhysicalReg::S7: return "s7"; + case PhysicalReg::S8: return "s8"; + case PhysicalReg::S9: return "s9"; + case PhysicalReg::S10: return "s10"; + case PhysicalReg::S11: return "s11"; + case PhysicalReg::T3: return "t3"; + case PhysicalReg::T4: return "t4"; + case PhysicalReg::T5: return "t5"; + case PhysicalReg::T6: return "t6"; + default: return "UNKNOWN_REG"; + } +} + +} // namespace sysy \ No newline at end of file diff --git a/src/RISCv64Backend.cpp b/src/RISCv64Backend.cpp index a726bf0..c9fdead 100644 --- a/src/RISCv64Backend.cpp +++ b/src/RISCv64Backend.cpp @@ -1,123 +1,28 @@ #include "RISCv64Backend.h" -#include "RISCv64LLIR.h" +#include "RISCv64ISel.h" +#include "RISCv64RegAlloc.h" +#include "RISCv64AsmPrinter.h" #include -#include #include -#include -#include -#include -namespace sysy { +namespace sysy { -// 可用于分配的寄存器 -const std::vector RISCv64CodeGen::allocable_regs = { - // 整数寄存器 - PhysicalReg::T0, PhysicalReg::T1, PhysicalReg::T2, PhysicalReg::T3, - PhysicalReg::T4, PhysicalReg::T5, PhysicalReg::T6, - PhysicalReg::A0, PhysicalReg::A1, PhysicalReg::A2, PhysicalReg::A3, - PhysicalReg::A4, PhysicalReg::A5, PhysicalReg::A6, PhysicalReg::A7, - PhysicalReg::S0, PhysicalReg::S1, PhysicalReg::S2, PhysicalReg::S3, - PhysicalReg::S4, PhysicalReg::S5, PhysicalReg::S6, PhysicalReg::S7, - PhysicalReg::S8, PhysicalReg::S9, PhysicalReg::S10, PhysicalReg::S11, - // 浮点寄存器 - PhysicalReg::F0, PhysicalReg::F1, PhysicalReg::F2, PhysicalReg::F3, - PhysicalReg::F4, PhysicalReg::F5, PhysicalReg::F6, PhysicalReg::F7, - PhysicalReg::F8, PhysicalReg::F9, PhysicalReg::F10, PhysicalReg::F11, - PhysicalReg::F12, PhysicalReg::F13, PhysicalReg::F14, PhysicalReg::F15, - PhysicalReg::F16, PhysicalReg::F17, PhysicalReg::F18, PhysicalReg::F19, - PhysicalReg::F20, PhysicalReg::F21, PhysicalReg::F22, PhysicalReg::F23, - PhysicalReg::F24, PhysicalReg::F25, PhysicalReg::F26, PhysicalReg::F27, - PhysicalReg::F28, PhysicalReg::F29, PhysicalReg::F30, PhysicalReg::F31 -}; - -// 将物理寄存器枚举转换为字符串 -std::string RISCv64CodeGen::reg_to_string(PhysicalReg reg) { - switch (reg) { - case PhysicalReg::ZERO: return "x0"; - case PhysicalReg::RA: return "ra"; - case PhysicalReg::SP: return "sp"; - case PhysicalReg::GP: return "gp"; - case PhysicalReg::TP: return "tp"; - case PhysicalReg::T0: return "t0"; - case PhysicalReg::T1: return "t1"; - case PhysicalReg::T2: return "t2"; - case PhysicalReg::S0: return "s0"; - case PhysicalReg::S1: return "s1"; - case PhysicalReg::A0: return "a0"; - case PhysicalReg::A1: return "a1"; - case PhysicalReg::A2: return "a2"; - case PhysicalReg::A3: return "a3"; - case PhysicalReg::A4: return "a4"; - case PhysicalReg::A5: return "a5"; - case PhysicalReg::A6: return "a6"; - case PhysicalReg::A7: return "a7"; - case PhysicalReg::S2: return "s2"; - case PhysicalReg::S3: return "s3"; - case PhysicalReg::S4: return "s4"; - case PhysicalReg::S5: return "s5"; - case PhysicalReg::S6: return "s6"; - case PhysicalReg::S7: return "s7"; - case PhysicalReg::S8: return "s8"; - case PhysicalReg::S9: return "s9"; - case PhysicalReg::S10: return "s10"; - case PhysicalReg::S11: return "s11"; - case PhysicalReg::T3: return "t3"; - case PhysicalReg::T4: return "t4"; - case PhysicalReg::T5: return "t5"; - case PhysicalReg::T6: return "t6"; - // 浮点寄存器 - case PhysicalReg::F0: return "f0"; - case PhysicalReg::F1: return "f1"; - case PhysicalReg::F2: return "f2"; - case PhysicalReg::F3: return "f3"; - case PhysicalReg::F4: return "f4"; - case PhysicalReg::F5: return "f5"; - case PhysicalReg::F6: return "f6"; - case PhysicalReg::F7: return "f7"; - case PhysicalReg::F8: return "f8"; - case PhysicalReg::F9: return "f9"; - case PhysicalReg::F10: return "f10"; - case PhysicalReg::F11: return "f11"; - case PhysicalReg::F12: return "f12"; - case PhysicalReg::F13: return "f13"; - case PhysicalReg::F14: return "f14"; - case PhysicalReg::F15: return "f15"; - case PhysicalReg::F16: return "f16"; - case PhysicalReg::F17: return "f17"; - case PhysicalReg::F18: return "f18"; - case PhysicalReg::F19: return "f19"; - case PhysicalReg::F20: return "f20"; - case PhysicalReg::F21: return "f21"; - case PhysicalReg::F22: return "f22"; - case PhysicalReg::F23: return "f23"; - case PhysicalReg::F24: return "f24"; - case PhysicalReg::F25: return "f25"; - case PhysicalReg::F26: return "f26"; - case PhysicalReg::F27: return "f27"; - case PhysicalReg::F28: return "f28"; - case PhysicalReg::F29: return "f29"; - case PhysicalReg::F30: return "f30"; - case PhysicalReg::F31: return "f31"; - default: return "UNKNOWN_REG"; - } -} - -// 总体代码生成入口 +// 顶层入口 std::string RISCv64CodeGen::code_gen() { - std::stringstream ss; - ss << module_gen(); - return ss.str(); + return module_gen(); } -// 模块级代码生成 (处理全局变量和函数) +// module_gen 的逻辑基本不变,它负责处理.data段和驱动每个函数的生成 std::string RISCv64CodeGen::module_gen() { std::stringstream ss; + + // 1. 处理全局变量 (.data段) bool has_globals = !module->getGlobals().empty(); if (has_globals) { - ss << ".data\n"; // 数据段 + ss << ".data\n"; for (const auto& global : module->getGlobals()) { - ss << ".globl " << global->getName() << "\n"; // 声明全局符号 - ss << global->getName() << ":\n"; // 标签 + ss << ".globl " << global->getName() << "\n"; + ss << global->getName() << ":\n"; const auto& init_values = global->getInitValues(); for (size_t i = 0; i < init_values.getValues().size(); ++i) { auto val = init_values.getValues()[i]; @@ -125,1375 +30,59 @@ std::string RISCv64CodeGen::module_gen() { if (auto constant = dynamic_cast(val)) { for (unsigned j = 0; j < count; ++j) { if (constant->isInt()) { - ss << " .word " << constant->getInt() << "\n"; // 整数常量 (32位) + ss << " .word " << constant->getInt() << "\n"; } else { float f = constant->getFloat(); uint32_t float_bits = *(uint32_t*)&f; - ss << " .word " << float_bits << "\n"; // 浮点常量 (32位) + ss << " .word " << float_bits << "\n"; } } } } } } + + // 2. 处理函数 (.text段) if (!module->getFunctions().empty()) { - ss << ".text\n"; // 代码段 + ss << ".text\n"; for (const auto& func : module->getFunctions()) { - ss << function_gen(func.second.get()); + if (func.second.get()) { + ss << function_gen(func.second.get()); + } } } return ss.str(); } -// 函数级代码生成 +// function_gen 现在是新的、模块化的处理流水线 std::string RISCv64CodeGen::function_gen(Function* func) { - this->local_label_counter = 0; // 为每个函数重置本地标签计数器 - std::stringstream ss; - ss << ".globl " << func->getName() << "\n"; // 声明函数为全局符号 - ss << func->getName() << ":\n"; // 函数入口标签 - - RegAllocResult alloc_result = register_allocation(func); - int stack_size = alloc_result.stack_size; - - // 函数序言 (Prologue) - // RV64: ra 和 s0 都是64位(8字节)寄存器 - // 保存 ra 和 s0, 调整栈指针 - // s0 指向当前帧的底部(分配局部变量/溢出空间后的 sp) - // 确保栈大小 16 字节对齐 - int aligned_stack_size = (stack_size + 15) & ~15; - - // 只有当需要栈空间时才生成序言 - if (aligned_stack_size > 0) { - ss << " addi sp, sp, -" << aligned_stack_size << "\n"; // 调整栈指针 - // RV64 修改: 使用 sd (store doubleword) 保存 8 字节的 ra 和 s0 - // 同时更新偏移量,为每个寄存器保留8字节 - ss << " sd ra, " << (aligned_stack_size - 8) << "(sp)\n"; // 保存返回地址 (8字节) - ss << " sd s0, " << (aligned_stack_size - 16) << "(sp)\n"; // 保存帧指针 (8字节) - ss << " mv s0, sp\n"; // 设置新的帧指针 - } - - // 将传入的寄存器参数 (a0-a7 / f10-f17) 保存到对应的栈槽 (AllocaInst)。 - // RV64中,a0-a7是64位寄存器,但我们传入的int/float是32位。 - // 使用 sw/fsw 会正确地存储低32位,这是正确的行为。 - int arg_idx = 0; - BasicBlock* entry_bb = func->getEntryBlock(); // 获取函数的入口基本块 - - if (entry_bb) { // 确保入口基本块存在 - for (AllocaInst* alloca_for_param : entry_bb->getArguments()) { - if (arg_idx >= 8) { - std::cerr << "警告: 函数 '" << func->getName() << "' 的参数 (索引 " << arg_idx << ") 数量超过了 RISC-V 寄存器传递限制 (8个参数)。\n" - << " 这些参数目前未通过栈正确处理,可能导致错误。\n"; - break; - } - - if (alloc_result.stack_map.count(alloca_for_param)) { - int offset = alloc_result.stack_map.at(alloca_for_param); - Type* allocated_type = alloca_for_param->getType()->as()->getBaseType(); - - if (allocated_type->isInt()) { - PhysicalReg arg_reg = static_cast(static_cast(PhysicalReg::A0) + arg_idx); - std::string arg_reg_str = reg_to_string(arg_reg); - // 使用 sw 保存 int (32位) 参数,这是正确的 - ss << " sw " << arg_reg_str << ", " << offset << "(s0)\n"; - } else if (allocated_type->isFloat()) { - PhysicalReg farg_reg = static_cast(static_cast(PhysicalReg::F10) + arg_idx); - std::string farg_reg_str = reg_to_string(farg_reg); - // 使用 fsw 保存 float (32位) 参数,这是正确的 - ss << " fsw " << farg_reg_str << ", " << offset << "(s0)\n"; - } else { - throw std::runtime_error("Unsupported function argument type encountered during parameter saving to stack."); - } - } else { - std::cerr << "警告: 函数参数对应的 AllocaInst '" - << (alloca_for_param->getName().empty() ? "anonymous" : alloca_for_param->getName()) - << "' 没有在栈映射中找到。这可能导致后续代码生成错误。\n"; - } - arg_idx++; - } - } else { - std::cerr << "错误: 函数 '" << func->getName() << "' 没有入口基本块。\n"; - } - - // 生成每个基本块的代码 - int block_idx = 0; - for (const auto& bb : func->getBasicBlocks()) { - ss << basicBlock_gen(bb.get(), alloc_result, block_idx++); - } - - // 函数尾声 (Epilogue) 由 RETURN DAGNode 的指令选择处理 - return ss.str(); -} - - -// 基本块代码生成 -std::string RISCv64CodeGen::basicBlock_gen(BasicBlock* bb, const RegAllocResult& alloc, int block_idx) { - std::stringstream ss; - - std::string bb_name = bb->getName(); - if (bb_name.empty()) { - bb_name = ENTRY_BLOCK_PSEUDO_NAME + std::to_string(block_idx); - if (block_idx == 0) { - bb_name = "entry"; - } - } - else { - ss << bb_name << ":\n"; // 基本块标签 - } - if (DEBUG) std::cerr << "=== 生成基本块: " << bb_name << " ===\n"; - - // 构建当前基本块的 DAG - auto dag_nodes_for_bb = build_dag(bb); - if (DEBUG) - print_dag(dag_nodes_for_bb, bb_name); // 打印 DAG 调试信息 - - // 存储最终生成的指令 - std::set emitted_nodes; // 跟踪已发射的节点,防止重复 - std::vector ordered_insts; // 用于收集指令并按序排列 - - // 在 DAG 中遍历并生成指令。由于 select_instructions 可能会递归地为操作数选择指令, - // 并且 emit_instructions 也会递归地发射,我们需要一个机制来确保指令的正确顺序和唯一性。 - // 最简单的方法是逆拓扑序遍历所有节点,确保其操作数先被处理。 - // 但是目前的 DAG 构建方式可能不支持直接的拓扑排序, - // 我们将依赖 emit_instructions 的递归特性来处理依赖。 - - // 遍历 DAG 的根节点(没有用户的节点,或者 Store/Return/Branch 节点) - // 从这些节点开始递归发射指令。 - // NOTE: 这种发射方式可能不总是产生最优的代码顺序,但可以确保依赖关系。 - for (auto it = dag_nodes_for_bb.rbegin(); it != dag_nodes_for_bb.rend(); ++it) { - DAGNode* node = it->get(); - // 只有那些没有用户(或者代表副作用,如STORE, RETURN, BRANCH)的节点才需要作为发射的“根” - // 否则,它们会被其用户节点递归地发射 - // 然而,为了确保所有指令都被发射,我们通常从所有节点(或者至少是副作用节点)开始发射 - // 并且利用 emitted_nodes 集合防止重复 - // 这里简化为对所有 DAG 节点进行一次 select_instructions 和 emit_instructions 调用。 - // emit_instructions 会通过递归处理其操作数来保证依赖顺序。 - select_instructions(node, alloc); // 为当前节点选择指令 - } - - // 收集所有指令到一个临时的 vector 中,然后进行排序 - // 注意:这里的发射逻辑需要重新设计,目前的 emit_instructions 是直接添加到 std::vector& insts 中 - // 并且期望是按顺序添加的,这在递归时难以保证。 - // 更好的方法是让 emit_instructions 直接输出到 stringstream,并控制递归顺序。 - // 但是为了最小化改动,我们先保持 emit_instructions 的现有签名, - // 然后在它内部处理指令的收集和去重。 - - // 重新设计 emit_instructions 的调用方式 - // 这里的思路是,每个 DAGNode 都存储了自己及其依赖(如果未被其他节点引用)的指令。 - // 最终,我们遍历 BasicBlock 中的所有原始 IR 指令,找到它们对应的 DAGNode,然后发射。 - // 这是因为 IR 指令的顺序决定了代码的逻辑顺序。 - - // 遍历 IR 指令,并找到对应的 DAGNode 进行发射 - // 由于 build_dag 是从 IR 指令顺序构建的,我们应该按照 IR 指令的顺序来发射。 - emitted_nodes.clear(); // 再次清空已发射节点集合 - // 临时存储每个 IR 指令对应的 DAGNode,因为 DAGNode 列表是平铺的 - std::map inst_to_dag_node; - for (const auto& dag_node_ptr : dag_nodes_for_bb) { - if (dag_node_ptr->value && dynamic_cast(dag_node_ptr->value)) { - inst_to_dag_node[dynamic_cast(dag_node_ptr->value)] = dag_node_ptr.get(); - } - } - - for (const auto& inst_ptr : bb->getInstructions()) { - DAGNode* node_to_emit = nullptr; - // 查找当前 IR 指令在 DAG 中对应的节点。 - // 注意:不是所有 IR 指令都会直接映射到一个“根”DAGNode (例如,某些值可能只作为操作数存在) - // 但终结符(如 Branch, Return)和 Store 指令总是重要的。 - // 对于 load/binary 等,我们应该在 build_dag 中确保它们有一个结果 vreg,并被后续指令使用。 - // 如果一个 IR 指令是某个 DAGNode 的 value,那么我们就发射那个 DAGNode。 - if (inst_to_dag_node.count(inst_ptr.get())) { - node_to_emit = inst_to_dag_node.at(inst_ptr.get()); - } - - if (node_to_emit) { - // 注意:select_instructions 已经在上面统一调用过,这里只需要 emit。 - // 但如果 select_instructions 没有递归地为所有依赖选择指令,这里可能需要重新考虑。 - // 为了简化,我们假定 select_instructions 在第一次被调用时(通常在 emit 之前)已经递归地为所有操作数选择了指令。 - - // 直接将指令添加到 ss 中,而不是通过 vector 中转 - emit_instructions(node_to_emit, ss, alloc, emitted_nodes); - } - } - - return ss.str(); -} - -// 辅助函数,用于创建 DAGNode 并管理其所有权 -sysy::RISCv64CodeGen::DAGNode* sysy::RISCv64CodeGen::create_node( - DAGNode::NodeKind kind, - Value* val, - std::map& value_to_node, // 需要外部传入 - std::vector>& nodes_storage // 需要外部传入 -) { - // 优化:如果一个值已经有节点并且它不是控制流/存储/Alloca地址/一元操作,则重用它 (CSE) - // 对于 AllocaInst,我们想创建一个代表其地址的节点,但不一定直接为 AllocaInst 本身分配虚拟寄存器。 - if (val && value_to_node.count(val) && kind != DAGNode::STORE && kind != DAGNode::RETURN && kind != DAGNode::BRANCH && kind != DAGNode::ALLOCA_ADDR && kind != DAGNode::UNARY) { - return value_to_node[val]; - } - - auto node = std::make_unique(kind); - node->value = val; - - // 为产生结果的值分配虚拟寄存器 - if (val && value_vreg_map.count(val) && !dynamic_cast(val)) { // 排除 AllocaInst - node->result_vreg = value_vreg_map.at(val); - } - DAGNode* raw_node_ptr = node.get(); - nodes_storage.push_back(std::move(node)); // 存储 unique_ptr + // === 新的、完整的流水线 === - // 仅当 IR Value 表示一个计算值时,才将其映射到创建的 DAGNode - // 且它应该已经在 register_allocation 中被分配了 vreg - if (val && value_vreg_map.count(val) && kind != DAGNode::STORE && kind != DAGNode::RETURN && kind != DAGNode::BRANCH && !dynamic_cast(val)) { - value_to_node[val] = raw_node_ptr; - } - return raw_node_ptr; -} + // 阶段 1: 指令选择 (sysy::IR -> LLIR with virtual registers) + RISCv64ISel isel; + std::unique_ptr mfunc = isel.runOnFunction(func); + // 阶段 2: 寄存器分配前优化 (未来扩展点) + // 例如: + // auto pre_ra_scheduler = std::make_unique(); + // pre_ra_scheduler->runOnMachineFunction(mfunc.get()); -// 辅助函数:获取值的 DAG 节点。 -// 如果 value 已经映射到 DAG 节点,则直接返回。 -// 如果是常量,则创建 CONSTANT 节点。 -// 如果是 AllocaInst,则创建 ALLOCA_ADDR 节点。 -// 否则,假定需要通过 LOAD 获取该值。 -sysy::RISCv64CodeGen::DAGNode* sysy::RISCv64CodeGen::get_operand_node( - Value* val_ir, - std::map& value_to_node, // 接受 value_to_node - std::vector>& nodes_storage // 接受 nodes_storage -) { - if (value_to_node.count(val_ir)) { - return value_to_node[val_ir]; - } else if (auto constant = dynamic_cast(val_ir)) { - return create_node(DAGNode::CONSTANT, constant, value_to_node, nodes_storage); // 调用成员函数版 create_node - } else if (auto alloca = dynamic_cast(val_ir)) { - return create_node(DAGNode::ALLOCA_ADDR, alloca, value_to_node, nodes_storage); // 调用成员函数版 create_node - } else if (auto global = dynamic_cast(val_ir)) { - // 确保 GlobalValue 也能正确处理,如果 DAGNode::CONSTANT 无法存储 GlobalValue*, - // 则需要新的 DAGNode 类型,例如 DAGNode::GLOBAL_ADDR - return create_node(DAGNode::CONSTANT, global, value_to_node, nodes_storage); // 调用成员函数版 create_node - } - // 这是一个尚未在此块中计算的值,假设它需要加载 (从内存或参数) - return create_node(DAGNode::LOAD, val_ir, value_to_node, nodes_storage); // 调用成员函数版 create_node -} + // 阶段 3: 物理寄存器分配 (virtual regs -> physical regs + spill code) + RISCv64RegAlloc reg_alloc(mfunc.get()); + reg_alloc.run(); -std::vector> RISCv64CodeGen::build_dag(BasicBlock* bb) { - std::vector> nodes_storage; // 存储所有 unique_ptr - std::map value_to_node; // 将 IR Value* 映射到原始 DAGNode*,用于快速查找 + // 阶段 4: 寄存器分配后优化 (未来扩展点) + // 例如: + // auto post_ra_peephole = std::make_unique(); + // post_ra_peephole->runOnMachineFunction(mfunc.get()); - for (const auto& inst_ptr : bb->getInstructions()) { - auto inst = inst_ptr.get(); - - if (auto alloca = dynamic_cast(inst)) { - // AllocaInst 本身不产生寄存器中的值,但其地址将被 load/store 使用。 - // 创建一个节点来表示分配内存的地址。 - // 这个地址将是 s0 (帧指针) 的偏移量。 - // 我们将 AllocaInst 指针存储在 DAGNode 的 `value` 字段中。 - // 修正:AllocaInst 类型的 DAGNode 应该有一个 value 对应 AllocaInst* - // 但它本身不应该有 result_vreg,因为不映射到物理寄存器。 - create_node(DAGNode::ALLOCA_ADDR, alloca, value_to_node, nodes_storage); - } else if (auto store = dynamic_cast(inst)) { - auto store_node = create_node(DAGNode::STORE, store, value_to_node, nodes_storage); - - // 获取要存储的值 - DAGNode* val_node = get_operand_node(store->getValue(), value_to_node, nodes_storage); - - // 获取内存位置的指针 (基地址) - Value* ptr_ir = store->getPointer(); - DAGNode* ptr_node = get_operand_node(ptr_ir, value_to_node, nodes_storage); - - store_node->operands.push_back(val_node); - store_node->operands.push_back(ptr_node); - ptr_node->users.push_back(store_node); - } else if (auto memset = dynamic_cast(inst)) { - auto memset_node = create_node(DAGNode::MEMSET, memset, value_to_node, nodes_storage); - - // 根据 IR.h 中的定义,获取 MemsetInst 的操作数 - DAGNode* pointer_node = get_operand_node(memset->getPointer(), value_to_node, nodes_storage); - DAGNode* begin_node = get_operand_node(memset->getBegin(), value_to_node, nodes_storage); - DAGNode* size_node = get_operand_node(memset->getSize(), value_to_node, nodes_storage); - DAGNode* value_node = get_operand_node(memset->getValue(), value_to_node, nodes_storage); - - // 将操作数节点添加到 MEMSET 节点的依赖列表中 - memset_node->operands.push_back(pointer_node); - memset_node->operands.push_back(begin_node); - memset_node->operands.push_back(size_node); - memset_node->operands.push_back(value_node); - - // 建立反向链接 - pointer_node->users.push_back(memset_node); - begin_node->users.push_back(memset_node); - size_node->users.push_back(memset_node); - value_node->users.push_back(memset_node); - } else if (auto load = dynamic_cast(inst)) { - auto load_node = create_node(DAGNode::LOAD, load, value_to_node, nodes_storage); - - // 获取内存位置的指针 (基地址) - Value* ptr_ir = load->getPointer(); - DAGNode* ptr_node = get_operand_node(ptr_ir, value_to_node, nodes_storage); - load_node->operands.push_back(ptr_node); - ptr_node->users.push_back(load_node); - } else if (auto bin = dynamic_cast(inst)) { - if (value_to_node.count(bin)) continue; // CSE - - if (bin->getKind() == BinaryInst::kSub || bin->getKind() == BinaryInst::kFSub) { - Value* lhs_ir = bin->getLhs(); - if (auto const_lhs = dynamic_cast(lhs_ir)) { - bool is_neg = false; - if (const_lhs->getType()->isInt()) { - if (const_lhs->getInt() == 0) { - is_neg = true; - } - } else if (const_lhs->getType()->isFloat()) { - if (std::fabs(const_lhs->getFloat()) < std::numeric_limits::epsilon()) { - is_neg = true; - } - } - - if (is_neg) { - auto unary_node = create_node(DAGNode::UNARY, bin, value_to_node, nodes_storage); // 传递参数 - Value* operand_ir = bin->getRhs(); - DAGNode* operand_node = get_operand_node(operand_ir, value_to_node, nodes_storage); // 传递参数 - unary_node->operands.push_back(operand_node); - operand_node->users.push_back(unary_node); - continue; - } - } - } - // 常规二进制操作 - auto bin_node = create_node(DAGNode::BINARY, bin, value_to_node, nodes_storage); // 传递参数 - - DAGNode* lhs_node = get_operand_node(bin->getLhs(), value_to_node, nodes_storage); // 传递参数 - DAGNode* rhs_node = get_operand_node(bin->getRhs(), value_to_node, nodes_storage); // 传递参数 - - bin_node->operands.push_back(lhs_node); - bin_node->operands.push_back(rhs_node); - lhs_node->users.push_back(bin_node); - rhs_node->users.push_back(bin_node); - - } else if (auto un_inst = dynamic_cast(inst)) { - if (value_to_node.count(un_inst)) continue; - - auto unary_node = create_node(DAGNode::UNARY, un_inst, value_to_node, nodes_storage); // 传递参数 - - Value* operand_ir = un_inst->getOperand(); - DAGNode* operand_node = get_operand_node(operand_ir, value_to_node, nodes_storage); // 传递参数 - - unary_node->operands.push_back(operand_node); - operand_node->users.push_back(unary_node); - - } else if (auto call = dynamic_cast(inst)) { - if (value_to_node.count(call)) continue; - auto call_node = create_node(DAGNode::CALL, call, value_to_node, nodes_storage); // 传递参数 - for (auto arg : call->getArguments()) { - auto arg_val_ir = arg->getValue(); - DAGNode* arg_node = get_operand_node(arg_val_ir, value_to_node, nodes_storage); // 传递参数 - call_node->operands.push_back(arg_node); - arg_node->users.push_back(call_node); - } - } else if (auto ret = dynamic_cast(inst)) { - if (DEBUG) std::cerr << "处理 RETURN 指令: " << ret->getName() << "\n"; // 调试输出 - auto ret_node = create_node(DAGNode::RETURN, ret, value_to_node, nodes_storage); // 传递参数 - if (ret->hasReturnValue()) { - auto val_ir = ret->getReturnValue(); - DAGNode* val_node = get_operand_node(val_ir, value_to_node, nodes_storage); // 传递参数 - ret_node->operands.push_back(val_node); - val_node->users.push_back(ret_node); - } - } else if (auto cond_br = dynamic_cast(inst)) { - auto br_node = create_node(DAGNode::BRANCH, cond_br, value_to_node, nodes_storage); // 传递参数 - auto cond_ir = cond_br->getCondition(); - - if (auto constant_cond = dynamic_cast(cond_ir)) { - br_node->inst = "j " + (constant_cond->getInt() ? cond_br->getThenBlock()->getName() : cond_br->getElseBlock()->getName()); - } else { - DAGNode* cond_node = get_operand_node(cond_ir, value_to_node, nodes_storage); // 传递参数 - br_node->operands.push_back(cond_node); - cond_node->users.push_back(br_node); - } - } else if (auto uncond_br = dynamic_cast(inst)) { - auto br_node = create_node(DAGNode::BRANCH, uncond_br, value_to_node, nodes_storage); // 传递参数 - br_node->inst = "j " + uncond_br->getBlock()->getName(); - } else { - // 其他指令类型(如 PHI, 可能的自定义指令等) - // 目前假设未处理的指令类型不需要特殊 DAGNode 类型 - // 可以在这里添加更多的处理逻辑 - if (DEBUG) std::cerr << "未处理的指令类型: " << inst->getKindString() << "\n"; - continue; // 跳过未处理的指令 - } - } - return nodes_storage; -} - -// 打印 DAG -void RISCv64CodeGen::print_dag(const std::vector>& dag, const std::string& bb_name) { - std::cerr << "=== DAG for Basic Block: " << bb_name << " ===\n"; - std::set visited; - - // 辅助映射,用于在打印输出中为节点分配顺序 ID - std::map node_to_id; - int current_id = 0; - for (const auto& node_ptr : dag) { - node_to_id[node_ptr.get()] = current_id++; - } - - std::function print_node = [&](DAGNode* node, int indent) { - if (!node) return; - - std::string current_indent(indent, ' '); - int node_id = node_to_id.count(node) ? node_to_id[node] : -1; // 获取分配的 ID - - std::cerr << current_indent << "Node#" << node_id << ": " << node->getNodeKindString(); - if (!node->result_vreg.empty()) { - std::cerr << " (vreg: " << node->result_vreg << ")"; - } - - if (node->value) { - std::cerr << " ["; - if (auto inst = dynamic_cast(node->value)) { - std::cerr << inst->getKindString(); - if (!inst->getName().empty()) { - std::cerr << "(" << inst->getName() << ")"; - } - } else if (auto constant = dynamic_cast(node->value)) { - if (constant->isInt()) { - std::cerr << "ConstInt(" << constant->getInt() << ")"; - } else { - std::cerr << "ConstFloat(" << constant->getFloat() << ")"; - } - } else if (auto global = dynamic_cast(node->value)) { - std::cerr << "Global(" << global->getName() << ")"; - } else if (auto alloca = dynamic_cast(node->value)) { - std::cerr << "Alloca(" << (alloca->getName().empty() ? ("%" + std::to_string(reinterpret_cast(alloca) % 1000)) : alloca->getName()) << ")"; - } - std::cerr << "]"; - } - std::cerr << " -> Inst: \"" << node->inst << "\""; // 打印选定的指令 - std::cerr << "\n"; - - if (visited.find(node) != visited.end()) { - std::cerr << current_indent << " (已打印后代)\n"; - return; // 避免循环的无限递归 - } - visited.insert(node); - - if (!node->operands.empty()) { - std::cerr << current_indent << " 操作数:\n"; - for (auto operand : node->operands) { - print_node(operand, indent + 4); - } - } - // 移除了 users 打印,以简化输出并避免 DAG 中的冗余递归。 - // Users 更适用于向上遍历,而不是向下遍历。 - }; - - // 遍历 DAG,以尊重依赖的方式打印。 - // 当前实现:遍历所有节点,从作为“根”的节点开始打印(没有用户或副作用节点)。 - // 每次打印新的根时,重置 visited 集合,以允许共享子图被重新打印(尽管这不是最高效的方式)。 - for (const auto& node_ptr : dag) { - // 只有那些没有用户或者表示副作用(如 store/branch/return)的节点才被视为“根” - // 这样可以确保所有指令(包括那些没有明确结果的)都被打印 - if (node_ptr->users.empty() || node_ptr->kind == DAGNode::STORE || node_ptr->kind == DAGNode::RETURN || node_ptr->kind == DAGNode::BRANCH) { - visited.clear(); // 为每个根重置 visited,允许重新打印共享子图 - print_node(node_ptr.get(), 0); - } - } - std::cerr << "=== DAG 结束 ===\n\n"; -} - -// 指令选择 -void RISCv64CodeGen::select_instructions(DAGNode* node, const RegAllocResult& alloc) { - if (!node) return; - if (!node->inst.empty()) return; // 指令已选择,跳过重复处理 - - // 递归地为操作数选择指令,确保依赖先被处理 - for (auto operand : node->operands) { - if (operand) { - select_instructions(operand, alloc); - } - } - - std::stringstream ss_inst; // 使用 stringstream 构建指令 - - // 获取分配的物理寄存器,若未分配则回退到 t0 - auto get_preg_or_temp = [&](const std::string& vreg) { - if (vreg.empty()) { // 添加对空 vreg 的明确检查 - if (DEBUG) std::cerr << "警告: 虚拟寄存器 (空字符串) 没有分配物理寄存器,使用临时寄存器 t0 代替。\n"; - return reg_to_string(PhysicalReg::T0); - } - if (alloc.vreg_to_preg.count(vreg)) { - return reg_to_string(alloc.vreg_to_preg.at(vreg)); - } - if (DEBUG) std::cerr << "警告: 虚拟寄存器 " << vreg << " 没有分配物理寄存器,使用临时寄存器 t0 代替。\n"; - return reg_to_string(PhysicalReg::T0); // 回退到临时寄存器 t0 - }; - - // 获取栈变量的内存偏移量 - auto get_stack_offset = [&](Value* val) -> std::string { // 返回类型明确为 std::string - if (alloc.stack_map.count(val)) { - if (DEBUG) { // 避免在非DEBUG模式下打印大量内容 - std::cout << "获取栈变量的内存偏移量,变量名: " << (val ? val->getName() : "unknown") << std::endl; - } - return std::to_string(alloc.stack_map.at(val)); - } - if (DEBUG) std::cerr << "警告: 栈变量 " << (val ? val->getName() : "unknown") << " 没有在栈映射中找到,使用默认偏移 0。\n"; - // 如果没有找到映射,返回默认偏移量 "0" - return std::string("0"); // 默认或错误情况 - }; - - switch (node->kind) { - case DAGNode::CONSTANT: { - // [V2 特性] CONSTANT 节点自身不再生成指令。 - // 它的存在是为了在DAG中标记一个常数值或全局地址。 - // 加载这个值的责任转移给了使用它的节点 (STORE, BINARY, CALL, RETURN)。 - break; - } - case DAGNode::ALLOCA_ADDR: { - // ALLOCA_ADDR 节点不直接生成指令,它仅作为一个地址标记,由 LOAD/STORE 或地址计算使用 - break; - } - case DAGNode::LOAD: { - // 处理加载指令 - if (node->operands.empty() || !node->operands[0]) break; - std::string dest_reg = get_preg_or_temp(node->result_vreg); - DAGNode* ptr_node = node->operands[0]; - - if (ptr_node->kind == DAGNode::ALLOCA_ADDR) { - if (auto alloca_inst = dynamic_cast(ptr_node->value)) { - int offset = alloc.stack_map.at(alloca_inst); - ss_inst << "lw " << dest_reg << ", " << offset << "(s0)"; - } - } else { - std::string ptr_reg = get_preg_or_temp(ptr_node->result_vreg); - - // [与STORE逻辑类似的修复] 如果是从一个全局地址加载,先用la加载地址 - if (ptr_node->kind == DAGNode::CONSTANT) { - if (auto global = dynamic_cast(ptr_node->value)) { - ss_inst << "la " << ptr_reg << ", " << global->getName() << "\n\t"; - } - } - - ss_inst << "lw " << dest_reg << ", 0(" << ptr_reg << ")"; - } - break; - } - case DAGNode::STORE: { - // 处理存储指令 - if (node->operands.size() < 2 || !node->operands[0] || !node->operands[1]) break; - DAGNode* val_node = node->operands[0]; - DAGNode* ptr_node = node->operands[1]; - - std::string src_reg = get_preg_or_temp(val_node->result_vreg); - - // [V2 特性] 如果要存储的值是一个常数,在此处立即加载它 - if (val_node->kind == DAGNode::CONSTANT) { - if (auto constant = dynamic_cast(val_node->value)) { - if (constant->isInt()) { - ss_inst << "li " << src_reg << ", " << constant->getInt() << "\n\t"; - } else { // 浮点数常量 - float f = constant->getFloat(); - uint32_t float_bits = *(uint32_t*)&f; - ss_inst << "li " << src_reg << ", " << float_bits << "\n\t"; - // 注意:这里假设 src_reg 是一个通用寄存器,需要额外指令移动到浮点寄存器 - // 如果 STORE 的目标是浮点数,这里逻辑需要调整 - } - } else if (auto global = dynamic_cast(val_node->value)) { - // 如果要存储的值是另一个全局变量的地址 - ss_inst << "la " << src_reg << ", " << global->getName() << "\n\t"; - } - } - - // [本次修复] 处理目标地址,并生成最终的存储指令 - if (ptr_node->kind == DAGNode::ALLOCA_ADDR) { - // 情况1: 存储到栈上的局部变量 - if (auto alloca_inst = dynamic_cast(ptr_node->value)) { - int offset = alloc.stack_map.at(alloca_inst); - ss_inst << "sw " << src_reg << ", " << offset << "(s0)"; - } - } else { - // 情况2: 存储到指针指向的地址 (可能是全局变量或已计算的地址) - std::string ptr_reg = get_preg_or_temp(ptr_node->result_vreg); - - // 如果指针本身是一个全局变量的地址,需要先用 la 加载它 - if (ptr_node->kind == DAGNode::CONSTANT) { - if (auto global = dynamic_cast(ptr_node->value)) { - ss_inst << "la " << ptr_reg << ", " << global->getName() << "\n\t"; - } - } - - // 生成最终的存储指令 - ss_inst << "sw " << src_reg << ", 0(" << ptr_reg << ")"; - } - break; - } - case DAGNode::BINARY: { - if (node->operands.size() < 2 || !node->operands[0] || !node->operands[1]) break; - auto bin = dynamic_cast(node->value); - if (!bin) break; - - std::string dest_reg = get_preg_or_temp(node->result_vreg); - DAGNode* lhs_node = node->operands[0]; - DAGNode* rhs_node = node->operands[1]; - - // [V1 特性] 检查是否是 base + offset 的地址计算 - if (bin->getKind() == BinaryInst::kAdd) { - DAGNode* base_node = nullptr; - DAGNode* offset_node = nullptr; - - if (lhs_node->kind == DAGNode::ALLOCA_ADDR) { base_node = lhs_node; offset_node = rhs_node; } - else if (rhs_node->kind == DAGNode::ALLOCA_ADDR) { base_node = rhs_node; offset_node = lhs_node; } - - if (base_node) { - if (auto alloca_inst = dynamic_cast(base_node->value)) { - std::string stack_offset_str = get_stack_offset(alloca_inst); - std::string index_offset_reg = get_preg_or_temp(offset_node->result_vreg); - - // 生成两条指令来计算最终地址: - // 1. addi 将 s0 加上 b 的栈偏移量得到 b 的实际基地址 - ss_inst << "addi " << dest_reg << ", s0, " << stack_offset_str << "\n"; - // 2. [V3 修复] add 将基地址和索引偏移量相加,得到最终地址。地址计算必须用64位指令。 - ss_inst << "\tadd " << dest_reg << ", " << dest_reg << ", " << index_offset_reg; - node->inst = ss_inst.str(); - return; // 指令已生成,提前返回 - } - } - } - - std::string lhs_reg = get_preg_or_temp(lhs_node->result_vreg); - std::string rhs_reg = get_preg_or_temp(rhs_node->result_vreg); - - // [V2 特性] & [上次修复] 在生成二元运算指令前,检查操作数是否为常数(立即数或全局地址),并按需加载 - if (lhs_node->kind == DAGNode::CONSTANT) { - if (auto c = dynamic_cast(lhs_node->value)) { - ss_inst << "li " << lhs_reg << ", " << c->getInt() << "\n\t"; - } else if (auto g = dynamic_cast(lhs_node->value)) { - ss_inst << "la " << lhs_reg << ", " << g->getName() << "\n\t"; - } - } - if (rhs_node->kind == DAGNode::CONSTANT) { - if (auto c = dynamic_cast(rhs_node->value)) { - // 优化:对于加法和小立即数,直接使用 addi 指令 - if (bin->getKind() == BinaryInst::kAdd && c->getInt() >= -2048 && c->getInt() < 2048) { - ss_inst << "addi " << dest_reg << ", " << lhs_reg << ", " << c->getInt(); - node->inst = ss_inst.str(); - return; // 指令已生成,提前返回 - } - // 否则,正常加载 - ss_inst << "li " << rhs_reg << ", " << c->getInt() << "\n\t"; - } else if (auto g = dynamic_cast(rhs_node->value)) { - ss_inst << "la " << rhs_reg << ", " << g->getName() << "\n\t"; - } - } - - std::string opcode; - switch (bin->getKind()) { - // [V1 特性] & [V3 修复] RV64 修改: 整数运算使用带 'w' 后缀的32位指令,地址运算使用64位指令 - case BinaryInst::kAdd: - // 通过检查操作数类型来决定使用 64位(add) 还是 32位(addw) 加法 - if (bin->getLhs()->getType()->isPointer() || bin->getRhs()->getType()->isPointer()) { - opcode = "add"; // 指针/地址运算 - } else { - opcode = "addw"; // 普通整数运算 - } - break; - case BinaryInst::kSub: opcode = "subw"; break; - case BinaryInst::kMul: opcode = "mulw"; break; - case Instruction::kDiv: opcode = "divw"; break; - case Instruction::kRem: opcode = "remw"; break; - case BinaryInst::kICmpEQ: - ss_inst << "subw " << dest_reg << ", " << lhs_reg << ", " << rhs_reg << "\n"; - ss_inst << "\tseqz " << dest_reg << ", " << dest_reg; - node->inst = ss_inst.str(); - return; - case BinaryInst::kICmpGE: - ss_inst << "slt " << dest_reg << ", " << lhs_reg << ", " << rhs_reg << "\n"; - ss_inst << "\txori " << dest_reg << ", " << dest_reg << ", 1"; - node->inst = ss_inst.str(); - return; - case BinaryInst::kICmpGT: - ss_inst << "slt " << dest_reg << ", " << rhs_reg << ", " << lhs_reg; - node->inst = ss_inst.str(); - return; - case BinaryInst::kICmpLE: - ss_inst << "slt " << dest_reg << ", " << rhs_reg << ", " << lhs_reg << "\n"; - ss_inst << "\txori " << dest_reg << ", " << dest_reg << ", 1"; - node->inst = ss_inst.str(); - return; - case BinaryInst::kICmpLT: - ss_inst << "slt " << dest_reg << ", " << lhs_reg << ", " << rhs_reg; - node->inst = ss_inst.str(); - return; - case BinaryInst::kICmpNE: - ss_inst << "subw " << dest_reg << ", " << lhs_reg << ", " << rhs_reg << "\n"; - ss_inst << "\tsnez " << dest_reg << ", " << dest_reg; - node->inst = ss_inst.str(); - return; - default: - // [V1 特性] 保留对未实现指令的定义和报错 - throw std::runtime_error("不支持的二元指令类型: " + bin->getKindString()); - } - if (!opcode.empty()) { - ss_inst << opcode << " " << dest_reg << ", " << lhs_reg << ", " << rhs_reg; - } - break; - } - case DAGNode::UNARY: { - if (node->operands.empty() || !node->operands[0]) break; - auto unary = dynamic_cast(node->value); - if (!unary) break; - - std::string dest_reg = get_preg_or_temp(node->result_vreg); - std::string src_reg = get_preg_or_temp(node->operands[0]->result_vreg); - - switch (unary->getKind()) { - case UnaryInst::kNeg: - // RV64 修改: 使用 subw 实现32位取负 (negw 伪指令) - ss_inst << "subw " << dest_reg << ", x0, " << src_reg; - break; - case UnaryInst::kNot: - // 整数逻辑非:seqz rd, rs (rs == 0 时 rd = 1,否则 rd = 0) - ss_inst << "seqz " << dest_reg << ", " << src_reg; - break; - // [V1 特性] 保留对未实现指令的定义和报错 - case UnaryInst::kFNeg: - case UnaryInst::kFNot: - case UnaryInst::kFtoI: - case UnaryInst::kItoF: - case UnaryInst::kBitFtoI: - case UnaryInst::kBitItoF: - throw std::runtime_error("不支持的浮点一元指令类型: " + unary->getKindString()); - default: - throw std::runtime_error("不支持的一元指令类型: " + unary->getKindString()); - } - break; - } - case DAGNode::CALL: { - // 处理函数调用指令 - if (!node->value) break; - auto call = dynamic_cast(node->value); - if (!call) break; - - // [V2/V3 特性] 修正参数处理逻辑 - for (size_t i = 0; i < node->operands.size() && i < 8; ++i) { - DAGNode* arg_node = node->operands[i]; - if (!arg_node) continue; - - std::string arg_preg = reg_to_string(static_cast(static_cast(PhysicalReg::A0) + i)); - - // 优先检查参数节点是否为常量,并直接加载 - if (arg_node->kind == DAGNode::CONSTANT) { - if (auto const_val = dynamic_cast(arg_node->value)) { - ss_inst << "li " << arg_preg << ", " << const_val->getInt() << "\n\t"; - } else if (auto global_val = dynamic_cast(arg_node->value)) { - ss_inst << "la " << arg_preg << ", " << global_val->getName() << "\n\t"; - } - } - // 如果不是常量,说明是一个计算出的值,使用 mv 指令移动 - else { - std::string src_reg = get_preg_or_temp(arg_node->result_vreg); - ss_inst << "mv " << arg_preg << ", " << src_reg << "\n\t"; - } - } - ss_inst << "call " << call->getCallee()->getName(); - - // 处理返回值 - if ((call->getType()->isInt() || call->getType()->isFloat()) && !node->result_vreg.empty()) { - ss_inst << "\nmv " << get_preg_or_temp(node->result_vreg) << ", a0"; - } - break; - } - case DAGNode::RETURN: { - // 处理返回指令 - if (!node->operands.empty() && node->operands[0]) { - DAGNode* ret_val_node = node->operands[0]; - - // [V2 特性] & [上次修复] 如果返回值是常量(立即数或全局地址),直接加载到 a0 - if (ret_val_node->kind == DAGNode::CONSTANT) { - if (auto c = dynamic_cast(ret_val_node->value)) { - ss_inst << "li a0, " << c->getInt() << "\n"; - } else if (auto g = dynamic_cast(ret_val_node->value)) { - ss_inst << "la a0, " << g->getName() << "\n"; - } - } else { - std::string return_val_reg = get_preg_or_temp(ret_val_node->result_vreg); - ss_inst << "mv a0, " << return_val_reg << "\n"; - } - } - - // 恢复栈帧 - if (alloc.stack_size > 0) { - int aligned_stack_size = (alloc.stack_size + 15) & ~15; - // RV64 修改: 使用 ld (load doubleword) 恢复 8 字节的 ra 和 s0 - ss_inst << "\tld ra, " << (aligned_stack_size - 8) << "(sp)\n"; - ss_inst << "\tld s0, " << (aligned_stack_size - 16) << "(sp)\n"; - ss_inst << "\taddi sp, sp, " << aligned_stack_size << "\n"; - } - ss_inst << "\tret"; - break; - } - case DAGNode::BRANCH: { - // 处理分支指令 - auto br = dynamic_cast(node->value); - auto uncond_br = dynamic_cast(node->value); - - if (br) { // 条件分支 - if (node->operands.empty() || !node->operands[0]) break; - std::string cond_reg = get_preg_or_temp(node->operands[0]->result_vreg); - std::string then_block = br->getThenBlock()->getName(); - std::string else_block = br->getElseBlock()->getName(); - - // [V1 特性] 如果基本块没有名字(例如,匿名块),给它一个伪名称 - if (then_block.empty()) { - then_block = ENTRY_BLOCK_PSEUDO_NAME + "_then_" + std::to_string(this->local_label_counter++); - } - if (else_block.empty()) { - else_block = ENTRY_BLOCK_PSEUDO_NAME + "_else_" + std::to_string(this->local_label_counter++); - } - - ss_inst << "bnez " << cond_reg << ", " << then_block << "\n"; - ss_inst << "\tj " << else_block; - } else if (uncond_br) { // 无条件分支 - std::string target_block = uncond_br->getBlock()->getName(); - if (target_block.empty()) { - target_block = ENTRY_BLOCK_PSEUDO_NAME + "_target_" + std::to_string((uintptr_t)uncond_br); - } - ss_inst << "j " << target_block; - } - break; - } - case DAGNode::MEMSET: { - if (node->operands.size() < 4) break; - - // 1. 获取操作数被分配到的物理寄存器 - // 您的 IR 中 pointer 和 begin 可能是同一个值,这里我们假设 pointer 是基地址 - DAGNode* ptr_node = node->operands[0]; - DAGNode* size_node = node->operands[2]; - DAGNode* value_node = node->operands[3]; - - std::string R_DEST_ADDR = get_preg_or_temp(ptr_node->result_vreg); - std::string R_NUM_BYTES = get_preg_or_temp(size_node->result_vreg); - std::string R_VALUE_BYTE = get_preg_or_temp(value_node->result_vreg); - - // 2. 定义我们将要使用的临时寄存器 - // 由于我们在冲突图中添加了规则,可以安全地使用这些调用者保存寄存器 - std::string R_COUNTER = "t3"; // 循环计数器 (字节) - std::string R_END_ADDR = "t4"; // 结束地址 - std::string R_CURRENT_ADDR = "t5"; // 当前写入地址 - std::string R_TEMP_VAL = "t6"; // 64位填充值 - - // 使用 local_label_counter 为 memset 循环生成唯一的、整洁的标签 - int unique_id = this->local_label_counter++; - std::string loop_start_label = "memset_loop_start_" + std::to_string(unique_id); - std::string loop_end_label = "memset_loop_end_" + std::to_string(unique_id); - std::string remainder_label = "memset_remainder_" + std::to_string(unique_id); - std::string done_label = "memset_done_" + std::to_string(unique_id); - - // 3. 生成汇编代码 - ss_inst << "# --- Memset Start ---\n"; - ss_inst << " andi " << R_TEMP_VAL << ", " << R_VALUE_BYTE << ", 255\n"; - ss_inst << " slli " << R_VALUE_BYTE << ", " << R_TEMP_VAL << ", 8\n"; - ss_inst << " or " << R_TEMP_VAL << ", " << R_TEMP_VAL << ", " << R_VALUE_BYTE << "\n"; - ss_inst << " slli " << R_VALUE_BYTE << ", " << R_TEMP_VAL << ", 16\n"; - ss_inst << " or " << R_TEMP_VAL << ", " << R_TEMP_VAL << ", " << R_VALUE_BYTE << "\n"; - ss_inst << " slli " << R_VALUE_BYTE << ", " << R_TEMP_VAL << ", 32\n"; - ss_inst << " or " << R_TEMP_VAL << ", " << R_TEMP_VAL << ", " << R_VALUE_BYTE << "\n"; - ss_inst << " add " << R_END_ADDR << ", " << R_DEST_ADDR << ", " << R_NUM_BYTES << "\n"; - ss_inst << " mv " << R_CURRENT_ADDR << ", " << R_DEST_ADDR << "\n"; - ss_inst << " andi " << R_COUNTER << ", " << R_NUM_BYTES << ", -8\n"; - ss_inst << " add " << R_COUNTER << ", " << R_DEST_ADDR << ", " << R_COUNTER << "\n"; - ss_inst << loop_start_label << ":\n"; - ss_inst << " bgeu " << R_CURRENT_ADDR << ", " << R_COUNTER << ", " << loop_end_label << "\n"; - ss_inst << " sd " << R_TEMP_VAL << ", 0(" << R_CURRENT_ADDR << ")\n"; - ss_inst << " addi " << R_CURRENT_ADDR << ", " << R_CURRENT_ADDR << ", 8\n"; - ss_inst << " j " << loop_start_label << "\n"; - ss_inst << loop_end_label << ":\n"; - ss_inst << remainder_label << ":\n"; - ss_inst << " bgeu " << R_CURRENT_ADDR << ", " << R_END_ADDR << ", " << done_label << "\n"; - ss_inst << " sb " << R_TEMP_VAL << ", 0(" << R_CURRENT_ADDR << ")\n"; - ss_inst << " addi " << R_CURRENT_ADDR << ", " << R_CURRENT_ADDR << ", 1\n"; - ss_inst << " j " << remainder_label << "\n"; - ss_inst << done_label << ":\n"; - ss_inst << "# --- Memset End ---"; - break; - } - default: - throw std::runtime_error("不支持的节点类型: " + node->getNodeKindString()); - } - node->inst = ss_inst.str(); // 存储生成的指令 -} - -// 指令发射 -void RISCv64CodeGen::emit_instructions(DAGNode* node, std::stringstream& ss, const RegAllocResult& alloc, std::set& emitted_nodes) { - if (!node || emitted_nodes.count(node)) { - return; // 如果节点为空或已经发射过,则返回 - } - - // 递归地发射操作数,以确保满足指令依赖 - for (auto operand : node->operands) { - if (operand) { - emit_instructions(operand, ss, alloc, emitted_nodes); - } - } - - // 标记当前节点为已发射,防止重复 - emitted_nodes.insert(node); - - // node->inst 中可能包含由 \n 分隔的多行指令和标签 - std::stringstream node_inst_ss(node->inst); - std::string line; - - while (std::getline(node_inst_ss, line, '\n')) { - // 首先,移除行首和行尾的空白字符,方便后续判断 - line = std::regex_replace(line, std::regex("^\\s+|\\s+$"), ""); - - if (line.empty()) { - continue; // 跳过空行 - } - - // ====================== 核心修正逻辑 ====================== - // 判断当前行是否是一个标签(即,非空且以':'结尾) - if (!line.empty() && line.back() == ':') { - // 如果是标签,直接打印,不加前导缩进 - ss << line << "\n"; - } else { - // 如果是常规指令,添加4个空格的前导缩进后再打印 - ss << " " << line << "\n"; - } - // ======================================================== - } -} - -// 辅助函数:将集合打印为字符串 -std::string print_set(const std::set& s) { + // 阶段 5: 代码发射 (LLIR with physical regs -> Assembly Text) std::stringstream ss; - ss << "{"; - bool first = true; - for (const auto& elem : s) { - if (!first) { - ss << ", "; - } - ss << elem; - first = false; - } - ss << "}"; + RISCv64AsmPrinter printer; + printer.runOnMachineFunction(mfunc.get(), ss); + return ss.str(); } -// 活跃性分析(更新以返回 live_in 和 live_out) -LivenessResult RISCv64CodeGen::liveness_analysis(Function* func) { - LivenessResult result; - bool changed = true; - - // 初始化 live_in 和 live_out - for (const auto& bb : func->getBasicBlocks()) { - for (const auto& inst_ptr : bb->getInstructions()) { - result.live_in[inst_ptr.get()] = {}; - result.live_out[inst_ptr.get()] = {}; - } - } - - int iteration_count = 0; - while (changed) { - changed = false; - iteration_count++; - if (DEEPDEBUG) std::cerr << "\n--- 活跃性分析迭代: " << iteration_count << " ---" << std::endl; - - // 逆序遍历基本块 - for (auto it = func->getBasicBlocks_NoRange().rbegin(); it != func->getBasicBlocks_NoRange().rend(); ++it) { - auto bb = it->get(); - if (DEEPDEBUG) std::cerr << " 基本块: " << bb->getName() << std::endl; - - // 计算基本块末尾的 live_out 集合,即所有后继基本块 live_in 的并集 - std::set live_out_for_bb; - for (const auto& succ_bb : bb->getSuccessors()) { - if (!succ_bb->getInstructions().empty()) { - Instruction* first_inst_in_succ = succ_bb->getInstructions().front().get(); - if (result.live_in.count(first_inst_in_succ)) { - const auto& succ_live_in = result.live_in.at(first_inst_in_succ); - live_out_for_bb.insert(succ_live_in.begin(), succ_live_in.end()); - } - } - } - - // 逆序遍历指令 - for (auto inst_it = bb->getInstructions().rbegin(); inst_it != bb->getInstructions().rend(); ++inst_it) { - auto inst = inst_it->get(); - if (DEEPDEBUG) std::cerr << " 指令 (BB: " << bb->getName() << ", 地址: " << static_cast(inst) << ")" << std::endl; - - std::set current_live_in = result.live_in[inst]; - std::set current_live_out = result.live_out[inst]; - std::set new_live_out; - - // 计算当前指令的 live_out - if (inst_it == bb->getInstructions().rbegin()) { - // 对于块中的最后一条指令,其 live_out 是块的 live_out - new_live_out = live_out_for_bb; - } else { - // 否则,其 live_out 是其后继指令的 live_in - auto next_inst_it = std::prev(inst_it); - new_live_out = result.live_in[next_inst_it->get()]; - } - - std::set use_set, def_set; - - // 定义 (Def) - if (value_vreg_map.count(inst) && !inst->getType()->isVoid() && !dynamic_cast(inst) && !dynamic_cast(inst) && - !dynamic_cast(inst) && !dynamic_cast(inst) && !dynamic_cast(inst)) { - def_set.insert(value_vreg_map.at(inst)); - } - - // 使用 (Use) - for (const auto& operand_use : inst->getOperands()) { - Value* operand = operand_use->getValue(); - if (value_vreg_map.count(operand) && !dynamic_cast(operand)) { - use_set.insert(value_vreg_map.at(operand)); - } - } - - // 数据流方程: live_in[i] = use[i] U (live_out[i] - def[i]) - std::set new_live_in = use_set; - for (const auto& vreg : new_live_out) { - if (def_set.find(vreg) == def_set.end()) { - new_live_in.insert(vreg); - } - } - - // 如果活跃性集合发生变化,更新并继续迭代 - if (new_live_in != current_live_in || new_live_out != current_live_out) { - result.live_in[inst] = new_live_in; - result.live_out[inst] = new_live_out; - changed = true; - } - } - } - } - return result; -} - - -// 干扰图构建 (使用正确的 live_out 集合) -std::map> RISCv64CodeGen::build_interference_graph( - const LivenessResult& liveness) { - std::map> graph; - - // 初始化图,确保每个虚拟寄存器都有一个节点 - for (const auto& pair : value_vreg_map) { - graph[pair.second] = {}; - } - - // 遍历每条指令来构建干扰 - for (const auto& pair : liveness.live_out) { - Instruction* inst = pair.first; - const auto& live_out_set = pair.second; - - // 获取该指令定义的 vreg - std::string defined_vreg; - if (value_vreg_map.count(inst) && !inst->getType()->isVoid() && !dynamic_cast(inst) && !dynamic_cast(inst) && - !dynamic_cast(inst) && !dynamic_cast(inst) && !dynamic_cast(inst)) { - defined_vreg = value_vreg_map.at(inst); - } - - // --- 规则一 和 新增的规则三 --- - if (!defined_vreg.empty()) { - // 您的“规则一”:定义与出口活跃寄存器冲突 (保留) - for (const auto& live_vreg : live_out_set) { - if (defined_vreg != live_vreg) { - graph[defined_vreg].insert(live_vreg); - graph[live_vreg].insert(defined_vreg); - } - } - - // ====================== 新增的、缺失的规则三 ====================== - // 定义与在同一指令中并发使用的寄存器冲突 - for (const auto& operand_use : inst->getOperands()) { - Value* operand = operand_use->getValue(); - if (value_vreg_map.count(operand)) { - const std::string& use_vreg = value_vreg_map.at(operand); - if (defined_vreg != use_vreg) { - graph[defined_vreg].insert(use_vreg); - graph[use_vreg].insert(defined_vreg); - } - } - } - // ================================================================= - } - - // --- 您的“规则二”:特殊指令内部的操作数之间也存在干扰 (保留) --- - if (auto store = dynamic_cast(inst)) { - Value* val_operand = store->getValue(); - Value* ptr_operand = store->getPointer(); - if (value_vreg_map.count(val_operand) && value_vreg_map.count(ptr_operand)) { - const std::string& val_vreg = value_vreg_map.at(val_operand); - const std::string& ptr_vreg = value_vreg_map.at(ptr_operand); - if (val_vreg != ptr_vreg) { - graph[val_vreg].insert(ptr_vreg); - graph[ptr_vreg].insert(val_vreg); - } - } - } else if (auto bin = dynamic_cast(inst)) { - Value* lhs_operand = bin->getLhs(); - Value* rhs_operand = bin->getRhs(); - if (value_vreg_map.count(lhs_operand) && value_vreg_map.count(rhs_operand)) { - const std::string& lhs_vreg = value_vreg_map.at(lhs_operand); - const std::string& rhs_vreg = value_vreg_map.at(rhs_operand); - if (lhs_vreg != rhs_vreg) { - graph[lhs_vreg].insert(rhs_vreg); - graph[rhs_vreg].insert(lhs_vreg); - } - } - } else if (auto call = dynamic_cast(inst)) { - std::vector arg_vregs; - for (auto arg_use : call->getArguments()) { - Value* arg_val = arg_use->getValue(); - if (value_vreg_map.count(arg_val)) { - arg_vregs.push_back(value_vreg_map.at(arg_val)); - } - } - for (size_t i = 0; i < arg_vregs.size(); ++i) { - for (size_t j = i + 1; j < arg_vregs.size(); ++j) { - graph[arg_vregs[i]].insert(arg_vregs[j]); - graph[arg_vregs[j]].insert(arg_vregs[i]); - } - } - } else if (auto memset = dynamic_cast(inst)) { - // 规则:MemsetInst 像一个函数调用,它会污染临时寄存器。 - // 因此,所有跨越这条指令的活跃变量(live_out), - // 都应该与这条指令的操作数(use)互相冲突。 - // 这会强制分配器将它们放入不同的寄存器中,或者安全地保存/恢复。 - - std::set use_set; - for (const auto& operand_use : memset->getOperands()) { - Value* operand = operand_use->getValue(); - if (value_vreg_map.count(operand)) { - use_set.insert(value_vreg_map.at(operand)); - } - } - - for (const auto& live_vreg : live_out_set) { - for (const auto& use_vreg : use_set) { - if (live_vreg != use_vreg) { - graph[live_vreg].insert(use_vreg); - graph[use_vreg].insert(live_vreg); - } - } - } - } - } - return graph; -} - -// 图着色(支持浮点寄存器) -void RISCv64CodeGen::color_graph(std::map& vreg_to_preg, - const std::map>& interference_graph) { - vreg_to_preg.clear(); - - // 分离整数和浮点寄存器池 - std::vector int_regs = { - PhysicalReg::T0, PhysicalReg::T1, PhysicalReg::T2, PhysicalReg::T3, - PhysicalReg::T4, PhysicalReg::T5, PhysicalReg::T6, - PhysicalReg::A0, PhysicalReg::A1, PhysicalReg::A2, PhysicalReg::A3, - PhysicalReg::A4, PhysicalReg::A5, PhysicalReg::A6, PhysicalReg::A7, - PhysicalReg::S0, PhysicalReg::S1, PhysicalReg::S2, PhysicalReg::S3, - PhysicalReg::S4, PhysicalReg::S5, PhysicalReg::S6, PhysicalReg::S7, - PhysicalReg::S8, PhysicalReg::S9, PhysicalReg::S10, PhysicalReg::S11 - }; - std::vector float_regs = { - PhysicalReg::F0, PhysicalReg::F1, PhysicalReg::F2, PhysicalReg::F3, - PhysicalReg::F4, PhysicalReg::F5, PhysicalReg::F6, PhysicalReg::F7, - PhysicalReg::F8, PhysicalReg::F9, PhysicalReg::F10, PhysicalReg::F11, - PhysicalReg::F12, PhysicalReg::F13, PhysicalReg::F14, PhysicalReg::F15, - PhysicalReg::F16, PhysicalReg::F17, PhysicalReg::F18, PhysicalReg::F19, - PhysicalReg::F20, PhysicalReg::F21, PhysicalReg::F22, PhysicalReg::F23, - PhysicalReg::F24, PhysicalReg::F25, PhysicalReg::F26, PhysicalReg::F27, - PhysicalReg::F28, PhysicalReg::F29, PhysicalReg::F30, PhysicalReg::F31 - }; - - // 确定虚拟寄存器类型(整数或浮点) - auto is_float_vreg = [&](const std::string& vreg) -> bool { - for (const auto& pair : value_vreg_map) { - if (pair.second == vreg) { - if (auto inst = dynamic_cast(pair.first)) { - if (inst->isUnary()) { - switch (inst->getKind()) { - case Instruction::kFNeg: - case Instruction::kFNot: - case Instruction::kFtoI: - case Instruction::kItoF: - case Instruction::kBitFtoI: - case Instruction::kBitItoF: - return true; // 浮点相关指令 - default: - return inst->getType()->isFloat(); - } - } - return inst->getType()->isFloat(); - } else if (auto constant = dynamic_cast(pair.first)) { - return constant->isFloat(); - } - } - } - return false; // 默认整数 - }; - - // 按度数排序虚拟寄存器 - std::vector> vreg_degrees; - for (const auto& entry : interference_graph) { - vreg_degrees.push_back({entry.first, (int)entry.second.size()}); - } - std::sort(vreg_degrees.begin(), vreg_degrees.end(), - [](const auto& a, const auto& b) { return a.second > b.second; }); - - for (const auto& vreg_deg_pair : vreg_degrees) { - const std::string& vreg = vreg_deg_pair.first; - std::set used_colors; - bool is_float = is_float_vreg(vreg); - - // 收集邻居使用的颜色 - if (interference_graph.count(vreg)) { - for (const auto& neighbor_vreg : interference_graph.at(vreg)) { - if (vreg_to_preg.count(neighbor_vreg)) { - used_colors.insert(vreg_to_preg.at(neighbor_vreg)); - } - } - } - - // 选择合适的寄存器池 - const auto& available_regs = is_float ? float_regs : int_regs; - - // 查找第一个可用的寄存器 - bool colored = false; - for (PhysicalReg preg : available_regs) { - if (used_colors.find(preg) == used_colors.end()) { - vreg_to_preg[vreg] = preg; - colored = true; - break; - } - } - - if (!colored) { - std::cerr << "警告: 无法为 " << vreg << " 分配" << (is_float ? "浮点" : "整数") << "寄存器,将溢出到栈。\n"; - // 溢出处理:在 stack_map 中分配栈空间 - // 这里假设每个溢出变量占用 4 字节 - // 注意:实际中需要区分整数和浮点溢出的存储指令(如 sw vs fsw) - } - } -} - -// 寄存器分配 -RISCv64CodeGen::RegAllocResult RISCv64CodeGen::register_allocation(Function* func) { - eliminate_phi(func); - vreg_counter = 0; - value_vreg_map.clear(); - - // 为所有产生值的指令和操作数分配虚拟寄存器 - for (const auto& bb_ptr : func->getBasicBlocks()) { - for (const auto& inst_ptr : bb_ptr->getInstructions()) { - Instruction* inst = inst_ptr.get(); - if (!inst->getType()->isVoid() && !dynamic_cast(inst)) { - if (value_vreg_map.find(inst) == value_vreg_map.end()) { - value_vreg_map[inst] = "v" + std::to_string(vreg_counter++); - } - } - for (const auto& operand_use : inst->getOperands()) { - Value* operand = operand_use->getValue(); - if (dynamic_cast(operand) || dynamic_cast(operand)) { - if (value_vreg_map.find(operand) == value_vreg_map.end()) { - value_vreg_map[operand] = "v" + std::to_string(vreg_counter++); - } - } else if (auto op_inst = dynamic_cast(operand)) { - if (!op_inst->getType()->isVoid() && !dynamic_cast(operand)) { - if (value_vreg_map.find(operand) == value_vreg_map.end()) { - value_vreg_map[operand] = "v" + std::to_string(vreg_counter++); - } - } - } - } - } - } - - RegAllocResult alloc_result; - int current_stack_offset = 0; - std::set allocas_in_func; - - // 为 AllocaInst 计算栈空间并分配偏移量 - for (const auto& bb_ptr : func->getBasicBlocks()) { - for (const auto& inst_ptr : bb_ptr->getInstructions()) { - if (auto alloca = dynamic_cast(inst_ptr.get())) { - allocas_in_func.insert(alloca); - } - } - } - for (auto alloca : allocas_in_func) { - int total_size = 4; - auto dims = alloca->getDims(); - if (!dims.empty()) { - int num_elements = 1; - for (const auto& dim_use : dims) { - Value* dim_value = dim_use->getValue(); - if (auto const_dim = dynamic_cast(dim_value)) { - num_elements *= const_dim->getInt(); - } else { - throw std::runtime_error("数组维度必须是编译时常量"); - } - } - total_size *= num_elements; - } - alloc_result.stack_map[alloca] = current_stack_offset; - current_stack_offset += total_size; - } - // 为保存的 ra 和 s0 (各8字节) 预留16字节空间 - alloc_result.stack_size = current_stack_offset + 16; - - // 活跃性分析 - LivenessResult liveness = liveness_analysis(func); - - // 构建干扰图 - std::map> interference_graph = build_interference_graph(liveness); - - // 图着色 - color_graph(alloc_result.vreg_to_preg, interference_graph); - - if (DEBUG) { - std::cerr << "=== 寄存器分配结果 (vreg_to_preg) ===\n"; - for (const auto& pair : alloc_result.vreg_to_preg) { - std::cerr << " " << pair.first << " -> " << reg_to_string(pair.second) << "\n"; - } - std::cerr << "=== 寄存器分配结果结束 ===\n\n"; - - std::cerr << "=== 活跃性分析结果 (live_in sets) ===\n"; - for (const auto& bb_ptr : func->getBasicBlocks()) { - std::cerr << "Basic Block: " << bb_ptr->getName() << "\n"; - for (const auto& inst_ptr : bb_ptr->getInstructions()) { - std::cerr << " Inst: " << inst_ptr->getKindString(); - if (!inst_ptr->getName().empty()) std::cerr << "(" << inst_ptr->getName() << ")"; - if (value_vreg_map.count(inst_ptr.get())) std::cerr << " (Def vreg: " << value_vreg_map.at(inst_ptr.get()) << ")"; - - std::cerr << " (Live In: "; - if (liveness.live_in.count(inst_ptr.get())) { - std::cerr << print_set(liveness.live_in.at(inst_ptr.get())); - } else { - std::cerr << "{}"; - } - std::cerr << ")\n"; - } - } - std::cerr << "=== 活跃性分析结果结束 ===\n\n"; - - std::cerr << "=== 干扰图 ===\n"; - for (const auto& pair : interference_graph) { - std::cerr << " " << pair.first << ": " << print_set(pair.second) << "\n"; - } - std::cerr << "=== 干扰图结束 ===\n\n"; - } - - return alloc_result; -} - -// Phi 消除 (简化版,将 Phi 的结果直接复制到每个前驱基本块的末尾) -void RISCv64CodeGen::eliminate_phi(Function* func) { - // 这是一个占位符。适当的 phi 消除将涉及 - // 在每个前驱基本块的末尾插入 `mov` 指令,用于每个 phi 操作数。 - // 对于给定的 IR 示例,没有 phi 节点,所以这可能不是严格必要的, - // 但如果前端生成 phi 节点,则有此阶段是好的做法。 - // 目前,我们假设没有生成 phi 节点或者它们已在前端处理。 -} - } // namespace sysy \ No newline at end of file diff --git a/src/RISCv64ISel.cpp b/src/RISCv64ISel.cpp new file mode 100644 index 0000000..bf56027 --- /dev/null +++ b/src/RISCv64ISel.cpp @@ -0,0 +1,605 @@ +#include "RISCv64ISel.h" +#include +#include +#include +#include + +namespace sysy { + +RISCv64ISel::RISCv64ISel() : vreg_counter(0), local_label_counter(0) {} + +// 为一个IR Value获取或分配一个新的虚拟寄存器 +unsigned RISCv64ISel::getVReg(Value* val) { + if (!val) { // 安全检查 + throw std::runtime_error("Cannot get vreg for a null Value."); + } + if (vreg_map.find(val) == vreg_map.end()) { + if (vreg_counter == 0) { + // vreg 0 通常保留给物理寄存器x0(zero),我们从1开始分配 + vreg_counter = 1; + } + vreg_map[val] = vreg_counter++; + } + return vreg_map.at(val); +} + +// 主入口函数 +std::unique_ptr RISCv64ISel::runOnFunction(Function* func) { + F = func; + if (!F) return nullptr; + MFunc = std::make_unique(F->getName()); + vreg_map.clear(); + bb_map.clear(); + vreg_counter = 0; + local_label_counter = 0; + + select(); + + return std::move(MFunc); +} + +// 指令选择主流程 +void RISCv64ISel::select() { + // 1. 为所有基本块创建对应的MachineBasicBlock + for (const auto& bb_ptr : F->getBasicBlocks()) { + BasicBlock* bb = bb_ptr.get(); + auto mbb = std::make_unique(bb->getName(), MFunc.get()); + bb_map[bb] = mbb.get(); + MFunc->addBlock(std::move(mbb)); + } + + // 2. 为函数参数创建虚拟寄存器 + // ====================== 已修正 ====================== + // 根据 IR.h, 参数列表存储在入口基本块中 + if (F->getEntryBlock()) { + for (auto* arg_alloca : F->getEntryBlock()->getArguments()) { + getVReg(arg_alloca); + } + } + // ===================================================== + + // 3. 遍历每个基本块,生成指令 + for (const auto& bb_ptr : F->getBasicBlocks()) { + selectBasicBlock(bb_ptr.get()); + } + + // 4. 设置基本块的前驱后继关系 + for (const auto& bb_ptr : F->getBasicBlocks()) { + BasicBlock* bb = bb_ptr.get(); + CurMBB = bb_map.at(bb); + for (auto succ : bb->getSuccessors()) { + CurMBB->successors.push_back(bb_map.at(succ)); + } + for (auto pred : bb->getPredecessors()) { + CurMBB->predecessors.push_back(bb_map.at(pred)); + } + } +} + +// 处理单个基本块 +void RISCv64ISel::selectBasicBlock(BasicBlock* bb) { + CurMBB = bb_map.at(bb); + auto dag = build_dag(bb); + + std::map value_to_node; + for(const auto& node : dag) { + if (node->value) { + value_to_node[node->value] = node.get(); + } + } + + std::set selected_nodes; + std::function select_recursive = + [&](DAGNode* node) { + if (!node || selected_nodes.count(node)) return; + + for (auto operand : node->operands) { + select_recursive(operand); + } + + // 只有当所有操作数都选择完毕后,才选择当前节点 + selectNode(node); + selected_nodes.insert(node); + }; + + // 按照IR指令的原始顺序来驱动指令选择 + for (const auto& inst_ptr : bb->getInstructions()) { + DAGNode* node_to_select = nullptr; + // 查找当前IR指令对应的DAG节点 + if (value_to_node.count(inst_ptr.get())) { + node_to_select = value_to_node.at(inst_ptr.get()); + } else { + // 对于没有返回值的指令或某些特殊情况 + for(const auto& node : dag) { + if(node->value == inst_ptr.get()) { + node_to_select = node.get(); + break; + } + } + } + if(node_to_select) { + select_recursive(node_to_select); + } + } +} + +void RISCv64ISel::selectNode(DAGNode* node) { + // 注意:不再生成字符串,而是创建MachineInstr对象并加入到CurMBB + switch (node->kind) { + case DAGNode::CONSTANT: + case DAGNode::ALLOCA_ADDR: + // 这些节点本身不生成指令。使用它们的指令会按需处理。 + // 为Alloca地址分配一个vreg是必要的,代表地址。 + if (node->value) getVReg(node->value); + break; + + case DAGNode::LOAD: { + // lw rd, offset(base) + auto dest_vreg = getVReg(node->value); + auto ptr_vreg = getVReg(node->operands[0]->value); + + auto instr = std::make_unique(RVOpcodes::LW); + instr->addOperand(std::make_unique(dest_vreg)); + // 暂时生成0(ptr),后续pass会将其优化为 offset(s0) + instr->addOperand(std::make_unique( + std::make_unique(ptr_vreg), + std::make_unique(0) + )); + CurMBB->addInstruction(std::move(instr)); + break; + } + + case DAGNode::STORE: { + // sw rs2, offset(rs1) + // 先加载常量 + if (auto val_const = dynamic_cast(node->operands[0]->value)) { + auto li = std::make_unique(RVOpcodes::LI); + li->addOperand(std::make_unique(getVReg(val_const))); + li->addOperand(std::make_unique(val_const->getInt())); + CurMBB->addInstruction(std::move(li)); + } + + auto val_vreg = getVReg(node->operands[0]->value); + auto ptr_vreg = getVReg(node->operands[1]->value); + + auto instr = std::make_unique(RVOpcodes::SW); + instr->addOperand(std::make_unique(val_vreg)); // value to store + instr->addOperand(std::make_unique( + std::make_unique(ptr_vreg), // base address + std::make_unique(0) // offset + )); + CurMBB->addInstruction(std::move(instr)); + break; + } + + case DAGNode::BINARY: { + auto bin = dynamic_cast(node->value); + if (!bin) break; + + Value* lhs = bin->getLhs(); + Value* rhs = bin->getRhs(); + + // 检查是否为 addi 优化 + if (bin->getKind() == BinaryInst::kAdd) { + if (auto rhs_const = dynamic_cast(rhs)) { + if (rhs_const->getInt() >= -2048 && rhs_const->getInt() < 2048) { + auto instr = std::make_unique(RVOpcodes::ADDIW); + instr->addOperand(std::make_unique(getVReg(bin))); + instr->addOperand(std::make_unique(getVReg(lhs))); + instr->addOperand(std::make_unique(rhs_const->getInt())); + CurMBB->addInstruction(std::move(instr)); + return; // 指令已生成,提前返回 + } + } + } + + // 为操作数加载立即数或地址 + auto load_val_if_const = [&](Value* val) { + if (auto c = dynamic_cast(val)) { + auto li = std::make_unique(RVOpcodes::LI); + li->addOperand(std::make_unique(getVReg(c))); + li->addOperand(std::make_unique(c->getInt())); + CurMBB->addInstruction(std::move(li)); + } else if (auto g = dynamic_cast(val)) { + auto la = std::make_unique(RVOpcodes::LA); + la->addOperand(std::make_unique(getVReg(g))); + la->addOperand(std::make_unique(g->getName())); + CurMBB->addInstruction(std::move(la)); + } + }; + load_val_if_const(lhs); + load_val_if_const(rhs); + + auto dest_vreg = getVReg(bin); + auto lhs_vreg = getVReg(lhs); + auto rhs_vreg = getVReg(rhs); + + // 生成二元运算指令 + switch (bin->getKind()) { + case BinaryInst::kAdd: { + RVOpcodes opcode = (lhs->getType()->isPointer() || rhs->getType()->isPointer()) ? RVOpcodes::ADD : RVOpcodes::ADDW; + auto instr = std::make_unique(opcode); + instr->addOperand(std::make_unique(dest_vreg)); + instr->addOperand(std::make_unique(lhs_vreg)); + instr->addOperand(std::make_unique(rhs_vreg)); + CurMBB->addInstruction(std::move(instr)); + break; + } + case BinaryInst::kSub: { + auto instr = std::make_unique(RVOpcodes::SUBW); + instr->addOperand(std::make_unique(dest_vreg)); + instr->addOperand(std::make_unique(lhs_vreg)); + instr->addOperand(std::make_unique(rhs_vreg)); + CurMBB->addInstruction(std::move(instr)); + break; + } + case BinaryInst::kMul: { + auto instr = std::make_unique(RVOpcodes::MULW); + instr->addOperand(std::make_unique(dest_vreg)); + instr->addOperand(std::make_unique(lhs_vreg)); + instr->addOperand(std::make_unique(rhs_vreg)); + CurMBB->addInstruction(std::move(instr)); + break; + } + case Instruction::kDiv: { + auto instr = std::make_unique(RVOpcodes::DIVW); + instr->addOperand(std::make_unique(dest_vreg)); + instr->addOperand(std::make_unique(lhs_vreg)); + instr->addOperand(std::make_unique(rhs_vreg)); + CurMBB->addInstruction(std::move(instr)); + break; + } + case Instruction::kRem: { + auto instr = std::make_unique(RVOpcodes::REMW); + instr->addOperand(std::make_unique(dest_vreg)); + instr->addOperand(std::make_unique(lhs_vreg)); + instr->addOperand(std::make_unique(rhs_vreg)); + CurMBB->addInstruction(std::move(instr)); + break; + } + case BinaryInst::kICmpEQ: { + auto sub = std::make_unique(RVOpcodes::SUBW); + sub->addOperand(std::make_unique(dest_vreg)); + sub->addOperand(std::make_unique(lhs_vreg)); + sub->addOperand(std::make_unique(rhs_vreg)); + CurMBB->addInstruction(std::move(sub)); + + auto seqz = std::make_unique(RVOpcodes::SEQZ); + seqz->addOperand(std::make_unique(dest_vreg)); + seqz->addOperand(std::make_unique(dest_vreg)); + CurMBB->addInstruction(std::move(seqz)); + break; + } + case BinaryInst::kICmpNE: { + auto sub = std::make_unique(RVOpcodes::SUBW); + sub->addOperand(std::make_unique(dest_vreg)); + sub->addOperand(std::make_unique(lhs_vreg)); + sub->addOperand(std::make_unique(rhs_vreg)); + CurMBB->addInstruction(std::move(sub)); + + auto snez = std::make_unique(RVOpcodes::SNEZ); + snez->addOperand(std::make_unique(dest_vreg)); + snez->addOperand(std::make_unique(dest_vreg)); + CurMBB->addInstruction(std::move(snez)); + break; + } + case BinaryInst::kICmpLT: { + auto instr = std::make_unique(RVOpcodes::SLT); + instr->addOperand(std::make_unique(dest_vreg)); + instr->addOperand(std::make_unique(lhs_vreg)); + instr->addOperand(std::make_unique(rhs_vreg)); + CurMBB->addInstruction(std::move(instr)); + break; + } + case BinaryInst::kICmpGT: { + auto instr = std::make_unique(RVOpcodes::SLT); + instr->addOperand(std::make_unique(dest_vreg)); + instr->addOperand(std::make_unique(rhs_vreg)); // Swapped + instr->addOperand(std::make_unique(lhs_vreg)); // Swapped + CurMBB->addInstruction(std::move(instr)); + break; + } + case BinaryInst::kICmpLE: { + auto slt = std::make_unique(RVOpcodes::SLT); + slt->addOperand(std::make_unique(dest_vreg)); + slt->addOperand(std::make_unique(rhs_vreg)); // Swapped + slt->addOperand(std::make_unique(lhs_vreg)); // Swapped + CurMBB->addInstruction(std::move(slt)); + + auto xori = std::make_unique(RVOpcodes::XORI); + xori->addOperand(std::make_unique(dest_vreg)); + xori->addOperand(std::make_unique(dest_vreg)); + xori->addOperand(std::make_unique(1)); + CurMBB->addInstruction(std::move(xori)); + break; + } + case BinaryInst::kICmpGE: { + auto slt = std::make_unique(RVOpcodes::SLT); + slt->addOperand(std::make_unique(dest_vreg)); + slt->addOperand(std::make_unique(lhs_vreg)); + slt->addOperand(std::make_unique(rhs_vreg)); + CurMBB->addInstruction(std::move(slt)); + + auto xori = std::make_unique(RVOpcodes::XORI); + xori->addOperand(std::make_unique(dest_vreg)); + xori->addOperand(std::make_unique(dest_vreg)); + xori->addOperand(std::make_unique(1)); + CurMBB->addInstruction(std::move(xori)); + break; + } + default: + throw std::runtime_error("Unsupported binary instruction in ISel"); + } + break; + } + + case DAGNode::UNARY: { + auto unary = dynamic_cast(node->value); + if (!unary) break; + + auto dest_vreg = getVReg(unary); + auto src_vreg = getVReg(unary->getOperand()); + + switch (unary->getKind()) { + case UnaryInst::kNeg: { + auto instr = std::make_unique(RVOpcodes::SUBW); + instr->addOperand(std::make_unique(dest_vreg)); + instr->addOperand(std::make_unique(PhysicalReg::ZERO)); // x0 + instr->addOperand(std::make_unique(src_vreg)); + CurMBB->addInstruction(std::move(instr)); + break; + } + case UnaryInst::kNot: { + auto instr = std::make_unique(RVOpcodes::SEQZ); + instr->addOperand(std::make_unique(dest_vreg)); + instr->addOperand(std::make_unique(src_vreg)); + CurMBB->addInstruction(std::move(instr)); + break; + } + default: + throw std::runtime_error("Unsupported unary instruction in ISel"); + } + break; + } + + case DAGNode::CALL: { + auto call = dynamic_cast(node->value); + if (!call) break; + + // 在此阶段,我们只处理函数调用本身和返回值的移动 + // 参数的传递将在一个专门的 Calling Convention Pass 中处理 + + auto call_instr = std::make_unique(RVOpcodes::CALL); + call_instr->addOperand(std::make_unique(call->getCallee()->getName())); + CurMBB->addInstruction(std::move(call_instr)); + + if (!call->getType()->isVoid()) { + auto mv_instr = std::make_unique(RVOpcodes::MV); + mv_instr->addOperand(std::make_unique(getVReg(call))); // dest + mv_instr->addOperand(std::make_unique(PhysicalReg::A0)); // src + CurMBB->addInstruction(std::move(mv_instr)); + } + break; + } + + case DAGNode::RETURN: { + auto ret_inst = dynamic_cast(node->value); + if (ret_inst && ret_inst->hasReturnValue()) { + // 如果有返回值,生成一条mv指令将其放入a0 + auto mv_instr = std::make_unique(RVOpcodes::MV); + mv_instr->addOperand(std::make_unique(PhysicalReg::A0)); + mv_instr->addOperand(std::make_unique(getVReg(ret_inst->getReturnValue()))); + CurMBB->addInstruction(std::move(mv_instr)); + } + // 生成ret伪指令 + auto instr = std::make_unique(RVOpcodes::RET); + CurMBB->addInstruction(std::move(instr)); + break; + } + + case DAGNode::BRANCH: { + if (auto cond_br = dynamic_cast(node->value)) { + // bne cond, x0, then_block + auto br_instr = std::make_unique(RVOpcodes::BNE); + br_instr->addOperand(std::make_unique(getVReg(cond_br->getCondition()))); + br_instr->addOperand(std::make_unique(PhysicalReg::ZERO)); + br_instr->addOperand(std::make_unique(cond_br->getThenBlock()->getName())); + CurMBB->addInstruction(std::move(br_instr)); + + // j else_block + // 注意:这里会产生一个fallthrough问题,后续的分支优化pass会解决它 + // 一个更健壮的生成方式是 bne -> j else; then: ...; else: ... + } else if (auto uncond_br = dynamic_cast(node->value)) { + auto j_instr = std::make_unique(RVOpcodes::J); + j_instr->addOperand(std::make_unique(uncond_br->getBlock()->getName())); + CurMBB->addInstruction(std::move(j_instr)); + } + break; + } + case DAGNode::MEMSET: { + // 这是对原memset逻辑的完整LLIR翻译 + auto memset = dynamic_cast(node->value); + if (!memset) break; + + auto r_dest_addr = getVReg(memset->getPointer()); + auto r_num_bytes = getVReg(memset->getSize()); + auto r_value_byte = getVReg(memset->getValue()); + + // 为临时值创建虚拟寄存器 + auto r_counter = vreg_counter++; + auto r_end_addr = vreg_counter++; + auto r_current_addr = vreg_counter++; + auto r_temp_val = vreg_counter++; + + auto add_instr = [&](RVOpcodes op, unsigned rd, unsigned rs1, unsigned rs2) { + auto i = std::make_unique(op); + i->addOperand(std::make_unique(rd)); + i->addOperand(std::make_unique(rs1)); + i->addOperand(std::make_unique(rs2)); + CurMBB->addInstruction(std::move(i)); + }; + auto addi_instr = [&](RVOpcodes op, unsigned rd, unsigned rs1, int64_t imm) { + auto i = std::make_unique(op); + i->addOperand(std::make_unique(rd)); + i->addOperand(std::make_unique(rs1)); + i->addOperand(std::make_unique(imm)); + CurMBB->addInstruction(std::move(i)); + }; + auto store_instr = [&](RVOpcodes op, unsigned src, unsigned base, int64_t off) { + auto i = std::make_unique(op); + i->addOperand(std::make_unique(src)); + i->addOperand(std::make_unique(std::make_unique(base), std::make_unique(off))); + CurMBB->addInstruction(std::move(i)); + }; + auto branch_instr = [&](RVOpcodes op, unsigned rs1, unsigned rs2, const std::string& label) { + auto i = std::make_unique(op); + i->addOperand(std::make_unique(rs1)); + i->addOperand(std::make_unique(rs2)); + i->addOperand(std::make_unique(label)); + CurMBB->addInstruction(std::move(i)); + }; + auto jump_instr = [&](const std::string& label) { + auto i = std::make_unique(RVOpcodes::J); + i->addOperand(std::make_unique(label)); + CurMBB->addInstruction(std::move(i)); + }; + auto label_instr = [&](const std::string& name) { + auto i = std::make_unique(RVOpcodes::LABEL); + i->addOperand(std::make_unique(name)); + CurMBB->addInstruction(std::move(i)); + }; + + int unique_id = this->local_label_counter++; + std::string loop_start_label = "memset_loop_start_" + std::to_string(unique_id); + std::string loop_end_label = "memset_loop_end_" + std::to_string(unique_id); + std::string remainder_label = "memset_remainder_" + std::to_string(unique_id); + std::string done_label = "memset_done_" + std::to_string(unique_id); + + // 构造64位的填充值 + addi_instr(RVOpcodes::ANDI, r_temp_val, r_value_byte, 255); + addi_instr(RVOpcodes::SLLI, r_value_byte, r_temp_val, 8); + add_instr(RVOpcodes::OR, r_temp_val, r_temp_val, r_value_byte); + addi_instr(RVOpcodes::SLLI, r_value_byte, r_temp_val, 16); + add_instr(RVOpcodes::OR, r_temp_val, r_temp_val, r_value_byte); + addi_instr(RVOpcodes::SLLI, r_value_byte, r_temp_val, 32); + add_instr(RVOpcodes::OR, r_temp_val, r_temp_val, r_value_byte); + + // 设置循环变量 + add_instr(RVOpcodes::ADD, r_end_addr, r_dest_addr, r_num_bytes); + auto mv = std::make_unique(RVOpcodes::MV); + mv->addOperand(std::make_unique(r_current_addr)); + mv->addOperand(std::make_unique(r_dest_addr)); + CurMBB->addInstruction(std::move(mv)); + addi_instr(RVOpcodes::ANDI, r_counter, r_num_bytes, -8); + add_instr(RVOpcodes::ADD, r_counter, r_dest_addr, r_counter); + + // 64位写入循环 + label_instr(loop_start_label); + branch_instr(RVOpcodes::BGEU, r_current_addr, r_counter, loop_end_label); + store_instr(RVOpcodes::SD, r_temp_val, r_current_addr, 0); + addi_instr(RVOpcodes::ADDI, r_current_addr, r_current_addr, 8); + jump_instr(loop_start_label); + label_instr(loop_end_label); + + // 剩余字节写入循环 + label_instr(remainder_label); + branch_instr(RVOpcodes::BGEU, r_current_addr, r_end_addr, done_label); + store_instr(RVOpcodes::SB, r_temp_val, r_current_addr, 0); + addi_instr(RVOpcodes::ADDI, r_current_addr, r_current_addr, 1); + jump_instr(remainder_label); + label_instr(done_label); + break; + } + + default: + throw std::runtime_error("Unsupported DAGNode kind in ISel: " + std::to_string(node->kind)); + } +} + + +// --- DAG构建函数 (从原RISCv64Backend.cpp几乎原样迁移, 保持不变) --- +RISCv64ISel::DAGNode* RISCv64ISel::create_node(DAGNode::NodeKind kind, Value* val, std::map& value_to_node, std::vector>& nodes_storage) { + if (val && value_to_node.count(val) && kind != DAGNode::STORE && kind != DAGNode::RETURN && kind != DAGNode::BRANCH && kind != DAGNode::MEMSET) { + return value_to_node[val]; + } + auto node = std::make_unique(kind); + node->value = val; + DAGNode* raw_node_ptr = node.get(); + nodes_storage.push_back(std::move(node)); + // 只有产生值的节点才应该被记录,以备复用 + if (val && !val->getType()->isVoid() && dynamic_cast(val)) { + value_to_node[val] = raw_node_ptr; + } else if (val && dynamic_cast(val)) { + value_to_node[val] = raw_node_ptr; + } + return raw_node_ptr; +} + +RISCv64ISel::DAGNode* RISCv64ISel::get_operand_node(Value* val_ir, std::map& value_to_node, std::vector>& nodes_storage) { + if (value_to_node.count(val_ir)) { + return value_to_node[val_ir]; + } else if (dynamic_cast(val_ir)) { + return create_node(DAGNode::CONSTANT, val_ir, value_to_node, nodes_storage); + } else if (dynamic_cast(val_ir)) { + return create_node(DAGNode::CONSTANT, val_ir, value_to_node, nodes_storage); + } else if (dynamic_cast(val_ir)) { + return create_node(DAGNode::ALLOCA_ADDR, val_ir, value_to_node, nodes_storage); + } + // Fallback: Assume it needs to be loaded if not found (might be a parameter or a value from another block) + return create_node(DAGNode::LOAD, val_ir, value_to_node, nodes_storage); +} + +std::vector> RISCv64ISel::build_dag(BasicBlock* bb) { + std::vector> nodes_storage; + std::map value_to_node; + + for (const auto& inst_ptr : bb->getInstructions()) { + Instruction* inst = inst_ptr.get(); + if (auto alloca = dynamic_cast(inst)) { + create_node(DAGNode::ALLOCA_ADDR, alloca, value_to_node, nodes_storage); + } else if (auto store = dynamic_cast(inst)) { + auto store_node = create_node(DAGNode::STORE, store, value_to_node, nodes_storage); + store_node->operands.push_back(get_operand_node(store->getValue(), value_to_node, nodes_storage)); + store_node->operands.push_back(get_operand_node(store->getPointer(), value_to_node, nodes_storage)); + } else if (auto memset = dynamic_cast(inst)) { + auto memset_node = create_node(DAGNode::MEMSET, memset, value_to_node, nodes_storage); + memset_node->operands.push_back(get_operand_node(memset->getPointer(), value_to_node, nodes_storage)); + memset_node->operands.push_back(get_operand_node(memset->getBegin(), value_to_node, nodes_storage)); + memset_node->operands.push_back(get_operand_node(memset->getSize(), value_to_node, nodes_storage)); + memset_node->operands.push_back(get_operand_node(memset->getValue(), value_to_node, nodes_storage)); + } + else if (auto load = dynamic_cast(inst)) { + auto load_node = create_node(DAGNode::LOAD, load, value_to_node, nodes_storage); + load_node->operands.push_back(get_operand_node(load->getPointer(), value_to_node, nodes_storage)); + } else if (auto bin = dynamic_cast(inst)) { + if(value_to_node.count(bin)) continue; + auto bin_node = create_node(DAGNode::BINARY, bin, value_to_node, nodes_storage); + bin_node->operands.push_back(get_operand_node(bin->getLhs(), value_to_node, nodes_storage)); + bin_node->operands.push_back(get_operand_node(bin->getRhs(), value_to_node, nodes_storage)); + } else if (auto un = dynamic_cast(inst)) { + if(value_to_node.count(un)) continue; + auto unary_node = create_node(DAGNode::UNARY, un, value_to_node, nodes_storage); + unary_node->operands.push_back(get_operand_node(un->getOperand(), value_to_node, nodes_storage)); + } + else if (auto call = dynamic_cast(inst)) { + if(value_to_node.count(call)) continue; + auto call_node = create_node(DAGNode::CALL, call, value_to_node, nodes_storage); + for (auto arg : call->getArguments()) { + call_node->operands.push_back(get_operand_node(arg->getValue(), value_to_node, nodes_storage)); + } + } else if (auto ret = dynamic_cast(inst)) { + auto ret_node = create_node(DAGNode::RETURN, ret, value_to_node, nodes_storage); + if (ret->hasReturnValue()) { + ret_node->operands.push_back(get_operand_node(ret->getReturnValue(), value_to_node, nodes_storage)); + } + } else if (auto cond_br = dynamic_cast(inst)) { + auto br_node = create_node(DAGNode::BRANCH, cond_br, value_to_node, nodes_storage); + br_node->operands.push_back(get_operand_node(cond_br->getCondition(), value_to_node, nodes_storage)); + } else if (auto uncond_br = dynamic_cast(inst)) { + create_node(DAGNode::BRANCH, uncond_br, value_to_node, nodes_storage); + } + } + return nodes_storage; +} + +} // namespace sysy \ No newline at end of file diff --git a/src/RISCv64RegAlloc.cpp b/src/RISCv64RegAlloc.cpp new file mode 100644 index 0000000..4471868 --- /dev/null +++ b/src/RISCv64RegAlloc.cpp @@ -0,0 +1,265 @@ +#include "RISCv64RegAlloc.h" +#include +#include + +namespace sysy { + +RISCv64RegAlloc::RISCv64RegAlloc(MachineFunction* mfunc) : MFunc(mfunc) { + // 初始化可分配的整数寄存器池 (排除特殊用途的) + allocable_int_regs = { + PhysicalReg::T0, PhysicalReg::T1, PhysicalReg::T2, PhysicalReg::T3, + PhysicalReg::T4, PhysicalReg::T5, PhysicalReg::T6, + PhysicalReg::A0, PhysicalReg::A1, PhysicalReg::A2, PhysicalReg::A3, + PhysicalReg::A4, PhysicalReg::A5, PhysicalReg::A6, PhysicalReg::A7, + PhysicalReg::S0, PhysicalReg::S1, PhysicalReg::S2, PhysicalReg::S3, + PhysicalReg::S4, PhysicalReg::S5, PhysicalReg::S6, PhysicalReg::S7, + PhysicalReg::S8, PhysicalReg::S9, PhysicalReg::S10, PhysicalReg::S11, + }; +} + +void RISCv64RegAlloc::run() { + analyzeLiveness(); + buildInterferenceGraph(); + colorGraph(); + rewriteFunction(); +} + +void RISCv64RegAlloc::getInstrUseDef(MachineInstr* instr, LiveSet& use, LiveSet& def) { + // 这是一个简化的版本,实际需要根据RVOpcodes精确定义 + // 通常第一个RegOperand是def,其余是use + bool is_def = true; + for (const auto& op : instr->getOperands()) { + if (op->getKind() == MachineOperand::KIND_REG) { + auto reg_op = static_cast(op.get()); + if (reg_op->isVirtual()) { + if (is_def) { + def.insert(reg_op->getVRegNum()); + is_def = false; // 假设每条指令最多一个def + } else { + use.insert(reg_op->getVRegNum()); + } + } + } else if (op->getKind() == MachineOperand::KIND_MEM) { + auto mem_op = static_cast(op.get()); + if (mem_op->getBase()->isVirtual()) { + use.insert(mem_op->getBase()->getVRegNum()); + } + } + } + + // 特殊处理store和branch指令,它们没有显式的def + auto opcode = instr->getOpcode(); + if (opcode == RVOpcodes::SW || opcode == RVOpcodes::SD || opcode == RVOpcodes::BNE || opcode == RVOpcodes::BEQ) { + def.clear(); // 清空错误的def + use.clear(); + for (const auto& op : instr->getOperands()) { + if (op->getKind() == MachineOperand::KIND_REG) { + auto reg_op = static_cast(op.get()); + if(reg_op->isVirtual()) use.insert(reg_op->getVRegNum()); + } else if (op->getKind() == MachineOperand::KIND_MEM) { + auto mem_op = static_cast(op.get()); + if(mem_op->getBase()->isVirtual()) use.insert(mem_op->getBase()->getVRegNum()); + } + } + } +} + + +void RISCv64RegAlloc::analyzeLiveness() { + bool changed = true; + while (changed) { + changed = false; + // 逆序遍历基本块 + for (auto it = MFunc->getBlocks().rbegin(); it != MFunc->getBlocks().rend(); ++it) { + auto& mbb = *it; + LiveSet live_out; + for (auto succ : mbb->successors) { + // live_out[B] = Union(live_in[S]) for all S in succ(B) + if (!succ->getInstructions().empty()) { + auto first_instr = succ->getInstructions().front().get(); + if (live_in_map.count(first_instr)) { + live_out.insert(live_in_map.at(first_instr).begin(), live_in_map.at(first_instr).end()); + } + } + } + + // 逆序遍历指令 + for (auto instr_it = mbb->getInstructions().rbegin(); instr_it != mbb->getInstructions().rend(); ++instr_it) { + MachineInstr* instr = instr_it->get(); + LiveSet old_live_in = live_in_map[instr]; + LiveSet old_live_out = live_out_map[instr]; + + // 更新 live_out + live_out_map[instr] = live_out; + + LiveSet use, def; + getInstrUseDef(instr, use, def); + + // live_in[i] = use[i] U (live_out[i] - def[i]) + LiveSet live_in = use; + LiveSet diff = live_out; + for (auto vreg : def) { + diff.erase(vreg); + } + live_in.insert(diff.begin(), diff.end()); + live_in_map[instr] = live_in; + + // 为下一次迭代准备live_out + live_out = live_in; + + if (live_in_map[instr] != old_live_in || live_out_map[instr] != old_live_out) { + changed = true; + } + } + } + } +} + +void RISCv64RegAlloc::buildInterferenceGraph() { + std::set all_vregs; + // 收集所有虚拟寄存器 + for (auto const& [instr, live_set] : live_out_map) { + all_vregs.insert(live_set.begin(), live_set.end()); + } + + // 初始化图 + for (auto vreg : all_vregs) { + interference_graph[vreg] = {}; + } + + for (auto& mbb : MFunc->getBlocks()) { + for (auto& instr : mbb->getInstructions()) { + LiveSet def, use; + getInstrUseDef(instr.get(), use, def); + + const LiveSet& live_out = live_out_map.at(instr.get()); + + for (unsigned d : def) { + for (unsigned l : live_out) { + if (d != l) { + interference_graph[d].insert(l); + interference_graph[l].insert(d); + } + } + } + } + } +} + +void RISCv64RegAlloc::colorGraph() { + std::vector sorted_vregs; + for (auto const& [vreg, neighbors] : interference_graph) { + sorted_vregs.push_back(vreg); + } + + // 按度数降序排序 (简单贪心策略) + std::sort(sorted_vregs.begin(), sorted_vregs.end(), [&](unsigned a, unsigned b) { + return interference_graph[a].size() > interference_graph[b].size(); + }); + + for (unsigned vreg : sorted_vregs) { + std::set used_colors; + // 查找邻居已用的颜色 + for (unsigned neighbor : interference_graph.at(vreg)) { + if (color_map.count(neighbor)) { + used_colors.insert(color_map.at(neighbor)); + } + } + + // 寻找一个可用的颜色 + bool colored = false; + for (PhysicalReg preg : allocable_int_regs) { + if (used_colors.find(preg) == used_colors.end()) { + color_map[vreg] = preg; + colored = true; + break; + } + } + + if (!colored) { + // 无法分配,需要溢出 + spilled_vregs.insert(vreg); + } + } +} + +void RISCv64RegAlloc::rewriteFunction() { + // 1. 为所有溢出的vreg分配栈槽 + StackFrameInfo& frame_info = MFunc->getFrameInfo(); + int current_offset = frame_info.frame_size; // 假设从现有栈大小后开始分配 + for (unsigned vreg : spilled_vregs) { + current_offset += 4; // 假设所有溢出变量都占4字节 + frame_info.spill_slots[vreg] = -current_offset; // 栈向下增长,所以是负偏移 + } + frame_info.frame_size = current_offset; + + // 2. 遍历所有指令,替换vreg并插入spill代码 + for (auto& mbb : MFunc->getBlocks()) { + std::vector> new_instructions; + for (auto& instr_ptr : mbb->getInstructions()) { + LiveSet use, def; + getInstrUseDef(instr_ptr.get(), use, def); + + // 为use的溢出变量插入LOAD + for (unsigned vreg : use) { + if (spilled_vregs.count(vreg)) { + int offset = frame_info.spill_slots.at(vreg); + auto load = std::make_unique(RVOpcodes::LW); + load->addOperand(std::make_unique(vreg)); // 临时用vreg号代表,稍后替换 + load->addOperand(std::make_unique( + std::make_unique(PhysicalReg::S0), // 基址用帧指针s0 + std::make_unique(offset) + )); + new_instructions.push_back(std::move(load)); + } + } + + // 添加原始指令 + new_instructions.push_back(std::move(instr_ptr)); + + // 为def的溢出变量插入STORE + for (unsigned vreg : def) { + if (spilled_vregs.count(vreg)) { + int offset = frame_info.spill_slots.at(vreg); + auto store = std::make_unique(RVOpcodes::SW); + store->addOperand(std::make_unique(vreg)); // 临时用vreg号代表 + store->addOperand(std::make_unique( + std::make_unique(PhysicalReg::S0), + std::make_unique(offset) + )); + new_instructions.push_back(std::move(store)); + } + } + } + mbb->getInstructions() = std::move(new_instructions); + } + + // 3. 最后一遍扫描,将所有RegOperand从vreg替换为preg + for (auto& mbb : MFunc->getBlocks()) { + for (auto& instr_ptr : mbb->getInstructions()) { + for (auto& op_ptr : instr_ptr->getOperands()) { + if(op_ptr->getKind() == MachineOperand::KIND_REG) { + auto reg_op = static_cast(op_ptr.get()); + if (reg_op->isVirtual()) { + unsigned vreg = reg_op->getVRegNum(); + if (color_map.count(vreg)) { + reg_op->setPReg(color_map.at(vreg)); + } else if (spilled_vregs.count(vreg)) { + // 对于spill的vreg, 使用一个固定的临时寄存器, 比如t6 + reg_op->setPReg(PhysicalReg::T6); + } + } + } else if (op_ptr->getKind() == MachineOperand::KIND_MEM) { + auto mem_op = static_cast(op_ptr.get()); + auto base_reg_op = mem_op->getBase(); + if(base_reg_op->isVirtual()){ + unsigned vreg = base_reg_op->getVRegNum(); + if(color_map.count(vreg)) base_reg_op->setPReg(color_map.at(vreg)); + } + } + } + } + } +} + +} // namespace sysy \ No newline at end of file diff --git a/src/include/RISCv64AsmPrinter.h b/src/include/RISCv64AsmPrinter.h new file mode 100644 index 0000000..7df35f6 --- /dev/null +++ b/src/include/RISCv64AsmPrinter.h @@ -0,0 +1,38 @@ +#ifndef RISCV64_ASMPRINTER_H +#define RISCV64_ASMPRINTER_H + +#include "RISCv64LLIR.h" +#include + +namespace sysy { + +class RISCv64AsmPrinter { +public: + // 主入口,将整个MachineFunction打印到指定的输出流 + void runOnMachineFunction(MachineFunction* mfunc, std::ostream& os); + +private: + // 打印单个基本块 + void printBasicBlock(MachineBasicBlock* mbb); + + // 打印单条指令 + void printInstruction(MachineInstr* instr, MachineBasicBlock* parent_bb); + + // 打印函数序言 + void printPrologue(MachineFunction* mfunc); + + // 打印函数尾声 + void printEpilogue(MachineFunction* mfunc); + + // 将物理寄存器枚举转换为字符串 (从原RISCv64Backend迁移) + std::string regToString(PhysicalReg reg); + + // 打印单个操作数 + void printOperand(MachineOperand* op); + + std::ostream* OS; // 指向当前输出流 +}; + +} // namespace sysy + +#endif // RISCV64_ASMPRINTER_H \ No newline at end of file diff --git a/src/include/RISCv64Backend.h b/src/include/RISCv64Backend.h index 429aba2..f929007 100644 --- a/src/include/RISCv64Backend.h +++ b/src/include/RISCv64Backend.h @@ -1,130 +1,29 @@ #ifndef RISCV64_BACKEND_H #define RISCV64_BACKEND_H -#include "IR.h" +#include "IR.h" // 只需包含高层IR定义 #include #include -#include -#include #include -#include -#include // For std::function - -extern int DEBUG; -extern int DEEPDEBUG; - - namespace sysy { -// 为活跃性分析的结果定义一个结构体,以同时持有 live_in 和 live_out 集合 -struct LivenessResult { - std::map> live_in; - std::map> live_out; -}; - +// RISCv64CodeGen 现在是一个高层驱动器 class RISCv64CodeGen { public: - enum class PhysicalReg { - ZERO, RA, SP, GP, TP, T0, T1, T2, S0, S1, A0, A1, A2, A3, A4, A5, A6, A7, S2, S3, S4, S5, S6, S7, S8, S9, S10, S11, T3, T4, T5, T6, - F0, F1, F2, F3, F4, F5, F6, F7, F8, F9, F10, F11, F12, F13, F14, F15,F16, F17, F18, F19, F20, F21, F22, F23, F24, F25, F26, F27, F28, F29, F30, F31 - }; - - // Move DAGNode and RegAllocResult to public section - struct DAGNode { - enum NodeKind { CONSTANT, LOAD, STORE, BINARY, CALL, RETURN, BRANCH, ALLOCA_ADDR, UNARY, MEMSET }; - NodeKind kind; - Value* value = nullptr; // For IR Value - std::string inst; // Generated RISC-V instruction(s) for this node - std::string result_vreg; // Virtual register assigned to this node's result - std::vector operands; - std::vector users; // For debugging and potentially optimizations - DAGNode(NodeKind k) : kind(k) {} - - // Debugging / helper - std::string getNodeKindString() const { - switch (kind) { - case CONSTANT: return "CONSTANT"; - case LOAD: return "LOAD"; - case STORE: return "STORE"; - case BINARY: return "BINARY"; - case CALL: return "CALL"; - case RETURN: return "RETURN"; - case BRANCH: return "BRANCH"; - case ALLOCA_ADDR: return "ALLOCA_ADDR"; - case UNARY: return "UNARY"; - case MEMSET: return "MEMSET"; - default: return "UNKNOWN"; - } - } - }; - - struct RegAllocResult { - std::map vreg_to_preg; // Virtual register to Physical Register mapping - std::map stack_map; // Value (AllocaInst) to stack offset - int stack_size = 0; // Total stack frame size for locals and spills - }; - RISCv64CodeGen(Module* mod) : module(mod) {} + // 唯一的公共入口点 std::string code_gen(); - std::string module_gen(); - std::string function_gen(Function* func); - // 修改 basicBlock_gen 的声明,添加 int block_idx 参数 - std::string basicBlock_gen(BasicBlock* bb, const RegAllocResult& alloc, int block_idx); - - // DAG related - std::vector> build_dag(BasicBlock* bb); - void select_instructions(DAGNode* node, const RegAllocResult& alloc); - // 改变 emit_instructions 的参数,使其可以直接添加汇编指令到 main ss - void emit_instructions(DAGNode* node, std::stringstream& ss, const RegAllocResult& alloc, std::set& emitted_nodes); - - // Register Allocation related - LivenessResult liveness_analysis(Function* func); - std::map> build_interference_graph(const LivenessResult& liveness); - void color_graph(std::map& vreg_to_preg, - const std::map>& interference_graph); - RegAllocResult register_allocation(Function* func); - void eliminate_phi(Function* func); // Phi elimination is typically done before DAG building - - // Utility - std::string reg_to_string(PhysicalReg reg); - void print_dag(const std::vector>& dag, const std::string& bb_name); private: - static const std::vector allocable_regs; - std::map value_vreg_map; // Maps IR Value* to its virtual register name + // 模块级代码生成 (处理全局变量和驱动函数生成) + std::string module_gen(); + + // 函数级代码生成 (实现新的流水线) + std::string function_gen(Function* func); + Module* module; - int vreg_counter = 0; // Counter for unique virtual register names - int alloca_offset_counter = 0; // Counter for alloca offsets - - // 新增一个成员变量来存储当前函数的所有 DAGNode,以确保其生命周期贯穿整个函数代码生成 - // 这样可以在多个 BasicBlock_gen 调用中访问到完整的 DAG 节点 - std::vector> current_function_dag_nodes; - - // 为空标签定义一个伪名称前缀,加上块索引以确保唯一性 - const std::string ENTRY_BLOCK_PSEUDO_NAME = "entry_block_"; - - int local_label_counter = 0; // 用于生成唯一的本地标签 (如 memset 循环, 匿名块跳转等) - - // !!! 修改:get_operand_node 辅助函数现在需要传入 value_to_node 和 nodes_storage 的引用 - // 因为它们是 build_dag 局部管理的 - DAGNode* get_operand_node( - Value* val_ir, - std::map& value_to_node, - std::vector>& nodes_storage - ); - - // !!! 新增:create_node 辅助函数也需要传入 value_to_node 和 nodes_storage 的引用 - // 并且它应该不再是 lambda,而是一个真正的成员函数 - DAGNode* create_node( - DAGNode::NodeKind kind, - Value* val, - std::map& value_to_node, - std::vector>& nodes_storage - ); - - std::vector> temp_instructions_storage; // 用于存储 build_dag 中创建的临时 BinaryInst }; } // namespace sysy diff --git a/src/include/RISCv64ISel.h b/src/include/RISCv64ISel.h new file mode 100644 index 0000000..e122ee0 --- /dev/null +++ b/src/include/RISCv64ISel.h @@ -0,0 +1,61 @@ +#ifndef RISCV64_ISEL_H +#define RISCV64_ISEL_H + +#include "IR.h" +#include "RISCv64LLIR.h" +#include +#include + +namespace sysy { + +class RISCv64ISel { +public: + RISCv64ISel(); + // 模块主入口:将一个高层IR函数转换为底层LLIR函数 + std::unique_ptr runOnFunction(Function* func); + +private: + // DAG节点定义,作为ISel的内部实现细节 + struct DAGNode { + enum NodeKind { CONSTANT, LOAD, STORE, BINARY, CALL, RETURN, BRANCH, ALLOCA_ADDR, UNARY, MEMSET }; + NodeKind kind; + Value* value = nullptr; + std::vector operands; + DAGNode(NodeKind k) : kind(k) {} + }; + + // 为当前函数生成LLIR + void select(); + + // 为单个基本块生成指令 + void selectBasicBlock(BasicBlock* bb); + + // 核心函数:为DAG节点选择并生成MachineInstr + void selectNode(DAGNode* node); + + // --- DAG 构建相关函数 (从原RISCv64Backend迁移) --- + std::vector> build_dag(BasicBlock* bb); + DAGNode* get_operand_node(Value* val_ir, std::map& value_to_node, std::vector>& nodes_storage); + DAGNode* create_node(DAGNode::NodeKind kind, Value* val, std::map& value_to_node, std::vector>& nodes_storage); + + // --- 辅助函数 --- + // 为一个IR Value获取/分配一个虚拟寄存器号 + unsigned getVReg(Value* val); + + Function* F; // 当前处理的高层IR函数 + std::unique_ptr MFunc; // 正在构建的底层LLIR函数 + + MachineBasicBlock* CurMBB; // 当前正在处理的机器基本块 + + // 映射关系 + std::map vreg_map; + std::map bb_map; + std::map value_to_node_map; // 用于selectNode中查找 + + unsigned vreg_counter; + int local_label_counter; +}; + +} // namespace sysy + +#endif // RISCV64_ISEL_H \ No newline at end of file diff --git a/src/include/RISCv64LLIR.h b/src/include/RISCv64LLIR.h index acbeec5..c1df3bf 100644 --- a/src/include/RISCv64LLIR.h +++ b/src/include/RISCv64LLIR.h @@ -163,8 +163,11 @@ public: MachineInstr(RVOpcodes opcode) : opcode(opcode) {} RVOpcodes getOpcode() const { return opcode; } + // 注意:返回const引用,因为通常不直接修改指令的操作数列表 const std::vector>& getOperands() const { return operands; } - + // 提供一个非const版本,用于内部修改 + std::vector>& getOperands() { return operands; } + void addOperand(std::unique_ptr operand) { operands.push_back(std::move(operand)); } @@ -181,9 +184,12 @@ public: : name(name), parent(parent) {} const std::string& getName() const { return name; } - const std::vector>& getInstructions() const { return instructions; } MachineFunction* getParent() const { return parent; } + // 同时提供 const 和 non-const 版本 + const std::vector>& getInstructions() const { return instructions; } + std::vector>& getInstructions() { return instructions; } + void addInstruction(std::unique_ptr instr) { instructions.push_back(std::move(instr)); } @@ -210,9 +216,12 @@ public: MachineFunction(const std::string& name) : name(name) {} const std::string& getName() const { return name; } - const std::vector>& getBlocks() const { return blocks; } StackFrameInfo& getFrameInfo() { return frame_info; } + // 同时提供 const 和 non-const 版本 + const std::vector>& getBlocks() const { return blocks; } + std::vector>& getBlocks() { return blocks; } + void addBlock(std::unique_ptr block) { blocks.push_back(std::move(block)); } diff --git a/src/include/RISCv64RegAlloc.h b/src/include/RISCv64RegAlloc.h new file mode 100644 index 0000000..ee8ab16 --- /dev/null +++ b/src/include/RISCv64RegAlloc.h @@ -0,0 +1,57 @@ +#ifndef RISCV64_REGALLOC_H +#define RISCV64_REGALLOC_H + +#include "RISCv64LLIR.h" +#include +#include +#include + +namespace sysy { + +class RISCv64RegAlloc { +public: + RISCv64RegAlloc(MachineFunction* mfunc); + + // 模块主入口 + void run(); + +private: + using LiveSet = std::set; // 活跃虚拟寄存器集合 + using InterferenceGraph = std::map>; + + // 活跃性分析 + void analyzeLiveness(); + + // 构建干扰图 + void buildInterferenceGraph(); + + // 图着色分配寄存器 + void colorGraph(); + + // 重写函数,将虚拟寄存器替换为物理寄存器,并插入溢出代码 + void rewriteFunction(); + + // 辅助函数,获取指令的Use/Def集合 + void getInstrUseDef(MachineInstr* instr, LiveSet& use, LiveSet& def); + + MachineFunction* MFunc; + + // 活跃性分析结果 + std::map live_in_map; + std::map live_out_map; + + // 干扰图 + InterferenceGraph interference_graph; + + // 图着色结果 + std::map color_map; // vreg -> preg + std::set spilled_vregs; // 被溢出的vreg集合 + + // 可用的物理寄存器池 + std::vector allocable_int_regs; + std::vector allocable_float_regs; // (为未来浮点支持预留) +}; + +} // namespace sysy + +#endif // RISCV64_REGALLOC_H \ No newline at end of file From 9528335a046c2024590e9e454aace01fc8fca8c8 Mon Sep 17 00:00:00 2001 From: Lixuanwang Date: Sat, 19 Jul 2025 17:50:14 +0800 Subject: [PATCH 3/3] =?UTF-8?q?[backend-llir]=E4=BF=AE=E5=A4=8D=E4=BA=86?= =?UTF-8?q?=E8=AE=B8=E5=A4=9A=E9=87=8D=E6=9E=84=E7=9A=84bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/RISCv64AsmPrinter.cpp | 281 ++++++++++++++--------------- src/RISCv64Backend.cpp | 33 +--- src/RISCv64ISel.cpp | 306 ++++++++++++++++++-------------- src/RISCv64Passes.cpp | 8 + src/RISCv64RegAlloc.cpp | 179 ++++++++++++------- src/include/RISCv64AsmPrinter.h | 26 ++- src/include/RISCv64Backend.h | 8 +- src/include/RISCv64ISel.h | 34 ++-- src/include/RISCv64LLIR.h | 104 ++++------- src/include/RISCv64Passes.h | 18 ++ src/include/RISCv64RegAlloc.h | 13 +- 11 files changed, 513 insertions(+), 497 deletions(-) create mode 100644 src/RISCv64Passes.cpp create mode 100644 src/include/RISCv64Passes.h diff --git a/src/RISCv64AsmPrinter.cpp b/src/RISCv64AsmPrinter.cpp index 0688095..0ad1c81 100644 --- a/src/RISCv64AsmPrinter.cpp +++ b/src/RISCv64AsmPrinter.cpp @@ -1,44 +1,71 @@ #include "RISCv64AsmPrinter.h" +#include "RISCv64ISel.h" #include namespace sysy { -void RISCv64AsmPrinter::runOnMachineFunction(MachineFunction* mfunc, std::ostream& os) { +// 检查是否为内存加载/存储指令,以处理特殊的打印格式 +bool isMemoryOp(RVOpcodes opcode) { + switch (opcode) { + case RVOpcodes::LB: case RVOpcodes::LH: case RVOpcodes::LW: case RVOpcodes::LD: + case RVOpcodes::LBU: case RVOpcodes::LHU: case RVOpcodes::LWU: + case RVOpcodes::SB: case RVOpcodes::SH: case RVOpcodes::SW: case RVOpcodes::SD: + return true; + default: + return false; + } +} + +RISCv64AsmPrinter::RISCv64AsmPrinter(MachineFunction* mfunc) : MFunc(mfunc) {} + +void RISCv64AsmPrinter::run(std::ostream& os) { OS = &os; - // 打印函数声明和全局符号 - *OS << ".text\n"; - *OS << ".globl " << mfunc->getName() << "\n"; - *OS << mfunc->getName() << ":\n"; + *OS << ".globl " << MFunc->getName() << "\n"; + *OS << MFunc->getName() << ":\n"; - // 打印函数序言 - printPrologue(mfunc); + printPrologue(); - // 遍历并打印所有基本块 - for (auto& mbb : mfunc->getBlocks()) { + for (auto& mbb : MFunc->getBlocks()) { printBasicBlock(mbb.get()); } } -void RISCv64AsmPrinter::printPrologue(MachineFunction* mfunc) { - int stack_size = mfunc->getFrameInfo().frame_size; - - // 确保栈大小是16字节对齐 - int aligned_stack_size = (stack_size + 15) & ~15; +void RISCv64AsmPrinter::printPrologue() { + StackFrameInfo& frame_info = MFunc->getFrameInfo(); + // 序言需要为保存ra和s0预留16字节 + int total_stack_size = frame_info.locals_size + frame_info.spill_size + 16; + int aligned_stack_size = (total_stack_size + 15) & ~15; + frame_info.total_size = aligned_stack_size; if (aligned_stack_size > 0) { *OS << " addi sp, sp, -" << aligned_stack_size << "\n"; - // RV64中ra和s0都是8字节 *OS << " sd ra, " << (aligned_stack_size - 8) << "(sp)\n"; *OS << " sd s0, " << (aligned_stack_size - 16) << "(sp)\n"; *OS << " mv s0, sp\n"; } + + // 忠实还原保存函数入口参数的逻辑 + Function* F = MFunc->getFunc(); + if (F && F->getEntryBlock()) { + int arg_idx = 0; + RISCv64ISel* isel = MFunc->getISel(); + for (AllocaInst* alloca_for_param : F->getEntryBlock()->getArguments()) { + if (arg_idx >= 8) break; + + unsigned vreg = isel->getVReg(alloca_for_param); + if (frame_info.alloca_offsets.count(vreg)) { + int offset = frame_info.alloca_offsets.at(vreg); + auto arg_reg = static_cast(static_cast(PhysicalReg::A0) + arg_idx); + *OS << " sw " << regToString(arg_reg) << ", " << offset << "(s0)\n"; + } + arg_idx++; + } + } } -void RISCv64AsmPrinter::printEpilogue(MachineFunction* mfunc) { - int stack_size = mfunc->getFrameInfo().frame_size; - int aligned_stack_size = (stack_size + 15) & ~15; - +void RISCv64AsmPrinter::printEpilogue() { + int aligned_stack_size = MFunc->getFrameInfo().total_size; if (aligned_stack_size > 0) { *OS << " ld ra, " << (aligned_stack_size - 8) << "(sp)\n"; *OS << " ld s0, " << (aligned_stack_size - 16) << "(sp)\n"; @@ -46,129 +73,85 @@ void RISCv64AsmPrinter::printEpilogue(MachineFunction* mfunc) { } } - void RISCv64AsmPrinter::printBasicBlock(MachineBasicBlock* mbb) { - // 打印基本块标签 if (!mbb->getName().empty()) { *OS << mbb->getName() << ":\n"; } - - // 打印指令 for (auto& instr : mbb->getInstructions()) { - printInstruction(instr.get(), mbb); + printInstruction(instr.get()); } } -void RISCv64AsmPrinter::printInstruction(MachineInstr* instr, MachineBasicBlock* parent_bb) { - *OS << " "; // 指令缩进 - +void RISCv64AsmPrinter::printInstruction(MachineInstr* instr) { auto opcode = instr->getOpcode(); - - // RET指令需要特殊处理,在打印ret之前先打印函数尾声 if (opcode == RVOpcodes::RET) { - printEpilogue(parent_bb->getParent()); + printEpilogue(); } - - // 使用switch将Opcode转换为汇编助记符 + if (opcode != RVOpcodes::LABEL) { + *OS << " "; + } + switch (opcode) { - // Arithmatic - case RVOpcodes::ADD: *OS << "add "; break; - case RVOpcodes::ADDI: *OS << "addi "; break; - case RVOpcodes::ADDW: *OS << "addw "; break; - case RVOpcodes::ADDIW: *OS << "addiw "; break; - case RVOpcodes::SUB: *OS << "sub "; break; - case RVOpcodes::SUBW: *OS << "subw "; break; - case RVOpcodes::MUL: *OS << "mul "; break; - case RVOpcodes::MULW: *OS << "mulw "; break; - case RVOpcodes::DIV: *OS << "div "; break; - case RVOpcodes::DIVW: *OS << "divw "; break; - case RVOpcodes::REM: *OS << "rem "; break; - case RVOpcodes::REMW: *OS << "remw "; break; - // Logical - case RVOpcodes::XOR: *OS << "xor "; break; - case RVOpcodes::XORI: *OS << "xori "; break; - case RVOpcodes::OR: *OS << "or "; break; - case RVOpcodes::ORI: *OS << "ori "; break; - case RVOpcodes::AND: *OS << "and "; break; - case RVOpcodes::ANDI: *OS << "andi "; break; - // Shift - case RVOpcodes::SLL: *OS << "sll "; break; - case RVOpcodes::SLLI: *OS << "slli "; break; - case RVOpcodes::SLLW: *OS << "sllw "; break; - case RVOpcodes::SLLIW: *OS << "slliw "; break; - case RVOpcodes::SRL: *OS << "srl "; break; - case RVOpcodes::SRLI: *OS << "srli "; break; - case RVOpcodes::SRLW: *OS << "srlw "; break; - case RVOpcodes::SRLIW: *OS << "srliw "; break; - case RVOpcodes::SRA: *OS << "sra "; break; - case RVOpcodes::SRAI: *OS << "srai "; break; - case RVOpcodes::SRAW: *OS << "sraw "; break; - case RVOpcodes::SRAIW: *OS << "sraiw "; break; - // Compare - case RVOpcodes::SLT: *OS << "slt "; break; - case RVOpcodes::SLTI: *OS << "slti "; break; - case RVOpcodes::SLTU: *OS << "sltu "; break; - case RVOpcodes::SLTIU: *OS << "sltiu "; break; - // Memory - case RVOpcodes::LW: *OS << "lw "; break; - case RVOpcodes::LH: *OS << "lh "; break; - case RVOpcodes::LB: *OS << "lb "; break; - case RVOpcodes::LWU: *OS << "lwu "; break; - case RVOpcodes::LHU: *OS << "lhu "; break; - case RVOpcodes::LBU: *OS << "lbu "; break; - case RVOpcodes::SW: *OS << "sw "; break; - case RVOpcodes::SH: *OS << "sh "; break; - case RVOpcodes::SB: *OS << "sb "; break; - case RVOpcodes::LD: *OS << "ld "; break; + case RVOpcodes::ADD: *OS << "add "; break; case RVOpcodes::ADDI: *OS << "addi "; break; + case RVOpcodes::ADDW: *OS << "addw "; break; case RVOpcodes::ADDIW: *OS << "addiw "; break; + case RVOpcodes::SUB: *OS << "sub "; break; case RVOpcodes::SUBW: *OS << "subw "; break; + case RVOpcodes::MUL: *OS << "mul "; break; case RVOpcodes::MULW: *OS << "mulw "; break; + case RVOpcodes::DIV: *OS << "div "; break; case RVOpcodes::DIVW: *OS << "divw "; break; + case RVOpcodes::REM: *OS << "rem "; break; case RVOpcodes::REMW: *OS << "remw "; break; + case RVOpcodes::XOR: *OS << "xor "; break; case RVOpcodes::XORI: *OS << "xori "; break; + case RVOpcodes::OR: *OS << "or "; break; case RVOpcodes::ORI: *OS << "ori "; break; + case RVOpcodes::AND: *OS << "and "; break; case RVOpcodes::ANDI: *OS << "andi "; break; + case RVOpcodes::SLL: *OS << "sll "; break; case RVOpcodes::SLLI: *OS << "slli "; break; + case RVOpcodes::SLLW: *OS << "sllw "; break; case RVOpcodes::SLLIW: *OS << "slliw "; break; + case RVOpcodes::SRL: *OS << "srl "; break; case RVOpcodes::SRLI: *OS << "srli "; break; + case RVOpcodes::SRLW: *OS << "srlw "; break; case RVOpcodes::SRLIW: *OS << "srliw "; break; + case RVOpcodes::SRA: *OS << "sra "; break; case RVOpcodes::SRAI: *OS << "srai "; break; + case RVOpcodes::SRAW: *OS << "sraw "; break; case RVOpcodes::SRAIW: *OS << "sraiw "; break; + case RVOpcodes::SLT: *OS << "slt "; break; case RVOpcodes::SLTI: *OS << "slti "; break; + case RVOpcodes::SLTU: *OS << "sltu "; break; case RVOpcodes::SLTIU: *OS << "sltiu "; break; + case RVOpcodes::LW: *OS << "lw "; break; case RVOpcodes::LH: *OS << "lh "; break; + case RVOpcodes::LB: *OS << "lb "; break; case RVOpcodes::LWU: *OS << "lwu "; break; + case RVOpcodes::LHU: *OS << "lhu "; break; case RVOpcodes::LBU: *OS << "lbu "; break; + case RVOpcodes::SW: *OS << "sw "; break; case RVOpcodes::SH: *OS << "sh "; break; + case RVOpcodes::SB: *OS << "sb "; break; case RVOpcodes::LD: *OS << "ld "; break; case RVOpcodes::SD: *OS << "sd "; break; - // Control Flow - case RVOpcodes::J: *OS << "j "; break; - case RVOpcodes::JAL: *OS << "jal "; break; - case RVOpcodes::JALR: *OS << "jalr "; break; - case RVOpcodes::RET: *OS << "ret"; break; - case RVOpcodes::BEQ: *OS << "beq "; break; - case RVOpcodes::BNE: *OS << "bne "; break; - case RVOpcodes::BLT: *OS << "blt "; break; - case RVOpcodes::BGE: *OS << "bge "; break; - case RVOpcodes::BLTU: *OS << "bltu "; break; - case RVOpcodes::BGEU: *OS << "bgeu "; break; - // Pseudo-Instructions - case RVOpcodes::LI: *OS << "li "; break; - case RVOpcodes::LA: *OS << "la "; break; - case RVOpcodes::MV: *OS << "mv "; break; - case RVOpcodes::NEG: *OS << "neg "; break; - case RVOpcodes::NEGW: *OS << "negw "; break; - case RVOpcodes::SEQZ: *OS << "seqz "; break; + case RVOpcodes::J: *OS << "j "; break; case RVOpcodes::JAL: *OS << "jal "; break; + case RVOpcodes::JALR: *OS << "jalr "; break; case RVOpcodes::RET: *OS << "ret"; break; + case RVOpcodes::BEQ: *OS << "beq "; break; case RVOpcodes::BNE: *OS << "bne "; break; + case RVOpcodes::BLT: *OS << "blt "; break; case RVOpcodes::BGE: *OS << "bge "; break; + case RVOpcodes::BLTU: *OS << "bltu "; break; case RVOpcodes::BGEU: *OS << "bgeu "; break; + case RVOpcodes::LI: *OS << "li "; break; case RVOpcodes::LA: *OS << "la "; break; + case RVOpcodes::MV: *OS << "mv "; break; case RVOpcodes::NEG: *OS << "neg "; break; + case RVOpcodes::NEGW: *OS << "negw "; break; case RVOpcodes::SEQZ: *OS << "seqz "; break; case RVOpcodes::SNEZ: *OS << "snez "; break; - // Call case RVOpcodes::CALL: *OS << "call "; break; - // Special case RVOpcodes::LABEL: - *OS << "\b\b\b\b"; printOperand(instr->getOperands()[0].get()); *OS << ":"; break; + case RVOpcodes::FRAME_LOAD: + case RVOpcodes::FRAME_STORE: + // These should have been eliminated by RegAlloc + throw std::runtime_error("FRAME pseudo-instruction not eliminated before AsmPrinter"); default: throw std::runtime_error("Unknown opcode in AsmPrinter"); } - // 打印操作数 const auto& operands = instr->getOperands(); - for (size_t i = 0; i < operands.size(); ++i) { - // 对于LW/SW, 操作数格式是 rd, offset(rs1) - if (opcode == RVOpcodes::LW || opcode == RVOpcodes::SW || opcode == RVOpcodes::LD || opcode == RVOpcodes::SD) { + if (!operands.empty()) { + if (isMemoryOp(opcode)) { printOperand(operands[0].get()); *OS << ", "; printOperand(operands[1].get()); - break; // LW/SW只有两个操作数部分 - } - - printOperand(operands[i].get()); - if (i < operands.size() - 1) { - *OS << ", "; + } else { + for (size_t i = 0; i < operands.size(); ++i) { + printOperand(operands[i].get()); + if (i < operands.size() - 1) { + *OS << ", "; + } + } } } - *OS << "\n"; } @@ -178,21 +161,18 @@ void RISCv64AsmPrinter::printOperand(MachineOperand* op) { case MachineOperand::KIND_REG: { auto reg_op = static_cast(op); if (reg_op->isVirtual()) { - // 在这个阶段不应该再有虚拟寄存器了 *OS << "%vreg" << reg_op->getVRegNum(); } else { *OS << regToString(reg_op->getPReg()); } break; } - case MachineOperand::KIND_IMM: { + case MachineOperand::KIND_IMM: *OS << static_cast(op)->getValue(); break; - } - case MachineOperand::KIND_LABEL: { + case MachineOperand::KIND_LABEL: *OS << static_cast(op)->getName(); break; - } case MachineOperand::KIND_MEM: { auto mem_op = static_cast(op); printOperand(mem_op->getOffset()); @@ -204,41 +184,40 @@ void RISCv64AsmPrinter::printOperand(MachineOperand* op) { } } -// 物理寄存器到字符串的转换 (从原RISCv64Backend.cpp迁移) std::string RISCv64AsmPrinter::regToString(PhysicalReg reg) { switch (reg) { - case PhysicalReg::ZERO: return "x0"; - case PhysicalReg::RA: return "ra"; - case PhysicalReg::SP: return "sp"; - case PhysicalReg::GP: return "gp"; - case PhysicalReg::TP: return "tp"; - case PhysicalReg::T0: return "t0"; - case PhysicalReg::T1: return "t1"; - case PhysicalReg::T2: return "t2"; - case PhysicalReg::S0: return "s0"; - case PhysicalReg::S1: return "s1"; - case PhysicalReg::A0: return "a0"; - case PhysicalReg::A1: return "a1"; - case PhysicalReg::A2: return "a2"; - case PhysicalReg::A3: return "a3"; - case PhysicalReg::A4: return "a4"; - case PhysicalReg::A5: return "a5"; - case PhysicalReg::A6: return "a6"; - case PhysicalReg::A7: return "a7"; - case PhysicalReg::S2: return "s2"; - case PhysicalReg::S3: return "s3"; - case PhysicalReg::S4: return "s4"; - case PhysicalReg::S5: return "s5"; - case PhysicalReg::S6: return "s6"; - case PhysicalReg::S7: return "s7"; - case PhysicalReg::S8: return "s8"; - case PhysicalReg::S9: return "s9"; - case PhysicalReg::S10: return "s10"; - case PhysicalReg::S11: return "s11"; - case PhysicalReg::T3: return "t3"; - case PhysicalReg::T4: return "t4"; - case PhysicalReg::T5: return "t5"; - case PhysicalReg::T6: return "t6"; + case PhysicalReg::ZERO: return "x0"; case PhysicalReg::RA: return "ra"; + case PhysicalReg::SP: return "sp"; case PhysicalReg::GP: return "gp"; + case PhysicalReg::TP: return "tp"; case PhysicalReg::T0: return "t0"; + case PhysicalReg::T1: return "t1"; case PhysicalReg::T2: return "t2"; + case PhysicalReg::S0: return "s0"; case PhysicalReg::S1: return "s1"; + case PhysicalReg::A0: return "a0"; case PhysicalReg::A1: return "a1"; + case PhysicalReg::A2: return "a2"; case PhysicalReg::A3: return "a3"; + case PhysicalReg::A4: return "a4"; case PhysicalReg::A5: return "a5"; + case PhysicalReg::A6: return "a6"; case PhysicalReg::A7: return "a7"; + case PhysicalReg::S2: return "s2"; case PhysicalReg::S3: return "s3"; + case PhysicalReg::S4: return "s4"; case PhysicalReg::S5: return "s5"; + case PhysicalReg::S6: return "s6"; case PhysicalReg::S7: return "s7"; + case PhysicalReg::S8: return "s8"; case PhysicalReg::S9: return "s9"; + case PhysicalReg::S10: return "s10"; case PhysicalReg::S11: return "s11"; + case PhysicalReg::T3: return "t3"; case PhysicalReg::T4: return "t4"; + case PhysicalReg::T5: return "t5"; case PhysicalReg::T6: return "t6"; + case PhysicalReg::F0: return "f0"; case PhysicalReg::F1: return "f1"; + case PhysicalReg::F2: return "f2"; case PhysicalReg::F3: return "f3"; + case PhysicalReg::F4: return "f4"; case PhysicalReg::F5: return "f5"; + case PhysicalReg::F6: return "f6"; case PhysicalReg::F7: return "f7"; + case PhysicalReg::F8: return "f8"; case PhysicalReg::F9: return "f9"; + case PhysicalReg::F10: return "f10"; case PhysicalReg::F11: return "f11"; + case PhysicalReg::F12: return "f12"; case PhysicalReg::F13: return "f13"; + case PhysicalReg::F14: return "f14"; case PhysicalReg::F15: return "f15"; + case PhysicalReg::F16: return "f16"; case PhysicalReg::F17: return "f17"; + case PhysicalReg::F18: return "f18"; case PhysicalReg::F19: return "f19"; + case PhysicalReg::F20: return "f20"; case PhysicalReg::F21: return "f21"; + case PhysicalReg::F22: return "f22"; case PhysicalReg::F23: return "f23"; + case PhysicalReg::F24: return "f24"; case PhysicalReg::F25: return "f25"; + case PhysicalReg::F26: return "f26"; case PhysicalReg::F27: return "f27"; + case PhysicalReg::F28: return "f28"; case PhysicalReg::F29: return "f29"; + case PhysicalReg::F30: return "f30"; case PhysicalReg::F31: return "f31"; default: return "UNKNOWN_REG"; } } diff --git a/src/RISCv64Backend.cpp b/src/RISCv64Backend.cpp index c9fdead..b2a7da0 100644 --- a/src/RISCv64Backend.cpp +++ b/src/RISCv64Backend.cpp @@ -3,7 +3,6 @@ #include "RISCv64RegAlloc.h" #include "RISCv64AsmPrinter.h" #include -#include namespace sysy { @@ -12,13 +11,12 @@ std::string RISCv64CodeGen::code_gen() { return module_gen(); } -// module_gen 的逻辑基本不变,它负责处理.data段和驱动每个函数的生成 +// 模块级代码生成 (移植自原文件,处理.data段和驱动函数生成) std::string RISCv64CodeGen::module_gen() { std::stringstream ss; // 1. 处理全局变量 (.data段) - bool has_globals = !module->getGlobals().empty(); - if (has_globals) { + if (!module->getGlobals().empty()) { ss << ".data\n"; for (const auto& global : module->getGlobals()) { ss << ".globl " << global->getName() << "\n"; @@ -45,9 +43,9 @@ std::string RISCv64CodeGen::module_gen() { // 2. 处理函数 (.text段) if (!module->getFunctions().empty()) { ss << ".text\n"; - for (const auto& func : module->getFunctions()) { - if (func.second.get()) { - ss << function_gen(func.second.get()); + for (const auto& func_pair : module->getFunctions()) { + if (func_pair.second.get()) { + ss << function_gen(func_pair.second.get()); } } } @@ -56,31 +54,18 @@ std::string RISCv64CodeGen::module_gen() { // function_gen 现在是新的、模块化的处理流水线 std::string RISCv64CodeGen::function_gen(Function* func) { - - // === 新的、完整的流水线 === - // 阶段 1: 指令选择 (sysy::IR -> LLIR with virtual registers) RISCv64ISel isel; std::unique_ptr mfunc = isel.runOnFunction(func); - // 阶段 2: 寄存器分配前优化 (未来扩展点) - // 例如: - // auto pre_ra_scheduler = std::make_unique(); - // pre_ra_scheduler->runOnMachineFunction(mfunc.get()); - - // 阶段 3: 物理寄存器分配 (virtual regs -> physical regs + spill code) + // 阶段 2: 寄存器分配 (包含栈帧布局, 活跃性分析, 图着色, spill/rewrite) RISCv64RegAlloc reg_alloc(mfunc.get()); reg_alloc.run(); - // 阶段 4: 寄存器分配后优化 (未来扩展点) - // 例如: - // auto post_ra_peephole = std::make_unique(); - // post_ra_peephole->runOnMachineFunction(mfunc.get()); - - // 阶段 5: 代码发射 (LLIR with physical regs -> Assembly Text) + // 阶段 3: 代码发射 (LLIR with physical regs -> Assembly Text) std::stringstream ss; - RISCv64AsmPrinter printer; - printer.runOnMachineFunction(mfunc.get(), ss); + RISCv64AsmPrinter printer(mfunc.get()); + printer.run(ss); return ss.str(); } diff --git a/src/RISCv64ISel.cpp b/src/RISCv64ISel.cpp index bf56027..005ba96 100644 --- a/src/RISCv64ISel.cpp +++ b/src/RISCv64ISel.cpp @@ -1,22 +1,32 @@ #include "RISCv64ISel.h" #include -#include -#include #include +#include +#include // For std::fabs +#include // For std::numeric_limits namespace sysy { +// DAG节点定义 (内部实现) +struct RISCv64ISel::DAGNode { + enum NodeKind { CONSTANT, LOAD, STORE, BINARY, CALL, RETURN, BRANCH, ALLOCA_ADDR, UNARY, MEMSET }; + NodeKind kind; + Value* value = nullptr; + std::vector operands; + std::vector users; + DAGNode(NodeKind k) : kind(k) {} +}; + RISCv64ISel::RISCv64ISel() : vreg_counter(0), local_label_counter(0) {} // 为一个IR Value获取或分配一个新的虚拟寄存器 unsigned RISCv64ISel::getVReg(Value* val) { - if (!val) { // 安全检查 + if (!val) { throw std::runtime_error("Cannot get vreg for a null Value."); } if (vreg_map.find(val) == vreg_map.end()) { if (vreg_counter == 0) { - // vreg 0 通常保留给物理寄存器x0(zero),我们从1开始分配 - vreg_counter = 1; + vreg_counter = 1; // vreg 0 保留 } vreg_map[val] = vreg_counter++; } @@ -27,7 +37,7 @@ unsigned RISCv64ISel::getVReg(Value* val) { std::unique_ptr RISCv64ISel::runOnFunction(Function* func) { F = func; if (!F) return nullptr; - MFunc = std::make_unique(F->getName()); + MFunc = std::make_unique(F, this); vreg_map.clear(); bb_map.clear(); vreg_counter = 0; @@ -40,37 +50,28 @@ std::unique_ptr RISCv64ISel::runOnFunction(Function* func) { // 指令选择主流程 void RISCv64ISel::select() { - // 1. 为所有基本块创建对应的MachineBasicBlock for (const auto& bb_ptr : F->getBasicBlocks()) { - BasicBlock* bb = bb_ptr.get(); - auto mbb = std::make_unique(bb->getName(), MFunc.get()); - bb_map[bb] = mbb.get(); + auto mbb = std::make_unique(bb_ptr->getName(), MFunc.get()); + bb_map[bb_ptr.get()] = mbb.get(); MFunc->addBlock(std::move(mbb)); } - // 2. 为函数参数创建虚拟寄存器 - // ====================== 已修正 ====================== - // 根据 IR.h, 参数列表存储在入口基本块中 if (F->getEntryBlock()) { for (auto* arg_alloca : F->getEntryBlock()->getArguments()) { getVReg(arg_alloca); } } - // ===================================================== - // 3. 遍历每个基本块,生成指令 for (const auto& bb_ptr : F->getBasicBlocks()) { selectBasicBlock(bb_ptr.get()); } - // 4. 设置基本块的前驱后继关系 for (const auto& bb_ptr : F->getBasicBlocks()) { - BasicBlock* bb = bb_ptr.get(); - CurMBB = bb_map.at(bb); - for (auto succ : bb->getSuccessors()) { + CurMBB = bb_map.at(bb_ptr.get()); + for (auto succ : bb_ptr->getSuccessors()) { CurMBB->successors.push_back(bb_map.at(succ)); } - for (auto pred : bb->getPredecessors()) { + for (auto pred : bb_ptr->getPredecessors()) { CurMBB->predecessors.push_back(bb_map.at(pred)); } } @@ -87,29 +88,23 @@ void RISCv64ISel::selectBasicBlock(BasicBlock* bb) { value_to_node[node->value] = node.get(); } } - + std::set selected_nodes; std::function select_recursive = [&](DAGNode* node) { if (!node || selected_nodes.count(node)) return; - for (auto operand : node->operands) { select_recursive(operand); } - - // 只有当所有操作数都选择完毕后,才选择当前节点 selectNode(node); selected_nodes.insert(node); }; - // 按照IR指令的原始顺序来驱动指令选择 for (const auto& inst_ptr : bb->getInstructions()) { DAGNode* node_to_select = nullptr; - // 查找当前IR指令对应的DAG节点 if (value_to_node.count(inst_ptr.get())) { node_to_select = value_to_node.at(inst_ptr.get()); } else { - // 对于没有返回值的指令或某些特殊情况 for(const auto& node : dag) { if(node->value == inst_ptr.get()) { node_to_select = node.get(); @@ -123,88 +118,105 @@ void RISCv64ISel::selectBasicBlock(BasicBlock* bb) { } } +// 核心函数:为DAG节点选择并生成MachineInstr (忠实移植版) void RISCv64ISel::selectNode(DAGNode* node) { - // 注意:不再生成字符串,而是创建MachineInstr对象并加入到CurMBB switch (node->kind) { case DAGNode::CONSTANT: case DAGNode::ALLOCA_ADDR: - // 这些节点本身不生成指令。使用它们的指令会按需处理。 - // 为Alloca地址分配一个vreg是必要的,代表地址。 if (node->value) getVReg(node->value); break; case DAGNode::LOAD: { - // lw rd, offset(base) auto dest_vreg = getVReg(node->value); - auto ptr_vreg = getVReg(node->operands[0]->value); + Value* ptr_val = node->operands[0]->value; - auto instr = std::make_unique(RVOpcodes::LW); - instr->addOperand(std::make_unique(dest_vreg)); - // 暂时生成0(ptr),后续pass会将其优化为 offset(s0) - instr->addOperand(std::make_unique( - std::make_unique(ptr_vreg), - std::make_unique(0) - )); - CurMBB->addInstruction(std::move(instr)); + if (auto alloca = dynamic_cast(ptr_val)) { + auto instr = std::make_unique(RVOpcodes::FRAME_LOAD); + instr->addOperand(std::make_unique(dest_vreg)); + instr->addOperand(std::make_unique(getVReg(alloca))); + CurMBB->addInstruction(std::move(instr)); + } else if (auto global = dynamic_cast(ptr_val)) { + auto addr_vreg = getNewVReg(); + auto la = std::make_unique(RVOpcodes::LA); + la->addOperand(std::make_unique(addr_vreg)); + la->addOperand(std::make_unique(global->getName())); + CurMBB->addInstruction(std::move(la)); + + auto lw = std::make_unique(RVOpcodes::LW); + lw->addOperand(std::make_unique(dest_vreg)); + lw->addOperand(std::make_unique( + std::make_unique(addr_vreg), + std::make_unique(0) + )); + CurMBB->addInstruction(std::move(lw)); + } else { + auto ptr_vreg = getVReg(ptr_val); + auto lw = std::make_unique(RVOpcodes::LW); + lw->addOperand(std::make_unique(dest_vreg)); + lw->addOperand(std::make_unique( + std::make_unique(ptr_vreg), + std::make_unique(0) + )); + CurMBB->addInstruction(std::move(lw)); + } break; } case DAGNode::STORE: { - // sw rs2, offset(rs1) - // 先加载常量 - if (auto val_const = dynamic_cast(node->operands[0]->value)) { + Value* val_to_store = node->operands[0]->value; + Value* ptr_val = node->operands[1]->value; + + if (auto val_const = dynamic_cast(val_to_store)) { auto li = std::make_unique(RVOpcodes::LI); li->addOperand(std::make_unique(getVReg(val_const))); li->addOperand(std::make_unique(val_const->getInt())); CurMBB->addInstruction(std::move(li)); } + auto val_vreg = getVReg(val_to_store); - auto val_vreg = getVReg(node->operands[0]->value); - auto ptr_vreg = getVReg(node->operands[1]->value); + if (auto alloca = dynamic_cast(ptr_val)) { + auto instr = std::make_unique(RVOpcodes::FRAME_STORE); + instr->addOperand(std::make_unique(val_vreg)); + instr->addOperand(std::make_unique(getVReg(alloca))); + CurMBB->addInstruction(std::move(instr)); + } else if (auto global = dynamic_cast(ptr_val)) { + auto addr_vreg = getNewVReg(); + auto la = std::make_unique(RVOpcodes::LA); + la->addOperand(std::make_unique(addr_vreg)); + la->addOperand(std::make_unique(global->getName())); + CurMBB->addInstruction(std::move(la)); - auto instr = std::make_unique(RVOpcodes::SW); - instr->addOperand(std::make_unique(val_vreg)); // value to store - instr->addOperand(std::make_unique( - std::make_unique(ptr_vreg), // base address - std::make_unique(0) // offset - )); - CurMBB->addInstruction(std::move(instr)); + auto sw = std::make_unique(RVOpcodes::SW); + sw->addOperand(std::make_unique(val_vreg)); + sw->addOperand(std::make_unique( + std::make_unique(addr_vreg), + std::make_unique(0) + )); + CurMBB->addInstruction(std::move(sw)); + } else { + auto ptr_vreg = getVReg(ptr_val); + auto sw = std::make_unique(RVOpcodes::SW); + sw->addOperand(std::make_unique(val_vreg)); + sw->addOperand(std::make_unique( + std::make_unique(ptr_vreg), + std::make_unique(0) + )); + CurMBB->addInstruction(std::move(sw)); + } break; } case DAGNode::BINARY: { auto bin = dynamic_cast(node->value); - if (!bin) break; - Value* lhs = bin->getLhs(); Value* rhs = bin->getRhs(); - // 检查是否为 addi 优化 - if (bin->getKind() == BinaryInst::kAdd) { - if (auto rhs_const = dynamic_cast(rhs)) { - if (rhs_const->getInt() >= -2048 && rhs_const->getInt() < 2048) { - auto instr = std::make_unique(RVOpcodes::ADDIW); - instr->addOperand(std::make_unique(getVReg(bin))); - instr->addOperand(std::make_unique(getVReg(lhs))); - instr->addOperand(std::make_unique(rhs_const->getInt())); - CurMBB->addInstruction(std::move(instr)); - return; // 指令已生成,提前返回 - } - } - } - - // 为操作数加载立即数或地址 auto load_val_if_const = [&](Value* val) { if (auto c = dynamic_cast(val)) { auto li = std::make_unique(RVOpcodes::LI); li->addOperand(std::make_unique(getVReg(c))); li->addOperand(std::make_unique(c->getInt())); CurMBB->addInstruction(std::move(li)); - } else if (auto g = dynamic_cast(val)) { - auto la = std::make_unique(RVOpcodes::LA); - la->addOperand(std::make_unique(getVReg(g))); - la->addOperand(std::make_unique(g->getName())); - CurMBB->addInstruction(std::move(la)); } }; load_val_if_const(lhs); @@ -214,7 +226,19 @@ void RISCv64ISel::selectNode(DAGNode* node) { auto lhs_vreg = getVReg(lhs); auto rhs_vreg = getVReg(rhs); - // 生成二元运算指令 + if (bin->getKind() == BinaryInst::kAdd) { + if (auto rhs_const = dynamic_cast(rhs)) { + if (rhs_const->getInt() >= -2048 && rhs_const->getInt() < 2048) { + auto instr = std::make_unique(RVOpcodes::ADDIW); + instr->addOperand(std::make_unique(dest_vreg)); + instr->addOperand(std::make_unique(lhs_vreg)); + instr->addOperand(std::make_unique(rhs_const->getInt())); + CurMBB->addInstruction(std::move(instr)); + return; + } + } + } + switch (bin->getKind()) { case BinaryInst::kAdd: { RVOpcodes opcode = (lhs->getType()->isPointer() || rhs->getType()->isPointer()) ? RVOpcodes::ADD : RVOpcodes::ADDW; @@ -294,16 +318,16 @@ void RISCv64ISel::selectNode(DAGNode* node) { case BinaryInst::kICmpGT: { auto instr = std::make_unique(RVOpcodes::SLT); instr->addOperand(std::make_unique(dest_vreg)); - instr->addOperand(std::make_unique(rhs_vreg)); // Swapped - instr->addOperand(std::make_unique(lhs_vreg)); // Swapped + instr->addOperand(std::make_unique(rhs_vreg)); + instr->addOperand(std::make_unique(lhs_vreg)); CurMBB->addInstruction(std::move(instr)); break; } case BinaryInst::kICmpLE: { auto slt = std::make_unique(RVOpcodes::SLT); slt->addOperand(std::make_unique(dest_vreg)); - slt->addOperand(std::make_unique(rhs_vreg)); // Swapped - slt->addOperand(std::make_unique(lhs_vreg)); // Swapped + slt->addOperand(std::make_unique(rhs_vreg)); + slt->addOperand(std::make_unique(lhs_vreg)); CurMBB->addInstruction(std::move(slt)); auto xori = std::make_unique(RVOpcodes::XORI); @@ -335,8 +359,6 @@ void RISCv64ISel::selectNode(DAGNode* node) { case DAGNode::UNARY: { auto unary = dynamic_cast(node->value); - if (!unary) break; - auto dest_vreg = getVReg(unary); auto src_vreg = getVReg(unary->getOperand()); @@ -344,7 +366,7 @@ void RISCv64ISel::selectNode(DAGNode* node) { case UnaryInst::kNeg: { auto instr = std::make_unique(RVOpcodes::SUBW); instr->addOperand(std::make_unique(dest_vreg)); - instr->addOperand(std::make_unique(PhysicalReg::ZERO)); // x0 + instr->addOperand(std::make_unique(PhysicalReg::ZERO)); instr->addOperand(std::make_unique(src_vreg)); CurMBB->addInstruction(std::move(instr)); break; @@ -364,51 +386,67 @@ void RISCv64ISel::selectNode(DAGNode* node) { case DAGNode::CALL: { auto call = dynamic_cast(node->value); - if (!call) break; + for (size_t i = 0; i < node->operands.size() && i < 8; ++i) { + DAGNode* arg_node = node->operands[i]; + auto arg_preg = static_cast(static_cast(PhysicalReg::A0) + i); + + if (arg_node->kind == DAGNode::CONSTANT) { + if (auto const_val = dynamic_cast(arg_node->value)) { + auto li = std::make_unique(RVOpcodes::LI); + li->addOperand(std::make_unique(arg_preg)); + li->addOperand(std::make_unique(const_val->getInt())); + CurMBB->addInstruction(std::move(li)); + } + } else { + auto src_vreg = getVReg(arg_node->value); + auto mv = std::make_unique(RVOpcodes::MV); + mv->addOperand(std::make_unique(arg_preg)); + mv->addOperand(std::make_unique(src_vreg)); + CurMBB->addInstruction(std::move(mv)); + } + } - // 在此阶段,我们只处理函数调用本身和返回值的移动 - // 参数的传递将在一个专门的 Calling Convention Pass 中处理 - auto call_instr = std::make_unique(RVOpcodes::CALL); call_instr->addOperand(std::make_unique(call->getCallee()->getName())); CurMBB->addInstruction(std::move(call_instr)); if (!call->getType()->isVoid()) { auto mv_instr = std::make_unique(RVOpcodes::MV); - mv_instr->addOperand(std::make_unique(getVReg(call))); // dest - mv_instr->addOperand(std::make_unique(PhysicalReg::A0)); // src + mv_instr->addOperand(std::make_unique(getVReg(call))); + mv_instr->addOperand(std::make_unique(PhysicalReg::A0)); CurMBB->addInstruction(std::move(mv_instr)); } break; } case DAGNode::RETURN: { - auto ret_inst = dynamic_cast(node->value); - if (ret_inst && ret_inst->hasReturnValue()) { - // 如果有返回值,生成一条mv指令将其放入a0 - auto mv_instr = std::make_unique(RVOpcodes::MV); - mv_instr->addOperand(std::make_unique(PhysicalReg::A0)); - mv_instr->addOperand(std::make_unique(getVReg(ret_inst->getReturnValue()))); - CurMBB->addInstruction(std::move(mv_instr)); + auto ret_inst_ir = dynamic_cast(node->value); + if (ret_inst_ir && ret_inst_ir->hasReturnValue()) { + Value* ret_val = ret_inst_ir->getReturnValue(); + if (auto const_val = dynamic_cast(ret_val)) { + auto li_instr = std::make_unique(RVOpcodes::LI); + li_instr->addOperand(std::make_unique(PhysicalReg::A0)); + li_instr->addOperand(std::make_unique(const_val->getInt())); + CurMBB->addInstruction(std::move(li_instr)); + } else { + auto mv_instr = std::make_unique(RVOpcodes::MV); + mv_instr->addOperand(std::make_unique(PhysicalReg::A0)); + mv_instr->addOperand(std::make_unique(getVReg(ret_val))); + CurMBB->addInstruction(std::move(mv_instr)); + } } - // 生成ret伪指令 - auto instr = std::make_unique(RVOpcodes::RET); - CurMBB->addInstruction(std::move(instr)); + auto ret_mi = std::make_unique(RVOpcodes::RET); + CurMBB->addInstruction(std::move(ret_mi)); break; } case DAGNode::BRANCH: { if (auto cond_br = dynamic_cast(node->value)) { - // bne cond, x0, then_block auto br_instr = std::make_unique(RVOpcodes::BNE); br_instr->addOperand(std::make_unique(getVReg(cond_br->getCondition()))); br_instr->addOperand(std::make_unique(PhysicalReg::ZERO)); br_instr->addOperand(std::make_unique(cond_br->getThenBlock()->getName())); CurMBB->addInstruction(std::move(br_instr)); - - // j else_block - // 注意:这里会产生一个fallthrough问题,后续的分支优化pass会解决它 - // 一个更健壮的生成方式是 bne -> j else; then: ...; else: ... } else if (auto uncond_br = dynamic_cast(node->value)) { auto j_instr = std::make_unique(RVOpcodes::J); j_instr->addOperand(std::make_unique(uncond_br->getBlock()->getName())); @@ -416,20 +454,16 @@ void RISCv64ISel::selectNode(DAGNode* node) { } break; } - case DAGNode::MEMSET: { - // 这是对原memset逻辑的完整LLIR翻译 - auto memset = dynamic_cast(node->value); - if (!memset) break; + case DAGNode::MEMSET: { + auto memset = dynamic_cast(node->value); auto r_dest_addr = getVReg(memset->getPointer()); auto r_num_bytes = getVReg(memset->getSize()); auto r_value_byte = getVReg(memset->getValue()); - - // 为临时值创建虚拟寄存器 - auto r_counter = vreg_counter++; - auto r_end_addr = vreg_counter++; - auto r_current_addr = vreg_counter++; - auto r_temp_val = vreg_counter++; + auto r_counter = getNewVReg(); + auto r_end_addr = getNewVReg(); + auto r_current_addr = getNewVReg(); + auto r_temp_val = getNewVReg(); auto add_instr = [&](RVOpcodes op, unsigned rd, unsigned rs1, unsigned rs2) { auto i = std::make_unique(op); @@ -470,12 +504,11 @@ void RISCv64ISel::selectNode(DAGNode* node) { }; int unique_id = this->local_label_counter++; - std::string loop_start_label = "memset_loop_start_" + std::to_string(unique_id); - std::string loop_end_label = "memset_loop_end_" + std::to_string(unique_id); - std::string remainder_label = "memset_remainder_" + std::to_string(unique_id); - std::string done_label = "memset_done_" + std::to_string(unique_id); - - // 构造64位的填充值 + std::string loop_start_label = MFunc->getName() + "_memset_loop_start_" + std::to_string(unique_id); + std::string loop_end_label = MFunc->getName() + "_memset_loop_end_" + std::to_string(unique_id); + std::string remainder_label = MFunc->getName() + "_memset_remainder_" + std::to_string(unique_id); + std::string done_label = MFunc->getName() + "_memset_done_" + std::to_string(unique_id); + addi_instr(RVOpcodes::ANDI, r_temp_val, r_value_byte, 255); addi_instr(RVOpcodes::SLLI, r_value_byte, r_temp_val, 8); add_instr(RVOpcodes::OR, r_temp_val, r_temp_val, r_value_byte); @@ -483,8 +516,6 @@ void RISCv64ISel::selectNode(DAGNode* node) { add_instr(RVOpcodes::OR, r_temp_val, r_temp_val, r_value_byte); addi_instr(RVOpcodes::SLLI, r_value_byte, r_temp_val, 32); add_instr(RVOpcodes::OR, r_temp_val, r_temp_val, r_value_byte); - - // 设置循环变量 add_instr(RVOpcodes::ADD, r_end_addr, r_dest_addr, r_num_bytes); auto mv = std::make_unique(RVOpcodes::MV); mv->addOperand(std::make_unique(r_current_addr)); @@ -492,16 +523,12 @@ void RISCv64ISel::selectNode(DAGNode* node) { CurMBB->addInstruction(std::move(mv)); addi_instr(RVOpcodes::ANDI, r_counter, r_num_bytes, -8); add_instr(RVOpcodes::ADD, r_counter, r_dest_addr, r_counter); - - // 64位写入循环 label_instr(loop_start_label); branch_instr(RVOpcodes::BGEU, r_current_addr, r_counter, loop_end_label); store_instr(RVOpcodes::SD, r_temp_val, r_current_addr, 0); addi_instr(RVOpcodes::ADDI, r_current_addr, r_current_addr, 8); jump_instr(loop_start_label); label_instr(loop_end_label); - - // 剩余字节写入循环 label_instr(remainder_label); branch_instr(RVOpcodes::BGEU, r_current_addr, r_end_addr, done_label); store_instr(RVOpcodes::SB, r_temp_val, r_current_addr, 0); @@ -512,13 +539,13 @@ void RISCv64ISel::selectNode(DAGNode* node) { } default: - throw std::runtime_error("Unsupported DAGNode kind in ISel: " + std::to_string(node->kind)); + throw std::runtime_error("Unsupported DAGNode kind in ISel"); } } - -// --- DAG构建函数 (从原RISCv64Backend.cpp几乎原样迁移, 保持不变) --- -RISCv64ISel::DAGNode* RISCv64ISel::create_node(DAGNode::NodeKind kind, Value* val, std::map& value_to_node, std::vector>& nodes_storage) { +// 以下是忠实移植的DAG构建函数 +RISCv64ISel::DAGNode* RISCv64ISel::create_node(int kind_int, Value* val, std::map& value_to_node, std::vector>& nodes_storage) { + auto kind = static_cast(kind_int); if (val && value_to_node.count(val) && kind != DAGNode::STORE && kind != DAGNode::RETURN && kind != DAGNode::BRANCH && kind != DAGNode::MEMSET) { return value_to_node[val]; } @@ -526,10 +553,7 @@ RISCv64ISel::DAGNode* RISCv64ISel::create_node(DAGNode::NodeKind kind, Value* va node->value = val; DAGNode* raw_node_ptr = node.get(); nodes_storage.push_back(std::move(node)); - // 只有产生值的节点才应该被记录,以备复用 - if (val && !val->getType()->isVoid() && dynamic_cast(val)) { - value_to_node[val] = raw_node_ptr; - } else if (val && dynamic_cast(val)) { + if (val && !val->getType()->isVoid() && (dynamic_cast(val) || dynamic_cast(val))) { value_to_node[val] = raw_node_ptr; } return raw_node_ptr; @@ -545,7 +569,6 @@ RISCv64ISel::DAGNode* RISCv64ISel::get_operand_node(Value* val_ir, std::map(val_ir)) { return create_node(DAGNode::ALLOCA_ADDR, val_ir, value_to_node, nodes_storage); } - // Fallback: Assume it needs to be loaded if not found (might be a parameter or a value from another block) return create_node(DAGNode::LOAD, val_ir, value_to_node, nodes_storage); } @@ -567,12 +590,20 @@ std::vector> RISCv64ISel::build_dag(BasicB memset_node->operands.push_back(get_operand_node(memset->getBegin(), value_to_node, nodes_storage)); memset_node->operands.push_back(get_operand_node(memset->getSize(), value_to_node, nodes_storage)); memset_node->operands.push_back(get_operand_node(memset->getValue(), value_to_node, nodes_storage)); - } - else if (auto load = dynamic_cast(inst)) { + } else if (auto load = dynamic_cast(inst)) { auto load_node = create_node(DAGNode::LOAD, load, value_to_node, nodes_storage); load_node->operands.push_back(get_operand_node(load->getPointer(), value_to_node, nodes_storage)); } else if (auto bin = dynamic_cast(inst)) { if(value_to_node.count(bin)) continue; + if (bin->getKind() == BinaryInst::kSub) { + if (auto const_lhs = dynamic_cast(bin->getLhs())) { + if (const_lhs->getInt() == 0) { + auto unary_node = create_node(DAGNode::UNARY, bin, value_to_node, nodes_storage); + unary_node->operands.push_back(get_operand_node(bin->getRhs(), value_to_node, nodes_storage)); + continue; + } + } + } auto bin_node = create_node(DAGNode::BINARY, bin, value_to_node, nodes_storage); bin_node->operands.push_back(get_operand_node(bin->getLhs(), value_to_node, nodes_storage)); bin_node->operands.push_back(get_operand_node(bin->getRhs(), value_to_node, nodes_storage)); @@ -580,8 +611,7 @@ std::vector> RISCv64ISel::build_dag(BasicB if(value_to_node.count(un)) continue; auto unary_node = create_node(DAGNode::UNARY, un, value_to_node, nodes_storage); unary_node->operands.push_back(get_operand_node(un->getOperand(), value_to_node, nodes_storage)); - } - else if (auto call = dynamic_cast(inst)) { + } else if (auto call = dynamic_cast(inst)) { if(value_to_node.count(call)) continue; auto call_node = create_node(DAGNode::CALL, call, value_to_node, nodes_storage); for (auto arg : call->getArguments()) { diff --git a/src/RISCv64Passes.cpp b/src/RISCv64Passes.cpp new file mode 100644 index 0000000..40aff21 --- /dev/null +++ b/src/RISCv64Passes.cpp @@ -0,0 +1,8 @@ +// RISCv64Passes.cpp +#include "RISCv64Passes.h" + +namespace sysy { + +// 此处为未来优化Pass的实现 + +} // namespace sysy \ No newline at end of file diff --git a/src/RISCv64RegAlloc.cpp b/src/RISCv64RegAlloc.cpp index 4471868..2695f3b 100644 --- a/src/RISCv64RegAlloc.cpp +++ b/src/RISCv64RegAlloc.cpp @@ -1,11 +1,11 @@ #include "RISCv64RegAlloc.h" +#include "RISCv64ISel.h" #include #include namespace sysy { RISCv64RegAlloc::RISCv64RegAlloc(MachineFunction* mfunc) : MFunc(mfunc) { - // 初始化可分配的整数寄存器池 (排除特殊用途的) allocable_int_regs = { PhysicalReg::T0, PhysicalReg::T1, PhysicalReg::T2, PhysicalReg::T3, PhysicalReg::T4, PhysicalReg::T5, PhysicalReg::T6, @@ -18,23 +18,113 @@ RISCv64RegAlloc::RISCv64RegAlloc(MachineFunction* mfunc) : MFunc(mfunc) { } void RISCv64RegAlloc::run() { + eliminateFrameIndices(); analyzeLiveness(); buildInterferenceGraph(); colorGraph(); rewriteFunction(); } +void RISCv64RegAlloc::eliminateFrameIndices() { + StackFrameInfo& frame_info = MFunc->getFrameInfo(); + int current_offset = 0; + Function* F = MFunc->getFunc(); + RISCv64ISel* isel = MFunc->getISel(); + + for (auto& bb : F->getBasicBlocks()) { + for (auto& inst : bb->getInstructions()) { + if (auto alloca = dynamic_cast(inst.get())) { + int size = 4; + if (!alloca->getDims().empty()) { + int num_elements = 1; + for (const auto& dim_use : alloca->getDims()) { + if (auto const_dim = dynamic_cast(dim_use->getValue())) { + num_elements *= const_dim->getInt(); + } + } + size *= num_elements; + } + current_offset += size; + unsigned alloca_vreg = isel->getVReg(alloca); + frame_info.alloca_offsets[alloca_vreg] = -current_offset; + } + } + } + frame_info.locals_size = current_offset; + + for (auto& mbb : MFunc->getBlocks()) { + std::vector> new_instructions; + for (auto& instr_ptr : mbb->getInstructions()) { + if (instr_ptr->getOpcode() == RVOpcodes::FRAME_LOAD) { + auto& operands = instr_ptr->getOperands(); + unsigned dest_vreg = static_cast(operands[0].get())->getVRegNum(); + unsigned alloca_vreg = static_cast(operands[1].get())->getVRegNum(); + int offset = frame_info.alloca_offsets.at(alloca_vreg); + auto addr_vreg = isel->getNewVReg(); + + auto addi = std::make_unique(RVOpcodes::ADDI); + addi->addOperand(std::make_unique(addr_vreg)); + addi->addOperand(std::make_unique(PhysicalReg::S0)); + addi->addOperand(std::make_unique(offset)); + new_instructions.push_back(std::move(addi)); + + auto lw = std::make_unique(RVOpcodes::LW); + lw->addOperand(std::make_unique(dest_vreg)); + lw->addOperand(std::make_unique( + std::make_unique(addr_vreg), + std::make_unique(0))); + new_instructions.push_back(std::move(lw)); + + } else if (instr_ptr->getOpcode() == RVOpcodes::FRAME_STORE) { + auto& operands = instr_ptr->getOperands(); + unsigned src_vreg = static_cast(operands[0].get())->getVRegNum(); + unsigned alloca_vreg = static_cast(operands[1].get())->getVRegNum(); + int offset = frame_info.alloca_offsets.at(alloca_vreg); + auto addr_vreg = isel->getNewVReg(); + + auto addi = std::make_unique(RVOpcodes::ADDI); + addi->addOperand(std::make_unique(addr_vreg)); + addi->addOperand(std::make_unique(PhysicalReg::S0)); + addi->addOperand(std::make_unique(offset)); + new_instructions.push_back(std::move(addi)); + + auto sw = std::make_unique(RVOpcodes::SW); + sw->addOperand(std::make_unique(src_vreg)); + sw->addOperand(std::make_unique( + std::make_unique(addr_vreg), + std::make_unique(0))); + new_instructions.push_back(std::move(sw)); + } else { + new_instructions.push_back(std::move(instr_ptr)); + } + } + mbb->getInstructions() = std::move(new_instructions); + } +} + void RISCv64RegAlloc::getInstrUseDef(MachineInstr* instr, LiveSet& use, LiveSet& def) { - // 这是一个简化的版本,实际需要根据RVOpcodes精确定义 - // 通常第一个RegOperand是def,其余是use bool is_def = true; + auto opcode = instr->getOpcode(); + + // 预定义def和use规则 + if (opcode == RVOpcodes::SW || opcode == RVOpcodes::SD || + opcode == RVOpcodes::BEQ || opcode == RVOpcodes::BNE || + opcode == RVOpcodes::BLT || opcode == RVOpcodes::BGE || + opcode == RVOpcodes::RET || opcode == RVOpcodes::J) { + is_def = false; + } + if (opcode == RVOpcodes::CALL) { + // CALL会杀死所有调用者保存寄存器,这是一个简化处理 + // 同时也使用了传入a0-a7的参数 + } + for (const auto& op : instr->getOperands()) { if (op->getKind() == MachineOperand::KIND_REG) { auto reg_op = static_cast(op.get()); if (reg_op->isVirtual()) { if (is_def) { def.insert(reg_op->getVRegNum()); - is_def = false; // 假设每条指令最多一个def + is_def = false; } else { use.insert(reg_op->getVRegNum()); } @@ -46,35 +136,16 @@ void RISCv64RegAlloc::getInstrUseDef(MachineInstr* instr, LiveSet& use, LiveSet& } } } - - // 特殊处理store和branch指令,它们没有显式的def - auto opcode = instr->getOpcode(); - if (opcode == RVOpcodes::SW || opcode == RVOpcodes::SD || opcode == RVOpcodes::BNE || opcode == RVOpcodes::BEQ) { - def.clear(); // 清空错误的def - use.clear(); - for (const auto& op : instr->getOperands()) { - if (op->getKind() == MachineOperand::KIND_REG) { - auto reg_op = static_cast(op.get()); - if(reg_op->isVirtual()) use.insert(reg_op->getVRegNum()); - } else if (op->getKind() == MachineOperand::KIND_MEM) { - auto mem_op = static_cast(op.get()); - if(mem_op->getBase()->isVirtual()) use.insert(mem_op->getBase()->getVRegNum()); - } - } - } } - void RISCv64RegAlloc::analyzeLiveness() { bool changed = true; while (changed) { changed = false; - // 逆序遍历基本块 for (auto it = MFunc->getBlocks().rbegin(); it != MFunc->getBlocks().rend(); ++it) { auto& mbb = *it; LiveSet live_out; for (auto succ : mbb->successors) { - // live_out[B] = Union(live_in[S]) for all S in succ(B) if (!succ->getInstructions().empty()) { auto first_instr = succ->getInstructions().front().get(); if (live_in_map.count(first_instr)) { @@ -83,19 +154,14 @@ void RISCv64RegAlloc::analyzeLiveness() { } } - // 逆序遍历指令 for (auto instr_it = mbb->getInstructions().rbegin(); instr_it != mbb->getInstructions().rend(); ++instr_it) { MachineInstr* instr = instr_it->get(); LiveSet old_live_in = live_in_map[instr]; - LiveSet old_live_out = live_out_map[instr]; - - // 更新 live_out live_out_map[instr] = live_out; LiveSet use, def; getInstrUseDef(instr, use, def); - // live_in[i] = use[i] U (live_out[i] - def[i]) LiveSet live_in = use; LiveSet diff = live_out; for (auto vreg : def) { @@ -104,10 +170,9 @@ void RISCv64RegAlloc::analyzeLiveness() { live_in.insert(diff.begin(), diff.end()); live_in_map[instr] = live_in; - // 为下一次迭代准备live_out live_out = live_in; - if (live_in_map[instr] != old_live_in || live_out_map[instr] != old_live_out) { + if (live_in_map[instr] != old_live_in) { changed = true; } } @@ -117,21 +182,21 @@ void RISCv64RegAlloc::analyzeLiveness() { void RISCv64RegAlloc::buildInterferenceGraph() { std::set all_vregs; - // 收集所有虚拟寄存器 - for (auto const& [instr, live_set] : live_out_map) { - all_vregs.insert(live_set.begin(), live_set.end()); + for (auto& mbb : MFunc->getBlocks()) { + for(auto& instr : mbb->getInstructions()) { + LiveSet use, def; + getInstrUseDef(instr.get(), use, def); + for(auto u : use) all_vregs.insert(u); + for(auto d : def) all_vregs.insert(d); + } } - // 初始化图 - for (auto vreg : all_vregs) { - interference_graph[vreg] = {}; - } + for (auto vreg : all_vregs) { interference_graph[vreg] = {}; } for (auto& mbb : MFunc->getBlocks()) { for (auto& instr : mbb->getInstructions()) { LiveSet def, use; getInstrUseDef(instr.get(), use, def); - const LiveSet& live_out = live_out_map.at(instr.get()); for (unsigned d : def) { @@ -152,21 +217,18 @@ void RISCv64RegAlloc::colorGraph() { sorted_vregs.push_back(vreg); } - // 按度数降序排序 (简单贪心策略) std::sort(sorted_vregs.begin(), sorted_vregs.end(), [&](unsigned a, unsigned b) { return interference_graph[a].size() > interference_graph[b].size(); }); for (unsigned vreg : sorted_vregs) { std::set used_colors; - // 查找邻居已用的颜色 for (unsigned neighbor : interference_graph.at(vreg)) { if (color_map.count(neighbor)) { used_colors.insert(color_map.at(neighbor)); } } - // 寻找一个可用的颜色 bool colored = false; for (PhysicalReg preg : allocable_int_regs) { if (used_colors.find(preg) == used_colors.end()) { @@ -175,54 +237,47 @@ void RISCv64RegAlloc::colorGraph() { break; } } - if (!colored) { - // 无法分配,需要溢出 spilled_vregs.insert(vreg); } } } void RISCv64RegAlloc::rewriteFunction() { - // 1. 为所有溢出的vreg分配栈槽 StackFrameInfo& frame_info = MFunc->getFrameInfo(); - int current_offset = frame_info.frame_size; // 假设从现有栈大小后开始分配 + int current_offset = frame_info.locals_size; for (unsigned vreg : spilled_vregs) { - current_offset += 4; // 假设所有溢出变量都占4字节 - frame_info.spill_slots[vreg] = -current_offset; // 栈向下增长,所以是负偏移 + current_offset += 4; + frame_info.spill_offsets[vreg] = -current_offset; } - frame_info.frame_size = current_offset; + frame_info.spill_size = current_offset - frame_info.locals_size; - // 2. 遍历所有指令,替换vreg并插入spill代码 for (auto& mbb : MFunc->getBlocks()) { std::vector> new_instructions; for (auto& instr_ptr : mbb->getInstructions()) { LiveSet use, def; getInstrUseDef(instr_ptr.get(), use, def); - // 为use的溢出变量插入LOAD for (unsigned vreg : use) { if (spilled_vregs.count(vreg)) { - int offset = frame_info.spill_slots.at(vreg); + int offset = frame_info.spill_offsets.at(vreg); auto load = std::make_unique(RVOpcodes::LW); - load->addOperand(std::make_unique(vreg)); // 临时用vreg号代表,稍后替换 + load->addOperand(std::make_unique(vreg)); load->addOperand(std::make_unique( - std::make_unique(PhysicalReg::S0), // 基址用帧指针s0 + std::make_unique(PhysicalReg::S0), std::make_unique(offset) )); new_instructions.push_back(std::move(load)); } } - // 添加原始指令 new_instructions.push_back(std::move(instr_ptr)); - // 为def的溢出变量插入STORE for (unsigned vreg : def) { if (spilled_vregs.count(vreg)) { - int offset = frame_info.spill_slots.at(vreg); + int offset = frame_info.spill_offsets.at(vreg); auto store = std::make_unique(RVOpcodes::SW); - store->addOperand(std::make_unique(vreg)); // 临时用vreg号代表 + store->addOperand(std::make_unique(vreg)); store->addOperand(std::make_unique( std::make_unique(PhysicalReg::S0), std::make_unique(offset) @@ -234,7 +289,6 @@ void RISCv64RegAlloc::rewriteFunction() { mbb->getInstructions() = std::move(new_instructions); } - // 3. 最后一遍扫描,将所有RegOperand从vreg替换为preg for (auto& mbb : MFunc->getBlocks()) { for (auto& instr_ptr : mbb->getInstructions()) { for (auto& op_ptr : instr_ptr->getOperands()) { @@ -245,8 +299,7 @@ void RISCv64RegAlloc::rewriteFunction() { if (color_map.count(vreg)) { reg_op->setPReg(color_map.at(vreg)); } else if (spilled_vregs.count(vreg)) { - // 对于spill的vreg, 使用一个固定的临时寄存器, 比如t6 - reg_op->setPReg(PhysicalReg::T6); + reg_op->setPReg(PhysicalReg::T6); // 溢出统一用t6 } } } else if (op_ptr->getKind() == MachineOperand::KIND_MEM) { @@ -254,7 +307,11 @@ void RISCv64RegAlloc::rewriteFunction() { auto base_reg_op = mem_op->getBase(); if(base_reg_op->isVirtual()){ unsigned vreg = base_reg_op->getVRegNum(); - if(color_map.count(vreg)) base_reg_op->setPReg(color_map.at(vreg)); + if(color_map.count(vreg)) { + base_reg_op->setPReg(color_map.at(vreg)); + } else if (spilled_vregs.count(vreg)) { + base_reg_op->setPReg(PhysicalReg::T6); + } } } } diff --git a/src/include/RISCv64AsmPrinter.h b/src/include/RISCv64AsmPrinter.h index 7df35f6..3ea71f6 100644 --- a/src/include/RISCv64AsmPrinter.h +++ b/src/include/RISCv64AsmPrinter.h @@ -8,29 +8,23 @@ namespace sysy { class RISCv64AsmPrinter { public: - // 主入口,将整个MachineFunction打印到指定的输出流 - void runOnMachineFunction(MachineFunction* mfunc, std::ostream& os); + RISCv64AsmPrinter(MachineFunction* mfunc); + // 主入口 + void run(std::ostream& os); private: - // 打印单个基本块 + // 打印各个部分 + void printPrologue(); + void printEpilogue(); void printBasicBlock(MachineBasicBlock* mbb); - - // 打印单条指令 - void printInstruction(MachineInstr* instr, MachineBasicBlock* parent_bb); - - // 打印函数序言 - void printPrologue(MachineFunction* mfunc); + void printInstruction(MachineInstr* instr); - // 打印函数尾声 - void printEpilogue(MachineFunction* mfunc); - - // 将物理寄存器枚举转换为字符串 (从原RISCv64Backend迁移) + // 辅助函数 std::string regToString(PhysicalReg reg); - - // 打印单个操作数 void printOperand(MachineOperand* op); - std::ostream* OS; // 指向当前输出流 + MachineFunction* MFunc; + std::ostream* OS; }; } // namespace sysy diff --git a/src/include/RISCv64Backend.h b/src/include/RISCv64Backend.h index f929007..33f7831 100644 --- a/src/include/RISCv64Backend.h +++ b/src/include/RISCv64Backend.h @@ -1,10 +1,8 @@ #ifndef RISCV64_BACKEND_H #define RISCV64_BACKEND_H -#include "IR.h" // 只需包含高层IR定义 +#include "IR.h" #include -#include -#include namespace sysy { @@ -12,14 +10,12 @@ namespace sysy { class RISCv64CodeGen { public: RISCv64CodeGen(Module* mod) : module(mod) {} - // 唯一的公共入口点 std::string code_gen(); private: - // 模块级代码生成 (处理全局变量和驱动函数生成) + // 模块级代码生成 std::string module_gen(); - // 函数级代码生成 (实现新的流水线) std::string function_gen(Function* func); diff --git a/src/include/RISCv64ISel.h b/src/include/RISCv64ISel.h index e122ee0..795b2b8 100644 --- a/src/include/RISCv64ISel.h +++ b/src/include/RISCv64ISel.h @@ -1,10 +1,7 @@ #ifndef RISCV64_ISEL_H #define RISCV64_ISEL_H -#include "IR.h" #include "RISCv64LLIR.h" -#include -#include namespace sysy { @@ -14,43 +11,34 @@ public: // 模块主入口:将一个高层IR函数转换为底层LLIR函数 std::unique_ptr runOnFunction(Function* func); + // 公开接口,以便后续模块(如RegAlloc)可以查询或创建vreg + unsigned getVReg(Value* val); + unsigned getNewVReg() { return vreg_counter++; } + private: // DAG节点定义,作为ISel的内部实现细节 - struct DAGNode { - enum NodeKind { CONSTANT, LOAD, STORE, BINARY, CALL, RETURN, BRANCH, ALLOCA_ADDR, UNARY, MEMSET }; - NodeKind kind; - Value* value = nullptr; - std::vector operands; - DAGNode(NodeKind k) : kind(k) {} - }; - - // 为当前函数生成LLIR + struct DAGNode; + + // 指令选择主流程 void select(); - // 为单个基本块生成指令 void selectBasicBlock(BasicBlock* bb); - // 核心函数:为DAG节点选择并生成MachineInstr void selectNode(DAGNode* node); - // --- DAG 构建相关函数 (从原RISCv64Backend迁移) --- + // DAG 构建相关函数 (从原RISCv64Backend迁移) std::vector> build_dag(BasicBlock* bb); - DAGNode* get_operand_node(Value* val_ir, std::map& value_to_node, std::vector>& nodes_storage); - DAGNode* create_node(DAGNode::NodeKind kind, Value* val, std::map& value_to_node, std::vector>& nodes_storage); - - // --- 辅助函数 --- - // 为一个IR Value获取/分配一个虚拟寄存器号 - unsigned getVReg(Value* val); + DAGNode* get_operand_node(Value* val_ir, std::map&, std::vector>&); + DAGNode* create_node(int kind, Value* val, std::map&, std::vector>&); + // 状态 Function* F; // 当前处理的高层IR函数 std::unique_ptr MFunc; // 正在构建的底层LLIR函数 - MachineBasicBlock* CurMBB; // 当前正在处理的机器基本块 // 映射关系 std::map vreg_map; std::map bb_map; - std::map value_to_node_map; // 用于selectNode中查找 unsigned vreg_counter; int local_label_counter; diff --git a/src/include/RISCv64LLIR.h b/src/include/RISCv64LLIR.h index c1df3bf..6310741 100644 --- a/src/include/RISCv64LLIR.h +++ b/src/include/RISCv64LLIR.h @@ -1,66 +1,51 @@ #ifndef RISCV64_LLIR_H #define RISCV64_LLIR_H +#include "IR.h" // 确保包含了您自己的IR头文件 #include #include #include #include #include +// 前向声明,避免循环引用 +namespace sysy { +class Function; +class RISCv64ISel; +} + namespace sysy { -// 物理寄存器定义 (从 RISCv64Backend.h 移至此) +// 物理寄存器定义 enum class PhysicalReg { ZERO, RA, SP, GP, TP, T0, T1, T2, S0, S1, A0, A1, A2, A3, A4, A5, A6, A7, S2, S3, S4, S5, S6, S7, S8, S9, S10, S11, T3, T4, T5, T6, F0, F1, F2, F3, F4, F5, F6, F7, F8, F9, F10, F11, F12, F13, F14, F15,F16, F17, F18, F19, F20, F21, F22, F23, F24, F25, F26, F27, F28, F29, F30, F31 }; - // RISC-V 指令操作码枚举 enum class RVOpcodes { // 算术指令 - ADD, ADDI, ADDW, ADDIW, - SUB, SUBW, - MUL, MULW, - DIV, DIVW, - REM, REMW, - + ADD, ADDI, ADDW, ADDIW, SUB, SUBW, MUL, MULW, DIV, DIVW, REM, REMW, // 逻辑指令 - XOR, XORI, - OR, ORI, - AND, ANDI, - + XOR, XORI, OR, ORI, AND, ANDI, // 移位指令 - SLL, SLLI, SLLW, SLLIW, - SRL, SRLI, SRLW, SRLIW, - SRA, SRAI, SRAW, SRAIW, - + SLL, SLLI, SLLW, SLLIW, SRL, SRLI, SRLW, SRLIW, SRA, SRAI, SRAW, SRAIW, // 比较指令 SLT, SLTI, SLTU, SLTIU, - // 内存访问指令 - LW, LH, LB, LWU, LHU, LBU, - SW, SH, SB, - LD, SD, // 64位 - + LW, LH, LB, LWU, LHU, LBU, SW, SH, SB, LD, SD, // 控制流指令 - J, JAL, JALR, RET, // RET 是 JALR x0, 0(ra) 的伪指令 + J, JAL, JALR, RET, BEQ, BNE, BLT, BGE, BLTU, BGEU, - - // 伪指令 (方便指令选择) - LI, // Load Immediate - LA, // Load Address - MV, // Move register - NEG, // Negate - NEGW, // Negate Word - SEQZ, // Set if Equal to Zero - SNEZ, // Set if Not Equal to Zero - + // 伪指令 + LI, LA, MV, NEG, NEGW, SEQZ, SNEZ, // 函数调用 CALL, - // 特殊标记,非指令 - LABEL, // 用于表示一个标签位置 + LABEL, + // 新增伪指令,用于解耦栈帧处理 + FRAME_LOAD, // 从栈帧加载 (AllocaInst) + FRAME_STORE, // 保存到栈帧 (AllocaInst) }; class MachineOperand; @@ -72,22 +57,13 @@ class MachineInstr; class MachineBasicBlock; class MachineFunction; -// --- 操作数定义 --- - // 操作数基类 class MachineOperand { public: - enum OperandKind { - KIND_REG, - KIND_IMM, - KIND_LABEL, - KIND_MEM - }; - + enum OperandKind { KIND_REG, KIND_IMM, KIND_LABEL, KIND_MEM }; MachineOperand(OperandKind kind) : kind(kind) {} virtual ~MachineOperand() = default; OperandKind getKind() const { return kind; } - private: OperandKind kind; }; @@ -111,7 +87,6 @@ public: preg = new_preg; is_virtual = false; } - private: unsigned vreg_num = 0; PhysicalReg preg = PhysicalReg::ZERO; @@ -121,9 +96,7 @@ private: // 立即数操作数 class ImmOperand : public MachineOperand { public: - ImmOperand(int64_t value) - : MachineOperand(KIND_IMM), value(value) {} - + ImmOperand(int64_t value) : MachineOperand(KIND_IMM), value(value) {} int64_t getValue() const { return value; } private: int64_t value; @@ -132,9 +105,7 @@ private: // 标签操作数 class LabelOperand : public MachineOperand { public: - LabelOperand(const std::string& name) - : MachineOperand(KIND_LABEL), name(name) {} - + LabelOperand(const std::string& name) : MachineOperand(KIND_LABEL), name(name) {} const std::string& getName() const { return name; } private: std::string name; @@ -145,33 +116,25 @@ class MemOperand : public MachineOperand { public: MemOperand(std::unique_ptr base, std::unique_ptr offset) : MachineOperand(KIND_MEM), base(std::move(base)), offset(std::move(offset)) {} - RegOperand* getBase() const { return base.get(); } ImmOperand* getOffset() const { return offset.get(); } - private: std::unique_ptr base; std::unique_ptr offset; }; - -// --- 组织结构定义 --- - // 机器指令 class MachineInstr { public: MachineInstr(RVOpcodes opcode) : opcode(opcode) {} RVOpcodes getOpcode() const { return opcode; } - // 注意:返回const引用,因为通常不直接修改指令的操作数列表 const std::vector>& getOperands() const { return operands; } - // 提供一个非const版本,用于内部修改 std::vector>& getOperands() { return operands; } void addOperand(std::unique_ptr operand) { operands.push_back(std::move(operand)); } - private: RVOpcodes opcode; std::vector> operands; @@ -185,8 +148,6 @@ public: const std::string& getName() const { return name; } MachineFunction* getParent() const { return parent; } - - // 同时提供 const 和 non-const 版本 const std::vector>& getInstructions() const { return instructions; } std::vector>& getInstructions() { return instructions; } @@ -196,43 +157,44 @@ public: std::vector successors; std::vector predecessors; - private: std::string name; std::vector> instructions; - MachineFunction* parent; // 指向所属函数 + MachineFunction* parent; }; // 栈帧信息 struct StackFrameInfo { - int frame_size = 0; - std::map spill_slots; // <虚拟寄存器号, 栈偏移> - // ... 未来可以添加更多信息 + int locals_size = 0; // 仅为AllocaInst分配的大小 + int spill_size = 0; // 仅为溢出分配的大小 + int total_size = 0; // 总大小 + std::map alloca_offsets; // + std::map spill_offsets; // <溢出vreg, 栈偏移> }; // 机器函数 class MachineFunction { public: - MachineFunction(const std::string& name) : name(name) {} + MachineFunction(Function* func, RISCv64ISel* isel) : F(func), name(func->getName()), isel(isel) {} + Function* getFunc() const { return F; } + RISCv64ISel* getISel() const { return isel; } const std::string& getName() const { return name; } StackFrameInfo& getFrameInfo() { return frame_info; } - - // 同时提供 const 和 non-const 版本 const std::vector>& getBlocks() const { return blocks; } std::vector>& getBlocks() { return blocks; } void addBlock(std::unique_ptr block) { blocks.push_back(std::move(block)); } - private: + Function* F; + RISCv64ISel* isel; // 指向创建它的ISel,用于获取vreg映射等信息 std::string name; std::vector> blocks; StackFrameInfo frame_info; }; - } // namespace sysy #endif // RISCV64_LLIR_H \ No newline at end of file diff --git a/src/include/RISCv64Passes.h b/src/include/RISCv64Passes.h new file mode 100644 index 0000000..3a4bcd1 --- /dev/null +++ b/src/include/RISCv64Passes.h @@ -0,0 +1,18 @@ +// RISCv64Passes.h +#ifndef RISCV64_PASSES_H +#define RISCV64_PASSES_H + +#include "RISCv64LLIR.h" + +namespace sysy { + +// 此处为未来优化Pass的基类或独立类定义 +// 例如: +// class PeepholeOptimizer { +// public: +// void runOnMachineFunction(MachineFunction* mfunc); +// }; + +} // namespace sysy + +#endif // RISCV64_PASSES_H \ No newline at end of file diff --git a/src/include/RISCv64RegAlloc.h b/src/include/RISCv64RegAlloc.h index ee8ab16..c786bde 100644 --- a/src/include/RISCv64RegAlloc.h +++ b/src/include/RISCv64RegAlloc.h @@ -2,9 +2,6 @@ #define RISCV64_REGALLOC_H #include "RISCv64LLIR.h" -#include -#include -#include namespace sysy { @@ -19,6 +16,9 @@ private: using LiveSet = std::set; // 活跃虚拟寄存器集合 using InterferenceGraph = std::map>; + // 栈帧管理 + void eliminateFrameIndices(); + // 活跃性分析 void analyzeLiveness(); @@ -28,7 +28,7 @@ private: // 图着色分配寄存器 void colorGraph(); - // 重写函数,将虚拟寄存器替换为物理寄存器,并插入溢出代码 + // 重写函数,替换vreg并插入溢出代码 void rewriteFunction(); // 辅助函数,获取指令的Use/Def集合 @@ -37,8 +37,8 @@ private: MachineFunction* MFunc; // 活跃性分析结果 - std::map live_in_map; - std::map live_out_map; + std::map live_in_map; + std::map live_out_map; // 干扰图 InterferenceGraph interference_graph; @@ -49,7 +49,6 @@ private: // 可用的物理寄存器池 std::vector allocable_int_regs; - std::vector allocable_float_regs; // (为未来浮点支持预留) }; } // namespace sysy