diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index c67a64e..b5053bc 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -28,6 +28,9 @@ add_executable(sysyc # Mem2Reg.cpp # Reg2Mem.cpp RISCv64Backend.cpp + RISCv64ISel.cpp + RISCv64RegAlloc.cpp + RISCv64AsmPrinter.cpp ) # 设置 include 路径,包含 ANTLR 运行时库和项目头文件 diff --git a/src/RISCv64AsmPrinter.cpp b/src/RISCv64AsmPrinter.cpp new file mode 100644 index 0000000..0ad1c81 --- /dev/null +++ b/src/RISCv64AsmPrinter.cpp @@ -0,0 +1,225 @@ +#include "RISCv64AsmPrinter.h" +#include "RISCv64ISel.h" +#include + +namespace sysy { + +// 检查是否为内存加载/存储指令,以处理特殊的打印格式 +bool isMemoryOp(RVOpcodes opcode) { + switch (opcode) { + case RVOpcodes::LB: case RVOpcodes::LH: case RVOpcodes::LW: case RVOpcodes::LD: + case RVOpcodes::LBU: case RVOpcodes::LHU: case RVOpcodes::LWU: + case RVOpcodes::SB: case RVOpcodes::SH: case RVOpcodes::SW: case RVOpcodes::SD: + return true; + default: + return false; + } +} + +RISCv64AsmPrinter::RISCv64AsmPrinter(MachineFunction* mfunc) : MFunc(mfunc) {} + +void RISCv64AsmPrinter::run(std::ostream& os) { + OS = &os; + + *OS << ".globl " << MFunc->getName() << "\n"; + *OS << MFunc->getName() << ":\n"; + + printPrologue(); + + for (auto& mbb : MFunc->getBlocks()) { + printBasicBlock(mbb.get()); + } +} + +void RISCv64AsmPrinter::printPrologue() { + StackFrameInfo& frame_info = MFunc->getFrameInfo(); + // 序言需要为保存ra和s0预留16字节 + int total_stack_size = frame_info.locals_size + frame_info.spill_size + 16; + int aligned_stack_size = (total_stack_size + 15) & ~15; + frame_info.total_size = aligned_stack_size; + + if (aligned_stack_size > 0) { + *OS << " addi sp, sp, -" << aligned_stack_size << "\n"; + *OS << " sd ra, " << (aligned_stack_size - 8) << "(sp)\n"; + *OS << " sd s0, " << (aligned_stack_size - 16) << "(sp)\n"; + *OS << " mv s0, sp\n"; + } + + // 忠实还原保存函数入口参数的逻辑 + Function* F = MFunc->getFunc(); + if (F && F->getEntryBlock()) { + int arg_idx = 0; + RISCv64ISel* isel = MFunc->getISel(); + for (AllocaInst* alloca_for_param : F->getEntryBlock()->getArguments()) { + if (arg_idx >= 8) break; + + unsigned vreg = isel->getVReg(alloca_for_param); + if (frame_info.alloca_offsets.count(vreg)) { + int offset = frame_info.alloca_offsets.at(vreg); + auto arg_reg = static_cast(static_cast(PhysicalReg::A0) + arg_idx); + *OS << " sw " << regToString(arg_reg) << ", " << offset << "(s0)\n"; + } + arg_idx++; + } + } +} + +void RISCv64AsmPrinter::printEpilogue() { + int aligned_stack_size = MFunc->getFrameInfo().total_size; + if (aligned_stack_size > 0) { + *OS << " ld ra, " << (aligned_stack_size - 8) << "(sp)\n"; + *OS << " ld s0, " << (aligned_stack_size - 16) << "(sp)\n"; + *OS << " addi sp, sp, " << aligned_stack_size << "\n"; + } +} + +void RISCv64AsmPrinter::printBasicBlock(MachineBasicBlock* mbb) { + if (!mbb->getName().empty()) { + *OS << mbb->getName() << ":\n"; + } + for (auto& instr : mbb->getInstructions()) { + printInstruction(instr.get()); + } +} + +void RISCv64AsmPrinter::printInstruction(MachineInstr* instr) { + auto opcode = instr->getOpcode(); + if (opcode == RVOpcodes::RET) { + printEpilogue(); + } + if (opcode != RVOpcodes::LABEL) { + *OS << " "; + } + + switch (opcode) { + case RVOpcodes::ADD: *OS << "add "; break; case RVOpcodes::ADDI: *OS << "addi "; break; + case RVOpcodes::ADDW: *OS << "addw "; break; case RVOpcodes::ADDIW: *OS << "addiw "; break; + case RVOpcodes::SUB: *OS << "sub "; break; case RVOpcodes::SUBW: *OS << "subw "; break; + case RVOpcodes::MUL: *OS << "mul "; break; case RVOpcodes::MULW: *OS << "mulw "; break; + case RVOpcodes::DIV: *OS << "div "; break; case RVOpcodes::DIVW: *OS << "divw "; break; + case RVOpcodes::REM: *OS << "rem "; break; case RVOpcodes::REMW: *OS << "remw "; break; + case RVOpcodes::XOR: *OS << "xor "; break; case RVOpcodes::XORI: *OS << "xori "; break; + case RVOpcodes::OR: *OS << "or "; break; case RVOpcodes::ORI: *OS << "ori "; break; + case RVOpcodes::AND: *OS << "and "; break; case RVOpcodes::ANDI: *OS << "andi "; break; + case RVOpcodes::SLL: *OS << "sll "; break; case RVOpcodes::SLLI: *OS << "slli "; break; + case RVOpcodes::SLLW: *OS << "sllw "; break; case RVOpcodes::SLLIW: *OS << "slliw "; break; + case RVOpcodes::SRL: *OS << "srl "; break; case RVOpcodes::SRLI: *OS << "srli "; break; + case RVOpcodes::SRLW: *OS << "srlw "; break; case RVOpcodes::SRLIW: *OS << "srliw "; break; + case RVOpcodes::SRA: *OS << "sra "; break; case RVOpcodes::SRAI: *OS << "srai "; break; + case RVOpcodes::SRAW: *OS << "sraw "; break; case RVOpcodes::SRAIW: *OS << "sraiw "; break; + case RVOpcodes::SLT: *OS << "slt "; break; case RVOpcodes::SLTI: *OS << "slti "; break; + case RVOpcodes::SLTU: *OS << "sltu "; break; case RVOpcodes::SLTIU: *OS << "sltiu "; break; + case RVOpcodes::LW: *OS << "lw "; break; case RVOpcodes::LH: *OS << "lh "; break; + case RVOpcodes::LB: *OS << "lb "; break; case RVOpcodes::LWU: *OS << "lwu "; break; + case RVOpcodes::LHU: *OS << "lhu "; break; case RVOpcodes::LBU: *OS << "lbu "; break; + case RVOpcodes::SW: *OS << "sw "; break; case RVOpcodes::SH: *OS << "sh "; break; + case RVOpcodes::SB: *OS << "sb "; break; case RVOpcodes::LD: *OS << "ld "; break; + case RVOpcodes::SD: *OS << "sd "; break; + case RVOpcodes::J: *OS << "j "; break; case RVOpcodes::JAL: *OS << "jal "; break; + case RVOpcodes::JALR: *OS << "jalr "; break; case RVOpcodes::RET: *OS << "ret"; break; + case RVOpcodes::BEQ: *OS << "beq "; break; case RVOpcodes::BNE: *OS << "bne "; break; + case RVOpcodes::BLT: *OS << "blt "; break; case RVOpcodes::BGE: *OS << "bge "; break; + case RVOpcodes::BLTU: *OS << "bltu "; break; case RVOpcodes::BGEU: *OS << "bgeu "; break; + case RVOpcodes::LI: *OS << "li "; break; case RVOpcodes::LA: *OS << "la "; break; + case RVOpcodes::MV: *OS << "mv "; break; case RVOpcodes::NEG: *OS << "neg "; break; + case RVOpcodes::NEGW: *OS << "negw "; break; case RVOpcodes::SEQZ: *OS << "seqz "; break; + case RVOpcodes::SNEZ: *OS << "snez "; break; + case RVOpcodes::CALL: *OS << "call "; break; + case RVOpcodes::LABEL: + printOperand(instr->getOperands()[0].get()); + *OS << ":"; + break; + case RVOpcodes::FRAME_LOAD: + case RVOpcodes::FRAME_STORE: + // These should have been eliminated by RegAlloc + throw std::runtime_error("FRAME pseudo-instruction not eliminated before AsmPrinter"); + default: + throw std::runtime_error("Unknown opcode in AsmPrinter"); + } + + const auto& operands = instr->getOperands(); + if (!operands.empty()) { + if (isMemoryOp(opcode)) { + printOperand(operands[0].get()); + *OS << ", "; + printOperand(operands[1].get()); + } else { + for (size_t i = 0; i < operands.size(); ++i) { + printOperand(operands[i].get()); + if (i < operands.size() - 1) { + *OS << ", "; + } + } + } + } + *OS << "\n"; +} + +void RISCv64AsmPrinter::printOperand(MachineOperand* op) { + if (!op) return; + switch(op->getKind()) { + case MachineOperand::KIND_REG: { + auto reg_op = static_cast(op); + if (reg_op->isVirtual()) { + *OS << "%vreg" << reg_op->getVRegNum(); + } else { + *OS << regToString(reg_op->getPReg()); + } + break; + } + case MachineOperand::KIND_IMM: + *OS << static_cast(op)->getValue(); + break; + case MachineOperand::KIND_LABEL: + *OS << static_cast(op)->getName(); + break; + case MachineOperand::KIND_MEM: { + auto mem_op = static_cast(op); + printOperand(mem_op->getOffset()); + *OS << "("; + printOperand(mem_op->getBase()); + *OS << ")"; + break; + } + } +} + +std::string RISCv64AsmPrinter::regToString(PhysicalReg reg) { + switch (reg) { + case PhysicalReg::ZERO: return "x0"; case PhysicalReg::RA: return "ra"; + case PhysicalReg::SP: return "sp"; case PhysicalReg::GP: return "gp"; + case PhysicalReg::TP: return "tp"; case PhysicalReg::T0: return "t0"; + case PhysicalReg::T1: return "t1"; case PhysicalReg::T2: return "t2"; + case PhysicalReg::S0: return "s0"; case PhysicalReg::S1: return "s1"; + case PhysicalReg::A0: return "a0"; case PhysicalReg::A1: return "a1"; + case PhysicalReg::A2: return "a2"; case PhysicalReg::A3: return "a3"; + case PhysicalReg::A4: return "a4"; case PhysicalReg::A5: return "a5"; + case PhysicalReg::A6: return "a6"; case PhysicalReg::A7: return "a7"; + case PhysicalReg::S2: return "s2"; case PhysicalReg::S3: return "s3"; + case PhysicalReg::S4: return "s4"; case PhysicalReg::S5: return "s5"; + case PhysicalReg::S6: return "s6"; case PhysicalReg::S7: return "s7"; + case PhysicalReg::S8: return "s8"; case PhysicalReg::S9: return "s9"; + case PhysicalReg::S10: return "s10"; case PhysicalReg::S11: return "s11"; + case PhysicalReg::T3: return "t3"; case PhysicalReg::T4: return "t4"; + case PhysicalReg::T5: return "t5"; case PhysicalReg::T6: return "t6"; + case PhysicalReg::F0: return "f0"; case PhysicalReg::F1: return "f1"; + case PhysicalReg::F2: return "f2"; case PhysicalReg::F3: return "f3"; + case PhysicalReg::F4: return "f4"; case PhysicalReg::F5: return "f5"; + case PhysicalReg::F6: return "f6"; case PhysicalReg::F7: return "f7"; + case PhysicalReg::F8: return "f8"; case PhysicalReg::F9: return "f9"; + case PhysicalReg::F10: return "f10"; case PhysicalReg::F11: return "f11"; + case PhysicalReg::F12: return "f12"; case PhysicalReg::F13: return "f13"; + case PhysicalReg::F14: return "f14"; case PhysicalReg::F15: return "f15"; + case PhysicalReg::F16: return "f16"; case PhysicalReg::F17: return "f17"; + case PhysicalReg::F18: return "f18"; case PhysicalReg::F19: return "f19"; + case PhysicalReg::F20: return "f20"; case PhysicalReg::F21: return "f21"; + case PhysicalReg::F22: return "f22"; case PhysicalReg::F23: return "f23"; + case PhysicalReg::F24: return "f24"; case PhysicalReg::F25: return "f25"; + case PhysicalReg::F26: return "f26"; case PhysicalReg::F27: return "f27"; + case PhysicalReg::F28: return "f28"; case PhysicalReg::F29: return "f29"; + case PhysicalReg::F30: return "f30"; case PhysicalReg::F31: return "f31"; + default: return "UNKNOWN_REG"; + } +} + +} // namespace sysy \ No newline at end of file diff --git a/src/RISCv64Backend.cpp b/src/RISCv64Backend.cpp index 29b0bfd..b2a7da0 100644 --- a/src/RISCv64Backend.cpp +++ b/src/RISCv64Backend.cpp @@ -1,122 +1,26 @@ #include "RISCv64Backend.h" +#include "RISCv64ISel.h" +#include "RISCv64RegAlloc.h" +#include "RISCv64AsmPrinter.h" #include -#include -#include -#include -#include -#include -namespace sysy { +namespace sysy { -// 可用于分配的寄存器 -const std::vector RISCv64CodeGen::allocable_regs = { - // 整数寄存器 - PhysicalReg::T0, PhysicalReg::T1, PhysicalReg::T2, PhysicalReg::T3, - PhysicalReg::T4, PhysicalReg::T5, PhysicalReg::T6, - PhysicalReg::A0, PhysicalReg::A1, PhysicalReg::A2, PhysicalReg::A3, - PhysicalReg::A4, PhysicalReg::A5, PhysicalReg::A6, PhysicalReg::A7, - PhysicalReg::S0, PhysicalReg::S1, PhysicalReg::S2, PhysicalReg::S3, - PhysicalReg::S4, PhysicalReg::S5, PhysicalReg::S6, PhysicalReg::S7, - PhysicalReg::S8, PhysicalReg::S9, PhysicalReg::S10, PhysicalReg::S11, - // 浮点寄存器 - PhysicalReg::F0, PhysicalReg::F1, PhysicalReg::F2, PhysicalReg::F3, - PhysicalReg::F4, PhysicalReg::F5, PhysicalReg::F6, PhysicalReg::F7, - PhysicalReg::F8, PhysicalReg::F9, PhysicalReg::F10, PhysicalReg::F11, - PhysicalReg::F12, PhysicalReg::F13, PhysicalReg::F14, PhysicalReg::F15, - PhysicalReg::F16, PhysicalReg::F17, PhysicalReg::F18, PhysicalReg::F19, - PhysicalReg::F20, PhysicalReg::F21, PhysicalReg::F22, PhysicalReg::F23, - PhysicalReg::F24, PhysicalReg::F25, PhysicalReg::F26, PhysicalReg::F27, - PhysicalReg::F28, PhysicalReg::F29, PhysicalReg::F30, PhysicalReg::F31 -}; - -// 将物理寄存器枚举转换为字符串 -std::string RISCv64CodeGen::reg_to_string(PhysicalReg reg) { - switch (reg) { - case PhysicalReg::ZERO: return "x0"; - case PhysicalReg::RA: return "ra"; - case PhysicalReg::SP: return "sp"; - case PhysicalReg::GP: return "gp"; - case PhysicalReg::TP: return "tp"; - case PhysicalReg::T0: return "t0"; - case PhysicalReg::T1: return "t1"; - case PhysicalReg::T2: return "t2"; - case PhysicalReg::S0: return "s0"; - case PhysicalReg::S1: return "s1"; - case PhysicalReg::A0: return "a0"; - case PhysicalReg::A1: return "a1"; - case PhysicalReg::A2: return "a2"; - case PhysicalReg::A3: return "a3"; - case PhysicalReg::A4: return "a4"; - case PhysicalReg::A5: return "a5"; - case PhysicalReg::A6: return "a6"; - case PhysicalReg::A7: return "a7"; - case PhysicalReg::S2: return "s2"; - case PhysicalReg::S3: return "s3"; - case PhysicalReg::S4: return "s4"; - case PhysicalReg::S5: return "s5"; - case PhysicalReg::S6: return "s6"; - case PhysicalReg::S7: return "s7"; - case PhysicalReg::S8: return "s8"; - case PhysicalReg::S9: return "s9"; - case PhysicalReg::S10: return "s10"; - case PhysicalReg::S11: return "s11"; - case PhysicalReg::T3: return "t3"; - case PhysicalReg::T4: return "t4"; - case PhysicalReg::T5: return "t5"; - case PhysicalReg::T6: return "t6"; - // 浮点寄存器 - case PhysicalReg::F0: return "f0"; - case PhysicalReg::F1: return "f1"; - case PhysicalReg::F2: return "f2"; - case PhysicalReg::F3: return "f3"; - case PhysicalReg::F4: return "f4"; - case PhysicalReg::F5: return "f5"; - case PhysicalReg::F6: return "f6"; - case PhysicalReg::F7: return "f7"; - case PhysicalReg::F8: return "f8"; - case PhysicalReg::F9: return "f9"; - case PhysicalReg::F10: return "f10"; - case PhysicalReg::F11: return "f11"; - case PhysicalReg::F12: return "f12"; - case PhysicalReg::F13: return "f13"; - case PhysicalReg::F14: return "f14"; - case PhysicalReg::F15: return "f15"; - case PhysicalReg::F16: return "f16"; - case PhysicalReg::F17: return "f17"; - case PhysicalReg::F18: return "f18"; - case PhysicalReg::F19: return "f19"; - case PhysicalReg::F20: return "f20"; - case PhysicalReg::F21: return "f21"; - case PhysicalReg::F22: return "f22"; - case PhysicalReg::F23: return "f23"; - case PhysicalReg::F24: return "f24"; - case PhysicalReg::F25: return "f25"; - case PhysicalReg::F26: return "f26"; - case PhysicalReg::F27: return "f27"; - case PhysicalReg::F28: return "f28"; - case PhysicalReg::F29: return "f29"; - case PhysicalReg::F30: return "f30"; - case PhysicalReg::F31: return "f31"; - default: return "UNKNOWN_REG"; - } -} - -// 总体代码生成入口 +// 顶层入口 std::string RISCv64CodeGen::code_gen() { - std::stringstream ss; - ss << module_gen(); - return ss.str(); + return module_gen(); } -// 模块级代码生成 (处理全局变量和函数) +// 模块级代码生成 (移植自原文件,处理.data段和驱动函数生成) std::string RISCv64CodeGen::module_gen() { std::stringstream ss; - bool has_globals = !module->getGlobals().empty(); - if (has_globals) { - ss << ".data\n"; // 数据段 + + // 1. 处理全局变量 (.data段) + if (!module->getGlobals().empty()) { + ss << ".data\n"; for (const auto& global : module->getGlobals()) { - ss << ".globl " << global->getName() << "\n"; // 声明全局符号 - ss << global->getName() << ":\n"; // 标签 + ss << ".globl " << global->getName() << "\n"; + ss << global->getName() << ":\n"; const auto& init_values = global->getInitValues(); for (size_t i = 0; i < init_values.getValues().size(); ++i) { auto val = init_values.getValues()[i]; @@ -124,1375 +28,46 @@ std::string RISCv64CodeGen::module_gen() { if (auto constant = dynamic_cast(val)) { for (unsigned j = 0; j < count; ++j) { if (constant->isInt()) { - ss << " .word " << constant->getInt() << "\n"; // 整数常量 (32位) + ss << " .word " << constant->getInt() << "\n"; } else { float f = constant->getFloat(); uint32_t float_bits = *(uint32_t*)&f; - ss << " .word " << float_bits << "\n"; // 浮点常量 (32位) + ss << " .word " << float_bits << "\n"; } } } } } } + + // 2. 处理函数 (.text段) if (!module->getFunctions().empty()) { - ss << ".text\n"; // 代码段 - for (const auto& func : module->getFunctions()) { - ss << function_gen(func.second.get()); + ss << ".text\n"; + for (const auto& func_pair : module->getFunctions()) { + if (func_pair.second.get()) { + ss << function_gen(func_pair.second.get()); + } } } return ss.str(); } -// 函数级代码生成 +// function_gen 现在是新的、模块化的处理流水线 std::string RISCv64CodeGen::function_gen(Function* func) { - this->local_label_counter = 0; // 为每个函数重置本地标签计数器 + // 阶段 1: 指令选择 (sysy::IR -> LLIR with virtual registers) + RISCv64ISel isel; + std::unique_ptr mfunc = isel.runOnFunction(func); + + // 阶段 2: 寄存器分配 (包含栈帧布局, 活跃性分析, 图着色, spill/rewrite) + RISCv64RegAlloc reg_alloc(mfunc.get()); + reg_alloc.run(); + + // 阶段 3: 代码发射 (LLIR with physical regs -> Assembly Text) std::stringstream ss; - ss << ".globl " << func->getName() << "\n"; // 声明函数为全局符号 - ss << func->getName() << ":\n"; // 函数入口标签 - - RegAllocResult alloc_result = register_allocation(func); - int stack_size = alloc_result.stack_size; - - // 函数序言 (Prologue) - // RV64: ra 和 s0 都是64位(8字节)寄存器 - // 保存 ra 和 s0, 调整栈指针 - // s0 指向当前帧的底部(分配局部变量/溢出空间后的 sp) - // 确保栈大小 16 字节对齐 - int aligned_stack_size = (stack_size + 15) & ~15; - - // 只有当需要栈空间时才生成序言 - if (aligned_stack_size > 0) { - ss << " addi sp, sp, -" << aligned_stack_size << "\n"; // 调整栈指针 - // RV64 修改: 使用 sd (store doubleword) 保存 8 字节的 ra 和 s0 - // 同时更新偏移量,为每个寄存器保留8字节 - ss << " sd ra, " << (aligned_stack_size - 8) << "(sp)\n"; // 保存返回地址 (8字节) - ss << " sd s0, " << (aligned_stack_size - 16) << "(sp)\n"; // 保存帧指针 (8字节) - ss << " mv s0, sp\n"; // 设置新的帧指针 - } - - // 将传入的寄存器参数 (a0-a7 / f10-f17) 保存到对应的栈槽 (AllocaInst)。 - // RV64中,a0-a7是64位寄存器,但我们传入的int/float是32位。 - // 使用 sw/fsw 会正确地存储低32位,这是正确的行为。 - int arg_idx = 0; - BasicBlock* entry_bb = func->getEntryBlock(); // 获取函数的入口基本块 - - if (entry_bb) { // 确保入口基本块存在 - for (AllocaInst* alloca_for_param : entry_bb->getArguments()) { - if (arg_idx >= 8) { - std::cerr << "警告: 函数 '" << func->getName() << "' 的参数 (索引 " << arg_idx << ") 数量超过了 RISC-V 寄存器传递限制 (8个参数)。\n" - << " 这些参数目前未通过栈正确处理,可能导致错误。\n"; - break; - } - - if (alloc_result.stack_map.count(alloca_for_param)) { - int offset = alloc_result.stack_map.at(alloca_for_param); - Type* allocated_type = alloca_for_param->getType()->as()->getBaseType(); - - if (allocated_type->isInt()) { - PhysicalReg arg_reg = static_cast(static_cast(PhysicalReg::A0) + arg_idx); - std::string arg_reg_str = reg_to_string(arg_reg); - // 使用 sw 保存 int (32位) 参数,这是正确的 - ss << " sw " << arg_reg_str << ", " << offset << "(s0)\n"; - } else if (allocated_type->isFloat()) { - PhysicalReg farg_reg = static_cast(static_cast(PhysicalReg::F10) + arg_idx); - std::string farg_reg_str = reg_to_string(farg_reg); - // 使用 fsw 保存 float (32位) 参数,这是正确的 - ss << " fsw " << farg_reg_str << ", " << offset << "(s0)\n"; - } else { - throw std::runtime_error("Unsupported function argument type encountered during parameter saving to stack."); - } - } else { - std::cerr << "警告: 函数参数对应的 AllocaInst '" - << (alloca_for_param->getName().empty() ? "anonymous" : alloca_for_param->getName()) - << "' 没有在栈映射中找到。这可能导致后续代码生成错误。\n"; - } - arg_idx++; - } - } else { - std::cerr << "错误: 函数 '" << func->getName() << "' 没有入口基本块。\n"; - } - - // 生成每个基本块的代码 - int block_idx = 0; - for (const auto& bb : func->getBasicBlocks()) { - ss << basicBlock_gen(bb.get(), alloc_result, block_idx++); - } - - // 函数尾声 (Epilogue) 由 RETURN DAGNode 的指令选择处理 - return ss.str(); -} - - -// 基本块代码生成 -std::string RISCv64CodeGen::basicBlock_gen(BasicBlock* bb, const RegAllocResult& alloc, int block_idx) { - std::stringstream ss; - - std::string bb_name = bb->getName(); - if (bb_name.empty()) { - bb_name = ENTRY_BLOCK_PSEUDO_NAME + std::to_string(block_idx); - if (block_idx == 0) { - bb_name = "entry"; - } - } - else { - ss << bb_name << ":\n"; // 基本块标签 - } - if (DEBUG) std::cerr << "=== 生成基本块: " << bb_name << " ===\n"; - - // 构建当前基本块的 DAG - auto dag_nodes_for_bb = build_dag(bb); - if (DEBUG) - print_dag(dag_nodes_for_bb, bb_name); // 打印 DAG 调试信息 - - // 存储最终生成的指令 - std::set emitted_nodes; // 跟踪已发射的节点,防止重复 - std::vector ordered_insts; // 用于收集指令并按序排列 - - // 在 DAG 中遍历并生成指令。由于 select_instructions 可能会递归地为操作数选择指令, - // 并且 emit_instructions 也会递归地发射,我们需要一个机制来确保指令的正确顺序和唯一性。 - // 最简单的方法是逆拓扑序遍历所有节点,确保其操作数先被处理。 - // 但是目前的 DAG 构建方式可能不支持直接的拓扑排序, - // 我们将依赖 emit_instructions 的递归特性来处理依赖。 - - // 遍历 DAG 的根节点(没有用户的节点,或者 Store/Return/Branch 节点) - // 从这些节点开始递归发射指令。 - // NOTE: 这种发射方式可能不总是产生最优的代码顺序,但可以确保依赖关系。 - for (auto it = dag_nodes_for_bb.rbegin(); it != dag_nodes_for_bb.rend(); ++it) { - DAGNode* node = it->get(); - // 只有那些没有用户(或者代表副作用,如STORE, RETURN, BRANCH)的节点才需要作为发射的“根” - // 否则,它们会被其用户节点递归地发射 - // 然而,为了确保所有指令都被发射,我们通常从所有节点(或者至少是副作用节点)开始发射 - // 并且利用 emitted_nodes 集合防止重复 - // 这里简化为对所有 DAG 节点进行一次 select_instructions 和 emit_instructions 调用。 - // emit_instructions 会通过递归处理其操作数来保证依赖顺序。 - select_instructions(node, alloc); // 为当前节点选择指令 - } - - // 收集所有指令到一个临时的 vector 中,然后进行排序 - // 注意:这里的发射逻辑需要重新设计,目前的 emit_instructions 是直接添加到 std::vector& insts 中 - // 并且期望是按顺序添加的,这在递归时难以保证。 - // 更好的方法是让 emit_instructions 直接输出到 stringstream,并控制递归顺序。 - // 但是为了最小化改动,我们先保持 emit_instructions 的现有签名, - // 然后在它内部处理指令的收集和去重。 - - // 重新设计 emit_instructions 的调用方式 - // 这里的思路是,每个 DAGNode 都存储了自己及其依赖(如果未被其他节点引用)的指令。 - // 最终,我们遍历 BasicBlock 中的所有原始 IR 指令,找到它们对应的 DAGNode,然后发射。 - // 这是因为 IR 指令的顺序决定了代码的逻辑顺序。 - - // 遍历 IR 指令,并找到对应的 DAGNode 进行发射 - // 由于 build_dag 是从 IR 指令顺序构建的,我们应该按照 IR 指令的顺序来发射。 - emitted_nodes.clear(); // 再次清空已发射节点集合 - // 临时存储每个 IR 指令对应的 DAGNode,因为 DAGNode 列表是平铺的 - std::map inst_to_dag_node; - for (const auto& dag_node_ptr : dag_nodes_for_bb) { - if (dag_node_ptr->value && dynamic_cast(dag_node_ptr->value)) { - inst_to_dag_node[dynamic_cast(dag_node_ptr->value)] = dag_node_ptr.get(); - } - } - - for (const auto& inst_ptr : bb->getInstructions()) { - DAGNode* node_to_emit = nullptr; - // 查找当前 IR 指令在 DAG 中对应的节点。 - // 注意:不是所有 IR 指令都会直接映射到一个“根”DAGNode (例如,某些值可能只作为操作数存在) - // 但终结符(如 Branch, Return)和 Store 指令总是重要的。 - // 对于 load/binary 等,我们应该在 build_dag 中确保它们有一个结果 vreg,并被后续指令使用。 - // 如果一个 IR 指令是某个 DAGNode 的 value,那么我们就发射那个 DAGNode。 - if (inst_to_dag_node.count(inst_ptr.get())) { - node_to_emit = inst_to_dag_node.at(inst_ptr.get()); - } - - if (node_to_emit) { - // 注意:select_instructions 已经在上面统一调用过,这里只需要 emit。 - // 但如果 select_instructions 没有递归地为所有依赖选择指令,这里可能需要重新考虑。 - // 为了简化,我们假定 select_instructions 在第一次被调用时(通常在 emit 之前)已经递归地为所有操作数选择了指令。 - - // 直接将指令添加到 ss 中,而不是通过 vector 中转 - emit_instructions(node_to_emit, ss, alloc, emitted_nodes); - } - } + RISCv64AsmPrinter printer(mfunc.get()); + printer.run(ss); return ss.str(); } -// 辅助函数,用于创建 DAGNode 并管理其所有权 -sysy::RISCv64CodeGen::DAGNode* sysy::RISCv64CodeGen::create_node( - DAGNode::NodeKind kind, - Value* val, - std::map& value_to_node, // 需要外部传入 - std::vector>& nodes_storage // 需要外部传入 -) { - // 优化:如果一个值已经有节点并且它不是控制流/存储/Alloca地址/一元操作,则重用它 (CSE) - // 对于 AllocaInst,我们想创建一个代表其地址的节点,但不一定直接为 AllocaInst 本身分配虚拟寄存器。 - if (val && value_to_node.count(val) && kind != DAGNode::STORE && kind != DAGNode::RETURN && kind != DAGNode::BRANCH && kind != DAGNode::ALLOCA_ADDR && kind != DAGNode::UNARY) { - return value_to_node[val]; - } - - auto node = std::make_unique(kind); - node->value = val; - - // 为产生结果的值分配虚拟寄存器 - if (val && value_vreg_map.count(val) && !dynamic_cast(val)) { // 排除 AllocaInst - node->result_vreg = value_vreg_map.at(val); - } - - DAGNode* raw_node_ptr = node.get(); - nodes_storage.push_back(std::move(node)); // 存储 unique_ptr - - // 仅当 IR Value 表示一个计算值时,才将其映射到创建的 DAGNode - // 且它应该已经在 register_allocation 中被分配了 vreg - if (val && value_vreg_map.count(val) && kind != DAGNode::STORE && kind != DAGNode::RETURN && kind != DAGNode::BRANCH && !dynamic_cast(val)) { - value_to_node[val] = raw_node_ptr; - } - return raw_node_ptr; -} - - -// 辅助函数:获取值的 DAG 节点。 -// 如果 value 已经映射到 DAG 节点,则直接返回。 -// 如果是常量,则创建 CONSTANT 节点。 -// 如果是 AllocaInst,则创建 ALLOCA_ADDR 节点。 -// 否则,假定需要通过 LOAD 获取该值。 -sysy::RISCv64CodeGen::DAGNode* sysy::RISCv64CodeGen::get_operand_node( - Value* val_ir, - std::map& value_to_node, // 接受 value_to_node - std::vector>& nodes_storage // 接受 nodes_storage -) { - if (value_to_node.count(val_ir)) { - return value_to_node[val_ir]; - } else if (auto constant = dynamic_cast(val_ir)) { - return create_node(DAGNode::CONSTANT, constant, value_to_node, nodes_storage); // 调用成员函数版 create_node - } else if (auto alloca = dynamic_cast(val_ir)) { - return create_node(DAGNode::ALLOCA_ADDR, alloca, value_to_node, nodes_storage); // 调用成员函数版 create_node - } else if (auto global = dynamic_cast(val_ir)) { - // 确保 GlobalValue 也能正确处理,如果 DAGNode::CONSTANT 无法存储 GlobalValue*, - // 则需要新的 DAGNode 类型,例如 DAGNode::GLOBAL_ADDR - return create_node(DAGNode::CONSTANT, global, value_to_node, nodes_storage); // 调用成员函数版 create_node - } - // 这是一个尚未在此块中计算的值,假设它需要加载 (从内存或参数) - return create_node(DAGNode::LOAD, val_ir, value_to_node, nodes_storage); // 调用成员函数版 create_node -} - -std::vector> RISCv64CodeGen::build_dag(BasicBlock* bb) { - std::vector> nodes_storage; // 存储所有 unique_ptr - std::map value_to_node; // 将 IR Value* 映射到原始 DAGNode*,用于快速查找 - - for (const auto& inst_ptr : bb->getInstructions()) { - auto inst = inst_ptr.get(); - - if (auto alloca = dynamic_cast(inst)) { - // AllocaInst 本身不产生寄存器中的值,但其地址将被 load/store 使用。 - // 创建一个节点来表示分配内存的地址。 - // 这个地址将是 s0 (帧指针) 的偏移量。 - // 我们将 AllocaInst 指针存储在 DAGNode 的 `value` 字段中。 - // 修正:AllocaInst 类型的 DAGNode 应该有一个 value 对应 AllocaInst* - // 但它本身不应该有 result_vreg,因为不映射到物理寄存器。 - create_node(DAGNode::ALLOCA_ADDR, alloca, value_to_node, nodes_storage); - } else if (auto store = dynamic_cast(inst)) { - auto store_node = create_node(DAGNode::STORE, store, value_to_node, nodes_storage); - - // 获取要存储的值 - DAGNode* val_node = get_operand_node(store->getValue(), value_to_node, nodes_storage); - - // 获取内存位置的指针 (基地址) - Value* ptr_ir = store->getPointer(); - DAGNode* ptr_node = get_operand_node(ptr_ir, value_to_node, nodes_storage); - - store_node->operands.push_back(val_node); - store_node->operands.push_back(ptr_node); - ptr_node->users.push_back(store_node); - } else if (auto memset = dynamic_cast(inst)) { - auto memset_node = create_node(DAGNode::MEMSET, memset, value_to_node, nodes_storage); - - // 根据 IR.h 中的定义,获取 MemsetInst 的操作数 - DAGNode* pointer_node = get_operand_node(memset->getPointer(), value_to_node, nodes_storage); - DAGNode* begin_node = get_operand_node(memset->getBegin(), value_to_node, nodes_storage); - DAGNode* size_node = get_operand_node(memset->getSize(), value_to_node, nodes_storage); - DAGNode* value_node = get_operand_node(memset->getValue(), value_to_node, nodes_storage); - - // 将操作数节点添加到 MEMSET 节点的依赖列表中 - memset_node->operands.push_back(pointer_node); - memset_node->operands.push_back(begin_node); - memset_node->operands.push_back(size_node); - memset_node->operands.push_back(value_node); - - // 建立反向链接 - pointer_node->users.push_back(memset_node); - begin_node->users.push_back(memset_node); - size_node->users.push_back(memset_node); - value_node->users.push_back(memset_node); - } else if (auto load = dynamic_cast(inst)) { - auto load_node = create_node(DAGNode::LOAD, load, value_to_node, nodes_storage); - - // 获取内存位置的指针 (基地址) - Value* ptr_ir = load->getPointer(); - DAGNode* ptr_node = get_operand_node(ptr_ir, value_to_node, nodes_storage); - load_node->operands.push_back(ptr_node); - ptr_node->users.push_back(load_node); - } else if (auto bin = dynamic_cast(inst)) { - if (value_to_node.count(bin)) continue; // CSE - - if (bin->getKind() == BinaryInst::kSub || bin->getKind() == BinaryInst::kFSub) { - Value* lhs_ir = bin->getLhs(); - if (auto const_lhs = dynamic_cast(lhs_ir)) { - bool is_neg = false; - if (const_lhs->getType()->isInt()) { - if (const_lhs->getInt() == 0) { - is_neg = true; - } - } else if (const_lhs->getType()->isFloat()) { - if (std::fabs(const_lhs->getFloat()) < std::numeric_limits::epsilon()) { - is_neg = true; - } - } - - if (is_neg) { - auto unary_node = create_node(DAGNode::UNARY, bin, value_to_node, nodes_storage); // 传递参数 - Value* operand_ir = bin->getRhs(); - DAGNode* operand_node = get_operand_node(operand_ir, value_to_node, nodes_storage); // 传递参数 - unary_node->operands.push_back(operand_node); - operand_node->users.push_back(unary_node); - continue; - } - } - } - // 常规二进制操作 - auto bin_node = create_node(DAGNode::BINARY, bin, value_to_node, nodes_storage); // 传递参数 - - DAGNode* lhs_node = get_operand_node(bin->getLhs(), value_to_node, nodes_storage); // 传递参数 - DAGNode* rhs_node = get_operand_node(bin->getRhs(), value_to_node, nodes_storage); // 传递参数 - - bin_node->operands.push_back(lhs_node); - bin_node->operands.push_back(rhs_node); - lhs_node->users.push_back(bin_node); - rhs_node->users.push_back(bin_node); - - } else if (auto un_inst = dynamic_cast(inst)) { - if (value_to_node.count(un_inst)) continue; - - auto unary_node = create_node(DAGNode::UNARY, un_inst, value_to_node, nodes_storage); // 传递参数 - - Value* operand_ir = un_inst->getOperand(); - DAGNode* operand_node = get_operand_node(operand_ir, value_to_node, nodes_storage); // 传递参数 - - unary_node->operands.push_back(operand_node); - operand_node->users.push_back(unary_node); - - } else if (auto call = dynamic_cast(inst)) { - if (value_to_node.count(call)) continue; - auto call_node = create_node(DAGNode::CALL, call, value_to_node, nodes_storage); // 传递参数 - for (auto arg : call->getArguments()) { - auto arg_val_ir = arg->getValue(); - DAGNode* arg_node = get_operand_node(arg_val_ir, value_to_node, nodes_storage); // 传递参数 - call_node->operands.push_back(arg_node); - arg_node->users.push_back(call_node); - } - } else if (auto ret = dynamic_cast(inst)) { - if (DEBUG) std::cerr << "处理 RETURN 指令: " << ret->getName() << "\n"; // 调试输出 - auto ret_node = create_node(DAGNode::RETURN, ret, value_to_node, nodes_storage); // 传递参数 - if (ret->hasReturnValue()) { - auto val_ir = ret->getReturnValue(); - DAGNode* val_node = get_operand_node(val_ir, value_to_node, nodes_storage); // 传递参数 - ret_node->operands.push_back(val_node); - val_node->users.push_back(ret_node); - } - } else if (auto cond_br = dynamic_cast(inst)) { - auto br_node = create_node(DAGNode::BRANCH, cond_br, value_to_node, nodes_storage); // 传递参数 - auto cond_ir = cond_br->getCondition(); - - if (auto constant_cond = dynamic_cast(cond_ir)) { - br_node->inst = "j " + (constant_cond->getInt() ? cond_br->getThenBlock()->getName() : cond_br->getElseBlock()->getName()); - } else { - DAGNode* cond_node = get_operand_node(cond_ir, value_to_node, nodes_storage); // 传递参数 - br_node->operands.push_back(cond_node); - cond_node->users.push_back(br_node); - } - } else if (auto uncond_br = dynamic_cast(inst)) { - auto br_node = create_node(DAGNode::BRANCH, uncond_br, value_to_node, nodes_storage); // 传递参数 - br_node->inst = "j " + uncond_br->getBlock()->getName(); - } else { - // 其他指令类型(如 PHI, 可能的自定义指令等) - // 目前假设未处理的指令类型不需要特殊 DAGNode 类型 - // 可以在这里添加更多的处理逻辑 - if (DEBUG) std::cerr << "未处理的指令类型: " << inst->getKindString() << "\n"; - continue; // 跳过未处理的指令 - } - } - return nodes_storage; -} - -// 打印 DAG -void RISCv64CodeGen::print_dag(const std::vector>& dag, const std::string& bb_name) { - std::cerr << "=== DAG for Basic Block: " << bb_name << " ===\n"; - std::set visited; - - // 辅助映射,用于在打印输出中为节点分配顺序 ID - std::map node_to_id; - int current_id = 0; - for (const auto& node_ptr : dag) { - node_to_id[node_ptr.get()] = current_id++; - } - - std::function print_node = [&](DAGNode* node, int indent) { - if (!node) return; - - std::string current_indent(indent, ' '); - int node_id = node_to_id.count(node) ? node_to_id[node] : -1; // 获取分配的 ID - - std::cerr << current_indent << "Node#" << node_id << ": " << node->getNodeKindString(); - if (!node->result_vreg.empty()) { - std::cerr << " (vreg: " << node->result_vreg << ")"; - } - - if (node->value) { - std::cerr << " ["; - if (auto inst = dynamic_cast(node->value)) { - std::cerr << inst->getKindString(); - if (!inst->getName().empty()) { - std::cerr << "(" << inst->getName() << ")"; - } - } else if (auto constant = dynamic_cast(node->value)) { - if (constant->isInt()) { - std::cerr << "ConstInt(" << constant->getInt() << ")"; - } else { - std::cerr << "ConstFloat(" << constant->getFloat() << ")"; - } - } else if (auto global = dynamic_cast(node->value)) { - std::cerr << "Global(" << global->getName() << ")"; - } else if (auto alloca = dynamic_cast(node->value)) { - std::cerr << "Alloca(" << (alloca->getName().empty() ? ("%" + std::to_string(reinterpret_cast(alloca) % 1000)) : alloca->getName()) << ")"; - } - std::cerr << "]"; - } - std::cerr << " -> Inst: \"" << node->inst << "\""; // 打印选定的指令 - std::cerr << "\n"; - - if (visited.find(node) != visited.end()) { - std::cerr << current_indent << " (已打印后代)\n"; - return; // 避免循环的无限递归 - } - visited.insert(node); - - if (!node->operands.empty()) { - std::cerr << current_indent << " 操作数:\n"; - for (auto operand : node->operands) { - print_node(operand, indent + 4); - } - } - // 移除了 users 打印,以简化输出并避免 DAG 中的冗余递归。 - // Users 更适用于向上遍历,而不是向下遍历。 - }; - - // 遍历 DAG,以尊重依赖的方式打印。 - // 当前实现:遍历所有节点,从作为“根”的节点开始打印(没有用户或副作用节点)。 - // 每次打印新的根时,重置 visited 集合,以允许共享子图被重新打印(尽管这不是最高效的方式)。 - for (const auto& node_ptr : dag) { - // 只有那些没有用户或者表示副作用(如 store/branch/return)的节点才被视为“根” - // 这样可以确保所有指令(包括那些没有明确结果的)都被打印 - if (node_ptr->users.empty() || node_ptr->kind == DAGNode::STORE || node_ptr->kind == DAGNode::RETURN || node_ptr->kind == DAGNode::BRANCH) { - visited.clear(); // 为每个根重置 visited,允许重新打印共享子图 - print_node(node_ptr.get(), 0); - } - } - std::cerr << "=== DAG 结束 ===\n\n"; -} - -// 指令选择 -void RISCv64CodeGen::select_instructions(DAGNode* node, const RegAllocResult& alloc) { - if (!node) return; - if (!node->inst.empty()) return; // 指令已选择,跳过重复处理 - - // 递归地为操作数选择指令,确保依赖先被处理 - for (auto operand : node->operands) { - if (operand) { - select_instructions(operand, alloc); - } - } - - std::stringstream ss_inst; // 使用 stringstream 构建指令 - - // 获取分配的物理寄存器,若未分配则回退到 t0 - auto get_preg_or_temp = [&](const std::string& vreg) { - if (vreg.empty()) { // 添加对空 vreg 的明确检查 - if (DEBUG) std::cerr << "警告: 虚拟寄存器 (空字符串) 没有分配物理寄存器,使用临时寄存器 t0 代替。\n"; - return reg_to_string(PhysicalReg::T0); - } - if (alloc.vreg_to_preg.count(vreg)) { - return reg_to_string(alloc.vreg_to_preg.at(vreg)); - } - if (DEBUG) std::cerr << "警告: 虚拟寄存器 " << vreg << " 没有分配物理寄存器,使用临时寄存器 t0 代替。\n"; - return reg_to_string(PhysicalReg::T0); // 回退到临时寄存器 t0 - }; - - // 获取栈变量的内存偏移量 - auto get_stack_offset = [&](Value* val) -> std::string { // 返回类型明确为 std::string - if (alloc.stack_map.count(val)) { - if (DEBUG) { // 避免在非DEBUG模式下打印大量内容 - std::cout << "获取栈变量的内存偏移量,变量名: " << (val ? val->getName() : "unknown") << std::endl; - } - return std::to_string(alloc.stack_map.at(val)); - } - if (DEBUG) std::cerr << "警告: 栈变量 " << (val ? val->getName() : "unknown") << " 没有在栈映射中找到,使用默认偏移 0。\n"; - // 如果没有找到映射,返回默认偏移量 "0" - return std::string("0"); // 默认或错误情况 - }; - - switch (node->kind) { - case DAGNode::CONSTANT: { - // [V2 特性] CONSTANT 节点自身不再生成指令。 - // 它的存在是为了在DAG中标记一个常数值或全局地址。 - // 加载这个值的责任转移给了使用它的节点 (STORE, BINARY, CALL, RETURN)。 - break; - } - case DAGNode::ALLOCA_ADDR: { - // ALLOCA_ADDR 节点不直接生成指令,它仅作为一个地址标记,由 LOAD/STORE 或地址计算使用 - break; - } - case DAGNode::LOAD: { - // 处理加载指令 - if (node->operands.empty() || !node->operands[0]) break; - std::string dest_reg = get_preg_or_temp(node->result_vreg); - DAGNode* ptr_node = node->operands[0]; - - if (ptr_node->kind == DAGNode::ALLOCA_ADDR) { - if (auto alloca_inst = dynamic_cast(ptr_node->value)) { - int offset = alloc.stack_map.at(alloca_inst); - ss_inst << "lw " << dest_reg << ", " << offset << "(s0)"; - } - } else { - std::string ptr_reg = get_preg_or_temp(ptr_node->result_vreg); - - // [与STORE逻辑类似的修复] 如果是从一个全局地址加载,先用la加载地址 - if (ptr_node->kind == DAGNode::CONSTANT) { - if (auto global = dynamic_cast(ptr_node->value)) { - ss_inst << "la " << ptr_reg << ", " << global->getName() << "\n\t"; - } - } - - ss_inst << "lw " << dest_reg << ", 0(" << ptr_reg << ")"; - } - break; - } - case DAGNode::STORE: { - // 处理存储指令 - if (node->operands.size() < 2 || !node->operands[0] || !node->operands[1]) break; - DAGNode* val_node = node->operands[0]; - DAGNode* ptr_node = node->operands[1]; - - std::string src_reg = get_preg_or_temp(val_node->result_vreg); - - // [V2 特性] 如果要存储的值是一个常数,在此处立即加载它 - if (val_node->kind == DAGNode::CONSTANT) { - if (auto constant = dynamic_cast(val_node->value)) { - if (constant->isInt()) { - ss_inst << "li " << src_reg << ", " << constant->getInt() << "\n\t"; - } else { // 浮点数常量 - float f = constant->getFloat(); - uint32_t float_bits = *(uint32_t*)&f; - ss_inst << "li " << src_reg << ", " << float_bits << "\n\t"; - // 注意:这里假设 src_reg 是一个通用寄存器,需要额外指令移动到浮点寄存器 - // 如果 STORE 的目标是浮点数,这里逻辑需要调整 - } - } else if (auto global = dynamic_cast(val_node->value)) { - // 如果要存储的值是另一个全局变量的地址 - ss_inst << "la " << src_reg << ", " << global->getName() << "\n\t"; - } - } - - // [本次修复] 处理目标地址,并生成最终的存储指令 - if (ptr_node->kind == DAGNode::ALLOCA_ADDR) { - // 情况1: 存储到栈上的局部变量 - if (auto alloca_inst = dynamic_cast(ptr_node->value)) { - int offset = alloc.stack_map.at(alloca_inst); - ss_inst << "sw " << src_reg << ", " << offset << "(s0)"; - } - } else { - // 情况2: 存储到指针指向的地址 (可能是全局变量或已计算的地址) - std::string ptr_reg = get_preg_or_temp(ptr_node->result_vreg); - - // 如果指针本身是一个全局变量的地址,需要先用 la 加载它 - if (ptr_node->kind == DAGNode::CONSTANT) { - if (auto global = dynamic_cast(ptr_node->value)) { - ss_inst << "la " << ptr_reg << ", " << global->getName() << "\n\t"; - } - } - - // 生成最终的存储指令 - ss_inst << "sw " << src_reg << ", 0(" << ptr_reg << ")"; - } - break; - } - case DAGNode::BINARY: { - if (node->operands.size() < 2 || !node->operands[0] || !node->operands[1]) break; - auto bin = dynamic_cast(node->value); - if (!bin) break; - - std::string dest_reg = get_preg_or_temp(node->result_vreg); - DAGNode* lhs_node = node->operands[0]; - DAGNode* rhs_node = node->operands[1]; - - // [V1 特性] 检查是否是 base + offset 的地址计算 - if (bin->getKind() == BinaryInst::kAdd) { - DAGNode* base_node = nullptr; - DAGNode* offset_node = nullptr; - - if (lhs_node->kind == DAGNode::ALLOCA_ADDR) { base_node = lhs_node; offset_node = rhs_node; } - else if (rhs_node->kind == DAGNode::ALLOCA_ADDR) { base_node = rhs_node; offset_node = lhs_node; } - - if (base_node) { - if (auto alloca_inst = dynamic_cast(base_node->value)) { - std::string stack_offset_str = get_stack_offset(alloca_inst); - std::string index_offset_reg = get_preg_or_temp(offset_node->result_vreg); - - // 生成两条指令来计算最终地址: - // 1. addi 将 s0 加上 b 的栈偏移量得到 b 的实际基地址 - ss_inst << "addi " << dest_reg << ", s0, " << stack_offset_str << "\n"; - // 2. [V3 修复] add 将基地址和索引偏移量相加,得到最终地址。地址计算必须用64位指令。 - ss_inst << "\tadd " << dest_reg << ", " << dest_reg << ", " << index_offset_reg; - node->inst = ss_inst.str(); - return; // 指令已生成,提前返回 - } - } - } - - std::string lhs_reg = get_preg_or_temp(lhs_node->result_vreg); - std::string rhs_reg = get_preg_or_temp(rhs_node->result_vreg); - - // [V2 特性] & [上次修复] 在生成二元运算指令前,检查操作数是否为常数(立即数或全局地址),并按需加载 - if (lhs_node->kind == DAGNode::CONSTANT) { - if (auto c = dynamic_cast(lhs_node->value)) { - ss_inst << "li " << lhs_reg << ", " << c->getInt() << "\n\t"; - } else if (auto g = dynamic_cast(lhs_node->value)) { - ss_inst << "la " << lhs_reg << ", " << g->getName() << "\n\t"; - } - } - if (rhs_node->kind == DAGNode::CONSTANT) { - if (auto c = dynamic_cast(rhs_node->value)) { - // 优化:对于加法和小立即数,直接使用 addi 指令 - if (bin->getKind() == BinaryInst::kAdd && c->getInt() >= -2048 && c->getInt() < 2048) { - ss_inst << "addi " << dest_reg << ", " << lhs_reg << ", " << c->getInt(); - node->inst = ss_inst.str(); - return; // 指令已生成,提前返回 - } - // 否则,正常加载 - ss_inst << "li " << rhs_reg << ", " << c->getInt() << "\n\t"; - } else if (auto g = dynamic_cast(rhs_node->value)) { - ss_inst << "la " << rhs_reg << ", " << g->getName() << "\n\t"; - } - } - - std::string opcode; - switch (bin->getKind()) { - // [V1 特性] & [V3 修复] RV64 修改: 整数运算使用带 'w' 后缀的32位指令,地址运算使用64位指令 - case BinaryInst::kAdd: - // 通过检查操作数类型来决定使用 64位(add) 还是 32位(addw) 加法 - if (bin->getLhs()->getType()->isPointer() || bin->getRhs()->getType()->isPointer()) { - opcode = "add"; // 指针/地址运算 - } else { - opcode = "addw"; // 普通整数运算 - } - break; - case BinaryInst::kSub: opcode = "subw"; break; - case BinaryInst::kMul: opcode = "mulw"; break; - case Instruction::kDiv: opcode = "divw"; break; - case Instruction::kRem: opcode = "remw"; break; - case BinaryInst::kICmpEQ: - ss_inst << "subw " << dest_reg << ", " << lhs_reg << ", " << rhs_reg << "\n"; - ss_inst << "\tseqz " << dest_reg << ", " << dest_reg; - node->inst = ss_inst.str(); - return; - case BinaryInst::kICmpGE: - ss_inst << "slt " << dest_reg << ", " << lhs_reg << ", " << rhs_reg << "\n"; - ss_inst << "\txori " << dest_reg << ", " << dest_reg << ", 1"; - node->inst = ss_inst.str(); - return; - case BinaryInst::kICmpGT: - ss_inst << "slt " << dest_reg << ", " << rhs_reg << ", " << lhs_reg; - node->inst = ss_inst.str(); - return; - case BinaryInst::kICmpLE: - ss_inst << "slt " << dest_reg << ", " << rhs_reg << ", " << lhs_reg << "\n"; - ss_inst << "\txori " << dest_reg << ", " << dest_reg << ", 1"; - node->inst = ss_inst.str(); - return; - case BinaryInst::kICmpLT: - ss_inst << "slt " << dest_reg << ", " << lhs_reg << ", " << rhs_reg; - node->inst = ss_inst.str(); - return; - case BinaryInst::kICmpNE: - ss_inst << "subw " << dest_reg << ", " << lhs_reg << ", " << rhs_reg << "\n"; - ss_inst << "\tsnez " << dest_reg << ", " << dest_reg; - node->inst = ss_inst.str(); - return; - default: - // [V1 特性] 保留对未实现指令的定义和报错 - throw std::runtime_error("不支持的二元指令类型: " + bin->getKindString()); - } - if (!opcode.empty()) { - ss_inst << opcode << " " << dest_reg << ", " << lhs_reg << ", " << rhs_reg; - } - break; - } - case DAGNode::UNARY: { - if (node->operands.empty() || !node->operands[0]) break; - auto unary = dynamic_cast(node->value); - if (!unary) break; - - std::string dest_reg = get_preg_or_temp(node->result_vreg); - std::string src_reg = get_preg_or_temp(node->operands[0]->result_vreg); - - switch (unary->getKind()) { - case UnaryInst::kNeg: - // RV64 修改: 使用 subw 实现32位取负 (negw 伪指令) - ss_inst << "subw " << dest_reg << ", x0, " << src_reg; - break; - case UnaryInst::kNot: - // 整数逻辑非:seqz rd, rs (rs == 0 时 rd = 1,否则 rd = 0) - ss_inst << "seqz " << dest_reg << ", " << src_reg; - break; - // [V1 特性] 保留对未实现指令的定义和报错 - case UnaryInst::kFNeg: - case UnaryInst::kFNot: - case UnaryInst::kFtoI: - case UnaryInst::kItoF: - case UnaryInst::kBitFtoI: - case UnaryInst::kBitItoF: - throw std::runtime_error("不支持的浮点一元指令类型: " + unary->getKindString()); - default: - throw std::runtime_error("不支持的一元指令类型: " + unary->getKindString()); - } - break; - } - case DAGNode::CALL: { - // 处理函数调用指令 - if (!node->value) break; - auto call = dynamic_cast(node->value); - if (!call) break; - - // [V2/V3 特性] 修正参数处理逻辑 - for (size_t i = 0; i < node->operands.size() && i < 8; ++i) { - DAGNode* arg_node = node->operands[i]; - if (!arg_node) continue; - - std::string arg_preg = reg_to_string(static_cast(static_cast(PhysicalReg::A0) + i)); - - // 优先检查参数节点是否为常量,并直接加载 - if (arg_node->kind == DAGNode::CONSTANT) { - if (auto const_val = dynamic_cast(arg_node->value)) { - ss_inst << "li " << arg_preg << ", " << const_val->getInt() << "\n\t"; - } else if (auto global_val = dynamic_cast(arg_node->value)) { - ss_inst << "la " << arg_preg << ", " << global_val->getName() << "\n\t"; - } - } - // 如果不是常量,说明是一个计算出的值,使用 mv 指令移动 - else { - std::string src_reg = get_preg_or_temp(arg_node->result_vreg); - ss_inst << "mv " << arg_preg << ", " << src_reg << "\n\t"; - } - } - ss_inst << "call " << call->getCallee()->getName(); - - // 处理返回值 - if ((call->getType()->isInt() || call->getType()->isFloat()) && !node->result_vreg.empty()) { - ss_inst << "\nmv " << get_preg_or_temp(node->result_vreg) << ", a0"; - } - break; - } - case DAGNode::RETURN: { - // 处理返回指令 - if (!node->operands.empty() && node->operands[0]) { - DAGNode* ret_val_node = node->operands[0]; - - // [V2 特性] & [上次修复] 如果返回值是常量(立即数或全局地址),直接加载到 a0 - if (ret_val_node->kind == DAGNode::CONSTANT) { - if (auto c = dynamic_cast(ret_val_node->value)) { - ss_inst << "li a0, " << c->getInt() << "\n"; - } else if (auto g = dynamic_cast(ret_val_node->value)) { - ss_inst << "la a0, " << g->getName() << "\n"; - } - } else { - std::string return_val_reg = get_preg_or_temp(ret_val_node->result_vreg); - ss_inst << "mv a0, " << return_val_reg << "\n"; - } - } - - // 恢复栈帧 - if (alloc.stack_size > 0) { - int aligned_stack_size = (alloc.stack_size + 15) & ~15; - // RV64 修改: 使用 ld (load doubleword) 恢复 8 字节的 ra 和 s0 - ss_inst << "\tld ra, " << (aligned_stack_size - 8) << "(sp)\n"; - ss_inst << "\tld s0, " << (aligned_stack_size - 16) << "(sp)\n"; - ss_inst << "\taddi sp, sp, " << aligned_stack_size << "\n"; - } - ss_inst << "\tret"; - break; - } - case DAGNode::BRANCH: { - // 处理分支指令 - auto br = dynamic_cast(node->value); - auto uncond_br = dynamic_cast(node->value); - - if (br) { // 条件分支 - if (node->operands.empty() || !node->operands[0]) break; - std::string cond_reg = get_preg_or_temp(node->operands[0]->result_vreg); - std::string then_block = br->getThenBlock()->getName(); - std::string else_block = br->getElseBlock()->getName(); - - // [V1 特性] 如果基本块没有名字(例如,匿名块),给它一个伪名称 - if (then_block.empty()) { - then_block = ENTRY_BLOCK_PSEUDO_NAME + "_then_" + std::to_string(this->local_label_counter++); - } - if (else_block.empty()) { - else_block = ENTRY_BLOCK_PSEUDO_NAME + "_else_" + std::to_string(this->local_label_counter++); - } - - ss_inst << "bnez " << cond_reg << ", " << then_block << "\n"; - ss_inst << "\tj " << else_block; - } else if (uncond_br) { // 无条件分支 - std::string target_block = uncond_br->getBlock()->getName(); - if (target_block.empty()) { - target_block = ENTRY_BLOCK_PSEUDO_NAME + "_target_" + std::to_string((uintptr_t)uncond_br); - } - ss_inst << "j " << target_block; - } - break; - } - case DAGNode::MEMSET: { - if (node->operands.size() < 4) break; - - // 1. 获取操作数被分配到的物理寄存器 - // 您的 IR 中 pointer 和 begin 可能是同一个值,这里我们假设 pointer 是基地址 - DAGNode* ptr_node = node->operands[0]; - DAGNode* size_node = node->operands[2]; - DAGNode* value_node = node->operands[3]; - - std::string R_DEST_ADDR = get_preg_or_temp(ptr_node->result_vreg); - std::string R_NUM_BYTES = get_preg_or_temp(size_node->result_vreg); - std::string R_VALUE_BYTE = get_preg_or_temp(value_node->result_vreg); - - // 2. 定义我们将要使用的临时寄存器 - // 由于我们在冲突图中添加了规则,可以安全地使用这些调用者保存寄存器 - std::string R_COUNTER = "t3"; // 循环计数器 (字节) - std::string R_END_ADDR = "t4"; // 结束地址 - std::string R_CURRENT_ADDR = "t5"; // 当前写入地址 - std::string R_TEMP_VAL = "t6"; // 64位填充值 - - // 使用 local_label_counter 为 memset 循环生成唯一的、整洁的标签 - int unique_id = this->local_label_counter++; - std::string loop_start_label = "memset_loop_start_" + std::to_string(unique_id); - std::string loop_end_label = "memset_loop_end_" + std::to_string(unique_id); - std::string remainder_label = "memset_remainder_" + std::to_string(unique_id); - std::string done_label = "memset_done_" + std::to_string(unique_id); - - // 3. 生成汇编代码 - ss_inst << "# --- Memset Start ---\n"; - ss_inst << " andi " << R_TEMP_VAL << ", " << R_VALUE_BYTE << ", 255\n"; - ss_inst << " slli " << R_VALUE_BYTE << ", " << R_TEMP_VAL << ", 8\n"; - ss_inst << " or " << R_TEMP_VAL << ", " << R_TEMP_VAL << ", " << R_VALUE_BYTE << "\n"; - ss_inst << " slli " << R_VALUE_BYTE << ", " << R_TEMP_VAL << ", 16\n"; - ss_inst << " or " << R_TEMP_VAL << ", " << R_TEMP_VAL << ", " << R_VALUE_BYTE << "\n"; - ss_inst << " slli " << R_VALUE_BYTE << ", " << R_TEMP_VAL << ", 32\n"; - ss_inst << " or " << R_TEMP_VAL << ", " << R_TEMP_VAL << ", " << R_VALUE_BYTE << "\n"; - ss_inst << " add " << R_END_ADDR << ", " << R_DEST_ADDR << ", " << R_NUM_BYTES << "\n"; - ss_inst << " mv " << R_CURRENT_ADDR << ", " << R_DEST_ADDR << "\n"; - ss_inst << " andi " << R_COUNTER << ", " << R_NUM_BYTES << ", -8\n"; - ss_inst << " add " << R_COUNTER << ", " << R_DEST_ADDR << ", " << R_COUNTER << "\n"; - ss_inst << loop_start_label << ":\n"; - ss_inst << " bgeu " << R_CURRENT_ADDR << ", " << R_COUNTER << ", " << loop_end_label << "\n"; - ss_inst << " sd " << R_TEMP_VAL << ", 0(" << R_CURRENT_ADDR << ")\n"; - ss_inst << " addi " << R_CURRENT_ADDR << ", " << R_CURRENT_ADDR << ", 8\n"; - ss_inst << " j " << loop_start_label << "\n"; - ss_inst << loop_end_label << ":\n"; - ss_inst << remainder_label << ":\n"; - ss_inst << " bgeu " << R_CURRENT_ADDR << ", " << R_END_ADDR << ", " << done_label << "\n"; - ss_inst << " sb " << R_TEMP_VAL << ", 0(" << R_CURRENT_ADDR << ")\n"; - ss_inst << " addi " << R_CURRENT_ADDR << ", " << R_CURRENT_ADDR << ", 1\n"; - ss_inst << " j " << remainder_label << "\n"; - ss_inst << done_label << ":\n"; - ss_inst << "# --- Memset End ---"; - break; - } - default: - throw std::runtime_error("不支持的节点类型: " + node->getNodeKindString()); - } - node->inst = ss_inst.str(); // 存储生成的指令 -} - -// 指令发射 -void RISCv64CodeGen::emit_instructions(DAGNode* node, std::stringstream& ss, const RegAllocResult& alloc, std::set& emitted_nodes) { - if (!node || emitted_nodes.count(node)) { - return; // 如果节点为空或已经发射过,则返回 - } - - // 递归地发射操作数,以确保满足指令依赖 - for (auto operand : node->operands) { - if (operand) { - emit_instructions(operand, ss, alloc, emitted_nodes); - } - } - - // 标记当前节点为已发射,防止重复 - emitted_nodes.insert(node); - - // node->inst 中可能包含由 \n 分隔的多行指令和标签 - std::stringstream node_inst_ss(node->inst); - std::string line; - - while (std::getline(node_inst_ss, line, '\n')) { - // 首先,移除行首和行尾的空白字符,方便后续判断 - line = std::regex_replace(line, std::regex("^\\s+|\\s+$"), ""); - - if (line.empty()) { - continue; // 跳过空行 - } - - // ====================== 核心修正逻辑 ====================== - // 判断当前行是否是一个标签(即,非空且以':'结尾) - if (!line.empty() && line.back() == ':') { - // 如果是标签,直接打印,不加前导缩进 - ss << line << "\n"; - } else { - // 如果是常规指令,添加4个空格的前导缩进后再打印 - ss << " " << line << "\n"; - } - // ======================================================== - } -} - -// 辅助函数:将集合打印为字符串 -std::string print_set(const std::set& s) { - std::stringstream ss; - ss << "{"; - bool first = true; - for (const auto& elem : s) { - if (!first) { - ss << ", "; - } - ss << elem; - first = false; - } - ss << "}"; - return ss.str(); -} - -// 活跃性分析(更新以返回 live_in 和 live_out) -LivenessResult RISCv64CodeGen::liveness_analysis(Function* func) { - LivenessResult result; - bool changed = true; - - // 初始化 live_in 和 live_out - for (const auto& bb : func->getBasicBlocks()) { - for (const auto& inst_ptr : bb->getInstructions()) { - result.live_in[inst_ptr.get()] = {}; - result.live_out[inst_ptr.get()] = {}; - } - } - - int iteration_count = 0; - while (changed) { - changed = false; - iteration_count++; - if (DEEPDEBUG) std::cerr << "\n--- 活跃性分析迭代: " << iteration_count << " ---" << std::endl; - - // 逆序遍历基本块 - for (auto it = func->getBasicBlocks_NoRange().rbegin(); it != func->getBasicBlocks_NoRange().rend(); ++it) { - auto bb = it->get(); - if (DEEPDEBUG) std::cerr << " 基本块: " << bb->getName() << std::endl; - - // 计算基本块末尾的 live_out 集合,即所有后继基本块 live_in 的并集 - std::set live_out_for_bb; - for (const auto& succ_bb : bb->getSuccessors()) { - if (!succ_bb->getInstructions().empty()) { - Instruction* first_inst_in_succ = succ_bb->getInstructions().front().get(); - if (result.live_in.count(first_inst_in_succ)) { - const auto& succ_live_in = result.live_in.at(first_inst_in_succ); - live_out_for_bb.insert(succ_live_in.begin(), succ_live_in.end()); - } - } - } - - // 逆序遍历指令 - for (auto inst_it = bb->getInstructions().rbegin(); inst_it != bb->getInstructions().rend(); ++inst_it) { - auto inst = inst_it->get(); - if (DEEPDEBUG) std::cerr << " 指令 (BB: " << bb->getName() << ", 地址: " << static_cast(inst) << ")" << std::endl; - - std::set current_live_in = result.live_in[inst]; - std::set current_live_out = result.live_out[inst]; - std::set new_live_out; - - // 计算当前指令的 live_out - if (inst_it == bb->getInstructions().rbegin()) { - // 对于块中的最后一条指令,其 live_out 是块的 live_out - new_live_out = live_out_for_bb; - } else { - // 否则,其 live_out 是其后继指令的 live_in - auto next_inst_it = std::prev(inst_it); - new_live_out = result.live_in[next_inst_it->get()]; - } - - std::set use_set, def_set; - - // 定义 (Def) - if (value_vreg_map.count(inst) && !inst->getType()->isVoid() && !dynamic_cast(inst) && !dynamic_cast(inst) && - !dynamic_cast(inst) && !dynamic_cast(inst) && !dynamic_cast(inst)) { - def_set.insert(value_vreg_map.at(inst)); - } - - // 使用 (Use) - for (const auto& operand_use : inst->getOperands()) { - Value* operand = operand_use->getValue(); - if (value_vreg_map.count(operand) && !dynamic_cast(operand)) { - use_set.insert(value_vreg_map.at(operand)); - } - } - - // 数据流方程: live_in[i] = use[i] U (live_out[i] - def[i]) - std::set new_live_in = use_set; - for (const auto& vreg : new_live_out) { - if (def_set.find(vreg) == def_set.end()) { - new_live_in.insert(vreg); - } - } - - // 如果活跃性集合发生变化,更新并继续迭代 - if (new_live_in != current_live_in || new_live_out != current_live_out) { - result.live_in[inst] = new_live_in; - result.live_out[inst] = new_live_out; - changed = true; - } - } - } - } - return result; -} - - -// 干扰图构建 (使用正确的 live_out 集合) -std::map> RISCv64CodeGen::build_interference_graph( - const LivenessResult& liveness) { - std::map> graph; - - // 初始化图,确保每个虚拟寄存器都有一个节点 - for (const auto& pair : value_vreg_map) { - graph[pair.second] = {}; - } - - // 遍历每条指令来构建干扰 - for (const auto& pair : liveness.live_out) { - Instruction* inst = pair.first; - const auto& live_out_set = pair.second; - - // 获取该指令定义的 vreg - std::string defined_vreg; - if (value_vreg_map.count(inst) && !inst->getType()->isVoid() && !dynamic_cast(inst) && !dynamic_cast(inst) && - !dynamic_cast(inst) && !dynamic_cast(inst) && !dynamic_cast(inst)) { - defined_vreg = value_vreg_map.at(inst); - } - - // --- 规则一 和 新增的规则三 --- - if (!defined_vreg.empty()) { - // 您的“规则一”:定义与出口活跃寄存器冲突 (保留) - for (const auto& live_vreg : live_out_set) { - if (defined_vreg != live_vreg) { - graph[defined_vreg].insert(live_vreg); - graph[live_vreg].insert(defined_vreg); - } - } - - // ====================== 新增的、缺失的规则三 ====================== - // 定义与在同一指令中并发使用的寄存器冲突 - for (const auto& operand_use : inst->getOperands()) { - Value* operand = operand_use->getValue(); - if (value_vreg_map.count(operand)) { - const std::string& use_vreg = value_vreg_map.at(operand); - if (defined_vreg != use_vreg) { - graph[defined_vreg].insert(use_vreg); - graph[use_vreg].insert(defined_vreg); - } - } - } - // ================================================================= - } - - // --- 您的“规则二”:特殊指令内部的操作数之间也存在干扰 (保留) --- - if (auto store = dynamic_cast(inst)) { - Value* val_operand = store->getValue(); - Value* ptr_operand = store->getPointer(); - if (value_vreg_map.count(val_operand) && value_vreg_map.count(ptr_operand)) { - const std::string& val_vreg = value_vreg_map.at(val_operand); - const std::string& ptr_vreg = value_vreg_map.at(ptr_operand); - if (val_vreg != ptr_vreg) { - graph[val_vreg].insert(ptr_vreg); - graph[ptr_vreg].insert(val_vreg); - } - } - } else if (auto bin = dynamic_cast(inst)) { - Value* lhs_operand = bin->getLhs(); - Value* rhs_operand = bin->getRhs(); - if (value_vreg_map.count(lhs_operand) && value_vreg_map.count(rhs_operand)) { - const std::string& lhs_vreg = value_vreg_map.at(lhs_operand); - const std::string& rhs_vreg = value_vreg_map.at(rhs_operand); - if (lhs_vreg != rhs_vreg) { - graph[lhs_vreg].insert(rhs_vreg); - graph[rhs_vreg].insert(lhs_vreg); - } - } - } else if (auto call = dynamic_cast(inst)) { - std::vector arg_vregs; - for (auto arg_use : call->getArguments()) { - Value* arg_val = arg_use->getValue(); - if (value_vreg_map.count(arg_val)) { - arg_vregs.push_back(value_vreg_map.at(arg_val)); - } - } - for (size_t i = 0; i < arg_vregs.size(); ++i) { - for (size_t j = i + 1; j < arg_vregs.size(); ++j) { - graph[arg_vregs[i]].insert(arg_vregs[j]); - graph[arg_vregs[j]].insert(arg_vregs[i]); - } - } - } else if (auto memset = dynamic_cast(inst)) { - // 规则:MemsetInst 像一个函数调用,它会污染临时寄存器。 - // 因此,所有跨越这条指令的活跃变量(live_out), - // 都应该与这条指令的操作数(use)互相冲突。 - // 这会强制分配器将它们放入不同的寄存器中,或者安全地保存/恢复。 - - std::set use_set; - for (const auto& operand_use : memset->getOperands()) { - Value* operand = operand_use->getValue(); - if (value_vreg_map.count(operand)) { - use_set.insert(value_vreg_map.at(operand)); - } - } - - for (const auto& live_vreg : live_out_set) { - for (const auto& use_vreg : use_set) { - if (live_vreg != use_vreg) { - graph[live_vreg].insert(use_vreg); - graph[use_vreg].insert(live_vreg); - } - } - } - } - } - return graph; -} - -// 图着色(支持浮点寄存器) -void RISCv64CodeGen::color_graph(std::map& vreg_to_preg, - const std::map>& interference_graph) { - vreg_to_preg.clear(); - - // 分离整数和浮点寄存器池 - std::vector int_regs = { - PhysicalReg::T0, PhysicalReg::T1, PhysicalReg::T2, PhysicalReg::T3, - PhysicalReg::T4, PhysicalReg::T5, PhysicalReg::T6, - PhysicalReg::A0, PhysicalReg::A1, PhysicalReg::A2, PhysicalReg::A3, - PhysicalReg::A4, PhysicalReg::A5, PhysicalReg::A6, PhysicalReg::A7, - PhysicalReg::S0, PhysicalReg::S1, PhysicalReg::S2, PhysicalReg::S3, - PhysicalReg::S4, PhysicalReg::S5, PhysicalReg::S6, PhysicalReg::S7, - PhysicalReg::S8, PhysicalReg::S9, PhysicalReg::S10, PhysicalReg::S11 - }; - std::vector float_regs = { - PhysicalReg::F0, PhysicalReg::F1, PhysicalReg::F2, PhysicalReg::F3, - PhysicalReg::F4, PhysicalReg::F5, PhysicalReg::F6, PhysicalReg::F7, - PhysicalReg::F8, PhysicalReg::F9, PhysicalReg::F10, PhysicalReg::F11, - PhysicalReg::F12, PhysicalReg::F13, PhysicalReg::F14, PhysicalReg::F15, - PhysicalReg::F16, PhysicalReg::F17, PhysicalReg::F18, PhysicalReg::F19, - PhysicalReg::F20, PhysicalReg::F21, PhysicalReg::F22, PhysicalReg::F23, - PhysicalReg::F24, PhysicalReg::F25, PhysicalReg::F26, PhysicalReg::F27, - PhysicalReg::F28, PhysicalReg::F29, PhysicalReg::F30, PhysicalReg::F31 - }; - - // 确定虚拟寄存器类型(整数或浮点) - auto is_float_vreg = [&](const std::string& vreg) -> bool { - for (const auto& pair : value_vreg_map) { - if (pair.second == vreg) { - if (auto inst = dynamic_cast(pair.first)) { - if (inst->isUnary()) { - switch (inst->getKind()) { - case Instruction::kFNeg: - case Instruction::kFNot: - case Instruction::kFtoI: - case Instruction::kItoF: - case Instruction::kBitFtoI: - case Instruction::kBitItoF: - return true; // 浮点相关指令 - default: - return inst->getType()->isFloat(); - } - } - return inst->getType()->isFloat(); - } else if (auto constant = dynamic_cast(pair.first)) { - return constant->isFloat(); - } - } - } - return false; // 默认整数 - }; - - // 按度数排序虚拟寄存器 - std::vector> vreg_degrees; - for (const auto& entry : interference_graph) { - vreg_degrees.push_back({entry.first, (int)entry.second.size()}); - } - std::sort(vreg_degrees.begin(), vreg_degrees.end(), - [](const auto& a, const auto& b) { return a.second > b.second; }); - - for (const auto& vreg_deg_pair : vreg_degrees) { - const std::string& vreg = vreg_deg_pair.first; - std::set used_colors; - bool is_float = is_float_vreg(vreg); - - // 收集邻居使用的颜色 - if (interference_graph.count(vreg)) { - for (const auto& neighbor_vreg : interference_graph.at(vreg)) { - if (vreg_to_preg.count(neighbor_vreg)) { - used_colors.insert(vreg_to_preg.at(neighbor_vreg)); - } - } - } - - // 选择合适的寄存器池 - const auto& available_regs = is_float ? float_regs : int_regs; - - // 查找第一个可用的寄存器 - bool colored = false; - for (PhysicalReg preg : available_regs) { - if (used_colors.find(preg) == used_colors.end()) { - vreg_to_preg[vreg] = preg; - colored = true; - break; - } - } - - if (!colored) { - std::cerr << "警告: 无法为 " << vreg << " 分配" << (is_float ? "浮点" : "整数") << "寄存器,将溢出到栈。\n"; - // 溢出处理:在 stack_map 中分配栈空间 - // 这里假设每个溢出变量占用 4 字节 - // 注意:实际中需要区分整数和浮点溢出的存储指令(如 sw vs fsw) - } - } -} - -// 寄存器分配 -RISCv64CodeGen::RegAllocResult RISCv64CodeGen::register_allocation(Function* func) { - eliminate_phi(func); - vreg_counter = 0; - value_vreg_map.clear(); - - // 为所有产生值的指令和操作数分配虚拟寄存器 - for (const auto& bb_ptr : func->getBasicBlocks()) { - for (const auto& inst_ptr : bb_ptr->getInstructions()) { - Instruction* inst = inst_ptr.get(); - if (!inst->getType()->isVoid() && !dynamic_cast(inst)) { - if (value_vreg_map.find(inst) == value_vreg_map.end()) { - value_vreg_map[inst] = "v" + std::to_string(vreg_counter++); - } - } - for (const auto& operand_use : inst->getOperands()) { - Value* operand = operand_use->getValue(); - if (dynamic_cast(operand) || dynamic_cast(operand)) { - if (value_vreg_map.find(operand) == value_vreg_map.end()) { - value_vreg_map[operand] = "v" + std::to_string(vreg_counter++); - } - } else if (auto op_inst = dynamic_cast(operand)) { - if (!op_inst->getType()->isVoid() && !dynamic_cast(operand)) { - if (value_vreg_map.find(operand) == value_vreg_map.end()) { - value_vreg_map[operand] = "v" + std::to_string(vreg_counter++); - } - } - } - } - } - } - - RegAllocResult alloc_result; - int current_stack_offset = 0; - std::set allocas_in_func; - - // 为 AllocaInst 计算栈空间并分配偏移量 - for (const auto& bb_ptr : func->getBasicBlocks()) { - for (const auto& inst_ptr : bb_ptr->getInstructions()) { - if (auto alloca = dynamic_cast(inst_ptr.get())) { - allocas_in_func.insert(alloca); - } - } - } - for (auto alloca : allocas_in_func) { - int total_size = 4; - auto dims = alloca->getDims(); - if (!dims.empty()) { - int num_elements = 1; - for (const auto& dim_use : dims) { - Value* dim_value = dim_use->getValue(); - if (auto const_dim = dynamic_cast(dim_value)) { - num_elements *= const_dim->getInt(); - } else { - throw std::runtime_error("数组维度必须是编译时常量"); - } - } - total_size *= num_elements; - } - alloc_result.stack_map[alloca] = current_stack_offset; - current_stack_offset += total_size; - } - // 为保存的 ra 和 s0 (各8字节) 预留16字节空间 - alloc_result.stack_size = current_stack_offset + 16; - - // 活跃性分析 - LivenessResult liveness = liveness_analysis(func); - - // 构建干扰图 - std::map> interference_graph = build_interference_graph(liveness); - - // 图着色 - color_graph(alloc_result.vreg_to_preg, interference_graph); - - if (DEBUG) { - std::cerr << "=== 寄存器分配结果 (vreg_to_preg) ===\n"; - for (const auto& pair : alloc_result.vreg_to_preg) { - std::cerr << " " << pair.first << " -> " << reg_to_string(pair.second) << "\n"; - } - std::cerr << "=== 寄存器分配结果结束 ===\n\n"; - - std::cerr << "=== 活跃性分析结果 (live_in sets) ===\n"; - for (const auto& bb_ptr : func->getBasicBlocks()) { - std::cerr << "Basic Block: " << bb_ptr->getName() << "\n"; - for (const auto& inst_ptr : bb_ptr->getInstructions()) { - std::cerr << " Inst: " << inst_ptr->getKindString(); - if (!inst_ptr->getName().empty()) std::cerr << "(" << inst_ptr->getName() << ")"; - if (value_vreg_map.count(inst_ptr.get())) std::cerr << " (Def vreg: " << value_vreg_map.at(inst_ptr.get()) << ")"; - - std::cerr << " (Live In: "; - if (liveness.live_in.count(inst_ptr.get())) { - std::cerr << print_set(liveness.live_in.at(inst_ptr.get())); - } else { - std::cerr << "{}"; - } - std::cerr << ")\n"; - } - } - std::cerr << "=== 活跃性分析结果结束 ===\n\n"; - - std::cerr << "=== 干扰图 ===\n"; - for (const auto& pair : interference_graph) { - std::cerr << " " << pair.first << ": " << print_set(pair.second) << "\n"; - } - std::cerr << "=== 干扰图结束 ===\n\n"; - } - - return alloc_result; -} - -// Phi 消除 (简化版,将 Phi 的结果直接复制到每个前驱基本块的末尾) -void RISCv64CodeGen::eliminate_phi(Function* func) { - // 这是一个占位符。适当的 phi 消除将涉及 - // 在每个前驱基本块的末尾插入 `mov` 指令,用于每个 phi 操作数。 - // 对于给定的 IR 示例,没有 phi 节点,所以这可能不是严格必要的, - // 但如果前端生成 phi 节点,则有此阶段是好的做法。 - // 目前,我们假设没有生成 phi 节点或者它们已在前端处理。 -} - } // namespace sysy \ No newline at end of file diff --git a/src/RISCv64ISel.cpp b/src/RISCv64ISel.cpp new file mode 100644 index 0000000..005ba96 --- /dev/null +++ b/src/RISCv64ISel.cpp @@ -0,0 +1,635 @@ +#include "RISCv64ISel.h" +#include +#include +#include +#include // For std::fabs +#include // For std::numeric_limits + +namespace sysy { + +// DAG节点定义 (内部实现) +struct RISCv64ISel::DAGNode { + enum NodeKind { CONSTANT, LOAD, STORE, BINARY, CALL, RETURN, BRANCH, ALLOCA_ADDR, UNARY, MEMSET }; + NodeKind kind; + Value* value = nullptr; + std::vector operands; + std::vector users; + DAGNode(NodeKind k) : kind(k) {} +}; + +RISCv64ISel::RISCv64ISel() : vreg_counter(0), local_label_counter(0) {} + +// 为一个IR Value获取或分配一个新的虚拟寄存器 +unsigned RISCv64ISel::getVReg(Value* val) { + if (!val) { + throw std::runtime_error("Cannot get vreg for a null Value."); + } + if (vreg_map.find(val) == vreg_map.end()) { + if (vreg_counter == 0) { + vreg_counter = 1; // vreg 0 保留 + } + vreg_map[val] = vreg_counter++; + } + return vreg_map.at(val); +} + +// 主入口函数 +std::unique_ptr RISCv64ISel::runOnFunction(Function* func) { + F = func; + if (!F) return nullptr; + MFunc = std::make_unique(F, this); + vreg_map.clear(); + bb_map.clear(); + vreg_counter = 0; + local_label_counter = 0; + + select(); + + return std::move(MFunc); +} + +// 指令选择主流程 +void RISCv64ISel::select() { + for (const auto& bb_ptr : F->getBasicBlocks()) { + auto mbb = std::make_unique(bb_ptr->getName(), MFunc.get()); + bb_map[bb_ptr.get()] = mbb.get(); + MFunc->addBlock(std::move(mbb)); + } + + if (F->getEntryBlock()) { + for (auto* arg_alloca : F->getEntryBlock()->getArguments()) { + getVReg(arg_alloca); + } + } + + for (const auto& bb_ptr : F->getBasicBlocks()) { + selectBasicBlock(bb_ptr.get()); + } + + for (const auto& bb_ptr : F->getBasicBlocks()) { + CurMBB = bb_map.at(bb_ptr.get()); + for (auto succ : bb_ptr->getSuccessors()) { + CurMBB->successors.push_back(bb_map.at(succ)); + } + for (auto pred : bb_ptr->getPredecessors()) { + CurMBB->predecessors.push_back(bb_map.at(pred)); + } + } +} + +// 处理单个基本块 +void RISCv64ISel::selectBasicBlock(BasicBlock* bb) { + CurMBB = bb_map.at(bb); + auto dag = build_dag(bb); + + std::map value_to_node; + for(const auto& node : dag) { + if (node->value) { + value_to_node[node->value] = node.get(); + } + } + + std::set selected_nodes; + std::function select_recursive = + [&](DAGNode* node) { + if (!node || selected_nodes.count(node)) return; + for (auto operand : node->operands) { + select_recursive(operand); + } + selectNode(node); + selected_nodes.insert(node); + }; + + for (const auto& inst_ptr : bb->getInstructions()) { + DAGNode* node_to_select = nullptr; + if (value_to_node.count(inst_ptr.get())) { + node_to_select = value_to_node.at(inst_ptr.get()); + } else { + for(const auto& node : dag) { + if(node->value == inst_ptr.get()) { + node_to_select = node.get(); + break; + } + } + } + if(node_to_select) { + select_recursive(node_to_select); + } + } +} + +// 核心函数:为DAG节点选择并生成MachineInstr (忠实移植版) +void RISCv64ISel::selectNode(DAGNode* node) { + switch (node->kind) { + case DAGNode::CONSTANT: + case DAGNode::ALLOCA_ADDR: + if (node->value) getVReg(node->value); + break; + + case DAGNode::LOAD: { + auto dest_vreg = getVReg(node->value); + Value* ptr_val = node->operands[0]->value; + + if (auto alloca = dynamic_cast(ptr_val)) { + auto instr = std::make_unique(RVOpcodes::FRAME_LOAD); + instr->addOperand(std::make_unique(dest_vreg)); + instr->addOperand(std::make_unique(getVReg(alloca))); + CurMBB->addInstruction(std::move(instr)); + } else if (auto global = dynamic_cast(ptr_val)) { + auto addr_vreg = getNewVReg(); + auto la = std::make_unique(RVOpcodes::LA); + la->addOperand(std::make_unique(addr_vreg)); + la->addOperand(std::make_unique(global->getName())); + CurMBB->addInstruction(std::move(la)); + + auto lw = std::make_unique(RVOpcodes::LW); + lw->addOperand(std::make_unique(dest_vreg)); + lw->addOperand(std::make_unique( + std::make_unique(addr_vreg), + std::make_unique(0) + )); + CurMBB->addInstruction(std::move(lw)); + } else { + auto ptr_vreg = getVReg(ptr_val); + auto lw = std::make_unique(RVOpcodes::LW); + lw->addOperand(std::make_unique(dest_vreg)); + lw->addOperand(std::make_unique( + std::make_unique(ptr_vreg), + std::make_unique(0) + )); + CurMBB->addInstruction(std::move(lw)); + } + break; + } + + case DAGNode::STORE: { + Value* val_to_store = node->operands[0]->value; + Value* ptr_val = node->operands[1]->value; + + if (auto val_const = dynamic_cast(val_to_store)) { + auto li = std::make_unique(RVOpcodes::LI); + li->addOperand(std::make_unique(getVReg(val_const))); + li->addOperand(std::make_unique(val_const->getInt())); + CurMBB->addInstruction(std::move(li)); + } + auto val_vreg = getVReg(val_to_store); + + if (auto alloca = dynamic_cast(ptr_val)) { + auto instr = std::make_unique(RVOpcodes::FRAME_STORE); + instr->addOperand(std::make_unique(val_vreg)); + instr->addOperand(std::make_unique(getVReg(alloca))); + CurMBB->addInstruction(std::move(instr)); + } else if (auto global = dynamic_cast(ptr_val)) { + auto addr_vreg = getNewVReg(); + auto la = std::make_unique(RVOpcodes::LA); + la->addOperand(std::make_unique(addr_vreg)); + la->addOperand(std::make_unique(global->getName())); + CurMBB->addInstruction(std::move(la)); + + auto sw = std::make_unique(RVOpcodes::SW); + sw->addOperand(std::make_unique(val_vreg)); + sw->addOperand(std::make_unique( + std::make_unique(addr_vreg), + std::make_unique(0) + )); + CurMBB->addInstruction(std::move(sw)); + } else { + auto ptr_vreg = getVReg(ptr_val); + auto sw = std::make_unique(RVOpcodes::SW); + sw->addOperand(std::make_unique(val_vreg)); + sw->addOperand(std::make_unique( + std::make_unique(ptr_vreg), + std::make_unique(0) + )); + CurMBB->addInstruction(std::move(sw)); + } + break; + } + + case DAGNode::BINARY: { + auto bin = dynamic_cast(node->value); + Value* lhs = bin->getLhs(); + Value* rhs = bin->getRhs(); + + auto load_val_if_const = [&](Value* val) { + if (auto c = dynamic_cast(val)) { + auto li = std::make_unique(RVOpcodes::LI); + li->addOperand(std::make_unique(getVReg(c))); + li->addOperand(std::make_unique(c->getInt())); + CurMBB->addInstruction(std::move(li)); + } + }; + load_val_if_const(lhs); + load_val_if_const(rhs); + + auto dest_vreg = getVReg(bin); + auto lhs_vreg = getVReg(lhs); + auto rhs_vreg = getVReg(rhs); + + if (bin->getKind() == BinaryInst::kAdd) { + if (auto rhs_const = dynamic_cast(rhs)) { + if (rhs_const->getInt() >= -2048 && rhs_const->getInt() < 2048) { + auto instr = std::make_unique(RVOpcodes::ADDIW); + instr->addOperand(std::make_unique(dest_vreg)); + instr->addOperand(std::make_unique(lhs_vreg)); + instr->addOperand(std::make_unique(rhs_const->getInt())); + CurMBB->addInstruction(std::move(instr)); + return; + } + } + } + + switch (bin->getKind()) { + case BinaryInst::kAdd: { + RVOpcodes opcode = (lhs->getType()->isPointer() || rhs->getType()->isPointer()) ? RVOpcodes::ADD : RVOpcodes::ADDW; + auto instr = std::make_unique(opcode); + instr->addOperand(std::make_unique(dest_vreg)); + instr->addOperand(std::make_unique(lhs_vreg)); + instr->addOperand(std::make_unique(rhs_vreg)); + CurMBB->addInstruction(std::move(instr)); + break; + } + case BinaryInst::kSub: { + auto instr = std::make_unique(RVOpcodes::SUBW); + instr->addOperand(std::make_unique(dest_vreg)); + instr->addOperand(std::make_unique(lhs_vreg)); + instr->addOperand(std::make_unique(rhs_vreg)); + CurMBB->addInstruction(std::move(instr)); + break; + } + case BinaryInst::kMul: { + auto instr = std::make_unique(RVOpcodes::MULW); + instr->addOperand(std::make_unique(dest_vreg)); + instr->addOperand(std::make_unique(lhs_vreg)); + instr->addOperand(std::make_unique(rhs_vreg)); + CurMBB->addInstruction(std::move(instr)); + break; + } + case Instruction::kDiv: { + auto instr = std::make_unique(RVOpcodes::DIVW); + instr->addOperand(std::make_unique(dest_vreg)); + instr->addOperand(std::make_unique(lhs_vreg)); + instr->addOperand(std::make_unique(rhs_vreg)); + CurMBB->addInstruction(std::move(instr)); + break; + } + case Instruction::kRem: { + auto instr = std::make_unique(RVOpcodes::REMW); + instr->addOperand(std::make_unique(dest_vreg)); + instr->addOperand(std::make_unique(lhs_vreg)); + instr->addOperand(std::make_unique(rhs_vreg)); + CurMBB->addInstruction(std::move(instr)); + break; + } + case BinaryInst::kICmpEQ: { + auto sub = std::make_unique(RVOpcodes::SUBW); + sub->addOperand(std::make_unique(dest_vreg)); + sub->addOperand(std::make_unique(lhs_vreg)); + sub->addOperand(std::make_unique(rhs_vreg)); + CurMBB->addInstruction(std::move(sub)); + + auto seqz = std::make_unique(RVOpcodes::SEQZ); + seqz->addOperand(std::make_unique(dest_vreg)); + seqz->addOperand(std::make_unique(dest_vreg)); + CurMBB->addInstruction(std::move(seqz)); + break; + } + case BinaryInst::kICmpNE: { + auto sub = std::make_unique(RVOpcodes::SUBW); + sub->addOperand(std::make_unique(dest_vreg)); + sub->addOperand(std::make_unique(lhs_vreg)); + sub->addOperand(std::make_unique(rhs_vreg)); + CurMBB->addInstruction(std::move(sub)); + + auto snez = std::make_unique(RVOpcodes::SNEZ); + snez->addOperand(std::make_unique(dest_vreg)); + snez->addOperand(std::make_unique(dest_vreg)); + CurMBB->addInstruction(std::move(snez)); + break; + } + case BinaryInst::kICmpLT: { + auto instr = std::make_unique(RVOpcodes::SLT); + instr->addOperand(std::make_unique(dest_vreg)); + instr->addOperand(std::make_unique(lhs_vreg)); + instr->addOperand(std::make_unique(rhs_vreg)); + CurMBB->addInstruction(std::move(instr)); + break; + } + case BinaryInst::kICmpGT: { + auto instr = std::make_unique(RVOpcodes::SLT); + instr->addOperand(std::make_unique(dest_vreg)); + instr->addOperand(std::make_unique(rhs_vreg)); + instr->addOperand(std::make_unique(lhs_vreg)); + CurMBB->addInstruction(std::move(instr)); + break; + } + case BinaryInst::kICmpLE: { + auto slt = std::make_unique(RVOpcodes::SLT); + slt->addOperand(std::make_unique(dest_vreg)); + slt->addOperand(std::make_unique(rhs_vreg)); + slt->addOperand(std::make_unique(lhs_vreg)); + CurMBB->addInstruction(std::move(slt)); + + auto xori = std::make_unique(RVOpcodes::XORI); + xori->addOperand(std::make_unique(dest_vreg)); + xori->addOperand(std::make_unique(dest_vreg)); + xori->addOperand(std::make_unique(1)); + CurMBB->addInstruction(std::move(xori)); + break; + } + case BinaryInst::kICmpGE: { + auto slt = std::make_unique(RVOpcodes::SLT); + slt->addOperand(std::make_unique(dest_vreg)); + slt->addOperand(std::make_unique(lhs_vreg)); + slt->addOperand(std::make_unique(rhs_vreg)); + CurMBB->addInstruction(std::move(slt)); + + auto xori = std::make_unique(RVOpcodes::XORI); + xori->addOperand(std::make_unique(dest_vreg)); + xori->addOperand(std::make_unique(dest_vreg)); + xori->addOperand(std::make_unique(1)); + CurMBB->addInstruction(std::move(xori)); + break; + } + default: + throw std::runtime_error("Unsupported binary instruction in ISel"); + } + break; + } + + case DAGNode::UNARY: { + auto unary = dynamic_cast(node->value); + auto dest_vreg = getVReg(unary); + auto src_vreg = getVReg(unary->getOperand()); + + switch (unary->getKind()) { + case UnaryInst::kNeg: { + auto instr = std::make_unique(RVOpcodes::SUBW); + instr->addOperand(std::make_unique(dest_vreg)); + instr->addOperand(std::make_unique(PhysicalReg::ZERO)); + instr->addOperand(std::make_unique(src_vreg)); + CurMBB->addInstruction(std::move(instr)); + break; + } + case UnaryInst::kNot: { + auto instr = std::make_unique(RVOpcodes::SEQZ); + instr->addOperand(std::make_unique(dest_vreg)); + instr->addOperand(std::make_unique(src_vreg)); + CurMBB->addInstruction(std::move(instr)); + break; + } + default: + throw std::runtime_error("Unsupported unary instruction in ISel"); + } + break; + } + + case DAGNode::CALL: { + auto call = dynamic_cast(node->value); + for (size_t i = 0; i < node->operands.size() && i < 8; ++i) { + DAGNode* arg_node = node->operands[i]; + auto arg_preg = static_cast(static_cast(PhysicalReg::A0) + i); + + if (arg_node->kind == DAGNode::CONSTANT) { + if (auto const_val = dynamic_cast(arg_node->value)) { + auto li = std::make_unique(RVOpcodes::LI); + li->addOperand(std::make_unique(arg_preg)); + li->addOperand(std::make_unique(const_val->getInt())); + CurMBB->addInstruction(std::move(li)); + } + } else { + auto src_vreg = getVReg(arg_node->value); + auto mv = std::make_unique(RVOpcodes::MV); + mv->addOperand(std::make_unique(arg_preg)); + mv->addOperand(std::make_unique(src_vreg)); + CurMBB->addInstruction(std::move(mv)); + } + } + + auto call_instr = std::make_unique(RVOpcodes::CALL); + call_instr->addOperand(std::make_unique(call->getCallee()->getName())); + CurMBB->addInstruction(std::move(call_instr)); + + if (!call->getType()->isVoid()) { + auto mv_instr = std::make_unique(RVOpcodes::MV); + mv_instr->addOperand(std::make_unique(getVReg(call))); + mv_instr->addOperand(std::make_unique(PhysicalReg::A0)); + CurMBB->addInstruction(std::move(mv_instr)); + } + break; + } + + case DAGNode::RETURN: { + auto ret_inst_ir = dynamic_cast(node->value); + if (ret_inst_ir && ret_inst_ir->hasReturnValue()) { + Value* ret_val = ret_inst_ir->getReturnValue(); + if (auto const_val = dynamic_cast(ret_val)) { + auto li_instr = std::make_unique(RVOpcodes::LI); + li_instr->addOperand(std::make_unique(PhysicalReg::A0)); + li_instr->addOperand(std::make_unique(const_val->getInt())); + CurMBB->addInstruction(std::move(li_instr)); + } else { + auto mv_instr = std::make_unique(RVOpcodes::MV); + mv_instr->addOperand(std::make_unique(PhysicalReg::A0)); + mv_instr->addOperand(std::make_unique(getVReg(ret_val))); + CurMBB->addInstruction(std::move(mv_instr)); + } + } + auto ret_mi = std::make_unique(RVOpcodes::RET); + CurMBB->addInstruction(std::move(ret_mi)); + break; + } + + case DAGNode::BRANCH: { + if (auto cond_br = dynamic_cast(node->value)) { + auto br_instr = std::make_unique(RVOpcodes::BNE); + br_instr->addOperand(std::make_unique(getVReg(cond_br->getCondition()))); + br_instr->addOperand(std::make_unique(PhysicalReg::ZERO)); + br_instr->addOperand(std::make_unique(cond_br->getThenBlock()->getName())); + CurMBB->addInstruction(std::move(br_instr)); + } else if (auto uncond_br = dynamic_cast(node->value)) { + auto j_instr = std::make_unique(RVOpcodes::J); + j_instr->addOperand(std::make_unique(uncond_br->getBlock()->getName())); + CurMBB->addInstruction(std::move(j_instr)); + } + break; + } + + case DAGNode::MEMSET: { + auto memset = dynamic_cast(node->value); + auto r_dest_addr = getVReg(memset->getPointer()); + auto r_num_bytes = getVReg(memset->getSize()); + auto r_value_byte = getVReg(memset->getValue()); + auto r_counter = getNewVReg(); + auto r_end_addr = getNewVReg(); + auto r_current_addr = getNewVReg(); + auto r_temp_val = getNewVReg(); + + auto add_instr = [&](RVOpcodes op, unsigned rd, unsigned rs1, unsigned rs2) { + auto i = std::make_unique(op); + i->addOperand(std::make_unique(rd)); + i->addOperand(std::make_unique(rs1)); + i->addOperand(std::make_unique(rs2)); + CurMBB->addInstruction(std::move(i)); + }; + auto addi_instr = [&](RVOpcodes op, unsigned rd, unsigned rs1, int64_t imm) { + auto i = std::make_unique(op); + i->addOperand(std::make_unique(rd)); + i->addOperand(std::make_unique(rs1)); + i->addOperand(std::make_unique(imm)); + CurMBB->addInstruction(std::move(i)); + }; + auto store_instr = [&](RVOpcodes op, unsigned src, unsigned base, int64_t off) { + auto i = std::make_unique(op); + i->addOperand(std::make_unique(src)); + i->addOperand(std::make_unique(std::make_unique(base), std::make_unique(off))); + CurMBB->addInstruction(std::move(i)); + }; + auto branch_instr = [&](RVOpcodes op, unsigned rs1, unsigned rs2, const std::string& label) { + auto i = std::make_unique(op); + i->addOperand(std::make_unique(rs1)); + i->addOperand(std::make_unique(rs2)); + i->addOperand(std::make_unique(label)); + CurMBB->addInstruction(std::move(i)); + }; + auto jump_instr = [&](const std::string& label) { + auto i = std::make_unique(RVOpcodes::J); + i->addOperand(std::make_unique(label)); + CurMBB->addInstruction(std::move(i)); + }; + auto label_instr = [&](const std::string& name) { + auto i = std::make_unique(RVOpcodes::LABEL); + i->addOperand(std::make_unique(name)); + CurMBB->addInstruction(std::move(i)); + }; + + int unique_id = this->local_label_counter++; + std::string loop_start_label = MFunc->getName() + "_memset_loop_start_" + std::to_string(unique_id); + std::string loop_end_label = MFunc->getName() + "_memset_loop_end_" + std::to_string(unique_id); + std::string remainder_label = MFunc->getName() + "_memset_remainder_" + std::to_string(unique_id); + std::string done_label = MFunc->getName() + "_memset_done_" + std::to_string(unique_id); + + addi_instr(RVOpcodes::ANDI, r_temp_val, r_value_byte, 255); + addi_instr(RVOpcodes::SLLI, r_value_byte, r_temp_val, 8); + add_instr(RVOpcodes::OR, r_temp_val, r_temp_val, r_value_byte); + addi_instr(RVOpcodes::SLLI, r_value_byte, r_temp_val, 16); + add_instr(RVOpcodes::OR, r_temp_val, r_temp_val, r_value_byte); + addi_instr(RVOpcodes::SLLI, r_value_byte, r_temp_val, 32); + add_instr(RVOpcodes::OR, r_temp_val, r_temp_val, r_value_byte); + add_instr(RVOpcodes::ADD, r_end_addr, r_dest_addr, r_num_bytes); + auto mv = std::make_unique(RVOpcodes::MV); + mv->addOperand(std::make_unique(r_current_addr)); + mv->addOperand(std::make_unique(r_dest_addr)); + CurMBB->addInstruction(std::move(mv)); + addi_instr(RVOpcodes::ANDI, r_counter, r_num_bytes, -8); + add_instr(RVOpcodes::ADD, r_counter, r_dest_addr, r_counter); + label_instr(loop_start_label); + branch_instr(RVOpcodes::BGEU, r_current_addr, r_counter, loop_end_label); + store_instr(RVOpcodes::SD, r_temp_val, r_current_addr, 0); + addi_instr(RVOpcodes::ADDI, r_current_addr, r_current_addr, 8); + jump_instr(loop_start_label); + label_instr(loop_end_label); + label_instr(remainder_label); + branch_instr(RVOpcodes::BGEU, r_current_addr, r_end_addr, done_label); + store_instr(RVOpcodes::SB, r_temp_val, r_current_addr, 0); + addi_instr(RVOpcodes::ADDI, r_current_addr, r_current_addr, 1); + jump_instr(remainder_label); + label_instr(done_label); + break; + } + + default: + throw std::runtime_error("Unsupported DAGNode kind in ISel"); + } +} + +// 以下是忠实移植的DAG构建函数 +RISCv64ISel::DAGNode* RISCv64ISel::create_node(int kind_int, Value* val, std::map& value_to_node, std::vector>& nodes_storage) { + auto kind = static_cast(kind_int); + if (val && value_to_node.count(val) && kind != DAGNode::STORE && kind != DAGNode::RETURN && kind != DAGNode::BRANCH && kind != DAGNode::MEMSET) { + return value_to_node[val]; + } + auto node = std::make_unique(kind); + node->value = val; + DAGNode* raw_node_ptr = node.get(); + nodes_storage.push_back(std::move(node)); + if (val && !val->getType()->isVoid() && (dynamic_cast(val) || dynamic_cast(val))) { + value_to_node[val] = raw_node_ptr; + } + return raw_node_ptr; +} + +RISCv64ISel::DAGNode* RISCv64ISel::get_operand_node(Value* val_ir, std::map& value_to_node, std::vector>& nodes_storage) { + if (value_to_node.count(val_ir)) { + return value_to_node[val_ir]; + } else if (dynamic_cast(val_ir)) { + return create_node(DAGNode::CONSTANT, val_ir, value_to_node, nodes_storage); + } else if (dynamic_cast(val_ir)) { + return create_node(DAGNode::CONSTANT, val_ir, value_to_node, nodes_storage); + } else if (dynamic_cast(val_ir)) { + return create_node(DAGNode::ALLOCA_ADDR, val_ir, value_to_node, nodes_storage); + } + return create_node(DAGNode::LOAD, val_ir, value_to_node, nodes_storage); +} + +std::vector> RISCv64ISel::build_dag(BasicBlock* bb) { + std::vector> nodes_storage; + std::map value_to_node; + + for (const auto& inst_ptr : bb->getInstructions()) { + Instruction* inst = inst_ptr.get(); + if (auto alloca = dynamic_cast(inst)) { + create_node(DAGNode::ALLOCA_ADDR, alloca, value_to_node, nodes_storage); + } else if (auto store = dynamic_cast(inst)) { + auto store_node = create_node(DAGNode::STORE, store, value_to_node, nodes_storage); + store_node->operands.push_back(get_operand_node(store->getValue(), value_to_node, nodes_storage)); + store_node->operands.push_back(get_operand_node(store->getPointer(), value_to_node, nodes_storage)); + } else if (auto memset = dynamic_cast(inst)) { + auto memset_node = create_node(DAGNode::MEMSET, memset, value_to_node, nodes_storage); + memset_node->operands.push_back(get_operand_node(memset->getPointer(), value_to_node, nodes_storage)); + memset_node->operands.push_back(get_operand_node(memset->getBegin(), value_to_node, nodes_storage)); + memset_node->operands.push_back(get_operand_node(memset->getSize(), value_to_node, nodes_storage)); + memset_node->operands.push_back(get_operand_node(memset->getValue(), value_to_node, nodes_storage)); + } else if (auto load = dynamic_cast(inst)) { + auto load_node = create_node(DAGNode::LOAD, load, value_to_node, nodes_storage); + load_node->operands.push_back(get_operand_node(load->getPointer(), value_to_node, nodes_storage)); + } else if (auto bin = dynamic_cast(inst)) { + if(value_to_node.count(bin)) continue; + if (bin->getKind() == BinaryInst::kSub) { + if (auto const_lhs = dynamic_cast(bin->getLhs())) { + if (const_lhs->getInt() == 0) { + auto unary_node = create_node(DAGNode::UNARY, bin, value_to_node, nodes_storage); + unary_node->operands.push_back(get_operand_node(bin->getRhs(), value_to_node, nodes_storage)); + continue; + } + } + } + auto bin_node = create_node(DAGNode::BINARY, bin, value_to_node, nodes_storage); + bin_node->operands.push_back(get_operand_node(bin->getLhs(), value_to_node, nodes_storage)); + bin_node->operands.push_back(get_operand_node(bin->getRhs(), value_to_node, nodes_storage)); + } else if (auto un = dynamic_cast(inst)) { + if(value_to_node.count(un)) continue; + auto unary_node = create_node(DAGNode::UNARY, un, value_to_node, nodes_storage); + unary_node->operands.push_back(get_operand_node(un->getOperand(), value_to_node, nodes_storage)); + } else if (auto call = dynamic_cast(inst)) { + if(value_to_node.count(call)) continue; + auto call_node = create_node(DAGNode::CALL, call, value_to_node, nodes_storage); + for (auto arg : call->getArguments()) { + call_node->operands.push_back(get_operand_node(arg->getValue(), value_to_node, nodes_storage)); + } + } else if (auto ret = dynamic_cast(inst)) { + auto ret_node = create_node(DAGNode::RETURN, ret, value_to_node, nodes_storage); + if (ret->hasReturnValue()) { + ret_node->operands.push_back(get_operand_node(ret->getReturnValue(), value_to_node, nodes_storage)); + } + } else if (auto cond_br = dynamic_cast(inst)) { + auto br_node = create_node(DAGNode::BRANCH, cond_br, value_to_node, nodes_storage); + br_node->operands.push_back(get_operand_node(cond_br->getCondition(), value_to_node, nodes_storage)); + } else if (auto uncond_br = dynamic_cast(inst)) { + create_node(DAGNode::BRANCH, uncond_br, value_to_node, nodes_storage); + } + } + return nodes_storage; +} + +} // namespace sysy \ No newline at end of file diff --git a/src/RISCv64Passes.cpp b/src/RISCv64Passes.cpp new file mode 100644 index 0000000..40aff21 --- /dev/null +++ b/src/RISCv64Passes.cpp @@ -0,0 +1,8 @@ +// RISCv64Passes.cpp +#include "RISCv64Passes.h" + +namespace sysy { + +// 此处为未来优化Pass的实现 + +} // namespace sysy \ No newline at end of file diff --git a/src/RISCv64RegAlloc.cpp b/src/RISCv64RegAlloc.cpp new file mode 100644 index 0000000..2695f3b --- /dev/null +++ b/src/RISCv64RegAlloc.cpp @@ -0,0 +1,322 @@ +#include "RISCv64RegAlloc.h" +#include "RISCv64ISel.h" +#include +#include + +namespace sysy { + +RISCv64RegAlloc::RISCv64RegAlloc(MachineFunction* mfunc) : MFunc(mfunc) { + allocable_int_regs = { + PhysicalReg::T0, PhysicalReg::T1, PhysicalReg::T2, PhysicalReg::T3, + PhysicalReg::T4, PhysicalReg::T5, PhysicalReg::T6, + PhysicalReg::A0, PhysicalReg::A1, PhysicalReg::A2, PhysicalReg::A3, + PhysicalReg::A4, PhysicalReg::A5, PhysicalReg::A6, PhysicalReg::A7, + PhysicalReg::S0, PhysicalReg::S1, PhysicalReg::S2, PhysicalReg::S3, + PhysicalReg::S4, PhysicalReg::S5, PhysicalReg::S6, PhysicalReg::S7, + PhysicalReg::S8, PhysicalReg::S9, PhysicalReg::S10, PhysicalReg::S11, + }; +} + +void RISCv64RegAlloc::run() { + eliminateFrameIndices(); + analyzeLiveness(); + buildInterferenceGraph(); + colorGraph(); + rewriteFunction(); +} + +void RISCv64RegAlloc::eliminateFrameIndices() { + StackFrameInfo& frame_info = MFunc->getFrameInfo(); + int current_offset = 0; + Function* F = MFunc->getFunc(); + RISCv64ISel* isel = MFunc->getISel(); + + for (auto& bb : F->getBasicBlocks()) { + for (auto& inst : bb->getInstructions()) { + if (auto alloca = dynamic_cast(inst.get())) { + int size = 4; + if (!alloca->getDims().empty()) { + int num_elements = 1; + for (const auto& dim_use : alloca->getDims()) { + if (auto const_dim = dynamic_cast(dim_use->getValue())) { + num_elements *= const_dim->getInt(); + } + } + size *= num_elements; + } + current_offset += size; + unsigned alloca_vreg = isel->getVReg(alloca); + frame_info.alloca_offsets[alloca_vreg] = -current_offset; + } + } + } + frame_info.locals_size = current_offset; + + for (auto& mbb : MFunc->getBlocks()) { + std::vector> new_instructions; + for (auto& instr_ptr : mbb->getInstructions()) { + if (instr_ptr->getOpcode() == RVOpcodes::FRAME_LOAD) { + auto& operands = instr_ptr->getOperands(); + unsigned dest_vreg = static_cast(operands[0].get())->getVRegNum(); + unsigned alloca_vreg = static_cast(operands[1].get())->getVRegNum(); + int offset = frame_info.alloca_offsets.at(alloca_vreg); + auto addr_vreg = isel->getNewVReg(); + + auto addi = std::make_unique(RVOpcodes::ADDI); + addi->addOperand(std::make_unique(addr_vreg)); + addi->addOperand(std::make_unique(PhysicalReg::S0)); + addi->addOperand(std::make_unique(offset)); + new_instructions.push_back(std::move(addi)); + + auto lw = std::make_unique(RVOpcodes::LW); + lw->addOperand(std::make_unique(dest_vreg)); + lw->addOperand(std::make_unique( + std::make_unique(addr_vreg), + std::make_unique(0))); + new_instructions.push_back(std::move(lw)); + + } else if (instr_ptr->getOpcode() == RVOpcodes::FRAME_STORE) { + auto& operands = instr_ptr->getOperands(); + unsigned src_vreg = static_cast(operands[0].get())->getVRegNum(); + unsigned alloca_vreg = static_cast(operands[1].get())->getVRegNum(); + int offset = frame_info.alloca_offsets.at(alloca_vreg); + auto addr_vreg = isel->getNewVReg(); + + auto addi = std::make_unique(RVOpcodes::ADDI); + addi->addOperand(std::make_unique(addr_vreg)); + addi->addOperand(std::make_unique(PhysicalReg::S0)); + addi->addOperand(std::make_unique(offset)); + new_instructions.push_back(std::move(addi)); + + auto sw = std::make_unique(RVOpcodes::SW); + sw->addOperand(std::make_unique(src_vreg)); + sw->addOperand(std::make_unique( + std::make_unique(addr_vreg), + std::make_unique(0))); + new_instructions.push_back(std::move(sw)); + } else { + new_instructions.push_back(std::move(instr_ptr)); + } + } + mbb->getInstructions() = std::move(new_instructions); + } +} + +void RISCv64RegAlloc::getInstrUseDef(MachineInstr* instr, LiveSet& use, LiveSet& def) { + bool is_def = true; + auto opcode = instr->getOpcode(); + + // 预定义def和use规则 + if (opcode == RVOpcodes::SW || opcode == RVOpcodes::SD || + opcode == RVOpcodes::BEQ || opcode == RVOpcodes::BNE || + opcode == RVOpcodes::BLT || opcode == RVOpcodes::BGE || + opcode == RVOpcodes::RET || opcode == RVOpcodes::J) { + is_def = false; + } + if (opcode == RVOpcodes::CALL) { + // CALL会杀死所有调用者保存寄存器,这是一个简化处理 + // 同时也使用了传入a0-a7的参数 + } + + for (const auto& op : instr->getOperands()) { + if (op->getKind() == MachineOperand::KIND_REG) { + auto reg_op = static_cast(op.get()); + if (reg_op->isVirtual()) { + if (is_def) { + def.insert(reg_op->getVRegNum()); + is_def = false; + } else { + use.insert(reg_op->getVRegNum()); + } + } + } else if (op->getKind() == MachineOperand::KIND_MEM) { + auto mem_op = static_cast(op.get()); + if (mem_op->getBase()->isVirtual()) { + use.insert(mem_op->getBase()->getVRegNum()); + } + } + } +} + +void RISCv64RegAlloc::analyzeLiveness() { + bool changed = true; + while (changed) { + changed = false; + for (auto it = MFunc->getBlocks().rbegin(); it != MFunc->getBlocks().rend(); ++it) { + auto& mbb = *it; + LiveSet live_out; + for (auto succ : mbb->successors) { + if (!succ->getInstructions().empty()) { + auto first_instr = succ->getInstructions().front().get(); + if (live_in_map.count(first_instr)) { + live_out.insert(live_in_map.at(first_instr).begin(), live_in_map.at(first_instr).end()); + } + } + } + + for (auto instr_it = mbb->getInstructions().rbegin(); instr_it != mbb->getInstructions().rend(); ++instr_it) { + MachineInstr* instr = instr_it->get(); + LiveSet old_live_in = live_in_map[instr]; + live_out_map[instr] = live_out; + + LiveSet use, def; + getInstrUseDef(instr, use, def); + + LiveSet live_in = use; + LiveSet diff = live_out; + for (auto vreg : def) { + diff.erase(vreg); + } + live_in.insert(diff.begin(), diff.end()); + live_in_map[instr] = live_in; + + live_out = live_in; + + if (live_in_map[instr] != old_live_in) { + changed = true; + } + } + } + } +} + +void RISCv64RegAlloc::buildInterferenceGraph() { + std::set all_vregs; + for (auto& mbb : MFunc->getBlocks()) { + for(auto& instr : mbb->getInstructions()) { + LiveSet use, def; + getInstrUseDef(instr.get(), use, def); + for(auto u : use) all_vregs.insert(u); + for(auto d : def) all_vregs.insert(d); + } + } + + for (auto vreg : all_vregs) { interference_graph[vreg] = {}; } + + for (auto& mbb : MFunc->getBlocks()) { + for (auto& instr : mbb->getInstructions()) { + LiveSet def, use; + getInstrUseDef(instr.get(), use, def); + const LiveSet& live_out = live_out_map.at(instr.get()); + + for (unsigned d : def) { + for (unsigned l : live_out) { + if (d != l) { + interference_graph[d].insert(l); + interference_graph[l].insert(d); + } + } + } + } + } +} + +void RISCv64RegAlloc::colorGraph() { + std::vector sorted_vregs; + for (auto const& [vreg, neighbors] : interference_graph) { + sorted_vregs.push_back(vreg); + } + + std::sort(sorted_vregs.begin(), sorted_vregs.end(), [&](unsigned a, unsigned b) { + return interference_graph[a].size() > interference_graph[b].size(); + }); + + for (unsigned vreg : sorted_vregs) { + std::set used_colors; + for (unsigned neighbor : interference_graph.at(vreg)) { + if (color_map.count(neighbor)) { + used_colors.insert(color_map.at(neighbor)); + } + } + + bool colored = false; + for (PhysicalReg preg : allocable_int_regs) { + if (used_colors.find(preg) == used_colors.end()) { + color_map[vreg] = preg; + colored = true; + break; + } + } + if (!colored) { + spilled_vregs.insert(vreg); + } + } +} + +void RISCv64RegAlloc::rewriteFunction() { + StackFrameInfo& frame_info = MFunc->getFrameInfo(); + int current_offset = frame_info.locals_size; + for (unsigned vreg : spilled_vregs) { + current_offset += 4; + frame_info.spill_offsets[vreg] = -current_offset; + } + frame_info.spill_size = current_offset - frame_info.locals_size; + + for (auto& mbb : MFunc->getBlocks()) { + std::vector> new_instructions; + for (auto& instr_ptr : mbb->getInstructions()) { + LiveSet use, def; + getInstrUseDef(instr_ptr.get(), use, def); + + for (unsigned vreg : use) { + if (spilled_vregs.count(vreg)) { + int offset = frame_info.spill_offsets.at(vreg); + auto load = std::make_unique(RVOpcodes::LW); + load->addOperand(std::make_unique(vreg)); + load->addOperand(std::make_unique( + std::make_unique(PhysicalReg::S0), + std::make_unique(offset) + )); + new_instructions.push_back(std::move(load)); + } + } + + new_instructions.push_back(std::move(instr_ptr)); + + for (unsigned vreg : def) { + if (spilled_vregs.count(vreg)) { + int offset = frame_info.spill_offsets.at(vreg); + auto store = std::make_unique(RVOpcodes::SW); + store->addOperand(std::make_unique(vreg)); + store->addOperand(std::make_unique( + std::make_unique(PhysicalReg::S0), + std::make_unique(offset) + )); + new_instructions.push_back(std::move(store)); + } + } + } + mbb->getInstructions() = std::move(new_instructions); + } + + for (auto& mbb : MFunc->getBlocks()) { + for (auto& instr_ptr : mbb->getInstructions()) { + for (auto& op_ptr : instr_ptr->getOperands()) { + if(op_ptr->getKind() == MachineOperand::KIND_REG) { + auto reg_op = static_cast(op_ptr.get()); + if (reg_op->isVirtual()) { + unsigned vreg = reg_op->getVRegNum(); + if (color_map.count(vreg)) { + reg_op->setPReg(color_map.at(vreg)); + } else if (spilled_vregs.count(vreg)) { + reg_op->setPReg(PhysicalReg::T6); // 溢出统一用t6 + } + } + } else if (op_ptr->getKind() == MachineOperand::KIND_MEM) { + auto mem_op = static_cast(op_ptr.get()); + auto base_reg_op = mem_op->getBase(); + if(base_reg_op->isVirtual()){ + unsigned vreg = base_reg_op->getVRegNum(); + if(color_map.count(vreg)) { + base_reg_op->setPReg(color_map.at(vreg)); + } else if (spilled_vregs.count(vreg)) { + base_reg_op->setPReg(PhysicalReg::T6); + } + } + } + } + } + } +} + +} // namespace sysy \ No newline at end of file diff --git a/src/include/RISCv64AsmPrinter.h b/src/include/RISCv64AsmPrinter.h new file mode 100644 index 0000000..3ea71f6 --- /dev/null +++ b/src/include/RISCv64AsmPrinter.h @@ -0,0 +1,32 @@ +#ifndef RISCV64_ASMPRINTER_H +#define RISCV64_ASMPRINTER_H + +#include "RISCv64LLIR.h" +#include + +namespace sysy { + +class RISCv64AsmPrinter { +public: + RISCv64AsmPrinter(MachineFunction* mfunc); + // 主入口 + void run(std::ostream& os); + +private: + // 打印各个部分 + void printPrologue(); + void printEpilogue(); + void printBasicBlock(MachineBasicBlock* mbb); + void printInstruction(MachineInstr* instr); + + // 辅助函数 + std::string regToString(PhysicalReg reg); + void printOperand(MachineOperand* op); + + MachineFunction* MFunc; + std::ostream* OS; +}; + +} // namespace sysy + +#endif // RISCV64_ASMPRINTER_H \ No newline at end of file diff --git a/src/include/RISCv64Backend.h b/src/include/RISCv64Backend.h index 429aba2..33f7831 100644 --- a/src/include/RISCv64Backend.h +++ b/src/include/RISCv64Backend.h @@ -3,128 +3,23 @@ #include "IR.h" #include -#include -#include -#include -#include -#include -#include // For std::function - -extern int DEBUG; -extern int DEEPDEBUG; - - namespace sysy { -// 为活跃性分析的结果定义一个结构体,以同时持有 live_in 和 live_out 集合 -struct LivenessResult { - std::map> live_in; - std::map> live_out; -}; - +// RISCv64CodeGen 现在是一个高层驱动器 class RISCv64CodeGen { public: - enum class PhysicalReg { - ZERO, RA, SP, GP, TP, T0, T1, T2, S0, S1, A0, A1, A2, A3, A4, A5, A6, A7, S2, S3, S4, S5, S6, S7, S8, S9, S10, S11, T3, T4, T5, T6, - F0, F1, F2, F3, F4, F5, F6, F7, F8, F9, F10, F11, F12, F13, F14, F15,F16, F17, F18, F19, F20, F21, F22, F23, F24, F25, F26, F27, F28, F29, F30, F31 - }; - - // Move DAGNode and RegAllocResult to public section - struct DAGNode { - enum NodeKind { CONSTANT, LOAD, STORE, BINARY, CALL, RETURN, BRANCH, ALLOCA_ADDR, UNARY, MEMSET }; - NodeKind kind; - Value* value = nullptr; // For IR Value - std::string inst; // Generated RISC-V instruction(s) for this node - std::string result_vreg; // Virtual register assigned to this node's result - std::vector operands; - std::vector users; // For debugging and potentially optimizations - DAGNode(NodeKind k) : kind(k) {} - - // Debugging / helper - std::string getNodeKindString() const { - switch (kind) { - case CONSTANT: return "CONSTANT"; - case LOAD: return "LOAD"; - case STORE: return "STORE"; - case BINARY: return "BINARY"; - case CALL: return "CALL"; - case RETURN: return "RETURN"; - case BRANCH: return "BRANCH"; - case ALLOCA_ADDR: return "ALLOCA_ADDR"; - case UNARY: return "UNARY"; - case MEMSET: return "MEMSET"; - default: return "UNKNOWN"; - } - } - }; - - struct RegAllocResult { - std::map vreg_to_preg; // Virtual register to Physical Register mapping - std::map stack_map; // Value (AllocaInst) to stack offset - int stack_size = 0; // Total stack frame size for locals and spills - }; - RISCv64CodeGen(Module* mod) : module(mod) {} - + // 唯一的公共入口点 std::string code_gen(); - std::string module_gen(); - std::string function_gen(Function* func); - // 修改 basicBlock_gen 的声明,添加 int block_idx 参数 - std::string basicBlock_gen(BasicBlock* bb, const RegAllocResult& alloc, int block_idx); - - // DAG related - std::vector> build_dag(BasicBlock* bb); - void select_instructions(DAGNode* node, const RegAllocResult& alloc); - // 改变 emit_instructions 的参数,使其可以直接添加汇编指令到 main ss - void emit_instructions(DAGNode* node, std::stringstream& ss, const RegAllocResult& alloc, std::set& emitted_nodes); - - // Register Allocation related - LivenessResult liveness_analysis(Function* func); - std::map> build_interference_graph(const LivenessResult& liveness); - void color_graph(std::map& vreg_to_preg, - const std::map>& interference_graph); - RegAllocResult register_allocation(Function* func); - void eliminate_phi(Function* func); // Phi elimination is typically done before DAG building - - // Utility - std::string reg_to_string(PhysicalReg reg); - void print_dag(const std::vector>& dag, const std::string& bb_name); private: - static const std::vector allocable_regs; - std::map value_vreg_map; // Maps IR Value* to its virtual register name + // 模块级代码生成 + std::string module_gen(); + // 函数级代码生成 (实现新的流水线) + std::string function_gen(Function* func); + Module* module; - int vreg_counter = 0; // Counter for unique virtual register names - int alloca_offset_counter = 0; // Counter for alloca offsets - - // 新增一个成员变量来存储当前函数的所有 DAGNode,以确保其生命周期贯穿整个函数代码生成 - // 这样可以在多个 BasicBlock_gen 调用中访问到完整的 DAG 节点 - std::vector> current_function_dag_nodes; - - // 为空标签定义一个伪名称前缀,加上块索引以确保唯一性 - const std::string ENTRY_BLOCK_PSEUDO_NAME = "entry_block_"; - - int local_label_counter = 0; // 用于生成唯一的本地标签 (如 memset 循环, 匿名块跳转等) - - // !!! 修改:get_operand_node 辅助函数现在需要传入 value_to_node 和 nodes_storage 的引用 - // 因为它们是 build_dag 局部管理的 - DAGNode* get_operand_node( - Value* val_ir, - std::map& value_to_node, - std::vector>& nodes_storage - ); - - // !!! 新增:create_node 辅助函数也需要传入 value_to_node 和 nodes_storage 的引用 - // 并且它应该不再是 lambda,而是一个真正的成员函数 - DAGNode* create_node( - DAGNode::NodeKind kind, - Value* val, - std::map& value_to_node, - std::vector>& nodes_storage - ); - - std::vector> temp_instructions_storage; // 用于存储 build_dag 中创建的临时 BinaryInst }; } // namespace sysy diff --git a/src/include/RISCv64ISel.h b/src/include/RISCv64ISel.h new file mode 100644 index 0000000..795b2b8 --- /dev/null +++ b/src/include/RISCv64ISel.h @@ -0,0 +1,49 @@ +#ifndef RISCV64_ISEL_H +#define RISCV64_ISEL_H + +#include "RISCv64LLIR.h" + +namespace sysy { + +class RISCv64ISel { +public: + RISCv64ISel(); + // 模块主入口:将一个高层IR函数转换为底层LLIR函数 + std::unique_ptr runOnFunction(Function* func); + + // 公开接口,以便后续模块(如RegAlloc)可以查询或创建vreg + unsigned getVReg(Value* val); + unsigned getNewVReg() { return vreg_counter++; } + +private: + // DAG节点定义,作为ISel的内部实现细节 + struct DAGNode; + + // 指令选择主流程 + void select(); + // 为单个基本块生成指令 + void selectBasicBlock(BasicBlock* bb); + // 核心函数:为DAG节点选择并生成MachineInstr + void selectNode(DAGNode* node); + + // DAG 构建相关函数 (从原RISCv64Backend迁移) + std::vector> build_dag(BasicBlock* bb); + DAGNode* get_operand_node(Value* val_ir, std::map&, std::vector>&); + DAGNode* create_node(int kind, Value* val, std::map&, std::vector>&); + + // 状态 + Function* F; // 当前处理的高层IR函数 + std::unique_ptr MFunc; // 正在构建的底层LLIR函数 + MachineBasicBlock* CurMBB; // 当前正在处理的机器基本块 + + // 映射关系 + std::map vreg_map; + std::map bb_map; + + unsigned vreg_counter; + int local_label_counter; +}; + +} // namespace sysy + +#endif // RISCV64_ISEL_H \ No newline at end of file diff --git a/src/include/RISCv64LLIR.h b/src/include/RISCv64LLIR.h new file mode 100644 index 0000000..6310741 --- /dev/null +++ b/src/include/RISCv64LLIR.h @@ -0,0 +1,200 @@ +#ifndef RISCV64_LLIR_H +#define RISCV64_LLIR_H + +#include "IR.h" // 确保包含了您自己的IR头文件 +#include +#include +#include +#include +#include + +// 前向声明,避免循环引用 +namespace sysy { +class Function; +class RISCv64ISel; +} + +namespace sysy { + +// 物理寄存器定义 +enum class PhysicalReg { + ZERO, RA, SP, GP, TP, T0, T1, T2, S0, S1, A0, A1, A2, A3, A4, A5, A6, A7, S2, S3, S4, S5, S6, S7, S8, S9, S10, S11, T3, T4, T5, T6, + F0, F1, F2, F3, F4, F5, F6, F7, F8, F9, F10, F11, F12, F13, F14, F15,F16, F17, F18, F19, F20, F21, F22, F23, F24, F25, F26, F27, F28, F29, F30, F31 +}; + +// RISC-V 指令操作码枚举 +enum class RVOpcodes { + // 算术指令 + ADD, ADDI, ADDW, ADDIW, SUB, SUBW, MUL, MULW, DIV, DIVW, REM, REMW, + // 逻辑指令 + XOR, XORI, OR, ORI, AND, ANDI, + // 移位指令 + SLL, SLLI, SLLW, SLLIW, SRL, SRLI, SRLW, SRLIW, SRA, SRAI, SRAW, SRAIW, + // 比较指令 + SLT, SLTI, SLTU, SLTIU, + // 内存访问指令 + LW, LH, LB, LWU, LHU, LBU, SW, SH, SB, LD, SD, + // 控制流指令 + J, JAL, JALR, RET, + BEQ, BNE, BLT, BGE, BLTU, BGEU, + // 伪指令 + LI, LA, MV, NEG, NEGW, SEQZ, SNEZ, + // 函数调用 + CALL, + // 特殊标记,非指令 + LABEL, + // 新增伪指令,用于解耦栈帧处理 + FRAME_LOAD, // 从栈帧加载 (AllocaInst) + FRAME_STORE, // 保存到栈帧 (AllocaInst) +}; + +class MachineOperand; +class RegOperand; +class ImmOperand; +class LabelOperand; +class MemOperand; +class MachineInstr; +class MachineBasicBlock; +class MachineFunction; + +// 操作数基类 +class MachineOperand { +public: + enum OperandKind { KIND_REG, KIND_IMM, KIND_LABEL, KIND_MEM }; + MachineOperand(OperandKind kind) : kind(kind) {} + virtual ~MachineOperand() = default; + OperandKind getKind() const { return kind; } +private: + OperandKind kind; +}; + +// 寄存器操作数 +class RegOperand : public MachineOperand { +public: + // 构造虚拟寄存器 + RegOperand(unsigned vreg_num) + : MachineOperand(KIND_REG), vreg_num(vreg_num), is_virtual(true) {} + + // 构造物理寄存器 + RegOperand(PhysicalReg preg) + : MachineOperand(KIND_REG), preg(preg), is_virtual(false) {} + + bool isVirtual() const { return is_virtual; } + unsigned getVRegNum() const { return vreg_num; } + PhysicalReg getPReg() const { return preg; } + + void setPReg(PhysicalReg new_preg) { + preg = new_preg; + is_virtual = false; + } +private: + unsigned vreg_num = 0; + PhysicalReg preg = PhysicalReg::ZERO; + bool is_virtual; +}; + +// 立即数操作数 +class ImmOperand : public MachineOperand { +public: + ImmOperand(int64_t value) : MachineOperand(KIND_IMM), value(value) {} + int64_t getValue() const { return value; } +private: + int64_t value; +}; + +// 标签操作数 +class LabelOperand : public MachineOperand { +public: + LabelOperand(const std::string& name) : MachineOperand(KIND_LABEL), name(name) {} + const std::string& getName() const { return name; } +private: + std::string name; +}; + +// 内存操作数, 表示 offset(base_reg) +class MemOperand : public MachineOperand { +public: + MemOperand(std::unique_ptr base, std::unique_ptr offset) + : MachineOperand(KIND_MEM), base(std::move(base)), offset(std::move(offset)) {} + RegOperand* getBase() const { return base.get(); } + ImmOperand* getOffset() const { return offset.get(); } +private: + std::unique_ptr base; + std::unique_ptr offset; +}; + +// 机器指令 +class MachineInstr { +public: + MachineInstr(RVOpcodes opcode) : opcode(opcode) {} + + RVOpcodes getOpcode() const { return opcode; } + const std::vector>& getOperands() const { return operands; } + std::vector>& getOperands() { return operands; } + + void addOperand(std::unique_ptr operand) { + operands.push_back(std::move(operand)); + } +private: + RVOpcodes opcode; + std::vector> operands; +}; + +// 机器基本块 +class MachineBasicBlock { +public: + MachineBasicBlock(const std::string& name, MachineFunction* parent) + : name(name), parent(parent) {} + + const std::string& getName() const { return name; } + MachineFunction* getParent() const { return parent; } + const std::vector>& getInstructions() const { return instructions; } + std::vector>& getInstructions() { return instructions; } + + void addInstruction(std::unique_ptr instr) { + instructions.push_back(std::move(instr)); + } + + std::vector successors; + std::vector predecessors; +private: + std::string name; + std::vector> instructions; + MachineFunction* parent; +}; + +// 栈帧信息 +struct StackFrameInfo { + int locals_size = 0; // 仅为AllocaInst分配的大小 + int spill_size = 0; // 仅为溢出分配的大小 + int total_size = 0; // 总大小 + std::map alloca_offsets; // + std::map spill_offsets; // <溢出vreg, 栈偏移> +}; + +// 机器函数 +class MachineFunction { +public: + MachineFunction(Function* func, RISCv64ISel* isel) : F(func), name(func->getName()), isel(isel) {} + + Function* getFunc() const { return F; } + RISCv64ISel* getISel() const { return isel; } + const std::string& getName() const { return name; } + StackFrameInfo& getFrameInfo() { return frame_info; } + const std::vector>& getBlocks() const { return blocks; } + std::vector>& getBlocks() { return blocks; } + + void addBlock(std::unique_ptr block) { + blocks.push_back(std::move(block)); + } +private: + Function* F; + RISCv64ISel* isel; // 指向创建它的ISel,用于获取vreg映射等信息 + std::string name; + std::vector> blocks; + StackFrameInfo frame_info; +}; + +} // namespace sysy + +#endif // RISCV64_LLIR_H \ No newline at end of file diff --git a/src/include/RISCv64Passes.h b/src/include/RISCv64Passes.h new file mode 100644 index 0000000..3a4bcd1 --- /dev/null +++ b/src/include/RISCv64Passes.h @@ -0,0 +1,18 @@ +// RISCv64Passes.h +#ifndef RISCV64_PASSES_H +#define RISCV64_PASSES_H + +#include "RISCv64LLIR.h" + +namespace sysy { + +// 此处为未来优化Pass的基类或独立类定义 +// 例如: +// class PeepholeOptimizer { +// public: +// void runOnMachineFunction(MachineFunction* mfunc); +// }; + +} // namespace sysy + +#endif // RISCV64_PASSES_H \ No newline at end of file diff --git a/src/include/RISCv64RegAlloc.h b/src/include/RISCv64RegAlloc.h new file mode 100644 index 0000000..c786bde --- /dev/null +++ b/src/include/RISCv64RegAlloc.h @@ -0,0 +1,56 @@ +#ifndef RISCV64_REGALLOC_H +#define RISCV64_REGALLOC_H + +#include "RISCv64LLIR.h" + +namespace sysy { + +class RISCv64RegAlloc { +public: + RISCv64RegAlloc(MachineFunction* mfunc); + + // 模块主入口 + void run(); + +private: + using LiveSet = std::set; // 活跃虚拟寄存器集合 + using InterferenceGraph = std::map>; + + // 栈帧管理 + void eliminateFrameIndices(); + + // 活跃性分析 + void analyzeLiveness(); + + // 构建干扰图 + void buildInterferenceGraph(); + + // 图着色分配寄存器 + void colorGraph(); + + // 重写函数,替换vreg并插入溢出代码 + void rewriteFunction(); + + // 辅助函数,获取指令的Use/Def集合 + void getInstrUseDef(MachineInstr* instr, LiveSet& use, LiveSet& def); + + MachineFunction* MFunc; + + // 活跃性分析结果 + std::map live_in_map; + std::map live_out_map; + + // 干扰图 + InterferenceGraph interference_graph; + + // 图着色结果 + std::map color_map; // vreg -> preg + std::set spilled_vregs; // 被溢出的vreg集合 + + // 可用的物理寄存器池 + std::vector allocable_int_regs; +}; + +} // namespace sysy + +#endif // RISCV64_REGALLOC_H \ No newline at end of file