#include "mir/MIR.h" #include "ir/IR.h" #include #include #include #include #include #include "utils/Log.h" namespace mir { namespace { const FrameSlot& GetFrameSlot(const MachineFunction& function, const Operand& operand) { if (operand.GetKind() != Operand::Kind::FrameIndex) { throw std::runtime_error(FormatError("mir", "期望 FrameIndex 操作数")); } return function.GetFrameSlot(operand.GetFrameIndex()); } bool IsFloatReg(PhysReg reg) { return reg >= PhysReg::S0 && reg <= PhysReg::S15; } void PrintStackAccess(std::ostream& os, const char* mnemonic, PhysReg reg, int offset) { bool is_float = IsFloatReg(reg); const char* ldr_cmd = is_float ? "ldr" : "ldr"; const char* str_cmd = is_float ? "str" : "str"; const char* base_mnemonic = (std::strcmp(mnemonic, "ldur") == 0) ? ldr_cmd : str_cmd; if (offset >= -256 && offset <= 255) { if (is_float) { os << " " << base_mnemonic << " " << PhysRegName(reg) << ", [x29, #" << offset << "]\n"; } else { os << " " << mnemonic << " " << PhysRegName(reg) << ", [x29, #" << offset << "]\n"; } } else { os << " mov x10, #" << offset << "\n"; os << " " << base_mnemonic << " " << PhysRegName(reg) << ", [x29, x10]\n"; } } std::string GetBlockLabel(const std::string& func_name, const std::string& block_name) { if (block_name == "entry" || block_name.empty()) { return func_name; } return ".L_" + func_name + "_" + block_name; } } // namespace void PrintAsm(const MachineFunction& function, std::ostream& os) { os << ".text\n"; os << ".global " << function.GetName() << "\n"; os << ".type " << function.GetName() << ", %function\n"; struct FloatConstant { std::string label; float value; }; std::vector float_constants; for (size_t b = 0; b < function.GetBlocks().size(); ++b) { const auto& block = function.GetBlocks()[b]; // Print the block label if (b == 0) { os << function.GetName() << ":\n"; } else { os << GetBlockLabel(function.GetName(), block.GetName()) << ":\n"; } for (const auto& inst : block.GetInstructions()) { const auto& ops = inst.GetOperands(); switch (inst.GetOpcode()) { case Opcode::Prologue: os << " stp x29, x30, [sp, #-16]!\n"; os << " mov x29, sp\n"; if (function.GetFrameSize() > 0) { os << " sub sp, sp, #" << function.GetFrameSize() << "\n"; } break; case Opcode::Epilogue: if (function.GetFrameSize() > 0) { os << " add sp, sp, #" << function.GetFrameSize() << "\n"; } os << " ldp x29, x30, [sp], #16\n"; break; case Opcode::MovImm: { PhysReg dst = ops.at(0).GetReg(); if (IsFloatReg(dst)) { // Load float constant int bits = ops.at(1).GetImm(); float val; std::memcpy(&val, &bits, sizeof(float)); std::string flabel = ".LC_" + function.GetName() + "_" + std::to_string(float_constants.size()); float_constants.push_back({flabel, val}); os << " adrp x8, " << flabel << "\n"; os << " ldr " << PhysRegName(dst) << ", [x8, :lo12:" << flabel << "]\n"; } else { os << " mov " << PhysRegName(dst) << ", #" << ops.at(1).GetImm() << "\n"; } break; } case Opcode::LoadStack: { const auto& slot = GetFrameSlot(function, ops.at(1)); PrintStackAccess(os, "ldur", ops.at(0).GetReg(), slot.offset); break; } case Opcode::StoreStack: { const auto& slot = GetFrameSlot(function, ops.at(1)); PrintStackAccess(os, "stur", ops.at(0).GetReg(), slot.offset); break; } case Opcode::AddRR: os << " add " << PhysRegName(ops.at(0).GetReg()) << ", " << PhysRegName(ops.at(1).GetReg()) << ", " << PhysRegName(ops.at(2).GetReg()) << "\n"; break; case Opcode::SubRR: os << " sub " << PhysRegName(ops.at(0).GetReg()) << ", " << PhysRegName(ops.at(1).GetReg()) << ", " << PhysRegName(ops.at(2).GetReg()) << "\n"; break; case Opcode::MulRR: os << " mul " << PhysRegName(ops.at(0).GetReg()) << ", " << PhysRegName(ops.at(1).GetReg()) << ", " << PhysRegName(ops.at(2).GetReg()) << "\n"; break; case Opcode::SDivRR: os << " sdiv " << PhysRegName(ops.at(0).GetReg()) << ", " << PhysRegName(ops.at(1).GetReg()) << ", " << PhysRegName(ops.at(2).GetReg()) << "\n"; break; case Opcode::MSubRRRR: os << " msub " << PhysRegName(ops.at(0).GetReg()) << ", " << PhysRegName(ops.at(1).GetReg()) << ", " << PhysRegName(ops.at(2).GetReg()) << ", " << PhysRegName(ops.at(3).GetReg()) << "\n"; break; case Opcode::FAddRRR: os << " fadd " << PhysRegName(ops.at(0).GetReg()) << ", " << PhysRegName(ops.at(1).GetReg()) << ", " << PhysRegName(ops.at(2).GetReg()) << "\n"; break; case Opcode::FSubRRR: os << " fsub " << PhysRegName(ops.at(0).GetReg()) << ", " << PhysRegName(ops.at(1).GetReg()) << ", " << PhysRegName(ops.at(2).GetReg()) << "\n"; break; case Opcode::FMulRRR: os << " fmul " << PhysRegName(ops.at(0).GetReg()) << ", " << PhysRegName(ops.at(1).GetReg()) << ", " << PhysRegName(ops.at(2).GetReg()) << "\n"; break; case Opcode::FDivRRR: os << " fdiv " << PhysRegName(ops.at(0).GetReg()) << ", " << PhysRegName(ops.at(1).GetReg()) << ", " << PhysRegName(ops.at(2).GetReg()) << "\n"; break; case Opcode::CmpRR: os << " cmp " << PhysRegName(ops.at(0).GetReg()) << ", " << PhysRegName(ops.at(1).GetReg()) << "\n"; break; case Opcode::FCmpRR: os << " fcmp " << PhysRegName(ops.at(0).GetReg()) << ", " << PhysRegName(ops.at(1).GetReg()) << "\n"; break; case Opcode::Cset: os << " cset " << PhysRegName(ops.at(0).GetReg()) << ", " << ops.at(1).GetCondCode() << "\n"; break; case Opcode::B: os << " b " << GetBlockLabel(function.GetName(), ops.at(0).GetLabelName()) << "\n"; break; case Opcode::BCond: os << " b." << ops.at(0).GetCondCode() << " " << GetBlockLabel(function.GetName(), ops.at(1).GetLabelName()) << "\n"; break; case Opcode::Call: os << " bl " << ops.at(0).GetGlobalName() << "\n"; break; case Opcode::Ret: os << " ret\n"; break; case Opcode::MovReg: if (IsFloatReg(ops.at(0).GetReg()) || IsFloatReg(ops.at(1).GetReg())) { os << " fmov " << PhysRegName(ops.at(0).GetReg()) << ", " << PhysRegName(ops.at(1).GetReg()) << "\n"; } else { os << " mov " << PhysRegName(ops.at(0).GetReg()) << ", " << PhysRegName(ops.at(1).GetReg()) << "\n"; } break; case Opcode::Adrp: os << " adrp " << PhysRegName(ops.at(0).GetReg()) << ", " << ops.at(1).GetGlobalName() << "\n"; break; case Opcode::AddRegImm: { os << " add " << PhysRegName(ops.at(0).GetReg()) << ", " << PhysRegName(ops.at(1).GetReg()) << ", "; if (ops.at(2).GetKind() == Operand::Kind::FrameIndex) { const auto& slot = function.GetFrameSlot(ops.at(2).GetFrameIndex()); os << "#" << slot.offset << "\n"; } else if (ops.at(2).GetKind() == Operand::Kind::Global) { os << ":lo12:" << ops.at(2).GetGlobalName() << "\n"; } else { os << "#" << ops.at(2).GetImm() << "\n"; } break; } case Opcode::LdrRegReg: { PhysReg reg = ops.at(0).GetReg(); const char* ldr_cmd = IsFloatReg(reg) ? "ldr" : "ldr"; os << " " << ldr_cmd << " " << PhysRegName(reg) << ", [" << PhysRegName(ops.at(1).GetReg()) << "]\n"; break; } case Opcode::StrRegReg: { PhysReg reg = ops.at(0).GetReg(); const char* str_cmd = IsFloatReg(reg) ? "str" : "str"; os << " " << str_cmd << " " << PhysRegName(reg) << ", [" << PhysRegName(ops.at(1).GetReg()) << "]\n"; break; } case Opcode::SIToFP: os << " scvtf " << PhysRegName(ops.at(0).GetReg()) << ", " << PhysRegName(ops.at(1).GetReg()) << "\n"; break; case Opcode::FPToSI: os << " fcvtzs " << PhysRegName(ops.at(0).GetReg()) << ", " << PhysRegName(ops.at(1).GetReg()) << "\n"; break; case Opcode::ZExt: if (ops.at(0).GetReg() >= PhysReg::X0 && ops.at(0).GetReg() <= PhysReg::X28) { os << " sxtw " << PhysRegName(ops.at(0).GetReg()) << ", " << PhysRegName(ops.at(1).GetReg()) << "\n"; } else { os << " and " << PhysRegName(ops.at(0).GetReg()) << ", " << PhysRegName(ops.at(1).GetReg()) << ", #1\n"; } break; } } } os << ".size " << function.GetName() << ", .-" << function.GetName() << "\n"; // Print read-only data segment if there are float constants if (!float_constants.empty()) { os << ".section .rodata\n"; os << ".align 2\n"; for (const auto& fc : float_constants) { os << fc.label << ":\n"; uint32_t bits; std::memcpy(&bits, &fc.value, sizeof(float)); os << " .word " << bits << " // float " << fc.value << "\n"; } } } static uint32_t GetTypeSize(const ir::Type* type) { if (type->IsInt32() || type->IsFloat()) { return 4; } if (type->IsPtrInt32() || type->IsPtrFloat()) { return 8; } if (type->IsArray()) { auto* arr_ty = const_cast(type)->GetAsArrayType().get(); return arr_ty->GetNumElements() * GetTypeSize(arr_ty->GetElementType().get()); } return 4; } void PrintGlobals(const ir::Module& module, std::ostream& os) { for (const auto& gv : module.GetGlobalValues()) { os << ".global " << gv->GetName() << "\n"; std::shared_ptr actual_ty = gv->GetType(); if (actual_ty->IsPtrInt32()) actual_ty = ir::Type::GetInt32Type(); else if (actual_ty->IsPtrFloat()) actual_ty = ir::Type::GetFloatType(); uint32_t actual_size = GetTypeSize(actual_ty.get()); if (gv->GetInitializer()) { os << ".data\n"; os << ".align 2\n"; os << ".size " << gv->GetName() << ", " << actual_size << "\n"; os << gv->GetName() << ":\n"; if (actual_ty->IsFloat()) { float val = 0.0f; if (auto* cf = dynamic_cast(gv->GetInitializer())) { val = cf->GetValue(); } else if (auto* ci = dynamic_cast(gv->GetInitializer())) { val = static_cast(ci->GetValue()); } uint32_t bits; std::memcpy(&bits, &val, sizeof(float)); os << " .word " << bits << " // float " << val << "\n"; } else { int val = 0; if (auto* ci = dynamic_cast(gv->GetInitializer())) { val = ci->GetValue(); } else if (auto* cf = dynamic_cast(gv->GetInitializer())) { val = static_cast(cf->GetValue()); } os << " .word " << val << "\n"; } } else { os << ".bss\n"; os << ".align 4\n"; os << ".size " << gv->GetName() << ", " << actual_size << "\n"; os << gv->GetName() << ":\n"; os << " .zero " << actual_size << "\n"; } os << "\n"; } } } // namespace mir