Compare commits
1 Commits
lab2
...
eac1cdf613
| Author | SHA1 | Date | |
|---|---|---|---|
| eac1cdf613 |
313
doc/Lab2-实验记录.md
313
doc/Lab2-实验记录.md
@@ -1,313 +0,0 @@
|
||||
# Lab2 实验记录:中间表示生成
|
||||
|
||||
## 1. 实验目标
|
||||
|
||||
本次 Lab2 的目标是在已有的 SysY 前端基础上,补齐语义检查与 IR 生成流程,使编译器能够把更完整的 SysY 程序翻译为 LLVM 风格 IR,并通过 `llc/clang` 验证生成结果的正确性。
|
||||
|
||||
本次完成工作的重点包括:
|
||||
|
||||
- 扩展 IR 类型系统与指令系统,支持 `float`、数组、分支、函数调用、GEP、类型转换等基础能力。
|
||||
- 扩展 Sema,支持嵌套作用域、左值绑定、函数调用绑定与内建库函数预声明。
|
||||
- 完成表达式、控制流、函数、数组与全局变量的 IR 生成逻辑。
|
||||
- 修复全局初始化常量求值、短路求值、数组寻址、IR 打印格式等会直接阻塞 Lab2 验证的关键问题。
|
||||
|
||||
## 2. 代码改动范围
|
||||
|
||||
本次实验主要修改了以下模块:
|
||||
|
||||
- `src/sem` 与 `include/sem`
|
||||
- `src/ir` 与 `include/ir`
|
||||
- `src/irgen` 与 `include/irgen`
|
||||
- 新增本文档 `doc/Lab2-实验记录.md`
|
||||
|
||||
其中:
|
||||
|
||||
- `sem` 负责名称绑定、作用域和语义信息准备。
|
||||
- `ir` 负责 IR 基础设施、Builder 与 Printer。
|
||||
- `irgen` 负责把 ANTLR 语法树翻译成 IR。
|
||||
|
||||
## 3. 完成过程
|
||||
|
||||
### 3.1 先确认问题边界
|
||||
|
||||
开始时先阅读了实验文档 `doc/Lab2-中间表示生成.md`,然后检查了以下实现:
|
||||
|
||||
- `IRGenDecl.cpp`
|
||||
- `IRGenExp.cpp`
|
||||
- `IRGenStmt.cpp`
|
||||
- `IRGenFunc.cpp`
|
||||
- `IRBuilder.cpp`
|
||||
- `IRPrinter.cpp`
|
||||
- `Sema.cpp`
|
||||
|
||||
在初始状态下,代码已经完成了大部分 Lab2 框架,但仍存在两个会直接导致失败的问题:
|
||||
|
||||
1. 全局常量初始化时,`EvalConstExpr` 实际上仍然调用了运行时的 `EvalExpr`,从而在没有插入点时进入 `builder_.CreateLoad/CreateBinary/...`,最终报错:
|
||||
`IRBuilder 未设置插入点`
|
||||
2. 数组相关的指针/聚合类型处理不一致,局部数组、多维数组与数组参数传递时很容易触发 `LoadInst 不支持的指针类型` 或生成错误的 GEP。
|
||||
|
||||
为了避免只靠静态阅读猜问题,随后先执行了构建与最小样例验证,确认真实失败点。
|
||||
|
||||
### 3.2 建立回归基线
|
||||
|
||||
首先重新构建项目:
|
||||
|
||||
```bash
|
||||
cmake -S . -B build -DCMAKE_BUILD_TYPE=Release
|
||||
cmake --build build -j 4
|
||||
```
|
||||
|
||||
然后针对典型样例做验证:
|
||||
|
||||
- `simple_add.sy`
|
||||
- `05_arr_defn4.sy`
|
||||
- `95_float.sy`
|
||||
|
||||
结果表明:
|
||||
|
||||
- `95_float.sy` 会因为全局常量路径错误触发 `IRBuilder 未设置插入点`
|
||||
- `05_arr_defn4.sy` 会因为数组寻址/存储类型不一致导致崩溃
|
||||
|
||||
这一步的作用是把问题从“感觉哪里有问题”缩小到“常量求值路径”和“数组存储/寻址路径”两条主线。
|
||||
|
||||
## 4. 关键困难与解决办法
|
||||
|
||||
### 4.1 困难一:全局初始化错误地走了运行时 IRBuilder 路径
|
||||
|
||||
#### 现象
|
||||
|
||||
像下面这种代码在全局或常量初始化中会崩溃:
|
||||
|
||||
```c
|
||||
const float PI = 3.1415926;
|
||||
const int A = 1 + 2;
|
||||
```
|
||||
|
||||
原因是原来的 `EvalConstExpr` 虽然名字叫“常量求值”,但内部还是直接调用了 `EvalExpr`。一旦表达式中包含需要访问变量、二元运算、短路逻辑等节点,就会落入 `builder_` 创建指令的逻辑,而此时全局作用域没有任何基本块插入点。
|
||||
|
||||
#### 解决办法
|
||||
|
||||
把编译期常量求值彻底独立出来:
|
||||
|
||||
- 为 `EvalConstExpr` 单独实现一套常量求值 Visitor。
|
||||
- 常量路径只返回 `ConstantInt` / `ConstantFloat`,绝不生成 IR 指令。
|
||||
- 支持:
|
||||
- 整数/浮点字面量
|
||||
- 括号表达式
|
||||
- `+`、`-`、`!`
|
||||
- `* / % + -`
|
||||
- 比较运算
|
||||
- `&& ||`
|
||||
- 标量 `const` 的引用
|
||||
- 在全局和常量初始化中,只允许使用 `EvalConstExpr` 的结果。
|
||||
|
||||
#### 效果
|
||||
|
||||
修复后:
|
||||
|
||||
- 全局初始化不再依赖插入点
|
||||
- `95_float.sy` 中的全局常量能够稳定生成
|
||||
- 短路表达式在常量上下文中只做纯编译期求值,不会试图分配 `alloca`
|
||||
|
||||
### 4.2 困难二:数组变量、数组参数与标量变量的“存储语义”混乱
|
||||
|
||||
#### 现象
|
||||
|
||||
原实现里,`alloca/load/store/GEP` 对类型的理解不统一:
|
||||
|
||||
- 标量变量需要的是“指向标量的槽位”
|
||||
- 局部数组需要的是“聚合对象的基址”
|
||||
- 数组形参在 SysY 中本质上是指针,不应按局部数组同样处理
|
||||
|
||||
如果把这些情况混在一起,就会出现:
|
||||
|
||||
- `load` 试图从数组类型直接取值
|
||||
- GEP 基类型和索引序列不匹配
|
||||
- 局部数组访问、多维数组访问、数组实参传递行为错误
|
||||
|
||||
#### 解决办法
|
||||
|
||||
做了三层拆分:
|
||||
|
||||
1. 标量与数组分离
|
||||
- 标量局部变量使用真正的标量槽位:`i32*` 或 `float*`
|
||||
- 数组局部变量保留聚合基址
|
||||
|
||||
2. 普通数组与数组形参分离
|
||||
- 局部/全局数组通过多级 GEP 沿数组维度寻址
|
||||
- 数组形参按“指针退化”处理,访问时根据剩余维度计算偏移
|
||||
|
||||
3. 左值取址与值求值分离
|
||||
- `GetLValuePtr` 只负责拿地址
|
||||
- `visitLValueExp` 根据左值是否仍是数组来决定是 `load` 还是数组退化传参
|
||||
|
||||
#### 效果
|
||||
|
||||
修复后:
|
||||
|
||||
- `simple_add.sy` 恢复正常
|
||||
- `05_arr_defn4.sy` 可以生成并运行
|
||||
- 多维数组和数组形参的寻址逻辑更加稳定
|
||||
|
||||
### 4.3 困难三:局部数组花括号初始化语义不正确
|
||||
|
||||
#### 现象
|
||||
|
||||
`05_arr_defn4.sy` 虽然在中期已经不再崩溃,但运行结果仍然错误,退出码从预期的 `21` 变成了 `13`。这说明不是寻址崩了,而是初始化布局错了。
|
||||
|
||||
问题根源在于:
|
||||
|
||||
- 一部分初始化按“子数组递进”处理
|
||||
- 一部分初始化又按“标量扁平展开”处理
|
||||
|
||||
两套逻辑混用后,多维数组初始化次序就会乱掉。
|
||||
|
||||
#### 解决办法
|
||||
|
||||
把局部数组初始化统一改成“聚合初始化 + 标量游标”方案:
|
||||
|
||||
- 先统一做零初始化
|
||||
- 再对花括号初始化维护一个标量游标
|
||||
- 标量初始化时按当前扁平偏移定位到实际元素
|
||||
- 子聚合初始化时按当前对齐边界进入对应子数组
|
||||
|
||||
这套逻辑与 SysY/LLVM 前端常见的聚合初始化处理方式更接近。
|
||||
|
||||
#### 效果
|
||||
|
||||
修复后 `05_arr_defn4.sy` 的 IR 可以通过 `verify_ir.sh --run`,输出与预期一致。
|
||||
|
||||
### 4.4 困难四:IR 文本虽然能打印,但 LLVM 后端不一定接受
|
||||
|
||||
#### 现象
|
||||
|
||||
在进入 `verify_ir.sh` 阶段后,又暴露出一批“IR 生成没崩,但 LLVM 不认”的问题:
|
||||
|
||||
- 内建函数被打印成了空定义,而不是声明
|
||||
- 浮点常量打印格式不符合 LLVM 期望
|
||||
- `icmp/fcmp` 的结果在打印和后续使用中对 `i1/i32` 处理不一致
|
||||
- 自动临时名使用纯数字,打乱后会违反 LLVM 的编号要求
|
||||
- 基本块名重复
|
||||
- `getelementptr` 打印时的基类型信息不正确
|
||||
|
||||
#### 解决办法
|
||||
|
||||
对 IR 基础设施做了系统修正:
|
||||
|
||||
- `Function` 不再默认创建入口块,只有真正定义函数时才建 `entry`
|
||||
- `IRPrinter` 对没有基本块的函数输出 `declare`
|
||||
- 自动临时名改成 `t0/t1/...`,避免 LLVM 对纯数字 SSA 名称的严格顺序约束
|
||||
- 比较结果按布尔值打印和消费
|
||||
- `if/while/and/or` 生成的块名追加唯一后缀
|
||||
- 修复 float 常量、GEP、Cast、Call 等打印格式
|
||||
|
||||
#### 效果
|
||||
|
||||
修复后:
|
||||
|
||||
- `simple_add.sy`
|
||||
- `13_sub2.sy`
|
||||
- `29_break.sy`
|
||||
- `36_op_priority2.sy`
|
||||
- `05_arr_defn4.sy`
|
||||
|
||||
都已经可以通过 `verify_ir.sh --run`。
|
||||
|
||||
### 4.5 困难五:`95_float.sy` 的最终运行验证仍受运行库缺失影响
|
||||
|
||||
#### 现象
|
||||
|
||||
在修完 IR 生成与打印问题后,`95_float.sy` 已经可以:
|
||||
|
||||
- 成功生成 IR
|
||||
- 通过 `llc` 生成目标文件
|
||||
|
||||
但在最终链接阶段仍会失败,原因不是 IR 错误,而是仓库中的 `sylib/sylib.c` 当前只是空壳,没有提供:
|
||||
|
||||
- `getfloat`
|
||||
- `putfloat`
|
||||
- `getfarray`
|
||||
- `putfarray`
|
||||
- `putch`
|
||||
- `putint`
|
||||
|
||||
等符号的真实实现。
|
||||
|
||||
#### 解决办法
|
||||
|
||||
本次提交中没有擅自扩展运行库,而是把问题明确定位为“Lab2 IR 生成正确,但运行时依赖未补齐”。这样可以把 Lab2 编译器部分与后续运行库实现清晰分开。
|
||||
|
||||
#### 影响
|
||||
|
||||
`95_float.sy` 当前的状态是:
|
||||
|
||||
- IR 生成正确
|
||||
- LLVM 后端接受
|
||||
- 最终运行依赖运行库补全
|
||||
|
||||
## 5. 本次实现的主要能力
|
||||
|
||||
本次实验结束后,编译器已经具备以下 Lab2 关键能力:
|
||||
|
||||
- 全局变量/常量 IR 生成
|
||||
- 局部变量 IR 生成
|
||||
- `int/float` 常量与表达式生成
|
||||
- 基本算术与比较运算
|
||||
- 类型转换:`sitofp`、`fptosi`、`zext`
|
||||
- `if-else`
|
||||
- `while`
|
||||
- `break/continue`
|
||||
- 函数定义与函数调用
|
||||
- 标量参数与数组参数
|
||||
- 多维数组寻址
|
||||
- 局部数组零初始化与花括号初始化
|
||||
- 短路求值
|
||||
- LLVM 可接受的 IR 文本打印
|
||||
|
||||
## 6. 验证结果
|
||||
|
||||
本次已完成的回归包括:
|
||||
|
||||
```bash
|
||||
./scripts/verify_ir.sh test/test_case/functional/simple_add.sy /tmp/ir_simple --run
|
||||
./scripts/verify_ir.sh test/test_case/functional/13_sub2.sy /tmp/ir_sub2 --run
|
||||
./scripts/verify_ir.sh test/test_case/functional/29_break.sy /tmp/ir_break --run
|
||||
./scripts/verify_ir.sh test/test_case/functional/36_op_priority2.sy /tmp/ir_op --run
|
||||
./scripts/verify_ir.sh test/test_case/functional/05_arr_defn4.sy /tmp/ir_arr --run
|
||||
```
|
||||
|
||||
这些样例均已通过。
|
||||
|
||||
另外:
|
||||
|
||||
```bash
|
||||
./build/bin/compiler --emit-ir test/test_case/functional/95_float.sy
|
||||
```
|
||||
|
||||
可以成功生成 IR,且 IR 能通过 `llc`,说明浮点常量、浮点表达式、浮点比较、类型转换与数组传参路径已经基本打通。
|
||||
|
||||
## 7. 本次实验中的经验总结
|
||||
|
||||
本次 Lab2 最核心的经验有三点:
|
||||
|
||||
1. 编译期常量求值和运行时 IR 生成必须严格分离。
|
||||
只要两条路径混在一起,全局初始化和常量表达式一定会出错。
|
||||
|
||||
2. 数组不能按“只是更大的标量”处理。
|
||||
数组对象、数组形参、数组元素地址、数组退化指针这几个概念必须明确区分。
|
||||
|
||||
3. “能打印 IR”不等于“LLVM 能接受 IR”。
|
||||
最后一定要走一遍 `llc/clang`,否则很多类型和格式问题会被掩盖。
|
||||
|
||||
## 8. 后续可继续完善的方向
|
||||
|
||||
虽然本次已经完成了 Lab2 的主体工作,但还可以继续完善:
|
||||
|
||||
- 为 `sylib` 补齐实际运行库实现,打通 `95_float` 等 I/O 样例的最终运行
|
||||
- 为全局数组初始化补完整的常量聚合表示,而不是目前以标量初始化为主
|
||||
- 进一步统一 IR 中布尔类型的内部表示,减少 `i1/i32` 的兼容分支
|
||||
- 继续批量回归 `test/test_case` 下更多样例,补齐剩余边界情况
|
||||
|
||||
## 9. 结论
|
||||
|
||||
本次 Lab2 已经从“完成约 90%,但被全局初始化与数组/短路问题卡住”的状态,推进到“核心 IR 生成链路可用、典型功能样例可运行验证”的状态。阻塞实验验收的主问题已经被定位并解决,代码结构也比原来更清晰,后续继续做运行库、优化与更大规模回归时会更稳。
|
||||
185
include/ir/IR.h
185
include/ir/IR.h
@@ -37,7 +37,6 @@
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include <cstdint>
|
||||
|
||||
namespace ir {
|
||||
|
||||
@@ -46,7 +45,6 @@ class Value;
|
||||
class User;
|
||||
class ConstantValue;
|
||||
class ConstantInt;
|
||||
class ConstantFloat;
|
||||
class GlobalValue;
|
||||
class Instruction;
|
||||
class BasicBlock;
|
||||
@@ -85,20 +83,17 @@ class Context {
|
||||
~Context();
|
||||
// 去重创建 i32 常量。
|
||||
ConstantInt* GetConstInt(int v);
|
||||
// 去重创建 float 常量。
|
||||
ConstantFloat* GetConstFloat(float v);
|
||||
|
||||
std::string NextTemp();
|
||||
|
||||
private:
|
||||
std::unordered_map<int, std::unique_ptr<ConstantInt>> const_ints_;
|
||||
std::unordered_map<float, std::unique_ptr<ConstantFloat>> const_floats_;
|
||||
int temp_index_ = -1;
|
||||
};
|
||||
|
||||
class Type : public std::enable_shared_from_this<Type> {
|
||||
class Type {
|
||||
public:
|
||||
enum class Kind { Void, Int32, PtrInt32, Float, PtrFloat, Label, Array };
|
||||
enum class Kind { Void, Int32, PtrInt32 };
|
||||
explicit Type(Kind k);
|
||||
// 使用静态共享对象获取类型。
|
||||
// 同一类型可直接比较返回值是否相等,例如:
|
||||
@@ -106,36 +101,15 @@ class Type : public std::enable_shared_from_this<Type> {
|
||||
static const std::shared_ptr<Type>& GetVoidType();
|
||||
static const std::shared_ptr<Type>& GetInt32Type();
|
||||
static const std::shared_ptr<Type>& GetPtrInt32Type();
|
||||
static const std::shared_ptr<Type>& GetFloatType();
|
||||
static const std::shared_ptr<Type>& GetPtrFloatType();
|
||||
static const std::shared_ptr<Type>& GetLabelType();
|
||||
Kind GetKind() const;
|
||||
bool IsVoid() const;
|
||||
bool IsInt32() const;
|
||||
bool IsPtrInt32() const;
|
||||
bool IsFloat() const;
|
||||
bool IsPtrFloat() const;
|
||||
bool IsLabel() const;
|
||||
bool IsArray() const;
|
||||
std::shared_ptr<class ArrayType> GetAsArrayType();
|
||||
|
||||
private:
|
||||
Kind kind_;
|
||||
};
|
||||
|
||||
class ArrayType : public Type {
|
||||
public:
|
||||
ArrayType(std::shared_ptr<Type> element_type, uint32_t num_elements);
|
||||
static std::shared_ptr<ArrayType> Get(std::shared_ptr<Type> element_type,
|
||||
uint32_t num_elements);
|
||||
std::shared_ptr<Type> GetElementType() const { return element_type_; }
|
||||
uint32_t GetNumElements() const { return num_elements_; }
|
||||
|
||||
private:
|
||||
std::shared_ptr<Type> element_type_;
|
||||
uint32_t num_elements_;
|
||||
};
|
||||
|
||||
class Value {
|
||||
public:
|
||||
Value(std::shared_ptr<Type> ty, std::string name);
|
||||
@@ -146,15 +120,10 @@ class Value {
|
||||
bool IsVoid() const;
|
||||
bool IsInt32() const;
|
||||
bool IsPtrInt32() const;
|
||||
bool IsFloat() const;
|
||||
bool IsPtrFloat() const;
|
||||
bool IsLabel() const;
|
||||
bool IsConstant() const;
|
||||
bool IsInstruction() const;
|
||||
bool IsUser() const;
|
||||
bool IsFunction() const;
|
||||
bool IsGlobalValue() const;
|
||||
bool IsArgument() const;
|
||||
void AddUse(User* user, size_t operand_index);
|
||||
void RemoveUse(User* user, size_t operand_index);
|
||||
const std::vector<Use>& GetUses() const;
|
||||
@@ -166,19 +135,6 @@ class Value {
|
||||
std::vector<Use> uses_;
|
||||
};
|
||||
|
||||
// Argument represents a function parameter.
|
||||
class Argument : public Value {
|
||||
public:
|
||||
Argument(std::shared_ptr<Type> ty, std::string name, Function* parent,
|
||||
unsigned arg_no);
|
||||
Function* GetParent() const { return parent_; }
|
||||
unsigned GetArgNo() const { return arg_no_; }
|
||||
|
||||
private:
|
||||
Function* parent_;
|
||||
unsigned arg_no_;
|
||||
};
|
||||
|
||||
// ConstantValue 是常量体系的基类。
|
||||
// 当前只实现了 ConstantInt,后续可继续扩展更多常量种类。
|
||||
class ConstantValue : public Value {
|
||||
@@ -195,49 +151,8 @@ class ConstantInt : public ConstantValue {
|
||||
int value_{};
|
||||
};
|
||||
|
||||
class ConstantFloat : public ConstantValue {
|
||||
public:
|
||||
ConstantFloat(std::shared_ptr<Type> ty, float v);
|
||||
float GetValue() const { return value_; }
|
||||
|
||||
private:
|
||||
float value_{};
|
||||
};
|
||||
|
||||
// 后续还需要扩展更多指令类型。
|
||||
enum class Opcode {
|
||||
Add,
|
||||
Sub,
|
||||
Mul,
|
||||
Div,
|
||||
Mod,
|
||||
FAdd,
|
||||
FSub,
|
||||
FMul,
|
||||
FDiv,
|
||||
ICmpEQ,
|
||||
ICmpNE,
|
||||
ICmpLT,
|
||||
ICmpGT,
|
||||
ICmpLE,
|
||||
ICmpGE,
|
||||
FCmpEQ,
|
||||
FCmpNE,
|
||||
FCmpLT,
|
||||
FCmpGT,
|
||||
FCmpLE,
|
||||
FCmpGE,
|
||||
Alloca,
|
||||
Load,
|
||||
Store,
|
||||
Ret,
|
||||
Br,
|
||||
Call,
|
||||
GEP,
|
||||
ZExt,
|
||||
SIToFP,
|
||||
FPToSI
|
||||
};
|
||||
enum class Opcode { Add, Sub, Mul, Alloca, Load, Store, Ret };
|
||||
|
||||
// User 是所有“会使用其他 Value 作为输入”的 IR 对象的抽象基类。
|
||||
// 当前实现中只有 Instruction 继承自 User。
|
||||
@@ -256,15 +171,11 @@ class User : public Value {
|
||||
std::vector<Value*> operands_;
|
||||
};
|
||||
|
||||
// GlobalValue 是全局值/全局变量体系的类。
|
||||
// GlobalValue 是全局值/全局变量体系的空壳占位类。
|
||||
// 当前只补齐类层次,具体初始化器、打印和链接语义后续再补。
|
||||
class GlobalValue : public User {
|
||||
public:
|
||||
GlobalValue(std::shared_ptr<Type> ty, std::string name, ConstantValue* init = nullptr);
|
||||
ConstantValue* GetInitializer() const { return init_; }
|
||||
void SetInitializer(ConstantValue* init) { init_ = init; }
|
||||
|
||||
private:
|
||||
ConstantValue* init_ = nullptr;
|
||||
GlobalValue(std::shared_ptr<Type> ty, std::string name);
|
||||
};
|
||||
|
||||
class Instruction : public User {
|
||||
@@ -285,40 +196,7 @@ class BinaryInst : public Instruction {
|
||||
BinaryInst(Opcode op, std::shared_ptr<Type> ty, Value* lhs, Value* rhs,
|
||||
std::string name);
|
||||
Value* GetLhs() const;
|
||||
Value* GetRhs() const;
|
||||
};
|
||||
|
||||
class BranchInst : public Instruction {
|
||||
public:
|
||||
// Unconditional branch
|
||||
explicit BranchInst(BasicBlock* dest);
|
||||
// Conditional branch
|
||||
BranchInst(Value* cond, BasicBlock* if_true, BasicBlock* if_false);
|
||||
|
||||
bool IsConditional() const;
|
||||
Value* GetCondition() const;
|
||||
BasicBlock* GetIfTrue() const;
|
||||
BasicBlock* GetIfFalse() const;
|
||||
BasicBlock* GetDest() const;
|
||||
};
|
||||
|
||||
class CallInst : public Instruction {
|
||||
public:
|
||||
CallInst(Function* func, const std::vector<Value*>& args, std::string name = "");
|
||||
Function* GetFunction() const;
|
||||
};
|
||||
|
||||
class GetElementPtrInst : public Instruction {
|
||||
public:
|
||||
GetElementPtrInst(std::shared_ptr<Type> ptr_ty, Value* ptr,
|
||||
const std::vector<Value*>& indices, std::string name = "");
|
||||
Value* GetPtr() const;
|
||||
};
|
||||
|
||||
class CastInst : public Instruction {
|
||||
public:
|
||||
CastInst(Opcode op, std::shared_ptr<Type> ty, Value* val, std::string name = "");
|
||||
Value* GetValue() const;
|
||||
Value* GetRhs() const;
|
||||
};
|
||||
|
||||
class ReturnInst : public Instruction {
|
||||
@@ -377,41 +255,38 @@ class BasicBlock : public Value {
|
||||
};
|
||||
|
||||
// Function 当前也采用了最小实现。
|
||||
// 需要特别注意:由于项目里还没有单独的 FunctionType,
|
||||
// Function 继承自 Value 后,其 type_ 目前只保存“返回类型”,
|
||||
// 并不能完整表达“返回类型 + 形参列表”这一整套函数签名。
|
||||
// 这对当前只支持 int main() 的最小 IR 足够,但后续若补普通函数、
|
||||
// 形参和调用,通常需要引入专门的函数类型表示。
|
||||
class Function : public Value {
|
||||
public:
|
||||
Function(std::string name, std::shared_ptr<Type> ret_type,
|
||||
std::vector<std::shared_ptr<Type>> param_types);
|
||||
// 当前构造函数接收的也是返回类型,而不是完整函数类型。
|
||||
Function(std::string name, std::shared_ptr<Type> ret_type);
|
||||
BasicBlock* CreateBlock(const std::string& name);
|
||||
BasicBlock* GetEntry();
|
||||
const BasicBlock* GetEntry() const;
|
||||
const std::vector<std::unique_ptr<BasicBlock>>& GetBlocks() const;
|
||||
const std::vector<std::unique_ptr<Argument>>& GetArguments() const;
|
||||
|
||||
private:
|
||||
BasicBlock* entry_ = nullptr;
|
||||
std::vector<std::unique_ptr<BasicBlock>> blocks_;
|
||||
std::vector<std::unique_ptr<Argument>> arguments_;
|
||||
};
|
||||
|
||||
|
||||
class Module {
|
||||
public:
|
||||
Module() = default;
|
||||
Context& GetContext();
|
||||
const Context& GetContext() const;
|
||||
// 创建函数时当前只显式传入返回类型,尚未接入完整的 FunctionType。
|
||||
Function* CreateFunction(const std::string& name,
|
||||
std::shared_ptr<Type> ret_type,
|
||||
std::vector<std::shared_ptr<Type>> param_types = {});
|
||||
std::shared_ptr<Type> ret_type);
|
||||
const std::vector<std::unique_ptr<Function>>& GetFunctions() const;
|
||||
GlobalValue* CreateGlobalValue(const std::string& name,
|
||||
std::shared_ptr<Type> ty,
|
||||
ConstantValue* init = nullptr);
|
||||
const std::vector<std::unique_ptr<GlobalValue>>& GetGlobalValues() const;
|
||||
|
||||
private:
|
||||
Context context_;
|
||||
std::vector<std::unique_ptr<Function>> functions_;
|
||||
std::vector<std::unique_ptr<GlobalValue>> global_values_;
|
||||
};
|
||||
|
||||
class IRBuilder {
|
||||
@@ -422,41 +297,13 @@ class IRBuilder {
|
||||
|
||||
// 构造常量、二元运算、返回指令的最小集合。
|
||||
ConstantInt* CreateConstInt(int v);
|
||||
ConstantFloat* CreateConstFloat(float v);
|
||||
BinaryInst* CreateBinary(Opcode op, Value* lhs, Value* rhs,
|
||||
const std::string& name);
|
||||
BinaryInst* CreateAdd(Value* lhs, Value* rhs, const std::string& name);
|
||||
BinaryInst* CreateSub(Value* lhs, Value* rhs, const std::string& name);
|
||||
BinaryInst* CreateMul(Value* lhs, Value* rhs, const std::string& name);
|
||||
BinaryInst* CreateDiv(Value* lhs, Value* rhs, const std::string& name);
|
||||
BinaryInst* CreateMod(Value* lhs, Value* rhs, const std::string& name);
|
||||
BinaryInst* CreateFAdd(Value* lhs, Value* rhs, const std::string& name);
|
||||
BinaryInst* CreateFSub(Value* lhs, Value* rhs, const std::string& name);
|
||||
BinaryInst* CreateFMul(Value* lhs, Value* rhs, const std::string& name);
|
||||
BinaryInst* CreateFDiv(Value* lhs, Value* rhs, const std::string& name);
|
||||
BinaryInst* CreateICmp(Opcode op, Value* lhs, Value* rhs,
|
||||
const std::string& name);
|
||||
BinaryInst* CreateFCmp(Opcode op, Value* lhs, Value* rhs,
|
||||
const std::string& name);
|
||||
AllocaInst* CreateAlloca(std::shared_ptr<Type> ty, const std::string& name);
|
||||
AllocaInst* CreateAllocaI32(const std::string& name);
|
||||
LoadInst* CreateLoad(Value* ptr, const std::string& name);
|
||||
StoreInst* CreateStore(Value* val, Value* ptr);
|
||||
ReturnInst* CreateRet(Value* v);
|
||||
BranchInst* CreateBr(BasicBlock* dest);
|
||||
BranchInst* CreateCondBr(Value* cond, BasicBlock* if_true,
|
||||
BasicBlock* if_false);
|
||||
CallInst* CreateCall(Function* func, const std::vector<Value*>& args,
|
||||
const std::string& name = "");
|
||||
GetElementPtrInst* CreateGEP(std::shared_ptr<Type> ptr_ty, Value* ptr,
|
||||
const std::vector<Value*>& indices,
|
||||
const std::string& name = "");
|
||||
CastInst* CreateZExt(Value* val, std::shared_ptr<Type> ty,
|
||||
const std::string& name = "");
|
||||
CastInst* CreateSIToFP(Value* val, std::shared_ptr<Type> ty,
|
||||
const std::string& name = "");
|
||||
CastInst* CreateFPToSI(Value* val, std::shared_ptr<Type> ty,
|
||||
const std::string& name = "");
|
||||
|
||||
private:
|
||||
Context& ctx_;
|
||||
|
||||
@@ -5,10 +5,8 @@
|
||||
|
||||
#include <any>
|
||||
#include <memory>
|
||||
#include <stack>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "SysYBaseVisitor.h"
|
||||
#include "SysYParser.h"
|
||||
@@ -20,56 +18,24 @@ class Module;
|
||||
class Function;
|
||||
class IRBuilder;
|
||||
class Value;
|
||||
class BasicBlock;
|
||||
}
|
||||
|
||||
class IRGenImpl final : public SysYBaseVisitor {
|
||||
public:
|
||||
IRGenImpl(ir::Module& module, const SemanticContext& sema);
|
||||
|
||||
// Top-level rules
|
||||
std::any visitCompUnit(SysYParser::CompUnitContext* ctx) override;
|
||||
std::any visitDecl(SysYParser::DeclContext* ctx) override;
|
||||
std::any visitConstDecl(SysYParser::ConstDeclContext* ctx) override;
|
||||
std::any visitVarDecl(SysYParser::VarDeclContext* ctx) override;
|
||||
std::any visitConstDef(SysYParser::ConstDefContext* ctx) override;
|
||||
std::any visitVarDef(SysYParser::VarDefContext* ctx) override;
|
||||
std::any visitFuncDef(SysYParser::FuncDefContext* ctx) override;
|
||||
std::any visitFuncFParam(SysYParser::FuncFParamContext* ctx) override;
|
||||
|
||||
// Statement rules
|
||||
std::any visitBlockStmt(SysYParser::BlockStmtContext* ctx) override;
|
||||
std::any visitBlockItem(SysYParser::BlockItemContext* ctx) override;
|
||||
std::any visitDecl(SysYParser::DeclContext* ctx) override;
|
||||
std::any visitStmt(SysYParser::StmtContext* ctx) override;
|
||||
std::any visitAssignStmt(SysYParser::AssignStmtContext* ctx) override;
|
||||
std::any visitVarDef(SysYParser::VarDefContext* ctx) override;
|
||||
std::any visitReturnStmt(SysYParser::ReturnStmtContext* ctx) override;
|
||||
std::any visitIfStmt(SysYParser::IfStmtContext* ctx) override;
|
||||
std::any visitWhileStmt(SysYParser::WhileStmtContext* ctx) override;
|
||||
std::any visitBreakStmt(SysYParser::BreakStmtContext* ctx) override;
|
||||
std::any visitContinueStmt(SysYParser::ContinueStmtContext* ctx) override;
|
||||
std::any visitExpStmt(SysYParser::ExpStmtContext* ctx) override;
|
||||
|
||||
// Expression rules
|
||||
std::any visitParenExp(SysYParser::ParenExpContext* ctx) override;
|
||||
std::any visitLValueExp(SysYParser::LValueExpContext* ctx) override;
|
||||
std::any visitNumberExp(SysYParser::NumberExpContext* ctx) override;
|
||||
std::any visitFuncCallExp(SysYParser::FuncCallExpContext* ctx) override;
|
||||
std::any visitNotExp(SysYParser::NotExpContext* ctx) override;
|
||||
std::any visitUnaryAddExp(SysYParser::UnaryAddExpContext* ctx) override;
|
||||
std::any visitUnarySubExp(SysYParser::UnarySubExpContext* ctx) override;
|
||||
std::any visitMulExp(SysYParser::MulExpContext* ctx) override;
|
||||
std::any visitDivExp(SysYParser::DivExpContext* ctx) override;
|
||||
std::any visitModExp(SysYParser::ModExpContext* ctx) override;
|
||||
std::any visitAddExp(SysYParser::AddExpContext* ctx) override;
|
||||
std::any visitSubExp(SysYParser::SubExpContext* ctx) override;
|
||||
std::any visitLtExp(SysYParser::LtExpContext* ctx) override;
|
||||
std::any visitLeExp(SysYParser::LeExpContext* ctx) override;
|
||||
std::any visitGtExp(SysYParser::GtExpContext* ctx) override;
|
||||
std::any visitGeExp(SysYParser::GeExpContext* ctx) override;
|
||||
std::any visitEqExp(SysYParser::EqExpContext* ctx) override;
|
||||
std::any visitNeExp(SysYParser::NeExpContext* ctx) override;
|
||||
std::any visitAndExp(SysYParser::AndExpContext* ctx) override;
|
||||
std::any visitOrExp(SysYParser::OrExpContext* ctx) override;
|
||||
std::any visitVarExp(SysYParser::VarExpContext* ctx) override;
|
||||
std::any visitAdditiveExp(SysYParser::AdditiveExpContext* ctx) override;
|
||||
|
||||
private:
|
||||
enum class BlockFlow {
|
||||
@@ -77,35 +43,15 @@ class IRGenImpl final : public SysYBaseVisitor {
|
||||
Terminated,
|
||||
};
|
||||
|
||||
BlockFlow VisitBlockItemResult(SysYParser::BlockItemContext& item);
|
||||
ir::Value* EvalExpr(SysYParser::ExpContext& expr);
|
||||
ir::ConstantValue* EvalConstExpr(SysYParser::ExpContext& expr);
|
||||
ir::Value* GetLValuePtr(SysYParser::LValueContext* ctx);
|
||||
ir::Value* DecayArrayPtr(SysYParser::LValueContext* ctx);
|
||||
bool IsArrayLikeDef(antlr4::ParserRuleContext* def) const;
|
||||
size_t GetArrayRank(antlr4::ParserRuleContext* def) const;
|
||||
std::shared_ptr<ir::Type> GetDefType(antlr4::ParserRuleContext* def) const;
|
||||
void ZeroInitializeLocal(ir::Value* ptr, std::shared_ptr<ir::Type> ty);
|
||||
void EmitLocalInitValue(ir::Value* ptr, std::shared_ptr<ir::Type> ty,
|
||||
SysYParser::InitValueContext* init);
|
||||
|
||||
ir::Module& module_;
|
||||
const SemanticContext& sema_;
|
||||
ir::Function* func_;
|
||||
ir::IRBuilder builder_;
|
||||
|
||||
// Maps a definition (VarDef, ConstDef, FuncFParam) to its IR value (Alloca or GlobalValue)
|
||||
std::unordered_map<antlr4::ParserRuleContext*, ir::Value*> storage_map_;
|
||||
|
||||
// For global scope tracking
|
||||
bool is_global_scope_ = true;
|
||||
|
||||
// For loop control
|
||||
std::stack<ir::BasicBlock*> break_stack_;
|
||||
std::stack<ir::BasicBlock*> continue_stack_;
|
||||
|
||||
// Helper to handle short-circuiting and comparison results
|
||||
ir::Value* ToI1(ir::Value* v);
|
||||
ir::Value* ToI32(ir::Value* v);
|
||||
// 名称绑定由 Sema 负责;IRGen 只维护“声明 -> 存储槽位”的代码生成状态。
|
||||
std::unordered_map<SysYParser::VarDefContext*, ir::Value*> storage_map_;
|
||||
};
|
||||
|
||||
std::unique_ptr<ir::Module> GenerateIR(SysYParser::CompUnitContext& tree,
|
||||
|
||||
@@ -1,40 +1,30 @@
|
||||
// 基于语法树的语义检查与名称绑定。
|
||||
#pragma once
|
||||
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "SysYParser.h"
|
||||
|
||||
class SemanticContext {
|
||||
public:
|
||||
void BindLValue(SysYParser::LValueContext* use,
|
||||
antlr4::ParserRuleContext* def) {
|
||||
lvalue_defs_[use] = def;
|
||||
void BindVarUse(SysYParser::VarContext* use,
|
||||
SysYParser::VarDefContext* decl) {
|
||||
var_uses_[use] = decl;
|
||||
}
|
||||
|
||||
void BindFuncCall(SysYParser::FuncCallExpContext* use,
|
||||
SysYParser::FuncDefContext* def) {
|
||||
funccall_defs_[use] = def;
|
||||
}
|
||||
|
||||
antlr4::ParserRuleContext* ResolveLValue(
|
||||
const SysYParser::LValueContext* use) const {
|
||||
auto it = lvalue_defs_.find(const_cast<SysYParser::LValueContext*>(use));
|
||||
return it == lvalue_defs_.end() ? nullptr : it->second;
|
||||
}
|
||||
|
||||
SysYParser::FuncDefContext* ResolveFuncCall(
|
||||
const SysYParser::FuncCallExpContext* use) const {
|
||||
auto it = funccall_defs_.find(const_cast<SysYParser::FuncCallExpContext*>(use));
|
||||
return it == funccall_defs_.end() ? nullptr : it->second;
|
||||
SysYParser::VarDefContext* ResolveVarUse(
|
||||
const SysYParser::VarContext* use) const {
|
||||
auto it = var_uses_.find(use);
|
||||
return it == var_uses_.end() ? nullptr : it->second;
|
||||
}
|
||||
|
||||
private:
|
||||
std::unordered_map<SysYParser::LValueContext*, antlr4::ParserRuleContext*>
|
||||
lvalue_defs_;
|
||||
std::unordered_map<SysYParser::FuncCallExpContext*,
|
||||
SysYParser::FuncDefContext*>
|
||||
funccall_defs_;
|
||||
std::unordered_map<const SysYParser::VarContext*,
|
||||
SysYParser::VarDefContext*>
|
||||
var_uses_;
|
||||
};
|
||||
|
||||
// 目前仅检查:
|
||||
// - 变量先声明后使用
|
||||
// - 局部变量不允许重复定义
|
||||
SemanticContext RunSema(SysYParser::CompUnitContext& comp_unit);
|
||||
|
||||
@@ -1,30 +1,17 @@
|
||||
// 极简符号表:记录局部变量定义点。
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "SysYParser.h"
|
||||
|
||||
struct Symbol {
|
||||
enum class Kind { Variable, Constant, Function, Parameter };
|
||||
Kind kind;
|
||||
antlr4::ParserRuleContext* def_ctx;
|
||||
bool is_const = false;
|
||||
bool is_array = false;
|
||||
// For functions, we can store pointers to their parameter types or just the
|
||||
// FuncDefContext*
|
||||
};
|
||||
|
||||
class SymbolTable {
|
||||
public:
|
||||
SymbolTable();
|
||||
void PushScope();
|
||||
void PopScope();
|
||||
bool Add(const std::string& name, const Symbol& symbol);
|
||||
Symbol* Lookup(const std::string& name);
|
||||
bool IsInCurrentScope(const std::string& name) const;
|
||||
void Add(const std::string& name, SysYParser::VarDefContext* decl);
|
||||
bool Contains(const std::string& name) const;
|
||||
SysYParser::VarDefContext* Lookup(const std::string& name) const;
|
||||
|
||||
private:
|
||||
std::vector<std::unordered_map<std::string, Symbol>> scopes_;
|
||||
std::unordered_map<std::string, SysYParser::VarDefContext*> table_;
|
||||
};
|
||||
|
||||
@@ -15,7 +15,7 @@ namespace ir {
|
||||
|
||||
// 当前 BasicBlock 还没有专门的 label type,因此先用 void 作为占位类型。
|
||||
BasicBlock::BasicBlock(std::string name)
|
||||
: Value(Type::GetLabelType(), std::move(name)) {}
|
||||
: Value(Type::GetVoidType(), std::move(name)) {}
|
||||
|
||||
Function* BasicBlock::GetParent() const { return parent_; }
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ add_library(ir_core STATIC
|
||||
Module.cpp
|
||||
Function.cpp
|
||||
BasicBlock.cpp
|
||||
GlobalValue.cpp
|
||||
Type.cpp
|
||||
Value.cpp
|
||||
Instruction.cpp
|
||||
|
||||
@@ -15,18 +15,10 @@ ConstantInt* Context::GetConstInt(int v) {
|
||||
return inserted->second.get();
|
||||
}
|
||||
|
||||
ConstantFloat* Context::GetConstFloat(float v) {
|
||||
auto it = const_floats_.find(v);
|
||||
if (it != const_floats_.end()) return it->second.get();
|
||||
auto inserted =
|
||||
const_floats_
|
||||
.emplace(v, std::make_unique<ConstantFloat>(Type::GetFloatType(), v))
|
||||
.first;
|
||||
return inserted->second.get();
|
||||
}
|
||||
|
||||
std::string Context::NextTemp() {
|
||||
return "t" + std::to_string(++temp_index_);
|
||||
std::ostringstream oss;
|
||||
oss << "%" << ++temp_index_;
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
} // namespace ir
|
||||
|
||||
@@ -5,14 +5,9 @@
|
||||
|
||||
namespace ir {
|
||||
|
||||
Function::Function(std::string name, std::shared_ptr<Type> ret_type,
|
||||
std::vector<std::shared_ptr<Type>> param_types)
|
||||
Function::Function(std::string name, std::shared_ptr<Type> ret_type)
|
||||
: Value(std::move(ret_type), std::move(name)) {
|
||||
for (size_t i = 0; i < param_types.size(); ++i) {
|
||||
arguments_.push_back(std::make_unique<Argument>(
|
||||
param_types[i], "a" + std::to_string(i), this,
|
||||
static_cast<unsigned>(i)));
|
||||
}
|
||||
entry_ = CreateBlock("entry");
|
||||
}
|
||||
|
||||
BasicBlock* Function::CreateBlock(const std::string& name) {
|
||||
@@ -34,8 +29,4 @@ const std::vector<std::unique_ptr<BasicBlock>>& Function::GetBlocks() const {
|
||||
return blocks_;
|
||||
}
|
||||
|
||||
const std::vector<std::unique_ptr<Argument>>& Function::GetArguments() const {
|
||||
return arguments_;
|
||||
}
|
||||
|
||||
} // namespace ir
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
|
||||
namespace ir {
|
||||
|
||||
GlobalValue::GlobalValue(std::shared_ptr<Type> ty, std::string name, ConstantValue* init)
|
||||
: User(std::move(ty), std::move(name)), init_(init) {}
|
||||
GlobalValue::GlobalValue(std::shared_ptr<Type> ty, std::string name)
|
||||
: User(std::move(ty), std::move(name)) {}
|
||||
|
||||
} // namespace ir
|
||||
|
||||
@@ -21,11 +21,6 @@ ConstantInt* IRBuilder::CreateConstInt(int v) {
|
||||
return ctx_.GetConstInt(v);
|
||||
}
|
||||
|
||||
ConstantFloat* IRBuilder::CreateConstFloat(float v) {
|
||||
// 常量不需要挂在基本块里,由 Context 负责去重与生命周期。
|
||||
return ctx_.GetConstFloat(v);
|
||||
}
|
||||
|
||||
BinaryInst* IRBuilder::CreateBinary(Opcode op, Value* lhs, Value* rhs,
|
||||
const std::string& name) {
|
||||
if (!insert_block_) {
|
||||
@@ -47,74 +42,11 @@ BinaryInst* IRBuilder::CreateAdd(Value* lhs, Value* rhs,
|
||||
return CreateBinary(Opcode::Add, lhs, rhs, name);
|
||||
}
|
||||
|
||||
BinaryInst* IRBuilder::CreateSub(Value* lhs, Value* rhs,
|
||||
const std::string& name) {
|
||||
return CreateBinary(Opcode::Sub, lhs, rhs, name);
|
||||
}
|
||||
|
||||
BinaryInst* IRBuilder::CreateMul(Value* lhs, Value* rhs,
|
||||
const std::string& name) {
|
||||
return CreateBinary(Opcode::Mul, lhs, rhs, name);
|
||||
}
|
||||
|
||||
BinaryInst* IRBuilder::CreateDiv(Value* lhs, Value* rhs,
|
||||
const std::string& name) {
|
||||
return CreateBinary(Opcode::Div, lhs, rhs, name);
|
||||
}
|
||||
|
||||
BinaryInst* IRBuilder::CreateMod(Value* lhs, Value* rhs,
|
||||
const std::string& name) {
|
||||
return CreateBinary(Opcode::Mod, lhs, rhs, name);
|
||||
}
|
||||
|
||||
BinaryInst* IRBuilder::CreateFAdd(Value* lhs, Value* rhs,
|
||||
const std::string& name) {
|
||||
return CreateBinary(Opcode::FAdd, lhs, rhs, name);
|
||||
}
|
||||
|
||||
BinaryInst* IRBuilder::CreateFSub(Value* lhs, Value* rhs,
|
||||
const std::string& name) {
|
||||
return CreateBinary(Opcode::FSub, lhs, rhs, name);
|
||||
}
|
||||
|
||||
BinaryInst* IRBuilder::CreateFMul(Value* lhs, Value* rhs,
|
||||
const std::string& name) {
|
||||
return CreateBinary(Opcode::FMul, lhs, rhs, name);
|
||||
}
|
||||
|
||||
BinaryInst* IRBuilder::CreateFDiv(Value* lhs, Value* rhs,
|
||||
const std::string& name) {
|
||||
return CreateBinary(Opcode::FDiv, lhs, rhs, name);
|
||||
}
|
||||
|
||||
BinaryInst* IRBuilder::CreateICmp(Opcode op, Value* lhs, Value* rhs,
|
||||
const std::string& name) {
|
||||
if (!insert_block_) {
|
||||
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
|
||||
}
|
||||
return insert_block_->Append<BinaryInst>(op, Type::GetInt32Type(), lhs, rhs,
|
||||
name);
|
||||
}
|
||||
|
||||
BinaryInst* IRBuilder::CreateFCmp(Opcode op, Value* lhs, Value* rhs,
|
||||
const std::string& name) {
|
||||
if (!insert_block_) {
|
||||
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
|
||||
}
|
||||
return insert_block_->Append<BinaryInst>(op, Type::GetInt32Type(), lhs, rhs,
|
||||
name);
|
||||
}
|
||||
|
||||
AllocaInst* IRBuilder::CreateAlloca(std::shared_ptr<Type> ty,
|
||||
const std::string& name) {
|
||||
if (!insert_block_) {
|
||||
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
|
||||
}
|
||||
return insert_block_->Append<AllocaInst>(ty, name);
|
||||
}
|
||||
|
||||
AllocaInst* IRBuilder::CreateAllocaI32(const std::string& name) {
|
||||
return CreateAlloca(Type::GetPtrInt32Type(), name);
|
||||
if (!insert_block_) {
|
||||
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
|
||||
}
|
||||
return insert_block_->Append<AllocaInst>(Type::GetPtrInt32Type(), name);
|
||||
}
|
||||
|
||||
LoadInst* IRBuilder::CreateLoad(Value* ptr, const std::string& name) {
|
||||
@@ -125,15 +57,7 @@ LoadInst* IRBuilder::CreateLoad(Value* ptr, const std::string& name) {
|
||||
throw std::runtime_error(
|
||||
FormatError("ir", "IRBuilder::CreateLoad 缺少 ptr"));
|
||||
}
|
||||
std::shared_ptr<Type> val_ty;
|
||||
if (ptr->GetType()->IsPtrInt32()) {
|
||||
val_ty = Type::GetInt32Type();
|
||||
} else if (ptr->GetType()->IsPtrFloat()) {
|
||||
val_ty = Type::GetFloatType();
|
||||
} else {
|
||||
throw std::runtime_error(FormatError("ir", "LoadInst 不支持的指针类型"));
|
||||
}
|
||||
return insert_block_->Append<LoadInst>(val_ty, ptr, name);
|
||||
return insert_block_->Append<LoadInst>(Type::GetInt32Type(), ptr, name);
|
||||
}
|
||||
|
||||
StoreInst* IRBuilder::CreateStore(Value* val, Value* ptr) {
|
||||
@@ -155,63 +79,11 @@ ReturnInst* IRBuilder::CreateRet(Value* v) {
|
||||
if (!insert_block_) {
|
||||
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
|
||||
}
|
||||
if (!v) {
|
||||
throw std::runtime_error(
|
||||
FormatError("ir", "IRBuilder::CreateRet 缺少返回值"));
|
||||
}
|
||||
return insert_block_->Append<ReturnInst>(Type::GetVoidType(), v);
|
||||
}
|
||||
|
||||
BranchInst* IRBuilder::CreateBr(BasicBlock* dest) {
|
||||
if (!insert_block_) {
|
||||
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
|
||||
}
|
||||
return insert_block_->Append<BranchInst>(dest);
|
||||
}
|
||||
|
||||
BranchInst* IRBuilder::CreateCondBr(Value* cond, BasicBlock* if_true,
|
||||
BasicBlock* if_false) {
|
||||
if (!insert_block_) {
|
||||
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
|
||||
}
|
||||
return insert_block_->Append<BranchInst>(cond, if_true, if_false);
|
||||
}
|
||||
|
||||
CallInst* IRBuilder::CreateCall(Function* func, const std::vector<Value*>& args,
|
||||
const std::string& name) {
|
||||
if (!insert_block_) {
|
||||
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
|
||||
}
|
||||
return insert_block_->Append<CallInst>(func, args, name);
|
||||
}
|
||||
|
||||
GetElementPtrInst* IRBuilder::CreateGEP(std::shared_ptr<Type> ptr_ty, Value* ptr,
|
||||
const std::vector<Value*>& indices,
|
||||
const std::string& name) {
|
||||
if (!insert_block_) {
|
||||
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
|
||||
}
|
||||
return insert_block_->Append<GetElementPtrInst>(ptr_ty, ptr, indices, name);
|
||||
}
|
||||
|
||||
CastInst* IRBuilder::CreateZExt(Value* val, std::shared_ptr<Type> ty,
|
||||
const std::string& name) {
|
||||
if (!insert_block_) {
|
||||
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
|
||||
}
|
||||
return insert_block_->Append<CastInst>(Opcode::ZExt, ty, val, name);
|
||||
}
|
||||
|
||||
CastInst* IRBuilder::CreateSIToFP(Value* val, std::shared_ptr<Type> ty,
|
||||
const std::string& name) {
|
||||
if (!insert_block_) {
|
||||
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
|
||||
}
|
||||
return insert_block_->Append<CastInst>(Opcode::SIToFP, ty, val, name);
|
||||
}
|
||||
|
||||
CastInst* IRBuilder::CreateFPToSI(Value* val, std::shared_ptr<Type> ty,
|
||||
const std::string& name) {
|
||||
if (!insert_block_) {
|
||||
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
|
||||
}
|
||||
return insert_block_->Append<CastInst>(Opcode::FPToSI, ty, val, name);
|
||||
}
|
||||
|
||||
} // namespace ir
|
||||
|
||||
@@ -4,11 +4,7 @@
|
||||
|
||||
#include "ir/IR.h"
|
||||
|
||||
#include <cstdio>
|
||||
#include <iomanip>
|
||||
#include <sstream>
|
||||
#include <ostream>
|
||||
#include <cstring>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
|
||||
@@ -16,7 +12,7 @@
|
||||
|
||||
namespace ir {
|
||||
|
||||
static std::string TypeToString(const Type& ty) {
|
||||
static const char* TypeToString(const Type& ty) {
|
||||
switch (ty.GetKind()) {
|
||||
case Type::Kind::Void:
|
||||
return "void";
|
||||
@@ -24,22 +20,11 @@ static std::string TypeToString(const Type& ty) {
|
||||
return "i32";
|
||||
case Type::Kind::PtrInt32:
|
||||
return "i32*";
|
||||
case Type::Kind::Float:
|
||||
return "float";
|
||||
case Type::Kind::PtrFloat:
|
||||
return "float*";
|
||||
case Type::Kind::Label:
|
||||
return "label";
|
||||
case Type::Kind::Array: {
|
||||
const auto* arr_ty = static_cast<const ArrayType*>(&ty);
|
||||
return "[" + std::to_string(arr_ty->GetNumElements()) + " x " +
|
||||
TypeToString(*arr_ty->GetElementType()) + "]";
|
||||
}
|
||||
}
|
||||
return "unknown";
|
||||
throw std::runtime_error(FormatError("ir", "未知类型"));
|
||||
}
|
||||
|
||||
static std::string OpcodeToString(Opcode op) {
|
||||
static const char* OpcodeToString(Opcode op) {
|
||||
switch (op) {
|
||||
case Opcode::Add:
|
||||
return "add";
|
||||
@@ -47,42 +32,6 @@ static std::string OpcodeToString(Opcode op) {
|
||||
return "sub";
|
||||
case Opcode::Mul:
|
||||
return "mul";
|
||||
case Opcode::Div:
|
||||
return "sdiv";
|
||||
case Opcode::Mod:
|
||||
return "srem";
|
||||
case Opcode::FAdd:
|
||||
return "fadd";
|
||||
case Opcode::FSub:
|
||||
return "fsub";
|
||||
case Opcode::FMul:
|
||||
return "fmul";
|
||||
case Opcode::FDiv:
|
||||
return "fdiv";
|
||||
case Opcode::ICmpEQ:
|
||||
return "icmp eq";
|
||||
case Opcode::ICmpNE:
|
||||
return "icmp ne";
|
||||
case Opcode::ICmpLT:
|
||||
return "icmp slt";
|
||||
case Opcode::ICmpGT:
|
||||
return "icmp sgt";
|
||||
case Opcode::ICmpLE:
|
||||
return "icmp sle";
|
||||
case Opcode::ICmpGE:
|
||||
return "icmp sge";
|
||||
case Opcode::FCmpEQ:
|
||||
return "fcmp oeq";
|
||||
case Opcode::FCmpNE:
|
||||
return "fcmp une";
|
||||
case Opcode::FCmpLT:
|
||||
return "fcmp olt";
|
||||
case Opcode::FCmpGT:
|
||||
return "fcmp ogt";
|
||||
case Opcode::FCmpLE:
|
||||
return "fcmp ole";
|
||||
case Opcode::FCmpGE:
|
||||
return "fcmp oge";
|
||||
case Opcode::Alloca:
|
||||
return "alloca";
|
||||
case Opcode::Load:
|
||||
@@ -91,114 +40,21 @@ static std::string OpcodeToString(Opcode op) {
|
||||
return "store";
|
||||
case Opcode::Ret:
|
||||
return "ret";
|
||||
case Opcode::Br:
|
||||
return "br";
|
||||
case Opcode::Call:
|
||||
return "call";
|
||||
case Opcode::GEP:
|
||||
return "getelementptr";
|
||||
case Opcode::ZExt:
|
||||
return "zext";
|
||||
case Opcode::SIToFP:
|
||||
return "sitofp";
|
||||
case Opcode::FPToSI:
|
||||
return "fptosi";
|
||||
}
|
||||
return "?";
|
||||
}
|
||||
|
||||
static std::string ValueToString(const Value* v) {
|
||||
if (!v) return "<null>";
|
||||
if (auto* ci = dynamic_cast<const ConstantInt*>(v)) {
|
||||
return std::to_string(ci->GetValue());
|
||||
}
|
||||
if (auto* cf = dynamic_cast<const ConstantFloat*>(v)) {
|
||||
const double as_double = static_cast<double>(cf->GetValue());
|
||||
uint64_t bits = 0;
|
||||
std::memcpy(&bits, &as_double, sizeof(bits));
|
||||
std::ostringstream oss;
|
||||
oss << "0x" << std::hex << std::uppercase << std::setw(16)
|
||||
<< std::setfill('0') << bits;
|
||||
return oss.str();
|
||||
}
|
||||
if (v->IsGlobalValue() || v->IsFunction()) {
|
||||
return "@" + v->GetName();
|
||||
}
|
||||
if (v->IsInstruction() || v->IsArgument() || v->GetType()->IsLabel()) {
|
||||
return "%" + v->GetName();
|
||||
}
|
||||
return v->GetName();
|
||||
}
|
||||
|
||||
static bool IsBoolLikeValue(const Value* v) {
|
||||
if (auto* inst = dynamic_cast<const Instruction*>(v)) {
|
||||
switch (inst->GetOpcode()) {
|
||||
case Opcode::ICmpEQ:
|
||||
case Opcode::ICmpNE:
|
||||
case Opcode::ICmpLT:
|
||||
case Opcode::ICmpGT:
|
||||
case Opcode::ICmpLE:
|
||||
case Opcode::ICmpGE:
|
||||
case Opcode::FCmpEQ:
|
||||
case Opcode::FCmpNE:
|
||||
case Opcode::FCmpLT:
|
||||
case Opcode::FCmpGT:
|
||||
case Opcode::FCmpLE:
|
||||
case Opcode::FCmpGE:
|
||||
return true;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
static std::string PrintedValueType(const Value* v) {
|
||||
if (IsBoolLikeValue(v)) return "i1";
|
||||
return TypeToString(*v->GetType());
|
||||
return v ? v->GetName() : "<null>";
|
||||
}
|
||||
|
||||
void IRPrinter::Print(const Module& module, std::ostream& os) {
|
||||
// Print global variables
|
||||
for (const auto& gv : module.GetGlobalValues()) {
|
||||
os << "@" << gv->GetName() << " = global ";
|
||||
if (gv->GetType()->IsPtrInt32()) {
|
||||
os << "i32";
|
||||
} else if (gv->GetType()->IsPtrFloat()) {
|
||||
os << "float";
|
||||
} else {
|
||||
os << TypeToString(*gv->GetType());
|
||||
}
|
||||
if (gv->GetInitializer()) {
|
||||
os << " " << ValueToString(gv->GetInitializer());
|
||||
} else {
|
||||
os << " zeroinitializer";
|
||||
}
|
||||
os << "\n";
|
||||
}
|
||||
if (!module.GetGlobalValues().empty()) os << "\n";
|
||||
|
||||
for (const auto& func : module.GetFunctions()) {
|
||||
if (func->GetBlocks().empty()) {
|
||||
os << "declare " << TypeToString(*func->GetType()) << " @" << func->GetName()
|
||||
<< "(";
|
||||
// For declarations, we just need types. But Argument objects might exist.
|
||||
const auto& args = func->GetArguments();
|
||||
for (size_t i = 0; i < args.size(); ++i) {
|
||||
os << TypeToString(*args[i]->GetType());
|
||||
if (i + 1 < args.size()) os << ", ";
|
||||
}
|
||||
os << ")\n\n";
|
||||
continue;
|
||||
}
|
||||
os << "define " << TypeToString(*func->GetType()) << " @" << func->GetName()
|
||||
<< "(";
|
||||
const auto& args = func->GetArguments();
|
||||
for (size_t i = 0; i < args.size(); ++i) {
|
||||
os << TypeToString(*args[i]->GetType()) << " %" << args[i]->GetName();
|
||||
if (i + 1 < args.size()) os << ", ";
|
||||
}
|
||||
os << ") {\n";
|
||||
<< "() {\n";
|
||||
for (const auto& bb : func->GetBlocks()) {
|
||||
if (!bb) {
|
||||
continue;
|
||||
@@ -209,142 +65,36 @@ void IRPrinter::Print(const Module& module, std::ostream& os) {
|
||||
switch (inst->GetOpcode()) {
|
||||
case Opcode::Add:
|
||||
case Opcode::Sub:
|
||||
case Opcode::Mul:
|
||||
case Opcode::Div:
|
||||
case Opcode::Mod:
|
||||
case Opcode::FAdd:
|
||||
case Opcode::FSub:
|
||||
case Opcode::FMul:
|
||||
case Opcode::FDiv: {
|
||||
case Opcode::Mul: {
|
||||
auto* bin = static_cast<const BinaryInst*>(inst);
|
||||
os << " %" << bin->GetName() << " = "
|
||||
os << " " << bin->GetName() << " = "
|
||||
<< OpcodeToString(bin->GetOpcode()) << " "
|
||||
<< PrintedValueType(bin->GetLhs()) << " "
|
||||
<< ValueToString(bin->GetLhs()) << ", "
|
||||
<< ValueToString(bin->GetRhs()) << "\n";
|
||||
break;
|
||||
}
|
||||
case Opcode::ICmpEQ:
|
||||
case Opcode::ICmpNE:
|
||||
case Opcode::ICmpLT:
|
||||
case Opcode::ICmpGT:
|
||||
case Opcode::ICmpLE:
|
||||
case Opcode::ICmpGE:
|
||||
case Opcode::FCmpEQ:
|
||||
case Opcode::FCmpNE:
|
||||
case Opcode::FCmpLT:
|
||||
case Opcode::FCmpGT:
|
||||
case Opcode::FCmpLE:
|
||||
case Opcode::FCmpGE: {
|
||||
auto* bin = static_cast<const BinaryInst*>(inst);
|
||||
os << " %" << bin->GetName() << " = "
|
||||
<< OpcodeToString(bin->GetOpcode()) << " "
|
||||
<< PrintedValueType(bin->GetLhs()) << " "
|
||||
<< TypeToString(*bin->GetLhs()->GetType()) << " "
|
||||
<< ValueToString(bin->GetLhs()) << ", "
|
||||
<< ValueToString(bin->GetRhs()) << "\n";
|
||||
break;
|
||||
}
|
||||
case Opcode::Alloca: {
|
||||
auto* alloca = static_cast<const AllocaInst*>(inst);
|
||||
os << " %" << alloca->GetName() << " = alloca ";
|
||||
if (alloca->GetType()->IsPtrInt32())
|
||||
os << "i32";
|
||||
else if (alloca->GetType()->IsPtrFloat())
|
||||
os << "float";
|
||||
else
|
||||
os << TypeToString(*alloca->GetType());
|
||||
os << "\n";
|
||||
os << " " << alloca->GetName() << " = alloca i32\n";
|
||||
break;
|
||||
}
|
||||
case Opcode::Load: {
|
||||
auto* load = static_cast<const LoadInst*>(inst);
|
||||
os << " %" << load->GetName() << " = load "
|
||||
<< TypeToString(*load->GetType()) << ", "
|
||||
<< TypeToString(*load->GetPtr()->GetType()) << " "
|
||||
os << " " << load->GetName() << " = load i32, i32* "
|
||||
<< ValueToString(load->GetPtr()) << "\n";
|
||||
break;
|
||||
}
|
||||
case Opcode::Store: {
|
||||
auto* store = static_cast<const StoreInst*>(inst);
|
||||
os << " store " << TypeToString(*store->GetValue()->GetType())
|
||||
<< " " << ValueToString(store->GetValue()) << ", "
|
||||
<< TypeToString(*store->GetPtr()->GetType()) << " "
|
||||
<< ValueToString(store->GetPtr()) << "\n";
|
||||
os << " store i32 " << ValueToString(store->GetValue())
|
||||
<< ", i32* " << ValueToString(store->GetPtr()) << "\n";
|
||||
break;
|
||||
}
|
||||
case Opcode::Ret: {
|
||||
auto* ret = static_cast<const ReturnInst*>(inst);
|
||||
os << " ret ";
|
||||
if (ret->GetValue()) {
|
||||
os << TypeToString(*ret->GetValue()->GetType()) << " "
|
||||
<< ValueToString(ret->GetValue());
|
||||
} else {
|
||||
os << "void";
|
||||
}
|
||||
os << "\n";
|
||||
break;
|
||||
}
|
||||
case Opcode::Br: {
|
||||
auto* br = static_cast<const BranchInst*>(inst);
|
||||
if (br->IsConditional()) {
|
||||
os << " br i1 " << ValueToString(br->GetCondition())
|
||||
<< ", label " << ValueToString(br->GetIfTrue()) << ", label "
|
||||
<< ValueToString(br->GetIfFalse()) << "\n";
|
||||
} else {
|
||||
os << " br label " << ValueToString(br->GetDest()) << "\n";
|
||||
}
|
||||
break;
|
||||
}
|
||||
case Opcode::Call: {
|
||||
auto* call = static_cast<const CallInst*>(inst);
|
||||
auto* func = call->GetFunction();
|
||||
if (!call->GetType()->IsVoid()) {
|
||||
os << " %" << call->GetName() << " = ";
|
||||
} else {
|
||||
os << " ";
|
||||
}
|
||||
os << "call " << TypeToString(*call->GetType()) << " "
|
||||
<< ValueToString(func) << "(";
|
||||
for (size_t i = 1; i < call->GetNumOperands(); ++i) {
|
||||
auto* arg = call->GetOperand(i);
|
||||
os << PrintedValueType(arg) << " " << ValueToString(arg);
|
||||
if (i + 1 < call->GetNumOperands()) os << ", ";
|
||||
}
|
||||
os << ")\n";
|
||||
break;
|
||||
}
|
||||
case Opcode::GEP: {
|
||||
auto* gep = static_cast<const GetElementPtrInst*>(inst);
|
||||
os << " %" << gep->GetName() << " = getelementptr ";
|
||||
if (gep->GetPtr()->GetType()->IsPtrInt32())
|
||||
os << "i32";
|
||||
else if (gep->GetPtr()->GetType()->IsPtrFloat())
|
||||
os << "float";
|
||||
else
|
||||
os << TypeToString(*gep->GetPtr()->GetType());
|
||||
os << ", ";
|
||||
if (gep->GetPtr()->GetType()->IsArray()) {
|
||||
os << TypeToString(*gep->GetPtr()->GetType()) << "* ";
|
||||
} else {
|
||||
os << TypeToString(*gep->GetPtr()->GetType()) << " ";
|
||||
}
|
||||
os << ValueToString(gep->GetPtr());
|
||||
for (size_t i = 1; i < gep->GetNumOperands(); ++i) {
|
||||
os << ", " << TypeToString(*gep->GetOperand(i)->GetType()) << " "
|
||||
<< ValueToString(gep->GetOperand(i));
|
||||
}
|
||||
os << "\n";
|
||||
break;
|
||||
}
|
||||
case Opcode::ZExt:
|
||||
case Opcode::SIToFP:
|
||||
case Opcode::FPToSI: {
|
||||
auto* cast = static_cast<const CastInst*>(inst);
|
||||
os << " %" << cast->GetName() << " = "
|
||||
<< OpcodeToString(cast->GetOpcode()) << " "
|
||||
<< PrintedValueType(cast->GetValue()) << " "
|
||||
<< ValueToString(cast->GetValue()) << " to "
|
||||
<< TypeToString(*cast->GetType()) << "\n";
|
||||
os << " ret " << TypeToString(*ret->GetValue()->GetType()) << " "
|
||||
<< ValueToString(ret->GetValue()) << "\n";
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -52,9 +52,7 @@ Instruction::Instruction(Opcode op, std::shared_ptr<Type> ty, std::string name)
|
||||
|
||||
Opcode Instruction::GetOpcode() const { return opcode_; }
|
||||
|
||||
bool Instruction::IsTerminator() const {
|
||||
return opcode_ == Opcode::Ret || opcode_ == Opcode::Br;
|
||||
}
|
||||
bool Instruction::IsTerminator() const { return opcode_ == Opcode::Ret; }
|
||||
|
||||
BasicBlock* Instruction::GetParent() const { return parent_; }
|
||||
|
||||
@@ -63,9 +61,22 @@ void Instruction::SetParent(BasicBlock* parent) { parent_ = parent; }
|
||||
BinaryInst::BinaryInst(Opcode op, std::shared_ptr<Type> ty, Value* lhs,
|
||||
Value* rhs, std::string name)
|
||||
: Instruction(op, std::move(ty), std::move(name)) {
|
||||
if (op != Opcode::Add) {
|
||||
throw std::runtime_error(FormatError("ir", "BinaryInst 当前只支持 Add"));
|
||||
}
|
||||
if (!lhs || !rhs) {
|
||||
throw std::runtime_error(FormatError("ir", "BinaryInst 缺少操作数"));
|
||||
}
|
||||
if (!type_ || !lhs->GetType() || !rhs->GetType()) {
|
||||
throw std::runtime_error(FormatError("ir", "BinaryInst 缺少类型信息"));
|
||||
}
|
||||
if (lhs->GetType()->GetKind() != rhs->GetType()->GetKind() ||
|
||||
type_->GetKind() != lhs->GetType()->GetKind()) {
|
||||
throw std::runtime_error(FormatError("ir", "BinaryInst 类型不匹配"));
|
||||
}
|
||||
if (!type_->IsInt32()) {
|
||||
throw std::runtime_error(FormatError("ir", "BinaryInst 当前只支持 i32"));
|
||||
}
|
||||
AddOperand(lhs);
|
||||
AddOperand(rhs);
|
||||
}
|
||||
@@ -74,85 +85,38 @@ Value* BinaryInst::GetLhs() const { return GetOperand(0); }
|
||||
|
||||
Value* BinaryInst::GetRhs() const { return GetOperand(1); }
|
||||
|
||||
BranchInst::BranchInst(BasicBlock* dest)
|
||||
: Instruction(Opcode::Br, Type::GetVoidType(), "") {
|
||||
AddOperand(dest);
|
||||
}
|
||||
|
||||
BranchInst::BranchInst(Value* cond, BasicBlock* if_true, BasicBlock* if_false)
|
||||
: Instruction(Opcode::Br, Type::GetVoidType(), "") {
|
||||
AddOperand(cond);
|
||||
AddOperand(if_true);
|
||||
AddOperand(if_false);
|
||||
}
|
||||
|
||||
bool BranchInst::IsConditional() const { return GetNumOperands() == 3; }
|
||||
|
||||
Value* BranchInst::GetCondition() const {
|
||||
return IsConditional() ? GetOperand(0) : nullptr;
|
||||
}
|
||||
|
||||
BasicBlock* BranchInst::GetIfTrue() const {
|
||||
return IsConditional() ? static_cast<BasicBlock*>(GetOperand(1)) : nullptr;
|
||||
}
|
||||
|
||||
BasicBlock* BranchInst::GetIfFalse() const {
|
||||
return IsConditional() ? static_cast<BasicBlock*>(GetOperand(2)) : nullptr;
|
||||
}
|
||||
|
||||
BasicBlock* BranchInst::GetDest() const {
|
||||
return !IsConditional() ? static_cast<BasicBlock*>(GetOperand(0)) : nullptr;
|
||||
}
|
||||
|
||||
CallInst::CallInst(Function* func, const std::vector<Value*>& args,
|
||||
std::string name)
|
||||
: Instruction(Opcode::Call, func->GetType(), std::move(name)) {
|
||||
AddOperand(func);
|
||||
for (auto* arg : args) {
|
||||
AddOperand(arg);
|
||||
ReturnInst::ReturnInst(std::shared_ptr<Type> void_ty, Value* val)
|
||||
: Instruction(Opcode::Ret, std::move(void_ty), "") {
|
||||
if (!val) {
|
||||
throw std::runtime_error(FormatError("ir", "ReturnInst 缺少返回值"));
|
||||
}
|
||||
}
|
||||
|
||||
Function* CallInst::GetFunction() const {
|
||||
return static_cast<Function*>(GetOperand(0));
|
||||
}
|
||||
|
||||
GetElementPtrInst::GetElementPtrInst(std::shared_ptr<Type> ptr_ty, Value* ptr,
|
||||
const std::vector<Value*>& indices,
|
||||
std::string name)
|
||||
: Instruction(Opcode::GEP, std::move(ptr_ty), std::move(name)) {
|
||||
AddOperand(ptr);
|
||||
for (auto* idx : indices) {
|
||||
AddOperand(idx);
|
||||
if (!type_ || !type_->IsVoid()) {
|
||||
throw std::runtime_error(FormatError("ir", "ReturnInst 返回类型必须为 void"));
|
||||
}
|
||||
}
|
||||
|
||||
Value* GetElementPtrInst::GetPtr() const { return GetOperand(0); }
|
||||
|
||||
CastInst::CastInst(Opcode op, std::shared_ptr<Type> ty, Value* val,
|
||||
std::string name)
|
||||
: Instruction(op, std::move(ty), std::move(name)) {
|
||||
AddOperand(val);
|
||||
}
|
||||
|
||||
Value* CastInst::GetValue() const { return GetOperand(0); }
|
||||
Value* ReturnInst::GetValue() const { return GetOperand(0); }
|
||||
|
||||
ReturnInst::ReturnInst(std::shared_ptr<Type> void_ty, Value* val)
|
||||
: Instruction(Opcode::Ret, std::move(void_ty), "") {
|
||||
if (val) {
|
||||
AddOperand(val);
|
||||
AllocaInst::AllocaInst(std::shared_ptr<Type> ptr_ty, std::string name)
|
||||
: Instruction(Opcode::Alloca, std::move(ptr_ty), std::move(name)) {
|
||||
if (!type_ || !type_->IsPtrInt32()) {
|
||||
throw std::runtime_error(FormatError("ir", "AllocaInst 当前只支持 i32*"));
|
||||
}
|
||||
}
|
||||
|
||||
Value* ReturnInst::GetValue() const {
|
||||
return GetNumOperands() > 0 ? GetOperand(0) : nullptr;
|
||||
}
|
||||
|
||||
AllocaInst::AllocaInst(std::shared_ptr<Type> ptr_ty, std::string name)
|
||||
: Instruction(Opcode::Alloca, std::move(ptr_ty), std::move(name)) {}
|
||||
|
||||
LoadInst::LoadInst(std::shared_ptr<Type> val_ty, Value* ptr, std::string name)
|
||||
: Instruction(Opcode::Load, std::move(val_ty), std::move(name)) {
|
||||
if (!ptr) {
|
||||
throw std::runtime_error(FormatError("ir", "LoadInst 缺少 ptr"));
|
||||
}
|
||||
if (!type_ || !type_->IsInt32()) {
|
||||
throw std::runtime_error(FormatError("ir", "LoadInst 当前只支持加载 i32"));
|
||||
}
|
||||
if (!ptr->GetType() || !ptr->GetType()->IsPtrInt32()) {
|
||||
throw std::runtime_error(
|
||||
FormatError("ir", "LoadInst 当前只支持从 i32* 加载"));
|
||||
}
|
||||
AddOperand(ptr);
|
||||
}
|
||||
|
||||
@@ -160,6 +124,22 @@ Value* LoadInst::GetPtr() const { return GetOperand(0); }
|
||||
|
||||
StoreInst::StoreInst(std::shared_ptr<Type> void_ty, Value* val, Value* ptr)
|
||||
: Instruction(Opcode::Store, std::move(void_ty), "") {
|
||||
if (!val) {
|
||||
throw std::runtime_error(FormatError("ir", "StoreInst 缺少 value"));
|
||||
}
|
||||
if (!ptr) {
|
||||
throw std::runtime_error(FormatError("ir", "StoreInst 缺少 ptr"));
|
||||
}
|
||||
if (!type_ || !type_->IsVoid()) {
|
||||
throw std::runtime_error(FormatError("ir", "StoreInst 返回类型必须为 void"));
|
||||
}
|
||||
if (!val->GetType() || !val->GetType()->IsInt32()) {
|
||||
throw std::runtime_error(FormatError("ir", "StoreInst 当前只支持存储 i32"));
|
||||
}
|
||||
if (!ptr->GetType() || !ptr->GetType()->IsPtrInt32()) {
|
||||
throw std::runtime_error(
|
||||
FormatError("ir", "StoreInst 当前只支持写入 i32*"));
|
||||
}
|
||||
AddOperand(val);
|
||||
AddOperand(ptr);
|
||||
}
|
||||
|
||||
@@ -9,10 +9,8 @@ Context& Module::GetContext() { return context_; }
|
||||
const Context& Module::GetContext() const { return context_; }
|
||||
|
||||
Function* Module::CreateFunction(const std::string& name,
|
||||
std::shared_ptr<Type> ret_type,
|
||||
std::vector<std::shared_ptr<Type>> param_types) {
|
||||
functions_.push_back(std::make_unique<Function>(name, std::move(ret_type),
|
||||
std::move(param_types)));
|
||||
std::shared_ptr<Type> ret_type) {
|
||||
functions_.push_back(std::make_unique<Function>(name, std::move(ret_type)));
|
||||
return functions_.back().get();
|
||||
}
|
||||
|
||||
@@ -20,15 +18,4 @@ const std::vector<std::unique_ptr<Function>>& Module::GetFunctions() const {
|
||||
return functions_;
|
||||
}
|
||||
|
||||
GlobalValue* Module::CreateGlobalValue(const std::string& name,
|
||||
std::shared_ptr<Type> ty,
|
||||
ConstantValue* init) {
|
||||
global_values_.push_back(std::make_unique<GlobalValue>(std::move(ty), name, init));
|
||||
return global_values_.back().get();
|
||||
}
|
||||
|
||||
const std::vector<std::unique_ptr<GlobalValue>>& Module::GetGlobalValues() const {
|
||||
return global_values_;
|
||||
}
|
||||
|
||||
} // namespace ir
|
||||
|
||||
@@ -20,21 +20,6 @@ const std::shared_ptr<Type>& Type::GetPtrInt32Type() {
|
||||
return type;
|
||||
}
|
||||
|
||||
const std::shared_ptr<Type>& Type::GetFloatType() {
|
||||
static const std::shared_ptr<Type> type = std::make_shared<Type>(Kind::Float);
|
||||
return type;
|
||||
}
|
||||
|
||||
const std::shared_ptr<Type>& Type::GetPtrFloatType() {
|
||||
static const std::shared_ptr<Type> type = std::make_shared<Type>(Kind::PtrFloat);
|
||||
return type;
|
||||
}
|
||||
|
||||
const std::shared_ptr<Type>& Type::GetLabelType() {
|
||||
static const std::shared_ptr<Type> type = std::make_shared<Type>(Kind::Label);
|
||||
return type;
|
||||
}
|
||||
|
||||
Type::Kind Type::GetKind() const { return kind_; }
|
||||
|
||||
bool Type::IsVoid() const { return kind_ == Kind::Void; }
|
||||
@@ -43,29 +28,4 @@ bool Type::IsInt32() const { return kind_ == Kind::Int32; }
|
||||
|
||||
bool Type::IsPtrInt32() const { return kind_ == Kind::PtrInt32; }
|
||||
|
||||
bool Type::IsFloat() const { return kind_ == Kind::Float; }
|
||||
|
||||
bool Type::IsPtrFloat() const { return kind_ == Kind::PtrFloat; }
|
||||
|
||||
bool Type::IsLabel() const { return kind_ == Kind::Label; }
|
||||
|
||||
bool Type::IsArray() const { return kind_ == Kind::Array; }
|
||||
|
||||
std::shared_ptr<class ArrayType> Type::GetAsArrayType() {
|
||||
if (IsArray()) {
|
||||
return std::static_pointer_cast<ArrayType>(shared_from_this());
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
ArrayType::ArrayType(std::shared_ptr<Type> element_type, uint32_t num_elements)
|
||||
: Type(Kind::Array),
|
||||
element_type_(std::move(element_type)),
|
||||
num_elements_(num_elements) {}
|
||||
|
||||
std::shared_ptr<ArrayType> ArrayType::Get(std::shared_ptr<Type> element_type,
|
||||
uint32_t num_elements) {
|
||||
return std::make_shared<ArrayType>(std::move(element_type), num_elements);
|
||||
}
|
||||
|
||||
} // namespace ir
|
||||
|
||||
@@ -22,12 +22,6 @@ bool Value::IsInt32() const { return type_ && type_->IsInt32(); }
|
||||
|
||||
bool Value::IsPtrInt32() const { return type_ && type_->IsPtrInt32(); }
|
||||
|
||||
bool Value::IsFloat() const { return type_ && type_->IsFloat(); }
|
||||
|
||||
bool Value::IsPtrFloat() const { return type_ && type_->IsPtrFloat(); }
|
||||
|
||||
bool Value::IsLabel() const { return type_ && type_->IsLabel(); }
|
||||
|
||||
bool Value::IsConstant() const {
|
||||
return dynamic_cast<const ConstantValue*>(this) != nullptr;
|
||||
}
|
||||
@@ -44,14 +38,6 @@ bool Value::IsFunction() const {
|
||||
return dynamic_cast<const Function*>(this) != nullptr;
|
||||
}
|
||||
|
||||
bool Value::IsGlobalValue() const {
|
||||
return dynamic_cast<const GlobalValue*>(this) != nullptr;
|
||||
}
|
||||
|
||||
bool Value::IsArgument() const {
|
||||
return dynamic_cast<const Argument*>(this) != nullptr;
|
||||
}
|
||||
|
||||
void Value::AddUse(User* user, size_t operand_index) {
|
||||
if (!user) return;
|
||||
uses_.push_back(Use(this, user, operand_index));
|
||||
@@ -88,27 +74,10 @@ void Value::ReplaceAllUsesWith(Value* new_value) {
|
||||
}
|
||||
}
|
||||
|
||||
Argument::Argument(std::shared_ptr<Type> ty, std::string name, Function* parent,
|
||||
unsigned arg_no)
|
||||
: Value(std::move(ty), std::move(name)),
|
||||
parent_(parent),
|
||||
arg_no_(arg_no) {}
|
||||
|
||||
ConstantValue::ConstantValue(std::shared_ptr<Type> ty, std::string name)
|
||||
: Value(std::move(ty), std::move(name)) {}
|
||||
|
||||
ConstantInt::ConstantInt(std::shared_ptr<Type> ty, int v)
|
||||
: ConstantValue(std::move(ty), ""), value_(v) {}
|
||||
|
||||
ConstantFloat::ConstantFloat(std::shared_ptr<Type> ty, float v)
|
||||
: ConstantValue(std::move(ty), ""), value_(v) {}
|
||||
|
||||
GlobalValue::GlobalValue(std::shared_ptr<Type> ty, std::string name,
|
||||
ConstantValue* init)
|
||||
: User(std::move(ty), std::move(name)), init_(init) {
|
||||
if (init_) {
|
||||
AddOperand(init_);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ir
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
#include "irgen/IRGen.h"
|
||||
|
||||
#include <stdexcept>
|
||||
#include <vector>
|
||||
|
||||
#include "SysYParser.h"
|
||||
#include "ir/IR.h"
|
||||
@@ -9,209 +8,100 @@
|
||||
|
||||
namespace {
|
||||
|
||||
std::shared_ptr<ir::Type> BaseTypeFromDecl(SysYParser::BtypeContext* btype) {
|
||||
return (btype && btype->FLOAT()) ? ir::Type::GetFloatType() : ir::Type::GetInt32Type();
|
||||
}
|
||||
|
||||
std::shared_ptr<ir::Type> StorageType(std::shared_ptr<ir::Type> ty) {
|
||||
if (ty->IsInt32()) return ir::Type::GetPtrInt32Type();
|
||||
if (ty->IsFloat()) return ir::Type::GetPtrFloatType();
|
||||
return ty;
|
||||
}
|
||||
|
||||
size_t CountScalars(const std::shared_ptr<ir::Type>& ty) {
|
||||
if (!ty->IsArray()) return 1;
|
||||
auto arr_ty = ty->GetAsArrayType();
|
||||
return arr_ty->GetNumElements() * CountScalars(arr_ty->GetElementType());
|
||||
std::string GetLValueName(SysYParser::LValueContext& lvalue) {
|
||||
if (!lvalue.ID()) {
|
||||
throw std::runtime_error(FormatError("irgen", "非法左值"));
|
||||
}
|
||||
return lvalue.ID()->getText();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void IRGenImpl::ZeroInitializeLocal(ir::Value* ptr, std::shared_ptr<ir::Type> ty) {
|
||||
if (ty->IsArray()) {
|
||||
auto arr_ty = ty->GetAsArrayType();
|
||||
for (uint32_t i = 0; i < arr_ty->GetNumElements(); ++i) {
|
||||
auto* elem_ptr = builder_.CreateGEP(StorageType(arr_ty->GetElementType()), ptr,
|
||||
{builder_.CreateConstInt(0),
|
||||
builder_.CreateConstInt(static_cast<int>(i))},
|
||||
module_.GetContext().NextTemp());
|
||||
ZeroInitializeLocal(elem_ptr, arr_ty->GetElementType());
|
||||
}
|
||||
return;
|
||||
std::any IRGenImpl::visitBlockStmt(SysYParser::BlockStmtContext* ctx) {
|
||||
if (!ctx) {
|
||||
throw std::runtime_error(FormatError("irgen", "缺少语句块"));
|
||||
}
|
||||
|
||||
ir::Value* zero = ty->IsFloat() ? static_cast<ir::Value*>(builder_.CreateConstFloat(0.0f))
|
||||
: static_cast<ir::Value*>(builder_.CreateConstInt(0));
|
||||
builder_.CreateStore(zero, ptr);
|
||||
}
|
||||
|
||||
void IRGenImpl::EmitLocalInitValue(ir::Value* ptr, std::shared_ptr<ir::Type> ty,
|
||||
SysYParser::InitValueContext* init) {
|
||||
if (!init) return;
|
||||
|
||||
auto build_flat_scalar_ptr = [&](ir::Value* base_ptr,
|
||||
const std::shared_ptr<ir::Type>& base_ty,
|
||||
size_t flat_index) -> ir::Value* {
|
||||
if (!base_ty->IsArray()) return base_ptr;
|
||||
|
||||
std::vector<ir::Value*> indices;
|
||||
indices.push_back(builder_.CreateConstInt(0));
|
||||
auto cur_ty = base_ty;
|
||||
size_t offset = flat_index;
|
||||
while (cur_ty->IsArray()) {
|
||||
auto arr_ty = cur_ty->GetAsArrayType();
|
||||
const size_t step = CountScalars(arr_ty->GetElementType());
|
||||
const size_t idx = step == 0 ? 0 : offset / step;
|
||||
offset = step == 0 ? 0 : offset % step;
|
||||
indices.push_back(builder_.CreateConstInt(static_cast<int>(idx)));
|
||||
cur_ty = arr_ty->GetElementType();
|
||||
}
|
||||
return builder_.CreateGEP(StorageType(cur_ty), base_ptr, indices,
|
||||
module_.GetContext().NextTemp());
|
||||
};
|
||||
|
||||
if (ty->IsArray()) {
|
||||
auto arr_ty = ty->GetAsArrayType();
|
||||
const auto elem_ty = arr_ty->GetElementType();
|
||||
const size_t elem_step = CountScalars(elem_ty);
|
||||
|
||||
if (init->exp()) {
|
||||
auto* elem_ptr = build_flat_scalar_ptr(ptr, ty, 0);
|
||||
EmitLocalInitValue(elem_ptr, elem_ty, init);
|
||||
return;
|
||||
}
|
||||
|
||||
const auto& children = init->initValue();
|
||||
size_t scalar_cursor = 0;
|
||||
for (auto* child : children) {
|
||||
if (scalar_cursor >= CountScalars(ty)) break;
|
||||
if (child->exp()) {
|
||||
auto* elem_ptr = build_flat_scalar_ptr(ptr, ty, scalar_cursor);
|
||||
auto scalar_ty = elem_ty;
|
||||
while (scalar_ty->IsArray()) {
|
||||
scalar_ty = scalar_ty->GetAsArrayType()->GetElementType();
|
||||
}
|
||||
EmitLocalInitValue(elem_ptr, scalar_ty, child);
|
||||
++scalar_cursor;
|
||||
continue;
|
||||
for (auto* item : ctx->blockItem()) {
|
||||
if (item) {
|
||||
if (VisitBlockItemResult(*item) == BlockFlow::Terminated) {
|
||||
// 当前语法要求 return 为块内最后一条语句;命中后可停止生成。
|
||||
break;
|
||||
}
|
||||
|
||||
const size_t elem_index = elem_step == 0 ? 0 : scalar_cursor / elem_step;
|
||||
if (elem_index >= arr_ty->GetNumElements()) break;
|
||||
auto* elem_ptr = builder_.CreateGEP(StorageType(elem_ty), ptr,
|
||||
{builder_.CreateConstInt(0),
|
||||
builder_.CreateConstInt(static_cast<int>(elem_index))},
|
||||
module_.GetContext().NextTemp());
|
||||
EmitLocalInitValue(elem_ptr, elem_ty, child);
|
||||
scalar_cursor = (elem_index + 1) * elem_step;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (!init->exp()) {
|
||||
return;
|
||||
}
|
||||
ir::Value* value = EvalExpr(*init->exp());
|
||||
if (ty->IsFloat() && value->GetType()->IsInt32()) {
|
||||
value = builder_.CreateSIToFP(value, ty, module_.GetContext().NextTemp());
|
||||
} else if (ty->IsInt32() && value->GetType()->IsFloat()) {
|
||||
value = builder_.CreateFPToSI(value, ty, module_.GetContext().NextTemp());
|
||||
}
|
||||
builder_.CreateStore(value, ptr);
|
||||
return {};
|
||||
}
|
||||
|
||||
IRGenImpl::BlockFlow IRGenImpl::VisitBlockItemResult(
|
||||
SysYParser::BlockItemContext& item) {
|
||||
return std::any_cast<BlockFlow>(item.accept(this));
|
||||
}
|
||||
|
||||
std::any IRGenImpl::visitBlockItem(SysYParser::BlockItemContext* ctx) {
|
||||
if (!ctx) {
|
||||
throw std::runtime_error(FormatError("irgen", "缺少块内项"));
|
||||
}
|
||||
if (ctx->decl()) {
|
||||
ctx->decl()->accept(this);
|
||||
return BlockFlow::Continue;
|
||||
}
|
||||
if (ctx->stmt()) {
|
||||
return ctx->stmt()->accept(this);
|
||||
}
|
||||
throw std::runtime_error(FormatError("irgen", "暂不支持的语句或声明"));
|
||||
}
|
||||
|
||||
// 变量声明的 IR 生成目前也是最小实现:
|
||||
// - 先检查声明的基础类型,当前仅支持局部 int;
|
||||
// - 再把 Decl 中的变量定义交给 visitVarDef 继续处理。
|
||||
//
|
||||
// 和更完整的版本相比,这里还没有:
|
||||
// - 一个 Decl 中多个变量定义的顺序处理;
|
||||
// - const、数组、全局变量等不同声明形态;
|
||||
// - 更丰富的类型系统。
|
||||
std::any IRGenImpl::visitDecl(SysYParser::DeclContext* ctx) {
|
||||
if (ctx->constDecl()) return ctx->constDecl()->accept(this);
|
||||
if (ctx->varDecl()) return ctx->varDecl()->accept(this);
|
||||
if (!ctx) {
|
||||
throw std::runtime_error(FormatError("irgen", "缺少变量声明"));
|
||||
}
|
||||
if (!ctx->btype() || !ctx->btype()->INT()) {
|
||||
throw std::runtime_error(FormatError("irgen", "当前仅支持局部 int 变量声明"));
|
||||
}
|
||||
auto* var_def = ctx->varDef();
|
||||
if (!var_def) {
|
||||
throw std::runtime_error(FormatError("irgen", "非法变量声明"));
|
||||
}
|
||||
var_def->accept(this);
|
||||
return {};
|
||||
}
|
||||
|
||||
std::any IRGenImpl::visitConstDecl(SysYParser::ConstDeclContext* ctx) {
|
||||
for (auto* def : ctx->constDef()) {
|
||||
def->accept(this);
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
std::any IRGenImpl::visitVarDecl(SysYParser::VarDeclContext* ctx) {
|
||||
for (auto* def : ctx->varDef()) {
|
||||
def->accept(this);
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
std::any IRGenImpl::visitConstDef(SysYParser::ConstDefContext* ctx) {
|
||||
const std::string name = ctx->ID()->getText();
|
||||
auto ty = BaseTypeFromDecl(
|
||||
dynamic_cast<SysYParser::ConstDeclContext*>(ctx->parent)->btype());
|
||||
const auto dims = ctx->exp();
|
||||
for (auto it = dims.rbegin(); it != dims.rend(); ++it) {
|
||||
auto* dim = EvalConstExpr(**it);
|
||||
if (auto* ci = dynamic_cast<ir::ConstantInt*>(dim)) {
|
||||
ty = ir::ArrayType::Get(ty, ci->GetValue());
|
||||
continue;
|
||||
}
|
||||
throw std::runtime_error(FormatError("irgen", "数组维度必须是整型常量"));
|
||||
}
|
||||
ir::Value* slot = nullptr;
|
||||
|
||||
if (is_global_scope_) {
|
||||
ir::ConstantValue* init = nullptr;
|
||||
if (ctx->initValue() && ctx->initValue()->exp()) {
|
||||
init = EvalConstExpr(*ctx->initValue()->exp());
|
||||
if (ty->IsInt32() && init->GetType()->IsFloat()) {
|
||||
init = module_.GetContext().GetConstInt(
|
||||
static_cast<int>(static_cast<ir::ConstantFloat*>(init)->GetValue()));
|
||||
} else if (ty->IsFloat() && init->GetType()->IsInt32()) {
|
||||
init = module_.GetContext().GetConstFloat(
|
||||
static_cast<float>(static_cast<ir::ConstantInt*>(init)->GetValue()));
|
||||
}
|
||||
}
|
||||
slot = module_.CreateGlobalValue(name, StorageType(ty), init);
|
||||
} else {
|
||||
slot = builder_.CreateAlloca(StorageType(ty), name);
|
||||
ZeroInitializeLocal(slot, ty);
|
||||
EmitLocalInitValue(slot, ty, ctx->initValue());
|
||||
}
|
||||
|
||||
storage_map_[ctx] = slot;
|
||||
return {};
|
||||
}
|
||||
|
||||
// 当前仍是教学用的最小版本,因此这里只支持:
|
||||
// - 局部 int 变量;
|
||||
// - 标量初始化;
|
||||
// - 一个 VarDef 对应一个槽位。
|
||||
std::any IRGenImpl::visitVarDef(SysYParser::VarDefContext* ctx) {
|
||||
const std::string name = ctx->ID()->getText();
|
||||
auto ty = BaseTypeFromDecl(
|
||||
dynamic_cast<SysYParser::VarDeclContext*>(ctx->parent)->btype());
|
||||
const auto dims = ctx->exp();
|
||||
for (auto it = dims.rbegin(); it != dims.rend(); ++it) {
|
||||
auto* dim = EvalConstExpr(**it);
|
||||
if (auto* ci = dynamic_cast<ir::ConstantInt*>(dim)) {
|
||||
ty = ir::ArrayType::Get(ty, ci->GetValue());
|
||||
continue;
|
||||
}
|
||||
throw std::runtime_error(FormatError("irgen", "数组维度必须是整型常量"));
|
||||
if (!ctx) {
|
||||
throw std::runtime_error(FormatError("irgen", "缺少变量定义"));
|
||||
}
|
||||
ir::Value* slot = nullptr;
|
||||
|
||||
if (is_global_scope_) {
|
||||
ir::ConstantValue* init = nullptr;
|
||||
if (ctx->initValue() && ctx->initValue()->exp()) {
|
||||
init = EvalConstExpr(*ctx->initValue()->exp());
|
||||
if (ty->IsInt32() && init->GetType()->IsFloat()) {
|
||||
init = module_.GetContext().GetConstInt(
|
||||
static_cast<int>(static_cast<ir::ConstantFloat*>(init)->GetValue()));
|
||||
} else if (ty->IsFloat() && init->GetType()->IsInt32()) {
|
||||
init = module_.GetContext().GetConstFloat(
|
||||
static_cast<float>(static_cast<ir::ConstantInt*>(init)->GetValue()));
|
||||
}
|
||||
}
|
||||
slot = module_.CreateGlobalValue(name, StorageType(ty), init);
|
||||
} else {
|
||||
slot = builder_.CreateAlloca(StorageType(ty), name);
|
||||
ZeroInitializeLocal(slot, ty);
|
||||
if (ctx->initValue()) EmitLocalInitValue(slot, ty, ctx->initValue());
|
||||
if (!ctx->lValue()) {
|
||||
throw std::runtime_error(FormatError("irgen", "变量声明缺少名称"));
|
||||
}
|
||||
|
||||
GetLValueName(*ctx->lValue());
|
||||
if (storage_map_.find(ctx) != storage_map_.end()) {
|
||||
throw std::runtime_error(FormatError("irgen", "声明重复生成存储槽位"));
|
||||
}
|
||||
auto* slot = builder_.CreateAllocaI32(module_.GetContext().NextTemp());
|
||||
storage_map_[ctx] = slot;
|
||||
|
||||
ir::Value* init = nullptr;
|
||||
if (auto* init_value = ctx->initValue()) {
|
||||
if (!init_value->exp()) {
|
||||
throw std::runtime_error(FormatError("irgen", "当前不支持聚合初始化"));
|
||||
}
|
||||
init = EvalExpr(*init_value->exp());
|
||||
} else {
|
||||
init = builder_.CreateConstInt(0);
|
||||
}
|
||||
builder_.CreateStore(init, slot);
|
||||
return {};
|
||||
}
|
||||
|
||||
@@ -6,27 +6,9 @@
|
||||
#include "ir/IR.h"
|
||||
#include "utils/Log.h"
|
||||
|
||||
static void PredeclareLibraryFunctions(ir::Module& module) {
|
||||
module.CreateFunction("getint", ir::Type::GetInt32Type(), {});
|
||||
module.CreateFunction("getch", ir::Type::GetInt32Type(), {});
|
||||
module.CreateFunction("getfloat", ir::Type::GetFloatType(), {});
|
||||
module.CreateFunction("getarray", ir::Type::GetInt32Type(), {ir::Type::GetPtrInt32Type()});
|
||||
module.CreateFunction("getfarray", ir::Type::GetInt32Type(), {ir::Type::GetPtrFloatType()});
|
||||
module.CreateFunction("putint", ir::Type::GetVoidType(), {ir::Type::GetInt32Type()});
|
||||
module.CreateFunction("putch", ir::Type::GetVoidType(), {ir::Type::GetInt32Type()});
|
||||
module.CreateFunction("putfloat", ir::Type::GetVoidType(), {ir::Type::GetFloatType()});
|
||||
module.CreateFunction("putarray", ir::Type::GetVoidType(), {ir::Type::GetInt32Type(), ir::Type::GetPtrInt32Type()});
|
||||
module.CreateFunction("putfarray", ir::Type::GetVoidType(), {ir::Type::GetInt32Type(), ir::Type::GetPtrFloatType()});
|
||||
module.CreateFunction("starttime", ir::Type::GetVoidType(), {});
|
||||
module.CreateFunction("stoptime", ir::Type::GetVoidType(), {});
|
||||
// putf is special, but for now we might not support it fully or just declare it simply
|
||||
// module.CreateFunction("putf", ...);
|
||||
}
|
||||
|
||||
std::unique_ptr<ir::Module> GenerateIR(SysYParser::CompUnitContext& tree,
|
||||
const SemanticContext& sema) {
|
||||
auto module = std::make_unique<ir::Module>();
|
||||
PredeclareLibraryFunctions(*module);
|
||||
IRGenImpl gen(*module, sema);
|
||||
tree.accept(&gen);
|
||||
return module;
|
||||
|
||||
@@ -1,724 +1,80 @@
|
||||
#include "irgen/IRGen.h"
|
||||
|
||||
#include <cmath>
|
||||
#include <stdexcept>
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
|
||||
#include "SysYParser.h"
|
||||
#include "ir/IR.h"
|
||||
#include "utils/Log.h"
|
||||
|
||||
namespace {
|
||||
|
||||
bool IsZero(const ir::ConstantValue* value) {
|
||||
if (auto* ci = dynamic_cast<const ir::ConstantInt*>(value)) {
|
||||
return ci->GetValue() == 0;
|
||||
}
|
||||
if (auto* cf = dynamic_cast<const ir::ConstantFloat*>(value)) {
|
||||
return cf->GetValue() == 0.0f;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool IsTruthy(const ir::ConstantValue* value) {
|
||||
return !IsZero(value);
|
||||
}
|
||||
|
||||
int AsInt(const ir::ConstantValue* value) {
|
||||
if (auto* ci = dynamic_cast<const ir::ConstantInt*>(value)) {
|
||||
return ci->GetValue();
|
||||
}
|
||||
if (auto* cf = dynamic_cast<const ir::ConstantFloat*>(value)) {
|
||||
return static_cast<int>(cf->GetValue());
|
||||
}
|
||||
throw std::runtime_error(FormatError("irgen", "无法将常量转换为 int"));
|
||||
}
|
||||
|
||||
float AsFloat(const ir::ConstantValue* value) {
|
||||
if (auto* cf = dynamic_cast<const ir::ConstantFloat*>(value)) {
|
||||
return cf->GetValue();
|
||||
}
|
||||
if (auto* ci = dynamic_cast<const ir::ConstantInt*>(value)) {
|
||||
return static_cast<float>(ci->GetValue());
|
||||
}
|
||||
throw std::runtime_error(FormatError("irgen", "无法将常量转换为 float"));
|
||||
}
|
||||
|
||||
std::shared_ptr<ir::Type> ScalarPointerType(std::shared_ptr<ir::Type> ty) {
|
||||
if (ty->IsInt32()) return ir::Type::GetPtrInt32Type();
|
||||
if (ty->IsFloat()) return ir::Type::GetPtrFloatType();
|
||||
return ty;
|
||||
}
|
||||
|
||||
std::shared_ptr<ir::Type> CommonArithType(ir::Value* lhs, ir::Value* rhs) {
|
||||
if (lhs->GetType()->IsFloat() || rhs->GetType()->IsFloat()) {
|
||||
return ir::Type::GetFloatType();
|
||||
}
|
||||
return ir::Type::GetInt32Type();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// 表达式生成当前也只实现了很小的一个子集。
|
||||
// 目前支持:
|
||||
// - 整数字面量
|
||||
// - 普通局部变量读取
|
||||
// - 括号表达式
|
||||
// - 二元加法
|
||||
//
|
||||
// 还未支持:
|
||||
// - 减乘除与一元运算
|
||||
// - 赋值表达式
|
||||
// - 函数调用
|
||||
// - 数组、指针、下标访问
|
||||
// - 条件与比较表达式
|
||||
// - ...
|
||||
ir::Value* IRGenImpl::EvalExpr(SysYParser::ExpContext& expr) {
|
||||
return std::any_cast<ir::Value*>(expr.accept(this));
|
||||
}
|
||||
|
||||
ir::ConstantValue* IRGenImpl::EvalConstExpr(SysYParser::ExpContext& expr) {
|
||||
class ConstExprVisitor final : public SysYBaseVisitor {
|
||||
public:
|
||||
ConstExprVisitor(ir::Module& module, const SemanticContext& sema)
|
||||
: module_(module), sema_(sema) {}
|
||||
|
||||
std::any visitParenExp(SysYParser::ParenExpContext* ctx) override {
|
||||
return Eval(*ctx->exp());
|
||||
}
|
||||
|
||||
std::any visitNumberExp(SysYParser::NumberExpContext* ctx) override {
|
||||
if (ctx->number()->ILITERAL()) {
|
||||
const std::string text = ctx->number()->ILITERAL()->getText();
|
||||
int value = 0;
|
||||
if (text.size() > 2 && (text[1] == 'x' || text[1] == 'X')) {
|
||||
value = std::stoi(text, nullptr, 16);
|
||||
} else if (text.size() > 1 && text[0] == '0') {
|
||||
value = std::stoi(text, nullptr, 8);
|
||||
} else {
|
||||
value = std::stoi(text, nullptr, 10);
|
||||
}
|
||||
return static_cast<ir::ConstantValue*>(module_.GetContext().GetConstInt(value));
|
||||
}
|
||||
return static_cast<ir::ConstantValue*>(
|
||||
module_.GetContext().GetConstFloat(std::stof(ctx->number()->FLITERAL()->getText())));
|
||||
}
|
||||
|
||||
std::any visitLValueExp(SysYParser::LValueExpContext* ctx) override {
|
||||
auto* def = sema_.ResolveLValue(ctx->lValue());
|
||||
if (!def) {
|
||||
throw std::runtime_error(FormatError("irgen", "常量表达式引用了未绑定左值"));
|
||||
}
|
||||
if (!ctx->lValue()->exp().empty()) {
|
||||
throw std::runtime_error(
|
||||
FormatError("irgen", "暂不支持在常量表达式中访问数组元素"));
|
||||
}
|
||||
if (auto* const_def = dynamic_cast<SysYParser::ConstDefContext*>(def)) {
|
||||
if (!const_def->initValue() || !const_def->initValue()->exp()) {
|
||||
throw std::runtime_error(
|
||||
FormatError("irgen", "常量缺少标量初始化表达式"));
|
||||
}
|
||||
return Eval(*const_def->initValue()->exp());
|
||||
}
|
||||
throw std::runtime_error(
|
||||
FormatError("irgen", "全局/常量表达式必须是编译期常量"));
|
||||
}
|
||||
|
||||
std::any visitUnaryAddExp(SysYParser::UnaryAddExpContext* ctx) override {
|
||||
return Eval(*ctx->exp());
|
||||
}
|
||||
|
||||
std::any visitUnarySubExp(SysYParser::UnarySubExpContext* ctx) override {
|
||||
auto* value = Eval(*ctx->exp());
|
||||
if (value->GetType()->IsFloat()) {
|
||||
return static_cast<ir::ConstantValue*>(
|
||||
module_.GetContext().GetConstFloat(-AsFloat(value)));
|
||||
}
|
||||
return static_cast<ir::ConstantValue*>(
|
||||
module_.GetContext().GetConstInt(-AsInt(value)));
|
||||
}
|
||||
|
||||
std::any visitNotExp(SysYParser::NotExpContext* ctx) override {
|
||||
return static_cast<ir::ConstantValue*>(
|
||||
module_.GetContext().GetConstInt(IsTruthy(Eval(*ctx->exp())) ? 0 : 1));
|
||||
}
|
||||
|
||||
std::any visitAddExp(SysYParser::AddExpContext* ctx) override {
|
||||
auto* lhs = Eval(*ctx->exp(0));
|
||||
auto* rhs = Eval(*ctx->exp(1));
|
||||
if (lhs->GetType()->IsFloat() || rhs->GetType()->IsFloat()) {
|
||||
return static_cast<ir::ConstantValue*>(
|
||||
module_.GetContext().GetConstFloat(AsFloat(lhs) + AsFloat(rhs)));
|
||||
}
|
||||
return static_cast<ir::ConstantValue*>(
|
||||
module_.GetContext().GetConstInt(AsInt(lhs) + AsInt(rhs)));
|
||||
}
|
||||
|
||||
std::any visitSubExp(SysYParser::SubExpContext* ctx) override {
|
||||
auto* lhs = Eval(*ctx->exp(0));
|
||||
auto* rhs = Eval(*ctx->exp(1));
|
||||
if (lhs->GetType()->IsFloat() || rhs->GetType()->IsFloat()) {
|
||||
return static_cast<ir::ConstantValue*>(
|
||||
module_.GetContext().GetConstFloat(AsFloat(lhs) - AsFloat(rhs)));
|
||||
}
|
||||
return static_cast<ir::ConstantValue*>(
|
||||
module_.GetContext().GetConstInt(AsInt(lhs) - AsInt(rhs)));
|
||||
}
|
||||
|
||||
std::any visitMulExp(SysYParser::MulExpContext* ctx) override {
|
||||
auto* lhs = Eval(*ctx->exp(0));
|
||||
auto* rhs = Eval(*ctx->exp(1));
|
||||
if (lhs->GetType()->IsFloat() || rhs->GetType()->IsFloat()) {
|
||||
return static_cast<ir::ConstantValue*>(
|
||||
module_.GetContext().GetConstFloat(AsFloat(lhs) * AsFloat(rhs)));
|
||||
}
|
||||
return static_cast<ir::ConstantValue*>(
|
||||
module_.GetContext().GetConstInt(AsInt(lhs) * AsInt(rhs)));
|
||||
}
|
||||
|
||||
std::any visitDivExp(SysYParser::DivExpContext* ctx) override {
|
||||
auto* lhs = Eval(*ctx->exp(0));
|
||||
auto* rhs = Eval(*ctx->exp(1));
|
||||
if (lhs->GetType()->IsFloat() || rhs->GetType()->IsFloat()) {
|
||||
const float rv = AsFloat(rhs);
|
||||
return static_cast<ir::ConstantValue*>(module_.GetContext().GetConstFloat(
|
||||
rv == 0.0f ? 0.0f : AsFloat(lhs) / rv));
|
||||
}
|
||||
const int rv = AsInt(rhs);
|
||||
return static_cast<ir::ConstantValue*>(
|
||||
module_.GetContext().GetConstInt(rv == 0 ? 0 : AsInt(lhs) / rv));
|
||||
}
|
||||
|
||||
std::any visitModExp(SysYParser::ModExpContext* ctx) override {
|
||||
auto* lhs = Eval(*ctx->exp(0));
|
||||
auto* rhs = Eval(*ctx->exp(1));
|
||||
return static_cast<ir::ConstantValue*>(module_.GetContext().GetConstInt(
|
||||
AsInt(rhs) == 0 ? 0 : AsInt(lhs) % AsInt(rhs)));
|
||||
}
|
||||
|
||||
std::any visitLtExp(SysYParser::LtExpContext* ctx) override {
|
||||
return EvalCmpImpl(*ctx->exp(0), *ctx->exp(1), ir::Opcode::ICmpLT);
|
||||
}
|
||||
std::any visitLeExp(SysYParser::LeExpContext* ctx) override {
|
||||
return EvalCmpImpl(*ctx->exp(0), *ctx->exp(1), ir::Opcode::ICmpLE);
|
||||
}
|
||||
std::any visitGtExp(SysYParser::GtExpContext* ctx) override {
|
||||
return EvalCmpImpl(*ctx->exp(0), *ctx->exp(1), ir::Opcode::ICmpGT);
|
||||
}
|
||||
std::any visitGeExp(SysYParser::GeExpContext* ctx) override {
|
||||
return EvalCmpImpl(*ctx->exp(0), *ctx->exp(1), ir::Opcode::ICmpGE);
|
||||
}
|
||||
std::any visitEqExp(SysYParser::EqExpContext* ctx) override {
|
||||
return EvalCmpImpl(*ctx->exp(0), *ctx->exp(1), ir::Opcode::ICmpEQ);
|
||||
}
|
||||
std::any visitNeExp(SysYParser::NeExpContext* ctx) override {
|
||||
return EvalCmpImpl(*ctx->exp(0), *ctx->exp(1), ir::Opcode::ICmpNE);
|
||||
}
|
||||
|
||||
std::any visitAndExp(SysYParser::AndExpContext* ctx) override {
|
||||
auto* lhs = Eval(*ctx->exp(0));
|
||||
if (!IsTruthy(lhs)) {
|
||||
return static_cast<ir::ConstantValue*>(module_.GetContext().GetConstInt(0));
|
||||
}
|
||||
return static_cast<ir::ConstantValue*>(
|
||||
module_.GetContext().GetConstInt(IsTruthy(Eval(*ctx->exp(1))) ? 1 : 0));
|
||||
}
|
||||
|
||||
std::any visitOrExp(SysYParser::OrExpContext* ctx) override {
|
||||
auto* lhs = Eval(*ctx->exp(0));
|
||||
if (IsTruthy(lhs)) {
|
||||
return static_cast<ir::ConstantValue*>(module_.GetContext().GetConstInt(1));
|
||||
}
|
||||
return static_cast<ir::ConstantValue*>(
|
||||
module_.GetContext().GetConstInt(IsTruthy(Eval(*ctx->exp(1))) ? 1 : 0));
|
||||
}
|
||||
|
||||
ir::ConstantValue* Eval(SysYParser::ExpContext& ctx) {
|
||||
return std::any_cast<ir::ConstantValue*>(ctx.accept(this));
|
||||
}
|
||||
|
||||
private:
|
||||
ir::ConstantValue* EvalCmpImpl(SysYParser::ExpContext& lhs_ctx,
|
||||
SysYParser::ExpContext& rhs_ctx,
|
||||
ir::Opcode op) {
|
||||
auto* lhs = Eval(lhs_ctx);
|
||||
auto* rhs = Eval(rhs_ctx);
|
||||
bool result = false;
|
||||
if (lhs->GetType()->IsFloat() || rhs->GetType()->IsFloat()) {
|
||||
const float a = AsFloat(lhs);
|
||||
const float b = AsFloat(rhs);
|
||||
switch (op) {
|
||||
case ir::Opcode::ICmpLT: result = a < b; break;
|
||||
case ir::Opcode::ICmpLE: result = a <= b; break;
|
||||
case ir::Opcode::ICmpGT: result = a > b; break;
|
||||
case ir::Opcode::ICmpGE: result = a >= b; break;
|
||||
case ir::Opcode::ICmpEQ: result = a == b; break;
|
||||
case ir::Opcode::ICmpNE: result = a != b; break;
|
||||
default: break;
|
||||
}
|
||||
} else {
|
||||
const int a = AsInt(lhs);
|
||||
const int b = AsInt(rhs);
|
||||
switch (op) {
|
||||
case ir::Opcode::ICmpLT: result = a < b; break;
|
||||
case ir::Opcode::ICmpLE: result = a <= b; break;
|
||||
case ir::Opcode::ICmpGT: result = a > b; break;
|
||||
case ir::Opcode::ICmpGE: result = a >= b; break;
|
||||
case ir::Opcode::ICmpEQ: result = a == b; break;
|
||||
case ir::Opcode::ICmpNE: result = a != b; break;
|
||||
default: break;
|
||||
}
|
||||
}
|
||||
return module_.GetContext().GetConstInt(result ? 1 : 0);
|
||||
}
|
||||
|
||||
ir::Module& module_;
|
||||
const SemanticContext& sema_;
|
||||
};
|
||||
|
||||
ConstExprVisitor visitor(module_, sema_);
|
||||
return visitor.Eval(expr);
|
||||
}
|
||||
|
||||
static ir::Value* CastValue(IRGenImpl& gen, ir::IRBuilder& builder, ir::Module& module,
|
||||
ir::Value* value, std::shared_ptr<ir::Type> target_ty) {
|
||||
if (value->GetType() == target_ty) return value;
|
||||
if (target_ty->IsFloat() && value->GetType()->IsInt32()) {
|
||||
if (auto* ci = dynamic_cast<ir::ConstantInt*>(value)) {
|
||||
return module.GetContext().GetConstFloat(static_cast<float>(ci->GetValue()));
|
||||
}
|
||||
return builder.CreateSIToFP(value, target_ty, module.GetContext().NextTemp());
|
||||
}
|
||||
if (target_ty->IsInt32() && value->GetType()->IsFloat()) {
|
||||
if (auto* cf = dynamic_cast<ir::ConstantFloat*>(value)) {
|
||||
return module.GetContext().GetConstInt(static_cast<int>(cf->GetValue()));
|
||||
}
|
||||
return builder.CreateFPToSI(value, target_ty, module.GetContext().NextTemp());
|
||||
}
|
||||
return value;
|
||||
}
|
||||
|
||||
std::any IRGenImpl::visitParenExp(SysYParser::ParenExpContext* ctx) {
|
||||
if (!ctx || !ctx->exp()) {
|
||||
throw std::runtime_error(FormatError("irgen", "非法括号表达式"));
|
||||
}
|
||||
return EvalExpr(*ctx->exp());
|
||||
}
|
||||
|
||||
std::any IRGenImpl::visitLValueExp(SysYParser::LValueExpContext* ctx) {
|
||||
auto* def = sema_.ResolveLValue(ctx->lValue());
|
||||
if (def && IsArrayLikeDef(def) && ctx->lValue()->exp().size() < GetArrayRank(def)) {
|
||||
return DecayArrayPtr(ctx->lValue());
|
||||
}
|
||||
ir::Value* ptr = GetLValuePtr(ctx->lValue());
|
||||
return static_cast<ir::Value*>(
|
||||
builder_.CreateLoad(ptr, module_.GetContext().NextTemp()));
|
||||
}
|
||||
|
||||
std::any IRGenImpl::visitNumberExp(SysYParser::NumberExpContext* ctx) {
|
||||
return static_cast<ir::Value*>(EvalConstExpr(*ctx));
|
||||
}
|
||||
|
||||
std::any IRGenImpl::visitFuncCallExp(SysYParser::FuncCallExpContext* ctx) {
|
||||
ir::Function* target_func = nullptr;
|
||||
if (auto* def = sema_.ResolveFuncCall(ctx)) {
|
||||
const std::string func_name = def->ID()->getText();
|
||||
for (const auto& f : module_.GetFunctions()) {
|
||||
if (f->GetName() == func_name) {
|
||||
target_func = f.get();
|
||||
break;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
const std::string func_name = ctx->ID()->getText();
|
||||
for (const auto& f : module_.GetFunctions()) {
|
||||
if (f->GetName() == func_name) {
|
||||
target_func = f.get();
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!target_func) {
|
||||
throw std::runtime_error(FormatError("irgen", "找不到函数: " + ctx->ID()->getText()));
|
||||
}
|
||||
|
||||
std::vector<ir::Value*> args;
|
||||
const auto& arg_types = target_func->GetArguments();
|
||||
if (ctx->funcRParams()) {
|
||||
const auto& exps = ctx->funcRParams()->exp();
|
||||
for (size_t i = 0; i < exps.size(); ++i) {
|
||||
ir::Value* arg = EvalExpr(*exps[i]);
|
||||
if (i < arg_types.size()) {
|
||||
arg = CastValue(*this, builder_, module_, arg, arg_types[i]->GetType());
|
||||
}
|
||||
args.push_back(arg);
|
||||
}
|
||||
}
|
||||
|
||||
return static_cast<ir::Value*>(
|
||||
builder_.CreateCall(target_func, args,
|
||||
target_func->GetType()->IsVoid()
|
||||
? ""
|
||||
: module_.GetContext().NextTemp()));
|
||||
}
|
||||
|
||||
std::any IRGenImpl::visitNotExp(SysYParser::NotExpContext* ctx) {
|
||||
ir::Value* val = EvalExpr(*ctx->exp());
|
||||
if (auto* cv = dynamic_cast<ir::ConstantValue*>(val)) {
|
||||
return static_cast<ir::Value*>(
|
||||
module_.GetContext().GetConstInt(IsTruthy(cv) ? 0 : 1));
|
||||
}
|
||||
ir::Value* cond = ToI1(val);
|
||||
ir::Value* res =
|
||||
builder_.CreateICmp(ir::Opcode::ICmpEQ, cond, builder_.CreateConstInt(0),
|
||||
module_.GetContext().NextTemp());
|
||||
return ToI32(res);
|
||||
}
|
||||
|
||||
std::any IRGenImpl::visitUnaryAddExp(SysYParser::UnaryAddExpContext* ctx) {
|
||||
return EvalExpr(*ctx->exp());
|
||||
}
|
||||
|
||||
std::any IRGenImpl::visitUnarySubExp(SysYParser::UnarySubExpContext* ctx) {
|
||||
ir::Value* val = EvalExpr(*ctx->exp());
|
||||
if (auto* ci = dynamic_cast<ir::ConstantInt*>(val)) {
|
||||
return static_cast<ir::Value*>(module_.GetContext().GetConstInt(-ci->GetValue()));
|
||||
}
|
||||
if (auto* cf = dynamic_cast<ir::ConstantFloat*>(val)) {
|
||||
return static_cast<ir::Value*>(module_.GetContext().GetConstFloat(-cf->GetValue()));
|
||||
}
|
||||
if (val->GetType()->IsFloat()) {
|
||||
return static_cast<ir::Value*>(builder_.CreateFSub(
|
||||
builder_.CreateConstFloat(0.0f), val, module_.GetContext().NextTemp()));
|
||||
}
|
||||
return static_cast<ir::Value*>(builder_.CreateSub(
|
||||
builder_.CreateConstInt(0), val, module_.GetContext().NextTemp()));
|
||||
}
|
||||
|
||||
#define DEFINE_ARITH_VISITOR(name, int_opcode, float_opcode) \
|
||||
std::any IRGenImpl::visit##name##Exp(SysYParser::name##ExpContext* ctx) { \
|
||||
ir::Value* lhs = EvalExpr(*ctx->exp(0)); \
|
||||
ir::Value* rhs = EvalExpr(*ctx->exp(1)); \
|
||||
const auto common_ty = CommonArithType(lhs, rhs); \
|
||||
lhs = CastValue(*this, builder_, module_, lhs, common_ty); \
|
||||
rhs = CastValue(*this, builder_, module_, rhs, common_ty); \
|
||||
if (auto* lconst = dynamic_cast<ir::ConstantValue*>(lhs)) { \
|
||||
if (auto* rconst = dynamic_cast<ir::ConstantValue*>(rhs)) { \
|
||||
if (common_ty->IsFloat()) { \
|
||||
const float lv = AsFloat(lconst); \
|
||||
const float rv = AsFloat(rconst); \
|
||||
if constexpr (ir::Opcode::float_opcode == ir::Opcode::FAdd) \
|
||||
return static_cast<ir::Value*>(module_.GetContext().GetConstFloat(lv + rv)); \
|
||||
if constexpr (ir::Opcode::float_opcode == ir::Opcode::FSub) \
|
||||
return static_cast<ir::Value*>(module_.GetContext().GetConstFloat(lv - rv)); \
|
||||
if constexpr (ir::Opcode::float_opcode == ir::Opcode::FMul) \
|
||||
return static_cast<ir::Value*>(module_.GetContext().GetConstFloat(lv * rv)); \
|
||||
return static_cast<ir::Value*>(module_.GetContext().GetConstFloat(rv == 0.0f ? 0.0f : lv / rv)); \
|
||||
} \
|
||||
const int lv = AsInt(lconst); \
|
||||
const int rv = AsInt(rconst); \
|
||||
if constexpr (ir::Opcode::int_opcode == ir::Opcode::Add) \
|
||||
return static_cast<ir::Value*>(module_.GetContext().GetConstInt(lv + rv)); \
|
||||
if constexpr (ir::Opcode::int_opcode == ir::Opcode::Sub) \
|
||||
return static_cast<ir::Value*>(module_.GetContext().GetConstInt(lv - rv)); \
|
||||
if constexpr (ir::Opcode::int_opcode == ir::Opcode::Mul) \
|
||||
return static_cast<ir::Value*>(module_.GetContext().GetConstInt(lv * rv)); \
|
||||
return static_cast<ir::Value*>(module_.GetContext().GetConstInt(rv == 0 ? 0 : lv / rv)); \
|
||||
} \
|
||||
} \
|
||||
if (common_ty->IsFloat()) { \
|
||||
if constexpr (ir::Opcode::float_opcode == ir::Opcode::FAdd) \
|
||||
return static_cast<ir::Value*>(builder_.CreateFAdd(lhs, rhs, module_.GetContext().NextTemp())); \
|
||||
if constexpr (ir::Opcode::float_opcode == ir::Opcode::FSub) \
|
||||
return static_cast<ir::Value*>(builder_.CreateFSub(lhs, rhs, module_.GetContext().NextTemp())); \
|
||||
if constexpr (ir::Opcode::float_opcode == ir::Opcode::FMul) \
|
||||
return static_cast<ir::Value*>(builder_.CreateFMul(lhs, rhs, module_.GetContext().NextTemp())); \
|
||||
return static_cast<ir::Value*>(builder_.CreateFDiv(lhs, rhs, module_.GetContext().NextTemp())); \
|
||||
} \
|
||||
return static_cast<ir::Value*>(builder_.CreateBinary(ir::Opcode::int_opcode, lhs, rhs, module_.GetContext().NextTemp())); \
|
||||
}
|
||||
|
||||
DEFINE_ARITH_VISITOR(Add, Add, FAdd)
|
||||
DEFINE_ARITH_VISITOR(Sub, Sub, FSub)
|
||||
DEFINE_ARITH_VISITOR(Mul, Mul, FMul)
|
||||
DEFINE_ARITH_VISITOR(Div, Div, FDiv)
|
||||
|
||||
std::any IRGenImpl::visitModExp(SysYParser::ModExpContext* ctx) {
|
||||
ir::Value* lhs = CastValue(*this, builder_, module_, EvalExpr(*ctx->exp(0)),
|
||||
ir::Type::GetInt32Type());
|
||||
ir::Value* rhs = CastValue(*this, builder_, module_, EvalExpr(*ctx->exp(1)),
|
||||
ir::Type::GetInt32Type());
|
||||
if (auto* lconst = dynamic_cast<ir::ConstantValue*>(lhs)) {
|
||||
if (auto* rconst = dynamic_cast<ir::ConstantValue*>(rhs)) {
|
||||
const int rv = AsInt(rconst);
|
||||
return static_cast<ir::Value*>(
|
||||
module_.GetContext().GetConstInt(rv == 0 ? 0 : AsInt(lconst) % rv));
|
||||
}
|
||||
if (!ctx || !ctx->number() || !ctx->number()->ILITERAL()) {
|
||||
throw std::runtime_error(FormatError("irgen", "当前仅支持整数字面量"));
|
||||
}
|
||||
return static_cast<ir::Value*>(
|
||||
builder_.CreateMod(lhs, rhs, module_.GetContext().NextTemp()));
|
||||
builder_.CreateConstInt(std::stoi(ctx->number()->getText())));
|
||||
}
|
||||
|
||||
#define DEFINE_CMP_VISITOR(name, int_opcode, float_opcode, cmp_op) \
|
||||
std::any IRGenImpl::visit##name##Exp(SysYParser::name##ExpContext* ctx) { \
|
||||
ir::Value* lhs = EvalExpr(*ctx->exp(0)); \
|
||||
ir::Value* rhs = EvalExpr(*ctx->exp(1)); \
|
||||
const auto common_ty = CommonArithType(lhs, rhs); \
|
||||
lhs = CastValue(*this, builder_, module_, lhs, common_ty); \
|
||||
rhs = CastValue(*this, builder_, module_, rhs, common_ty); \
|
||||
if (auto* lconst = dynamic_cast<ir::ConstantValue*>(lhs)) { \
|
||||
if (auto* rconst = dynamic_cast<ir::ConstantValue*>(rhs)) { \
|
||||
const bool result = common_ty->IsFloat() ? (AsFloat(lconst) cmp_op AsFloat(rconst)) \
|
||||
: (AsInt(lconst) cmp_op AsInt(rconst)); \
|
||||
return static_cast<ir::Value*>(module_.GetContext().GetConstInt(result ? 1 : 0)); \
|
||||
} \
|
||||
} \
|
||||
if (common_ty->IsFloat()) { \
|
||||
return static_cast<ir::Value*>(builder_.CreateFCmp(ir::Opcode::float_opcode, lhs, rhs, module_.GetContext().NextTemp())); \
|
||||
} \
|
||||
return static_cast<ir::Value*>(builder_.CreateICmp(ir::Opcode::int_opcode, lhs, rhs, module_.GetContext().NextTemp())); \
|
||||
// 变量使用的处理流程:
|
||||
// 1. 先通过语义分析结果把变量使用绑定回声明;
|
||||
// 2. 再通过 storage_map_ 找到该声明对应的栈槽位;
|
||||
// 3. 最后生成 load,把内存中的值读出来。
|
||||
//
|
||||
// 因此当前 IRGen 自己不再做名字查找,而是直接消费 Sema 的绑定结果。
|
||||
std::any IRGenImpl::visitVarExp(SysYParser::VarExpContext* ctx) {
|
||||
if (!ctx || !ctx->var() || !ctx->var()->ID()) {
|
||||
throw std::runtime_error(FormatError("irgen", "当前仅支持普通整型变量"));
|
||||
}
|
||||
|
||||
DEFINE_CMP_VISITOR(Lt, ICmpLT, FCmpLT, <)
|
||||
DEFINE_CMP_VISITOR(Le, ICmpLE, FCmpLE, <=)
|
||||
DEFINE_CMP_VISITOR(Gt, ICmpGT, FCmpGT, >)
|
||||
DEFINE_CMP_VISITOR(Ge, ICmpGE, FCmpGE, >=)
|
||||
DEFINE_CMP_VISITOR(Eq, ICmpEQ, FCmpEQ, ==)
|
||||
DEFINE_CMP_VISITOR(Ne, ICmpNE, FCmpNE, !=)
|
||||
|
||||
std::any IRGenImpl::visitAndExp(SysYParser::AndExpContext* ctx) {
|
||||
if (!builder_.GetInsertBlock()) {
|
||||
return static_cast<ir::Value*>(EvalConstExpr(*ctx));
|
||||
auto* decl = sema_.ResolveVarUse(ctx->var());
|
||||
if (!decl) {
|
||||
throw std::runtime_error(
|
||||
FormatError("irgen",
|
||||
"变量使用缺少语义绑定: " + ctx->var()->ID()->getText()));
|
||||
}
|
||||
|
||||
ir::Value* lhs = EvalExpr(*ctx->exp(0));
|
||||
if (auto* c = dynamic_cast<ir::ConstantValue*>(lhs); c && !IsTruthy(c)) {
|
||||
return static_cast<ir::Value*>(module_.GetContext().GetConstInt(0));
|
||||
}
|
||||
|
||||
const std::string suffix = module_.GetContext().NextTemp();
|
||||
ir::BasicBlock* rhs_bb = func_->CreateBlock("and.rhs." + suffix);
|
||||
ir::BasicBlock* merge_bb = func_->CreateBlock("and.merge." + suffix);
|
||||
auto* res_ptr = builder_.CreateAllocaI32(module_.GetContext().NextTemp());
|
||||
|
||||
builder_.CreateStore(ToI32(ToI1(lhs)), res_ptr);
|
||||
builder_.CreateCondBr(ToI1(lhs), rhs_bb, merge_bb);
|
||||
|
||||
builder_.SetInsertPoint(rhs_bb);
|
||||
ir::Value* rhs = EvalExpr(*ctx->exp(1));
|
||||
builder_.CreateStore(ToI32(ToI1(rhs)), res_ptr);
|
||||
builder_.CreateBr(merge_bb);
|
||||
|
||||
builder_.SetInsertPoint(merge_bb);
|
||||
return static_cast<ir::Value*>(
|
||||
builder_.CreateLoad(res_ptr, module_.GetContext().NextTemp()));
|
||||
}
|
||||
|
||||
std::any IRGenImpl::visitOrExp(SysYParser::OrExpContext* ctx) {
|
||||
if (!builder_.GetInsertBlock()) {
|
||||
return static_cast<ir::Value*>(EvalConstExpr(*ctx));
|
||||
}
|
||||
|
||||
ir::Value* lhs = EvalExpr(*ctx->exp(0));
|
||||
if (auto* c = dynamic_cast<ir::ConstantValue*>(lhs); c && IsTruthy(c)) {
|
||||
return static_cast<ir::Value*>(module_.GetContext().GetConstInt(1));
|
||||
}
|
||||
|
||||
const std::string suffix = module_.GetContext().NextTemp();
|
||||
ir::BasicBlock* rhs_bb = func_->CreateBlock("or.rhs." + suffix);
|
||||
ir::BasicBlock* merge_bb = func_->CreateBlock("or.merge." + suffix);
|
||||
auto* res_ptr = builder_.CreateAllocaI32(module_.GetContext().NextTemp());
|
||||
|
||||
builder_.CreateStore(ToI32(ToI1(lhs)), res_ptr);
|
||||
builder_.CreateCondBr(ToI1(lhs), merge_bb, rhs_bb);
|
||||
|
||||
builder_.SetInsertPoint(rhs_bb);
|
||||
ir::Value* rhs = EvalExpr(*ctx->exp(1));
|
||||
builder_.CreateStore(ToI32(ToI1(rhs)), res_ptr);
|
||||
builder_.CreateBr(merge_bb);
|
||||
|
||||
builder_.SetInsertPoint(merge_bb);
|
||||
return static_cast<ir::Value*>(
|
||||
builder_.CreateLoad(res_ptr, module_.GetContext().NextTemp()));
|
||||
}
|
||||
|
||||
bool IRGenImpl::IsArrayLikeDef(antlr4::ParserRuleContext* def) const {
|
||||
if (auto* const_def = dynamic_cast<SysYParser::ConstDefContext*>(def)) {
|
||||
return !const_def->exp().empty();
|
||||
}
|
||||
if (auto* var_def = dynamic_cast<SysYParser::VarDefContext*>(def)) {
|
||||
return !var_def->exp().empty();
|
||||
}
|
||||
if (auto* param = dynamic_cast<SysYParser::FuncFParamContext*>(def)) {
|
||||
return !param->LBRACK().empty();
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
size_t IRGenImpl::GetArrayRank(antlr4::ParserRuleContext* def) const {
|
||||
if (auto* const_def = dynamic_cast<SysYParser::ConstDefContext*>(def)) {
|
||||
return const_def->exp().size();
|
||||
}
|
||||
if (auto* var_def = dynamic_cast<SysYParser::VarDefContext*>(def)) {
|
||||
return var_def->exp().size();
|
||||
}
|
||||
if (auto* param = dynamic_cast<SysYParser::FuncFParamContext*>(def)) {
|
||||
return param->LBRACK().size();
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
std::shared_ptr<ir::Type> IRGenImpl::GetDefType(antlr4::ParserRuleContext* def) const {
|
||||
std::shared_ptr<ir::Type> ty = ir::Type::GetInt32Type();
|
||||
auto apply_dims = [&](const auto& dims) {
|
||||
auto result = ty;
|
||||
for (auto it = dims.rbegin(); it != dims.rend(); ++it) {
|
||||
if constexpr (std::is_pointer_v<typename std::decay_t<decltype(*it)>>) {
|
||||
auto* dim_val = const_cast<IRGenImpl*>(this)->EvalConstExpr(**it);
|
||||
result = ir::ArrayType::Get(result, AsInt(dim_val));
|
||||
} else {
|
||||
auto* dim_val = const_cast<IRGenImpl*>(this)->EvalConstExpr(*it);
|
||||
result = ir::ArrayType::Get(result, AsInt(dim_val));
|
||||
}
|
||||
}
|
||||
return result;
|
||||
};
|
||||
|
||||
if (auto* const_def = dynamic_cast<SysYParser::ConstDefContext*>(def)) {
|
||||
auto* decl = dynamic_cast<SysYParser::ConstDeclContext*>(const_def->parent);
|
||||
ty = (decl && decl->btype() && decl->btype()->FLOAT()) ? ir::Type::GetFloatType()
|
||||
: ir::Type::GetInt32Type();
|
||||
return apply_dims(const_def->exp());
|
||||
}
|
||||
if (auto* var_def = dynamic_cast<SysYParser::VarDefContext*>(def)) {
|
||||
auto* decl = dynamic_cast<SysYParser::VarDeclContext*>(var_def->parent);
|
||||
ty = (decl && decl->btype() && decl->btype()->FLOAT()) ? ir::Type::GetFloatType()
|
||||
: ir::Type::GetInt32Type();
|
||||
return apply_dims(var_def->exp());
|
||||
}
|
||||
if (auto* param = dynamic_cast<SysYParser::FuncFParamContext*>(def)) {
|
||||
ty = param->btype()->FLOAT() ? ir::Type::GetFloatType() : ir::Type::GetInt32Type();
|
||||
if (param->LBRACK().empty()) return ty;
|
||||
for (int i = static_cast<int>(param->exp().size()) - 1; i >= 0; --i) {
|
||||
auto* dim_val = const_cast<IRGenImpl*>(this)->EvalConstExpr(*param->exp(i));
|
||||
ty = ir::ArrayType::Get(ty, AsInt(dim_val));
|
||||
}
|
||||
return ty;
|
||||
}
|
||||
return ty;
|
||||
}
|
||||
|
||||
ir::Value* IRGenImpl::DecayArrayPtr(SysYParser::LValueContext* ctx) {
|
||||
auto* def = sema_.ResolveLValue(ctx);
|
||||
if (!def) {
|
||||
throw std::runtime_error(FormatError("irgen", "数组退化失败: 未绑定定义"));
|
||||
}
|
||||
ir::Value* base_ptr = storage_map_.at(def);
|
||||
const auto base_ty = GetDefType(def);
|
||||
|
||||
if (dynamic_cast<SysYParser::FuncFParamContext*>(def)) {
|
||||
if (ctx->exp().empty()) return base_ptr;
|
||||
|
||||
ir::Value* offset = CastValue(*this, builder_, module_, EvalExpr(*ctx->exp(0)),
|
||||
ir::Type::GetInt32Type());
|
||||
auto cur_ty = base_ty;
|
||||
for (size_t i = 1; i < ctx->exp().size(); ++i) {
|
||||
if (!cur_ty->IsArray()) break;
|
||||
auto arr_ty = cur_ty->GetAsArrayType();
|
||||
ir::Value* stride = builder_.CreateConstInt(static_cast<int>(arr_ty->GetNumElements()));
|
||||
offset = builder_.CreateMul(offset, stride, module_.GetContext().NextTemp());
|
||||
offset = builder_.CreateAdd(
|
||||
offset,
|
||||
CastValue(*this, builder_, module_, EvalExpr(*ctx->exp(i)),
|
||||
ir::Type::GetInt32Type()),
|
||||
module_.GetContext().NextTemp());
|
||||
cur_ty = arr_ty->GetElementType();
|
||||
}
|
||||
return builder_.CreateGEP(ScalarPointerType(cur_ty), base_ptr, {offset},
|
||||
module_.GetContext().NextTemp());
|
||||
}
|
||||
|
||||
std::vector<ir::Value*> indices;
|
||||
indices.push_back(builder_.CreateConstInt(0));
|
||||
for (auto* exp : ctx->exp()) {
|
||||
indices.push_back(
|
||||
CastValue(*this, builder_, module_, EvalExpr(*exp), ir::Type::GetInt32Type()));
|
||||
}
|
||||
|
||||
auto cur_ty = base_ty;
|
||||
while (cur_ty->IsArray()) {
|
||||
cur_ty = cur_ty->GetAsArrayType()->GetElementType();
|
||||
}
|
||||
if (!ctx->exp().empty()) {
|
||||
return builder_.CreateGEP(ScalarPointerType(cur_ty), base_ptr, indices,
|
||||
module_.GetContext().NextTemp());
|
||||
}
|
||||
indices.push_back(builder_.CreateConstInt(0));
|
||||
return builder_.CreateGEP(ScalarPointerType(cur_ty), base_ptr, indices,
|
||||
module_.GetContext().NextTemp());
|
||||
}
|
||||
|
||||
ir::Value* IRGenImpl::GetLValuePtr(SysYParser::LValueContext* ctx) {
|
||||
auto* def = sema_.ResolveLValue(ctx);
|
||||
if (!def) {
|
||||
throw std::runtime_error(FormatError("irgen", "未定义的左值: " + ctx->ID()->getText()));
|
||||
}
|
||||
auto it = storage_map_.find(def);
|
||||
auto it = storage_map_.find(decl);
|
||||
if (it == storage_map_.end()) {
|
||||
throw std::runtime_error(
|
||||
FormatError("irgen", "左值缺少存储槽位: " + ctx->ID()->getText()));
|
||||
FormatError("irgen",
|
||||
"变量声明缺少存储槽位: " + ctx->var()->ID()->getText()));
|
||||
}
|
||||
|
||||
ir::Value* base_ptr = it->second;
|
||||
if (ctx->exp().empty()) return base_ptr;
|
||||
|
||||
if (dynamic_cast<SysYParser::FuncFParamContext*>(def)) {
|
||||
return DecayArrayPtr(ctx);
|
||||
}
|
||||
|
||||
const auto base_ty = GetDefType(def);
|
||||
std::vector<ir::Value*> indices;
|
||||
indices.push_back(builder_.CreateConstInt(0));
|
||||
for (auto* exp : ctx->exp()) {
|
||||
indices.push_back(
|
||||
CastValue(*this, builder_, module_, EvalExpr(*exp), ir::Type::GetInt32Type()));
|
||||
}
|
||||
|
||||
auto cur_ty = base_ty;
|
||||
for (size_t i = 0; i < ctx->exp().size(); ++i) {
|
||||
if (!cur_ty->IsArray()) {
|
||||
throw std::runtime_error(FormatError("irgen", "数组下标层数超出定义"));
|
||||
}
|
||||
cur_ty = cur_ty->GetAsArrayType()->GetElementType();
|
||||
}
|
||||
return builder_.CreateGEP(ScalarPointerType(cur_ty), base_ptr, indices,
|
||||
module_.GetContext().NextTemp());
|
||||
return static_cast<ir::Value*>(
|
||||
builder_.CreateLoad(it->second, module_.GetContext().NextTemp()));
|
||||
}
|
||||
|
||||
ir::Value* IRGenImpl::ToI1(ir::Value* v) {
|
||||
if (auto* cv = dynamic_cast<ir::ConstantValue*>(v)) {
|
||||
return module_.GetContext().GetConstInt(IsTruthy(cv) ? 1 : 0);
|
||||
}
|
||||
if (auto* inst = dynamic_cast<ir::Instruction*>(v)) {
|
||||
switch (inst->GetOpcode()) {
|
||||
case ir::Opcode::ICmpEQ:
|
||||
case ir::Opcode::ICmpNE:
|
||||
case ir::Opcode::ICmpLT:
|
||||
case ir::Opcode::ICmpGT:
|
||||
case ir::Opcode::ICmpLE:
|
||||
case ir::Opcode::ICmpGE:
|
||||
case ir::Opcode::FCmpEQ:
|
||||
case ir::Opcode::FCmpNE:
|
||||
case ir::Opcode::FCmpLT:
|
||||
case ir::Opcode::FCmpGT:
|
||||
case ir::Opcode::FCmpLE:
|
||||
case ir::Opcode::FCmpGE:
|
||||
return v;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (v->GetType()->IsFloat()) {
|
||||
return builder_.CreateFCmp(ir::Opcode::FCmpNE, v, builder_.CreateConstFloat(0.0f),
|
||||
module_.GetContext().NextTemp());
|
||||
}
|
||||
if (v->GetType()->IsInt32()) {
|
||||
return builder_.CreateICmp(ir::Opcode::ICmpNE, v, builder_.CreateConstInt(0),
|
||||
module_.GetContext().NextTemp());
|
||||
}
|
||||
return v;
|
||||
}
|
||||
|
||||
ir::Value* IRGenImpl::ToI32(ir::Value* v) {
|
||||
if (v->GetType()->IsInt32()) return v;
|
||||
if (v->GetType()->IsFloat()) {
|
||||
return builder_.CreateFPToSI(v, ir::Type::GetInt32Type(),
|
||||
module_.GetContext().NextTemp());
|
||||
std::any IRGenImpl::visitAdditiveExp(SysYParser::AdditiveExpContext* ctx) {
|
||||
if (!ctx || !ctx->exp(0) || !ctx->exp(1)) {
|
||||
throw std::runtime_error(FormatError("irgen", "非法加法表达式"));
|
||||
}
|
||||
return builder_.CreateZExt(v, ir::Type::GetInt32Type(),
|
||||
module_.GetContext().NextTemp());
|
||||
ir::Value* lhs = EvalExpr(*ctx->exp(0));
|
||||
ir::Value* rhs = EvalExpr(*ctx->exp(1));
|
||||
return static_cast<ir::Value*>(
|
||||
builder_.CreateBinary(ir::Opcode::Add, lhs, rhs,
|
||||
module_.GetContext().NextTemp()));
|
||||
}
|
||||
|
||||
@@ -8,18 +8,10 @@
|
||||
|
||||
namespace {
|
||||
|
||||
std::shared_ptr<ir::Type> StorageType(std::shared_ptr<ir::Type> ty) {
|
||||
if (ty->IsInt32()) return ir::Type::GetPtrInt32Type();
|
||||
if (ty->IsFloat()) return ir::Type::GetPtrFloatType();
|
||||
return ty;
|
||||
}
|
||||
|
||||
void VerifyFunctionStructure(const ir::Function& func) {
|
||||
// 当前 IRGen 仍是单入口、顺序生成;这里在生成结束后补一层块终结校验。
|
||||
for (const auto& bb : func.GetBlocks()) {
|
||||
if (!bb || !bb->HasTerminator()) {
|
||||
// If a block doesn't have a terminator, it might be an empty function or
|
||||
// missing a return in a path. For SysY, we should at least have a default return for void.
|
||||
// But IRGen should have handled it.
|
||||
throw std::runtime_error(
|
||||
FormatError("irgen", "基本块未正确终结: " +
|
||||
(bb ? bb->GetName() : std::string("<null>"))));
|
||||
@@ -33,83 +25,63 @@ IRGenImpl::IRGenImpl(ir::Module& module, const SemanticContext& sema)
|
||||
: module_(module),
|
||||
sema_(sema),
|
||||
func_(nullptr),
|
||||
builder_(module.GetContext(), nullptr),
|
||||
is_global_scope_(true) {}
|
||||
builder_(module.GetContext(), nullptr) {}
|
||||
|
||||
// 编译单元的 IR 生成当前只实现了最小功能:
|
||||
// - Module 已在 GenerateIR 中创建,这里只负责继续生成其中的内容;
|
||||
// - 当前会读取编译单元中的函数定义,并交给 visitFuncDef 生成函数 IR;
|
||||
//
|
||||
// 当前还没有实现:
|
||||
// - 多个函数定义的遍历与生成;
|
||||
// - 全局变量、全局常量的 IR 生成。
|
||||
std::any IRGenImpl::visitCompUnit(SysYParser::CompUnitContext* ctx) {
|
||||
if (!ctx) return {};
|
||||
|
||||
is_global_scope_ = true;
|
||||
for (auto* decl : ctx->decl()) {
|
||||
decl->accept(this);
|
||||
if (!ctx) {
|
||||
throw std::runtime_error(FormatError("irgen", "缺少编译单元"));
|
||||
}
|
||||
for (auto* funcDef : ctx->funcDef()) {
|
||||
funcDef->accept(this);
|
||||
auto* func = ctx->funcDef();
|
||||
if (!func) {
|
||||
throw std::runtime_error(FormatError("irgen", "缺少函数定义"));
|
||||
}
|
||||
func->accept(this);
|
||||
return {};
|
||||
}
|
||||
|
||||
// 函数 IR 生成当前实现了:
|
||||
// 1. 获取函数名;
|
||||
// 2. 检查函数返回类型;
|
||||
// 3. 在 Module 中创建 Function;
|
||||
// 4. 将 builder 插入点设置到入口基本块;
|
||||
// 5. 继续生成函数体。
|
||||
//
|
||||
// 当前还没有实现:
|
||||
// - 通用函数返回类型处理;
|
||||
// - 形参列表遍历与参数类型收集;
|
||||
// - FunctionType 这样的函数类型对象;
|
||||
// - Argument/形式参数 IR 对象;
|
||||
// - 入口块中的参数初始化逻辑。
|
||||
// ...
|
||||
|
||||
// 因此这里目前只支持最小的“无参 int 函数”生成。
|
||||
std::any IRGenImpl::visitFuncDef(SysYParser::FuncDefContext* ctx) {
|
||||
is_global_scope_ = false;
|
||||
|
||||
std::shared_ptr<ir::Type> ret_ty;
|
||||
if (ctx->funcType()->INT()) {
|
||||
ret_ty = ir::Type::GetInt32Type();
|
||||
} else if (ctx->funcType()->FLOAT()) {
|
||||
ret_ty = ir::Type::GetFloatType();
|
||||
} else {
|
||||
ret_ty = ir::Type::GetVoidType();
|
||||
if (!ctx) {
|
||||
throw std::runtime_error(FormatError("irgen", "缺少函数定义"));
|
||||
}
|
||||
if (!ctx->blockStmt()) {
|
||||
throw std::runtime_error(FormatError("irgen", "函数体为空"));
|
||||
}
|
||||
if (!ctx->ID()) {
|
||||
throw std::runtime_error(FormatError("irgen", "缺少函数名"));
|
||||
}
|
||||
if (!ctx->funcType() || !ctx->funcType()->INT()) {
|
||||
throw std::runtime_error(FormatError("irgen", "当前仅支持无参 int 函数"));
|
||||
}
|
||||
|
||||
std::string func_name = ctx->ID()->getText();
|
||||
|
||||
std::vector<std::shared_ptr<ir::Type>> param_types;
|
||||
if (ctx->funcFParams()) {
|
||||
for (auto* fparam : ctx->funcFParams()->funcFParam()) {
|
||||
if (fparam->LBRACK().empty()) {
|
||||
param_types.push_back(fparam->btype()->INT() ? ir::Type::GetInt32Type() : ir::Type::GetFloatType());
|
||||
} else {
|
||||
// Array param is a pointer
|
||||
param_types.push_back(fparam->btype()->INT() ? ir::Type::GetPtrInt32Type() : ir::Type::GetPtrFloatType());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func_ = module_.CreateFunction(func_name, ret_ty, param_types);
|
||||
builder_.SetInsertPoint(func_->CreateBlock("entry"));
|
||||
|
||||
// Handle parameters: alloca and store
|
||||
if (ctx->funcFParams()) {
|
||||
const auto& fparams = ctx->funcFParams()->funcFParam();
|
||||
const auto& args = func_->GetArguments();
|
||||
for (size_t i = 0; i < fparams.size(); ++i) {
|
||||
auto* fparam = fparams[i];
|
||||
auto* arg = args[i].get();
|
||||
auto* slot = builder_.CreateAlloca(StorageType(arg->GetType()), fparam->ID()->getText());
|
||||
builder_.CreateStore(arg, slot);
|
||||
storage_map_[fparam] = slot;
|
||||
}
|
||||
}
|
||||
func_ = module_.CreateFunction(ctx->ID()->getText(), ir::Type::GetInt32Type());
|
||||
builder_.SetInsertPoint(func_->GetEntry());
|
||||
storage_map_.clear();
|
||||
|
||||
ctx->blockStmt()->accept(this);
|
||||
|
||||
// Default return for void functions if not terminated
|
||||
if (!builder_.GetInsertBlock()->HasTerminator()) {
|
||||
if (ret_ty->IsVoid()) {
|
||||
builder_.CreateRet(nullptr);
|
||||
} else if (ret_ty->IsInt32()) {
|
||||
builder_.CreateRet(builder_.CreateConstInt(0));
|
||||
} else if (ret_ty->IsFloat()) {
|
||||
builder_.CreateRet(builder_.CreateConstFloat(0.0f));
|
||||
}
|
||||
}
|
||||
|
||||
// 语义正确性主要由 sema 保证,这里只兜底检查 IR 结构是否合法。
|
||||
VerifyFunctionStructure(*func_);
|
||||
is_global_scope_ = true;
|
||||
return {};
|
||||
}
|
||||
|
||||
std::any IRGenImpl::visitFuncFParam(SysYParser::FuncFParamContext* ctx) {
|
||||
// We handle fparams in visitFuncDef directly.
|
||||
return {};
|
||||
}
|
||||
|
||||
@@ -6,146 +6,34 @@
|
||||
#include "ir/IR.h"
|
||||
#include "utils/Log.h"
|
||||
|
||||
// 语句生成
|
||||
// 语句生成当前只实现了最小子集。
|
||||
// 目前支持:
|
||||
// - return <exp>;
|
||||
//
|
||||
// 还未支持:
|
||||
// - 赋值语句
|
||||
// - if / while 等控制流
|
||||
// - 空语句、块语句嵌套分发之外的更多语句形态
|
||||
|
||||
std::any IRGenImpl::visitStmt(SysYParser::StmtContext* ctx) {
|
||||
if (!ctx) return BlockFlow::Continue;
|
||||
|
||||
if (ctx->assignStmt()) return ctx->assignStmt()->accept(this);
|
||||
if (ctx->returnStmt()) return ctx->returnStmt()->accept(this);
|
||||
if (ctx->blockStmt()) return ctx->blockStmt()->accept(this);
|
||||
if (ctx->ifStmt()) return ctx->ifStmt()->accept(this);
|
||||
if (ctx->whileStmt()) return ctx->whileStmt()->accept(this);
|
||||
if (ctx->breakStmt()) return ctx->breakStmt()->accept(this);
|
||||
if (ctx->continueStmt()) return ctx->continueStmt()->accept(this);
|
||||
if (ctx->expStmt()) return ctx->expStmt()->accept(this);
|
||||
|
||||
return BlockFlow::Continue;
|
||||
if (!ctx) {
|
||||
throw std::runtime_error(FormatError("irgen", "缺少语句"));
|
||||
}
|
||||
if (ctx->returnStmt()) {
|
||||
return ctx->returnStmt()->accept(this);
|
||||
}
|
||||
throw std::runtime_error(FormatError("irgen", "暂不支持的语句类型"));
|
||||
}
|
||||
|
||||
std::any IRGenImpl::visitBlockStmt(SysYParser::BlockStmtContext* ctx) {
|
||||
for (auto* item : ctx->blockItem()) {
|
||||
if (std::any_cast<BlockFlow>(item->accept(this)) == BlockFlow::Terminated) {
|
||||
return BlockFlow::Terminated;
|
||||
}
|
||||
}
|
||||
return BlockFlow::Continue;
|
||||
}
|
||||
|
||||
std::any IRGenImpl::visitBlockItem(SysYParser::BlockItemContext* ctx) {
|
||||
if (ctx->decl()) {
|
||||
ctx->decl()->accept(this);
|
||||
return BlockFlow::Continue;
|
||||
}
|
||||
if (ctx->stmt()) {
|
||||
return ctx->stmt()->accept(this);
|
||||
}
|
||||
return BlockFlow::Continue;
|
||||
}
|
||||
|
||||
std::any IRGenImpl::visitAssignStmt(SysYParser::AssignStmtContext* ctx) {
|
||||
ir::Value* ptr = GetLValuePtr(ctx->lValue());
|
||||
ir::Value* val = EvalExpr(*ctx->exp());
|
||||
if (ptr->GetType()->IsPtrFloat() && val->GetType()->IsInt32()) {
|
||||
val = builder_.CreateSIToFP(val, ir::Type::GetFloatType(),
|
||||
module_.GetContext().NextTemp());
|
||||
} else if (ptr->GetType()->IsPtrInt32() && val->GetType()->IsFloat()) {
|
||||
val = builder_.CreateFPToSI(val, ir::Type::GetInt32Type(),
|
||||
module_.GetContext().NextTemp());
|
||||
}
|
||||
builder_.CreateStore(val, ptr);
|
||||
return BlockFlow::Continue;
|
||||
}
|
||||
|
||||
std::any IRGenImpl::visitReturnStmt(SysYParser::ReturnStmtContext* ctx) {
|
||||
if (ctx->exp()) {
|
||||
ir::Value* v = EvalExpr(*ctx->exp());
|
||||
if (func_->GetType()->IsFloat() && v->GetType()->IsInt32()) {
|
||||
v = builder_.CreateSIToFP(v, ir::Type::GetFloatType(),
|
||||
module_.GetContext().NextTemp());
|
||||
} else if (func_->GetType()->IsInt32() && v->GetType()->IsFloat()) {
|
||||
v = builder_.CreateFPToSI(v, ir::Type::GetInt32Type(),
|
||||
module_.GetContext().NextTemp());
|
||||
}
|
||||
builder_.CreateRet(v);
|
||||
} else {
|
||||
builder_.CreateRet(nullptr);
|
||||
if (!ctx) {
|
||||
throw std::runtime_error(FormatError("irgen", "缺少 return 语句"));
|
||||
}
|
||||
if (!ctx->exp()) {
|
||||
throw std::runtime_error(FormatError("irgen", "return 缺少表达式"));
|
||||
}
|
||||
ir::Value* v = EvalExpr(*ctx->exp());
|
||||
builder_.CreateRet(v);
|
||||
return BlockFlow::Terminated;
|
||||
}
|
||||
|
||||
std::any IRGenImpl::visitIfStmt(SysYParser::IfStmtContext* ctx) {
|
||||
const std::string suffix = module_.GetContext().NextTemp();
|
||||
ir::BasicBlock* true_bb = func_->CreateBlock("if.true." + suffix);
|
||||
ir::BasicBlock* false_bb =
|
||||
ctx->ELSE() ? func_->CreateBlock("if.false." + suffix) : nullptr;
|
||||
ir::BasicBlock* merge_bb = func_->CreateBlock("if.merge." + suffix);
|
||||
|
||||
ir::Value* cond = EvalExpr(*ctx->exp());
|
||||
builder_.CreateCondBr(ToI1(cond), true_bb, false_bb ? false_bb : merge_bb);
|
||||
|
||||
// True block
|
||||
builder_.SetInsertPoint(true_bb);
|
||||
if (std::any_cast<BlockFlow>(ctx->stmt(0)->accept(this)) == BlockFlow::Continue) {
|
||||
builder_.CreateBr(merge_bb);
|
||||
}
|
||||
|
||||
// False block
|
||||
if (false_bb) {
|
||||
builder_.SetInsertPoint(false_bb);
|
||||
if (std::any_cast<BlockFlow>(ctx->stmt(1)->accept(this)) == BlockFlow::Continue) {
|
||||
builder_.CreateBr(merge_bb);
|
||||
}
|
||||
}
|
||||
|
||||
builder_.SetInsertPoint(merge_bb);
|
||||
return BlockFlow::Continue;
|
||||
}
|
||||
|
||||
std::any IRGenImpl::visitWhileStmt(SysYParser::WhileStmtContext* ctx) {
|
||||
const std::string suffix = module_.GetContext().NextTemp();
|
||||
ir::BasicBlock* cond_bb = func_->CreateBlock("while.cond." + suffix);
|
||||
ir::BasicBlock* body_bb = func_->CreateBlock("while.body." + suffix);
|
||||
ir::BasicBlock* end_bb = func_->CreateBlock("while.end." + suffix);
|
||||
|
||||
builder_.CreateBr(cond_bb);
|
||||
builder_.SetInsertPoint(cond_bb);
|
||||
ir::Value* cond = EvalExpr(*ctx->exp());
|
||||
builder_.CreateCondBr(ToI1(cond), body_bb, end_bb);
|
||||
|
||||
break_stack_.push(end_bb);
|
||||
continue_stack_.push(cond_bb);
|
||||
|
||||
builder_.SetInsertPoint(body_bb);
|
||||
if (std::any_cast<BlockFlow>(ctx->stmt()->accept(this)) == BlockFlow::Continue) {
|
||||
builder_.CreateBr(cond_bb);
|
||||
}
|
||||
|
||||
break_stack_.pop();
|
||||
continue_stack_.pop();
|
||||
|
||||
builder_.SetInsertPoint(end_bb);
|
||||
return BlockFlow::Continue;
|
||||
}
|
||||
|
||||
std::any IRGenImpl::visitBreakStmt(SysYParser::BreakStmtContext* ctx) {
|
||||
if (break_stack_.empty()) {
|
||||
throw std::runtime_error(FormatError("irgen", "break 语句不在循环内"));
|
||||
}
|
||||
builder_.CreateBr(break_stack_.top());
|
||||
return BlockFlow::Terminated;
|
||||
}
|
||||
|
||||
std::any IRGenImpl::visitContinueStmt(SysYParser::ContinueStmtContext* ctx) {
|
||||
if (continue_stack_.empty()) {
|
||||
throw std::runtime_error(FormatError("irgen", "continue 语句不在循环内"));
|
||||
}
|
||||
builder_.CreateBr(continue_stack_.top());
|
||||
return BlockFlow::Terminated;
|
||||
}
|
||||
|
||||
std::any IRGenImpl::visitExpStmt(SysYParser::ExpStmtContext* ctx) {
|
||||
if (ctx->exp()) {
|
||||
EvalExpr(*ctx->exp());
|
||||
}
|
||||
return BlockFlow::Continue;
|
||||
}
|
||||
|
||||
382
src/sem/Sema.cpp
382
src/sem/Sema.cpp
@@ -10,321 +10,185 @@
|
||||
|
||||
namespace {
|
||||
|
||||
std::string GetLValueName(SysYParser::LValueContext& lvalue) {
|
||||
if (!lvalue.ID()) {
|
||||
throw std::runtime_error(FormatError("sema", "非法左值"));
|
||||
}
|
||||
return lvalue.ID()->getText();
|
||||
}
|
||||
|
||||
class SemaVisitor final : public SysYBaseVisitor {
|
||||
public:
|
||||
explicit SemaVisitor() {
|
||||
// 预填标准库函数
|
||||
AddBuiltin("getint", Symbol::Kind::Function);
|
||||
AddBuiltin("getch", Symbol::Kind::Function);
|
||||
AddBuiltin("getfloat", Symbol::Kind::Function);
|
||||
AddBuiltin("getarray", Symbol::Kind::Function);
|
||||
AddBuiltin("getfarray", Symbol::Kind::Function);
|
||||
AddBuiltin("putint", Symbol::Kind::Function);
|
||||
AddBuiltin("putch", Symbol::Kind::Function);
|
||||
AddBuiltin("putfloat", Symbol::Kind::Function);
|
||||
AddBuiltin("putarray", Symbol::Kind::Function);
|
||||
AddBuiltin("putfarray", Symbol::Kind::Function);
|
||||
AddBuiltin("starttime", Symbol::Kind::Function);
|
||||
AddBuiltin("stoptime", Symbol::Kind::Function);
|
||||
AddBuiltin("putf", Symbol::Kind::Function);
|
||||
}
|
||||
|
||||
private:
|
||||
void AddBuiltin(const std::string& name, Symbol::Kind kind) {
|
||||
Symbol sym;
|
||||
sym.kind = kind;
|
||||
sym.def_ctx = nullptr; // 内建函数没有语法树节点
|
||||
table_.Add(name, sym);
|
||||
}
|
||||
|
||||
std::any visitCompUnit(SysYParser::CompUnitContext* ctx) override {
|
||||
if (!ctx) return {};
|
||||
for (auto* child : ctx->children) {
|
||||
if (auto* decl = dynamic_cast<SysYParser::DeclContext*>(child)) {
|
||||
decl->accept(this);
|
||||
} else if (auto* funcDef = dynamic_cast<SysYParser::FuncDefContext*>(child)) {
|
||||
funcDef->accept(this);
|
||||
}
|
||||
if (!ctx) {
|
||||
throw std::runtime_error(FormatError("sema", "缺少编译单元"));
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
std::any visitDecl(SysYParser::DeclContext* ctx) override {
|
||||
if (ctx->constDecl()) return ctx->constDecl()->accept(this);
|
||||
if (ctx->varDecl()) return ctx->varDecl()->accept(this);
|
||||
return {};
|
||||
}
|
||||
|
||||
std::any visitConstDecl(SysYParser::ConstDeclContext* ctx) override {
|
||||
for (auto* def : ctx->constDef()) {
|
||||
const std::string name = def->ID()->getText();
|
||||
if (table_.IsInCurrentScope(name)) {
|
||||
throw std::runtime_error(FormatError("sema", "重复定义常量: " + name));
|
||||
}
|
||||
Symbol sym;
|
||||
sym.kind = Symbol::Kind::Constant;
|
||||
sym.def_ctx = def;
|
||||
sym.is_const = true;
|
||||
sym.is_array = !def->exp().empty();
|
||||
table_.Add(name, sym);
|
||||
|
||||
for (auto* exp : def->exp()) exp->accept(this);
|
||||
def->initValue()->accept(this);
|
||||
auto* func = ctx->funcDef();
|
||||
if (!func || !func->blockStmt()) {
|
||||
throw std::runtime_error(FormatError("sema", "缺少 main 函数定义"));
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
std::any visitVarDecl(SysYParser::VarDeclContext* ctx) override {
|
||||
for (auto* def : ctx->varDef()) {
|
||||
const std::string name = def->ID()->getText();
|
||||
if (table_.IsInCurrentScope(name)) {
|
||||
throw std::runtime_error(FormatError("sema", "重复定义变量: " + name));
|
||||
}
|
||||
Symbol sym;
|
||||
sym.kind = Symbol::Kind::Variable;
|
||||
sym.def_ctx = def;
|
||||
sym.is_const = false;
|
||||
sym.is_array = !def->exp().empty();
|
||||
table_.Add(name, sym);
|
||||
|
||||
for (auto* exp : def->exp()) exp->accept(this);
|
||||
if (def->initValue()) def->initValue()->accept(this);
|
||||
if (!func->ID() || func->ID()->getText() != "main") {
|
||||
throw std::runtime_error(FormatError("sema", "缺少 main 函数定义"));
|
||||
}
|
||||
func->accept(this);
|
||||
if (!seen_return_) {
|
||||
throw std::runtime_error(
|
||||
FormatError("sema", "main 函数必须包含 return 语句"));
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
std::any visitFuncDef(SysYParser::FuncDefContext* ctx) override {
|
||||
const std::string name = ctx->ID()->getText();
|
||||
if (table_.IsInCurrentScope(name)) {
|
||||
throw std::runtime_error(FormatError("sema", "重复定义函数: " + name));
|
||||
if (!ctx || !ctx->blockStmt()) {
|
||||
throw std::runtime_error(FormatError("sema", "缺少 main 函数定义"));
|
||||
}
|
||||
Symbol sym;
|
||||
sym.kind = Symbol::Kind::Function;
|
||||
sym.def_ctx = ctx;
|
||||
table_.Add(name, sym);
|
||||
|
||||
table_.PushScope();
|
||||
if (ctx->funcFParams()) {
|
||||
ctx->funcFParams()->accept(this);
|
||||
if (!ctx->funcType() || !ctx->funcType()->INT()) {
|
||||
throw std::runtime_error(FormatError("sema", "当前仅支持 int main"));
|
||||
}
|
||||
if (ctx->blockStmt()) {
|
||||
// Visit block items without pushing another scope to keep params in same scope
|
||||
for (auto* item : ctx->blockStmt()->blockItem()) {
|
||||
item->accept(this);
|
||||
}
|
||||
const auto& items = ctx->blockStmt()->blockItem();
|
||||
if (items.empty()) {
|
||||
throw std::runtime_error(
|
||||
FormatError("sema", "main 函数不能为空,且必须以 return 结束"));
|
||||
}
|
||||
table_.PopScope();
|
||||
return {};
|
||||
}
|
||||
|
||||
std::any visitFuncFParam(SysYParser::FuncFParamContext* ctx) override {
|
||||
const std::string name = ctx->ID()->getText();
|
||||
if (table_.IsInCurrentScope(name)) {
|
||||
throw std::runtime_error(FormatError("sema", "函数参数名冲突: " + name));
|
||||
}
|
||||
Symbol sym;
|
||||
sym.kind = Symbol::Kind::Parameter;
|
||||
sym.def_ctx = ctx;
|
||||
sym.is_array = !ctx->LBRACK().empty();
|
||||
table_.Add(name, sym);
|
||||
|
||||
for (auto* exp : ctx->exp()) exp->accept(this);
|
||||
ctx->blockStmt()->accept(this);
|
||||
return {};
|
||||
}
|
||||
|
||||
std::any visitBlockStmt(SysYParser::BlockStmtContext* ctx) override {
|
||||
table_.PushScope();
|
||||
for (auto* item : ctx->blockItem()) {
|
||||
if (!ctx) {
|
||||
throw std::runtime_error(FormatError("sema", "缺少语句块"));
|
||||
}
|
||||
const auto& items = ctx->blockItem();
|
||||
for (size_t i = 0; i < items.size(); ++i) {
|
||||
auto* item = items[i];
|
||||
if (!item) {
|
||||
continue;
|
||||
}
|
||||
if (seen_return_) {
|
||||
throw std::runtime_error(
|
||||
FormatError("sema", "return 必须是 main 函数中的最后一条语句"));
|
||||
}
|
||||
current_item_index_ = i;
|
||||
total_items_ = items.size();
|
||||
item->accept(this);
|
||||
}
|
||||
table_.PopScope();
|
||||
return {};
|
||||
}
|
||||
|
||||
std::any visitAssignStmt(SysYParser::AssignStmtContext* ctx) override {
|
||||
ctx->lValue()->accept(this);
|
||||
const std::string name = ctx->lValue()->ID()->getText();
|
||||
Symbol* sym = table_.Lookup(name);
|
||||
if (sym && sym->is_const) {
|
||||
throw std::runtime_error(FormatError("sema", "试图给常量赋值: " + name));
|
||||
std::any visitBlockItem(SysYParser::BlockItemContext* ctx) override {
|
||||
if (!ctx) {
|
||||
throw std::runtime_error(FormatError("sema", "暂不支持的语句或声明"));
|
||||
}
|
||||
ctx->exp()->accept(this);
|
||||
return {};
|
||||
}
|
||||
|
||||
std::any visitLValue(SysYParser::LValueContext* ctx) override {
|
||||
const std::string name = ctx->ID()->getText();
|
||||
Symbol* sym = table_.Lookup(name);
|
||||
if (!sym) {
|
||||
throw std::runtime_error(FormatError("sema", "使用了未定义的标识符: " + name));
|
||||
if (ctx->decl()) {
|
||||
ctx->decl()->accept(this);
|
||||
return {};
|
||||
}
|
||||
if (sym->kind == Symbol::Kind::Function) {
|
||||
throw std::runtime_error(FormatError("sema", "函数名不能作为左值: " + name));
|
||||
if (ctx->stmt()) {
|
||||
ctx->stmt()->accept(this);
|
||||
return {};
|
||||
}
|
||||
sema_.BindLValue(ctx, sym->def_ctx);
|
||||
for (auto* exp : ctx->exp()) exp->accept(this);
|
||||
return {};
|
||||
throw std::runtime_error(FormatError("sema", "暂不支持的语句或声明"));
|
||||
}
|
||||
|
||||
std::any visitFuncCallExp(SysYParser::FuncCallExpContext* ctx) override {
|
||||
const std::string name = ctx->ID()->getText();
|
||||
Symbol* sym = table_.Lookup(name);
|
||||
if (!sym) {
|
||||
throw std::runtime_error(FormatError("sema", "调用未定义的函数: " + name));
|
||||
std::any visitDecl(SysYParser::DeclContext* ctx) override {
|
||||
if (!ctx) {
|
||||
throw std::runtime_error(FormatError("sema", "非法变量声明"));
|
||||
}
|
||||
if (sym->kind != Symbol::Kind::Function) {
|
||||
throw std::runtime_error(FormatError("sema", "标识符不是函数: " + name));
|
||||
if (!ctx->btype() || !ctx->btype()->INT()) {
|
||||
throw std::runtime_error(FormatError("sema", "当前仅支持局部 int 变量声明"));
|
||||
}
|
||||
sema_.BindFuncCall(ctx, dynamic_cast<SysYParser::FuncDefContext*>(sym->def_ctx));
|
||||
if (ctx->funcRParams()) {
|
||||
ctx->funcRParams()->accept(this);
|
||||
auto* var_def = ctx->varDef();
|
||||
if (!var_def || !var_def->lValue()) {
|
||||
throw std::runtime_error(FormatError("sema", "非法变量声明"));
|
||||
}
|
||||
const std::string name = GetLValueName(*var_def->lValue());
|
||||
if (table_.Contains(name)) {
|
||||
throw std::runtime_error(FormatError("sema", "重复定义变量: " + name));
|
||||
}
|
||||
if (auto* init = var_def->initValue()) {
|
||||
if (!init->exp()) {
|
||||
throw std::runtime_error(FormatError("sema", "当前不支持聚合初始化"));
|
||||
}
|
||||
init->exp()->accept(this);
|
||||
}
|
||||
table_.Add(name, var_def);
|
||||
return {};
|
||||
}
|
||||
|
||||
// Visit expressions to ensure all sub-expressions are checked (e.g. for variable uses)
|
||||
std::any visitParenExp(SysYParser::ParenExpContext* ctx) override {
|
||||
return ctx->exp()->accept(this);
|
||||
}
|
||||
|
||||
std::any visitLValueExp(SysYParser::LValueExpContext* ctx) override {
|
||||
return ctx->lValue()->accept(this);
|
||||
}
|
||||
|
||||
std::any visitNumberExp(SysYParser::NumberExpContext* ctx) override {
|
||||
return ctx->number()->accept(this);
|
||||
}
|
||||
|
||||
std::any visitNumber(SysYParser::NumberContext* ctx) override {
|
||||
return {};
|
||||
}
|
||||
|
||||
std::any visitNotExp(SysYParser::NotExpContext* ctx) override {
|
||||
return ctx->exp()->accept(this);
|
||||
}
|
||||
|
||||
std::any visitUnaryAddExp(SysYParser::UnaryAddExpContext* ctx) override {
|
||||
return ctx->exp()->accept(this);
|
||||
}
|
||||
|
||||
std::any visitUnarySubExp(SysYParser::UnarySubExpContext* ctx) override {
|
||||
return ctx->exp()->accept(this);
|
||||
}
|
||||
|
||||
std::any visitMulExp(SysYParser::MulExpContext* ctx) override {
|
||||
ctx->exp(0)->accept(this);
|
||||
ctx->exp(1)->accept(this);
|
||||
return {};
|
||||
}
|
||||
|
||||
std::any visitDivExp(SysYParser::DivExpContext* ctx) override {
|
||||
ctx->exp(0)->accept(this);
|
||||
ctx->exp(1)->accept(this);
|
||||
return {};
|
||||
}
|
||||
|
||||
std::any visitModExp(SysYParser::ModExpContext* ctx) override {
|
||||
ctx->exp(0)->accept(this);
|
||||
ctx->exp(1)->accept(this);
|
||||
return {};
|
||||
}
|
||||
|
||||
std::any visitAddExp(SysYParser::AddExpContext* ctx) override {
|
||||
ctx->exp(0)->accept(this);
|
||||
ctx->exp(1)->accept(this);
|
||||
return {};
|
||||
}
|
||||
|
||||
std::any visitSubExp(SysYParser::SubExpContext* ctx) override {
|
||||
ctx->exp(0)->accept(this);
|
||||
ctx->exp(1)->accept(this);
|
||||
return {};
|
||||
}
|
||||
|
||||
std::any visitLtExp(SysYParser::LtExpContext* ctx) override {
|
||||
ctx->exp(0)->accept(this);
|
||||
ctx->exp(1)->accept(this);
|
||||
return {};
|
||||
}
|
||||
|
||||
std::any visitLeExp(SysYParser::LeExpContext* ctx) override {
|
||||
ctx->exp(0)->accept(this);
|
||||
ctx->exp(1)->accept(this);
|
||||
return {};
|
||||
}
|
||||
|
||||
std::any visitGtExp(SysYParser::GtExpContext* ctx) override {
|
||||
ctx->exp(0)->accept(this);
|
||||
ctx->exp(1)->accept(this);
|
||||
return {};
|
||||
}
|
||||
|
||||
std::any visitGeExp(SysYParser::GeExpContext* ctx) override {
|
||||
ctx->exp(0)->accept(this);
|
||||
ctx->exp(1)->accept(this);
|
||||
return {};
|
||||
}
|
||||
|
||||
std::any visitEqExp(SysYParser::EqExpContext* ctx) override {
|
||||
ctx->exp(0)->accept(this);
|
||||
ctx->exp(1)->accept(this);
|
||||
return {};
|
||||
}
|
||||
|
||||
std::any visitNeExp(SysYParser::NeExpContext* ctx) override {
|
||||
ctx->exp(0)->accept(this);
|
||||
ctx->exp(1)->accept(this);
|
||||
return {};
|
||||
}
|
||||
|
||||
std::any visitAndExp(SysYParser::AndExpContext* ctx) override {
|
||||
ctx->exp(0)->accept(this);
|
||||
ctx->exp(1)->accept(this);
|
||||
return {};
|
||||
}
|
||||
|
||||
std::any visitOrExp(SysYParser::OrExpContext* ctx) override {
|
||||
ctx->exp(0)->accept(this);
|
||||
ctx->exp(1)->accept(this);
|
||||
std::any visitStmt(SysYParser::StmtContext* ctx) override {
|
||||
if (!ctx || !ctx->returnStmt()) {
|
||||
throw std::runtime_error(FormatError("sema", "暂不支持的语句或声明"));
|
||||
}
|
||||
ctx->returnStmt()->accept(this);
|
||||
return {};
|
||||
}
|
||||
|
||||
std::any visitReturnStmt(SysYParser::ReturnStmtContext* ctx) override {
|
||||
if (ctx->exp()) ctx->exp()->accept(this);
|
||||
return {};
|
||||
}
|
||||
|
||||
std::any visitIfStmt(SysYParser::IfStmtContext* ctx) override {
|
||||
if (!ctx || !ctx->exp()) {
|
||||
throw std::runtime_error(FormatError("sema", "return 缺少表达式"));
|
||||
}
|
||||
ctx->exp()->accept(this);
|
||||
ctx->stmt(0)->accept(this);
|
||||
if (ctx->stmt(1)) ctx->stmt(1)->accept(this);
|
||||
seen_return_ = true;
|
||||
if (current_item_index_ + 1 != total_items_) {
|
||||
throw std::runtime_error(
|
||||
FormatError("sema", "return 必须是 main 函数中的最后一条语句"));
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
std::any visitWhileStmt(SysYParser::WhileStmtContext* ctx) override {
|
||||
std::any visitParenExp(SysYParser::ParenExpContext* ctx) override {
|
||||
if (!ctx || !ctx->exp()) {
|
||||
throw std::runtime_error(FormatError("sema", "非法括号表达式"));
|
||||
}
|
||||
ctx->exp()->accept(this);
|
||||
ctx->stmt()->accept(this);
|
||||
return {};
|
||||
}
|
||||
|
||||
std::any visitBreakStmt(SysYParser::BreakStmtContext* ctx) override {
|
||||
std::any visitVarExp(SysYParser::VarExpContext* ctx) override {
|
||||
if (!ctx || !ctx->var()) {
|
||||
throw std::runtime_error(FormatError("sema", "非法变量表达式"));
|
||||
}
|
||||
ctx->var()->accept(this);
|
||||
return {};
|
||||
}
|
||||
|
||||
std::any visitContinueStmt(SysYParser::ContinueStmtContext* ctx) override {
|
||||
std::any visitNumberExp(SysYParser::NumberExpContext* ctx) override {
|
||||
if (!ctx || !ctx->number() || !ctx->number()->ILITERAL()) {
|
||||
throw std::runtime_error(FormatError("sema", "当前仅支持整数字面量"));
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
std::any visitExpStmt(SysYParser::ExpStmtContext* ctx) override {
|
||||
if (ctx->exp()) ctx->exp()->accept(this);
|
||||
std::any visitAdditiveExp(SysYParser::AdditiveExpContext* ctx) override {
|
||||
if (!ctx || !ctx->exp(0) || !ctx->exp(1)) {
|
||||
throw std::runtime_error(FormatError("sema", "暂不支持的表达式形式"));
|
||||
}
|
||||
ctx->exp(0)->accept(this);
|
||||
ctx->exp(1)->accept(this);
|
||||
return {};
|
||||
}
|
||||
public:
|
||||
SemanticContext TakeSemanticContext() { return std::move(sema_); }
|
||||
|
||||
private:
|
||||
std::any visitVar(SysYParser::VarContext* ctx) override {
|
||||
if (!ctx || !ctx->ID()) {
|
||||
throw std::runtime_error(FormatError("sema", "非法变量引用"));
|
||||
}
|
||||
const std::string name = ctx->ID()->getText();
|
||||
auto* decl = table_.Lookup(name);
|
||||
if (!decl) {
|
||||
throw std::runtime_error(FormatError("sema", "使用了未定义的变量: " + name));
|
||||
}
|
||||
sema_.BindVarUse(ctx, decl);
|
||||
return {};
|
||||
}
|
||||
|
||||
SemanticContext TakeSemanticContext() { return std::move(sema_); }
|
||||
|
||||
private:
|
||||
SymbolTable table_;
|
||||
SemanticContext sema_;
|
||||
bool seen_return_ = false;
|
||||
size_t current_item_index_ = 0;
|
||||
size_t total_items_ = 0;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
@@ -1,40 +1,17 @@
|
||||
// 维护局部变量声明的注册与查找。
|
||||
|
||||
#include "sem/SymbolTable.h"
|
||||
|
||||
SymbolTable::SymbolTable() {
|
||||
// Push global scope
|
||||
PushScope();
|
||||
void SymbolTable::Add(const std::string& name,
|
||||
SysYParser::VarDefContext* decl) {
|
||||
table_[name] = decl;
|
||||
}
|
||||
|
||||
void SymbolTable::PushScope() {
|
||||
scopes_.emplace_back();
|
||||
bool SymbolTable::Contains(const std::string& name) const {
|
||||
return table_.find(name) != table_.end();
|
||||
}
|
||||
|
||||
void SymbolTable::PopScope() {
|
||||
if (scopes_.size() > 1) {
|
||||
scopes_.pop_back();
|
||||
}
|
||||
}
|
||||
|
||||
bool SymbolTable::Add(const std::string& name, const Symbol& symbol) {
|
||||
auto& current_scope = scopes_.back();
|
||||
if (current_scope.find(name) != current_scope.end()) {
|
||||
return false;
|
||||
}
|
||||
current_scope[name] = symbol;
|
||||
return true;
|
||||
}
|
||||
|
||||
Symbol* SymbolTable::Lookup(const std::string& name) {
|
||||
for (auto it = scopes_.rbegin(); it != scopes_.rend(); ++it) {
|
||||
auto search = it->find(name);
|
||||
if (search != it->end()) {
|
||||
return &search->second;
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
bool SymbolTable::IsInCurrentScope(const std::string& name) const {
|
||||
const auto& current_scope = scopes_.back();
|
||||
return current_scope.find(name) != current_scope.end();
|
||||
SysYParser::VarDefContext* SymbolTable::Lookup(const std::string& name) const {
|
||||
auto it = table_.find(name);
|
||||
return it == table_.end() ? nullptr : it->second;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user