Complete Lab2 IR generation and document process

This commit is contained in:
2026-04-16 00:21:35 +08:00
parent 6fc0c89072
commit 979d271ebe
23 changed files with 2583 additions and 471 deletions

View File

@@ -4,7 +4,11 @@
#include "ir/IR.h"
#include <cstdio>
#include <iomanip>
#include <sstream>
#include <ostream>
#include <cstring>
#include <stdexcept>
#include <string>
@@ -12,7 +16,7 @@
namespace ir {
static const char* TypeToString(const Type& ty) {
static std::string TypeToString(const Type& ty) {
switch (ty.GetKind()) {
case Type::Kind::Void:
return "void";
@@ -20,11 +24,22 @@ static const char* TypeToString(const Type& ty) {
return "i32";
case Type::Kind::PtrInt32:
return "i32*";
case Type::Kind::Float:
return "float";
case Type::Kind::PtrFloat:
return "float*";
case Type::Kind::Label:
return "label";
case Type::Kind::Array: {
const auto* arr_ty = static_cast<const ArrayType*>(&ty);
return "[" + std::to_string(arr_ty->GetNumElements()) + " x " +
TypeToString(*arr_ty->GetElementType()) + "]";
}
}
throw std::runtime_error(FormatError("ir", "未知类型"));
return "unknown";
}
static const char* OpcodeToString(Opcode op) {
static std::string OpcodeToString(Opcode op) {
switch (op) {
case Opcode::Add:
return "add";
@@ -32,6 +47,42 @@ static const char* OpcodeToString(Opcode op) {
return "sub";
case Opcode::Mul:
return "mul";
case Opcode::Div:
return "sdiv";
case Opcode::Mod:
return "srem";
case Opcode::FAdd:
return "fadd";
case Opcode::FSub:
return "fsub";
case Opcode::FMul:
return "fmul";
case Opcode::FDiv:
return "fdiv";
case Opcode::ICmpEQ:
return "icmp eq";
case Opcode::ICmpNE:
return "icmp ne";
case Opcode::ICmpLT:
return "icmp slt";
case Opcode::ICmpGT:
return "icmp sgt";
case Opcode::ICmpLE:
return "icmp sle";
case Opcode::ICmpGE:
return "icmp sge";
case Opcode::FCmpEQ:
return "fcmp oeq";
case Opcode::FCmpNE:
return "fcmp une";
case Opcode::FCmpLT:
return "fcmp olt";
case Opcode::FCmpGT:
return "fcmp ogt";
case Opcode::FCmpLE:
return "fcmp ole";
case Opcode::FCmpGE:
return "fcmp oge";
case Opcode::Alloca:
return "alloca";
case Opcode::Load:
@@ -40,21 +91,114 @@ static const char* OpcodeToString(Opcode op) {
return "store";
case Opcode::Ret:
return "ret";
case Opcode::Br:
return "br";
case Opcode::Call:
return "call";
case Opcode::GEP:
return "getelementptr";
case Opcode::ZExt:
return "zext";
case Opcode::SIToFP:
return "sitofp";
case Opcode::FPToSI:
return "fptosi";
}
return "?";
}
static std::string ValueToString(const Value* v) {
if (!v) return "<null>";
if (auto* ci = dynamic_cast<const ConstantInt*>(v)) {
return std::to_string(ci->GetValue());
}
return v ? v->GetName() : "<null>";
if (auto* cf = dynamic_cast<const ConstantFloat*>(v)) {
const double as_double = static_cast<double>(cf->GetValue());
uint64_t bits = 0;
std::memcpy(&bits, &as_double, sizeof(bits));
std::ostringstream oss;
oss << "0x" << std::hex << std::uppercase << std::setw(16)
<< std::setfill('0') << bits;
return oss.str();
}
if (v->IsGlobalValue() || v->IsFunction()) {
return "@" + v->GetName();
}
if (v->IsInstruction() || v->IsArgument() || v->GetType()->IsLabel()) {
return "%" + v->GetName();
}
return v->GetName();
}
static bool IsBoolLikeValue(const Value* v) {
if (auto* inst = dynamic_cast<const Instruction*>(v)) {
switch (inst->GetOpcode()) {
case Opcode::ICmpEQ:
case Opcode::ICmpNE:
case Opcode::ICmpLT:
case Opcode::ICmpGT:
case Opcode::ICmpLE:
case Opcode::ICmpGE:
case Opcode::FCmpEQ:
case Opcode::FCmpNE:
case Opcode::FCmpLT:
case Opcode::FCmpGT:
case Opcode::FCmpLE:
case Opcode::FCmpGE:
return true;
default:
break;
}
}
return false;
}
static std::string PrintedValueType(const Value* v) {
if (IsBoolLikeValue(v)) return "i1";
return TypeToString(*v->GetType());
}
void IRPrinter::Print(const Module& module, std::ostream& os) {
// Print global variables
for (const auto& gv : module.GetGlobalValues()) {
os << "@" << gv->GetName() << " = global ";
if (gv->GetType()->IsPtrInt32()) {
os << "i32";
} else if (gv->GetType()->IsPtrFloat()) {
os << "float";
} else {
os << TypeToString(*gv->GetType());
}
if (gv->GetInitializer()) {
os << " " << ValueToString(gv->GetInitializer());
} else {
os << " zeroinitializer";
}
os << "\n";
}
if (!module.GetGlobalValues().empty()) os << "\n";
for (const auto& func : module.GetFunctions()) {
if (func->GetBlocks().empty()) {
os << "declare " << TypeToString(*func->GetType()) << " @" << func->GetName()
<< "(";
// For declarations, we just need types. But Argument objects might exist.
const auto& args = func->GetArguments();
for (size_t i = 0; i < args.size(); ++i) {
os << TypeToString(*args[i]->GetType());
if (i + 1 < args.size()) os << ", ";
}
os << ")\n\n";
continue;
}
os << "define " << TypeToString(*func->GetType()) << " @" << func->GetName()
<< "() {\n";
<< "(";
const auto& args = func->GetArguments();
for (size_t i = 0; i < args.size(); ++i) {
os << TypeToString(*args[i]->GetType()) << " %" << args[i]->GetName();
if (i + 1 < args.size()) os << ", ";
}
os << ") {\n";
for (const auto& bb : func->GetBlocks()) {
if (!bb) {
continue;
@@ -65,36 +209,142 @@ void IRPrinter::Print(const Module& module, std::ostream& os) {
switch (inst->GetOpcode()) {
case Opcode::Add:
case Opcode::Sub:
case Opcode::Mul: {
case Opcode::Mul:
case Opcode::Div:
case Opcode::Mod:
case Opcode::FAdd:
case Opcode::FSub:
case Opcode::FMul:
case Opcode::FDiv: {
auto* bin = static_cast<const BinaryInst*>(inst);
os << " " << bin->GetName() << " = "
os << " %" << bin->GetName() << " = "
<< OpcodeToString(bin->GetOpcode()) << " "
<< TypeToString(*bin->GetLhs()->GetType()) << " "
<< PrintedValueType(bin->GetLhs()) << " "
<< ValueToString(bin->GetLhs()) << ", "
<< ValueToString(bin->GetRhs()) << "\n";
break;
}
case Opcode::ICmpEQ:
case Opcode::ICmpNE:
case Opcode::ICmpLT:
case Opcode::ICmpGT:
case Opcode::ICmpLE:
case Opcode::ICmpGE:
case Opcode::FCmpEQ:
case Opcode::FCmpNE:
case Opcode::FCmpLT:
case Opcode::FCmpGT:
case Opcode::FCmpLE:
case Opcode::FCmpGE: {
auto* bin = static_cast<const BinaryInst*>(inst);
os << " %" << bin->GetName() << " = "
<< OpcodeToString(bin->GetOpcode()) << " "
<< PrintedValueType(bin->GetLhs()) << " "
<< ValueToString(bin->GetLhs()) << ", "
<< ValueToString(bin->GetRhs()) << "\n";
break;
}
case Opcode::Alloca: {
auto* alloca = static_cast<const AllocaInst*>(inst);
os << " " << alloca->GetName() << " = alloca i32\n";
os << " %" << alloca->GetName() << " = alloca ";
if (alloca->GetType()->IsPtrInt32())
os << "i32";
else if (alloca->GetType()->IsPtrFloat())
os << "float";
else
os << TypeToString(*alloca->GetType());
os << "\n";
break;
}
case Opcode::Load: {
auto* load = static_cast<const LoadInst*>(inst);
os << " " << load->GetName() << " = load i32, i32* "
os << " %" << load->GetName() << " = load "
<< TypeToString(*load->GetType()) << ", "
<< TypeToString(*load->GetPtr()->GetType()) << " "
<< ValueToString(load->GetPtr()) << "\n";
break;
}
case Opcode::Store: {
auto* store = static_cast<const StoreInst*>(inst);
os << " store i32 " << ValueToString(store->GetValue())
<< ", i32* " << ValueToString(store->GetPtr()) << "\n";
os << " store " << TypeToString(*store->GetValue()->GetType())
<< " " << ValueToString(store->GetValue()) << ", "
<< TypeToString(*store->GetPtr()->GetType()) << " "
<< ValueToString(store->GetPtr()) << "\n";
break;
}
case Opcode::Ret: {
auto* ret = static_cast<const ReturnInst*>(inst);
os << " ret " << TypeToString(*ret->GetValue()->GetType()) << " "
<< ValueToString(ret->GetValue()) << "\n";
os << " ret ";
if (ret->GetValue()) {
os << TypeToString(*ret->GetValue()->GetType()) << " "
<< ValueToString(ret->GetValue());
} else {
os << "void";
}
os << "\n";
break;
}
case Opcode::Br: {
auto* br = static_cast<const BranchInst*>(inst);
if (br->IsConditional()) {
os << " br i1 " << ValueToString(br->GetCondition())
<< ", label " << ValueToString(br->GetIfTrue()) << ", label "
<< ValueToString(br->GetIfFalse()) << "\n";
} else {
os << " br label " << ValueToString(br->GetDest()) << "\n";
}
break;
}
case Opcode::Call: {
auto* call = static_cast<const CallInst*>(inst);
auto* func = call->GetFunction();
if (!call->GetType()->IsVoid()) {
os << " %" << call->GetName() << " = ";
} else {
os << " ";
}
os << "call " << TypeToString(*call->GetType()) << " "
<< ValueToString(func) << "(";
for (size_t i = 1; i < call->GetNumOperands(); ++i) {
auto* arg = call->GetOperand(i);
os << PrintedValueType(arg) << " " << ValueToString(arg);
if (i + 1 < call->GetNumOperands()) os << ", ";
}
os << ")\n";
break;
}
case Opcode::GEP: {
auto* gep = static_cast<const GetElementPtrInst*>(inst);
os << " %" << gep->GetName() << " = getelementptr ";
if (gep->GetPtr()->GetType()->IsPtrInt32())
os << "i32";
else if (gep->GetPtr()->GetType()->IsPtrFloat())
os << "float";
else
os << TypeToString(*gep->GetPtr()->GetType());
os << ", ";
if (gep->GetPtr()->GetType()->IsArray()) {
os << TypeToString(*gep->GetPtr()->GetType()) << "* ";
} else {
os << TypeToString(*gep->GetPtr()->GetType()) << " ";
}
os << ValueToString(gep->GetPtr());
for (size_t i = 1; i < gep->GetNumOperands(); ++i) {
os << ", " << TypeToString(*gep->GetOperand(i)->GetType()) << " "
<< ValueToString(gep->GetOperand(i));
}
os << "\n";
break;
}
case Opcode::ZExt:
case Opcode::SIToFP:
case Opcode::FPToSI: {
auto* cast = static_cast<const CastInst*>(inst);
os << " %" << cast->GetName() << " = "
<< OpcodeToString(cast->GetOpcode()) << " "
<< PrintedValueType(cast->GetValue()) << " "
<< ValueToString(cast->GetValue()) << " to "
<< TypeToString(*cast->GetType()) << "\n";
break;
}
}