refactor(ir): ir改为更标准的实现

This commit is contained in:
jing
2026-03-18 01:53:54 +08:00
parent 1b283856b3
commit 7d4d60c546
9 changed files with 397 additions and 172 deletions

View File

@@ -1,6 +1,11 @@
// IR 基本块:
// - 保存指令序列
// - 维护或可计算前驱/后继关系,用于 CFG 分析与优化
// - 为后续 CFG 分析预留前驱/后继接口
//
// 当前仍是最小实现:
// - BasicBlock 已纳入 Value 体系,但类型先用 void 占位;
// - 指令追加与 terminator 约束主要在头文件中的 Append 模板里处理;
// - 前驱/后继容器已经预留,但当前项目里还没有分支指令与自动维护逻辑。
#include "ir/IR.h"
@@ -8,23 +13,27 @@
namespace ir {
BasicBlock::BasicBlock(std::string name) : name_(std::move(name)) {}
const std::string& BasicBlock::GetName() const { return name_; }
// 当前 BasicBlock 还没有专门的 label type因此先用 void 作为占位类型。
BasicBlock::BasicBlock(std::string name)
: Value(Type::GetVoidType(), std::move(name)) {}
Function* BasicBlock::GetParent() const { return parent_; }
void BasicBlock::SetParent(Function* parent) { parent_ = parent; }
bool BasicBlock::HasTerminator() const {
return !instructions_.empty() && instructions_.back()->IsTerminator();
}
// 按插入顺序返回块内指令序列。
const std::vector<std::unique_ptr<Instruction>>& BasicBlock::GetInstructions()
const {
return instructions_;
}
// 前驱/后继接口先保留给后续 CFG 扩展使用。
// 当前最小 IR 中还没有 branch 指令,因此这些列表通常为空。
const std::vector<BasicBlock*>& BasicBlock::GetPredecessors() const {
return predecessors_;
}

View File

@@ -3,6 +3,7 @@ add_library(ir_core STATIC
Module.cpp
Function.cpp
BasicBlock.cpp
GlobalValue.cpp
Type.cpp
Value.cpp
Instruction.cpp

View File

@@ -7,32 +7,11 @@ namespace ir {
Context::~Context() = default;
const std::shared_ptr<Type>& Context::Void() {
if (!void_) {
void_ = std::make_shared<Type>(Type::Kind::Void);
}
return void_;
}
const std::shared_ptr<Type>& Context::Int32() {
if (!int32_) {
int32_ = std::make_shared<Type>(Type::Kind::Int32);
}
return int32_;
}
const std::shared_ptr<Type>& Context::PtrInt32() {
if (!ptr_i32_) {
ptr_i32_ = std::make_shared<Type>(Type::Kind::PtrInt32);
}
return ptr_i32_;
}
ConstantInt* Context::GetConstInt(int v) {
auto it = const_ints_.find(v);
if (it != const_ints_.end()) return it->second.get();
auto inserted =
const_ints_.emplace(v, std::make_unique<ConstantInt>(Int32(), v)).first;
const_ints_.emplace(v, std::make_unique<ConstantInt>(Type::GetInt32Type(), v)).first;
return inserted->second.get();
}

11
src/ir/GlobalValue.cpp Normal file
View File

@@ -0,0 +1,11 @@
// GlobalValue 占位实现:
// - 具体的全局初始化器、打印和链接语义需要自行补全
#include "ir/IR.h"
namespace ir {
GlobalValue::GlobalValue(std::shared_ptr<Type> ty, std::string name)
: User(std::move(ty), std::move(name)) {}
} // namespace ir

View File

@@ -46,7 +46,7 @@ AllocaInst* IRBuilder::CreateAllocaI32(const std::string& name) {
if (!insert_block_) {
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
}
return insert_block_->Append<AllocaInst>(ctx_.PtrInt32(), name);
return insert_block_->Append<AllocaInst>(Type::GetPtrInt32Type(), name);
}
LoadInst* IRBuilder::CreateLoad(Value* ptr, const std::string& name) {
@@ -57,7 +57,7 @@ LoadInst* IRBuilder::CreateLoad(Value* ptr, const std::string& name) {
throw std::runtime_error(
FormatError("ir", "IRBuilder::CreateLoad 缺少 ptr"));
}
return insert_block_->Append<LoadInst>(ctx_.Int32(), ptr, name);
return insert_block_->Append<LoadInst>(Type::GetInt32Type(), ptr, name);
}
StoreInst* IRBuilder::CreateStore(Value* val, Value* ptr) {
@@ -72,7 +72,7 @@ StoreInst* IRBuilder::CreateStore(Value* val, Value* ptr) {
throw std::runtime_error(
FormatError("ir", "IRBuilder::CreateStore 缺少 ptr"));
}
return insert_block_->Append<StoreInst>(ctx_.Void(), val, ptr);
return insert_block_->Append<StoreInst>(Type::GetVoidType(), val, ptr);
}
ReturnInst* IRBuilder::CreateRet(Value* v) {
@@ -83,7 +83,7 @@ ReturnInst* IRBuilder::CreateRet(Value* v) {
throw std::runtime_error(
FormatError("ir", "IRBuilder::CreateRet 缺少返回值"));
}
return insert_block_->Append<ReturnInst>(ctx_.Void(), v);
return insert_block_->Append<ReturnInst>(Type::GetVoidType(), v);
}
} // namespace ir

View File

@@ -5,6 +5,21 @@ namespace ir {
Type::Type(Kind k) : kind_(k) {}
const std::shared_ptr<Type>& Type::GetVoidType() {
static const std::shared_ptr<Type> type = std::make_shared<Type>(Kind::Void);
return type;
}
const std::shared_ptr<Type>& Type::GetInt32Type() {
static const std::shared_ptr<Type> type = std::make_shared<Type>(Kind::Int32);
return type;
}
const std::shared_ptr<Type>& Type::GetPtrInt32Type() {
static const std::shared_ptr<Type> type = std::make_shared<Type>(Kind::PtrInt32);
return type;
}
Type::Kind Type::GetKind() const { return kind_; }
bool Type::IsVoid() const { return kind_ == Kind::Void; }

View File

@@ -3,6 +3,8 @@
// - 提供类型信息与使用/被使用关系(按需要实现)
#include "ir/IR.h"
#include <algorithm>
namespace ir {
Value::Value(std::shared_ptr<Type> ty, std::string name)
@@ -14,11 +16,68 @@ const std::string& Value::GetName() const { return name_; }
void Value::SetName(std::string n) { name_ = std::move(n); }
void Value::AddUser(Instruction* user) { users_.push_back(user); }
bool Value::IsVoid() const { return type_ && type_->IsVoid(); }
const std::vector<Instruction*>& Value::GetUsers() const { return users_; }
bool Value::IsInt32() const { return type_ && type_->IsInt32(); }
bool Value::IsPtrInt32() const { return type_ && type_->IsPtrInt32(); }
bool Value::IsConstant() const {
return dynamic_cast<const ConstantValue*>(this) != nullptr;
}
bool Value::IsInstruction() const {
return dynamic_cast<const Instruction*>(this) != nullptr;
}
bool Value::IsUser() const {
return dynamic_cast<const User*>(this) != nullptr;
}
bool Value::IsFunction() const {
return dynamic_cast<const Function*>(this) != nullptr;
}
void Value::AddUse(User* user, size_t operand_index) {
if (!user) return;
uses_.push_back(Use(this, user, operand_index));
}
void Value::RemoveUse(User* user, size_t operand_index) {
uses_.erase(
std::remove_if(uses_.begin(), uses_.end(),
[&](const Use& use) {
return use.GetUser() == user &&
use.GetOperandIndex() == operand_index;
}),
uses_.end());
}
const std::vector<Use>& Value::GetUses() const { return uses_; }
void Value::ReplaceAllUsesWith(Value* new_value) {
if (!new_value) {
throw std::runtime_error("ReplaceAllUsesWith 缺少 new_value");
}
if (new_value == this) {
return;
}
auto uses = uses_;
for (const auto& use : uses) {
auto* user = use.GetUser();
if (!user) continue;
size_t operand_index = use.GetOperandIndex();
if (user->GetOperand(operand_index) == this) {
user->SetOperand(operand_index, new_value);
}
}
}
ConstantValue::ConstantValue(std::shared_ptr<Type> ty, std::string name)
: Value(std::move(ty), std::move(name)) {}
ConstantInt::ConstantInt(std::shared_ptr<Type> ty, int v)
: Value(std::move(ty), ""), value_(v) {}
: ConstantValue(std::move(ty), ""), value_(v) {}
} // namespace ir

View File

@@ -1,16 +1,15 @@
#include "sem/Sema.h"
#include <any>
#include <stdexcept>
#include <string>
#include "SysYBaseVisitor.h"
#include "sem/SymbolTable.h"
#include "utils/Log.h"
namespace {
void CheckExpr(SysYParser::ExpContext& exp, const SymbolTable& table,
SemanticContext& sema);
std::string GetLValueName(SysYParser::LValueContext& lvalue) {
if (!lvalue.ID()) {
throw std::runtime_error(FormatError("sema", "非法左值"));
@@ -18,122 +17,184 @@ std::string GetLValueName(SysYParser::LValueContext& lvalue) {
return lvalue.ID()->getText();
}
void CheckVar(SysYParser::VarContext& var, const SymbolTable& table,
SemanticContext& sema) {
if (!var.ID()) {
throw std::runtime_error(FormatError("sema", "非法变量引用"));
}
const std::string name = var.ID()->getText();
auto* decl = table.Lookup(name);
if (!decl) {
throw std::runtime_error(FormatError("sema", "使用了未定义的变量: " + name));
}
sema.BindVarUse(&var, decl);
}
void CheckExpr(SysYParser::ExpContext& exp, const SymbolTable& table,
SemanticContext& sema) {
if (auto* paren = dynamic_cast<SysYParser::ParenExpContext*>(&exp)) {
CheckExpr(*paren->exp(), table, sema);
return;
}
if (auto* var = dynamic_cast<SysYParser::VarExpContext*>(&exp)) {
if (!var->var()) {
throw std::runtime_error(FormatError("sema", "非法变量表达式"));
class SemaVisitor final : public SysYBaseVisitor {
public:
std::any visitCompUnit(SysYParser::CompUnitContext* ctx) override {
if (!ctx) {
throw std::runtime_error(FormatError("sema", "缺少编译单元"));
}
CheckVar(*var->var(), table, sema);
return;
}
if (dynamic_cast<SysYParser::NumberExpContext*>(&exp)) {
return;
}
if (auto* binary = dynamic_cast<SysYParser::AdditiveExpContext*>(&exp)) {
CheckExpr(*binary->exp(0), table, sema);
CheckExpr(*binary->exp(1), table, sema);
return;
}
throw std::runtime_error(FormatError("sema", "暂不支持的表达式形式"));
}
SysYParser::FuncDefContext* FindMainFunc(SysYParser::CompUnitContext& comp_unit) {
auto* func = comp_unit.funcDef();
if (func && func->ID() && func->ID()->getText() == "main") {
return func;
}
return nullptr;
}
} // namespace
SemanticContext RunSema(SysYParser::CompUnitContext& comp_unit) {
auto* func = FindMainFunc(comp_unit);
if (!func || !func->blockStmt()) {
throw std::runtime_error(FormatError("sema", "缺少 main 函数定义"));
}
if (!func->funcType() || !func->funcType()->INT()) {
throw std::runtime_error(FormatError("sema", "当前仅支持 int main"));
}
SymbolTable table;
SemanticContext sema;
bool seen_return = false;
const auto& items = func->blockStmt()->blockItem();
if (items.empty()) {
throw std::runtime_error(
FormatError("sema", "main 函数不能为空,且必须以 return 结束"));
}
for (size_t i = 0; i < items.size(); ++i) {
auto* item = items[i];
if (!item) {
continue;
auto* func = ctx->funcDef();
if (!func || !func->blockStmt()) {
throw std::runtime_error(FormatError("sema", "缺少 main 函数定义"));
}
if (seen_return) {
if (!func->ID() || func->ID()->getText() != "main") {
throw std::runtime_error(FormatError("sema", "缺少 main 函数定义"));
}
func->accept(this);
if (!seen_return_) {
throw std::runtime_error(
FormatError("sema", "return 必须是 main 函数中的最后一条语句"));
FormatError("sema", "main 函数必须包含 return 语句"));
}
if (auto* decl = item->decl()) {
if (!decl->btype() || !decl->btype()->INT()) {
throw std::runtime_error(FormatError("sema", "当前仅支持局部 int 变量声明"));
}
auto* var_def = decl->varDef();
if (!var_def || !var_def->lValue()) {
throw std::runtime_error(FormatError("sema", "非法变量声明"));
}
const std::string name = GetLValueName(*var_def->lValue());
if (table.Contains(name)) {
throw std::runtime_error(FormatError("sema", "重复定义变量: " + name));
}
if (auto* init = var_def->initValue()) {
if (!init->exp()) {
throw std::runtime_error(
FormatError("sema", "当前不支持聚合初始化"));
}
CheckExpr(*init->exp(), table, sema);
}
table.Add(name, var_def);
continue;
return {};
}
std::any visitFuncDef(SysYParser::FuncDefContext* ctx) override {
if (!ctx || !ctx->blockStmt()) {
throw std::runtime_error(FormatError("sema", "缺少 main 函数定义"));
}
if (auto* stmt = item->stmt(); stmt && stmt->returnStmt()) {
auto* ret = stmt->returnStmt();
if (!ret->exp()) {
throw std::runtime_error(FormatError("sema", "return 缺少表达式"));
if (!ctx->funcType() || !ctx->funcType()->INT()) {
throw std::runtime_error(FormatError("sema", "当前仅支持 int main"));
}
const auto& items = ctx->blockStmt()->blockItem();
if (items.empty()) {
throw std::runtime_error(
FormatError("sema", "main 函数不能为空,且必须以 return 结束"));
}
ctx->blockStmt()->accept(this);
return {};
}
std::any visitBlockStmt(SysYParser::BlockStmtContext* ctx) override {
if (!ctx) {
throw std::runtime_error(FormatError("sema", "缺少语句块"));
}
const auto& items = ctx->blockItem();
for (size_t i = 0; i < items.size(); ++i) {
auto* item = items[i];
if (!item) {
continue;
}
CheckExpr(*ret->exp(), table, sema);
seen_return = true;
if (i + 1 != items.size()) {
if (seen_return_) {
throw std::runtime_error(
FormatError("sema", "return 必须是 main 函数中的最后一条语句"));
}
continue;
current_item_index_ = i;
total_items_ = items.size();
item->accept(this);
}
return {};
}
std::any visitBlockItem(SysYParser::BlockItemContext* ctx) override {
if (!ctx) {
throw std::runtime_error(FormatError("sema", "暂不支持的语句或声明"));
}
if (ctx->decl()) {
ctx->decl()->accept(this);
return {};
}
if (ctx->stmt()) {
ctx->stmt()->accept(this);
return {};
}
throw std::runtime_error(FormatError("sema", "暂不支持的语句或声明"));
}
if (!seen_return) {
throw std::runtime_error(FormatError("sema", "main 函数必须包含 return 语句"));
std::any visitDecl(SysYParser::DeclContext* ctx) override {
if (!ctx) {
throw std::runtime_error(FormatError("sema", "非法变量声明"));
}
if (!ctx->btype() || !ctx->btype()->INT()) {
throw std::runtime_error(FormatError("sema", "当前仅支持局部 int 变量声明"));
}
auto* var_def = ctx->varDef();
if (!var_def || !var_def->lValue()) {
throw std::runtime_error(FormatError("sema", "非法变量声明"));
}
const std::string name = GetLValueName(*var_def->lValue());
if (table_.Contains(name)) {
throw std::runtime_error(FormatError("sema", "重复定义变量: " + name));
}
if (auto* init = var_def->initValue()) {
if (!init->exp()) {
throw std::runtime_error(FormatError("sema", "当前不支持聚合初始化"));
}
init->exp()->accept(this);
}
table_.Add(name, var_def);
return {};
}
return sema;
std::any visitStmt(SysYParser::StmtContext* ctx) override {
if (!ctx || !ctx->returnStmt()) {
throw std::runtime_error(FormatError("sema", "暂不支持的语句或声明"));
}
ctx->returnStmt()->accept(this);
return {};
}
std::any visitReturnStmt(SysYParser::ReturnStmtContext* ctx) override {
if (!ctx || !ctx->exp()) {
throw std::runtime_error(FormatError("sema", "return 缺少表达式"));
}
ctx->exp()->accept(this);
seen_return_ = true;
if (current_item_index_ + 1 != total_items_) {
throw std::runtime_error(
FormatError("sema", "return 必须是 main 函数中的最后一条语句"));
}
return {};
}
std::any visitParenExp(SysYParser::ParenExpContext* ctx) override {
if (!ctx || !ctx->exp()) {
throw std::runtime_error(FormatError("sema", "非法括号表达式"));
}
ctx->exp()->accept(this);
return {};
}
std::any visitVarExp(SysYParser::VarExpContext* ctx) override {
if (!ctx || !ctx->var()) {
throw std::runtime_error(FormatError("sema", "非法变量表达式"));
}
ctx->var()->accept(this);
return {};
}
std::any visitNumberExp(SysYParser::NumberExpContext* ctx) override {
if (!ctx || !ctx->number() || !ctx->number()->ILITERAL()) {
throw std::runtime_error(FormatError("sema", "当前仅支持整数字面量"));
}
return {};
}
std::any visitAdditiveExp(SysYParser::AdditiveExpContext* ctx) override {
if (!ctx || !ctx->exp(0) || !ctx->exp(1)) {
throw std::runtime_error(FormatError("sema", "暂不支持的表达式形式"));
}
ctx->exp(0)->accept(this);
ctx->exp(1)->accept(this);
return {};
}
std::any visitVar(SysYParser::VarContext* ctx) override {
if (!ctx || !ctx->ID()) {
throw std::runtime_error(FormatError("sema", "非法变量引用"));
}
const std::string name = ctx->ID()->getText();
auto* decl = table_.Lookup(name);
if (!decl) {
throw std::runtime_error(FormatError("sema", "使用了未定义的变量: " + name));
}
sema_.BindVarUse(ctx, decl);
return {};
}
SemanticContext TakeSemanticContext() { return std::move(sema_); }
private:
SymbolTable table_;
SemanticContext sema_;
bool seen_return_ = false;
size_t current_item_index_ = 0;
size_t total_items_ = 0;
};
} // namespace
SemanticContext RunSema(SysYParser::CompUnitContext& comp_unit) {
SemaVisitor visitor;
comp_unit.accept(&visitor);
return visitor.TakeSemanticContext();
}