diff --git a/src/backend/RISCv64/CMakeLists.txt b/src/backend/RISCv64/CMakeLists.txt index 4330e40..c86645e 100644 --- a/src/backend/RISCv64/CMakeLists.txt +++ b/src/backend/RISCv64/CMakeLists.txt @@ -5,6 +5,7 @@ add_library(riscv64_backend_lib STATIC RISCv64ISel.cpp RISCv64LLIR.cpp RISCv64RegAlloc.cpp + RISCv64LinearScan.cpp Handler/CalleeSavedHandler.cpp Handler/LegalizeImmediates.cpp Handler/PrologueEpilogueInsertion.cpp diff --git a/src/backend/RISCv64/RISCv64AsmPrinter.cpp b/src/backend/RISCv64/RISCv64AsmPrinter.cpp index 4dd8fd8..5a01b8e 100644 --- a/src/backend/RISCv64/RISCv64AsmPrinter.cpp +++ b/src/backend/RISCv64/RISCv64AsmPrinter.cpp @@ -82,7 +82,7 @@ void RISCv64AsmPrinter::printInstruction(MachineInstr* instr, bool debug) { case RVOpcodes::SB: *OS << "sb "; break; case RVOpcodes::LD: *OS << "ld "; break; case RVOpcodes::SD: *OS << "sd "; break; case RVOpcodes::FLW: *OS << "flw "; break; case RVOpcodes::FSW: *OS << "fsw "; break; case RVOpcodes::FLD: *OS << "fld "; break; - case RVOpcodes::FSD: *OS << "fsd "; break; + case RVOpcodes::FSD: *OS << "fsd "; break; case RVOpcodes::J: *OS << "j "; break; case RVOpcodes::JAL: *OS << "jal "; break; case RVOpcodes::JALR: *OS << "jalr "; break; case RVOpcodes::RET: *OS << "ret"; break; case RVOpcodes::BEQ: *OS << "beq "; break; case RVOpcodes::BNE: *OS << "bne "; break; @@ -102,6 +102,7 @@ void RISCv64AsmPrinter::printInstruction(MachineInstr* instr, bool debug) { case RVOpcodes::FLE_S: *OS << "fle.s "; break; case RVOpcodes::FCVT_S_W: *OS << "fcvt.s.w "; break; case RVOpcodes::FCVT_W_S: *OS << "fcvt.w.s "; break; + case RVOpcodes::FCVT_W_S_RTZ: *OS << "fcvt.w.s "; break; case RVOpcodes::FMV_S: *OS << "fmv.s "; break; case RVOpcodes::FMV_W_X: *OS << "fmv.w.x "; break; case RVOpcodes::FMV_X_W: *OS << "fmv.x.w "; break; diff --git a/src/backend/RISCv64/RISCv64Backend.cpp b/src/backend/RISCv64/RISCv64Backend.cpp index 7e08102..b78d02e 100644 --- a/src/backend/RISCv64/RISCv64Backend.cpp +++ b/src/backend/RISCv64/RISCv64Backend.cpp @@ -1,10 +1,13 @@ #include "RISCv64Backend.h" #include "RISCv64ISel.h" #include "RISCv64RegAlloc.h" +#include "RISCv64LinearScan.h" // <--- 新增此行 #include "RISCv64AsmPrinter.h" #include "RISCv64Passes.h" #include - +#include // <--- 新增此行 +#include // <--- 新增此行 +#include // <--- 新增此行,用于打印超时警告 namespace sysy { // 顶层入口 @@ -196,9 +199,6 @@ std::string RISCv64CodeGen::function_gen(Function* func) { // === 完整的后端处理流水线 === // 阶段 1: 指令选择 (sysy::IR -> LLIR with virtual registers) - DEBUG = 0; - DEEPDEBUG = 0; - RISCv64ISel isel; std::unique_ptr mfunc = isel.runOnFunction(func); @@ -206,9 +206,7 @@ std::string RISCv64CodeGen::function_gen(Function* func) { std::stringstream ss_after_isel; RISCv64AsmPrinter printer_isel(mfunc.get()); printer_isel.run(ss_after_isel, true); - if (DEBUG) { - std::cout << ss_after_isel.str(); - } + if (DEBUG) { std::cerr << "====== Intermediate Representation after Instruction Selection ======\n" << ss_after_isel.str(); @@ -228,13 +226,13 @@ std::string RISCv64CodeGen::function_gen(Function* func) { << ss_after_eli.str(); } - // 阶段 2: 除法强度削弱优化 (Division Strength Reduction) - DivStrengthReduction div_strength_reduction; - div_strength_reduction.runOnMachineFunction(mfunc.get()); + // // 阶段 2: 除法强度削弱优化 (Division Strength Reduction) + // DivStrengthReduction div_strength_reduction; + // div_strength_reduction.runOnMachineFunction(mfunc.get()); - // 阶段 2.1: 指令调度 (Instruction Scheduling) - PreRA_Scheduler scheduler; - scheduler.runOnMachineFunction(mfunc.get()); + // // 阶段 2.1: 指令调度 (Instruction Scheduling) + // PreRA_Scheduler scheduler; + // scheduler.runOnMachineFunction(mfunc.get()); // 阶段 3: 物理寄存器分配 (Register Allocation) RISCv64RegAlloc reg_alloc(mfunc.get()); @@ -254,9 +252,9 @@ std::string RISCv64CodeGen::function_gen(Function* func) { mfunc->dumpStackFrameInfo(std::cerr); } - // 阶段 4: 窥孔优化 (Peephole Optimization) - PeepholeOptimizer peephole; - peephole.runOnMachineFunction(mfunc.get()); + // // 阶段 4: 窥孔优化 (Peephole Optimization) + // PeepholeOptimizer peephole; + // peephole.runOnMachineFunction(mfunc.get()); // 阶段 5: 局部指令调度 (Local Scheduling) PostRA_Scheduler local_scheduler; @@ -276,7 +274,6 @@ std::string RISCv64CodeGen::function_gen(Function* func) { printer.run(ss); return ss.str(); - } } // namespace sysy \ No newline at end of file diff --git a/src/backend/RISCv64/RISCv64ISel.cpp b/src/backend/RISCv64/RISCv64ISel.cpp index dad1bbb..21e8e3c 100644 --- a/src/backend/RISCv64/RISCv64ISel.cpp +++ b/src/backend/RISCv64/RISCv64ISel.cpp @@ -745,83 +745,12 @@ void RISCv64ISel::selectNode(DAGNode* node) { CurMBB->addInstruction(std::move(instr)); break; } - case Instruction::kFtoI: { // 浮点 to 整数 (带向下取整) - // 目标:实现 floor(x) 的效果, C/C++中浮点转整数是截断(truncate) - // 对于正数,floor(x) == truncate(x) - // RISC-V的 fcvt.w.s 默认是“四舍五入到偶数” - // 我们需要手动实现截断逻辑 - // 逻辑: - // temp_i = fcvt.w.s(x) // 四舍五入 - // temp_f = fcvt.s.w(temp_i) // 转回浮点 - // if (x < temp_f) { // 如果原数更小,说明被“五入”了 - // result = temp_i - 1 - // } else { - // result = temp_i - // } - - auto temp_i_vreg = getNewVReg(Type::getIntType()); - auto temp_f_vreg = getNewVReg(Type::getFloatType()); - auto cmp_vreg = getNewVReg(Type::getIntType()); - - // 1. fcvt.w.s temp_i_vreg, src_vreg - auto fcvt_w = std::make_unique(RVOpcodes::FCVT_W_S); - fcvt_w->addOperand(std::make_unique(temp_i_vreg)); - fcvt_w->addOperand(std::make_unique(src_vreg)); - CurMBB->addInstruction(std::move(fcvt_w)); - - // 2. fcvt.s.w temp_f_vreg, temp_i_vreg - auto fcvt_s = std::make_unique(RVOpcodes::FCVT_S_W); - fcvt_s->addOperand(std::make_unique(temp_f_vreg)); - fcvt_s->addOperand(std::make_unique(temp_i_vreg)); - CurMBB->addInstruction(std::move(fcvt_s)); - - // 3. flt.s cmp_vreg, src_vreg, temp_f_vreg - auto flt = std::make_unique(RVOpcodes::FLT_S); - flt->addOperand(std::make_unique(cmp_vreg)); - flt->addOperand(std::make_unique(src_vreg)); - flt->addOperand(std::make_unique(temp_f_vreg)); - CurMBB->addInstruction(std::move(flt)); - - // 创建标签 - int unique_id = this->local_label_counter++; - std::string rounded_up_label = MFunc->getName() + "_ftoi_rounded_up_" + std::to_string(unique_id); - std::string done_label = MFunc->getName() + "_ftoi_done_" + std::to_string(unique_id); - - // 4. bne cmp_vreg, x0, rounded_up_label - auto bne = std::make_unique(RVOpcodes::BNE); - bne->addOperand(std::make_unique(cmp_vreg)); - bne->addOperand(std::make_unique(PhysicalReg::ZERO)); - bne->addOperand(std::make_unique(rounded_up_label)); - CurMBB->addInstruction(std::move(bne)); - - // 5. else 分支: mv dest_vreg, temp_i_vreg - auto mv = std::make_unique(RVOpcodes::MV); - mv->addOperand(std::make_unique(dest_vreg)); - mv->addOperand(std::make_unique(temp_i_vreg)); - CurMBB->addInstruction(std::move(mv)); - - // 6. j done_label - auto j = std::make_unique(RVOpcodes::J); - j->addOperand(std::make_unique(done_label)); - CurMBB->addInstruction(std::move(j)); - - // 7. rounded_up_label: - auto label_up = std::make_unique(RVOpcodes::LABEL); - label_up->addOperand(std::make_unique(rounded_up_label)); - CurMBB->addInstruction(std::move(label_up)); - - // 8. addiw dest_vreg, temp_i_vreg, -1 - auto addi = std::make_unique(RVOpcodes::ADDIW); - addi->addOperand(std::make_unique(dest_vreg)); - addi->addOperand(std::make_unique(temp_i_vreg)); - addi->addOperand(std::make_unique(-1)); - CurMBB->addInstruction(std::move(addi)); - - // 9. done_label: - auto label_done = std::make_unique(RVOpcodes::LABEL); - label_done->addOperand(std::make_unique(done_label)); - CurMBB->addInstruction(std::move(label_done)); - + case Instruction::kFtoI: { // 浮点 to 整数 (使用硬件指令进行向零截断) + // 直接生成一条带有 rtz 舍入模式的转换指令 + auto instr = std::make_unique(RVOpcodes::FCVT_W_S_RTZ); + instr->addOperand(std::make_unique(dest_vreg)); // 目标是整数vreg + instr->addOperand(std::make_unique(src_vreg)); // 源是浮点vreg + CurMBB->addInstruction(std::move(instr)); break; } case Instruction::kFNeg: { // 浮点取负 @@ -1202,10 +1131,11 @@ void RISCv64ISel::selectNode(DAGNode* node) { auto r_value_byte = getVReg(memset->getValue()); // 为memset内部逻辑创建新的临时虚拟寄存器 - auto r_counter = getNewVReg(); - auto r_end_addr = getNewVReg(); - auto r_current_addr = getNewVReg(); - auto r_temp_val = getNewVReg(); + Type* ptr_type = Type::getPointerType(Type::getIntType()); + auto r_counter = getNewVReg(ptr_type); + auto r_end_addr = getNewVReg(ptr_type); + auto r_current_addr = getNewVReg(ptr_type); + auto r_temp_val = getNewVReg(Type::getIntType()); // 定义一系列lambda表达式来简化指令创建 auto add_instr = [&](RVOpcodes op, unsigned rd, unsigned rs1, unsigned rs2) { @@ -1296,7 +1226,7 @@ void RISCv64ISel::selectNode(DAGNode* node) { // --- Step 1: 获取基地址 (此部分逻辑正确,保持不变) --- auto base_ptr_node = node->operands[0]; - auto current_addr_vreg = getNewVReg(); + auto current_addr_vreg = getNewVReg(gep->getType()); if (auto alloca_base = dynamic_cast(base_ptr_node->value)) { auto frame_addr_instr = std::make_unique(RVOpcodes::FRAME_ADDR); @@ -1338,13 +1268,13 @@ void RISCv64ISel::selectNode(DAGNode* node) { // 如果步长为0(例如对一个void类型或空结构体索引),则不产生任何偏移 if (stride != 0) { // --- 为当前索引和步长生成偏移计算指令 --- - auto offset_vreg = getNewVReg(); + auto offset_vreg = getNewVReg(Type::getIntType()); // 处理索引 - 区分常量与动态值 unsigned index_vreg; if (auto const_index = dynamic_cast(indexValue)) { // 对于常量索引,直接创建新的虚拟寄存器 - index_vreg = getNewVReg(); + index_vreg = getNewVReg(Type::getIntType()); auto li = std::make_unique(RVOpcodes::LI); li->addOperand(std::make_unique(index_vreg)); li->addOperand(std::make_unique(const_index->getInt())); @@ -1362,7 +1292,7 @@ void RISCv64ISel::selectNode(DAGNode* node) { CurMBB->addInstruction(std::move(mv)); } else { // 步长不为1,需要生成乘法指令 - auto size_vreg = getNewVReg(); + auto size_vreg = getNewVReg(Type::getIntType()); auto li_size = std::make_unique(RVOpcodes::LI); li_size->addOperand(std::make_unique(size_vreg)); li_size->addOperand(std::make_unique(stride)); diff --git a/src/backend/RISCv64/RISCv64LinearScan.cpp b/src/backend/RISCv64/RISCv64LinearScan.cpp new file mode 100644 index 0000000..43e253b --- /dev/null +++ b/src/backend/RISCv64/RISCv64LinearScan.cpp @@ -0,0 +1,517 @@ +#include "RISCv64LinearScan.h" +#include "RISCv64LLIR.h" +#include "RISCv64ISel.h" +#include +#include + +extern int DEBUG; + +namespace sysy { + +RISCv64LinearScan::RISCv64LinearScan(MachineFunction* mfunc) + : MFunc(mfunc), + ISel(mfunc->getISel()), + vreg_type_map(ISel->getVRegTypeMap()) { + + // 初始化可用的物理寄存器池,与图着色版本保持一致 + // 整数寄存器 + allocable_int_regs = { + PhysicalReg::T0, PhysicalReg::T1, PhysicalReg::T2, PhysicalReg::T3, PhysicalReg::T4, /*T5保留作为大立即数加载寄存器*/ PhysicalReg::T6, + PhysicalReg::A0, PhysicalReg::A1, PhysicalReg::A2, PhysicalReg::A3, PhysicalReg::A4, PhysicalReg::A5, PhysicalReg::A6, PhysicalReg::A7, + PhysicalReg::S1, PhysicalReg::S2, PhysicalReg::S3, PhysicalReg::S4, PhysicalReg::S5, PhysicalReg::S6, PhysicalReg::S7, + PhysicalReg::S8, PhysicalReg::S9, PhysicalReg::S10, PhysicalReg::S11, + }; + // 浮点寄存器 + allocable_fp_regs = { + PhysicalReg::F0, PhysicalReg::F1, PhysicalReg::F2, PhysicalReg::F3, PhysicalReg::F4, PhysicalReg::F5, PhysicalReg::F6, PhysicalReg::F7, + PhysicalReg::F10, PhysicalReg::F11, PhysicalReg::F12, PhysicalReg::F13, PhysicalReg::F14, PhysicalReg::F15, PhysicalReg::F16, PhysicalReg::F17, + PhysicalReg::F8, PhysicalReg::F9, PhysicalReg::F18, PhysicalReg::F19, PhysicalReg::F20, PhysicalReg::F21, PhysicalReg::F22, + PhysicalReg::F23, PhysicalReg::F24, PhysicalReg::F25, PhysicalReg::F26, PhysicalReg::F27, + PhysicalReg::F28, PhysicalReg::F29, PhysicalReg::F30, PhysicalReg::F31, + }; + // 新增:识别所有通过寄存器传递的参数,并建立vreg到物理寄存器(preg)的映射 + // 这等同于图着色算法中的“预着色”步骤。 + if (MFunc->getFunc()) { + int int_arg_idx = 0; + int fp_arg_idx = 0; + for (Argument* arg : MFunc->getFunc()->getArguments()) { + unsigned arg_vreg = ISel->getVReg(arg); + if (arg->getType()->isFloat()) { + if (fp_arg_idx < 8) { // fa0-fa7 + auto preg = static_cast(static_cast(PhysicalReg::F10) + fp_arg_idx); + abi_vreg_map[arg_vreg] = preg; + fp_arg_idx++; + } + } else { // 整数或指针 + if (int_arg_idx < 8) { // a0-a7 + auto preg = static_cast(static_cast(PhysicalReg::A0) + int_arg_idx); + abi_vreg_map[arg_vreg] = preg; + int_arg_idx++; + } + } + } + } +} + +void RISCv64LinearScan::run() { + if (DEBUG) std::cerr << "===== Running Linear Scan Register Allocation for function: " << MFunc->getName() << " =====\n"; + + bool changed = true; + while(changed) { + // 1. 准备阶段 + linearizeBlocks(); + computeLiveIntervals(); + + // 2. 执行线性扫描 + changed = linearScan(); + + // 3. 如果有溢出,重写代码,然后下一轮重新开始 + if (changed) { + rewriteProgram(); + if (DEBUG) std::cerr << "--- Spilling detected, re-running linear scan ---\n"; + } + } + + // 4. 将最终分配结果应用到机器指令 + applyAllocation(); + // 5. 收集用到的被调用者保存寄存器 + MFunc->getFrameInfo().vreg_to_preg_map = this->vreg_to_preg_map; + collectUsedCalleeSavedRegs(); + + if (DEBUG) std::cerr << "===== Finished Linear Scan Register Allocation =====\n\n"; +} + +// 步骤 1.1: 对基本块进行线性化,这里我们简单地按现有顺序排列 +void RISCv64LinearScan::linearizeBlocks() { + linear_order_blocks.clear(); + for (auto& mbb : MFunc->getBlocks()) { + linear_order_blocks.push_back(mbb.get()); + } +} + +// RISCv64LinearScan.cpp + +void RISCv64LinearScan::computeLiveIntervals() { + instr_numbering.clear(); + live_intervals.clear(); + unhandled.clear(); + + // a. 对所有指令进行线性编号,并记录CALL指令的位置 + int num = 0; + std::set call_locations; + for (auto* mbb : linear_order_blocks) { + for (auto& instr : mbb->getInstructions()) { + instr_numbering[instr.get()] = num; + if (instr->getOpcode() == RVOpcodes::CALL) { + call_locations.insert(num); + } + num += 2; // 指令编号间隔为2,方便在溢出重写时插入指令 + } + } + + // b. 遍历所有指令,记录每个vreg首次和末次出现的位置 + std::map> vreg_ranges; // vreg -> {first_instr_num, last_instr_num} + + for (auto* mbb : linear_order_blocks) { + for (auto& instr_ptr : mbb->getInstructions()) { + const MachineInstr* instr = instr_ptr.get(); + int instr_num = instr_numbering.at(instr); + std::set use, def; + getInstrUseDef(instr, use, def); + + auto all_vregs = use; + all_vregs.insert(def.begin(), def.end()); + + for (unsigned vreg : all_vregs) { + if (vreg_ranges.find(vreg) == vreg_ranges.end()) { + vreg_ranges[vreg] = {instr_num, instr_num}; + } else { + vreg_ranges[vreg].second = std::max(vreg_ranges[vreg].second, instr_num); + } + } + } + } + + // c. 根据记录的边界,创建LiveInterval对象,并检查是否跨越CALL + for (auto const& [vreg, range] : vreg_ranges) { + live_intervals.emplace(vreg, LiveInterval(vreg)); + auto& interval = live_intervals.at(vreg); + interval.start = range.first; + interval.end = range.second; + + // 检查此区间是否跨越了任何CALL指令 + auto it = call_locations.lower_bound(interval.start); + if (it != call_locations.end() && *it < interval.end) { + interval.crosses_call = true; + } + } + + // d. 将所有计算出的活跃区间放入 unhandled 列表 + for (auto& pair : live_intervals) { + unhandled.push_back(&pair.second); + } + std::sort(unhandled.begin(), unhandled.end(), [](const LiveInterval* a, const LiveInterval* b){ + return a->start < b->start; + }); +} + +// RISCv64LinearScan.cpp + +// 在类的定义中添加一个辅助函数来判断寄存器类型 +bool isCalleeSaved(PhysicalReg preg) { + if (preg >= PhysicalReg::S1 && preg <= PhysicalReg::S11) return true; + if (preg == PhysicalReg::S0) return true; // s0 通常也作为被调用者保存 + // 浮点寄存器 + if (preg >= PhysicalReg::F8 && preg <= PhysicalReg::F9) return true; + if (preg >= PhysicalReg::F18 && preg <= PhysicalReg::F27) return true; + return false; +} + +// 线性扫描主算法 +bool RISCv64LinearScan::linearScan() { + active.clear(); + spilled_vregs.clear(); + vreg_to_preg_map.clear(); + + // 将寄存器池分为调用者保存和被调用者保存两类 + std::set free_caller_int_regs, free_callee_int_regs; + std::set free_caller_fp_regs, free_callee_fp_regs; + + for (auto preg : allocable_int_regs) { + if (isCalleeSaved(preg)) free_callee_int_regs.insert(preg); + else free_caller_int_regs.insert(preg); + } + for (auto preg : allocable_fp_regs) { + if (isCalleeSaved(preg)) free_callee_fp_regs.insert(preg); + else free_caller_fp_regs.insert(preg); + } + + // 预处理ABI参数寄存器 + vreg_to_preg_map.insert(abi_vreg_map.begin(), abi_vreg_map.end()); + std::vector normal_unhandled; + for(LiveInterval* interval : unhandled) { + if(abi_vreg_map.count(interval->vreg)) { + active.push_back(interval); + PhysicalReg preg = abi_vreg_map.at(interval->vreg); + if (isFPVReg(interval->vreg)) { + if(isCalleeSaved(preg)) free_callee_fp_regs.erase(preg); else free_caller_fp_regs.erase(preg); + } else { + if(isCalleeSaved(preg)) free_callee_int_regs.erase(preg); else free_caller_int_regs.erase(preg); + } + } else { + normal_unhandled.push_back(interval); + } + } + unhandled = normal_unhandled; + std::sort(active.begin(), active.end(), [](const LiveInterval* a, const LiveInterval* b){ return a->end < b->end; }); + + // 主循环 + for (LiveInterval* current : unhandled) { + // a. 释放active列表中已结束的区间 + std::vector new_active; + for (LiveInterval* active_interval : active) { + if (active_interval->end < current->start) { + PhysicalReg preg = vreg_to_preg_map.at(active_interval->vreg); + if (isFPVReg(active_interval->vreg)) { + if(isCalleeSaved(preg)) free_callee_fp_regs.insert(preg); else free_caller_fp_regs.insert(preg); + } else { + if(isCalleeSaved(preg)) free_callee_int_regs.insert(preg); else free_caller_int_regs.insert(preg); + } + } else { + new_active.push_back(active_interval); + } + } + active = new_active; + + // b. 约束化地为当前区间分配寄存器 + bool is_fp = isFPVReg(current->vreg); + auto& free_caller = is_fp ? free_caller_fp_regs : free_caller_int_regs; + auto& free_callee = is_fp ? free_callee_fp_regs : free_callee_int_regs; + + PhysicalReg allocated_preg = PhysicalReg::INVALID; + + if (current->crosses_call) { + // 跨调用区间:必须使用被调用者保存寄存器 + if (!free_callee.empty()) { + allocated_preg = *free_callee.begin(); + free_callee.erase(allocated_preg); + } + } else { + // 非跨调用区间:优先使用调用者保存寄存器 + if (!free_caller.empty()) { + allocated_preg = *free_caller.begin(); + free_caller.erase(allocated_preg); + } else if (!free_callee.empty()) { + allocated_preg = *free_callee.begin(); + free_callee.erase(allocated_preg); + } + } + + if (allocated_preg != PhysicalReg::INVALID) { + vreg_to_preg_map[current->vreg] = allocated_preg; + active.push_back(current); + std::sort(active.begin(), active.end(), [](const LiveInterval* a, const LiveInterval* b){ return a->end < b->end; }); + } else { + // c. 没有可用寄存器,需要溢出 + spillAtInterval(current); + } + } + return !spilled_vregs.empty(); +} + +void RISCv64LinearScan::chooseRegForInterval(LiveInterval* current) { + bool is_fp = isFPVReg(current->vreg); + auto& free_regs = is_fp ? free_fp_regs : free_int_regs; + + if (!free_regs.empty()) { + // 有可用寄存器 + PhysicalReg preg = *free_regs.begin(); + free_regs.erase(free_regs.begin()); + vreg_to_preg_map[current->vreg] = preg; + active.push_back(current); + // 保持 active 列表按结束点排序 + std::sort(active.begin(), active.end(), [](const LiveInterval* a, const LiveInterval* b){ + return a->end < b->end; + }); + } else { + // 没有可用寄存器,需要溢出 + spillAtInterval(current); + } +} + +void RISCv64LinearScan::spillAtInterval(LiveInterval* current) { + LiveInterval* spill_candidate = nullptr; + // 启发式溢出: + // 如果current需要callee-saved,则从active中找一个占用callee-saved且结束最晚的区间比较 + // 否则,找active中结束最晚的区间 + // 这里简化处理:总是找active中结束最晚的区间 + auto last_active = active.back(); + + if (last_active->end > current->end) { + // 溢出active中的区间 + spill_candidate = last_active; + PhysicalReg preg = vreg_to_preg_map.at(spill_candidate->vreg); + vreg_to_preg_map[current->vreg] = preg; // 把换出的寄存器给current + // 更新active列表 + active.pop_back(); + active.push_back(current); + std::sort(active.begin(), active.end(), [](const LiveInterval* a, const LiveInterval* b){ return a->end < b->end; }); + spilled_vregs.insert(spill_candidate->vreg); + } else { + // 溢出当前区间 + spilled_vregs.insert(current->vreg); + } +} + +// 步骤 3: 重写程序,插入溢出代码 +void RISCv64LinearScan::rewriteProgram() { + StackFrameInfo& frame_info = MFunc->getFrameInfo(); + int spill_offset = frame_info.locals_size; // 溢出区域接在局部变量之后 + + for (unsigned vreg : spilled_vregs) { + if (frame_info.spill_offsets.count(vreg)) continue; // 避免重复分配 + + int size = isFPVReg(vreg) ? 4 : (vreg_type_map.at(vreg)->isPointer() ? 8 : 4); + spill_offset += size; + spill_offset = (spill_offset + 7) & ~7; // 8字节对齐 + frame_info.spill_offsets[vreg] = -(16 + spill_offset); + } + frame_info.spill_size = spill_offset - frame_info.locals_size; + + for (auto& mbb : MFunc->getBlocks()) { + auto& instrs = mbb->getInstructions(); + std::vector> new_instrs; + + for (auto it = instrs.begin(); it != instrs.end(); ++it) { + auto& instr = *it; + std::set use_vregs, def_vregs; + getInstrUseDef(instr.get(), use_vregs, def_vregs); + + // 建立溢出vreg到新临时vreg的映射 + std::map use_remap; + std::map def_remap; + + // 1. 为所有溢出的USE创建LOAD指令和映射 + for (unsigned old_vreg : use_vregs) { + if (spilled_vregs.count(old_vreg) && use_remap.find(old_vreg) == use_remap.end()) { + Type* type = vreg_type_map.at(old_vreg); + unsigned new_temp_vreg = ISel->getNewVReg(type); + use_remap[old_vreg] = new_temp_vreg; + + RVOpcodes load_op = isFPVReg(old_vreg) ? RVOpcodes::FLW : (type->isPointer() ? RVOpcodes::LD : RVOpcodes::LW); + auto load = std::make_unique(load_op); + load->addOperand(std::make_unique(new_temp_vreg)); + load->addOperand(std::make_unique( + std::make_unique(PhysicalReg::S0), + std::make_unique(frame_info.spill_offsets.at(old_vreg)) + )); + new_instrs.push_back(std::move(load)); + } + } + + // 2. 为所有溢出的DEF创建映射 + for (unsigned old_vreg : def_vregs) { + if (spilled_vregs.count(old_vreg) && def_remap.find(old_vreg) == def_remap.end()) { + Type* type = vreg_type_map.at(old_vreg); + unsigned new_temp_vreg = ISel->getNewVReg(type); + def_remap[old_vreg] = new_temp_vreg; + } + } + + // 3. 基于角色精确地替换原指令中的操作数 + auto opcode = instr->getOpcode(); + auto& operands = instr->getOperands(); + + auto replace_reg_op = [](RegOperand* reg_op, const std::map& remap) { + if (reg_op->isVirtual() && remap.count(reg_op->getVRegNum())) { + reg_op->setVRegNum(remap.at(reg_op->getVRegNum())); + } + }; + + if (op_info.count(opcode)) { + const auto& info = op_info.at(opcode); + // 替换 Defs + for (int idx : info.first) { + if (idx < operands.size() && operands[idx]->getKind() == MachineOperand::KIND_REG) { + replace_reg_op(static_cast(operands[idx].get()), def_remap); + } + } + // 替换 Uses + for (int idx : info.second) { + if (idx < operands.size()) { + if (operands[idx]->getKind() == MachineOperand::KIND_REG) { + replace_reg_op(static_cast(operands[idx].get()), use_remap); + } else if (operands[idx]->getKind() == MachineOperand::KIND_MEM) { + replace_reg_op(static_cast(operands[idx].get())->getBase(), use_remap); + } + } + } + } else if (opcode == RVOpcodes::CALL) { + // 特殊处理 CALL 指令 + if (!operands.empty() && operands[0]->getKind() == MachineOperand::KIND_REG) { + replace_reg_op(static_cast(operands[0].get()), def_remap); + } + for (size_t i = 1; i < operands.size(); ++i) { + if (operands[i]->getKind() == MachineOperand::KIND_REG) { + replace_reg_op(static_cast(operands[i].get()), use_remap); + } + } + } + + // 4. 将修改后的指令放入新列表 + new_instrs.push_back(std::move(instr)); + + // 5. 为所有溢出的DEF创建STORE指令 + for(const auto& pair : def_remap) { + unsigned old_vreg = pair.first; + unsigned new_temp_vreg = pair.second; + Type* type = vreg_type_map.at(old_vreg); + RVOpcodes store_op = isFPVReg(old_vreg) ? RVOpcodes::FSW : (type->isPointer() ? RVOpcodes::SD : RVOpcodes::SW); + auto store = std::make_unique(store_op); + store->addOperand(std::make_unique(new_temp_vreg)); + store->addOperand(std::make_unique( + std::make_unique(PhysicalReg::S0), + std::make_unique(frame_info.spill_offsets.at(old_vreg)) + )); + new_instrs.push_back(std::move(store)); + } + } + instrs = std::move(new_instrs); + } +} + + +// 步骤 4: 应用最终分配结果 +void RISCv64LinearScan::applyAllocation() { + for (auto& mbb : MFunc->getBlocks()) { + for (auto& instr_ptr : mbb->getInstructions()) { + for (auto& op_ptr : instr_ptr->getOperands()) { + if (op_ptr->getKind() == MachineOperand::KIND_REG) { + auto reg_op = static_cast(op_ptr.get()); + if (reg_op->isVirtual()) { + unsigned vreg = reg_op->getVRegNum(); + if (vreg_to_preg_map.count(vreg)) { + reg_op->setPReg(vreg_to_preg_map.at(vreg)); + } else { + // 如果一个vreg最终没有颜色,这通常意味着它是一个短生命周期的临时变量 + // 在溢出重写中产生,但在下一轮分配前就被优化掉了。 + // 或者是一个从未被使用的定义。 + // 给他一个临时寄存器以防万一。 + reg_op->setPReg(PhysicalReg::T5); + } + } + } else if (op_ptr->getKind() == MachineOperand::KIND_MEM) { + auto mem_op = static_cast(op_ptr.get()); + auto reg_op = mem_op->getBase(); + if (reg_op->isVirtual()) { + unsigned vreg = reg_op->getVRegNum(); + if (vreg_to_preg_map.count(vreg)) { + reg_op->setPReg(vreg_to_preg_map.at(vreg)); + } else { + reg_op->setPReg(PhysicalReg::T5); + } + } + } + } + } + } +} + +void RISCv64LinearScan::getInstrUseDef(const MachineInstr* instr, std::set& use, std::set& def) { + // 这个函数与图着色版本中的 getInstrUseDef 逻辑完全相同,此处直接复用 + auto opcode = instr->getOpcode(); + const auto& operands = instr->getOperands(); + + // op_info 的定义已被移到函数外部的命名空间中 + + auto get_vreg_id_if_virtual = [&](const MachineOperand* op, std::set& s) { + if (op->getKind() == MachineOperand::KIND_REG) { + auto reg_op = static_cast(op); + if (reg_op->isVirtual()) s.insert(reg_op->getVRegNum()); + } else if (op->getKind() == MachineOperand::KIND_MEM) { + auto mem_op = static_cast(op); + auto reg_op = mem_op->getBase(); + if (reg_op->isVirtual()) s.insert(reg_op->getVRegNum()); + } + }; + + if (op_info.count(opcode)) { + const auto& info = op_info.at(opcode); + for (int idx : info.first) if (idx < operands.size()) get_vreg_id_if_virtual(operands[idx].get(), def); + for (int idx : info.second) if (idx < operands.size()) get_vreg_id_if_virtual(operands[idx].get(), use); + // MemOperand 的基址寄存器总是一个 use + for (const auto& op : operands) if (op->getKind() == MachineOperand::KIND_MEM) get_vreg_id_if_virtual(op.get(), use); + } else if (opcode == RVOpcodes::CALL) { + // CALL指令的特殊处理 + // 第一个操作数(如果有)是def(返回值) + if (!operands.empty() && operands[0]->getKind() == MachineOperand::KIND_REG) get_vreg_id_if_virtual(operands[0].get(), def); + // 后续的寄存器操作数是use(参数) + for (size_t i = 1; i < operands.size(); ++i) if (operands[i]->getKind() == MachineOperand::KIND_REG) get_vreg_id_if_virtual(operands[i].get(), use); + } +} + +// 辅助函数: 判断是否为浮点vreg +bool RISCv64LinearScan::isFPVReg(unsigned vreg) const { + return vreg_type_map.count(vreg) && vreg_type_map.at(vreg)->isFloat(); +} + +// 辅助函数: 收集被使用的被调用者保存寄存器 +void RISCv64LinearScan::collectUsedCalleeSavedRegs() { + StackFrameInfo& frame_info = MFunc->getFrameInfo(); + frame_info.used_callee_saved_regs.clear(); + + const auto& callee_saved_int = getCalleeSavedIntRegs(); + const auto& callee_saved_fp = getCalleeSavedFpRegs(); + std::set callee_saved_set(callee_saved_int.begin(), callee_saved_int.end()); + callee_saved_set.insert(callee_saved_fp.begin(), callee_saved_fp.end()); + callee_saved_set.insert(PhysicalReg::S0); // s0总是被用作帧指针 + + for(const auto& pair : vreg_to_preg_map) { + PhysicalReg preg = pair.second; + if(callee_saved_set.count(preg)) { + frame_info.used_callee_saved_regs.insert(preg); + } + } +} + +} // namespace sysy \ No newline at end of file diff --git a/src/backend/RISCv64/RISCv64RegAlloc.cpp b/src/backend/RISCv64/RISCv64RegAlloc.cpp index 6b2c341..633ecfb 100644 --- a/src/backend/RISCv64/RISCv64RegAlloc.cpp +++ b/src/backend/RISCv64/RISCv64RegAlloc.cpp @@ -55,41 +55,12 @@ void RISCv64RegAlloc::run() { if (DEBUG) std::cerr << "===== Running Graph Coloring Register Allocation for function: " << MFunc->getName() << " =====\n"; - const int MAX_ITERATIONS = 50; - int iteration = 0; - - while (iteration++ < MAX_ITERATIONS) { + while (true) { if (doAllocation()) { break; } else { rewriteProgram(); - if (DEBUG) std::cerr << "--- Spilling detected, re-running allocation (iteration " << iteration << ") ---\n"; - - if (iteration >= MAX_ITERATIONS) { - std::cerr << "ERROR: Register allocation failed to converge after " << MAX_ITERATIONS << " iterations\n"; - std::cerr << " Spill worklist size: " << spillWorklist.size() << "\n"; - std::cerr << " Total nodes: " << (initial.size() + coloredNodes.size()) << "\n"; - - // Emergency spill remaining nodes to break the loop - std::cerr << " Emergency spilling remaining spill worklist nodes...\n"; - for (unsigned node : spillWorklist) { - spilledNodes.insert(node); - } - - // Also spill any nodes that didn't get colors - std::set uncolored; - for (unsigned node : initial) { - if (color_map.find(node) == color_map.end()) { - uncolored.insert(node); - } - } - for (unsigned node : uncolored) { - spilledNodes.insert(node); - } - - // Force completion - break; - } + if (DEBUG) std::cerr << "--- Spilling detected, re-running allocation ---\n"; } } diff --git a/src/include/backend/RISCv64/RISCv64ISel.h b/src/include/backend/RISCv64/RISCv64ISel.h index e9bb27c..35fb7a7 100644 --- a/src/include/backend/RISCv64/RISCv64ISel.h +++ b/src/include/backend/RISCv64/RISCv64ISel.h @@ -22,7 +22,6 @@ public: // 公开接口,以便后续模块(如RegAlloc)可以查询或创建vreg unsigned getVReg(Value* val); - unsigned getNewVReg() { return vreg_counter++; } unsigned getNewVReg(Type* type); unsigned getVRegCounter() const; // 获取 vreg_map 的公共接口 diff --git a/src/include/backend/RISCv64/RISCv64LLIR.h b/src/include/backend/RISCv64/RISCv64LLIR.h index b2111ff..977b348 100644 --- a/src/include/backend/RISCv64/RISCv64LLIR.h +++ b/src/include/backend/RISCv64/RISCv64LLIR.h @@ -41,6 +41,8 @@ enum class PhysicalReg { // 假设 vreg_counter 不会达到这么大的值 PHYS_REG_START_ID = 1000000, PHYS_REG_END_ID = PHYS_REG_START_ID + 320, // 预留足够的空间 + + INVALID, ///< 无效寄存器标记 }; // RISC-V 指令操作码枚举 @@ -86,6 +88,7 @@ enum class RVOpcodes { // 浮点转换 FCVT_S_W, // fcvt.s.w rd, rs1 (有符号整数 -> 单精度浮点) FCVT_W_S, // fcvt.w.s rd, rs1 (单精度浮点 -> 有符号整数) + FCVT_W_S_RTZ, // fcvt.w.s rd, rs1, rtz (使用向零截断模式) // 浮点传送/移动 FMV_S, // fmv.s rd, rs1 (浮点寄存器之间) diff --git a/src/include/backend/RISCv64/RISCv64LinearScan.h b/src/include/backend/RISCv64/RISCv64LinearScan.h new file mode 100644 index 0000000..96cf5f6 --- /dev/null +++ b/src/include/backend/RISCv64/RISCv64LinearScan.h @@ -0,0 +1,104 @@ +#ifndef RISCV64_LINEARSCAN_H +#define RISCV64_LINEARSCAN_H + +#include "RISCv64LLIR.h" +#include "RISCv64ISel.h" +#include +#include +#include +#include + +namespace sysy { + +// 前向声明 +class MachineBasicBlock; +class MachineFunction; +class RISCv64ISel; + +/** + * @brief 表示一个虚拟寄存器的活跃区间。 + * 包含起始和结束指令编号。为了简化,我们不处理有“洞”的区间。 + */ +struct LiveInterval { + unsigned vreg = 0; + int start = -1; + int end = -1; + bool crosses_call = false; + + LiveInterval(unsigned vreg) : vreg(vreg) {} + + // 用于排序,按起始点从小到大 + bool operator<(const LiveInterval& other) const { + return start < other.start; + } +}; + +class RISCv64LinearScan { +public: + RISCv64LinearScan(MachineFunction* mfunc); + void run(); + +private: + // --- 核心算法流程 --- + void linearizeBlocks(); + void computeLiveIntervals(); + bool linearScan(); + void rewriteProgram(); + void applyAllocation(); + void chooseRegForInterval(LiveInterval* current); + void spillAtInterval(LiveInterval* current); + + // --- 辅助函数 --- + void getInstrUseDef(const MachineInstr* instr, std::set& use, std::set& def); + bool isFPVReg(unsigned vreg) const; + void collectUsedCalleeSavedRegs(); + + MachineFunction* MFunc; + RISCv64ISel* ISel; + + // --- 线性扫描数据结构 --- + std::vector linear_order_blocks; + std::map instr_numbering; + std::map live_intervals; + + std::vector unhandled; + std::vector active; // 活跃且已分配物理寄存器的区间 + + std::set spilled_vregs; // 记录在本轮被决定溢出的vreg + + // --- 寄存器池和分配结果 --- + std::vector allocable_int_regs; + std::vector allocable_fp_regs; + std::set free_int_regs; + std::set free_fp_regs; + std::map vreg_to_preg_map; + std::map abi_vreg_map; + + const std::map& vreg_type_map; +}; + +static const std::map, std::vector>> op_info = { + {RVOpcodes::ADD, {{0}, {1, 2}}}, {RVOpcodes::SUB, {{0}, {1, 2}}}, {RVOpcodes::MUL, {{0}, {1, 2}}}, + {RVOpcodes::DIV, {{0}, {1, 2}}}, {RVOpcodes::REM, {{0}, {1, 2}}}, {RVOpcodes::ADDW, {{0}, {1, 2}}}, + {RVOpcodes::SUBW, {{0}, {1, 2}}}, {RVOpcodes::MULW, {{0}, {1, 2}}}, {RVOpcodes::DIVW, {{0}, {1, 2}}}, + {RVOpcodes::REMW, {{0}, {1, 2}}}, {RVOpcodes::SLT, {{0}, {1, 2}}}, {RVOpcodes::SLTU, {{0}, {1, 2}}}, + {RVOpcodes::ADDI, {{0}, {1}}}, {RVOpcodes::ADDIW, {{0}, {1}}}, {RVOpcodes::XORI, {{0}, {1}}}, + {RVOpcodes::SLTI, {{0}, {1}}}, {RVOpcodes::SLTIU, {{0}, {1}}}, {RVOpcodes::LB, {{0}, {}}}, + {RVOpcodes::LH, {{0}, {}}}, {RVOpcodes::LW, {{0}, {}}}, {RVOpcodes::LD, {{0}, {}}}, + {RVOpcodes::LBU, {{0}, {}}}, {RVOpcodes::LHU, {{0}, {}}}, {RVOpcodes::LWU, {{0}, {}}}, + {RVOpcodes::FLW, {{0}, {}}}, {RVOpcodes::FLD, {{0}, {}}}, {RVOpcodes::SB, {{}, {0, 1}}}, + {RVOpcodes::SH, {{}, {0, 1}}}, {RVOpcodes::SW, {{}, {0, 1}}}, {RVOpcodes::SD, {{}, {0, 1}}}, + {RVOpcodes::FSW, {{}, {0, 1}}}, {RVOpcodes::FSD, {{}, {0, 1}}}, {RVOpcodes::BEQ, {{}, {0, 1}}}, + {RVOpcodes::BNE, {{}, {0, 1}}}, {RVOpcodes::BLT, {{}, {0, 1}}}, {RVOpcodes::BGE, {{}, {0, 1}}}, + {RVOpcodes::JALR, {{0}, {1}}}, {RVOpcodes::LI, {{0}, {}}}, {RVOpcodes::LA, {{0}, {}}}, + {RVOpcodes::MV, {{0}, {1}}}, {RVOpcodes::SEQZ, {{0}, {1}}}, {RVOpcodes::SNEZ, {{0}, {1}}}, + {RVOpcodes::RET, {{}, {}}}, {RVOpcodes::FADD_S, {{0}, {1, 2}}}, {RVOpcodes::FSUB_S, {{0}, {1, 2}}}, + {RVOpcodes::FMUL_S, {{0}, {1, 2}}}, {RVOpcodes::FDIV_S, {{0}, {1, 2}}}, {RVOpcodes::FEQ_S, {{0}, {1, 2}}}, + {RVOpcodes::FLT_S, {{0}, {1, 2}}}, {RVOpcodes::FLE_S, {{0}, {1, 2}}}, {RVOpcodes::FCVT_S_W, {{0}, {1}}}, + {RVOpcodes::FCVT_W_S, {{0}, {1}}}, {RVOpcodes::FMV_S, {{0}, {1}}}, {RVOpcodes::FMV_W_X, {{0}, {1}}}, + {RVOpcodes::FMV_X_W, {{0}, {1}}}, {RVOpcodes::FNEG_S, {{0}, {1}}} +}; + +} // namespace sysy + +#endif // RISCV64_LINEARSCAN_H \ No newline at end of file