Files
nudt-compiler-cpp/src/ir/IRPrinter.cpp

370 lines
12 KiB
C++
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
// IR 文本输出:
// - 将 IR 打印为 .ll 风格的文本
// - 支撑调试与测试对比diff
#include "ir/IR.h"
#include <cstdio>
#include <iomanip>
#include <sstream>
#include <ostream>
#include <cstring>
#include <stdexcept>
#include <string>
#include "utils/Log.h"
namespace ir {
static std::string TypeToString(const Type& ty) {
switch (ty.GetKind()) {
case Type::Kind::Void:
return "void";
case Type::Kind::Int32:
return "i32";
case Type::Kind::PtrInt32:
return "i32*";
case Type::Kind::Float:
return "float";
case Type::Kind::PtrFloat:
return "float*";
case Type::Kind::Label:
return "label";
case Type::Kind::Array: {
const auto* arr_ty = static_cast<const ArrayType*>(&ty);
return "[" + std::to_string(arr_ty->GetNumElements()) + " x " +
TypeToString(*arr_ty->GetElementType()) + "]";
}
}
return "unknown";
}
static std::string OpcodeToString(Opcode op) {
switch (op) {
case Opcode::Add:
return "add";
case Opcode::Sub:
return "sub";
case Opcode::Mul:
return "mul";
case Opcode::Div:
return "sdiv";
case Opcode::Mod:
return "srem";
case Opcode::FAdd:
return "fadd";
case Opcode::FSub:
return "fsub";
case Opcode::FMul:
return "fmul";
case Opcode::FDiv:
return "fdiv";
case Opcode::ICmpEQ:
return "icmp eq";
case Opcode::ICmpNE:
return "icmp ne";
case Opcode::ICmpLT:
return "icmp slt";
case Opcode::ICmpGT:
return "icmp sgt";
case Opcode::ICmpLE:
return "icmp sle";
case Opcode::ICmpGE:
return "icmp sge";
case Opcode::FCmpEQ:
return "fcmp oeq";
case Opcode::FCmpNE:
return "fcmp une";
case Opcode::FCmpLT:
return "fcmp olt";
case Opcode::FCmpGT:
return "fcmp ogt";
case Opcode::FCmpLE:
return "fcmp ole";
case Opcode::FCmpGE:
return "fcmp oge";
case Opcode::Alloca:
return "alloca";
case Opcode::Load:
return "load";
case Opcode::Store:
return "store";
case Opcode::Ret:
return "ret";
case Opcode::Br:
return "br";
case Opcode::Call:
return "call";
case Opcode::GEP:
return "getelementptr";
case Opcode::ZExt:
return "zext";
case Opcode::SIToFP:
return "sitofp";
case Opcode::FPToSI:
return "fptosi";
case Opcode::Phi:
return "phi";
}
return "?";
}
static std::string ValueToString(const Value* v) {
if (!v) return "<null>";
if (auto* ci = dynamic_cast<const ConstantInt*>(v)) {
return std::to_string(ci->GetValue());
}
if (auto* cf = dynamic_cast<const ConstantFloat*>(v)) {
const double as_double = static_cast<double>(cf->GetValue());
uint64_t bits = 0;
std::memcpy(&bits, &as_double, sizeof(bits));
std::ostringstream oss;
oss << "0x" << std::hex << std::uppercase << std::setw(16)
<< std::setfill('0') << bits;
return oss.str();
}
if (v->IsGlobalValue() || v->IsFunction()) {
return "@" + v->GetName();
}
if (v->IsInstruction() || v->IsArgument() || v->GetType()->IsLabel()) {
return "%" + v->GetName();
}
return v->GetName();
}
static bool IsBoolLikeValue(const Value* v) {
if (auto* inst = dynamic_cast<const Instruction*>(v)) {
switch (inst->GetOpcode()) {
case Opcode::ICmpEQ:
case Opcode::ICmpNE:
case Opcode::ICmpLT:
case Opcode::ICmpGT:
case Opcode::ICmpLE:
case Opcode::ICmpGE:
case Opcode::FCmpEQ:
case Opcode::FCmpNE:
case Opcode::FCmpLT:
case Opcode::FCmpGT:
case Opcode::FCmpLE:
case Opcode::FCmpGE:
return true;
default:
break;
}
}
return false;
}
static std::string PrintedValueType(const Value* v) {
if (IsBoolLikeValue(v)) return "i1";
return TypeToString(*v->GetType());
}
void IRPrinter::Print(const Module& module, std::ostream& os) {
// Print global variables
for (const auto& gv : module.GetGlobalValues()) {
os << "@" << gv->GetName() << " = global ";
if (gv->GetType()->IsPtrInt32()) {
os << "i32";
} else if (gv->GetType()->IsPtrFloat()) {
os << "float";
} else {
os << TypeToString(*gv->GetType());
}
if (gv->GetInitializer()) {
os << " " << ValueToString(gv->GetInitializer());
} else {
os << " zeroinitializer";
}
os << "\n";
}
if (!module.GetGlobalValues().empty()) os << "\n";
for (const auto& func : module.GetFunctions()) {
if (func->GetBlocks().empty()) {
os << "declare " << TypeToString(*func->GetType()) << " @" << func->GetName()
<< "(";
// For declarations, we just need types. But Argument objects might exist.
const auto& args = func->GetArguments();
for (size_t i = 0; i < args.size(); ++i) {
os << TypeToString(*args[i]->GetType());
if (i + 1 < args.size()) os << ", ";
}
os << ")\n\n";
continue;
}
os << "define " << TypeToString(*func->GetType()) << " @" << func->GetName()
<< "(";
const auto& args = func->GetArguments();
for (size_t i = 0; i < args.size(); ++i) {
os << TypeToString(*args[i]->GetType()) << " %" << args[i]->GetName();
if (i + 1 < args.size()) os << ", ";
}
os << ") {\n";
for (const auto& bb : func->GetBlocks()) {
if (!bb) {
continue;
}
os << bb->GetName() << ":\n";
for (const auto& instPtr : bb->GetInstructions()) {
const auto* inst = instPtr.get();
switch (inst->GetOpcode()) {
case Opcode::Add:
case Opcode::Sub:
case Opcode::Mul:
case Opcode::Div:
case Opcode::Mod:
case Opcode::FAdd:
case Opcode::FSub:
case Opcode::FMul:
case Opcode::FDiv: {
auto* bin = static_cast<const BinaryInst*>(inst);
os << " %" << bin->GetName() << " = "
<< OpcodeToString(bin->GetOpcode()) << " "
<< PrintedValueType(bin->GetLhs()) << " "
<< ValueToString(bin->GetLhs()) << ", "
<< ValueToString(bin->GetRhs()) << "\n";
break;
}
case Opcode::ICmpEQ:
case Opcode::ICmpNE:
case Opcode::ICmpLT:
case Opcode::ICmpGT:
case Opcode::ICmpLE:
case Opcode::ICmpGE:
case Opcode::FCmpEQ:
case Opcode::FCmpNE:
case Opcode::FCmpLT:
case Opcode::FCmpGT:
case Opcode::FCmpLE:
case Opcode::FCmpGE: {
auto* bin = static_cast<const BinaryInst*>(inst);
os << " %" << bin->GetName() << " = "
<< OpcodeToString(bin->GetOpcode()) << " "
<< PrintedValueType(bin->GetLhs()) << " "
<< ValueToString(bin->GetLhs()) << ", "
<< ValueToString(bin->GetRhs()) << "\n";
break;
}
case Opcode::Alloca: {
auto* alloca = static_cast<const AllocaInst*>(inst);
os << " %" << alloca->GetName() << " = alloca ";
if (alloca->GetType()->IsPtrInt32())
os << "i32";
else if (alloca->GetType()->IsPtrFloat())
os << "float";
else
os << TypeToString(*alloca->GetType());
os << "\n";
break;
}
case Opcode::Load: {
auto* load = static_cast<const LoadInst*>(inst);
os << " %" << load->GetName() << " = load "
<< TypeToString(*load->GetType()) << ", "
<< TypeToString(*load->GetPtr()->GetType()) << " "
<< ValueToString(load->GetPtr()) << "\n";
break;
}
case Opcode::Store: {
auto* store = static_cast<const StoreInst*>(inst);
os << " store " << TypeToString(*store->GetValue()->GetType())
<< " " << ValueToString(store->GetValue()) << ", "
<< TypeToString(*store->GetPtr()->GetType()) << " "
<< ValueToString(store->GetPtr()) << "\n";
break;
}
case Opcode::Ret: {
auto* ret = static_cast<const ReturnInst*>(inst);
os << " ret ";
if (ret->GetValue()) {
os << TypeToString(*ret->GetValue()->GetType()) << " "
<< ValueToString(ret->GetValue());
} else {
os << "void";
}
os << "\n";
break;
}
case Opcode::Br: {
auto* br = static_cast<const BranchInst*>(inst);
if (br->IsConditional()) {
os << " br i1 " << ValueToString(br->GetCondition())
<< ", label " << ValueToString(br->GetIfTrue()) << ", label "
<< ValueToString(br->GetIfFalse()) << "\n";
} else {
os << " br label " << ValueToString(br->GetDest()) << "\n";
}
break;
}
case Opcode::Call: {
auto* call = static_cast<const CallInst*>(inst);
auto* func = call->GetFunction();
if (!call->GetType()->IsVoid()) {
os << " %" << call->GetName() << " = ";
} else {
os << " ";
}
os << "call " << TypeToString(*call->GetType()) << " "
<< ValueToString(func) << "(";
for (size_t i = 1; i < call->GetNumOperands(); ++i) {
auto* arg = call->GetOperand(i);
os << PrintedValueType(arg) << " " << ValueToString(arg);
if (i + 1 < call->GetNumOperands()) os << ", ";
}
os << ")\n";
break;
}
case Opcode::GEP: {
auto* gep = static_cast<const GetElementPtrInst*>(inst);
os << " %" << gep->GetName() << " = getelementptr ";
if (gep->GetPtr()->GetType()->IsPtrInt32())
os << "i32";
else if (gep->GetPtr()->GetType()->IsPtrFloat())
os << "float";
else
os << TypeToString(*gep->GetPtr()->GetType());
os << ", ";
if (gep->GetPtr()->GetType()->IsArray()) {
os << TypeToString(*gep->GetPtr()->GetType()) << "* ";
} else {
os << TypeToString(*gep->GetPtr()->GetType()) << " ";
}
os << ValueToString(gep->GetPtr());
for (size_t i = 1; i < gep->GetNumOperands(); ++i) {
os << ", " << TypeToString(*gep->GetOperand(i)->GetType()) << " "
<< ValueToString(gep->GetOperand(i));
}
os << "\n";
break;
}
case Opcode::ZExt:
case Opcode::SIToFP:
case Opcode::FPToSI: {
auto* cast = static_cast<const CastInst*>(inst);
os << " %" << cast->GetName() << " = "
<< OpcodeToString(cast->GetOpcode()) << " "
<< PrintedValueType(cast->GetValue()) << " "
<< ValueToString(cast->GetValue()) << " to "
<< TypeToString(*cast->GetType()) << "\n";
break;
}
case Opcode::Phi: {
auto* phi = static_cast<const PhiInst*>(inst);
os << " %" << phi->GetName() << " = phi " << TypeToString(*phi->GetType()) << " ";
for (size_t i = 0; i < phi->GetNumIncoming(); ++i) {
if (i > 0) os << ", ";
os << "[ " << ValueToString(phi->GetIncomingValue(i)) << ", %" << phi->GetIncomingBlock(i)->GetName() << " ]";
}
os << "\n";
break;
}
}
}
}
os << "}\n";
}
}
} // namespace ir