diff --git a/src/ir/passes/CSE.cpp b/src/ir/passes/CSE.cpp index 2684fd7..f057d7e 100644 --- a/src/ir/passes/CSE.cpp +++ b/src/ir/passes/CSE.cpp @@ -1,5 +1,6 @@ #include "ir/PassManager.h" #include +#include #include #include @@ -56,10 +57,30 @@ bool RunCSE(Function* func) { for (const auto& bbPtr : func->GetBlocks()) { std::vector seen_instructions; + std::unordered_map available_loads; std::vector to_erase; for (const auto& instPtr : bbPtr->GetInstructions()) { auto* inst = instPtr.get(); + + if (inst->GetOpcode() == Opcode::Load) { + auto* load = static_cast(inst); + auto it = available_loads.find(load->GetPtr()); + if (it != available_loads.end()) { + inst->ReplaceAllUsesWith(it->second); + to_erase.push_back(inst); + changed = true; + continue; + } + available_loads[load->GetPtr()] = inst; + continue; + } + + if (inst->GetOpcode() == Opcode::Store || + inst->GetOpcode() == Opcode::Call) { + available_loads.clear(); + } + Instruction* match = nullptr; for (auto* seen : seen_instructions) { if (IsEquivalent(inst, seen)) { diff --git a/src/mir/AsmPrinter.cpp b/src/mir/AsmPrinter.cpp index 3f730ab..b7f55c7 100644 --- a/src/mir/AsmPrinter.cpp +++ b/src/mir/AsmPrinter.cpp @@ -25,7 +25,7 @@ bool IsFloatReg(PhysReg reg) { } void PrintStackAccess(std::ostream& os, const char* mnemonic, PhysReg reg, - int offset) { + int offset, int frame_size) { bool is_float = IsFloatReg(reg); const char* ldr_cmd = is_float ? "ldr" : "ldr"; const char* str_cmd = is_float ? "str" : "str"; @@ -38,8 +38,22 @@ void PrintStackAccess(std::ostream& os, const char* mnemonic, PhysReg reg, os << " " << mnemonic << " " << PhysRegName(reg) << ", [x29, #" << offset << "]\n"; } } else { - os << " ldr x10, =" << offset << "\n"; - os << " " << base_mnemonic << " " << PhysRegName(reg) << ", [x29, x10]\n"; + int sp_offset = frame_size + offset; + int access_size = 4; + if ((reg >= PhysReg::X0 && reg <= PhysReg::X28) || + reg == PhysReg::X29 || reg == PhysReg::X30 || + reg == PhysReg::SP) { + access_size = 8; + } + int max_offset = access_size == 8 ? 32760 : 16380; + if (sp_offset >= 0 && sp_offset <= max_offset && + sp_offset % access_size == 0) { + os << " " << base_mnemonic << " " << PhysRegName(reg) + << ", [sp, #" << sp_offset << "]\n"; + } else { + os << " ldr x10, =" << offset << "\n"; + os << " " << base_mnemonic << " " << PhysRegName(reg) << ", [x29, x10]\n"; + } } } @@ -125,12 +139,14 @@ void PrintAsm(const MachineFunction& function, std::ostream& os) { } case Opcode::LoadStack: { const auto& slot = GetFrameSlot(function, ops.at(1)); - PrintStackAccess(os, "ldur", ops.at(0).GetReg(), slot.offset); + PrintStackAccess(os, "ldur", ops.at(0).GetReg(), slot.offset, + function.GetFrameSize()); break; } case Opcode::StoreStack: { const auto& slot = GetFrameSlot(function, ops.at(1)); - PrintStackAccess(os, "stur", ops.at(0).GetReg(), slot.offset); + PrintStackAccess(os, "stur", ops.at(0).GetReg(), slot.offset, + function.GetFrameSize()); break; } case Opcode::AddRR: diff --git a/src/mir/passes/Peephole.cpp b/src/mir/passes/Peephole.cpp index a8bc248..97af910 100644 --- a/src/mir/passes/Peephole.cpp +++ b/src/mir/passes/Peephole.cpp @@ -7,6 +7,10 @@ namespace mir { namespace { +int AlignTo(int value, int align) { + return ((value + align - 1) / align) * align; +} + PhysReg NormalizeReg(PhysReg reg) { int r = static_cast(reg); // Map 64-bit X0-X28 registers to 32-bit W0-W28 registers to handle aliasing @@ -96,6 +100,29 @@ std::vector SimplifyCompareToBranch( return simplified; } +void CompactFrameSlots(MachineFunction& function) { + std::unordered_set used_slots; + for (const auto& block : function.GetBlocks()) { + for (const auto& inst : block.GetInstructions()) { + for (const auto& opnd : inst.GetOperands()) { + if (opnd.GetKind() == Operand::Kind::FrameIndex) { + used_slots.insert(opnd.GetFrameIndex()); + } + } + } + } + + int cursor = 0; + for (const auto& slot : function.GetFrameSlots()) { + if (used_slots.find(slot.index) == used_slots.end()) { + continue; + } + cursor += slot.size; + function.GetFrameSlot(slot.index).offset = -cursor; + } + function.SetFrameSize(AlignTo(cursor, 16)); +} + } // namespace void RunPeephole(MachineFunction& function) { @@ -285,6 +312,8 @@ void RunPeephole(MachineFunction& function) { } insts = std::move(optimized); } + + CompactFrameSlots(function); } } // namespace mir