#include "RISCv64RegAlloc.h" #include "RISCv64ISel.h" #include #include namespace sysy { RISCv64RegAlloc::RISCv64RegAlloc(MachineFunction* mfunc) : MFunc(mfunc) { allocable_int_regs = { PhysicalReg::T0, PhysicalReg::T1, PhysicalReg::T2, PhysicalReg::T3, PhysicalReg::T4, PhysicalReg::T5, PhysicalReg::T6, PhysicalReg::A0, PhysicalReg::A1, PhysicalReg::A2, PhysicalReg::A3, PhysicalReg::A4, PhysicalReg::A5, PhysicalReg::A6, PhysicalReg::A7, PhysicalReg::S0, PhysicalReg::S1, PhysicalReg::S2, PhysicalReg::S3, PhysicalReg::S4, PhysicalReg::S5, PhysicalReg::S6, PhysicalReg::S7, PhysicalReg::S8, PhysicalReg::S9, PhysicalReg::S10, PhysicalReg::S11, }; } void RISCv64RegAlloc::run() { eliminateFrameIndices(); analyzeLiveness(); buildInterferenceGraph(); colorGraph(); rewriteFunction(); } void RISCv64RegAlloc::eliminateFrameIndices() { StackFrameInfo& frame_info = MFunc->getFrameInfo(); int current_offset = 0; Function* F = MFunc->getFunc(); RISCv64ISel* isel = MFunc->getISel(); for (auto& bb : F->getBasicBlocks()) { for (auto& inst : bb->getInstructions()) { if (auto alloca = dynamic_cast(inst.get())) { int size = 4; if (!alloca->getDims().empty()) { int num_elements = 1; for (const auto& dim_use : alloca->getDims()) { if (auto const_dim = dynamic_cast(dim_use->getValue())) { num_elements *= const_dim->getInt(); } } size *= num_elements; } current_offset += size; unsigned alloca_vreg = isel->getVReg(alloca); frame_info.alloca_offsets[alloca_vreg] = -current_offset; } } } frame_info.locals_size = current_offset; for (auto& mbb : MFunc->getBlocks()) { std::vector> new_instructions; for (auto& instr_ptr : mbb->getInstructions()) { if (instr_ptr->getOpcode() == RVOpcodes::FRAME_LOAD) { auto& operands = instr_ptr->getOperands(); unsigned dest_vreg = static_cast(operands[0].get())->getVRegNum(); unsigned alloca_vreg = static_cast(operands[1].get())->getVRegNum(); int offset = frame_info.alloca_offsets.at(alloca_vreg); auto addr_vreg = isel->getNewVReg(); auto addi = std::make_unique(RVOpcodes::ADDI); addi->addOperand(std::make_unique(addr_vreg)); addi->addOperand(std::make_unique(PhysicalReg::S0)); addi->addOperand(std::make_unique(offset)); new_instructions.push_back(std::move(addi)); auto lw = std::make_unique(RVOpcodes::LW); lw->addOperand(std::make_unique(dest_vreg)); lw->addOperand(std::make_unique( std::make_unique(addr_vreg), std::make_unique(0))); new_instructions.push_back(std::move(lw)); } else if (instr_ptr->getOpcode() == RVOpcodes::FRAME_STORE) { auto& operands = instr_ptr->getOperands(); unsigned src_vreg = static_cast(operands[0].get())->getVRegNum(); unsigned alloca_vreg = static_cast(operands[1].get())->getVRegNum(); int offset = frame_info.alloca_offsets.at(alloca_vreg); auto addr_vreg = isel->getNewVReg(); auto addi = std::make_unique(RVOpcodes::ADDI); addi->addOperand(std::make_unique(addr_vreg)); addi->addOperand(std::make_unique(PhysicalReg::S0)); addi->addOperand(std::make_unique(offset)); new_instructions.push_back(std::move(addi)); auto sw = std::make_unique(RVOpcodes::SW); sw->addOperand(std::make_unique(src_vreg)); sw->addOperand(std::make_unique( std::make_unique(addr_vreg), std::make_unique(0))); new_instructions.push_back(std::move(sw)); } else { new_instructions.push_back(std::move(instr_ptr)); } } mbb->getInstructions() = std::move(new_instructions); } } void RISCv64RegAlloc::getInstrUseDef(MachineInstr* instr, LiveSet& use, LiveSet& def) { bool is_def = true; auto opcode = instr->getOpcode(); // 预定义def和use规则 if (opcode == RVOpcodes::SW || opcode == RVOpcodes::SD || opcode == RVOpcodes::BEQ || opcode == RVOpcodes::BNE || opcode == RVOpcodes::BLT || opcode == RVOpcodes::BGE || opcode == RVOpcodes::RET || opcode == RVOpcodes::J) { is_def = false; } if (opcode == RVOpcodes::CALL) { // CALL会杀死所有调用者保存寄存器,这是一个简化处理 // 同时也使用了传入a0-a7的参数 } for (const auto& op : instr->getOperands()) { if (op->getKind() == MachineOperand::KIND_REG) { auto reg_op = static_cast(op.get()); if (reg_op->isVirtual()) { if (is_def) { def.insert(reg_op->getVRegNum()); is_def = false; } else { use.insert(reg_op->getVRegNum()); } } } else if (op->getKind() == MachineOperand::KIND_MEM) { auto mem_op = static_cast(op.get()); if (mem_op->getBase()->isVirtual()) { use.insert(mem_op->getBase()->getVRegNum()); } } } } void RISCv64RegAlloc::analyzeLiveness() { bool changed = true; while (changed) { changed = false; for (auto it = MFunc->getBlocks().rbegin(); it != MFunc->getBlocks().rend(); ++it) { auto& mbb = *it; LiveSet live_out; for (auto succ : mbb->successors) { if (!succ->getInstructions().empty()) { auto first_instr = succ->getInstructions().front().get(); if (live_in_map.count(first_instr)) { live_out.insert(live_in_map.at(first_instr).begin(), live_in_map.at(first_instr).end()); } } } for (auto instr_it = mbb->getInstructions().rbegin(); instr_it != mbb->getInstructions().rend(); ++instr_it) { MachineInstr* instr = instr_it->get(); LiveSet old_live_in = live_in_map[instr]; live_out_map[instr] = live_out; LiveSet use, def; getInstrUseDef(instr, use, def); LiveSet live_in = use; LiveSet diff = live_out; for (auto vreg : def) { diff.erase(vreg); } live_in.insert(diff.begin(), diff.end()); live_in_map[instr] = live_in; live_out = live_in; if (live_in_map[instr] != old_live_in) { changed = true; } } } } } void RISCv64RegAlloc::buildInterferenceGraph() { std::set all_vregs; for (auto& mbb : MFunc->getBlocks()) { for(auto& instr : mbb->getInstructions()) { LiveSet use, def; getInstrUseDef(instr.get(), use, def); for(auto u : use) all_vregs.insert(u); for(auto d : def) all_vregs.insert(d); } } for (auto vreg : all_vregs) { interference_graph[vreg] = {}; } for (auto& mbb : MFunc->getBlocks()) { for (auto& instr : mbb->getInstructions()) { LiveSet def, use; getInstrUseDef(instr.get(), use, def); const LiveSet& live_out = live_out_map.at(instr.get()); for (unsigned d : def) { for (unsigned l : live_out) { if (d != l) { interference_graph[d].insert(l); interference_graph[l].insert(d); } } } } } } void RISCv64RegAlloc::colorGraph() { std::vector sorted_vregs; for (auto const& [vreg, neighbors] : interference_graph) { sorted_vregs.push_back(vreg); } std::sort(sorted_vregs.begin(), sorted_vregs.end(), [&](unsigned a, unsigned b) { return interference_graph[a].size() > interference_graph[b].size(); }); for (unsigned vreg : sorted_vregs) { std::set used_colors; for (unsigned neighbor : interference_graph.at(vreg)) { if (color_map.count(neighbor)) { used_colors.insert(color_map.at(neighbor)); } } bool colored = false; for (PhysicalReg preg : allocable_int_regs) { if (used_colors.find(preg) == used_colors.end()) { color_map[vreg] = preg; colored = true; break; } } if (!colored) { spilled_vregs.insert(vreg); } } } void RISCv64RegAlloc::rewriteFunction() { StackFrameInfo& frame_info = MFunc->getFrameInfo(); int current_offset = frame_info.locals_size; for (unsigned vreg : spilled_vregs) { current_offset += 4; frame_info.spill_offsets[vreg] = -current_offset; } frame_info.spill_size = current_offset - frame_info.locals_size; for (auto& mbb : MFunc->getBlocks()) { std::vector> new_instructions; for (auto& instr_ptr : mbb->getInstructions()) { LiveSet use, def; getInstrUseDef(instr_ptr.get(), use, def); for (unsigned vreg : use) { if (spilled_vregs.count(vreg)) { int offset = frame_info.spill_offsets.at(vreg); auto load = std::make_unique(RVOpcodes::LW); load->addOperand(std::make_unique(vreg)); load->addOperand(std::make_unique( std::make_unique(PhysicalReg::S0), std::make_unique(offset) )); new_instructions.push_back(std::move(load)); } } new_instructions.push_back(std::move(instr_ptr)); for (unsigned vreg : def) { if (spilled_vregs.count(vreg)) { int offset = frame_info.spill_offsets.at(vreg); auto store = std::make_unique(RVOpcodes::SW); store->addOperand(std::make_unique(vreg)); store->addOperand(std::make_unique( std::make_unique(PhysicalReg::S0), std::make_unique(offset) )); new_instructions.push_back(std::move(store)); } } } mbb->getInstructions() = std::move(new_instructions); } for (auto& mbb : MFunc->getBlocks()) { for (auto& instr_ptr : mbb->getInstructions()) { for (auto& op_ptr : instr_ptr->getOperands()) { if(op_ptr->getKind() == MachineOperand::KIND_REG) { auto reg_op = static_cast(op_ptr.get()); if (reg_op->isVirtual()) { unsigned vreg = reg_op->getVRegNum(); if (color_map.count(vreg)) { reg_op->setPReg(color_map.at(vreg)); } else if (spilled_vregs.count(vreg)) { reg_op->setPReg(PhysicalReg::T6); // 溢出统一用t6 } } } else if (op_ptr->getKind() == MachineOperand::KIND_MEM) { auto mem_op = static_cast(op_ptr.get()); auto base_reg_op = mem_op->getBase(); if(base_reg_op->isVirtual()){ unsigned vreg = base_reg_op->getVRegNum(); if(color_map.count(vreg)) { base_reg_op->setPReg(color_map.at(vreg)); } else if (spilled_vregs.count(vreg)) { base_reg_op->setPReg(PhysicalReg::T6); } } } } } } } } // namespace sysy