Files
nudt-compiler-cpp/src/mir/AsmPrinter.cpp

325 lines
12 KiB
C++

#include "mir/MIR.h"
#include "ir/IR.h"
#include <ostream>
#include <stdexcept>
#include <cstdint>
#include <vector>
#include <cstring>
#include "utils/Log.h"
namespace mir {
namespace {
const FrameSlot& GetFrameSlot(const MachineFunction& function,
const Operand& operand) {
if (operand.GetKind() != Operand::Kind::FrameIndex) {
throw std::runtime_error(FormatError("mir", "期望 FrameIndex 操作数"));
}
return function.GetFrameSlot(operand.GetFrameIndex());
}
bool IsFloatReg(PhysReg reg) {
return reg >= PhysReg::S0 && reg <= PhysReg::S15;
}
void PrintStackAccess(std::ostream& os, const char* mnemonic, PhysReg reg,
int offset) {
bool is_float = IsFloatReg(reg);
const char* ldr_cmd = is_float ? "ldr" : "ldr";
const char* str_cmd = is_float ? "str" : "str";
const char* base_mnemonic = (std::strcmp(mnemonic, "ldur") == 0) ? ldr_cmd : str_cmd;
if (offset >= -256 && offset <= 255) {
if (is_float) {
os << " " << base_mnemonic << " " << PhysRegName(reg) << ", [x29, #" << offset << "]\n";
} else {
os << " " << mnemonic << " " << PhysRegName(reg) << ", [x29, #" << offset << "]\n";
}
} else {
os << " mov x10, #" << offset << "\n";
os << " " << base_mnemonic << " " << PhysRegName(reg) << ", [x29, x10]\n";
}
}
std::string GetBlockLabel(const std::string& func_name, const std::string& block_name) {
if (block_name == "entry" || block_name.empty()) {
return func_name;
}
return ".L_" + func_name + "_" + block_name;
}
} // namespace
void PrintAsm(const MachineFunction& function, std::ostream& os) {
os << ".text\n";
os << ".global " << function.GetName() << "\n";
os << ".type " << function.GetName() << ", %function\n";
struct FloatConstant {
std::string label;
float value;
};
std::vector<FloatConstant> float_constants;
for (size_t b = 0; b < function.GetBlocks().size(); ++b) {
const auto& block = function.GetBlocks()[b];
// Print the block label
if (b == 0) {
os << function.GetName() << ":\n";
} else {
os << GetBlockLabel(function.GetName(), block.GetName()) << ":\n";
}
for (const auto& inst : block.GetInstructions()) {
const auto& ops = inst.GetOperands();
switch (inst.GetOpcode()) {
case Opcode::Prologue:
os << " stp x29, x30, [sp, #-16]!\n";
os << " mov x29, sp\n";
if (function.GetFrameSize() > 0) {
os << " sub sp, sp, #" << function.GetFrameSize() << "\n";
}
break;
case Opcode::Epilogue:
if (function.GetFrameSize() > 0) {
os << " add sp, sp, #" << function.GetFrameSize() << "\n";
}
os << " ldp x29, x30, [sp], #16\n";
break;
case Opcode::MovImm: {
PhysReg dst = ops.at(0).GetReg();
if (IsFloatReg(dst)) {
// Load float constant
int bits = ops.at(1).GetImm();
float val;
std::memcpy(&val, &bits, sizeof(float));
std::string flabel = ".LC_" + function.GetName() + "_" + std::to_string(float_constants.size());
float_constants.push_back({flabel, val});
os << " adrp x8, " << flabel << "\n";
os << " ldr " << PhysRegName(dst) << ", [x8, :lo12:" << flabel << "]\n";
} else {
os << " mov " << PhysRegName(dst) << ", #" << ops.at(1).GetImm() << "\n";
}
break;
}
case Opcode::LoadStack: {
const auto& slot = GetFrameSlot(function, ops.at(1));
PrintStackAccess(os, "ldur", ops.at(0).GetReg(), slot.offset);
break;
}
case Opcode::StoreStack: {
const auto& slot = GetFrameSlot(function, ops.at(1));
PrintStackAccess(os, "stur", ops.at(0).GetReg(), slot.offset);
break;
}
case Opcode::AddRR:
os << " add " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::SubRR:
os << " sub " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::MulRR:
os << " mul " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::SDivRR:
os << " sdiv " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::MSubRRRR:
os << " msub " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << ", "
<< PhysRegName(ops.at(3).GetReg()) << "\n";
break;
case Opcode::FAddRRR:
os << " fadd " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::FSubRRR:
os << " fsub " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::FMulRRR:
os << " fmul " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::FDivRRR:
os << " fdiv " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::CmpRR:
os << " cmp " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << "\n";
break;
case Opcode::FCmpRR:
os << " fcmp " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << "\n";
break;
case Opcode::Cset:
os << " cset " << PhysRegName(ops.at(0).GetReg()) << ", "
<< ops.at(1).GetCondCode() << "\n";
break;
case Opcode::B:
os << " b " << GetBlockLabel(function.GetName(), ops.at(0).GetLabelName()) << "\n";
break;
case Opcode::BCond:
os << " b." << ops.at(0).GetCondCode() << " "
<< GetBlockLabel(function.GetName(), ops.at(1).GetLabelName()) << "\n";
break;
case Opcode::Call:
os << " bl " << ops.at(0).GetGlobalName() << "\n";
break;
case Opcode::Ret:
os << " ret\n";
break;
case Opcode::MovReg:
if (IsFloatReg(ops.at(0).GetReg()) || IsFloatReg(ops.at(1).GetReg())) {
os << " fmov " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << "\n";
} else {
os << " mov " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << "\n";
}
break;
case Opcode::Adrp:
os << " adrp " << PhysRegName(ops.at(0).GetReg()) << ", "
<< ops.at(1).GetGlobalName() << "\n";
break;
case Opcode::AddRegImm: {
os << " add " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", ";
if (ops.at(2).GetKind() == Operand::Kind::FrameIndex) {
const auto& slot = function.GetFrameSlot(ops.at(2).GetFrameIndex());
os << "#" << slot.offset << "\n";
} else if (ops.at(2).GetKind() == Operand::Kind::Global) {
os << ":lo12:" << ops.at(2).GetGlobalName() << "\n";
} else {
os << "#" << ops.at(2).GetImm() << "\n";
}
break;
}
case Opcode::LdrRegReg: {
PhysReg reg = ops.at(0).GetReg();
const char* ldr_cmd = IsFloatReg(reg) ? "ldr" : "ldr";
os << " " << ldr_cmd << " " << PhysRegName(reg) << ", ["
<< PhysRegName(ops.at(1).GetReg()) << "]\n";
break;
}
case Opcode::StrRegReg: {
PhysReg reg = ops.at(0).GetReg();
const char* str_cmd = IsFloatReg(reg) ? "str" : "str";
os << " " << str_cmd << " " << PhysRegName(reg) << ", ["
<< PhysRegName(ops.at(1).GetReg()) << "]\n";
break;
}
case Opcode::SIToFP:
os << " scvtf " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << "\n";
break;
case Opcode::FPToSI:
os << " fcvtzs " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << "\n";
break;
case Opcode::ZExt:
if (ops.at(0).GetReg() >= PhysReg::X0 && ops.at(0).GetReg() <= PhysReg::X28) {
os << " sxtw " << PhysRegName(ops.at(0).GetReg()) << ", " << PhysRegName(ops.at(1).GetReg()) << "\n";
} else {
os << " and " << PhysRegName(ops.at(0).GetReg()) << ", " << PhysRegName(ops.at(1).GetReg()) << ", #1\n";
}
break;
}
}
}
os << ".size " << function.GetName() << ", .-" << function.GetName() << "\n";
// Print read-only data segment if there are float constants
if (!float_constants.empty()) {
os << ".section .rodata\n";
os << ".align 2\n";
for (const auto& fc : float_constants) {
os << fc.label << ":\n";
uint32_t bits;
std::memcpy(&bits, &fc.value, sizeof(float));
os << " .word " << bits << " // float " << fc.value << "\n";
}
}
}
static uint32_t GetTypeSize(const ir::Type* type) {
if (type->IsInt32() || type->IsFloat()) {
return 4;
}
if (type->IsPtrInt32() || type->IsPtrFloat()) {
return 8;
}
if (type->IsArray()) {
auto* arr_ty = const_cast<ir::Type*>(type)->GetAsArrayType().get();
return arr_ty->GetNumElements() * GetTypeSize(arr_ty->GetElementType().get());
}
return 4;
}
void PrintGlobals(const ir::Module& module, std::ostream& os) {
for (const auto& gv : module.GetGlobalValues()) {
os << ".global " << gv->GetName() << "\n";
std::shared_ptr<ir::Type> actual_ty = gv->GetType();
if (actual_ty->IsPtrInt32()) actual_ty = ir::Type::GetInt32Type();
else if (actual_ty->IsPtrFloat()) actual_ty = ir::Type::GetFloatType();
uint32_t actual_size = GetTypeSize(actual_ty.get());
if (gv->GetInitializer()) {
os << ".data\n";
os << ".align 2\n";
os << ".size " << gv->GetName() << ", " << actual_size << "\n";
os << gv->GetName() << ":\n";
if (actual_ty->IsFloat()) {
float val = 0.0f;
if (auto* cf = dynamic_cast<const ir::ConstantFloat*>(gv->GetInitializer())) {
val = cf->GetValue();
} else if (auto* ci = dynamic_cast<const ir::ConstantInt*>(gv->GetInitializer())) {
val = static_cast<float>(ci->GetValue());
}
uint32_t bits;
std::memcpy(&bits, &val, sizeof(float));
os << " .word " << bits << " // float " << val << "\n";
} else {
int val = 0;
if (auto* ci = dynamic_cast<const ir::ConstantInt*>(gv->GetInitializer())) {
val = ci->GetValue();
} else if (auto* cf = dynamic_cast<const ir::ConstantFloat*>(gv->GetInitializer())) {
val = static_cast<int>(cf->GetValue());
}
os << " .word " << val << "\n";
}
} else {
os << ".bss\n";
os << ".align 4\n";
os << ".size " << gv->GetName() << ", " << actual_size << "\n";
os << gv->GetName() << ":\n";
os << " .zero " << actual_size << "\n";
}
os << "\n";
}
}
} // namespace mir