325 lines
12 KiB
C++
325 lines
12 KiB
C++
#include "mir/MIR.h"
|
|
#include "ir/IR.h"
|
|
|
|
#include <ostream>
|
|
#include <stdexcept>
|
|
#include <cstdint>
|
|
#include <vector>
|
|
#include <cstring>
|
|
|
|
#include "utils/Log.h"
|
|
|
|
namespace mir {
|
|
namespace {
|
|
|
|
const FrameSlot& GetFrameSlot(const MachineFunction& function,
|
|
const Operand& operand) {
|
|
if (operand.GetKind() != Operand::Kind::FrameIndex) {
|
|
throw std::runtime_error(FormatError("mir", "期望 FrameIndex 操作数"));
|
|
}
|
|
return function.GetFrameSlot(operand.GetFrameIndex());
|
|
}
|
|
|
|
bool IsFloatReg(PhysReg reg) {
|
|
return reg >= PhysReg::S0 && reg <= PhysReg::S15;
|
|
}
|
|
|
|
void PrintStackAccess(std::ostream& os, const char* mnemonic, PhysReg reg,
|
|
int offset) {
|
|
bool is_float = IsFloatReg(reg);
|
|
const char* ldr_cmd = is_float ? "ldr" : "ldr";
|
|
const char* str_cmd = is_float ? "str" : "str";
|
|
const char* base_mnemonic = (std::strcmp(mnemonic, "ldur") == 0) ? ldr_cmd : str_cmd;
|
|
|
|
if (offset >= -256 && offset <= 255) {
|
|
if (is_float) {
|
|
os << " " << base_mnemonic << " " << PhysRegName(reg) << ", [x29, #" << offset << "]\n";
|
|
} else {
|
|
os << " " << mnemonic << " " << PhysRegName(reg) << ", [x29, #" << offset << "]\n";
|
|
}
|
|
} else {
|
|
os << " mov x10, #" << offset << "\n";
|
|
os << " " << base_mnemonic << " " << PhysRegName(reg) << ", [x29, x10]\n";
|
|
}
|
|
}
|
|
|
|
std::string GetBlockLabel(const std::string& func_name, const std::string& block_name) {
|
|
if (block_name == "entry" || block_name.empty()) {
|
|
return func_name;
|
|
}
|
|
return ".L_" + func_name + "_" + block_name;
|
|
}
|
|
|
|
} // namespace
|
|
|
|
void PrintAsm(const MachineFunction& function, std::ostream& os) {
|
|
os << ".text\n";
|
|
os << ".global " << function.GetName() << "\n";
|
|
os << ".type " << function.GetName() << ", %function\n";
|
|
|
|
struct FloatConstant {
|
|
std::string label;
|
|
float value;
|
|
};
|
|
std::vector<FloatConstant> float_constants;
|
|
|
|
for (size_t b = 0; b < function.GetBlocks().size(); ++b) {
|
|
const auto& block = function.GetBlocks()[b];
|
|
|
|
// Print the block label
|
|
if (b == 0) {
|
|
os << function.GetName() << ":\n";
|
|
} else {
|
|
os << GetBlockLabel(function.GetName(), block.GetName()) << ":\n";
|
|
}
|
|
|
|
for (const auto& inst : block.GetInstructions()) {
|
|
const auto& ops = inst.GetOperands();
|
|
switch (inst.GetOpcode()) {
|
|
case Opcode::Prologue:
|
|
os << " stp x29, x30, [sp, #-16]!\n";
|
|
os << " mov x29, sp\n";
|
|
if (function.GetFrameSize() > 0) {
|
|
os << " sub sp, sp, #" << function.GetFrameSize() << "\n";
|
|
}
|
|
break;
|
|
case Opcode::Epilogue:
|
|
if (function.GetFrameSize() > 0) {
|
|
os << " add sp, sp, #" << function.GetFrameSize() << "\n";
|
|
}
|
|
os << " ldp x29, x30, [sp], #16\n";
|
|
break;
|
|
case Opcode::MovImm: {
|
|
PhysReg dst = ops.at(0).GetReg();
|
|
if (IsFloatReg(dst)) {
|
|
// Load float constant
|
|
int bits = ops.at(1).GetImm();
|
|
float val;
|
|
std::memcpy(&val, &bits, sizeof(float));
|
|
std::string flabel = ".LC_" + function.GetName() + "_" + std::to_string(float_constants.size());
|
|
float_constants.push_back({flabel, val});
|
|
|
|
os << " adrp x8, " << flabel << "\n";
|
|
os << " ldr " << PhysRegName(dst) << ", [x8, :lo12:" << flabel << "]\n";
|
|
} else {
|
|
os << " mov " << PhysRegName(dst) << ", #" << ops.at(1).GetImm() << "\n";
|
|
}
|
|
break;
|
|
}
|
|
case Opcode::LoadStack: {
|
|
const auto& slot = GetFrameSlot(function, ops.at(1));
|
|
PrintStackAccess(os, "ldur", ops.at(0).GetReg(), slot.offset);
|
|
break;
|
|
}
|
|
case Opcode::StoreStack: {
|
|
const auto& slot = GetFrameSlot(function, ops.at(1));
|
|
PrintStackAccess(os, "stur", ops.at(0).GetReg(), slot.offset);
|
|
break;
|
|
}
|
|
case Opcode::AddRR:
|
|
os << " add " << PhysRegName(ops.at(0).GetReg()) << ", "
|
|
<< PhysRegName(ops.at(1).GetReg()) << ", "
|
|
<< PhysRegName(ops.at(2).GetReg()) << "\n";
|
|
break;
|
|
case Opcode::SubRR:
|
|
os << " sub " << PhysRegName(ops.at(0).GetReg()) << ", "
|
|
<< PhysRegName(ops.at(1).GetReg()) << ", "
|
|
<< PhysRegName(ops.at(2).GetReg()) << "\n";
|
|
break;
|
|
case Opcode::MulRR:
|
|
os << " mul " << PhysRegName(ops.at(0).GetReg()) << ", "
|
|
<< PhysRegName(ops.at(1).GetReg()) << ", "
|
|
<< PhysRegName(ops.at(2).GetReg()) << "\n";
|
|
break;
|
|
case Opcode::SDivRR:
|
|
os << " sdiv " << PhysRegName(ops.at(0).GetReg()) << ", "
|
|
<< PhysRegName(ops.at(1).GetReg()) << ", "
|
|
<< PhysRegName(ops.at(2).GetReg()) << "\n";
|
|
break;
|
|
case Opcode::MSubRRRR:
|
|
os << " msub " << PhysRegName(ops.at(0).GetReg()) << ", "
|
|
<< PhysRegName(ops.at(1).GetReg()) << ", "
|
|
<< PhysRegName(ops.at(2).GetReg()) << ", "
|
|
<< PhysRegName(ops.at(3).GetReg()) << "\n";
|
|
break;
|
|
case Opcode::FAddRRR:
|
|
os << " fadd " << PhysRegName(ops.at(0).GetReg()) << ", "
|
|
<< PhysRegName(ops.at(1).GetReg()) << ", "
|
|
<< PhysRegName(ops.at(2).GetReg()) << "\n";
|
|
break;
|
|
case Opcode::FSubRRR:
|
|
os << " fsub " << PhysRegName(ops.at(0).GetReg()) << ", "
|
|
<< PhysRegName(ops.at(1).GetReg()) << ", "
|
|
<< PhysRegName(ops.at(2).GetReg()) << "\n";
|
|
break;
|
|
case Opcode::FMulRRR:
|
|
os << " fmul " << PhysRegName(ops.at(0).GetReg()) << ", "
|
|
<< PhysRegName(ops.at(1).GetReg()) << ", "
|
|
<< PhysRegName(ops.at(2).GetReg()) << "\n";
|
|
break;
|
|
case Opcode::FDivRRR:
|
|
os << " fdiv " << PhysRegName(ops.at(0).GetReg()) << ", "
|
|
<< PhysRegName(ops.at(1).GetReg()) << ", "
|
|
<< PhysRegName(ops.at(2).GetReg()) << "\n";
|
|
break;
|
|
case Opcode::CmpRR:
|
|
os << " cmp " << PhysRegName(ops.at(0).GetReg()) << ", "
|
|
<< PhysRegName(ops.at(1).GetReg()) << "\n";
|
|
break;
|
|
case Opcode::FCmpRR:
|
|
os << " fcmp " << PhysRegName(ops.at(0).GetReg()) << ", "
|
|
<< PhysRegName(ops.at(1).GetReg()) << "\n";
|
|
break;
|
|
case Opcode::Cset:
|
|
os << " cset " << PhysRegName(ops.at(0).GetReg()) << ", "
|
|
<< ops.at(1).GetCondCode() << "\n";
|
|
break;
|
|
case Opcode::B:
|
|
os << " b " << GetBlockLabel(function.GetName(), ops.at(0).GetLabelName()) << "\n";
|
|
break;
|
|
case Opcode::BCond:
|
|
os << " b." << ops.at(0).GetCondCode() << " "
|
|
<< GetBlockLabel(function.GetName(), ops.at(1).GetLabelName()) << "\n";
|
|
break;
|
|
case Opcode::Call:
|
|
os << " bl " << ops.at(0).GetGlobalName() << "\n";
|
|
break;
|
|
case Opcode::Ret:
|
|
os << " ret\n";
|
|
break;
|
|
case Opcode::MovReg:
|
|
if (IsFloatReg(ops.at(0).GetReg()) || IsFloatReg(ops.at(1).GetReg())) {
|
|
os << " fmov " << PhysRegName(ops.at(0).GetReg()) << ", "
|
|
<< PhysRegName(ops.at(1).GetReg()) << "\n";
|
|
} else {
|
|
os << " mov " << PhysRegName(ops.at(0).GetReg()) << ", "
|
|
<< PhysRegName(ops.at(1).GetReg()) << "\n";
|
|
}
|
|
break;
|
|
case Opcode::Adrp:
|
|
os << " adrp " << PhysRegName(ops.at(0).GetReg()) << ", "
|
|
<< ops.at(1).GetGlobalName() << "\n";
|
|
break;
|
|
case Opcode::AddRegImm: {
|
|
os << " add " << PhysRegName(ops.at(0).GetReg()) << ", "
|
|
<< PhysRegName(ops.at(1).GetReg()) << ", ";
|
|
if (ops.at(2).GetKind() == Operand::Kind::FrameIndex) {
|
|
const auto& slot = function.GetFrameSlot(ops.at(2).GetFrameIndex());
|
|
os << "#" << slot.offset << "\n";
|
|
} else if (ops.at(2).GetKind() == Operand::Kind::Global) {
|
|
os << ":lo12:" << ops.at(2).GetGlobalName() << "\n";
|
|
} else {
|
|
os << "#" << ops.at(2).GetImm() << "\n";
|
|
}
|
|
break;
|
|
}
|
|
case Opcode::LdrRegReg: {
|
|
PhysReg reg = ops.at(0).GetReg();
|
|
const char* ldr_cmd = IsFloatReg(reg) ? "ldr" : "ldr";
|
|
os << " " << ldr_cmd << " " << PhysRegName(reg) << ", ["
|
|
<< PhysRegName(ops.at(1).GetReg()) << "]\n";
|
|
break;
|
|
}
|
|
case Opcode::StrRegReg: {
|
|
PhysReg reg = ops.at(0).GetReg();
|
|
const char* str_cmd = IsFloatReg(reg) ? "str" : "str";
|
|
os << " " << str_cmd << " " << PhysRegName(reg) << ", ["
|
|
<< PhysRegName(ops.at(1).GetReg()) << "]\n";
|
|
break;
|
|
}
|
|
case Opcode::SIToFP:
|
|
os << " scvtf " << PhysRegName(ops.at(0).GetReg()) << ", "
|
|
<< PhysRegName(ops.at(1).GetReg()) << "\n";
|
|
break;
|
|
case Opcode::FPToSI:
|
|
os << " fcvtzs " << PhysRegName(ops.at(0).GetReg()) << ", "
|
|
<< PhysRegName(ops.at(1).GetReg()) << "\n";
|
|
break;
|
|
case Opcode::ZExt:
|
|
if (ops.at(0).GetReg() >= PhysReg::X0 && ops.at(0).GetReg() <= PhysReg::X28) {
|
|
os << " sxtw " << PhysRegName(ops.at(0).GetReg()) << ", " << PhysRegName(ops.at(1).GetReg()) << "\n";
|
|
} else {
|
|
os << " and " << PhysRegName(ops.at(0).GetReg()) << ", " << PhysRegName(ops.at(1).GetReg()) << ", #1\n";
|
|
}
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
os << ".size " << function.GetName() << ", .-" << function.GetName() << "\n";
|
|
|
|
// Print read-only data segment if there are float constants
|
|
if (!float_constants.empty()) {
|
|
os << ".section .rodata\n";
|
|
os << ".align 2\n";
|
|
for (const auto& fc : float_constants) {
|
|
os << fc.label << ":\n";
|
|
uint32_t bits;
|
|
std::memcpy(&bits, &fc.value, sizeof(float));
|
|
os << " .word " << bits << " // float " << fc.value << "\n";
|
|
}
|
|
}
|
|
}
|
|
|
|
static uint32_t GetTypeSize(const ir::Type* type) {
|
|
if (type->IsInt32() || type->IsFloat()) {
|
|
return 4;
|
|
}
|
|
if (type->IsPtrInt32() || type->IsPtrFloat()) {
|
|
return 8;
|
|
}
|
|
if (type->IsArray()) {
|
|
auto* arr_ty = const_cast<ir::Type*>(type)->GetAsArrayType().get();
|
|
return arr_ty->GetNumElements() * GetTypeSize(arr_ty->GetElementType().get());
|
|
}
|
|
return 4;
|
|
}
|
|
|
|
void PrintGlobals(const ir::Module& module, std::ostream& os) {
|
|
for (const auto& gv : module.GetGlobalValues()) {
|
|
os << ".global " << gv->GetName() << "\n";
|
|
|
|
std::shared_ptr<ir::Type> actual_ty = gv->GetType();
|
|
if (actual_ty->IsPtrInt32()) actual_ty = ir::Type::GetInt32Type();
|
|
else if (actual_ty->IsPtrFloat()) actual_ty = ir::Type::GetFloatType();
|
|
|
|
uint32_t actual_size = GetTypeSize(actual_ty.get());
|
|
|
|
if (gv->GetInitializer()) {
|
|
os << ".data\n";
|
|
os << ".align 2\n";
|
|
os << ".size " << gv->GetName() << ", " << actual_size << "\n";
|
|
os << gv->GetName() << ":\n";
|
|
|
|
if (actual_ty->IsFloat()) {
|
|
float val = 0.0f;
|
|
if (auto* cf = dynamic_cast<const ir::ConstantFloat*>(gv->GetInitializer())) {
|
|
val = cf->GetValue();
|
|
} else if (auto* ci = dynamic_cast<const ir::ConstantInt*>(gv->GetInitializer())) {
|
|
val = static_cast<float>(ci->GetValue());
|
|
}
|
|
uint32_t bits;
|
|
std::memcpy(&bits, &val, sizeof(float));
|
|
os << " .word " << bits << " // float " << val << "\n";
|
|
} else {
|
|
int val = 0;
|
|
if (auto* ci = dynamic_cast<const ir::ConstantInt*>(gv->GetInitializer())) {
|
|
val = ci->GetValue();
|
|
} else if (auto* cf = dynamic_cast<const ir::ConstantFloat*>(gv->GetInitializer())) {
|
|
val = static_cast<int>(cf->GetValue());
|
|
}
|
|
os << " .word " << val << "\n";
|
|
}
|
|
} else {
|
|
os << ".bss\n";
|
|
os << ".align 4\n";
|
|
os << ".size " << gv->GetName() << ", " << actual_size << "\n";
|
|
os << gv->GetName() << ":\n";
|
|
os << " .zero " << actual_size << "\n";
|
|
}
|
|
os << "\n";
|
|
}
|
|
}
|
|
|
|
} // namespace mir
|