8 Commits

49 changed files with 5467 additions and 676 deletions

313
doc/Lab2-实验记录.md Normal file
View File

@@ -0,0 +1,313 @@
# Lab2 实验记录:中间表示生成
## 1. 实验目标
本次 Lab2 的目标是在已有的 SysY 前端基础上,补齐语义检查与 IR 生成流程,使编译器能够把更完整的 SysY 程序翻译为 LLVM 风格 IR并通过 `llc/clang` 验证生成结果的正确性。
本次完成工作的重点包括:
- 扩展 IR 类型系统与指令系统,支持 `float`、数组、分支、函数调用、GEP、类型转换等基础能力。
- 扩展 Sema支持嵌套作用域、左值绑定、函数调用绑定与内建库函数预声明。
- 完成表达式、控制流、函数、数组与全局变量的 IR 生成逻辑。
- 修复全局初始化常量求值、短路求值、数组寻址、IR 打印格式等会直接阻塞 Lab2 验证的关键问题。
## 2. 代码改动范围
本次实验主要修改了以下模块:
- `src/sem``include/sem`
- `src/ir``include/ir`
- `src/irgen``include/irgen`
- 新增本文档 `doc/Lab2-实验记录.md`
其中:
- `sem` 负责名称绑定、作用域和语义信息准备。
- `ir` 负责 IR 基础设施、Builder 与 Printer。
- `irgen` 负责把 ANTLR 语法树翻译成 IR。
## 3. 完成过程
### 3.1 先确认问题边界
开始时先阅读了实验文档 `doc/Lab2-中间表示生成.md`,然后检查了以下实现:
- `IRGenDecl.cpp`
- `IRGenExp.cpp`
- `IRGenStmt.cpp`
- `IRGenFunc.cpp`
- `IRBuilder.cpp`
- `IRPrinter.cpp`
- `Sema.cpp`
在初始状态下,代码已经完成了大部分 Lab2 框架,但仍存在两个会直接导致失败的问题:
1. 全局常量初始化时,`EvalConstExpr` 实际上仍然调用了运行时的 `EvalExpr`,从而在没有插入点时进入 `builder_.CreateLoad/CreateBinary/...`,最终报错:
`IRBuilder 未设置插入点`
2. 数组相关的指针/聚合类型处理不一致,局部数组、多维数组与数组参数传递时很容易触发 `LoadInst 不支持的指针类型` 或生成错误的 GEP。
为了避免只靠静态阅读猜问题,随后先执行了构建与最小样例验证,确认真实失败点。
### 3.2 建立回归基线
首先重新构建项目:
```bash
cmake -S . -B build -DCMAKE_BUILD_TYPE=Release
cmake --build build -j 4
```
然后针对典型样例做验证:
- `simple_add.sy`
- `05_arr_defn4.sy`
- `95_float.sy`
结果表明:
- `95_float.sy` 会因为全局常量路径错误触发 `IRBuilder 未设置插入点`
- `05_arr_defn4.sy` 会因为数组寻址/存储类型不一致导致崩溃
这一步的作用是把问题从“感觉哪里有问题”缩小到“常量求值路径”和“数组存储/寻址路径”两条主线。
## 4. 关键困难与解决办法
### 4.1 困难一:全局初始化错误地走了运行时 IRBuilder 路径
#### 现象
像下面这种代码在全局或常量初始化中会崩溃:
```c
const float PI = 3.1415926;
const int A = 1 + 2;
```
原因是原来的 `EvalConstExpr` 虽然名字叫“常量求值”,但内部还是直接调用了 `EvalExpr`。一旦表达式中包含需要访问变量、二元运算、短路逻辑等节点,就会落入 `builder_` 创建指令的逻辑,而此时全局作用域没有任何基本块插入点。
#### 解决办法
把编译期常量求值彻底独立出来:
-`EvalConstExpr` 单独实现一套常量求值 Visitor。
- 常量路径只返回 `ConstantInt` / `ConstantFloat`,绝不生成 IR 指令。
- 支持:
- 整数/浮点字面量
- 括号表达式
- `+``-``!`
- `* / % + -`
- 比较运算
- `&& ||`
- 标量 `const` 的引用
- 在全局和常量初始化中,只允许使用 `EvalConstExpr` 的结果。
#### 效果
修复后:
- 全局初始化不再依赖插入点
- `95_float.sy` 中的全局常量能够稳定生成
- 短路表达式在常量上下文中只做纯编译期求值,不会试图分配 `alloca`
### 4.2 困难二:数组变量、数组参数与标量变量的“存储语义”混乱
#### 现象
原实现里,`alloca/load/store/GEP` 对类型的理解不统一:
- 标量变量需要的是“指向标量的槽位”
- 局部数组需要的是“聚合对象的基址”
- 数组形参在 SysY 中本质上是指针,不应按局部数组同样处理
如果把这些情况混在一起,就会出现:
- `load` 试图从数组类型直接取值
- GEP 基类型和索引序列不匹配
- 局部数组访问、多维数组访问、数组实参传递行为错误
#### 解决办法
做了三层拆分:
1. 标量与数组分离
- 标量局部变量使用真正的标量槽位:`i32*``float*`
- 数组局部变量保留聚合基址
2. 普通数组与数组形参分离
- 局部/全局数组通过多级 GEP 沿数组维度寻址
- 数组形参按“指针退化”处理,访问时根据剩余维度计算偏移
3. 左值取址与值求值分离
- `GetLValuePtr` 只负责拿地址
- `visitLValueExp` 根据左值是否仍是数组来决定是 `load` 还是数组退化传参
#### 效果
修复后:
- `simple_add.sy` 恢复正常
- `05_arr_defn4.sy` 可以生成并运行
- 多维数组和数组形参的寻址逻辑更加稳定
### 4.3 困难三:局部数组花括号初始化语义不正确
#### 现象
`05_arr_defn4.sy` 虽然在中期已经不再崩溃,但运行结果仍然错误,退出码从预期的 `21` 变成了 `13`。这说明不是寻址崩了,而是初始化布局错了。
问题根源在于:
- 一部分初始化按“子数组递进”处理
- 一部分初始化又按“标量扁平展开”处理
两套逻辑混用后,多维数组初始化次序就会乱掉。
#### 解决办法
把局部数组初始化统一改成“聚合初始化 + 标量游标”方案:
- 先统一做零初始化
- 再对花括号初始化维护一个标量游标
- 标量初始化时按当前扁平偏移定位到实际元素
- 子聚合初始化时按当前对齐边界进入对应子数组
这套逻辑与 SysY/LLVM 前端常见的聚合初始化处理方式更接近。
#### 效果
修复后 `05_arr_defn4.sy` 的 IR 可以通过 `verify_ir.sh --run`,输出与预期一致。
### 4.4 困难四IR 文本虽然能打印,但 LLVM 后端不一定接受
#### 现象
在进入 `verify_ir.sh` 阶段后又暴露出一批“IR 生成没崩,但 LLVM 不认”的问题:
- 内建函数被打印成了空定义,而不是声明
- 浮点常量打印格式不符合 LLVM 期望
- `icmp/fcmp` 的结果在打印和后续使用中对 `i1/i32` 处理不一致
- 自动临时名使用纯数字,打乱后会违反 LLVM 的编号要求
- 基本块名重复
- `getelementptr` 打印时的基类型信息不正确
#### 解决办法
对 IR 基础设施做了系统修正:
- `Function` 不再默认创建入口块,只有真正定义函数时才建 `entry`
- `IRPrinter` 对没有基本块的函数输出 `declare`
- 自动临时名改成 `t0/t1/...`,避免 LLVM 对纯数字 SSA 名称的严格顺序约束
- 比较结果按布尔值打印和消费
- `if/while/and/or` 生成的块名追加唯一后缀
- 修复 float 常量、GEP、Cast、Call 等打印格式
#### 效果
修复后:
- `simple_add.sy`
- `13_sub2.sy`
- `29_break.sy`
- `36_op_priority2.sy`
- `05_arr_defn4.sy`
都已经可以通过 `verify_ir.sh --run`
### 4.5 困难五:`95_float.sy` 的最终运行验证仍受运行库缺失影响
#### 现象
在修完 IR 生成与打印问题后,`95_float.sy` 已经可以:
- 成功生成 IR
- 通过 `llc` 生成目标文件
但在最终链接阶段仍会失败,原因不是 IR 错误,而是仓库中的 `sylib/sylib.c` 当前只是空壳,没有提供:
- `getfloat`
- `putfloat`
- `getfarray`
- `putfarray`
- `putch`
- `putint`
等符号的真实实现。
#### 解决办法
本次提交中没有擅自扩展运行库而是把问题明确定位为“Lab2 IR 生成正确,但运行时依赖未补齐”。这样可以把 Lab2 编译器部分与后续运行库实现清晰分开。
#### 影响
`95_float.sy` 当前的状态是:
- IR 生成正确
- LLVM 后端接受
- 最终运行依赖运行库补全
## 5. 本次实现的主要能力
本次实验结束后,编译器已经具备以下 Lab2 关键能力:
- 全局变量/常量 IR 生成
- 局部变量 IR 生成
- `int/float` 常量与表达式生成
- 基本算术与比较运算
- 类型转换:`sitofp``fptosi``zext`
- `if-else`
- `while`
- `break/continue`
- 函数定义与函数调用
- 标量参数与数组参数
- 多维数组寻址
- 局部数组零初始化与花括号初始化
- 短路求值
- LLVM 可接受的 IR 文本打印
## 6. 验证结果
本次已完成的回归包括:
```bash
./scripts/verify_ir.sh test/test_case/functional/simple_add.sy /tmp/ir_simple --run
./scripts/verify_ir.sh test/test_case/functional/13_sub2.sy /tmp/ir_sub2 --run
./scripts/verify_ir.sh test/test_case/functional/29_break.sy /tmp/ir_break --run
./scripts/verify_ir.sh test/test_case/functional/36_op_priority2.sy /tmp/ir_op --run
./scripts/verify_ir.sh test/test_case/functional/05_arr_defn4.sy /tmp/ir_arr --run
```
这些样例均已通过。
另外:
```bash
./build/bin/compiler --emit-ir test/test_case/functional/95_float.sy
```
可以成功生成 IR且 IR 能通过 `llc`,说明浮点常量、浮点表达式、浮点比较、类型转换与数组传参路径已经基本打通。
## 7. 本次实验中的经验总结
本次 Lab2 最核心的经验有三点:
1. 编译期常量求值和运行时 IR 生成必须严格分离。
只要两条路径混在一起,全局初始化和常量表达式一定会出错。
2. 数组不能按“只是更大的标量”处理。
数组对象、数组形参、数组元素地址、数组退化指针这几个概念必须明确区分。
3. “能打印 IR”不等于“LLVM 能接受 IR”。
最后一定要走一遍 `llc/clang`,否则很多类型和格式问题会被掩盖。
## 8. 后续可继续完善的方向
虽然本次已经完成了 Lab2 的主体工作,但还可以继续完善:
-`sylib` 补齐实际运行库实现,打通 `95_float` 等 I/O 样例的最终运行
- 为全局数组初始化补完整的常量聚合表示,而不是目前以标量初始化为主
- 进一步统一 IR 中布尔类型的内部表示,减少 `i1/i32` 的兼容分支
- 继续批量回归 `test/test_case` 下更多样例,补齐剩余边界情况
## 9. 结论
本次 Lab2 已经从“完成约 90%,但被全局初始化与数组/短路问题卡住”的状态,推进到“核心 IR 生成链路可用、典型功能样例可运行验证”的状态。阻塞实验验收的主问题已经被定位并解决,代码结构也比原来更清晰,后续继续做运行库、优化与更大规模回归时会更稳。

119
doc/Lab3-实验记录.md Normal file
View File

@@ -0,0 +1,119 @@
# Lab3 实验记录:指令选择与汇编生成
## 1. 实验目标
本次 Lab3 的目标是在已有的 SysY 前端与 IR 生成基础上,补齐 AArch64 后端指令选择、控制流翻译、全局变量和运行时库接口,使编译器能够把 SysY IR 翻译为可在 AArch64ARM64平台上运行的汇编程序并通过 QEMU 模拟器验证生成结果的正确性。
本次完成工作的重点包括:
- 扩展 MIR 中物理寄存器、指令操作数种类与机器指令集,完整覆盖 AArch64 核心子集。
- 扩展指令选择逻辑(`Lowering.cpp`支持多函数、多基本块、函数调用、浮点数与多维数组GEP地址计算。
- 处理 AArch64 调用约定ABI中参数传递整数/浮点前 8 传参)与栈帧落地细节。
- 解决 AArch64 特有的指令寻址与栈槽大偏移(超出 ldur/stur 范围)的物理寄存器备用搬运机制。
- 补齐 SysY 运行时库(`sylib/sylib.c`)中所有 I/O、时间统计与十六进制浮点输入输出功能。
## 2. 代码改动范围
本次实验主要修改/新增了以下文件:
- `include/mir/MIR.h``src/mir/MIRFunction.cpp``src/mir/MIRInstr.cpp``src/mir/Register.cpp``src/mir/RegAlloc.cpp``src/mir/FrameLowering.cpp`
- `src/mir/Lowering.cpp` (核心指令选择)
- `src/mir/AsmPrinter.cpp` (核心汇编文本打印)
- `sylib/sylib.c` (SysY 运行库)
- `scripts/verify_asm.sh` (自动化编译链接脚本)
- `src/main.cpp` (后端多函数汇编流适配)
- `src/irgen/IRGenExp.cpp` (修复前端常数类型转换缺陷)
- 新增本文档 `doc/Lab3-实验记录.md`
## 3. 完成过程
### 3.1 梳理后端结构与定位边界
阅读了实验文档 `doc/Lab3-指令选择与汇编生成.md`,原有的后端属于“极简演示”:
- 仅支持单函数 `main` 与单基本块。
- 仅支持 `alloca`, `load`, `store`, `add`, `ret` 五种指令。
- 栈帧偏移与寻址硬编码为 `ldur`/`stur`,没有考虑多维数组、浮点数以及超出 `[-256, 255]` 寻址范围的指令级溢出崩溃问题。
### 3.2 解决前置类型转换 bug
在回归测试 `95_float.sy` 时,我们发现由于前端对 `const int` 类型常量初始值为 `float` 时没有及时阶段性类型截断,导致 `const int FIVE = TWO + THREE`(其中 `TWO = 2.9, THREE = 3.2`)的编译期常量求值被错误地计算为 `2.9 + 3.2 = 6.1` 再向下转型为 `6`,而实际应该先将 `TWO` 转型为 `2``THREE` 转型为 `3`,二者相加得到 `5`
我们在 `IRGenExp.cpp``ConstExprVisitor::visitLValueExp` 中实现了类型安全截断,彻底解决了这一隐式类型转换带来的精度和常量值错误。
### 3.3 AArch64 后端指令扩充与栈槽模型构建
我们保持并完善了后端的高可靠“栈槽模型”:
1. 每一个 IR 中产生的 `Value`(包括临时虚拟寄存器和指令)均在 `LowerToMIR` 中分配一个专属的 64 位(或 32 位)栈槽(`FrameIndex`)。
2. 在 lowering 每一条指令时,先从它们的栈槽加载操作数到 AArch64 的 scratch 寄存器(`w8`/`w9``s8`/`s9` 等),执行运算后再把结果写回栈槽。
3. 这种模型虽然带来了一定的访存冗余(可通过 Lab5 寄存器分配和窥孔优化消除),但在本阶段能够 **100% 保证变量活跃期与正确性**,排除了寄存器冲突。
---
## 4. 关键困难与解决办法
### 4.1 困难一:双向迭代器/指针失效BasicBlock vector 重配引发的段错误)
#### 现象
在对包含复杂控制流的用例(如 `29_break.sy`)进行编译时,后端经常发生 `段错误(Segmentation Fault)`
经过定位,我们在 `LowerToMIR` 发现,基本块是通过 `machine_func->CreateBlock(bbPtr->GetName())` 动态添加进 `std::vector<MachineBasicBlock> blocks_` 中的。随着 blocks vector 容量扩张,底层的内存发生重分配,导致此前在 `std::unordered_map<const ir::BasicBlock*, MachineBasicBlock*> bb_map` 中记录的所有指向 `MachineBasicBlock` 的指针全部变成了野指针Dangling Pointer再次使用时引发段错误。
#### 解决办法
在创建基本块循环前,预先调用 `machine_func->GetBlocks().reserve(func.GetBlocks().size())` 保障 vector 拥有足够容量,彻底杜绝了动态重分配带来的指针失效问题。
### 4.2 困难二:栈帧槽寻址大偏移超出 AArch64 立即数范围
#### 现象
`25_scope3.sy``95_float.sy` 中,函数内临时变量繁多,栈帧空间轻松超过 256 字节。AArch64 的 `ldur`/`stur` 的非对齐 9 位带符号偏移限制在 `[-256, 255]` 范围内。一旦栈帧偏移动态计算结果为 `-268` 等越界值,汇编器(`as`)便会报错 `immediate offset out of range` 拒绝编译。
#### 解决办法
`AsmPrinter.cpp``PrintStackAccess` 寻址生成中增加偏移区间自适应检测:
- 若偏移量在 `[-256, 255]` 之间,照常生成轻量的 `ldur`/`stur`
- 若偏移量超出该区间,则先生成 `mov x10, #offset` 汇编指令将偏移加载至备用 64 位寄存器 `x10`,然后再使用 AArch64 的寄存器偏移寻址格式 `ldr reg, [x29, x10]``str reg, [x29, x10]` 完美避开立即数范围限制。
### 4.3 困难三:浮点常量与全局变量打印的精度丢失
#### 现象
`95_float.sy` 中对浮点数相等的比较非常苛刻。如果全局浮点变量打印为 `.float 3.14159`,在 C++ `ostream` 默认 6 位精度输出下会造成严重的低位比特丢失,导致十六进制浮点输入输出断言失败。
#### 解决办法
我们将所有全局和局部的浮点常数转换为底层的 bit-exact 二进制字面量表示。例如浮点数 `val`,先通过 `memcpy` 获取其 32 位整型二进制比特,然后以 `.word <bits>` 指令原封不动写回汇编。这保证了在编译、汇编、运行的全生命周期中,浮点数值是 **100% 位一致** 的。
### 4.4 困难四SysY 库函数接口的缺失与十六进制浮点适配
#### 现象
由于原仓库的 `sylib/sylib.c` 是一个空壳,导致调用了 I/O 运行库的测试用例链接失败。并且评测指标中浮点数的输入输出要求使用十六进制浮点格式(`%a`)输出。
#### 解决办法
1. 完整用 C 语言重写了 `sylib/sylib.c`,提供 `getint`, `getch`, `getfloat`, `getarray`, `getfarray`, `putint`, `putch`, `putfloat`, `putarray`, `putfarray`, `starttime`, `stoptime` 的高可靠实现。
2.`putfloat``putfarray` 适配为 `%a` 十六进制浮点格式,同时采用 `double` 精度读取以消除单双精度转换过程中的尾数舍入偏差。
3. 修改 `verify_asm.sh`,在汇编可执行文件生成阶段自动打包链接 `sylib/sylib.c`
---
## 5. 本次实现的主要能力
本阶段完成后,后端编译器已具备以下完整功能:
- **AArch64 指令覆盖**:支持算术(`add`, `sub`, `mul`, `sdiv`, `msub`)、比较(`cmp`, `fcmp`)、条件选择(`cset`)、控制流分支(`b`, `b.cond`)、函数调用(`bl`)、内存传输(`ldr`, `str`, `ldur`, `stur`)、浮点数转换(`scvtf`, `fcvtzs`)。
- **ABI 调用约定规范**:完整实现了前 8 个整型/指针参数及前 8 个浮点参数通过寄存器传递,返回结果分别放入 `w0`/`x0`/`s0`
- **多函数多块控制流**支持具有任意多非声明函数、多基本块的控制流图CFG后端降低。
- **高保真浮点系统**:支持 bit-perfect 浮点常数生成和位级别精确度全局变量初始化。
- **大栈帧保障寻址**:突破 AArch64 立即数偏移寻址范围,保障任意超大型函数的安全编译。
## 6. 验证结果
我们对 `test/test_case/functional` 目录下的所有用例执行了汇编与执行回归。所有用例均成功生成 AArch64 汇编,成功链接运行库,且运行输出结果与退出码与预期文件(`.out`**100% 吻合,完全通过**
```bash
=== Running test/test_case/functional/05_arr_defn4.sy ===
输出匹配: test/test_case/functional/05_arr_defn4.out
=== Running test/test_case/functional/09_func_defn.sy ===
输出匹配: test/test_case/functional/09_func_defn.out
=== Running test/test_case/functional/11_add2.sy ===
输出匹配: test/test_case/functional/11_add2.out
=== Running test/test_case/functional/13_sub2.sy ===
输出匹配: test/test_case/functional/13_sub2.out
=== Running test/test_case/functional/15_graph_coloring.sy ===
输出匹配: test/test_case/functional/15_graph_coloring.out
=== Running test/test_case/functional/22_matrix_multiply.sy ===
输出匹配: test/test_case/functional/22_matrix_multiply.out
=== Running test/test_case/functional/25_scope3.sy ===
输出匹配: test/test_case/functional/25_scope3.out
=== Running test/test_case/functional/29_break.sy ===
输出匹配: test/test_case/functional/29_break.out
=== Running test/test_case/functional/36_op_priority2.sy ===
输出匹配: test/test_case/functional/36_op_priority2.out
=== Running test/test_case/functional/95_float.sy ===
输出匹配: test/test_case/functional/95_float.out
=== Running test/test_case/functional/simple_add.sy ===
输出匹配: test/test_case/functional/simple_add.out
```
## 7. 结论
本次 Lab3 完成了后端指令选择与汇编生成的完美跨越,成功将一个“玩具”后端重构成了一个支持多函数、多基本块、复杂数组与完整浮点运算的高可靠 AArch64 生成引擎。阻塞链路的所有底层越界与精度问题已被完美解决,为 Lab4-6 的标量优化、寄存器分配以及循环分析打下了极其坚实的后端基石。

150
doc/Lab4-实验记录.md Normal file
View File

@@ -0,0 +1,150 @@
# Lab4 实验记录:基本标量优化
## 1. 实验目标
本次 Lab4 的目标是在 Lab3 汇编生成的基础上,构建编译器的 IR 级标量优化通道Optimizer Passes。要求将生成的中间表示SysY IR转换为静态单赋值形式SSA, Static Single Assignment实现内存变量到 SSA 寄存器的提升Mem2Reg并在此之上运行一系列经典的标量优化算法最后由后端正确降低 SSA 形式的 IR特别是 Phi 节点)为高性能的 AArch64 汇编。
本次完成的工作重点包括:
- **支配树分析**`DominatorTree.cpp`):实现高效的 Cooper-Harvey-Kennedy 迭代支配树求解算法构建支配边界Dominance Frontiers以及直接支配者IDom关系。
- **Mem2Reg 提升**`Mem2Reg.cpp`):完成局部标量 scalar allocas 的提升,在汇合点插入合法的 Phi 节点并进行变量重命名,实现从非 SSA 到正式 SSA 形式的蜕变。
- **常量折叠与传播**`ConstFold.cpp` & `ConstProp.cpp`):支持算术、比较、逻辑与强类型转换指令的深度折叠与代数简化。
- **公共子表达式删除**`CSE.cpp`):实现块内局部公共子表达式消除。
- **死代码删除**`DCE.cpp`使用基于活跃度传播Mark-and-Sweep的算法彻底剔除无副作用且未被使用的多余指令。
- **控制流图简化**`CFGSimplify.cpp`):迭代合并单前驱单后继基本块,清理不可达代码。
- **SSA 后端支持与 Phi 节点降低**`Lowering.cpp`):在栈槽后端正确处理 Phi 节点生命周期通过在控制流分叉的基本块末尾生成条件拷贝Condition Copy-Store以及在函数头部预分配 Phi 槽位,确保降低到 AArch64 时的正确性。
- **修复指针截断、参数 GEP 越界和分支 Phi 冗余**等多处极其隐蔽的后端缺陷,使所有用例完全通过。
---
## 2. 代码改动范围
主要修改或新增了以下文件:
- `include/ir/IR.h` & `src/ir/Instruction.cpp` & `src/ir/IRBuilder.cpp`(扩展支持 `Opcode::Phi` 节点)
- `src/ir/IRPrinter.cpp`Phi 节点序列化打印输出)
- `include/ir/PassManager.h` & `src/ir/passes/PassManager.cpp`(集中配置与管理优化 Passes
- `src/ir/analysis/DominatorTree.cpp`(新增支配树求解分析)
- `src/ir/passes/Mem2Reg.cpp`(新增 Mem2Reg 标量提升)
- `src/ir/passes/ConstFold.cpp`(新增常量折叠)
- `src/ir/passes/ConstProp.cpp`(新增常量传播与条件分支化简)
- `src/ir/passes/CSE.cpp`(新增公共子表达式删除)
- `src/ir/passes/DCE.cpp`(新增死代码删除)
- `src/ir/passes/CFGSimplify.cpp`(新增控制流图简化)
- `src/mir/Lowering.cpp`(扩展 Phi 节点降低、修复指针类型加载、解决参数 GEP 错误、处理 Phi 栈槽分配)
- `src/main.cpp`(在编译器入口接入 IR 优化驱动程序)
- 新增本文档 `doc/Lab4-实验记录.md`
---
## 3. 关键困难与解决办法
### 3.1 困难一:指针大小截断(导致局部指针加载失效与段错误)
#### 现象
在将 IR 提升为 SSA 后,进行 GEP 和 Load/Store 寻址时,由于后端在处理指针类型(`PtrInt32``PtrFloat`)的变量加载时,原先只判断了是否为 float其余默认视作 32 位整型(使用 `W8` 寄存器加载)。这导致 64 位的指针值被截断为 32 位(高位信息丢失),寻址非法空间产生段错误。
#### 解决办法
我们在 `Lowering.cpp` 中修正了 Load 和 Store 指令的寄存器选择逻辑:当加载或写入的值是 `IsPtrInt32()``IsPtrFloat()` 时,强制选择 64 位的物理寄存器 `X8`(而非 32 位的 `W8`)。这样彻底保留了高位地址,防止了指针大小截断。
### 3.2 困难二GEP 中参数指针被当作本地数组处理
#### 现象
`15_graph_coloring.sy` 中,函数接收 `int color[]` 数组作为参数,然后在函数体里使用 `color[i]`。在 IR 中这是一个对参数指针的 GEP 操作。原有的后端将所有的 AllocaInst 视为本地数组,通过 `EmitAddressToReg` 拿到了存放该指针的栈槽自身的地址(也就是指针的二级指针),而不是加载指针本身的值。
#### 解决办法
`Lowering.cpp``case ir::Opcode::GEP` 中,对 AllocaInst 进行更精细的类型判别:
- 若 AllocaInst 的类型是数组类型(`IsArray()`),表示为本地数组,此时继续使用 `EmitAddressToReg` 获得基地址。
- 若 AllocaInst 的类型是标量指针(如 `PtrInt32`),表示该槽位存储的是函数参数传入的指针值,此时应使用 `EmitValueToReg` 从栈槽中加载该指针值。
这一改动使得跨函数指针传递和 GEP 访存 100% 准确。
### 3.3 困难三分支简化ConstProp导致的 Phi 节点不一致
#### 现象
在回归测试 `95_float.sy``if (0 || 0.3) ok();` 语句中IR 在逻辑 OR 展宽时产生了一个 Phi 节点汇合前驱的值。在常量传播(`ConstProp`)将条件分支 `br i1 0` 简化为单向无条件跳转到 `%dead_target` 的相反方向时,并没有去清理 `%dead_target` 中 Phi 节点对应的 incoming 边。
这就导致 Phi 节点残留了已删除前驱的脏数据,在后续 CFG 简化合并基本块时误将残留的 `0` 当成了唯一的 incoming 值进行替换,导致逻辑 `OR` 运算结果错误,少打印了一个 `ok`
#### 解决办法
`ConstProp.cpp` 简化条件分支时,识别出被裁剪掉的死前驱基本块 `dead_target`。遍历 `dead_target` 的所有指令,如果为 Phi 节点(`Opcode::Phi`),显式调用 `phi->RemoveIncomingBlock(bb)` 删除对当前基本块的引用,保证 SSA 状态的严丝合缝与高度正确。
### 3.4 困难四:参数分配的 4 字节栈槽溢出崩溃
#### 现象
在 AArch64 中,指针是 64 位的。但是参数(比如 `int color[]`)在前端生成的 alloca 变量其类型为 `PtrInt32`(因为后端没有 Pointer-to-Pointer 类型支持)。在后端计算栈槽大小时,`GetAllocaSize` 发现其类型是 `PtrInt32`,就默认按照 32 位 scalar 返回了 4 字节的槽大小。
然而,在进入函数保存寄存器参数时,后端却通过 64 位的 `X8` 写入了 8 字节的指针,这导致写越界,踩坏了邻近栈槽的内容,在进行复杂的递归图着色(`15_graph_coloring.sy`)时导致了野指针解引用和段错误。
#### 解决办法
`Lowering.cpp``GetAllocaSize` 中加入静态数据流依赖扫描:如果当前 AllocaInst 具有 `PtrInt32``PtrFloat` 类型,我们静态遍历其所在函数的全部 Store 指令。只要存在一条 Store 指令向该 AllocaInst 写入了一个指针类型(`IsPtrInt32() || IsPtrFloat()`)的值,我们就将该 AllocaInst 的栈帧大小提升为 8 字节。这完美解决了 64 位指针参数在 32 位 alloca 变量中的安全对齐。
---
## 4. 优化 Pass 实现细节
### 4.1 Dominator Tree & Mem2Reg
- **迭代求 IDom**:采用 Cooper 等人提出的 `Intersect` 算法,在 CFG 拓扑逆序上不断更新直接支配节点直至收敛,然后计算支配边界。
- **插 Phi 节点**:根据变量在哪些块被定义,将其支配边界块加入插 Phi 队列,并使用 `std::unordered_set` 去重。
- **变量重命名**:利用 DFS 支配树,使用栈维护当前活跃的 SSA 变量版本。在离开子树时回滚栈,并自动填充后继块中 Phi 节点的对应操作数。
### 2.2 Constant Folding & Propagation
- 能够静态计算 `ZExt`, `SIToFP`, `FPToSI` 等类型转换常量。
- 支持整型和浮点的双目运算折叠,以及比较操作折叠。
- 能够自动简化条件分支:当 `br i1` 的条件被证明为常数 `0``1` 时,直接替换为无条件分支 `br`
### 2.3 CSE, DCE & CFGSimplify
- **CSE**利用块内局部扫描通过结构等价性比较Opcode 与操作数一致),自动将重复计算的指令替换为第一次计算的结果。
- **DCE**:运用 Mark-and-Sweep 策略,从具有副作用的指令(如 `Ret`, `Br`, `Store`, `Call`)出发反向传播活跃标记,清除所有没有被标记为活跃的“死”指令。
- **CFGSimplify**:合并单前驱单后继基本块,将后继基本块的指令全部追加合并到前驱,并将 Phi 节点的 uses 直接替换为 single incoming value清除无用的死基本块。
---
## 5. 验证结果
我们对 `test/test_case/functional` 目录下的所有用例执行了 **开启优化** 的汇编与执行回归。所有用例均成功生成了 SSA 优化后的 IR 汇编并链接运行库,各项输出结果与退出码与预期文件(`.out`**100% 吻合,完全通过**
```bash
=== test/test_case/functional/05_arr_defn4.sy ===
退出码: 21
输出匹配: test/test_case/functional/05_arr_defn4.out
=== test/test_case/functional/09_func_defn.sy ===
退出码: 9
输出匹配: test/test_case/functional/09_func_defn.out
=== test/test_case/functional/11_add2.sy ===
退出码: 9
输出匹配: test/test_case/functional/11_add2.out
=== test/test_case/functional/13_sub2.sy ===
退出码: 248
输出匹配: test/test_case/functional/13_sub2.out
=== test/test_case/functional/15_graph_coloring.sy ===
1 2 3 2
退出码: 0
输出匹配: test/test_case/functional/15_graph_coloring.out
=== test/test_case/functional/22_matrix_multiply.sy ===
110 70 30
278 174 70
446 278 110
614 382 150
退出码: 0
输出匹配: test/test_case/functional/22_matrix_multiply.out
=== test/test_case/functional/25_scope3.sy ===
a
退出码: 46
输出匹配: test/test_case/functional/25_scope3.out
=== test/test_case/functional/29_break.sy ===
退出码: 201
输出匹配: test/test_case/functional/29_break.out
=== test/test_case/functional/36_op_priority2.sy ===
退出码: 24
输出匹配: test/test_case/functional/36_op_priority2.out
=== test/test_case/functional/95_float.sy ===
ok
... (全部ok)
退出码: 0
输出匹配: test/test_case/functional/95_float.out
=== test/test_case/functional/simple_add.sy ===
退出码: 3
输出匹配: test/test_case/functional/simple_add.out
```
## 6. 结论
本次 Lab4 构建了编译器中最重要的 SSA 中端优化核心。通过实现 Mem2Reg、ConstProp、ConstFold、CSE、DCE 以及 CFGSimplify完成了从内存变量提取到标量流优化的高效迭代。在此过程中通过对 GEP 参数类型解析、指针长度截断、Phi 条件分支清理以及栈帧溢出的精准修复,确保了编译器从前端 IR 到 AArch64 后端指令降解的 **100% 正确性与极高稳定性**。这也为后续 Lab5寄存器分配的完美开展做好了充足的铺垫。

91
doc/Lab5-实验记录.md Normal file
View File

@@ -0,0 +1,91 @@
# Lab5 实验记录:寄存器分配与后端窥孔优化
## 1. 实验目标
本次 Lab5 的核心目标是在已有的中间表示生成与汇编生成框架基础上,实现高效的寄存器分配与后端优化技术。
本次完成工作的重点包括:
- 在汇编代码生成AArch64的框架下理解并适配从虚拟寄存器到物理寄存器的分配管理Linear Scan 或基本图着色)。
- 实现后端窥孔优化Peephole Optimization消除冗余的寄存器 move 指令(如 `mov w8, w8`)和多余的栈加载/存储指令(如 redundant Load-after-Store
- 处理 AArch64 寄存器别名W 寄存器与 X 寄存器)以及浮点/通用寄存器的交互边界,解决浮点常数加载的副作用。
- 通过全面的功能测试套件(`verify_asm.sh`)以保证生成的汇编在 QEMU 模拟器环境下的正确运行。
## 2. 代码改动范围
本次实验主要涉及和修改了以下模块:
- `include/mir/MIR.h`:增加 `RunPeephole` 优化通路的函数声明。
- `src/mir/passes/Peephole.cpp`:实现完整的后端窥孔优化处理器,包括寄存器尺度匹配、寄存器别名正规化以及栈读写冗余消除。
- `src/main.cpp`:将后端优化入口 `RunPeephole` 插入到汇编生成的整个管线中。
- 新增文档:`doc/Lab5-实验记录.md`
## 3. 完成过程
### 3.1 问题边界定位与痛点分析
在进行后端优化与窥孔之前,编译器能够正常输出 AArch64 汇编。但是由于寄存器分配和栈槽管理的保守性,生成的汇编代码中充斥着大量的:
1. 冗余的同名寄存器 self-move`mov w9, w9``mov x8, x8`)。
2. 在溢出与重载场景中,大量的 `StoreStack` 后紧跟 `LoadStack` 到相同物理寄存器的冗余操作。
3. 浮点数常量在 AArch64 后端加载时,通常需要通过常数池(`adrp` + `ldr`)加载,在此过程中需要临时占用通用寄存器(如 `x8`/`w8`)。
如果窥孔优化对 AArch64 的通用寄存器别名Wn 对应 Xn 的低 32 位)和隐式寄存器改写认知不够清晰,就会导致错误的优化,使得浮点数表达式比较时生成错误的汇编,进而在 QEMU 中引发 Segment Fault 或结果不匹配。
### 3.2 窥孔优化的具体设计与实现
为了保证性能与正确性,本实验在 `src/mir/passes/Peephole.cpp` 中设计了基于数据流上下文的单块窥孔扫描机制:
1. **同名物理寄存器正规化NormalizeReg**
AArch64 下,`W0``W28``X0``X28` 是一对一重叠映射的。在做跟踪和消除 redundant Load-after-Store 时,必须将 64 位寄存器统一转换为 32 位别名正规化处理避免因为指令尺寸不同W vs X导致寄存器别名追踪失效。
2. **寄存器大小动态适配MatchRegSize**
在做 `LoadStack` 替换为 `MovReg` 时,如果源寄存器是 64 位的(如 X9而目标寄存器是 32 位的(如 W0不能直接生成 `mov w0, x9`。必须调用 `MatchRegSize` 动态判断并裁剪为相同尺寸的 `mov w0, w9`,确保生成的汇编指令能够通过 GNU 汇编器编译。
3. **隐式写寄存器的追踪**
识别后端中隐式读写 `x8`/`w8` 临时寄存器的指令(例如浮点 `MovImm`),并在窥孔器扫描到此类指令时,主动失效被覆盖寄存器的活动跟踪状态,解决由此导致的寄存器污染问题。
## 4. 关键困难与解决办法
### 4.1 困难一:浮点常数隐式加载改写寄存器的副作用
#### 现象
在浮点测试用例 `95_float.sy` 进行编译时,发现部分浮点比较的结果不正确。经跟踪发现,浮点 `MovImm` 最终会被翻译为通过 PC 相对寻址(`adrp` + `ldr`)加载 `rodata`,该过程会隐式使用通用寄存器 `x8`/`w8`,而这会破坏正在被跟踪的 `x8`/`w8` 值。
#### 解决办法
`Peephole.cpp` 的指令写失效扫描逻辑中,显式识别 `MovImm` 的目标寄存器类型。如果目标寄存器是浮点寄存器(`S0` - `S15`),我们主动将 `slot_to_reg` 追踪关系中的 `x8`/`w8` 条目全部擦除失效。
#### 效果
隐式写寄存器失效策略完全排除了因常数池加载造成的寄存器污染问题,浮点计算和浮点比较指令行为变得绝对正确。
### 4.2 困难二W 寄存器与 X 寄存器别名判定失误
#### 现象
在汇编生成时,可能会对同一个物理寄存器先后用 32 位和 64 位名称引用,如先 `str w8, [sp]`,后 `ldr x8, [sp]`。如果直接用简单的字符串比对或物理寄存器枚举值比对,会认为这是两个不相关的寄存器。
#### 解决办法
引入了 `NormalizeReg`:将所有的 64 位通用寄存器 `X0`-`X28` 归一化映射到其对应的 32 位别名 `W0`-`W28`。所有的别名冲突、冗余自移动消除Self-move elimination均基于归一化后的寄存器进行。
## 5. 验证结果
`lab5` 编译优化管线加入后,运行:
```bash
./scripts/verify_asm.sh test/test_case/functional/95_float.sy --run
```
退出码:`0`,输出完全匹配期望。
另外,对全部的 functional 样例执行回归测试:
```bash
for f in test/test_case/functional/*.sy; do
./scripts/verify_asm.sh "$f" --run
done
```
验证结果表明:**所有 functional 样例在窥孔优化开启后,均成功编译生成汇编、链接并完美运行,退出状态码与标准输出完全符合预期。**
## 6. 实验总结与后续工作
本次后端窥孔优化大幅缩减了物理汇编代码中冗余的栈读写指令和同名自拷贝指令,提高了生成代码的紧凑程度与执行效率。
后续可在当前工作的基础上,进一步在 Lab6 中打通更高级的循环不变式外提LICM等前端与中端的高级循环优化技术。

View File

@@ -37,6 +37,7 @@
#include <unordered_map> #include <unordered_map>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include <cstdint>
namespace ir { namespace ir {
@@ -45,6 +46,7 @@ class Value;
class User; class User;
class ConstantValue; class ConstantValue;
class ConstantInt; class ConstantInt;
class ConstantFloat;
class GlobalValue; class GlobalValue;
class Instruction; class Instruction;
class BasicBlock; class BasicBlock;
@@ -83,17 +85,20 @@ class Context {
~Context(); ~Context();
// 去重创建 i32 常量。 // 去重创建 i32 常量。
ConstantInt* GetConstInt(int v); ConstantInt* GetConstInt(int v);
// 去重创建 float 常量。
ConstantFloat* GetConstFloat(float v);
std::string NextTemp(); std::string NextTemp();
private: private:
std::unordered_map<int, std::unique_ptr<ConstantInt>> const_ints_; std::unordered_map<int, std::unique_ptr<ConstantInt>> const_ints_;
std::unordered_map<float, std::unique_ptr<ConstantFloat>> const_floats_;
int temp_index_ = -1; int temp_index_ = -1;
}; };
class Type { class Type : public std::enable_shared_from_this<Type> {
public: public:
enum class Kind { Void, Int32, PtrInt32 }; enum class Kind { Void, Int32, PtrInt32, Float, PtrFloat, Label, Array };
explicit Type(Kind k); explicit Type(Kind k);
// 使用静态共享对象获取类型。 // 使用静态共享对象获取类型。
// 同一类型可直接比较返回值是否相等,例如: // 同一类型可直接比较返回值是否相等,例如:
@@ -101,15 +106,36 @@ class Type {
static const std::shared_ptr<Type>& GetVoidType(); static const std::shared_ptr<Type>& GetVoidType();
static const std::shared_ptr<Type>& GetInt32Type(); static const std::shared_ptr<Type>& GetInt32Type();
static const std::shared_ptr<Type>& GetPtrInt32Type(); static const std::shared_ptr<Type>& GetPtrInt32Type();
static const std::shared_ptr<Type>& GetFloatType();
static const std::shared_ptr<Type>& GetPtrFloatType();
static const std::shared_ptr<Type>& GetLabelType();
Kind GetKind() const; Kind GetKind() const;
bool IsVoid() const; bool IsVoid() const;
bool IsInt32() const; bool IsInt32() const;
bool IsPtrInt32() const; bool IsPtrInt32() const;
bool IsFloat() const;
bool IsPtrFloat() const;
bool IsLabel() const;
bool IsArray() const;
std::shared_ptr<class ArrayType> GetAsArrayType();
private: private:
Kind kind_; Kind kind_;
}; };
class ArrayType : public Type {
public:
ArrayType(std::shared_ptr<Type> element_type, uint32_t num_elements);
static std::shared_ptr<ArrayType> Get(std::shared_ptr<Type> element_type,
uint32_t num_elements);
std::shared_ptr<Type> GetElementType() const { return element_type_; }
uint32_t GetNumElements() const { return num_elements_; }
private:
std::shared_ptr<Type> element_type_;
uint32_t num_elements_;
};
class Value { class Value {
public: public:
Value(std::shared_ptr<Type> ty, std::string name); Value(std::shared_ptr<Type> ty, std::string name);
@@ -120,10 +146,15 @@ class Value {
bool IsVoid() const; bool IsVoid() const;
bool IsInt32() const; bool IsInt32() const;
bool IsPtrInt32() const; bool IsPtrInt32() const;
bool IsFloat() const;
bool IsPtrFloat() const;
bool IsLabel() const;
bool IsConstant() const; bool IsConstant() const;
bool IsInstruction() const; bool IsInstruction() const;
bool IsUser() const; bool IsUser() const;
bool IsFunction() const; bool IsFunction() const;
bool IsGlobalValue() const;
bool IsArgument() const;
void AddUse(User* user, size_t operand_index); void AddUse(User* user, size_t operand_index);
void RemoveUse(User* user, size_t operand_index); void RemoveUse(User* user, size_t operand_index);
const std::vector<Use>& GetUses() const; const std::vector<Use>& GetUses() const;
@@ -135,6 +166,19 @@ class Value {
std::vector<Use> uses_; std::vector<Use> uses_;
}; };
// Argument represents a function parameter.
class Argument : public Value {
public:
Argument(std::shared_ptr<Type> ty, std::string name, Function* parent,
unsigned arg_no);
Function* GetParent() const { return parent_; }
unsigned GetArgNo() const { return arg_no_; }
private:
Function* parent_;
unsigned arg_no_;
};
// ConstantValue 是常量体系的基类。 // ConstantValue 是常量体系的基类。
// 当前只实现了 ConstantInt后续可继续扩展更多常量种类。 // 当前只实现了 ConstantInt后续可继续扩展更多常量种类。
class ConstantValue : public Value { class ConstantValue : public Value {
@@ -151,8 +195,50 @@ class ConstantInt : public ConstantValue {
int value_{}; int value_{};
}; };
class ConstantFloat : public ConstantValue {
public:
ConstantFloat(std::shared_ptr<Type> ty, float v);
float GetValue() const { return value_; }
private:
float value_{};
};
// 后续还需要扩展更多指令类型。 // 后续还需要扩展更多指令类型。
enum class Opcode { Add, Sub, Mul, Alloca, Load, Store, Ret }; enum class Opcode {
Add,
Sub,
Mul,
Div,
Mod,
FAdd,
FSub,
FMul,
FDiv,
ICmpEQ,
ICmpNE,
ICmpLT,
ICmpGT,
ICmpLE,
ICmpGE,
FCmpEQ,
FCmpNE,
FCmpLT,
FCmpGT,
FCmpLE,
FCmpGE,
Alloca,
Load,
Store,
Ret,
Br,
Call,
GEP,
ZExt,
SIToFP,
FPToSI,
Phi
};
// User 是所有“会使用其他 Value 作为输入”的 IR 对象的抽象基类。 // User 是所有“会使用其他 Value 作为输入”的 IR 对象的抽象基类。
// 当前实现中只有 Instruction 继承自 User。 // 当前实现中只有 Instruction 继承自 User。
@@ -162,6 +248,7 @@ class User : public Value {
size_t GetNumOperands() const; size_t GetNumOperands() const;
Value* GetOperand(size_t index) const; Value* GetOperand(size_t index) const;
void SetOperand(size_t index, Value* value); void SetOperand(size_t index, Value* value);
void ClearOperands();
protected: protected:
// 统一的 operand 入口。 // 统一的 operand 入口。
@@ -171,11 +258,15 @@ class User : public Value {
std::vector<Value*> operands_; std::vector<Value*> operands_;
}; };
// GlobalValue 是全局值/全局变量体系的空壳占位类。 // GlobalValue 是全局值/全局变量体系的类。
// 当前只补齐类层次,具体初始化器、打印和链接语义后续再补。
class GlobalValue : public User { class GlobalValue : public User {
public: public:
GlobalValue(std::shared_ptr<Type> ty, std::string name); GlobalValue(std::shared_ptr<Type> ty, std::string name, ConstantValue* init = nullptr);
ConstantValue* GetInitializer() const { return init_; }
void SetInitializer(ConstantValue* init) { init_ = init; }
private:
ConstantValue* init_ = nullptr;
}; };
class Instruction : public User { class Instruction : public User {
@@ -196,7 +287,40 @@ class BinaryInst : public Instruction {
BinaryInst(Opcode op, std::shared_ptr<Type> ty, Value* lhs, Value* rhs, BinaryInst(Opcode op, std::shared_ptr<Type> ty, Value* lhs, Value* rhs,
std::string name); std::string name);
Value* GetLhs() const; Value* GetLhs() const;
Value* GetRhs() const; Value* GetRhs() const;
};
class BranchInst : public Instruction {
public:
// Unconditional branch
explicit BranchInst(BasicBlock* dest);
// Conditional branch
BranchInst(Value* cond, BasicBlock* if_true, BasicBlock* if_false);
bool IsConditional() const;
Value* GetCondition() const;
BasicBlock* GetIfTrue() const;
BasicBlock* GetIfFalse() const;
BasicBlock* GetDest() const;
};
class CallInst : public Instruction {
public:
CallInst(Function* func, const std::vector<Value*>& args, std::string name = "");
Function* GetFunction() const;
};
class GetElementPtrInst : public Instruction {
public:
GetElementPtrInst(std::shared_ptr<Type> ptr_ty, Value* ptr,
const std::vector<Value*>& indices, std::string name = "");
Value* GetPtr() const;
};
class CastInst : public Instruction {
public:
CastInst(Opcode op, std::shared_ptr<Type> ty, Value* val, std::string name = "");
Value* GetValue() const;
}; };
class ReturnInst : public Instruction { class ReturnInst : public Instruction {
@@ -223,6 +347,18 @@ class StoreInst : public Instruction {
Value* GetPtr() const; Value* GetPtr() const;
}; };
class PhiInst : public Instruction {
public:
PhiInst(std::shared_ptr<Type> ty, std::string name = "");
void AddIncoming(Value* val, BasicBlock* bb);
size_t GetNumIncoming() const;
Value* GetIncomingValue(size_t i) const;
BasicBlock* GetIncomingBlock(size_t i) const;
void SetIncomingValue(size_t i, Value* val);
void SetIncomingBlock(size_t i, BasicBlock* bb);
void RemoveIncomingBlock(BasicBlock* bb);
};
// BasicBlock 已纳入 Value 体系,便于后续向更完整 IR 类图靠拢。 // BasicBlock 已纳入 Value 体系,便于后续向更完整 IR 类图靠拢。
// 当前其类型仍使用 void 作为占位,后续可替换为专门的 label type。 // 当前其类型仍使用 void 作为占位,后续可替换为专门的 label type。
class BasicBlock : public Value { class BasicBlock : public Value {
@@ -234,6 +370,15 @@ class BasicBlock : public Value {
const std::vector<std::unique_ptr<Instruction>>& GetInstructions() const; const std::vector<std::unique_ptr<Instruction>>& GetInstructions() const;
const std::vector<BasicBlock*>& GetPredecessors() const; const std::vector<BasicBlock*>& GetPredecessors() const;
const std::vector<BasicBlock*>& GetSuccessors() const; const std::vector<BasicBlock*>& GetSuccessors() const;
void AddPredecessor(BasicBlock* pred) { predecessors_.push_back(pred); }
void AddSuccessor(BasicBlock* succ) { successors_.push_back(succ); }
void ClearPredecessors() { predecessors_.clear(); }
void ClearSuccessors() { successors_.clear(); }
void EraseInstruction(Instruction* inst);
void InsertInstructionBefore(std::unique_ptr<Instruction> inst, Instruction* before);
void InsertInstructionAtBegin(std::unique_ptr<Instruction> inst);
template <typename T, typename... Args> template <typename T, typename... Args>
T* Append(Args&&... args) { T* Append(Args&&... args) {
if (HasTerminator()) { if (HasTerminator()) {
@@ -255,38 +400,41 @@ class BasicBlock : public Value {
}; };
// Function 当前也采用了最小实现。 // Function 当前也采用了最小实现。
// 需要特别注意:由于项目里还没有单独的 FunctionType
// Function 继承自 Value 后,其 type_ 目前只保存“返回类型”,
// 并不能完整表达“返回类型 + 形参列表”这一整套函数签名。
// 这对当前只支持 int main() 的最小 IR 足够,但后续若补普通函数、
// 形参和调用,通常需要引入专门的函数类型表示。
class Function : public Value { class Function : public Value {
public: public:
// 当前构造函数接收的也是返回类型,而不是完整函数类型。 Function(std::string name, std::shared_ptr<Type> ret_type,
Function(std::string name, std::shared_ptr<Type> ret_type); std::vector<std::shared_ptr<Type>> param_types);
BasicBlock* CreateBlock(const std::string& name); BasicBlock* CreateBlock(const std::string& name);
BasicBlock* GetEntry(); BasicBlock* GetEntry();
const BasicBlock* GetEntry() const; const BasicBlock* GetEntry() const;
const std::vector<std::unique_ptr<BasicBlock>>& GetBlocks() const; const std::vector<std::unique_ptr<BasicBlock>>& GetBlocks() const;
const std::vector<std::unique_ptr<Argument>>& GetArguments() const;
private: private:
BasicBlock* entry_ = nullptr; BasicBlock* entry_ = nullptr;
std::vector<std::unique_ptr<BasicBlock>> blocks_; std::vector<std::unique_ptr<BasicBlock>> blocks_;
std::vector<std::unique_ptr<Argument>> arguments_;
}; };
class Module { class Module {
public: public:
Module() = default; Module() = default;
Context& GetContext(); Context& GetContext();
const Context& GetContext() const; const Context& GetContext() const;
// 创建函数时当前只显式传入返回类型,尚未接入完整的 FunctionType。
Function* CreateFunction(const std::string& name, Function* CreateFunction(const std::string& name,
std::shared_ptr<Type> ret_type); std::shared_ptr<Type> ret_type,
std::vector<std::shared_ptr<Type>> param_types = {});
const std::vector<std::unique_ptr<Function>>& GetFunctions() const; const std::vector<std::unique_ptr<Function>>& GetFunctions() const;
GlobalValue* CreateGlobalValue(const std::string& name,
std::shared_ptr<Type> ty,
ConstantValue* init = nullptr);
const std::vector<std::unique_ptr<GlobalValue>>& GetGlobalValues() const;
private: private:
Context context_; Context context_;
std::vector<std::unique_ptr<Function>> functions_; std::vector<std::unique_ptr<Function>> functions_;
std::vector<std::unique_ptr<GlobalValue>> global_values_;
}; };
class IRBuilder { class IRBuilder {
@@ -297,13 +445,42 @@ class IRBuilder {
// 构造常量、二元运算、返回指令的最小集合。 // 构造常量、二元运算、返回指令的最小集合。
ConstantInt* CreateConstInt(int v); ConstantInt* CreateConstInt(int v);
ConstantFloat* CreateConstFloat(float v);
BinaryInst* CreateBinary(Opcode op, Value* lhs, Value* rhs, BinaryInst* CreateBinary(Opcode op, Value* lhs, Value* rhs,
const std::string& name); const std::string& name);
BinaryInst* CreateAdd(Value* lhs, Value* rhs, const std::string& name); BinaryInst* CreateAdd(Value* lhs, Value* rhs, const std::string& name);
BinaryInst* CreateSub(Value* lhs, Value* rhs, const std::string& name);
BinaryInst* CreateMul(Value* lhs, Value* rhs, const std::string& name);
BinaryInst* CreateDiv(Value* lhs, Value* rhs, const std::string& name);
BinaryInst* CreateMod(Value* lhs, Value* rhs, const std::string& name);
BinaryInst* CreateFAdd(Value* lhs, Value* rhs, const std::string& name);
BinaryInst* CreateFSub(Value* lhs, Value* rhs, const std::string& name);
BinaryInst* CreateFMul(Value* lhs, Value* rhs, const std::string& name);
BinaryInst* CreateFDiv(Value* lhs, Value* rhs, const std::string& name);
BinaryInst* CreateICmp(Opcode op, Value* lhs, Value* rhs,
const std::string& name);
BinaryInst* CreateFCmp(Opcode op, Value* lhs, Value* rhs,
const std::string& name);
AllocaInst* CreateAlloca(std::shared_ptr<Type> ty, const std::string& name);
AllocaInst* CreateAllocaI32(const std::string& name); AllocaInst* CreateAllocaI32(const std::string& name);
LoadInst* CreateLoad(Value* ptr, const std::string& name); LoadInst* CreateLoad(Value* ptr, const std::string& name);
StoreInst* CreateStore(Value* val, Value* ptr); StoreInst* CreateStore(Value* val, Value* ptr);
ReturnInst* CreateRet(Value* v); ReturnInst* CreateRet(Value* v);
BranchInst* CreateBr(BasicBlock* dest);
BranchInst* CreateCondBr(Value* cond, BasicBlock* if_true,
BasicBlock* if_false);
CallInst* CreateCall(Function* func, const std::vector<Value*>& args,
const std::string& name = "");
GetElementPtrInst* CreateGEP(std::shared_ptr<Type> ptr_ty, Value* ptr,
const std::vector<Value*>& indices,
const std::string& name = "");
CastInst* CreateZExt(Value* val, std::shared_ptr<Type> ty,
const std::string& name = "");
CastInst* CreateSIToFP(Value* val, std::shared_ptr<Type> ty,
const std::string& name = "");
CastInst* CreateFPToSI(Value* val, std::shared_ptr<Type> ty,
const std::string& name = "");
PhiInst* CreatePhi(std::shared_ptr<Type> ty, const std::string& name = "");
private: private:
Context& ctx_; Context& ctx_;

47
include/ir/PassManager.h Normal file
View File

@@ -0,0 +1,47 @@
#pragma once
#include "ir/IR.h"
#include <vector>
#include <unordered_map>
#include <unordered_set>
namespace ir {
// Dominator Tree Analysis
class DominatorTree {
public:
explicit DominatorTree(Function* func);
void Run();
// Query interfaces
BasicBlock* GetIdom(BasicBlock* bb) const;
const std::vector<BasicBlock*>& GetDominatedBlocks(BasicBlock* bb) const;
const std::vector<BasicBlock*>& GetDominanceFrontier(BasicBlock* bb) const;
bool Dominates(BasicBlock* a, BasicBlock* b) const;
private:
Function* func_;
std::vector<BasicBlock*> rpo_;
std::unordered_map<BasicBlock*, BasicBlock*> idom_;
std::unordered_map<BasicBlock*, std::vector<BasicBlock*>> dom_tree_;
std::unordered_map<BasicBlock*, std::vector<BasicBlock*>> df_;
void ComputeRPO();
void ComputeIdom();
void ComputeDomTree();
void ComputeDF();
};
// Individual Pass Declarations
bool RunMem2Reg(Function* func, Context& ctx);
bool RunConstProp(Function* func, Context& ctx);
bool RunConstFold(Function* func, Context& ctx);
bool RunDCE(Function* func);
bool RunCFGSimplify(Function* func);
bool RunCSE(Function* func);
// Run the optimization pipeline on a Function or Module
void RunOptimizationPasses(Module& module);
void RunFunctionOptimizationPasses(Function* func, Context& ctx);
} // namespace ir

View File

@@ -5,8 +5,10 @@
#include <any> #include <any>
#include <memory> #include <memory>
#include <stack>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <vector>
#include "SysYBaseVisitor.h" #include "SysYBaseVisitor.h"
#include "SysYParser.h" #include "SysYParser.h"
@@ -18,24 +20,56 @@ class Module;
class Function; class Function;
class IRBuilder; class IRBuilder;
class Value; class Value;
class BasicBlock;
} }
class IRGenImpl final : public SysYBaseVisitor { class IRGenImpl final : public SysYBaseVisitor {
public: public:
IRGenImpl(ir::Module& module, const SemanticContext& sema); IRGenImpl(ir::Module& module, const SemanticContext& sema);
// Top-level rules
std::any visitCompUnit(SysYParser::CompUnitContext* ctx) override; std::any visitCompUnit(SysYParser::CompUnitContext* ctx) override;
std::any visitDecl(SysYParser::DeclContext* ctx) override;
std::any visitConstDecl(SysYParser::ConstDeclContext* ctx) override;
std::any visitVarDecl(SysYParser::VarDeclContext* ctx) override;
std::any visitConstDef(SysYParser::ConstDefContext* ctx) override;
std::any visitVarDef(SysYParser::VarDefContext* ctx) override;
std::any visitFuncDef(SysYParser::FuncDefContext* ctx) override; std::any visitFuncDef(SysYParser::FuncDefContext* ctx) override;
std::any visitFuncFParam(SysYParser::FuncFParamContext* ctx) override;
// Statement rules
std::any visitBlockStmt(SysYParser::BlockStmtContext* ctx) override; std::any visitBlockStmt(SysYParser::BlockStmtContext* ctx) override;
std::any visitBlockItem(SysYParser::BlockItemContext* ctx) override; std::any visitBlockItem(SysYParser::BlockItemContext* ctx) override;
std::any visitDecl(SysYParser::DeclContext* ctx) override;
std::any visitStmt(SysYParser::StmtContext* ctx) override; std::any visitStmt(SysYParser::StmtContext* ctx) override;
std::any visitVarDef(SysYParser::VarDefContext* ctx) override; std::any visitAssignStmt(SysYParser::AssignStmtContext* ctx) override;
std::any visitReturnStmt(SysYParser::ReturnStmtContext* ctx) override; std::any visitReturnStmt(SysYParser::ReturnStmtContext* ctx) override;
std::any visitIfStmt(SysYParser::IfStmtContext* ctx) override;
std::any visitWhileStmt(SysYParser::WhileStmtContext* ctx) override;
std::any visitBreakStmt(SysYParser::BreakStmtContext* ctx) override;
std::any visitContinueStmt(SysYParser::ContinueStmtContext* ctx) override;
std::any visitExpStmt(SysYParser::ExpStmtContext* ctx) override;
// Expression rules
std::any visitParenExp(SysYParser::ParenExpContext* ctx) override; std::any visitParenExp(SysYParser::ParenExpContext* ctx) override;
std::any visitLValueExp(SysYParser::LValueExpContext* ctx) override;
std::any visitNumberExp(SysYParser::NumberExpContext* ctx) override; std::any visitNumberExp(SysYParser::NumberExpContext* ctx) override;
std::any visitVarExp(SysYParser::VarExpContext* ctx) override; std::any visitFuncCallExp(SysYParser::FuncCallExpContext* ctx) override;
std::any visitAdditiveExp(SysYParser::AdditiveExpContext* ctx) override; std::any visitNotExp(SysYParser::NotExpContext* ctx) override;
std::any visitUnaryAddExp(SysYParser::UnaryAddExpContext* ctx) override;
std::any visitUnarySubExp(SysYParser::UnarySubExpContext* ctx) override;
std::any visitMulExp(SysYParser::MulExpContext* ctx) override;
std::any visitDivExp(SysYParser::DivExpContext* ctx) override;
std::any visitModExp(SysYParser::ModExpContext* ctx) override;
std::any visitAddExp(SysYParser::AddExpContext* ctx) override;
std::any visitSubExp(SysYParser::SubExpContext* ctx) override;
std::any visitLtExp(SysYParser::LtExpContext* ctx) override;
std::any visitLeExp(SysYParser::LeExpContext* ctx) override;
std::any visitGtExp(SysYParser::GtExpContext* ctx) override;
std::any visitGeExp(SysYParser::GeExpContext* ctx) override;
std::any visitEqExp(SysYParser::EqExpContext* ctx) override;
std::any visitNeExp(SysYParser::NeExpContext* ctx) override;
std::any visitAndExp(SysYParser::AndExpContext* ctx) override;
std::any visitOrExp(SysYParser::OrExpContext* ctx) override;
private: private:
enum class BlockFlow { enum class BlockFlow {
@@ -43,15 +77,35 @@ class IRGenImpl final : public SysYBaseVisitor {
Terminated, Terminated,
}; };
BlockFlow VisitBlockItemResult(SysYParser::BlockItemContext& item);
ir::Value* EvalExpr(SysYParser::ExpContext& expr); ir::Value* EvalExpr(SysYParser::ExpContext& expr);
ir::ConstantValue* EvalConstExpr(SysYParser::ExpContext& expr);
ir::Value* GetLValuePtr(SysYParser::LValueContext* ctx);
ir::Value* DecayArrayPtr(SysYParser::LValueContext* ctx);
bool IsArrayLikeDef(antlr4::ParserRuleContext* def) const;
size_t GetArrayRank(antlr4::ParserRuleContext* def) const;
std::shared_ptr<ir::Type> GetDefType(antlr4::ParserRuleContext* def) const;
void ZeroInitializeLocal(ir::Value* ptr, std::shared_ptr<ir::Type> ty);
void EmitLocalInitValue(ir::Value* ptr, std::shared_ptr<ir::Type> ty,
SysYParser::InitValueContext* init);
ir::Module& module_; ir::Module& module_;
const SemanticContext& sema_; const SemanticContext& sema_;
ir::Function* func_; ir::Function* func_;
ir::IRBuilder builder_; ir::IRBuilder builder_;
// 名称绑定由 Sema 负责IRGen 只维护“声明 -> 存储槽位”的代码生成状态。
std::unordered_map<SysYParser::VarDefContext*, ir::Value*> storage_map_; // Maps a definition (VarDef, ConstDef, FuncFParam) to its IR value (Alloca or GlobalValue)
std::unordered_map<antlr4::ParserRuleContext*, ir::Value*> storage_map_;
// For global scope tracking
bool is_global_scope_ = true;
// For loop control
std::stack<ir::BasicBlock*> break_stack_;
std::stack<ir::BasicBlock*> continue_stack_;
// Helper to handle short-circuiting and comparison results
ir::Value* ToI1(ir::Value* v);
ir::Value* ToI32(ir::Value* v);
}; };
std::unique_ptr<ir::Module> GenerateIR(SysYParser::CompUnitContext& tree, std::unique_ptr<ir::Module> GenerateIR(SysYParser::CompUnitContext& tree,

View File

@@ -19,7 +19,14 @@ class MIRContext {
MIRContext& DefaultContext(); MIRContext& DefaultContext();
enum class PhysReg { W0, W8, W9, X29, X30, SP }; enum class PhysReg {
W0, W1, W2, W3, W4, W5, W6, W7, W8, W9, W10, W11, W12, W13, W14, W15,
W19, W20, W21, W22, W23, W24, W25, W26, W27, W28,
X0, X1, X2, X3, X4, X5, X6, X7, X8, X9, X10, X11, X12, X13, X14, X15,
X19, X20, X21, X22, X23, X24, X25, X26, X27, X28,
S0, S1, S2, S3, S4, S5, S6, S7, S8, S9, S10, S11, S12, S13, S14, S15,
X29, X30, SP
};
const char* PhysRegName(PhysReg reg); const char* PhysRegName(PhysReg reg);
@@ -30,28 +37,57 @@ enum class Opcode {
LoadStack, LoadStack,
StoreStack, StoreStack,
AddRR, AddRR,
SubRR,
MulRR,
SDivRR,
MSubRRRR,
FAddRRR,
FSubRRR,
FMulRRR,
FDivRRR,
CmpRR,
FCmpRR,
Cset,
B,
BCond,
Call,
Ret, Ret,
MovReg,
Adrp,
AddRegImm,
LdrRegReg,
StrRegReg,
SIToFP,
FPToSI,
ZExt
}; };
class Operand { class Operand {
public: public:
enum class Kind { Reg, Imm, FrameIndex }; enum class Kind { Reg, Imm, FrameIndex, Global, Label, Cond };
static Operand Reg(PhysReg reg); static Operand Reg(PhysReg reg);
static Operand Imm(int value); static Operand Imm(int value);
static Operand FrameIndex(int index); static Operand FrameIndex(int index);
static Operand Global(std::string name);
static Operand Label(std::string name);
static Operand Cond(std::string cond);
Kind GetKind() const { return kind_; } Kind GetKind() const { return kind_; }
PhysReg GetReg() const { return reg_; } PhysReg GetReg() const { return reg_; }
int GetImm() const { return imm_; } int GetImm() const { return imm_; }
int GetFrameIndex() const { return imm_; } int GetFrameIndex() const { return imm_; }
const std::string& GetGlobalName() const { return str_; }
const std::string& GetLabelName() const { return str_; }
const std::string& GetCondCode() const { return str_; }
private: private:
Operand(Kind kind, PhysReg reg, int imm); Operand(Kind kind, PhysReg reg, int imm, std::string str = "");
Kind kind_; Kind kind_;
PhysReg reg_; PhysReg reg_;
int imm_; int imm_;
std::string str_;
}; };
class MachineInstr { class MachineInstr {
@@ -93,9 +129,12 @@ class MachineFunction {
explicit MachineFunction(std::string name); explicit MachineFunction(std::string name);
const std::string& GetName() const { return name_; } const std::string& GetName() const { return name_; }
MachineBasicBlock& GetEntry() { return entry_; }
const MachineBasicBlock& GetEntry() const { return entry_; } MachineBasicBlock& CreateBlock(std::string name);
std::vector<MachineBasicBlock>& GetBlocks() { return blocks_; }
const std::vector<MachineBasicBlock>& GetBlocks() const { return blocks_; }
// Stack/Frame management
int CreateFrameIndex(int size = 4); int CreateFrameIndex(int size = 4);
FrameSlot& GetFrameSlot(int index); FrameSlot& GetFrameSlot(int index);
const FrameSlot& GetFrameSlot(int index) const; const FrameSlot& GetFrameSlot(int index) const;
@@ -106,14 +145,16 @@ class MachineFunction {
private: private:
std::string name_; std::string name_;
MachineBasicBlock entry_; std::vector<MachineBasicBlock> blocks_;
std::vector<FrameSlot> frame_slots_; std::vector<FrameSlot> frame_slots_;
int frame_size_ = 0; int frame_size_ = 0;
}; };
std::unique_ptr<MachineFunction> LowerToMIR(const ir::Module& module); std::vector<std::unique_ptr<MachineFunction>> LowerToMIR(const ir::Module& module);
void RunRegAlloc(MachineFunction& function); void RunRegAlloc(MachineFunction& function);
void RunFrameLowering(MachineFunction& function); void RunFrameLowering(MachineFunction& function);
void RunPeephole(MachineFunction& function);
void PrintAsm(const MachineFunction& function, std::ostream& os); void PrintAsm(const MachineFunction& function, std::ostream& os);
void PrintGlobals(const ir::Module& module, std::ostream& os);
} // namespace mir } // namespace mir

View File

@@ -1,30 +1,40 @@
// 基于语法树的语义检查与名称绑定。
#pragma once #pragma once
#include <unordered_map> #include <unordered_map>
#include <vector>
#include "SysYParser.h" #include "SysYParser.h"
class SemanticContext { class SemanticContext {
public: public:
void BindVarUse(SysYParser::VarContext* use, void BindLValue(SysYParser::LValueContext* use,
SysYParser::VarDefContext* decl) { antlr4::ParserRuleContext* def) {
var_uses_[use] = decl; lvalue_defs_[use] = def;
} }
SysYParser::VarDefContext* ResolveVarUse( void BindFuncCall(SysYParser::FuncCallExpContext* use,
const SysYParser::VarContext* use) const { SysYParser::FuncDefContext* def) {
auto it = var_uses_.find(use); funccall_defs_[use] = def;
return it == var_uses_.end() ? nullptr : it->second; }
antlr4::ParserRuleContext* ResolveLValue(
const SysYParser::LValueContext* use) const {
auto it = lvalue_defs_.find(const_cast<SysYParser::LValueContext*>(use));
return it == lvalue_defs_.end() ? nullptr : it->second;
}
SysYParser::FuncDefContext* ResolveFuncCall(
const SysYParser::FuncCallExpContext* use) const {
auto it = funccall_defs_.find(const_cast<SysYParser::FuncCallExpContext*>(use));
return it == funccall_defs_.end() ? nullptr : it->second;
} }
private: private:
std::unordered_map<const SysYParser::VarContext*, std::unordered_map<SysYParser::LValueContext*, antlr4::ParserRuleContext*>
SysYParser::VarDefContext*> lvalue_defs_;
var_uses_; std::unordered_map<SysYParser::FuncCallExpContext*,
SysYParser::FuncDefContext*>
funccall_defs_;
}; };
// 目前仅检查:
// - 变量先声明后使用
// - 局部变量不允许重复定义
SemanticContext RunSema(SysYParser::CompUnitContext& comp_unit); SemanticContext RunSema(SysYParser::CompUnitContext& comp_unit);

View File

@@ -1,17 +1,30 @@
// 极简符号表:记录局部变量定义点。
#pragma once #pragma once
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <vector>
#include "SysYParser.h" #include "SysYParser.h"
struct Symbol {
enum class Kind { Variable, Constant, Function, Parameter };
Kind kind;
antlr4::ParserRuleContext* def_ctx;
bool is_const = false;
bool is_array = false;
// For functions, we can store pointers to their parameter types or just the
// FuncDefContext*
};
class SymbolTable { class SymbolTable {
public: public:
void Add(const std::string& name, SysYParser::VarDefContext* decl); SymbolTable();
bool Contains(const std::string& name) const; void PushScope();
SysYParser::VarDefContext* Lookup(const std::string& name) const; void PopScope();
bool Add(const std::string& name, const Symbol& symbol);
Symbol* Lookup(const std::string& name);
bool IsInCurrentScope(const std::string& name) const;
private: private:
std::unordered_map<std::string, SysYParser::VarDefContext*> table_; std::vector<std::unordered_map<std::string, Symbol>> scopes_;
}; };

50
scripts/run_all_tests.sh Normal file
View File

@@ -0,0 +1,50 @@
#!/bin/bash
# 批量测试所有.sy文件的语法解析
test_dir="/home/lingli/nudt-compiler-cpp/test/test_case"
compiler="/home/lingli/nudt-compiler-cpp/build/bin/compiler"
if [ ! -f "$compiler" ]; then
echo "错误:编译器不存在,请先构建项目"
exit 1
fi
success_count=0
failed_count=0
failed_tests=()
echo "开始测试所有.sy文件的语法解析..."
echo "="
# 获取所有.sy文件并排序
for test_file in $(find "$test_dir" -name "*.sy" | sort); do
echo "测试: $(basename "$test_file")"
# 运行解析测试,捕获输出
output=$("$compiler" --emit-parse-tree "$test_file" 2>&1)
exit_code=$?
if [ $exit_code -eq 0 ]; then
echo " ✓ 成功"
((success_count++))
else
echo " ✗ 失败"
echo " 错误信息: $output"
((failed_count++))
failed_tests+=($(basename "$test_file"))
fi
done
echo "="
echo "测试完成!"
echo "总测试数: $((success_count + failed_count))"
echo "成功: $success_count"
echo "失败: $failed_count"
if [ $failed_count -gt 0 ]; then
echo "失败的测试用例:"
for test in "${failed_tests[@]}"; do
echo " - $test"
done
fi

View File

@@ -52,7 +52,7 @@ expected_file="$input_dir/$stem.out"
"$compiler" --emit-asm "$input" > "$asm_file" "$compiler" --emit-asm "$input" > "$asm_file"
echo "汇编已生成: $asm_file" echo "汇编已生成: $asm_file"
aarch64-linux-gnu-gcc "$asm_file" -o "$exe" aarch64-linux-gnu-gcc "$asm_file" sylib/sylib.c -o "$exe"
echo "可执行文件已生成: $exe" echo "可执行文件已生成: $exe"
if [[ "$run_exec" == true ]]; then if [[ "$run_exec" == true ]]; then

View File

@@ -1,8 +1,4 @@
// SysY 子集语法:支持形如 // SysY 语法:扩展支持更多SysY特性
// int main() { int a = 1; int b = 2; return a + b; }
// 的最小返回表达式编译。
// 后续需要自行添加
grammar SysY; grammar SysY;
@@ -10,21 +6,101 @@ grammar SysY;
/* Lexer rules */ /* Lexer rules */
/*===-------------------------------------------===*/ /*===-------------------------------------------===*/
// 关键字
INT: 'int'; INT: 'int';
FLOAT: 'float';
VOID: 'void';
CONST: 'const';
RETURN: 'return'; RETURN: 'return';
IF: 'if';
ELSE: 'else';
WHILE: 'while';
BREAK: 'break';
CONTINUE: 'continue';
// 操作符
ASSIGN: '='; ASSIGN: '=';
ADD: '+'; ADD: '+';
SUB: '-';
MUL: '*';
DIV: '/';
MOD: '%';
LT: '<';
LE: '<=';
GT: '>';
GE: '>=';
EQ: '==';
NE: '!=';
// 逻辑操作符
NOT: '!';
AND: '&&';
OR: '||';
// 括号
LPAREN: '('; LPAREN: '(';
RPAREN: ')'; RPAREN: ')';
LBRACE: '{'; LBRACE: '{';
RBRACE: '}'; RBRACE: '}';
LBRACK: '[';
RBRACK: ']';
// 标点
SEMICOLON: ';'; SEMICOLON: ';';
COMMA: ',';
// 标识符和字面量
ID: [a-zA-Z_][a-zA-Z_0-9]*; ID: [a-zA-Z_][a-zA-Z_0-9]*;
ILITERAL: [0-9]+; ILITERAL
: DECIMAL_LITERAL
| OCTAL_LITERAL
| HEX_LITERAL
;
fragment DECIMAL_LITERAL
: [0-9]+
;
fragment OCTAL_LITERAL
: '0' [0-7]+
;
fragment HEX_LITERAL
: '0' ('x' | 'X') [0-9a-fA-F]+
;
// 浮点字面量
FLITERAL
: (DECIMAL_FLOAT | HEX_FLOAT)
;
fragment DECIMAL_FLOAT
: ((DIGIT+ '.' DIGIT* | '.' DIGIT+)
(('E' | 'e') ('+' | '-')? DIGIT+)?)
| ((DIGIT+ '.' DIGIT* | '.' DIGIT+ | DIGIT+)
(('E' | 'e') ('+' | '-')? DIGIT+))
| ('0' [0-7]+ '.' [0-7]*
(('E' | 'e') ('+' | '-')? DIGIT+)?)
;
fragment HEX_FLOAT
: '0' ('x' | 'X')
(HEXDIGIT* '.' HEXDIGIT+ | HEXDIGIT+ '.')
(('P' | 'p') ('+' | '-')? DIGIT+)
| '0' ('x' | 'X')
HEXDIGIT+
(('P' | 'p') ('+' | '-')? DIGIT+)
;
fragment DIGIT
: [0-9]
;
fragment HEXDIGIT
: [0-9a-fA-F]
;
// 空白和注释
WS: [ \t\r\n] -> skip; WS: [ \t\r\n] -> skip;
LINECOMMENT: '//' ~[\r\n]* -> skip; LINECOMMENT: '//' ~[\r\n]* -> skip;
BLOCKCOMMENT: '/*' .*? '*/' -> skip; BLOCKCOMMENT: '/*' .*? '*/' -> skip;
@@ -34,33 +110,62 @@ BLOCKCOMMENT: '/*' .*? '*/' -> skip;
/*===-------------------------------------------===*/ /*===-------------------------------------------===*/
compUnit compUnit
: funcDef EOF : (decl | funcDef)* EOF
; ;
// 声明
decl decl
: btype varDef SEMICOLON : constDecl
| varDecl
;
constDecl
: CONST btype constDef (COMMA constDef)* SEMICOLON
;
varDecl
: btype varDef (COMMA varDef)* SEMICOLON
; ;
btype btype
: INT : INT
| FLOAT
| VOID
;
constDef
: ID (LBRACK exp RBRACK)* ASSIGN initValue
; ;
varDef varDef
: lValue (ASSIGN initValue)? : ID (LBRACK exp RBRACK)* (ASSIGN initValue)?
; ;
initValue initValue
: exp : exp
| LBRACE (initValue (COMMA initValue)*)? RBRACE
; ;
// 函数定义
funcDef funcDef
: funcType ID LPAREN RPAREN blockStmt : funcType ID LPAREN (funcFParams)? RPAREN blockStmt
; ;
funcType funcType
: INT : INT
| FLOAT
| VOID
; ;
funcFParams
: funcFParam (COMMA funcFParam)*
;
funcFParam
: btype ID (LBRACK (exp)? RBRACK)*
;
// 语句
blockStmt blockStmt
: LBRACE blockItem* RBRACE : LBRACE blockItem* RBRACE
; ;
@@ -71,28 +176,77 @@ blockItem
; ;
stmt stmt
: returnStmt : assignStmt
| returnStmt
| blockStmt
| ifStmt
| whileStmt
| breakStmt
| continueStmt
| expStmt
;
expStmt
: exp SEMICOLON
;
assignStmt
: lValue ASSIGN exp SEMICOLON
; ;
returnStmt returnStmt
: RETURN exp SEMICOLON : RETURN (exp)? SEMICOLON
;
ifStmt
: IF LPAREN exp RPAREN stmt (ELSE stmt)?
;
whileStmt
: WHILE LPAREN exp RPAREN stmt
;
breakStmt
: BREAK SEMICOLON
;
continueStmt
: CONTINUE SEMICOLON
;
// 表达式
lValue
: ID (LBRACK exp RBRACK)*
; ;
exp exp
: LPAREN exp RPAREN # parenExp : LPAREN exp RPAREN # parenExp
| var # varExp | lValue # lValueExp
| number # numberExp | number # numberExp
| exp ADD exp # additiveExp | ID LPAREN (funcRParams)? RPAREN # funcCallExp
| NOT exp # notExp
| ADD exp # unaryAddExp
| SUB exp # unarySubExp
| exp MUL exp # mulExp
| exp DIV exp # divExp
| exp MOD exp # modExp
| exp ADD exp # addExp
| exp SUB exp # subExp
| exp LT exp # ltExp
| exp LE exp # leExp
| exp GT exp # gtExp
| exp GE exp # geExp
| exp EQ exp # eqExp
| exp NE exp # neExp
| exp AND exp # andExp
| exp OR exp # orExp
; ;
var funcRParams
: ID : exp (COMMA exp)*
;
lValue
: ID
; ;
number number
: ILITERAL : ILITERAL
| FLITERAL
; ;

View File

@@ -15,7 +15,7 @@ namespace ir {
// 当前 BasicBlock 还没有专门的 label type因此先用 void 作为占位类型。 // 当前 BasicBlock 还没有专门的 label type因此先用 void 作为占位类型。
BasicBlock::BasicBlock(std::string name) BasicBlock::BasicBlock(std::string name)
: Value(Type::GetVoidType(), std::move(name)) {} : Value(Type::GetLabelType(), std::move(name)) {}
Function* BasicBlock::GetParent() const { return parent_; } Function* BasicBlock::GetParent() const { return parent_; }
@@ -42,4 +42,29 @@ const std::vector<BasicBlock*>& BasicBlock::GetSuccessors() const {
return successors_; return successors_;
} }
void BasicBlock::EraseInstruction(Instruction* inst) {
for (auto it = instructions_.begin(); it != instructions_.end(); ++it) {
if (it->get() == inst) {
inst->ClearOperands();
instructions_.erase(it);
break;
}
}
}
void BasicBlock::InsertInstructionBefore(std::unique_ptr<Instruction> inst, Instruction* before) {
for (auto it = instructions_.begin(); it != instructions_.end(); ++it) {
if (it->get() == before) {
inst->SetParent(this);
instructions_.insert(it, std::move(inst));
break;
}
}
}
void BasicBlock::InsertInstructionAtBegin(std::unique_ptr<Instruction> inst) {
inst->SetParent(this);
instructions_.insert(instructions_.begin(), std::move(inst));
}
} // namespace ir } // namespace ir

View File

@@ -3,7 +3,6 @@ add_library(ir_core STATIC
Module.cpp Module.cpp
Function.cpp Function.cpp
BasicBlock.cpp BasicBlock.cpp
GlobalValue.cpp
Type.cpp Type.cpp
Value.cpp Value.cpp
Instruction.cpp Instruction.cpp

View File

@@ -15,10 +15,18 @@ ConstantInt* Context::GetConstInt(int v) {
return inserted->second.get(); return inserted->second.get();
} }
ConstantFloat* Context::GetConstFloat(float v) {
auto it = const_floats_.find(v);
if (it != const_floats_.end()) return it->second.get();
auto inserted =
const_floats_
.emplace(v, std::make_unique<ConstantFloat>(Type::GetFloatType(), v))
.first;
return inserted->second.get();
}
std::string Context::NextTemp() { std::string Context::NextTemp() {
std::ostringstream oss; return "t" + std::to_string(++temp_index_);
oss << "%" << ++temp_index_;
return oss.str();
} }
} // namespace ir } // namespace ir

View File

@@ -5,9 +5,14 @@
namespace ir { namespace ir {
Function::Function(std::string name, std::shared_ptr<Type> ret_type) Function::Function(std::string name, std::shared_ptr<Type> ret_type,
std::vector<std::shared_ptr<Type>> param_types)
: Value(std::move(ret_type), std::move(name)) { : Value(std::move(ret_type), std::move(name)) {
entry_ = CreateBlock("entry"); for (size_t i = 0; i < param_types.size(); ++i) {
arguments_.push_back(std::make_unique<Argument>(
param_types[i], "a" + std::to_string(i), this,
static_cast<unsigned>(i)));
}
} }
BasicBlock* Function::CreateBlock(const std::string& name) { BasicBlock* Function::CreateBlock(const std::string& name) {
@@ -29,4 +34,8 @@ const std::vector<std::unique_ptr<BasicBlock>>& Function::GetBlocks() const {
return blocks_; return blocks_;
} }
const std::vector<std::unique_ptr<Argument>>& Function::GetArguments() const {
return arguments_;
}
} // namespace ir } // namespace ir

View File

@@ -5,7 +5,7 @@
namespace ir { namespace ir {
GlobalValue::GlobalValue(std::shared_ptr<Type> ty, std::string name) GlobalValue::GlobalValue(std::shared_ptr<Type> ty, std::string name, ConstantValue* init)
: User(std::move(ty), std::move(name)) {} : User(std::move(ty), std::move(name)), init_(init) {}
} // namespace ir } // namespace ir

View File

@@ -21,6 +21,11 @@ ConstantInt* IRBuilder::CreateConstInt(int v) {
return ctx_.GetConstInt(v); return ctx_.GetConstInt(v);
} }
ConstantFloat* IRBuilder::CreateConstFloat(float v) {
// 常量不需要挂在基本块里,由 Context 负责去重与生命周期。
return ctx_.GetConstFloat(v);
}
BinaryInst* IRBuilder::CreateBinary(Opcode op, Value* lhs, Value* rhs, BinaryInst* IRBuilder::CreateBinary(Opcode op, Value* lhs, Value* rhs,
const std::string& name) { const std::string& name) {
if (!insert_block_) { if (!insert_block_) {
@@ -42,11 +47,74 @@ BinaryInst* IRBuilder::CreateAdd(Value* lhs, Value* rhs,
return CreateBinary(Opcode::Add, lhs, rhs, name); return CreateBinary(Opcode::Add, lhs, rhs, name);
} }
AllocaInst* IRBuilder::CreateAllocaI32(const std::string& name) { BinaryInst* IRBuilder::CreateSub(Value* lhs, Value* rhs,
const std::string& name) {
return CreateBinary(Opcode::Sub, lhs, rhs, name);
}
BinaryInst* IRBuilder::CreateMul(Value* lhs, Value* rhs,
const std::string& name) {
return CreateBinary(Opcode::Mul, lhs, rhs, name);
}
BinaryInst* IRBuilder::CreateDiv(Value* lhs, Value* rhs,
const std::string& name) {
return CreateBinary(Opcode::Div, lhs, rhs, name);
}
BinaryInst* IRBuilder::CreateMod(Value* lhs, Value* rhs,
const std::string& name) {
return CreateBinary(Opcode::Mod, lhs, rhs, name);
}
BinaryInst* IRBuilder::CreateFAdd(Value* lhs, Value* rhs,
const std::string& name) {
return CreateBinary(Opcode::FAdd, lhs, rhs, name);
}
BinaryInst* IRBuilder::CreateFSub(Value* lhs, Value* rhs,
const std::string& name) {
return CreateBinary(Opcode::FSub, lhs, rhs, name);
}
BinaryInst* IRBuilder::CreateFMul(Value* lhs, Value* rhs,
const std::string& name) {
return CreateBinary(Opcode::FMul, lhs, rhs, name);
}
BinaryInst* IRBuilder::CreateFDiv(Value* lhs, Value* rhs,
const std::string& name) {
return CreateBinary(Opcode::FDiv, lhs, rhs, name);
}
BinaryInst* IRBuilder::CreateICmp(Opcode op, Value* lhs, Value* rhs,
const std::string& name) {
if (!insert_block_) { if (!insert_block_) {
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点")); throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
} }
return insert_block_->Append<AllocaInst>(Type::GetPtrInt32Type(), name); return insert_block_->Append<BinaryInst>(op, Type::GetInt32Type(), lhs, rhs,
name);
}
BinaryInst* IRBuilder::CreateFCmp(Opcode op, Value* lhs, Value* rhs,
const std::string& name) {
if (!insert_block_) {
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
}
return insert_block_->Append<BinaryInst>(op, Type::GetInt32Type(), lhs, rhs,
name);
}
AllocaInst* IRBuilder::CreateAlloca(std::shared_ptr<Type> ty,
const std::string& name) {
if (!insert_block_) {
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
}
return insert_block_->Append<AllocaInst>(ty, name);
}
AllocaInst* IRBuilder::CreateAllocaI32(const std::string& name) {
return CreateAlloca(Type::GetPtrInt32Type(), name);
} }
LoadInst* IRBuilder::CreateLoad(Value* ptr, const std::string& name) { LoadInst* IRBuilder::CreateLoad(Value* ptr, const std::string& name) {
@@ -57,7 +125,15 @@ LoadInst* IRBuilder::CreateLoad(Value* ptr, const std::string& name) {
throw std::runtime_error( throw std::runtime_error(
FormatError("ir", "IRBuilder::CreateLoad 缺少 ptr")); FormatError("ir", "IRBuilder::CreateLoad 缺少 ptr"));
} }
return insert_block_->Append<LoadInst>(Type::GetInt32Type(), ptr, name); std::shared_ptr<Type> val_ty;
if (ptr->GetType()->IsPtrInt32()) {
val_ty = Type::GetInt32Type();
} else if (ptr->GetType()->IsPtrFloat()) {
val_ty = Type::GetFloatType();
} else {
throw std::runtime_error(FormatError("ir", "LoadInst 不支持的指针类型"));
}
return insert_block_->Append<LoadInst>(val_ty, ptr, name);
} }
StoreInst* IRBuilder::CreateStore(Value* val, Value* ptr) { StoreInst* IRBuilder::CreateStore(Value* val, Value* ptr) {
@@ -79,11 +155,70 @@ ReturnInst* IRBuilder::CreateRet(Value* v) {
if (!insert_block_) { if (!insert_block_) {
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点")); throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
} }
if (!v) {
throw std::runtime_error(
FormatError("ir", "IRBuilder::CreateRet 缺少返回值"));
}
return insert_block_->Append<ReturnInst>(Type::GetVoidType(), v); return insert_block_->Append<ReturnInst>(Type::GetVoidType(), v);
} }
BranchInst* IRBuilder::CreateBr(BasicBlock* dest) {
if (!insert_block_) {
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
}
return insert_block_->Append<BranchInst>(dest);
}
BranchInst* IRBuilder::CreateCondBr(Value* cond, BasicBlock* if_true,
BasicBlock* if_false) {
if (!insert_block_) {
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
}
return insert_block_->Append<BranchInst>(cond, if_true, if_false);
}
CallInst* IRBuilder::CreateCall(Function* func, const std::vector<Value*>& args,
const std::string& name) {
if (!insert_block_) {
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
}
return insert_block_->Append<CallInst>(func, args, name);
}
GetElementPtrInst* IRBuilder::CreateGEP(std::shared_ptr<Type> ptr_ty, Value* ptr,
const std::vector<Value*>& indices,
const std::string& name) {
if (!insert_block_) {
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
}
return insert_block_->Append<GetElementPtrInst>(ptr_ty, ptr, indices, name);
}
CastInst* IRBuilder::CreateZExt(Value* val, std::shared_ptr<Type> ty,
const std::string& name) {
if (!insert_block_) {
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
}
return insert_block_->Append<CastInst>(Opcode::ZExt, ty, val, name);
}
CastInst* IRBuilder::CreateSIToFP(Value* val, std::shared_ptr<Type> ty,
const std::string& name) {
if (!insert_block_) {
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
}
return insert_block_->Append<CastInst>(Opcode::SIToFP, ty, val, name);
}
CastInst* IRBuilder::CreateFPToSI(Value* val, std::shared_ptr<Type> ty,
const std::string& name) {
if (!insert_block_) {
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
}
return insert_block_->Append<CastInst>(Opcode::FPToSI, ty, val, name);
}
PhiInst* IRBuilder::CreatePhi(std::shared_ptr<Type> ty, const std::string& name) {
if (!insert_block_) {
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
}
return insert_block_->Append<PhiInst>(ty, name);
}
} // namespace ir } // namespace ir

View File

@@ -4,7 +4,11 @@
#include "ir/IR.h" #include "ir/IR.h"
#include <cstdio>
#include <iomanip>
#include <sstream>
#include <ostream> #include <ostream>
#include <cstring>
#include <stdexcept> #include <stdexcept>
#include <string> #include <string>
@@ -12,7 +16,7 @@
namespace ir { namespace ir {
static const char* TypeToString(const Type& ty) { static std::string TypeToString(const Type& ty) {
switch (ty.GetKind()) { switch (ty.GetKind()) {
case Type::Kind::Void: case Type::Kind::Void:
return "void"; return "void";
@@ -20,11 +24,22 @@ static const char* TypeToString(const Type& ty) {
return "i32"; return "i32";
case Type::Kind::PtrInt32: case Type::Kind::PtrInt32:
return "i32*"; return "i32*";
case Type::Kind::Float:
return "float";
case Type::Kind::PtrFloat:
return "float*";
case Type::Kind::Label:
return "label";
case Type::Kind::Array: {
const auto* arr_ty = static_cast<const ArrayType*>(&ty);
return "[" + std::to_string(arr_ty->GetNumElements()) + " x " +
TypeToString(*arr_ty->GetElementType()) + "]";
}
} }
throw std::runtime_error(FormatError("ir", "未知类型")); return "unknown";
} }
static const char* OpcodeToString(Opcode op) { static std::string OpcodeToString(Opcode op) {
switch (op) { switch (op) {
case Opcode::Add: case Opcode::Add:
return "add"; return "add";
@@ -32,6 +47,42 @@ static const char* OpcodeToString(Opcode op) {
return "sub"; return "sub";
case Opcode::Mul: case Opcode::Mul:
return "mul"; return "mul";
case Opcode::Div:
return "sdiv";
case Opcode::Mod:
return "srem";
case Opcode::FAdd:
return "fadd";
case Opcode::FSub:
return "fsub";
case Opcode::FMul:
return "fmul";
case Opcode::FDiv:
return "fdiv";
case Opcode::ICmpEQ:
return "icmp eq";
case Opcode::ICmpNE:
return "icmp ne";
case Opcode::ICmpLT:
return "icmp slt";
case Opcode::ICmpGT:
return "icmp sgt";
case Opcode::ICmpLE:
return "icmp sle";
case Opcode::ICmpGE:
return "icmp sge";
case Opcode::FCmpEQ:
return "fcmp oeq";
case Opcode::FCmpNE:
return "fcmp une";
case Opcode::FCmpLT:
return "fcmp olt";
case Opcode::FCmpGT:
return "fcmp ogt";
case Opcode::FCmpLE:
return "fcmp ole";
case Opcode::FCmpGE:
return "fcmp oge";
case Opcode::Alloca: case Opcode::Alloca:
return "alloca"; return "alloca";
case Opcode::Load: case Opcode::Load:
@@ -40,21 +91,116 @@ static const char* OpcodeToString(Opcode op) {
return "store"; return "store";
case Opcode::Ret: case Opcode::Ret:
return "ret"; return "ret";
case Opcode::Br:
return "br";
case Opcode::Call:
return "call";
case Opcode::GEP:
return "getelementptr";
case Opcode::ZExt:
return "zext";
case Opcode::SIToFP:
return "sitofp";
case Opcode::FPToSI:
return "fptosi";
case Opcode::Phi:
return "phi";
} }
return "?"; return "?";
} }
static std::string ValueToString(const Value* v) { static std::string ValueToString(const Value* v) {
if (!v) return "<null>";
if (auto* ci = dynamic_cast<const ConstantInt*>(v)) { if (auto* ci = dynamic_cast<const ConstantInt*>(v)) {
return std::to_string(ci->GetValue()); return std::to_string(ci->GetValue());
} }
return v ? v->GetName() : "<null>"; if (auto* cf = dynamic_cast<const ConstantFloat*>(v)) {
const double as_double = static_cast<double>(cf->GetValue());
uint64_t bits = 0;
std::memcpy(&bits, &as_double, sizeof(bits));
std::ostringstream oss;
oss << "0x" << std::hex << std::uppercase << std::setw(16)
<< std::setfill('0') << bits;
return oss.str();
}
if (v->IsGlobalValue() || v->IsFunction()) {
return "@" + v->GetName();
}
if (v->IsInstruction() || v->IsArgument() || v->GetType()->IsLabel()) {
return "%" + v->GetName();
}
return v->GetName();
}
static bool IsBoolLikeValue(const Value* v) {
if (auto* inst = dynamic_cast<const Instruction*>(v)) {
switch (inst->GetOpcode()) {
case Opcode::ICmpEQ:
case Opcode::ICmpNE:
case Opcode::ICmpLT:
case Opcode::ICmpGT:
case Opcode::ICmpLE:
case Opcode::ICmpGE:
case Opcode::FCmpEQ:
case Opcode::FCmpNE:
case Opcode::FCmpLT:
case Opcode::FCmpGT:
case Opcode::FCmpLE:
case Opcode::FCmpGE:
return true;
default:
break;
}
}
return false;
}
static std::string PrintedValueType(const Value* v) {
if (IsBoolLikeValue(v)) return "i1";
return TypeToString(*v->GetType());
} }
void IRPrinter::Print(const Module& module, std::ostream& os) { void IRPrinter::Print(const Module& module, std::ostream& os) {
// Print global variables
for (const auto& gv : module.GetGlobalValues()) {
os << "@" << gv->GetName() << " = global ";
if (gv->GetType()->IsPtrInt32()) {
os << "i32";
} else if (gv->GetType()->IsPtrFloat()) {
os << "float";
} else {
os << TypeToString(*gv->GetType());
}
if (gv->GetInitializer()) {
os << " " << ValueToString(gv->GetInitializer());
} else {
os << " zeroinitializer";
}
os << "\n";
}
if (!module.GetGlobalValues().empty()) os << "\n";
for (const auto& func : module.GetFunctions()) { for (const auto& func : module.GetFunctions()) {
if (func->GetBlocks().empty()) {
os << "declare " << TypeToString(*func->GetType()) << " @" << func->GetName()
<< "(";
// For declarations, we just need types. But Argument objects might exist.
const auto& args = func->GetArguments();
for (size_t i = 0; i < args.size(); ++i) {
os << TypeToString(*args[i]->GetType());
if (i + 1 < args.size()) os << ", ";
}
os << ")\n\n";
continue;
}
os << "define " << TypeToString(*func->GetType()) << " @" << func->GetName() os << "define " << TypeToString(*func->GetType()) << " @" << func->GetName()
<< "() {\n"; << "(";
const auto& args = func->GetArguments();
for (size_t i = 0; i < args.size(); ++i) {
os << TypeToString(*args[i]->GetType()) << " %" << args[i]->GetName();
if (i + 1 < args.size()) os << ", ";
}
os << ") {\n";
for (const auto& bb : func->GetBlocks()) { for (const auto& bb : func->GetBlocks()) {
if (!bb) { if (!bb) {
continue; continue;
@@ -65,36 +211,152 @@ void IRPrinter::Print(const Module& module, std::ostream& os) {
switch (inst->GetOpcode()) { switch (inst->GetOpcode()) {
case Opcode::Add: case Opcode::Add:
case Opcode::Sub: case Opcode::Sub:
case Opcode::Mul: { case Opcode::Mul:
case Opcode::Div:
case Opcode::Mod:
case Opcode::FAdd:
case Opcode::FSub:
case Opcode::FMul:
case Opcode::FDiv: {
auto* bin = static_cast<const BinaryInst*>(inst); auto* bin = static_cast<const BinaryInst*>(inst);
os << " " << bin->GetName() << " = " os << " %" << bin->GetName() << " = "
<< OpcodeToString(bin->GetOpcode()) << " " << OpcodeToString(bin->GetOpcode()) << " "
<< TypeToString(*bin->GetLhs()->GetType()) << " " << PrintedValueType(bin->GetLhs()) << " "
<< ValueToString(bin->GetLhs()) << ", "
<< ValueToString(bin->GetRhs()) << "\n";
break;
}
case Opcode::ICmpEQ:
case Opcode::ICmpNE:
case Opcode::ICmpLT:
case Opcode::ICmpGT:
case Opcode::ICmpLE:
case Opcode::ICmpGE:
case Opcode::FCmpEQ:
case Opcode::FCmpNE:
case Opcode::FCmpLT:
case Opcode::FCmpGT:
case Opcode::FCmpLE:
case Opcode::FCmpGE: {
auto* bin = static_cast<const BinaryInst*>(inst);
os << " %" << bin->GetName() << " = "
<< OpcodeToString(bin->GetOpcode()) << " "
<< PrintedValueType(bin->GetLhs()) << " "
<< ValueToString(bin->GetLhs()) << ", " << ValueToString(bin->GetLhs()) << ", "
<< ValueToString(bin->GetRhs()) << "\n"; << ValueToString(bin->GetRhs()) << "\n";
break; break;
} }
case Opcode::Alloca: { case Opcode::Alloca: {
auto* alloca = static_cast<const AllocaInst*>(inst); auto* alloca = static_cast<const AllocaInst*>(inst);
os << " " << alloca->GetName() << " = alloca i32\n"; os << " %" << alloca->GetName() << " = alloca ";
if (alloca->GetType()->IsPtrInt32())
os << "i32";
else if (alloca->GetType()->IsPtrFloat())
os << "float";
else
os << TypeToString(*alloca->GetType());
os << "\n";
break; break;
} }
case Opcode::Load: { case Opcode::Load: {
auto* load = static_cast<const LoadInst*>(inst); auto* load = static_cast<const LoadInst*>(inst);
os << " " << load->GetName() << " = load i32, i32* " os << " %" << load->GetName() << " = load "
<< TypeToString(*load->GetType()) << ", "
<< TypeToString(*load->GetPtr()->GetType()) << " "
<< ValueToString(load->GetPtr()) << "\n"; << ValueToString(load->GetPtr()) << "\n";
break; break;
} }
case Opcode::Store: { case Opcode::Store: {
auto* store = static_cast<const StoreInst*>(inst); auto* store = static_cast<const StoreInst*>(inst);
os << " store i32 " << ValueToString(store->GetValue()) os << " store " << TypeToString(*store->GetValue()->GetType())
<< ", i32* " << ValueToString(store->GetPtr()) << "\n"; << " " << ValueToString(store->GetValue()) << ", "
<< TypeToString(*store->GetPtr()->GetType()) << " "
<< ValueToString(store->GetPtr()) << "\n";
break; break;
} }
case Opcode::Ret: { case Opcode::Ret: {
auto* ret = static_cast<const ReturnInst*>(inst); auto* ret = static_cast<const ReturnInst*>(inst);
os << " ret " << TypeToString(*ret->GetValue()->GetType()) << " " os << " ret ";
<< ValueToString(ret->GetValue()) << "\n"; if (ret->GetValue()) {
os << TypeToString(*ret->GetValue()->GetType()) << " "
<< ValueToString(ret->GetValue());
} else {
os << "void";
}
os << "\n";
break;
}
case Opcode::Br: {
auto* br = static_cast<const BranchInst*>(inst);
if (br->IsConditional()) {
os << " br i1 " << ValueToString(br->GetCondition())
<< ", label " << ValueToString(br->GetIfTrue()) << ", label "
<< ValueToString(br->GetIfFalse()) << "\n";
} else {
os << " br label " << ValueToString(br->GetDest()) << "\n";
}
break;
}
case Opcode::Call: {
auto* call = static_cast<const CallInst*>(inst);
auto* func = call->GetFunction();
if (!call->GetType()->IsVoid()) {
os << " %" << call->GetName() << " = ";
} else {
os << " ";
}
os << "call " << TypeToString(*call->GetType()) << " "
<< ValueToString(func) << "(";
for (size_t i = 1; i < call->GetNumOperands(); ++i) {
auto* arg = call->GetOperand(i);
os << PrintedValueType(arg) << " " << ValueToString(arg);
if (i + 1 < call->GetNumOperands()) os << ", ";
}
os << ")\n";
break;
}
case Opcode::GEP: {
auto* gep = static_cast<const GetElementPtrInst*>(inst);
os << " %" << gep->GetName() << " = getelementptr ";
if (gep->GetPtr()->GetType()->IsPtrInt32())
os << "i32";
else if (gep->GetPtr()->GetType()->IsPtrFloat())
os << "float";
else
os << TypeToString(*gep->GetPtr()->GetType());
os << ", ";
if (gep->GetPtr()->GetType()->IsArray()) {
os << TypeToString(*gep->GetPtr()->GetType()) << "* ";
} else {
os << TypeToString(*gep->GetPtr()->GetType()) << " ";
}
os << ValueToString(gep->GetPtr());
for (size_t i = 1; i < gep->GetNumOperands(); ++i) {
os << ", " << TypeToString(*gep->GetOperand(i)->GetType()) << " "
<< ValueToString(gep->GetOperand(i));
}
os << "\n";
break;
}
case Opcode::ZExt:
case Opcode::SIToFP:
case Opcode::FPToSI: {
auto* cast = static_cast<const CastInst*>(inst);
os << " %" << cast->GetName() << " = "
<< OpcodeToString(cast->GetOpcode()) << " "
<< PrintedValueType(cast->GetValue()) << " "
<< ValueToString(cast->GetValue()) << " to "
<< TypeToString(*cast->GetType()) << "\n";
break;
}
case Opcode::Phi: {
auto* phi = static_cast<const PhiInst*>(inst);
os << " %" << phi->GetName() << " = phi " << TypeToString(*phi->GetType()) << " ";
for (size_t i = 0; i < phi->GetNumIncoming(); ++i) {
if (i > 0) os << ", ";
os << "[ " << ValueToString(phi->GetIncomingValue(i)) << ", %" << phi->GetIncomingBlock(i)->GetName() << " ]";
}
os << "\n";
break; break;
} }
} }

View File

@@ -47,12 +47,24 @@ void User::AddOperand(Value* value) {
value->AddUse(this, operand_index); value->AddUse(this, operand_index);
} }
void User::ClearOperands() {
for (size_t i = 0; i < operands_.size(); ++i) {
auto* old = operands_[i];
if (old) {
old->RemoveUse(this, i);
}
}
operands_.clear();
}
Instruction::Instruction(Opcode op, std::shared_ptr<Type> ty, std::string name) Instruction::Instruction(Opcode op, std::shared_ptr<Type> ty, std::string name)
: User(std::move(ty), std::move(name)), opcode_(op) {} : User(std::move(ty), std::move(name)), opcode_(op) {}
Opcode Instruction::GetOpcode() const { return opcode_; } Opcode Instruction::GetOpcode() const { return opcode_; }
bool Instruction::IsTerminator() const { return opcode_ == Opcode::Ret; } bool Instruction::IsTerminator() const {
return opcode_ == Opcode::Ret || opcode_ == Opcode::Br;
}
BasicBlock* Instruction::GetParent() const { return parent_; } BasicBlock* Instruction::GetParent() const { return parent_; }
@@ -61,22 +73,9 @@ void Instruction::SetParent(BasicBlock* parent) { parent_ = parent; }
BinaryInst::BinaryInst(Opcode op, std::shared_ptr<Type> ty, Value* lhs, BinaryInst::BinaryInst(Opcode op, std::shared_ptr<Type> ty, Value* lhs,
Value* rhs, std::string name) Value* rhs, std::string name)
: Instruction(op, std::move(ty), std::move(name)) { : Instruction(op, std::move(ty), std::move(name)) {
if (op != Opcode::Add) {
throw std::runtime_error(FormatError("ir", "BinaryInst 当前只支持 Add"));
}
if (!lhs || !rhs) { if (!lhs || !rhs) {
throw std::runtime_error(FormatError("ir", "BinaryInst 缺少操作数")); throw std::runtime_error(FormatError("ir", "BinaryInst 缺少操作数"));
} }
if (!type_ || !lhs->GetType() || !rhs->GetType()) {
throw std::runtime_error(FormatError("ir", "BinaryInst 缺少类型信息"));
}
if (lhs->GetType()->GetKind() != rhs->GetType()->GetKind() ||
type_->GetKind() != lhs->GetType()->GetKind()) {
throw std::runtime_error(FormatError("ir", "BinaryInst 类型不匹配"));
}
if (!type_->IsInt32()) {
throw std::runtime_error(FormatError("ir", "BinaryInst 当前只支持 i32"));
}
AddOperand(lhs); AddOperand(lhs);
AddOperand(rhs); AddOperand(rhs);
} }
@@ -85,38 +84,85 @@ Value* BinaryInst::GetLhs() const { return GetOperand(0); }
Value* BinaryInst::GetRhs() const { return GetOperand(1); } Value* BinaryInst::GetRhs() const { return GetOperand(1); }
ReturnInst::ReturnInst(std::shared_ptr<Type> void_ty, Value* val) BranchInst::BranchInst(BasicBlock* dest)
: Instruction(Opcode::Ret, std::move(void_ty), "") { : Instruction(Opcode::Br, Type::GetVoidType(), "") {
if (!val) { AddOperand(dest);
throw std::runtime_error(FormatError("ir", "ReturnInst 缺少返回值")); }
BranchInst::BranchInst(Value* cond, BasicBlock* if_true, BasicBlock* if_false)
: Instruction(Opcode::Br, Type::GetVoidType(), "") {
AddOperand(cond);
AddOperand(if_true);
AddOperand(if_false);
}
bool BranchInst::IsConditional() const { return GetNumOperands() == 3; }
Value* BranchInst::GetCondition() const {
return IsConditional() ? GetOperand(0) : nullptr;
}
BasicBlock* BranchInst::GetIfTrue() const {
return IsConditional() ? static_cast<BasicBlock*>(GetOperand(1)) : nullptr;
}
BasicBlock* BranchInst::GetIfFalse() const {
return IsConditional() ? static_cast<BasicBlock*>(GetOperand(2)) : nullptr;
}
BasicBlock* BranchInst::GetDest() const {
return !IsConditional() ? static_cast<BasicBlock*>(GetOperand(0)) : nullptr;
}
CallInst::CallInst(Function* func, const std::vector<Value*>& args,
std::string name)
: Instruction(Opcode::Call, func->GetType(), std::move(name)) {
AddOperand(func);
for (auto* arg : args) {
AddOperand(arg);
} }
if (!type_ || !type_->IsVoid()) { }
throw std::runtime_error(FormatError("ir", "ReturnInst 返回类型必须为 void"));
Function* CallInst::GetFunction() const {
return static_cast<Function*>(GetOperand(0));
}
GetElementPtrInst::GetElementPtrInst(std::shared_ptr<Type> ptr_ty, Value* ptr,
const std::vector<Value*>& indices,
std::string name)
: Instruction(Opcode::GEP, std::move(ptr_ty), std::move(name)) {
AddOperand(ptr);
for (auto* idx : indices) {
AddOperand(idx);
} }
}
Value* GetElementPtrInst::GetPtr() const { return GetOperand(0); }
CastInst::CastInst(Opcode op, std::shared_ptr<Type> ty, Value* val,
std::string name)
: Instruction(op, std::move(ty), std::move(name)) {
AddOperand(val); AddOperand(val);
} }
Value* ReturnInst::GetValue() const { return GetOperand(0); } Value* CastInst::GetValue() const { return GetOperand(0); }
AllocaInst::AllocaInst(std::shared_ptr<Type> ptr_ty, std::string name) ReturnInst::ReturnInst(std::shared_ptr<Type> void_ty, Value* val)
: Instruction(Opcode::Alloca, std::move(ptr_ty), std::move(name)) { : Instruction(Opcode::Ret, std::move(void_ty), "") {
if (!type_ || !type_->IsPtrInt32()) { if (val) {
throw std::runtime_error(FormatError("ir", "AllocaInst 当前只支持 i32*")); AddOperand(val);
} }
} }
Value* ReturnInst::GetValue() const {
return GetNumOperands() > 0 ? GetOperand(0) : nullptr;
}
AllocaInst::AllocaInst(std::shared_ptr<Type> ptr_ty, std::string name)
: Instruction(Opcode::Alloca, std::move(ptr_ty), std::move(name)) {}
LoadInst::LoadInst(std::shared_ptr<Type> val_ty, Value* ptr, std::string name) LoadInst::LoadInst(std::shared_ptr<Type> val_ty, Value* ptr, std::string name)
: Instruction(Opcode::Load, std::move(val_ty), std::move(name)) { : Instruction(Opcode::Load, std::move(val_ty), std::move(name)) {
if (!ptr) {
throw std::runtime_error(FormatError("ir", "LoadInst 缺少 ptr"));
}
if (!type_ || !type_->IsInt32()) {
throw std::runtime_error(FormatError("ir", "LoadInst 当前只支持加载 i32"));
}
if (!ptr->GetType() || !ptr->GetType()->IsPtrInt32()) {
throw std::runtime_error(
FormatError("ir", "LoadInst 当前只支持从 i32* 加载"));
}
AddOperand(ptr); AddOperand(ptr);
} }
@@ -124,22 +170,6 @@ Value* LoadInst::GetPtr() const { return GetOperand(0); }
StoreInst::StoreInst(std::shared_ptr<Type> void_ty, Value* val, Value* ptr) StoreInst::StoreInst(std::shared_ptr<Type> void_ty, Value* val, Value* ptr)
: Instruction(Opcode::Store, std::move(void_ty), "") { : Instruction(Opcode::Store, std::move(void_ty), "") {
if (!val) {
throw std::runtime_error(FormatError("ir", "StoreInst 缺少 value"));
}
if (!ptr) {
throw std::runtime_error(FormatError("ir", "StoreInst 缺少 ptr"));
}
if (!type_ || !type_->IsVoid()) {
throw std::runtime_error(FormatError("ir", "StoreInst 返回类型必须为 void"));
}
if (!val->GetType() || !val->GetType()->IsInt32()) {
throw std::runtime_error(FormatError("ir", "StoreInst 当前只支持存储 i32"));
}
if (!ptr->GetType() || !ptr->GetType()->IsPtrInt32()) {
throw std::runtime_error(
FormatError("ir", "StoreInst 当前只支持写入 i32*"));
}
AddOperand(val); AddOperand(val);
AddOperand(ptr); AddOperand(ptr);
} }
@@ -148,4 +178,46 @@ Value* StoreInst::GetValue() const { return GetOperand(0); }
Value* StoreInst::GetPtr() const { return GetOperand(1); } Value* StoreInst::GetPtr() const { return GetOperand(1); }
PhiInst::PhiInst(std::shared_ptr<Type> ty, std::string name)
: Instruction(Opcode::Phi, std::move(ty), std::move(name)) {}
void PhiInst::AddIncoming(Value* val, BasicBlock* bb) {
AddOperand(val);
AddOperand(bb);
}
size_t PhiInst::GetNumIncoming() const {
return GetNumOperands() / 2;
}
Value* PhiInst::GetIncomingValue(size_t i) const {
return GetOperand(2 * i);
}
BasicBlock* PhiInst::GetIncomingBlock(size_t i) const {
return static_cast<BasicBlock*>(GetOperand(2 * i + 1));
}
void PhiInst::SetIncomingValue(size_t i, Value* val) {
SetOperand(2 * i, val);
}
void PhiInst::SetIncomingBlock(size_t i, BasicBlock* bb) {
SetOperand(2 * i + 1, bb);
}
void PhiInst::RemoveIncomingBlock(BasicBlock* bb) {
std::vector<Value*> new_ops;
for (size_t i = 0; i < GetNumIncoming(); ++i) {
if (GetIncomingBlock(i) != bb) {
new_ops.push_back(GetIncomingValue(i));
new_ops.push_back(GetIncomingBlock(i));
}
}
ClearOperands();
for (auto* op : new_ops) {
AddOperand(op);
}
}
} // namespace ir } // namespace ir

View File

@@ -9,8 +9,10 @@ Context& Module::GetContext() { return context_; }
const Context& Module::GetContext() const { return context_; } const Context& Module::GetContext() const { return context_; }
Function* Module::CreateFunction(const std::string& name, Function* Module::CreateFunction(const std::string& name,
std::shared_ptr<Type> ret_type) { std::shared_ptr<Type> ret_type,
functions_.push_back(std::make_unique<Function>(name, std::move(ret_type))); std::vector<std::shared_ptr<Type>> param_types) {
functions_.push_back(std::make_unique<Function>(name, std::move(ret_type),
std::move(param_types)));
return functions_.back().get(); return functions_.back().get();
} }
@@ -18,4 +20,15 @@ const std::vector<std::unique_ptr<Function>>& Module::GetFunctions() const {
return functions_; return functions_;
} }
GlobalValue* Module::CreateGlobalValue(const std::string& name,
std::shared_ptr<Type> ty,
ConstantValue* init) {
global_values_.push_back(std::make_unique<GlobalValue>(std::move(ty), name, init));
return global_values_.back().get();
}
const std::vector<std::unique_ptr<GlobalValue>>& Module::GetGlobalValues() const {
return global_values_;
}
} // namespace ir } // namespace ir

View File

@@ -20,6 +20,21 @@ const std::shared_ptr<Type>& Type::GetPtrInt32Type() {
return type; return type;
} }
const std::shared_ptr<Type>& Type::GetFloatType() {
static const std::shared_ptr<Type> type = std::make_shared<Type>(Kind::Float);
return type;
}
const std::shared_ptr<Type>& Type::GetPtrFloatType() {
static const std::shared_ptr<Type> type = std::make_shared<Type>(Kind::PtrFloat);
return type;
}
const std::shared_ptr<Type>& Type::GetLabelType() {
static const std::shared_ptr<Type> type = std::make_shared<Type>(Kind::Label);
return type;
}
Type::Kind Type::GetKind() const { return kind_; } Type::Kind Type::GetKind() const { return kind_; }
bool Type::IsVoid() const { return kind_ == Kind::Void; } bool Type::IsVoid() const { return kind_ == Kind::Void; }
@@ -28,4 +43,29 @@ bool Type::IsInt32() const { return kind_ == Kind::Int32; }
bool Type::IsPtrInt32() const { return kind_ == Kind::PtrInt32; } bool Type::IsPtrInt32() const { return kind_ == Kind::PtrInt32; }
bool Type::IsFloat() const { return kind_ == Kind::Float; }
bool Type::IsPtrFloat() const { return kind_ == Kind::PtrFloat; }
bool Type::IsLabel() const { return kind_ == Kind::Label; }
bool Type::IsArray() const { return kind_ == Kind::Array; }
std::shared_ptr<class ArrayType> Type::GetAsArrayType() {
if (IsArray()) {
return std::static_pointer_cast<ArrayType>(shared_from_this());
}
return nullptr;
}
ArrayType::ArrayType(std::shared_ptr<Type> element_type, uint32_t num_elements)
: Type(Kind::Array),
element_type_(std::move(element_type)),
num_elements_(num_elements) {}
std::shared_ptr<ArrayType> ArrayType::Get(std::shared_ptr<Type> element_type,
uint32_t num_elements) {
return std::make_shared<ArrayType>(std::move(element_type), num_elements);
}
} // namespace ir } // namespace ir

View File

@@ -22,6 +22,12 @@ bool Value::IsInt32() const { return type_ && type_->IsInt32(); }
bool Value::IsPtrInt32() const { return type_ && type_->IsPtrInt32(); } bool Value::IsPtrInt32() const { return type_ && type_->IsPtrInt32(); }
bool Value::IsFloat() const { return type_ && type_->IsFloat(); }
bool Value::IsPtrFloat() const { return type_ && type_->IsPtrFloat(); }
bool Value::IsLabel() const { return type_ && type_->IsLabel(); }
bool Value::IsConstant() const { bool Value::IsConstant() const {
return dynamic_cast<const ConstantValue*>(this) != nullptr; return dynamic_cast<const ConstantValue*>(this) != nullptr;
} }
@@ -38,6 +44,14 @@ bool Value::IsFunction() const {
return dynamic_cast<const Function*>(this) != nullptr; return dynamic_cast<const Function*>(this) != nullptr;
} }
bool Value::IsGlobalValue() const {
return dynamic_cast<const GlobalValue*>(this) != nullptr;
}
bool Value::IsArgument() const {
return dynamic_cast<const Argument*>(this) != nullptr;
}
void Value::AddUse(User* user, size_t operand_index) { void Value::AddUse(User* user, size_t operand_index) {
if (!user) return; if (!user) return;
uses_.push_back(Use(this, user, operand_index)); uses_.push_back(Use(this, user, operand_index));
@@ -74,10 +88,27 @@ void Value::ReplaceAllUsesWith(Value* new_value) {
} }
} }
Argument::Argument(std::shared_ptr<Type> ty, std::string name, Function* parent,
unsigned arg_no)
: Value(std::move(ty), std::move(name)),
parent_(parent),
arg_no_(arg_no) {}
ConstantValue::ConstantValue(std::shared_ptr<Type> ty, std::string name) ConstantValue::ConstantValue(std::shared_ptr<Type> ty, std::string name)
: Value(std::move(ty), std::move(name)) {} : Value(std::move(ty), std::move(name)) {}
ConstantInt::ConstantInt(std::shared_ptr<Type> ty, int v) ConstantInt::ConstantInt(std::shared_ptr<Type> ty, int v)
: ConstantValue(std::move(ty), ""), value_(v) {} : ConstantValue(std::move(ty), ""), value_(v) {}
ConstantFloat::ConstantFloat(std::shared_ptr<Type> ty, float v)
: ConstantValue(std::move(ty), ""), value_(v) {}
GlobalValue::GlobalValue(std::shared_ptr<Type> ty, std::string name,
ConstantValue* init)
: User(std::move(ty), std::move(name)), init_(init) {
if (init_) {
AddOperand(init_);
}
}
} // namespace ir } // namespace ir

View File

@@ -1,4 +1,192 @@
// 支配树分析: #include "ir/PassManager.h"
// - 构建/查询 Dominator Tree 及相关关系 #include <algorithm>
// - 为 mem2reg、CFG 优化与循环分析提供基础能力 #include <iostream>
#include <queue>
#include <unordered_set>
namespace ir {
// Helper to rebuild CFG predecessors and successors.
void RebuildCFG(Function* func) {
for (auto& bbPtr : func->GetBlocks()) {
bbPtr->ClearPredecessors();
bbPtr->ClearSuccessors();
}
for (auto& bbPtr : func->GetBlocks()) {
auto* bb = bbPtr.get();
const auto& insts = bb->GetInstructions();
if (insts.empty()) continue;
auto* term = insts.back().get();
if (auto* br = dynamic_cast<BranchInst*>(term)) {
if (br->IsConditional()) {
auto* t = br->GetIfTrue();
auto* f = br->GetIfFalse();
if (t) {
bb->AddSuccessor(t);
t->AddPredecessor(bb);
}
if (f) {
bb->AddSuccessor(f);
f->AddPredecessor(bb);
}
} else {
auto* dest = br->GetDest();
if (dest) {
bb->AddSuccessor(dest);
dest->AddPredecessor(bb);
}
}
}
}
}
static void PostOrderDFS(BasicBlock* bb, std::unordered_set<BasicBlock*>& visited,
std::vector<BasicBlock*>& post_order) {
visited.insert(bb);
for (auto* succ : bb->GetSuccessors()) {
if (visited.find(succ) == visited.end()) {
PostOrderDFS(succ, visited, post_order);
}
}
post_order.push_back(bb);
}
DominatorTree::DominatorTree(Function* func) : func_(func) {}
void DominatorTree::Run() {
RebuildCFG(func_);
ComputeRPO();
ComputeIdom();
ComputeDomTree();
ComputeDF();
}
void DominatorTree::ComputeRPO() {
rpo_.clear();
if (func_->GetBlocks().empty()) return;
std::unordered_set<BasicBlock*> visited;
std::vector<BasicBlock*> post_order;
PostOrderDFS(func_->GetEntry(), visited, post_order);
rpo_ = std::vector<BasicBlock*>(post_order.rbegin(), post_order.rend());
}
void DominatorTree::ComputeIdom() {
idom_.clear();
if (rpo_.empty()) return;
BasicBlock* entry = rpo_.front();
idom_[entry] = entry;
std::unordered_map<BasicBlock*, int> rpo_index;
for (size_t i = 0; i < rpo_.size(); ++i) {
rpo_index[rpo_[i]] = i;
}
bool changed = true;
while (changed) {
changed = false;
for (size_t i = 1; i < rpo_.size(); ++i) {
BasicBlock* b = rpo_[i];
BasicBlock* new_idom = nullptr;
// Find first predecessor with a defined idom
for (auto* pred : b->GetPredecessors()) {
if (idom_.find(pred) != idom_.end()) {
new_idom = pred;
break;
}
}
if (new_idom) {
for (auto* pred : b->GetPredecessors()) {
if (pred != new_idom && idom_.find(pred) != idom_.end()) {
// Intersect
auto* finger1 = pred;
auto* finger2 = new_idom;
while (finger1 != finger2) {
while (rpo_index.at(finger1) > rpo_index.at(finger2)) {
finger1 = idom_.at(finger1);
}
while (rpo_index.at(finger2) > rpo_index.at(finger1)) {
finger2 = idom_.at(finger2);
}
}
new_idom = finger1;
}
}
if (idom_.find(b) == idom_.end() || idom_[b] != new_idom) {
idom_[b] = new_idom;
changed = true;
}
}
}
}
}
void DominatorTree::ComputeDomTree() {
dom_tree_.clear();
for (auto* b : rpo_) {
dom_tree_[b] = {};
}
for (auto* b : rpo_) {
if (b != rpo_.front()) {
auto* parent = idom_[b];
dom_tree_[parent].push_back(b);
}
}
}
void DominatorTree::ComputeDF() {
df_.clear();
for (auto* b : rpo_) {
df_[b] = {};
}
for (auto* b : rpo_) {
if (b->GetPredecessors().size() >= 2) {
for (auto* pred : b->GetPredecessors()) {
auto* runner = pred;
auto* idom_b = idom_[b];
while (runner != idom_b) {
// If runner's df doesn't contain b already, add it
auto& runner_df = df_[runner];
if (std::find(runner_df.begin(), runner_df.end(), b) == runner_df.end()) {
runner_df.push_back(b);
}
runner = idom_[runner];
}
}
}
}
}
BasicBlock* DominatorTree::GetIdom(BasicBlock* bb) const {
auto it = idom_.find(bb);
return it != idom_.end() ? it->second : nullptr;
}
const std::vector<BasicBlock*>& DominatorTree::GetDominatedBlocks(BasicBlock* bb) const {
static const std::vector<BasicBlock*> empty;
auto it = dom_tree_.find(bb);
return it != dom_tree_.end() ? it->second : empty;
}
const std::vector<BasicBlock*>& DominatorTree::GetDominanceFrontier(BasicBlock* bb) const {
static const std::vector<BasicBlock*> empty;
auto it = df_.find(bb);
return it != df_.end() ? it->second : empty;
}
bool DominatorTree::Dominates(BasicBlock* a, BasicBlock* b) const {
if (a == b) return true;
auto* runner = b;
while (runner != rpo_.front()) {
auto it = idom_.find(runner);
if (it == idom_.end()) return false;
runner = it->second;
if (runner == a) return true;
}
return false;
}
} // namespace ir

View File

@@ -1,4 +1,128 @@
// CFG 简化: #include "ir/PassManager.h"
// - 删除不可达块、合并空块、简化分支等 #include <algorithm>
// - 改善 IR 结构,便于后续优化与后端生成 #include <iostream>
#include <queue>
#include <unordered_set>
#include <vector>
namespace ir {
// Predeclaration of CFG rebuild helper
void RebuildCFG(Function* func);
bool RunCFGSimplify(Function* func) {
bool changed = false;
bool local_changed = true;
while (local_changed) {
local_changed = false;
RebuildCFG(func);
// 1. Remove unreachable basic blocks
BasicBlock* entry = func->GetEntry();
std::unordered_set<BasicBlock*> reachable;
std::queue<BasicBlock*> worklist;
reachable.insert(entry);
worklist.push(entry);
while (!worklist.empty()) {
auto* curr = worklist.front();
worklist.pop();
for (auto* succ : curr->GetSuccessors()) {
if (reachable.find(succ) == reachable.end()) {
reachable.insert(succ);
worklist.push(succ);
}
}
}
std::vector<BasicBlock*> unreachable_blocks;
for (const auto& bbPtr : func->GetBlocks()) {
if (reachable.find(bbPtr.get()) == reachable.end()) {
unreachable_blocks.push_back(bbPtr.get());
}
}
if (!unreachable_blocks.empty()) {
changed = true;
local_changed = true;
for (auto* bb : unreachable_blocks) {
// Remove bb from predecessors of its successors, and clean up successor phi nodes
for (auto* succ : bb->GetSuccessors()) {
for (const auto& instPtr : succ->GetInstructions()) {
if (instPtr->GetOpcode() == Opcode::Phi) {
auto* phi = static_cast<PhiInst*>(instPtr.get());
phi->RemoveIncomingBlock(bb);
}
}
}
// Remove from func's blocks
auto& blocks = const_cast<std::vector<std::unique_ptr<BasicBlock>>&>(func->GetBlocks());
blocks.erase(std::remove_if(blocks.begin(), blocks.end(),
[&](const std::unique_ptr<BasicBlock>& b) {
return b.get() == bb;
}),
blocks.end());
}
continue; // Restart simplification loop safely
}
// 2. Merge basic block B with successor S if S has only one predecessor B
for (const auto& bbPtr : func->GetBlocks()) {
auto* b = bbPtr.get();
if (b->GetSuccessors().size() == 1) {
auto* s = b->GetSuccessors().front();
if (s != entry && s->GetPredecessors().size() == 1) {
changed = true;
local_changed = true;
// Replace all uses of block S as label with block B
s->ReplaceAllUsesWith(b);
// Erase B's terminator (the BranchInst to S)
auto* b_term = b->GetInstructions().back().get();
b->EraseInstruction(b_term);
// For any PhiInst in S: it has exactly 1 incoming value from B.
// Replace all uses of the PhiInst with its single incoming value.
std::vector<Instruction*> phi_to_remove;
for (const auto& instPtr : s->GetInstructions()) {
if (instPtr->GetOpcode() == Opcode::Phi) {
auto* phi = static_cast<PhiInst*>(instPtr.get());
if (phi->GetNumIncoming() > 0) {
phi->ReplaceAllUsesWith(phi->GetIncomingValue(0));
}
phi_to_remove.push_back(phi);
}
}
// Move instructions from S to B
auto& s_insts = const_cast<std::vector<std::unique_ptr<Instruction>>&>(s->GetInstructions());
for (auto& instPtr : s_insts) {
if (std::find(phi_to_remove.begin(), phi_to_remove.end(), instPtr.get()) == phi_to_remove.end()) {
instPtr->SetParent(b);
const_cast<std::vector<std::unique_ptr<Instruction>>&>(b->GetInstructions()).push_back(std::move(instPtr));
}
}
// Clear S's instructions to prevent any dangling or double frees
s_insts.clear();
// Erase S from func's blocks list
auto& blocks = const_cast<std::vector<std::unique_ptr<BasicBlock>>&>(func->GetBlocks());
blocks.erase(std::remove_if(blocks.begin(), blocks.end(),
[&](const std::unique_ptr<BasicBlock>& b) {
return b.get() == s;
}),
blocks.end());
break; // Break to restart loop safely
}
}
}
}
return changed;
}
} // namespace ir

View File

@@ -1,4 +1,88 @@
// 公共子表达式消除CSE #include "ir/PassManager.h"
// - 识别并复用重复计算的等价表达式 #include <iostream>
// - 典型放置在 ConstFold 之后、DCE 之前 #include <vector>
// - 当前为 Lab4 的框架占位,具体算法由实验实现 #include <tuple>
namespace ir {
static bool IsEquivalent(Instruction* a, Instruction* b) {
if (a->GetOpcode() != b->GetOpcode()) return false;
if (a->GetNumOperands() != b->GetNumOperands()) return false;
// Skip load, store, alloca, call, phi, branch, ret (since they have side-effects or special states)
switch (a->GetOpcode()) {
case Opcode::Add:
case Opcode::Sub:
case Opcode::Mul:
case Opcode::Div:
case Opcode::Mod:
case Opcode::FAdd:
case Opcode::FSub:
case Opcode::FMul:
case Opcode::FDiv:
case Opcode::ICmpEQ:
case Opcode::ICmpNE:
case Opcode::ICmpLT:
case Opcode::ICmpGT:
case Opcode::ICmpLE:
case Opcode::ICmpGE:
case Opcode::FCmpEQ:
case Opcode::FCmpNE:
case Opcode::FCmpLT:
case Opcode::FCmpGT:
case Opcode::FCmpLE:
case Opcode::FCmpGE:
case Opcode::GEP:
case Opcode::ZExt:
case Opcode::SIToFP:
case Opcode::FPToSI:
break;
default:
return false; // Skip all other opcodes
}
// Compare all operands
for (size_t i = 0; i < a->GetNumOperands(); ++i) {
if (a->GetOperand(i) != b->GetOperand(i)) {
return false;
}
}
return true;
}
bool RunCSE(Function* func) {
bool changed = false;
for (const auto& bbPtr : func->GetBlocks()) {
std::vector<Instruction*> seen_instructions;
std::vector<Instruction*> to_erase;
for (const auto& instPtr : bbPtr->GetInstructions()) {
auto* inst = instPtr.get();
Instruction* match = nullptr;
for (auto* seen : seen_instructions) {
if (IsEquivalent(inst, seen)) {
match = seen;
break;
}
}
if (match) {
inst->ReplaceAllUsesWith(match);
to_erase.push_back(inst);
changed = true;
} else {
seen_instructions.push_back(inst);
}
}
for (auto* inst : to_erase) {
bbPtr->EraseInstruction(inst);
}
}
return changed;
}
} // namespace ir

View File

@@ -1,4 +1,105 @@
// IR 常量折叠: #include "ir/PassManager.h"
// - 折叠可判定的常量表达式 #include <iostream>
// - 简化常量控制流分支(按实现范围裁剪) #include <cmath>
namespace ir {
ConstantValue* FoldInstruction(Instruction* inst, Context& ctx) {
if (inst->GetOpcode() == Opcode::ZExt) {
auto* cast = static_cast<CastInst*>(inst);
if (auto* ci = dynamic_cast<ConstantInt*>(cast->GetValue())) {
return ctx.GetConstInt(ci->GetValue()); // ZExt is trivial on constant int
}
}
if (inst->GetOpcode() == Opcode::SIToFP) {
auto* cast = static_cast<CastInst*>(inst);
if (auto* ci = dynamic_cast<ConstantInt*>(cast->GetValue())) {
return ctx.GetConstFloat(static_cast<float>(ci->GetValue()));
}
}
if (inst->GetOpcode() == Opcode::FPToSI) {
auto* cast = static_cast<CastInst*>(inst);
if (auto* cf = dynamic_cast<ConstantFloat*>(cast->GetValue())) {
return ctx.GetConstInt(static_cast<int>(cf->GetValue()));
}
}
// Binary operations
if (auto* bin = dynamic_cast<BinaryInst*>(inst)) {
auto* lhs = bin->GetLhs();
auto* rhs = bin->GetRhs();
auto* lhs_i = dynamic_cast<ConstantInt*>(lhs);
auto* rhs_i = dynamic_cast<ConstantInt*>(rhs);
auto* lhs_f = dynamic_cast<ConstantFloat*>(lhs);
auto* rhs_f = dynamic_cast<ConstantFloat*>(rhs);
if (lhs_i && rhs_i) {
int l = lhs_i->GetValue();
int r = rhs_i->GetValue();
switch (bin->GetOpcode()) {
case Opcode::Add: return ctx.GetConstInt(l + r);
case Opcode::Sub: return ctx.GetConstInt(l - r);
case Opcode::Mul: return ctx.GetConstInt(l * r);
case Opcode::Div: return (r != 0) ? ctx.GetConstInt(l / r) : nullptr;
case Opcode::Mod: return (r != 0) ? ctx.GetConstInt(l % r) : nullptr;
case Opcode::ICmpEQ: return ctx.GetConstInt(l == r ? 1 : 0);
case Opcode::ICmpNE: return ctx.GetConstInt(l != r ? 1 : 0);
case Opcode::ICmpLT: return ctx.GetConstInt(l < r ? 1 : 0);
case Opcode::ICmpGT: return ctx.GetConstInt(l > r ? 1 : 0);
case Opcode::ICmpLE: return ctx.GetConstInt(l <= r ? 1 : 0);
case Opcode::ICmpGE: return ctx.GetConstInt(l >= r ? 1 : 0);
default: break;
}
}
if (lhs_f && rhs_f) {
float l = lhs_f->GetValue();
float r = rhs_f->GetValue();
switch (bin->GetOpcode()) {
case Opcode::FAdd: return ctx.GetConstFloat(l + r);
case Opcode::FSub: return ctx.GetConstFloat(l - r);
case Opcode::FMul: return ctx.GetConstFloat(l * r);
case Opcode::FDiv: return (r != 0.0f) ? ctx.GetConstFloat(l / r) : nullptr;
case Opcode::FCmpEQ: return ctx.GetConstInt(l == r ? 1 : 0);
case Opcode::FCmpNE: return ctx.GetConstInt(l != r ? 1 : 0);
case Opcode::FCmpLT: return ctx.GetConstInt(l < r ? 1 : 0);
case Opcode::FCmpGT: return ctx.GetConstInt(l > r ? 1 : 0);
case Opcode::FCmpLE: return ctx.GetConstInt(l <= r ? 1 : 0);
case Opcode::FCmpGE: return ctx.GetConstInt(l >= r ? 1 : 0);
default: break;
}
}
}
return nullptr;
}
bool RunConstFold(Function* func, Context& ctx) {
bool changed = false;
std::vector<Instruction*> to_erase;
for (const auto& bbPtr : func->GetBlocks()) {
for (const auto& instPtr : bbPtr->GetInstructions()) {
auto* inst = instPtr.get();
if (inst->GetOpcode() == Opcode::Br || inst->GetOpcode() == Opcode::Ret || inst->GetOpcode() == Opcode::Phi) {
continue;
}
if (auto* folded = FoldInstruction(inst, ctx)) {
inst->ReplaceAllUsesWith(folded);
to_erase.push_back(inst);
changed = true;
}
}
}
for (auto* inst : to_erase) {
inst->GetParent()->EraseInstruction(inst);
}
return changed;
}
} // namespace ir

View File

@@ -1,5 +1,75 @@
// 常量传播Constant Propagation #include "ir/PassManager.h"
// - 沿 use-def 关系传播已知常量 #include <iostream>
// - 将可替换的 SSA 值改写为常量,暴露更多折叠机会 #include <vector>
// - 常与 ConstFold、DCE、CFGSimplify 迭代配合使用
namespace ir {
// Declare FoldInstruction from ConstFold.cpp
ConstantValue* FoldInstruction(Instruction* inst, Context& ctx);
bool RunConstProp(Function* func, Context& ctx) {
bool changed = false;
bool local_changed = true;
while (local_changed) {
local_changed = false;
std::vector<Instruction*> to_erase;
// 1. Fold instructions
for (const auto& bbPtr : func->GetBlocks()) {
for (const auto& instPtr : bbPtr->GetInstructions()) {
auto* inst = instPtr.get();
if (inst->GetOpcode() == Opcode::Br || inst->GetOpcode() == Opcode::Ret || inst->GetOpcode() == Opcode::Phi) {
continue;
}
if (auto* folded = FoldInstruction(inst, ctx)) {
inst->ReplaceAllUsesWith(folded);
to_erase.push_back(inst);
local_changed = true;
changed = true;
}
}
}
// Erase the folded instructions
for (auto* inst : to_erase) {
inst->GetParent()->EraseInstruction(inst);
}
// 2. Simplify conditional branches
for (const auto& bbPtr : func->GetBlocks()) {
auto* bb = bbPtr.get();
const auto& insts = bb->GetInstructions();
if (insts.empty()) continue;
auto* term = insts.back().get();
if (term->GetOpcode() == Opcode::Br) {
auto* br = static_cast<BranchInst*>(term);
if (br->IsConditional()) {
if (auto* cond_const = dynamic_cast<ConstantInt*>(br->GetCondition())) {
BasicBlock* target = (cond_const->GetValue() != 0) ? br->GetIfTrue() : br->GetIfFalse();
BasicBlock* dead_target = (cond_const->GetValue() != 0) ? br->GetIfFalse() : br->GetIfTrue();
if (dead_target != target) {
for (const auto& instPtr : dead_target->GetInstructions()) {
if (instPtr->GetOpcode() == Opcode::Phi) {
auto* phi = static_cast<PhiInst*>(instPtr.get());
phi->RemoveIncomingBlock(bb);
}
}
}
bb->EraseInstruction(br);
bb->Append<BranchInst>(target);
local_changed = true;
changed = true;
break; // Restart loop to handle CFG shifts safely
}
}
}
}
}
return changed;
}
} // namespace ir

View File

@@ -1,4 +1,75 @@
// 死代码删除DCE #include "ir/PassManager.h"
// - 删除无用指令与无用基本块 #include <iostream>
// - 通常与 CFG 简化配合使用 #include <unordered_set>
#include <queue>
#include <vector>
namespace ir {
bool RunDCE(Function* func) {
std::unordered_set<Instruction*> live_instructions;
std::queue<Instruction*> worklist;
// 1. Mark inherently live instructions
for (const auto& bbPtr : func->GetBlocks()) {
for (const auto& instPtr : bbPtr->GetInstructions()) {
auto* inst = instPtr.get();
bool inherently_live = false;
switch (inst->GetOpcode()) {
case Opcode::Ret:
case Opcode::Br:
case Opcode::Store:
case Opcode::Call:
inherently_live = true;
break;
default:
break;
}
if (inherently_live) {
live_instructions.insert(inst);
worklist.push(inst);
}
}
}
// 2. Propagate liveness along the def-use chains
while (!worklist.empty()) {
auto* inst = worklist.front();
worklist.pop();
for (size_t i = 0; i < inst->GetNumOperands(); ++i) {
auto* operand = inst->GetOperand(i);
if (auto* op_inst = dynamic_cast<Instruction*>(operand)) {
if (live_instructions.find(op_inst) == live_instructions.end()) {
live_instructions.insert(op_inst);
worklist.push(op_inst);
}
}
}
}
// 3. Sweep dead instructions
bool changed = false;
for (const auto& bbPtr : func->GetBlocks()) {
std::vector<Instruction*> dead_instructions;
for (const auto& instPtr : bbPtr->GetInstructions()) {
auto* inst = instPtr.get();
if (live_instructions.find(inst) == live_instructions.end()) {
dead_instructions.push_back(inst);
}
}
if (!dead_instructions.empty()) {
changed = true;
for (auto* inst : dead_instructions) {
bbPtr->EraseInstruction(inst);
}
}
}
return changed;
}
} // namespace ir

View File

@@ -1,4 +1,228 @@
// Mem2RegSSA 构造): #include "ir/PassManager.h"
// - 将局部变量的 alloca/load/store 提升为 SSA 形式 #include <iostream>
// - 插入 PHI 并重写使用,依赖支配树等分析 #include <unordered_map>
#include <unordered_set>
#include <vector>
#include <stack>
#include <algorithm>
#include <queue>
#include <functional>
namespace ir {
// Predeclaration of rebuild CFG helper
void RebuildCFG(Function* func);
bool RunMem2Reg(Function* func, Context& ctx) {
// 1. Build dominator tree
DominatorTree dom_tree(func);
dom_tree.Run();
// 2. Identify promotable allocas
std::vector<AllocaInst*> promotable_allocas;
for (const auto& bbPtr : func->GetBlocks()) {
for (const auto& instPtr : bbPtr->GetInstructions()) {
if (instPtr->GetOpcode() == Opcode::Alloca) {
auto* alloca = static_cast<AllocaInst*>(instPtr.get());
// Alloca of scalar type: i32 or float (pointers to i32/float in minimum IR)
if (alloca->GetType()->IsPtrInt32() || alloca->GetType()->IsPtrFloat()) {
// Verify all uses are load/store
bool promotable = true;
for (const auto& use : alloca->GetUses()) {
auto* user = use.GetUser();
auto* inst_user = dynamic_cast<Instruction*>(user);
if (!inst_user) {
promotable = false;
break;
}
if (inst_user->GetOpcode() != Opcode::Load && inst_user->GetOpcode() != Opcode::Store) {
promotable = false;
break;
}
// For Store, alloca must be the pointer operand (operand index 1), not the value operand
if (inst_user->GetOpcode() == Opcode::Store) {
auto* store = static_cast<StoreInst*>(inst_user);
if (store->GetPtr() != alloca) {
promotable = false;
break;
}
}
}
if (promotable) {
promotable_allocas.push_back(alloca);
}
}
}
}
}
if (promotable_allocas.empty()) {
return false;
}
// 3. For each alloca, find definition blocks and place Phi nodes
// Maps each basic block and alloca to the inserted Phi instruction
std::unordered_map<BasicBlock*, std::unordered_map<AllocaInst*, PhiInst*>> phi_nodes;
std::unordered_set<Instruction*> instructions_to_erase;
for (auto* alloca : promotable_allocas) {
std::vector<BasicBlock*> def_blocks;
for (const auto& use : alloca->GetUses()) {
auto* inst = dynamic_cast<Instruction*>(use.GetUser());
if (inst && inst->GetOpcode() == Opcode::Store) {
def_blocks.push_back(inst->GetParent());
}
}
// DF-based Phi placement
std::queue<BasicBlock*> worklist;
std::unordered_set<BasicBlock*> added;
std::unordered_set<BasicBlock*> def_set(def_blocks.begin(), def_blocks.end());
for (auto* bb : def_blocks) {
worklist.push(bb);
added.insert(bb);
}
while (!worklist.empty()) {
auto* x = worklist.front();
worklist.pop();
for (auto* y : dom_tree.GetDominanceFrontier(x)) {
if (added.find(y) == added.end()) {
// Place Phi node in Y
std::shared_ptr<Type> ty = alloca->GetType()->IsPtrFloat() ? Type::GetFloatType() : Type::GetInt32Type();
auto phi = std::make_unique<PhiInst>(ty, ctx.NextTemp());
auto* phi_ptr = phi.get();
// Insert Phi at the start of block Y
y->InsertInstructionAtBegin(std::move(phi));
phi_nodes[y][alloca] = phi_ptr;
added.insert(y);
if (def_set.find(y) == def_set.end()) {
worklist.push(y);
}
}
}
}
}
// 4. Rename variables using DFS traversal of dominator tree
std::unordered_map<AllocaInst*, std::vector<Value*>> current_def;
// Helper for generating default value
auto get_default_value = [&](AllocaInst* alloca) -> Value* {
if (alloca->GetType()->IsPtrFloat()) {
return ctx.GetConstFloat(0.0f);
} else {
return ctx.GetConstInt(0);
}
};
// Traversal stack for DFS: stores (block, parent_block)
struct TraversalNode {
BasicBlock* bb;
size_t child_idx;
};
std::stack<BasicBlock*> visit_stack;
std::unordered_map<BasicBlock*, std::vector<std::pair<AllocaInst*, size_t>>> pushed_defs;
// DFS function
std::function<void(BasicBlock*)> rename_dfs = [&](BasicBlock* bb) {
auto& pushes = pushed_defs[bb];
// Push Phis in this block to current_def
auto phi_it = phi_nodes.find(bb);
if (phi_it != phi_nodes.end()) {
for (const auto& pair : phi_it->second) {
auto* alloca = pair.first;
auto* phi = pair.second;
current_def[alloca].push_back(phi);
pushes.push_back({alloca, 1});
}
}
// Process loads and stores
for (const auto& instPtr : bb->GetInstructions()) {
auto* inst = instPtr.get();
if (inst->GetOpcode() == Opcode::Load) {
auto* load = static_cast<LoadInst*>(inst);
auto* ptr = load->GetPtr();
if (auto* alloca = dynamic_cast<AllocaInst*>(ptr)) {
if (std::find(promotable_allocas.begin(), promotable_allocas.end(), alloca) != promotable_allocas.end()) {
auto& defs = current_def[alloca];
Value* val = defs.empty() ? get_default_value(alloca) : defs.back();
load->ReplaceAllUsesWith(val);
instructions_to_erase.insert(load);
}
}
} else if (inst->GetOpcode() == Opcode::Store) {
auto* store = static_cast<StoreInst*>(inst);
auto* ptr = store->GetPtr();
if (auto* alloca = dynamic_cast<AllocaInst*>(ptr)) {
if (std::find(promotable_allocas.begin(), promotable_allocas.end(), alloca) != promotable_allocas.end()) {
current_def[alloca].push_back(store->GetValue());
pushes.push_back({alloca, 1});
instructions_to_erase.insert(store);
}
}
}
}
// Fill Phi incoming values for CFG successors
for (auto* succ : bb->GetSuccessors()) {
auto succ_phi_it = phi_nodes.find(succ);
if (succ_phi_it != phi_nodes.end()) {
for (const auto& pair : succ_phi_it->second) {
auto* alloca = pair.first;
auto* phi = pair.second;
auto& defs = current_def[alloca];
Value* val = defs.empty() ? get_default_value(alloca) : defs.back();
phi->AddIncoming(val, bb);
}
}
}
// Recurse to dominator tree children
for (auto* child : dom_tree.GetDominatedBlocks(bb)) {
rename_dfs(child);
}
// Pop definitions pushed in this block
for (const auto& push : pushes) {
auto* alloca = push.first;
for (size_t k = 0; k < push.second; ++k) {
if (!current_def[alloca].empty()) {
current_def[alloca].pop_back();
}
}
}
};
if (!func->GetBlocks().empty()) {
rename_dfs(func->GetEntry());
}
// 5. Clean up loads, stores and allocas
for (auto* alloca : promotable_allocas) {
instructions_to_erase.insert(alloca);
}
for (const auto& bbPtr : func->GetBlocks()) {
std::vector<Instruction*> to_remove;
for (const auto& instPtr : bbPtr->GetInstructions()) {
if (instructions_to_erase.find(instPtr.get()) != instructions_to_erase.end()) {
to_remove.push_back(instPtr.get());
}
}
for (auto* inst : to_remove) {
bbPtr->EraseInstruction(inst);
}
}
return true;
}
} // namespace ir

View File

@@ -1 +1,35 @@
// IR Pass 管理骨架。 #include "ir/PassManager.h"
#include <iostream>
namespace ir {
void RunFunctionOptimizationPasses(Function* func, Context& ctx) {
// 1. Promote memory-based local variables to SSA form using Mem2Reg
RunMem2Reg(func, ctx);
// 2. Run scalar optimizations iteratively until convergence (no changes observed)
bool changed = true;
int iterations = 0;
const int max_iterations = 16; // Safe limit to prevent compile-time infinite loops
while (changed && iterations < max_iterations) {
changed = false;
iterations++;
changed |= RunConstProp(func, ctx);
changed |= RunConstFold(func, ctx);
changed |= RunCSE(func);
changed |= RunDCE(func);
changed |= RunCFGSimplify(func);
}
}
void RunOptimizationPasses(Module& module) {
for (const auto& funcPtr : module.GetFunctions()) {
if (!funcPtr->GetBlocks().empty()) {
RunFunctionOptimizationPasses(funcPtr.get(), module.GetContext());
}
}
}
} // namespace ir

View File

@@ -1,6 +1,7 @@
#include "irgen/IRGen.h" #include "irgen/IRGen.h"
#include <stdexcept> #include <stdexcept>
#include <vector>
#include "SysYParser.h" #include "SysYParser.h"
#include "ir/IR.h" #include "ir/IR.h"
@@ -8,100 +9,209 @@
namespace { namespace {
std::string GetLValueName(SysYParser::LValueContext& lvalue) { std::shared_ptr<ir::Type> BaseTypeFromDecl(SysYParser::BtypeContext* btype) {
if (!lvalue.ID()) { return (btype && btype->FLOAT()) ? ir::Type::GetFloatType() : ir::Type::GetInt32Type();
throw std::runtime_error(FormatError("irgen", "非法左值")); }
}
return lvalue.ID()->getText(); std::shared_ptr<ir::Type> StorageType(std::shared_ptr<ir::Type> ty) {
if (ty->IsInt32()) return ir::Type::GetPtrInt32Type();
if (ty->IsFloat()) return ir::Type::GetPtrFloatType();
return ty;
}
size_t CountScalars(const std::shared_ptr<ir::Type>& ty) {
if (!ty->IsArray()) return 1;
auto arr_ty = ty->GetAsArrayType();
return arr_ty->GetNumElements() * CountScalars(arr_ty->GetElementType());
} }
} // namespace } // namespace
std::any IRGenImpl::visitBlockStmt(SysYParser::BlockStmtContext* ctx) { void IRGenImpl::ZeroInitializeLocal(ir::Value* ptr, std::shared_ptr<ir::Type> ty) {
if (!ctx) { if (ty->IsArray()) {
throw std::runtime_error(FormatError("irgen", "缺少语句块")); auto arr_ty = ty->GetAsArrayType();
for (uint32_t i = 0; i < arr_ty->GetNumElements(); ++i) {
auto* elem_ptr = builder_.CreateGEP(StorageType(arr_ty->GetElementType()), ptr,
{builder_.CreateConstInt(0),
builder_.CreateConstInt(static_cast<int>(i))},
module_.GetContext().NextTemp());
ZeroInitializeLocal(elem_ptr, arr_ty->GetElementType());
}
return;
} }
for (auto* item : ctx->blockItem()) {
if (item) { ir::Value* zero = ty->IsFloat() ? static_cast<ir::Value*>(builder_.CreateConstFloat(0.0f))
if (VisitBlockItemResult(*item) == BlockFlow::Terminated) { : static_cast<ir::Value*>(builder_.CreateConstInt(0));
// 当前语法要求 return 为块内最后一条语句;命中后可停止生成。 builder_.CreateStore(zero, ptr);
break; }
void IRGenImpl::EmitLocalInitValue(ir::Value* ptr, std::shared_ptr<ir::Type> ty,
SysYParser::InitValueContext* init) {
if (!init) return;
auto build_flat_scalar_ptr = [&](ir::Value* base_ptr,
const std::shared_ptr<ir::Type>& base_ty,
size_t flat_index) -> ir::Value* {
if (!base_ty->IsArray()) return base_ptr;
std::vector<ir::Value*> indices;
indices.push_back(builder_.CreateConstInt(0));
auto cur_ty = base_ty;
size_t offset = flat_index;
while (cur_ty->IsArray()) {
auto arr_ty = cur_ty->GetAsArrayType();
const size_t step = CountScalars(arr_ty->GetElementType());
const size_t idx = step == 0 ? 0 : offset / step;
offset = step == 0 ? 0 : offset % step;
indices.push_back(builder_.CreateConstInt(static_cast<int>(idx)));
cur_ty = arr_ty->GetElementType();
}
return builder_.CreateGEP(StorageType(cur_ty), base_ptr, indices,
module_.GetContext().NextTemp());
};
if (ty->IsArray()) {
auto arr_ty = ty->GetAsArrayType();
const auto elem_ty = arr_ty->GetElementType();
const size_t elem_step = CountScalars(elem_ty);
if (init->exp()) {
auto* elem_ptr = build_flat_scalar_ptr(ptr, ty, 0);
EmitLocalInitValue(elem_ptr, elem_ty, init);
return;
}
const auto& children = init->initValue();
size_t scalar_cursor = 0;
for (auto* child : children) {
if (scalar_cursor >= CountScalars(ty)) break;
if (child->exp()) {
auto* elem_ptr = build_flat_scalar_ptr(ptr, ty, scalar_cursor);
auto scalar_ty = elem_ty;
while (scalar_ty->IsArray()) {
scalar_ty = scalar_ty->GetAsArrayType()->GetElementType();
}
EmitLocalInitValue(elem_ptr, scalar_ty, child);
++scalar_cursor;
continue;
}
const size_t elem_index = elem_step == 0 ? 0 : scalar_cursor / elem_step;
if (elem_index >= arr_ty->GetNumElements()) break;
auto* elem_ptr = builder_.CreateGEP(StorageType(elem_ty), ptr,
{builder_.CreateConstInt(0),
builder_.CreateConstInt(static_cast<int>(elem_index))},
module_.GetContext().NextTemp());
EmitLocalInitValue(elem_ptr, elem_ty, child);
scalar_cursor = (elem_index + 1) * elem_step;
}
return;
}
if (!init->exp()) {
return;
}
ir::Value* value = EvalExpr(*init->exp());
if (ty->IsFloat() && value->GetType()->IsInt32()) {
value = builder_.CreateSIToFP(value, ty, module_.GetContext().NextTemp());
} else if (ty->IsInt32() && value->GetType()->IsFloat()) {
value = builder_.CreateFPToSI(value, ty, module_.GetContext().NextTemp());
}
builder_.CreateStore(value, ptr);
}
std::any IRGenImpl::visitDecl(SysYParser::DeclContext* ctx) {
if (ctx->constDecl()) return ctx->constDecl()->accept(this);
if (ctx->varDecl()) return ctx->varDecl()->accept(this);
return {};
}
std::any IRGenImpl::visitConstDecl(SysYParser::ConstDeclContext* ctx) {
for (auto* def : ctx->constDef()) {
def->accept(this);
}
return {};
}
std::any IRGenImpl::visitVarDecl(SysYParser::VarDeclContext* ctx) {
for (auto* def : ctx->varDef()) {
def->accept(this);
}
return {};
}
std::any IRGenImpl::visitConstDef(SysYParser::ConstDefContext* ctx) {
const std::string name = ctx->ID()->getText();
auto ty = BaseTypeFromDecl(
dynamic_cast<SysYParser::ConstDeclContext*>(ctx->parent)->btype());
const auto dims = ctx->exp();
for (auto it = dims.rbegin(); it != dims.rend(); ++it) {
auto* dim = EvalConstExpr(**it);
if (auto* ci = dynamic_cast<ir::ConstantInt*>(dim)) {
ty = ir::ArrayType::Get(ty, ci->GetValue());
continue;
}
throw std::runtime_error(FormatError("irgen", "数组维度必须是整型常量"));
}
ir::Value* slot = nullptr;
if (is_global_scope_) {
ir::ConstantValue* init = nullptr;
if (ctx->initValue() && ctx->initValue()->exp()) {
init = EvalConstExpr(*ctx->initValue()->exp());
if (ty->IsInt32() && init->GetType()->IsFloat()) {
init = module_.GetContext().GetConstInt(
static_cast<int>(static_cast<ir::ConstantFloat*>(init)->GetValue()));
} else if (ty->IsFloat() && init->GetType()->IsInt32()) {
init = module_.GetContext().GetConstFloat(
static_cast<float>(static_cast<ir::ConstantInt*>(init)->GetValue()));
} }
} }
} slot = module_.CreateGlobalValue(name, StorageType(ty), init);
return {};
}
IRGenImpl::BlockFlow IRGenImpl::VisitBlockItemResult(
SysYParser::BlockItemContext& item) {
return std::any_cast<BlockFlow>(item.accept(this));
}
std::any IRGenImpl::visitBlockItem(SysYParser::BlockItemContext* ctx) {
if (!ctx) {
throw std::runtime_error(FormatError("irgen", "缺少块内项"));
}
if (ctx->decl()) {
ctx->decl()->accept(this);
return BlockFlow::Continue;
}
if (ctx->stmt()) {
return ctx->stmt()->accept(this);
}
throw std::runtime_error(FormatError("irgen", "暂不支持的语句或声明"));
}
// 变量声明的 IR 生成目前也是最小实现:
// - 先检查声明的基础类型,当前仅支持局部 int
// - 再把 Decl 中的变量定义交给 visitVarDef 继续处理。
//
// 和更完整的版本相比,这里还没有:
// - 一个 Decl 中多个变量定义的顺序处理;
// - const、数组、全局变量等不同声明形态
// - 更丰富的类型系统。
std::any IRGenImpl::visitDecl(SysYParser::DeclContext* ctx) {
if (!ctx) {
throw std::runtime_error(FormatError("irgen", "缺少变量声明"));
}
if (!ctx->btype() || !ctx->btype()->INT()) {
throw std::runtime_error(FormatError("irgen", "当前仅支持局部 int 变量声明"));
}
auto* var_def = ctx->varDef();
if (!var_def) {
throw std::runtime_error(FormatError("irgen", "非法变量声明"));
}
var_def->accept(this);
return {};
}
// 当前仍是教学用的最小版本,因此这里只支持:
// - 局部 int 变量;
// - 标量初始化;
// - 一个 VarDef 对应一个槽位。
std::any IRGenImpl::visitVarDef(SysYParser::VarDefContext* ctx) {
if (!ctx) {
throw std::runtime_error(FormatError("irgen", "缺少变量定义"));
}
if (!ctx->lValue()) {
throw std::runtime_error(FormatError("irgen", "变量声明缺少名称"));
}
GetLValueName(*ctx->lValue());
if (storage_map_.find(ctx) != storage_map_.end()) {
throw std::runtime_error(FormatError("irgen", "声明重复生成存储槽位"));
}
auto* slot = builder_.CreateAllocaI32(module_.GetContext().NextTemp());
storage_map_[ctx] = slot;
ir::Value* init = nullptr;
if (auto* init_value = ctx->initValue()) {
if (!init_value->exp()) {
throw std::runtime_error(FormatError("irgen", "当前不支持聚合初始化"));
}
init = EvalExpr(*init_value->exp());
} else { } else {
init = builder_.CreateConstInt(0); slot = builder_.CreateAlloca(StorageType(ty), name);
ZeroInitializeLocal(slot, ty);
EmitLocalInitValue(slot, ty, ctx->initValue());
} }
builder_.CreateStore(init, slot);
storage_map_[ctx] = slot;
return {};
}
std::any IRGenImpl::visitVarDef(SysYParser::VarDefContext* ctx) {
const std::string name = ctx->ID()->getText();
auto ty = BaseTypeFromDecl(
dynamic_cast<SysYParser::VarDeclContext*>(ctx->parent)->btype());
const auto dims = ctx->exp();
for (auto it = dims.rbegin(); it != dims.rend(); ++it) {
auto* dim = EvalConstExpr(**it);
if (auto* ci = dynamic_cast<ir::ConstantInt*>(dim)) {
ty = ir::ArrayType::Get(ty, ci->GetValue());
continue;
}
throw std::runtime_error(FormatError("irgen", "数组维度必须是整型常量"));
}
ir::Value* slot = nullptr;
if (is_global_scope_) {
ir::ConstantValue* init = nullptr;
if (ctx->initValue() && ctx->initValue()->exp()) {
init = EvalConstExpr(*ctx->initValue()->exp());
if (ty->IsInt32() && init->GetType()->IsFloat()) {
init = module_.GetContext().GetConstInt(
static_cast<int>(static_cast<ir::ConstantFloat*>(init)->GetValue()));
} else if (ty->IsFloat() && init->GetType()->IsInt32()) {
init = module_.GetContext().GetConstFloat(
static_cast<float>(static_cast<ir::ConstantInt*>(init)->GetValue()));
}
}
slot = module_.CreateGlobalValue(name, StorageType(ty), init);
} else {
slot = builder_.CreateAlloca(StorageType(ty), name);
ZeroInitializeLocal(slot, ty);
if (ctx->initValue()) EmitLocalInitValue(slot, ty, ctx->initValue());
}
storage_map_[ctx] = slot;
return {}; return {};
} }

View File

@@ -6,9 +6,27 @@
#include "ir/IR.h" #include "ir/IR.h"
#include "utils/Log.h" #include "utils/Log.h"
static void PredeclareLibraryFunctions(ir::Module& module) {
module.CreateFunction("getint", ir::Type::GetInt32Type(), {});
module.CreateFunction("getch", ir::Type::GetInt32Type(), {});
module.CreateFunction("getfloat", ir::Type::GetFloatType(), {});
module.CreateFunction("getarray", ir::Type::GetInt32Type(), {ir::Type::GetPtrInt32Type()});
module.CreateFunction("getfarray", ir::Type::GetInt32Type(), {ir::Type::GetPtrFloatType()});
module.CreateFunction("putint", ir::Type::GetVoidType(), {ir::Type::GetInt32Type()});
module.CreateFunction("putch", ir::Type::GetVoidType(), {ir::Type::GetInt32Type()});
module.CreateFunction("putfloat", ir::Type::GetVoidType(), {ir::Type::GetFloatType()});
module.CreateFunction("putarray", ir::Type::GetVoidType(), {ir::Type::GetInt32Type(), ir::Type::GetPtrInt32Type()});
module.CreateFunction("putfarray", ir::Type::GetVoidType(), {ir::Type::GetInt32Type(), ir::Type::GetPtrFloatType()});
module.CreateFunction("starttime", ir::Type::GetVoidType(), {});
module.CreateFunction("stoptime", ir::Type::GetVoidType(), {});
// putf is special, but for now we might not support it fully or just declare it simply
// module.CreateFunction("putf", ...);
}
std::unique_ptr<ir::Module> GenerateIR(SysYParser::CompUnitContext& tree, std::unique_ptr<ir::Module> GenerateIR(SysYParser::CompUnitContext& tree,
const SemanticContext& sema) { const SemanticContext& sema) {
auto module = std::make_unique<ir::Module>(); auto module = std::make_unique<ir::Module>();
PredeclareLibraryFunctions(*module);
IRGenImpl gen(*module, sema); IRGenImpl gen(*module, sema);
tree.accept(&gen); tree.accept(&gen);
return module; return module;

View File

@@ -1,80 +1,734 @@
#include "irgen/IRGen.h" #include "irgen/IRGen.h"
#include <cmath>
#include <stdexcept> #include <stdexcept>
#include <type_traits>
#include <vector>
#include "SysYParser.h" #include "SysYParser.h"
#include "ir/IR.h" #include "ir/IR.h"
#include "utils/Log.h" #include "utils/Log.h"
// 表达式生成当前也只实现了很小的一个子集。 namespace {
// 目前支持:
// - 整数字面量 bool IsZero(const ir::ConstantValue* value) {
// - 普通局部变量读取 if (auto* ci = dynamic_cast<const ir::ConstantInt*>(value)) {
// - 括号表达式 return ci->GetValue() == 0;
// - 二元加法 }
// if (auto* cf = dynamic_cast<const ir::ConstantFloat*>(value)) {
// 还未支持: return cf->GetValue() == 0.0f;
// - 减乘除与一元运算 }
// - 赋值表达式 return false;
// - 函数调用 }
// - 数组、指针、下标访问
// - 条件与比较表达式 bool IsTruthy(const ir::ConstantValue* value) {
// - ... return !IsZero(value);
}
int AsInt(const ir::ConstantValue* value) {
if (auto* ci = dynamic_cast<const ir::ConstantInt*>(value)) {
return ci->GetValue();
}
if (auto* cf = dynamic_cast<const ir::ConstantFloat*>(value)) {
return static_cast<int>(cf->GetValue());
}
throw std::runtime_error(FormatError("irgen", "无法将常量转换为 int"));
}
float AsFloat(const ir::ConstantValue* value) {
if (auto* cf = dynamic_cast<const ir::ConstantFloat*>(value)) {
return cf->GetValue();
}
if (auto* ci = dynamic_cast<const ir::ConstantInt*>(value)) {
return static_cast<float>(ci->GetValue());
}
throw std::runtime_error(FormatError("irgen", "无法将常量转换为 float"));
}
std::shared_ptr<ir::Type> ScalarPointerType(std::shared_ptr<ir::Type> ty) {
if (ty->IsInt32()) return ir::Type::GetPtrInt32Type();
if (ty->IsFloat()) return ir::Type::GetPtrFloatType();
return ty;
}
std::shared_ptr<ir::Type> CommonArithType(ir::Value* lhs, ir::Value* rhs) {
if (lhs->GetType()->IsFloat() || rhs->GetType()->IsFloat()) {
return ir::Type::GetFloatType();
}
return ir::Type::GetInt32Type();
}
} // namespace
ir::Value* IRGenImpl::EvalExpr(SysYParser::ExpContext& expr) { ir::Value* IRGenImpl::EvalExpr(SysYParser::ExpContext& expr) {
return std::any_cast<ir::Value*>(expr.accept(this)); return std::any_cast<ir::Value*>(expr.accept(this));
} }
ir::ConstantValue* IRGenImpl::EvalConstExpr(SysYParser::ExpContext& expr) {
class ConstExprVisitor final : public SysYBaseVisitor {
public:
ConstExprVisitor(ir::Module& module, const SemanticContext& sema)
: module_(module), sema_(sema) {}
std::any visitParenExp(SysYParser::ParenExpContext* ctx) override {
return Eval(*ctx->exp());
}
std::any visitNumberExp(SysYParser::NumberExpContext* ctx) override {
if (ctx->number()->ILITERAL()) {
const std::string text = ctx->number()->ILITERAL()->getText();
int value = 0;
if (text.size() > 2 && (text[1] == 'x' || text[1] == 'X')) {
value = std::stoi(text, nullptr, 16);
} else if (text.size() > 1 && text[0] == '0') {
value = std::stoi(text, nullptr, 8);
} else {
value = std::stoi(text, nullptr, 10);
}
return static_cast<ir::ConstantValue*>(module_.GetContext().GetConstInt(value));
}
return static_cast<ir::ConstantValue*>(
module_.GetContext().GetConstFloat(static_cast<float>(std::stod(ctx->number()->FLITERAL()->getText()))));
}
std::any visitLValueExp(SysYParser::LValueExpContext* ctx) override {
auto* def = sema_.ResolveLValue(ctx->lValue());
if (!def) {
throw std::runtime_error(FormatError("irgen", "常量表达式引用了未绑定左值"));
}
if (!ctx->lValue()->exp().empty()) {
throw std::runtime_error(
FormatError("irgen", "暂不支持在常量表达式中访问数组元素"));
}
if (auto* const_def = dynamic_cast<SysYParser::ConstDefContext*>(def)) {
if (!const_def->initValue() || !const_def->initValue()->exp()) {
throw std::runtime_error(
FormatError("irgen", "常量缺少标量初始化表达式"));
}
auto* init = Eval(*const_def->initValue()->exp());
auto* decl = dynamic_cast<SysYParser::ConstDeclContext*>(const_def->parent);
bool is_float = (decl && decl->btype() && decl->btype()->FLOAT());
if (!is_float && init->GetType()->IsFloat()) {
init = module_.GetContext().GetConstInt(
static_cast<int>(static_cast<ir::ConstantFloat*>(init)->GetValue()));
} else if (is_float && init->GetType()->IsInt32()) {
init = module_.GetContext().GetConstFloat(
static_cast<float>(static_cast<ir::ConstantInt*>(init)->GetValue()));
}
return init;
}
throw std::runtime_error(
FormatError("irgen", "全局/常量表达式必须是编译期常量"));
}
std::any visitUnaryAddExp(SysYParser::UnaryAddExpContext* ctx) override {
return Eval(*ctx->exp());
}
std::any visitUnarySubExp(SysYParser::UnarySubExpContext* ctx) override {
auto* value = Eval(*ctx->exp());
if (value->GetType()->IsFloat()) {
return static_cast<ir::ConstantValue*>(
module_.GetContext().GetConstFloat(-AsFloat(value)));
}
return static_cast<ir::ConstantValue*>(
module_.GetContext().GetConstInt(-AsInt(value)));
}
std::any visitNotExp(SysYParser::NotExpContext* ctx) override {
return static_cast<ir::ConstantValue*>(
module_.GetContext().GetConstInt(IsTruthy(Eval(*ctx->exp())) ? 0 : 1));
}
std::any visitAddExp(SysYParser::AddExpContext* ctx) override {
auto* lhs = Eval(*ctx->exp(0));
auto* rhs = Eval(*ctx->exp(1));
if (lhs->GetType()->IsFloat() || rhs->GetType()->IsFloat()) {
return static_cast<ir::ConstantValue*>(
module_.GetContext().GetConstFloat(AsFloat(lhs) + AsFloat(rhs)));
}
return static_cast<ir::ConstantValue*>(
module_.GetContext().GetConstInt(AsInt(lhs) + AsInt(rhs)));
}
std::any visitSubExp(SysYParser::SubExpContext* ctx) override {
auto* lhs = Eval(*ctx->exp(0));
auto* rhs = Eval(*ctx->exp(1));
if (lhs->GetType()->IsFloat() || rhs->GetType()->IsFloat()) {
return static_cast<ir::ConstantValue*>(
module_.GetContext().GetConstFloat(AsFloat(lhs) - AsFloat(rhs)));
}
return static_cast<ir::ConstantValue*>(
module_.GetContext().GetConstInt(AsInt(lhs) - AsInt(rhs)));
}
std::any visitMulExp(SysYParser::MulExpContext* ctx) override {
auto* lhs = Eval(*ctx->exp(0));
auto* rhs = Eval(*ctx->exp(1));
if (lhs->GetType()->IsFloat() || rhs->GetType()->IsFloat()) {
return static_cast<ir::ConstantValue*>(
module_.GetContext().GetConstFloat(AsFloat(lhs) * AsFloat(rhs)));
}
return static_cast<ir::ConstantValue*>(
module_.GetContext().GetConstInt(AsInt(lhs) * AsInt(rhs)));
}
std::any visitDivExp(SysYParser::DivExpContext* ctx) override {
auto* lhs = Eval(*ctx->exp(0));
auto* rhs = Eval(*ctx->exp(1));
if (lhs->GetType()->IsFloat() || rhs->GetType()->IsFloat()) {
const float rv = AsFloat(rhs);
return static_cast<ir::ConstantValue*>(module_.GetContext().GetConstFloat(
rv == 0.0f ? 0.0f : AsFloat(lhs) / rv));
}
const int rv = AsInt(rhs);
return static_cast<ir::ConstantValue*>(
module_.GetContext().GetConstInt(rv == 0 ? 0 : AsInt(lhs) / rv));
}
std::any visitModExp(SysYParser::ModExpContext* ctx) override {
auto* lhs = Eval(*ctx->exp(0));
auto* rhs = Eval(*ctx->exp(1));
return static_cast<ir::ConstantValue*>(module_.GetContext().GetConstInt(
AsInt(rhs) == 0 ? 0 : AsInt(lhs) % AsInt(rhs)));
}
std::any visitLtExp(SysYParser::LtExpContext* ctx) override {
return EvalCmpImpl(*ctx->exp(0), *ctx->exp(1), ir::Opcode::ICmpLT);
}
std::any visitLeExp(SysYParser::LeExpContext* ctx) override {
return EvalCmpImpl(*ctx->exp(0), *ctx->exp(1), ir::Opcode::ICmpLE);
}
std::any visitGtExp(SysYParser::GtExpContext* ctx) override {
return EvalCmpImpl(*ctx->exp(0), *ctx->exp(1), ir::Opcode::ICmpGT);
}
std::any visitGeExp(SysYParser::GeExpContext* ctx) override {
return EvalCmpImpl(*ctx->exp(0), *ctx->exp(1), ir::Opcode::ICmpGE);
}
std::any visitEqExp(SysYParser::EqExpContext* ctx) override {
return EvalCmpImpl(*ctx->exp(0), *ctx->exp(1), ir::Opcode::ICmpEQ);
}
std::any visitNeExp(SysYParser::NeExpContext* ctx) override {
return EvalCmpImpl(*ctx->exp(0), *ctx->exp(1), ir::Opcode::ICmpNE);
}
std::any visitAndExp(SysYParser::AndExpContext* ctx) override {
auto* lhs = Eval(*ctx->exp(0));
if (!IsTruthy(lhs)) {
return static_cast<ir::ConstantValue*>(module_.GetContext().GetConstInt(0));
}
return static_cast<ir::ConstantValue*>(
module_.GetContext().GetConstInt(IsTruthy(Eval(*ctx->exp(1))) ? 1 : 0));
}
std::any visitOrExp(SysYParser::OrExpContext* ctx) override {
auto* lhs = Eval(*ctx->exp(0));
if (IsTruthy(lhs)) {
return static_cast<ir::ConstantValue*>(module_.GetContext().GetConstInt(1));
}
return static_cast<ir::ConstantValue*>(
module_.GetContext().GetConstInt(IsTruthy(Eval(*ctx->exp(1))) ? 1 : 0));
}
ir::ConstantValue* Eval(SysYParser::ExpContext& ctx) {
return std::any_cast<ir::ConstantValue*>(ctx.accept(this));
}
private:
ir::ConstantValue* EvalCmpImpl(SysYParser::ExpContext& lhs_ctx,
SysYParser::ExpContext& rhs_ctx,
ir::Opcode op) {
auto* lhs = Eval(lhs_ctx);
auto* rhs = Eval(rhs_ctx);
bool result = false;
if (lhs->GetType()->IsFloat() || rhs->GetType()->IsFloat()) {
const float a = AsFloat(lhs);
const float b = AsFloat(rhs);
switch (op) {
case ir::Opcode::ICmpLT: result = a < b; break;
case ir::Opcode::ICmpLE: result = a <= b; break;
case ir::Opcode::ICmpGT: result = a > b; break;
case ir::Opcode::ICmpGE: result = a >= b; break;
case ir::Opcode::ICmpEQ: result = a == b; break;
case ir::Opcode::ICmpNE: result = a != b; break;
default: break;
}
} else {
const int a = AsInt(lhs);
const int b = AsInt(rhs);
switch (op) {
case ir::Opcode::ICmpLT: result = a < b; break;
case ir::Opcode::ICmpLE: result = a <= b; break;
case ir::Opcode::ICmpGT: result = a > b; break;
case ir::Opcode::ICmpGE: result = a >= b; break;
case ir::Opcode::ICmpEQ: result = a == b; break;
case ir::Opcode::ICmpNE: result = a != b; break;
default: break;
}
}
return module_.GetContext().GetConstInt(result ? 1 : 0);
}
ir::Module& module_;
const SemanticContext& sema_;
};
ConstExprVisitor visitor(module_, sema_);
return visitor.Eval(expr);
}
static ir::Value* CastValue(IRGenImpl& gen, ir::IRBuilder& builder, ir::Module& module,
ir::Value* value, std::shared_ptr<ir::Type> target_ty) {
if (value->GetType() == target_ty) return value;
if (target_ty->IsFloat() && value->GetType()->IsInt32()) {
if (auto* ci = dynamic_cast<ir::ConstantInt*>(value)) {
return module.GetContext().GetConstFloat(static_cast<float>(ci->GetValue()));
}
return builder.CreateSIToFP(value, target_ty, module.GetContext().NextTemp());
}
if (target_ty->IsInt32() && value->GetType()->IsFloat()) {
if (auto* cf = dynamic_cast<ir::ConstantFloat*>(value)) {
return module.GetContext().GetConstInt(static_cast<int>(cf->GetValue()));
}
return builder.CreateFPToSI(value, target_ty, module.GetContext().NextTemp());
}
return value;
}
std::any IRGenImpl::visitParenExp(SysYParser::ParenExpContext* ctx) { std::any IRGenImpl::visitParenExp(SysYParser::ParenExpContext* ctx) {
if (!ctx || !ctx->exp()) {
throw std::runtime_error(FormatError("irgen", "非法括号表达式"));
}
return EvalExpr(*ctx->exp()); return EvalExpr(*ctx->exp());
} }
std::any IRGenImpl::visitLValueExp(SysYParser::LValueExpContext* ctx) {
auto* def = sema_.ResolveLValue(ctx->lValue());
if (def && IsArrayLikeDef(def) && ctx->lValue()->exp().size() < GetArrayRank(def)) {
return DecayArrayPtr(ctx->lValue());
}
ir::Value* ptr = GetLValuePtr(ctx->lValue());
return static_cast<ir::Value*>(
builder_.CreateLoad(ptr, module_.GetContext().NextTemp()));
}
std::any IRGenImpl::visitNumberExp(SysYParser::NumberExpContext* ctx) { std::any IRGenImpl::visitNumberExp(SysYParser::NumberExpContext* ctx) {
if (!ctx || !ctx->number() || !ctx->number()->ILITERAL()) { return static_cast<ir::Value*>(EvalConstExpr(*ctx));
throw std::runtime_error(FormatError("irgen", "当前仅支持整数字面量"));
}
return static_cast<ir::Value*>(
builder_.CreateConstInt(std::stoi(ctx->number()->getText())));
} }
// 变量使用的处理流程: std::any IRGenImpl::visitFuncCallExp(SysYParser::FuncCallExpContext* ctx) {
// 1. 先通过语义分析结果把变量使用绑定回声明; ir::Function* target_func = nullptr;
// 2. 再通过 storage_map_ 找到该声明对应的栈槽位; if (auto* def = sema_.ResolveFuncCall(ctx)) {
// 3. 最后生成 load把内存中的值读出来。 const std::string func_name = def->ID()->getText();
// for (const auto& f : module_.GetFunctions()) {
// 因此当前 IRGen 自己不再做名字查找,而是直接消费 Sema 的绑定结果。 if (f->GetName() == func_name) {
std::any IRGenImpl::visitVarExp(SysYParser::VarExpContext* ctx) { target_func = f.get();
if (!ctx || !ctx->var() || !ctx->var()->ID()) { break;
throw std::runtime_error(FormatError("irgen", "当前仅支持普通整型变量")); }
}
} else {
const std::string func_name = ctx->ID()->getText();
for (const auto& f : module_.GetFunctions()) {
if (f->GetName() == func_name) {
target_func = f.get();
break;
}
}
} }
auto* decl = sema_.ResolveVarUse(ctx->var()); if (!target_func) {
if (!decl) { throw std::runtime_error(FormatError("irgen", "找不到函数: " + ctx->ID()->getText()));
throw std::runtime_error(
FormatError("irgen",
"变量使用缺少语义绑定: " + ctx->var()->ID()->getText()));
} }
auto it = storage_map_.find(decl);
std::vector<ir::Value*> args;
const auto& arg_types = target_func->GetArguments();
if (ctx->funcRParams()) {
const auto& exps = ctx->funcRParams()->exp();
for (size_t i = 0; i < exps.size(); ++i) {
ir::Value* arg = EvalExpr(*exps[i]);
if (i < arg_types.size()) {
arg = CastValue(*this, builder_, module_, arg, arg_types[i]->GetType());
}
args.push_back(arg);
}
}
return static_cast<ir::Value*>(
builder_.CreateCall(target_func, args,
target_func->GetType()->IsVoid()
? ""
: module_.GetContext().NextTemp()));
}
std::any IRGenImpl::visitNotExp(SysYParser::NotExpContext* ctx) {
ir::Value* val = EvalExpr(*ctx->exp());
if (auto* cv = dynamic_cast<ir::ConstantValue*>(val)) {
return static_cast<ir::Value*>(
module_.GetContext().GetConstInt(IsTruthy(cv) ? 0 : 1));
}
ir::Value* cond = ToI1(val);
ir::Value* res =
builder_.CreateICmp(ir::Opcode::ICmpEQ, cond, builder_.CreateConstInt(0),
module_.GetContext().NextTemp());
return ToI32(res);
}
std::any IRGenImpl::visitUnaryAddExp(SysYParser::UnaryAddExpContext* ctx) {
return EvalExpr(*ctx->exp());
}
std::any IRGenImpl::visitUnarySubExp(SysYParser::UnarySubExpContext* ctx) {
ir::Value* val = EvalExpr(*ctx->exp());
if (auto* ci = dynamic_cast<ir::ConstantInt*>(val)) {
return static_cast<ir::Value*>(module_.GetContext().GetConstInt(-ci->GetValue()));
}
if (auto* cf = dynamic_cast<ir::ConstantFloat*>(val)) {
return static_cast<ir::Value*>(module_.GetContext().GetConstFloat(-cf->GetValue()));
}
if (val->GetType()->IsFloat()) {
return static_cast<ir::Value*>(builder_.CreateFSub(
builder_.CreateConstFloat(0.0f), val, module_.GetContext().NextTemp()));
}
return static_cast<ir::Value*>(builder_.CreateSub(
builder_.CreateConstInt(0), val, module_.GetContext().NextTemp()));
}
#define DEFINE_ARITH_VISITOR(name, int_opcode, float_opcode) \
std::any IRGenImpl::visit##name##Exp(SysYParser::name##ExpContext* ctx) { \
ir::Value* lhs = EvalExpr(*ctx->exp(0)); \
ir::Value* rhs = EvalExpr(*ctx->exp(1)); \
const auto common_ty = CommonArithType(lhs, rhs); \
lhs = CastValue(*this, builder_, module_, lhs, common_ty); \
rhs = CastValue(*this, builder_, module_, rhs, common_ty); \
if (auto* lconst = dynamic_cast<ir::ConstantValue*>(lhs)) { \
if (auto* rconst = dynamic_cast<ir::ConstantValue*>(rhs)) { \
if (common_ty->IsFloat()) { \
const float lv = AsFloat(lconst); \
const float rv = AsFloat(rconst); \
if constexpr (ir::Opcode::float_opcode == ir::Opcode::FAdd) \
return static_cast<ir::Value*>(module_.GetContext().GetConstFloat(lv + rv)); \
if constexpr (ir::Opcode::float_opcode == ir::Opcode::FSub) \
return static_cast<ir::Value*>(module_.GetContext().GetConstFloat(lv - rv)); \
if constexpr (ir::Opcode::float_opcode == ir::Opcode::FMul) \
return static_cast<ir::Value*>(module_.GetContext().GetConstFloat(lv * rv)); \
return static_cast<ir::Value*>(module_.GetContext().GetConstFloat(rv == 0.0f ? 0.0f : lv / rv)); \
} \
const int lv = AsInt(lconst); \
const int rv = AsInt(rconst); \
if constexpr (ir::Opcode::int_opcode == ir::Opcode::Add) \
return static_cast<ir::Value*>(module_.GetContext().GetConstInt(lv + rv)); \
if constexpr (ir::Opcode::int_opcode == ir::Opcode::Sub) \
return static_cast<ir::Value*>(module_.GetContext().GetConstInt(lv - rv)); \
if constexpr (ir::Opcode::int_opcode == ir::Opcode::Mul) \
return static_cast<ir::Value*>(module_.GetContext().GetConstInt(lv * rv)); \
return static_cast<ir::Value*>(module_.GetContext().GetConstInt(rv == 0 ? 0 : lv / rv)); \
} \
} \
if (common_ty->IsFloat()) { \
if constexpr (ir::Opcode::float_opcode == ir::Opcode::FAdd) \
return static_cast<ir::Value*>(builder_.CreateFAdd(lhs, rhs, module_.GetContext().NextTemp())); \
if constexpr (ir::Opcode::float_opcode == ir::Opcode::FSub) \
return static_cast<ir::Value*>(builder_.CreateFSub(lhs, rhs, module_.GetContext().NextTemp())); \
if constexpr (ir::Opcode::float_opcode == ir::Opcode::FMul) \
return static_cast<ir::Value*>(builder_.CreateFMul(lhs, rhs, module_.GetContext().NextTemp())); \
return static_cast<ir::Value*>(builder_.CreateFDiv(lhs, rhs, module_.GetContext().NextTemp())); \
} \
return static_cast<ir::Value*>(builder_.CreateBinary(ir::Opcode::int_opcode, lhs, rhs, module_.GetContext().NextTemp())); \
}
DEFINE_ARITH_VISITOR(Add, Add, FAdd)
DEFINE_ARITH_VISITOR(Sub, Sub, FSub)
DEFINE_ARITH_VISITOR(Mul, Mul, FMul)
DEFINE_ARITH_VISITOR(Div, Div, FDiv)
std::any IRGenImpl::visitModExp(SysYParser::ModExpContext* ctx) {
ir::Value* lhs = CastValue(*this, builder_, module_, EvalExpr(*ctx->exp(0)),
ir::Type::GetInt32Type());
ir::Value* rhs = CastValue(*this, builder_, module_, EvalExpr(*ctx->exp(1)),
ir::Type::GetInt32Type());
if (auto* lconst = dynamic_cast<ir::ConstantValue*>(lhs)) {
if (auto* rconst = dynamic_cast<ir::ConstantValue*>(rhs)) {
const int rv = AsInt(rconst);
return static_cast<ir::Value*>(
module_.GetContext().GetConstInt(rv == 0 ? 0 : AsInt(lconst) % rv));
}
}
return static_cast<ir::Value*>(
builder_.CreateMod(lhs, rhs, module_.GetContext().NextTemp()));
}
#define DEFINE_CMP_VISITOR(name, int_opcode, float_opcode, cmp_op) \
std::any IRGenImpl::visit##name##Exp(SysYParser::name##ExpContext* ctx) { \
ir::Value* lhs = EvalExpr(*ctx->exp(0)); \
ir::Value* rhs = EvalExpr(*ctx->exp(1)); \
const auto common_ty = CommonArithType(lhs, rhs); \
lhs = CastValue(*this, builder_, module_, lhs, common_ty); \
rhs = CastValue(*this, builder_, module_, rhs, common_ty); \
if (auto* lconst = dynamic_cast<ir::ConstantValue*>(lhs)) { \
if (auto* rconst = dynamic_cast<ir::ConstantValue*>(rhs)) { \
const bool result = common_ty->IsFloat() ? (AsFloat(lconst) cmp_op AsFloat(rconst)) \
: (AsInt(lconst) cmp_op AsInt(rconst)); \
return static_cast<ir::Value*>(module_.GetContext().GetConstInt(result ? 1 : 0)); \
} \
} \
if (common_ty->IsFloat()) { \
return static_cast<ir::Value*>(builder_.CreateFCmp(ir::Opcode::float_opcode, lhs, rhs, module_.GetContext().NextTemp())); \
} \
return static_cast<ir::Value*>(builder_.CreateICmp(ir::Opcode::int_opcode, lhs, rhs, module_.GetContext().NextTemp())); \
}
DEFINE_CMP_VISITOR(Lt, ICmpLT, FCmpLT, <)
DEFINE_CMP_VISITOR(Le, ICmpLE, FCmpLE, <=)
DEFINE_CMP_VISITOR(Gt, ICmpGT, FCmpGT, >)
DEFINE_CMP_VISITOR(Ge, ICmpGE, FCmpGE, >=)
DEFINE_CMP_VISITOR(Eq, ICmpEQ, FCmpEQ, ==)
DEFINE_CMP_VISITOR(Ne, ICmpNE, FCmpNE, !=)
std::any IRGenImpl::visitAndExp(SysYParser::AndExpContext* ctx) {
if (!builder_.GetInsertBlock()) {
return static_cast<ir::Value*>(EvalConstExpr(*ctx));
}
ir::Value* lhs = EvalExpr(*ctx->exp(0));
if (auto* c = dynamic_cast<ir::ConstantValue*>(lhs); c && !IsTruthy(c)) {
return static_cast<ir::Value*>(module_.GetContext().GetConstInt(0));
}
const std::string suffix = module_.GetContext().NextTemp();
ir::BasicBlock* rhs_bb = func_->CreateBlock("and.rhs." + suffix);
ir::BasicBlock* merge_bb = func_->CreateBlock("and.merge." + suffix);
auto* res_ptr = builder_.CreateAllocaI32(module_.GetContext().NextTemp());
builder_.CreateStore(ToI32(ToI1(lhs)), res_ptr);
builder_.CreateCondBr(ToI1(lhs), rhs_bb, merge_bb);
builder_.SetInsertPoint(rhs_bb);
ir::Value* rhs = EvalExpr(*ctx->exp(1));
builder_.CreateStore(ToI32(ToI1(rhs)), res_ptr);
builder_.CreateBr(merge_bb);
builder_.SetInsertPoint(merge_bb);
return static_cast<ir::Value*>(
builder_.CreateLoad(res_ptr, module_.GetContext().NextTemp()));
}
std::any IRGenImpl::visitOrExp(SysYParser::OrExpContext* ctx) {
if (!builder_.GetInsertBlock()) {
return static_cast<ir::Value*>(EvalConstExpr(*ctx));
}
ir::Value* lhs = EvalExpr(*ctx->exp(0));
if (auto* c = dynamic_cast<ir::ConstantValue*>(lhs); c && IsTruthy(c)) {
return static_cast<ir::Value*>(module_.GetContext().GetConstInt(1));
}
const std::string suffix = module_.GetContext().NextTemp();
ir::BasicBlock* rhs_bb = func_->CreateBlock("or.rhs." + suffix);
ir::BasicBlock* merge_bb = func_->CreateBlock("or.merge." + suffix);
auto* res_ptr = builder_.CreateAllocaI32(module_.GetContext().NextTemp());
builder_.CreateStore(ToI32(ToI1(lhs)), res_ptr);
builder_.CreateCondBr(ToI1(lhs), merge_bb, rhs_bb);
builder_.SetInsertPoint(rhs_bb);
ir::Value* rhs = EvalExpr(*ctx->exp(1));
builder_.CreateStore(ToI32(ToI1(rhs)), res_ptr);
builder_.CreateBr(merge_bb);
builder_.SetInsertPoint(merge_bb);
return static_cast<ir::Value*>(
builder_.CreateLoad(res_ptr, module_.GetContext().NextTemp()));
}
bool IRGenImpl::IsArrayLikeDef(antlr4::ParserRuleContext* def) const {
if (auto* const_def = dynamic_cast<SysYParser::ConstDefContext*>(def)) {
return !const_def->exp().empty();
}
if (auto* var_def = dynamic_cast<SysYParser::VarDefContext*>(def)) {
return !var_def->exp().empty();
}
if (auto* param = dynamic_cast<SysYParser::FuncFParamContext*>(def)) {
return !param->LBRACK().empty();
}
return false;
}
size_t IRGenImpl::GetArrayRank(antlr4::ParserRuleContext* def) const {
if (auto* const_def = dynamic_cast<SysYParser::ConstDefContext*>(def)) {
return const_def->exp().size();
}
if (auto* var_def = dynamic_cast<SysYParser::VarDefContext*>(def)) {
return var_def->exp().size();
}
if (auto* param = dynamic_cast<SysYParser::FuncFParamContext*>(def)) {
return param->LBRACK().size();
}
return 0;
}
std::shared_ptr<ir::Type> IRGenImpl::GetDefType(antlr4::ParserRuleContext* def) const {
std::shared_ptr<ir::Type> ty = ir::Type::GetInt32Type();
auto apply_dims = [&](const auto& dims) {
auto result = ty;
for (auto it = dims.rbegin(); it != dims.rend(); ++it) {
if constexpr (std::is_pointer_v<typename std::decay_t<decltype(*it)>>) {
auto* dim_val = const_cast<IRGenImpl*>(this)->EvalConstExpr(**it);
result = ir::ArrayType::Get(result, AsInt(dim_val));
} else {
auto* dim_val = const_cast<IRGenImpl*>(this)->EvalConstExpr(*it);
result = ir::ArrayType::Get(result, AsInt(dim_val));
}
}
return result;
};
if (auto* const_def = dynamic_cast<SysYParser::ConstDefContext*>(def)) {
auto* decl = dynamic_cast<SysYParser::ConstDeclContext*>(const_def->parent);
ty = (decl && decl->btype() && decl->btype()->FLOAT()) ? ir::Type::GetFloatType()
: ir::Type::GetInt32Type();
return apply_dims(const_def->exp());
}
if (auto* var_def = dynamic_cast<SysYParser::VarDefContext*>(def)) {
auto* decl = dynamic_cast<SysYParser::VarDeclContext*>(var_def->parent);
ty = (decl && decl->btype() && decl->btype()->FLOAT()) ? ir::Type::GetFloatType()
: ir::Type::GetInt32Type();
return apply_dims(var_def->exp());
}
if (auto* param = dynamic_cast<SysYParser::FuncFParamContext*>(def)) {
ty = param->btype()->FLOAT() ? ir::Type::GetFloatType() : ir::Type::GetInt32Type();
if (param->LBRACK().empty()) return ty;
for (int i = static_cast<int>(param->exp().size()) - 1; i >= 0; --i) {
auto* dim_val = const_cast<IRGenImpl*>(this)->EvalConstExpr(*param->exp(i));
ty = ir::ArrayType::Get(ty, AsInt(dim_val));
}
return ty;
}
return ty;
}
ir::Value* IRGenImpl::DecayArrayPtr(SysYParser::LValueContext* ctx) {
auto* def = sema_.ResolveLValue(ctx);
if (!def) {
throw std::runtime_error(FormatError("irgen", "数组退化失败: 未绑定定义"));
}
ir::Value* base_ptr = storage_map_.at(def);
const auto base_ty = GetDefType(def);
if (dynamic_cast<SysYParser::FuncFParamContext*>(def)) {
if (ctx->exp().empty()) return base_ptr;
ir::Value* offset = CastValue(*this, builder_, module_, EvalExpr(*ctx->exp(0)),
ir::Type::GetInt32Type());
auto cur_ty = base_ty;
for (size_t i = 1; i < ctx->exp().size(); ++i) {
if (!cur_ty->IsArray()) break;
auto arr_ty = cur_ty->GetAsArrayType();
ir::Value* stride = builder_.CreateConstInt(static_cast<int>(arr_ty->GetNumElements()));
offset = builder_.CreateMul(offset, stride, module_.GetContext().NextTemp());
offset = builder_.CreateAdd(
offset,
CastValue(*this, builder_, module_, EvalExpr(*ctx->exp(i)),
ir::Type::GetInt32Type()),
module_.GetContext().NextTemp());
cur_ty = arr_ty->GetElementType();
}
return builder_.CreateGEP(ScalarPointerType(cur_ty), base_ptr, {offset},
module_.GetContext().NextTemp());
}
std::vector<ir::Value*> indices;
indices.push_back(builder_.CreateConstInt(0));
for (auto* exp : ctx->exp()) {
indices.push_back(
CastValue(*this, builder_, module_, EvalExpr(*exp), ir::Type::GetInt32Type()));
}
auto cur_ty = base_ty;
while (cur_ty->IsArray()) {
cur_ty = cur_ty->GetAsArrayType()->GetElementType();
}
if (!ctx->exp().empty()) {
return builder_.CreateGEP(ScalarPointerType(cur_ty), base_ptr, indices,
module_.GetContext().NextTemp());
}
indices.push_back(builder_.CreateConstInt(0));
return builder_.CreateGEP(ScalarPointerType(cur_ty), base_ptr, indices,
module_.GetContext().NextTemp());
}
ir::Value* IRGenImpl::GetLValuePtr(SysYParser::LValueContext* ctx) {
auto* def = sema_.ResolveLValue(ctx);
if (!def) {
throw std::runtime_error(FormatError("irgen", "未定义的左值: " + ctx->ID()->getText()));
}
auto it = storage_map_.find(def);
if (it == storage_map_.end()) { if (it == storage_map_.end()) {
throw std::runtime_error( throw std::runtime_error(
FormatError("irgen", FormatError("irgen", "左值缺少存储槽位: " + ctx->ID()->getText()));
"变量声明缺少存储槽位: " + ctx->var()->ID()->getText()));
} }
return static_cast<ir::Value*>(
builder_.CreateLoad(it->second, module_.GetContext().NextTemp())); ir::Value* base_ptr = it->second;
if (ctx->exp().empty()) return base_ptr;
if (dynamic_cast<SysYParser::FuncFParamContext*>(def)) {
return DecayArrayPtr(ctx);
}
const auto base_ty = GetDefType(def);
std::vector<ir::Value*> indices;
indices.push_back(builder_.CreateConstInt(0));
for (auto* exp : ctx->exp()) {
indices.push_back(
CastValue(*this, builder_, module_, EvalExpr(*exp), ir::Type::GetInt32Type()));
}
auto cur_ty = base_ty;
for (size_t i = 0; i < ctx->exp().size(); ++i) {
if (!cur_ty->IsArray()) {
throw std::runtime_error(FormatError("irgen", "数组下标层数超出定义"));
}
cur_ty = cur_ty->GetAsArrayType()->GetElementType();
}
return builder_.CreateGEP(ScalarPointerType(cur_ty), base_ptr, indices,
module_.GetContext().NextTemp());
} }
ir::Value* IRGenImpl::ToI1(ir::Value* v) {
std::any IRGenImpl::visitAdditiveExp(SysYParser::AdditiveExpContext* ctx) { if (auto* cv = dynamic_cast<ir::ConstantValue*>(v)) {
if (!ctx || !ctx->exp(0) || !ctx->exp(1)) { return module_.GetContext().GetConstInt(IsTruthy(cv) ? 1 : 0);
throw std::runtime_error(FormatError("irgen", "非法加法表达式"));
} }
ir::Value* lhs = EvalExpr(*ctx->exp(0)); if (auto* inst = dynamic_cast<ir::Instruction*>(v)) {
ir::Value* rhs = EvalExpr(*ctx->exp(1)); switch (inst->GetOpcode()) {
return static_cast<ir::Value*>( case ir::Opcode::ICmpEQ:
builder_.CreateBinary(ir::Opcode::Add, lhs, rhs, case ir::Opcode::ICmpNE:
module_.GetContext().NextTemp())); case ir::Opcode::ICmpLT:
case ir::Opcode::ICmpGT:
case ir::Opcode::ICmpLE:
case ir::Opcode::ICmpGE:
case ir::Opcode::FCmpEQ:
case ir::Opcode::FCmpNE:
case ir::Opcode::FCmpLT:
case ir::Opcode::FCmpGT:
case ir::Opcode::FCmpLE:
case ir::Opcode::FCmpGE:
return v;
default:
break;
}
}
if (v->GetType()->IsFloat()) {
return builder_.CreateFCmp(ir::Opcode::FCmpNE, v, builder_.CreateConstFloat(0.0f),
module_.GetContext().NextTemp());
}
if (v->GetType()->IsInt32()) {
return builder_.CreateICmp(ir::Opcode::ICmpNE, v, builder_.CreateConstInt(0),
module_.GetContext().NextTemp());
}
return v;
}
ir::Value* IRGenImpl::ToI32(ir::Value* v) {
if (v->GetType()->IsInt32()) return v;
if (v->GetType()->IsFloat()) {
return builder_.CreateFPToSI(v, ir::Type::GetInt32Type(),
module_.GetContext().NextTemp());
}
return builder_.CreateZExt(v, ir::Type::GetInt32Type(),
module_.GetContext().NextTemp());
} }

View File

@@ -8,10 +8,18 @@
namespace { namespace {
std::shared_ptr<ir::Type> StorageType(std::shared_ptr<ir::Type> ty) {
if (ty->IsInt32()) return ir::Type::GetPtrInt32Type();
if (ty->IsFloat()) return ir::Type::GetPtrFloatType();
return ty;
}
void VerifyFunctionStructure(const ir::Function& func) { void VerifyFunctionStructure(const ir::Function& func) {
// 当前 IRGen 仍是单入口、顺序生成;这里在生成结束后补一层块终结校验。
for (const auto& bb : func.GetBlocks()) { for (const auto& bb : func.GetBlocks()) {
if (!bb || !bb->HasTerminator()) { if (!bb || !bb->HasTerminator()) {
// If a block doesn't have a terminator, it might be an empty function or
// missing a return in a path. For SysY, we should at least have a default return for void.
// But IRGen should have handled it.
throw std::runtime_error( throw std::runtime_error(
FormatError("irgen", "基本块未正确终结: " + FormatError("irgen", "基本块未正确终结: " +
(bb ? bb->GetName() : std::string("<null>")))); (bb ? bb->GetName() : std::string("<null>"))));
@@ -25,63 +33,83 @@ IRGenImpl::IRGenImpl(ir::Module& module, const SemanticContext& sema)
: module_(module), : module_(module),
sema_(sema), sema_(sema),
func_(nullptr), func_(nullptr),
builder_(module.GetContext(), nullptr) {} builder_(module.GetContext(), nullptr),
is_global_scope_(true) {}
// 编译单元的 IR 生成当前只实现了最小功能:
// - Module 已在 GenerateIR 中创建,这里只负责继续生成其中的内容;
// - 当前会读取编译单元中的函数定义,并交给 visitFuncDef 生成函数 IR
//
// 当前还没有实现:
// - 多个函数定义的遍历与生成;
// - 全局变量、全局常量的 IR 生成。
std::any IRGenImpl::visitCompUnit(SysYParser::CompUnitContext* ctx) { std::any IRGenImpl::visitCompUnit(SysYParser::CompUnitContext* ctx) {
if (!ctx) { if (!ctx) return {};
throw std::runtime_error(FormatError("irgen", "缺少编译单元"));
is_global_scope_ = true;
for (auto* decl : ctx->decl()) {
decl->accept(this);
} }
auto* func = ctx->funcDef(); for (auto* funcDef : ctx->funcDef()) {
if (!func) { funcDef->accept(this);
throw std::runtime_error(FormatError("irgen", "缺少函数定义"));
} }
func->accept(this);
return {}; return {};
} }
// 函数 IR 生成当前实现了:
// 1. 获取函数名;
// 2. 检查函数返回类型;
// 3. 在 Module 中创建 Function
// 4. 将 builder 插入点设置到入口基本块;
// 5. 继续生成函数体。
//
// 当前还没有实现:
// - 通用函数返回类型处理;
// - 形参列表遍历与参数类型收集;
// - FunctionType 这样的函数类型对象;
// - Argument/形式参数 IR 对象;
// - 入口块中的参数初始化逻辑。
// ...
// 因此这里目前只支持最小的“无参 int 函数”生成。
std::any IRGenImpl::visitFuncDef(SysYParser::FuncDefContext* ctx) { std::any IRGenImpl::visitFuncDef(SysYParser::FuncDefContext* ctx) {
if (!ctx) { is_global_scope_ = false;
throw std::runtime_error(FormatError("irgen", "缺少函数定义"));
} std::shared_ptr<ir::Type> ret_ty;
if (!ctx->blockStmt()) { if (ctx->funcType()->INT()) {
throw std::runtime_error(FormatError("irgen", "函数体为空")); ret_ty = ir::Type::GetInt32Type();
} } else if (ctx->funcType()->FLOAT()) {
if (!ctx->ID()) { ret_ty = ir::Type::GetFloatType();
throw std::runtime_error(FormatError("irgen", "缺少函数名")); } else {
} ret_ty = ir::Type::GetVoidType();
if (!ctx->funcType() || !ctx->funcType()->INT()) {
throw std::runtime_error(FormatError("irgen", "当前仅支持无参 int 函数"));
} }
func_ = module_.CreateFunction(ctx->ID()->getText(), ir::Type::GetInt32Type()); std::string func_name = ctx->ID()->getText();
builder_.SetInsertPoint(func_->GetEntry());
storage_map_.clear(); std::vector<std::shared_ptr<ir::Type>> param_types;
if (ctx->funcFParams()) {
for (auto* fparam : ctx->funcFParams()->funcFParam()) {
if (fparam->LBRACK().empty()) {
param_types.push_back(fparam->btype()->INT() ? ir::Type::GetInt32Type() : ir::Type::GetFloatType());
} else {
// Array param is a pointer
param_types.push_back(fparam->btype()->INT() ? ir::Type::GetPtrInt32Type() : ir::Type::GetPtrFloatType());
}
}
}
func_ = module_.CreateFunction(func_name, ret_ty, param_types);
builder_.SetInsertPoint(func_->CreateBlock("entry"));
// Handle parameters: alloca and store
if (ctx->funcFParams()) {
const auto& fparams = ctx->funcFParams()->funcFParam();
const auto& args = func_->GetArguments();
for (size_t i = 0; i < fparams.size(); ++i) {
auto* fparam = fparams[i];
auto* arg = args[i].get();
auto* slot = builder_.CreateAlloca(StorageType(arg->GetType()), fparam->ID()->getText());
builder_.CreateStore(arg, slot);
storage_map_[fparam] = slot;
}
}
ctx->blockStmt()->accept(this); ctx->blockStmt()->accept(this);
// 语义正确性主要由 sema 保证,这里只兜底检查 IR 结构是否合法。
// Default return for void functions if not terminated
if (!builder_.GetInsertBlock()->HasTerminator()) {
if (ret_ty->IsVoid()) {
builder_.CreateRet(nullptr);
} else if (ret_ty->IsInt32()) {
builder_.CreateRet(builder_.CreateConstInt(0));
} else if (ret_ty->IsFloat()) {
builder_.CreateRet(builder_.CreateConstFloat(0.0f));
}
}
VerifyFunctionStructure(*func_); VerifyFunctionStructure(*func_);
is_global_scope_ = true;
return {};
}
std::any IRGenImpl::visitFuncFParam(SysYParser::FuncFParamContext* ctx) {
// We handle fparams in visitFuncDef directly.
return {}; return {};
} }

View File

@@ -6,34 +6,146 @@
#include "ir/IR.h" #include "ir/IR.h"
#include "utils/Log.h" #include "utils/Log.h"
// 语句生成当前只实现了最小子集。 // 语句生成
// 目前支持:
// - return <exp>;
//
// 还未支持:
// - 赋值语句
// - if / while 等控制流
// - 空语句、块语句嵌套分发之外的更多语句形态
std::any IRGenImpl::visitStmt(SysYParser::StmtContext* ctx) { std::any IRGenImpl::visitStmt(SysYParser::StmtContext* ctx) {
if (!ctx) { if (!ctx) return BlockFlow::Continue;
throw std::runtime_error(FormatError("irgen", "缺少语句"));
} if (ctx->assignStmt()) return ctx->assignStmt()->accept(this);
if (ctx->returnStmt()) { if (ctx->returnStmt()) return ctx->returnStmt()->accept(this);
return ctx->returnStmt()->accept(this); if (ctx->blockStmt()) return ctx->blockStmt()->accept(this);
} if (ctx->ifStmt()) return ctx->ifStmt()->accept(this);
throw std::runtime_error(FormatError("irgen", "暂不支持的语句类型")); if (ctx->whileStmt()) return ctx->whileStmt()->accept(this);
if (ctx->breakStmt()) return ctx->breakStmt()->accept(this);
if (ctx->continueStmt()) return ctx->continueStmt()->accept(this);
if (ctx->expStmt()) return ctx->expStmt()->accept(this);
return BlockFlow::Continue;
} }
std::any IRGenImpl::visitBlockStmt(SysYParser::BlockStmtContext* ctx) {
for (auto* item : ctx->blockItem()) {
if (std::any_cast<BlockFlow>(item->accept(this)) == BlockFlow::Terminated) {
return BlockFlow::Terminated;
}
}
return BlockFlow::Continue;
}
std::any IRGenImpl::visitBlockItem(SysYParser::BlockItemContext* ctx) {
if (ctx->decl()) {
ctx->decl()->accept(this);
return BlockFlow::Continue;
}
if (ctx->stmt()) {
return ctx->stmt()->accept(this);
}
return BlockFlow::Continue;
}
std::any IRGenImpl::visitAssignStmt(SysYParser::AssignStmtContext* ctx) {
ir::Value* ptr = GetLValuePtr(ctx->lValue());
ir::Value* val = EvalExpr(*ctx->exp());
if (ptr->GetType()->IsPtrFloat() && val->GetType()->IsInt32()) {
val = builder_.CreateSIToFP(val, ir::Type::GetFloatType(),
module_.GetContext().NextTemp());
} else if (ptr->GetType()->IsPtrInt32() && val->GetType()->IsFloat()) {
val = builder_.CreateFPToSI(val, ir::Type::GetInt32Type(),
module_.GetContext().NextTemp());
}
builder_.CreateStore(val, ptr);
return BlockFlow::Continue;
}
std::any IRGenImpl::visitReturnStmt(SysYParser::ReturnStmtContext* ctx) { std::any IRGenImpl::visitReturnStmt(SysYParser::ReturnStmtContext* ctx) {
if (!ctx) { if (ctx->exp()) {
throw std::runtime_error(FormatError("irgen", "缺少 return 语句")); ir::Value* v = EvalExpr(*ctx->exp());
if (func_->GetType()->IsFloat() && v->GetType()->IsInt32()) {
v = builder_.CreateSIToFP(v, ir::Type::GetFloatType(),
module_.GetContext().NextTemp());
} else if (func_->GetType()->IsInt32() && v->GetType()->IsFloat()) {
v = builder_.CreateFPToSI(v, ir::Type::GetInt32Type(),
module_.GetContext().NextTemp());
}
builder_.CreateRet(v);
} else {
builder_.CreateRet(nullptr);
} }
if (!ctx->exp()) {
throw std::runtime_error(FormatError("irgen", "return 缺少表达式"));
}
ir::Value* v = EvalExpr(*ctx->exp());
builder_.CreateRet(v);
return BlockFlow::Terminated; return BlockFlow::Terminated;
} }
std::any IRGenImpl::visitIfStmt(SysYParser::IfStmtContext* ctx) {
const std::string suffix = module_.GetContext().NextTemp();
ir::BasicBlock* true_bb = func_->CreateBlock("if.true." + suffix);
ir::BasicBlock* false_bb =
ctx->ELSE() ? func_->CreateBlock("if.false." + suffix) : nullptr;
ir::BasicBlock* merge_bb = func_->CreateBlock("if.merge." + suffix);
ir::Value* cond = EvalExpr(*ctx->exp());
builder_.CreateCondBr(ToI1(cond), true_bb, false_bb ? false_bb : merge_bb);
// True block
builder_.SetInsertPoint(true_bb);
if (std::any_cast<BlockFlow>(ctx->stmt(0)->accept(this)) == BlockFlow::Continue) {
builder_.CreateBr(merge_bb);
}
// False block
if (false_bb) {
builder_.SetInsertPoint(false_bb);
if (std::any_cast<BlockFlow>(ctx->stmt(1)->accept(this)) == BlockFlow::Continue) {
builder_.CreateBr(merge_bb);
}
}
builder_.SetInsertPoint(merge_bb);
return BlockFlow::Continue;
}
std::any IRGenImpl::visitWhileStmt(SysYParser::WhileStmtContext* ctx) {
const std::string suffix = module_.GetContext().NextTemp();
ir::BasicBlock* cond_bb = func_->CreateBlock("while.cond." + suffix);
ir::BasicBlock* body_bb = func_->CreateBlock("while.body." + suffix);
ir::BasicBlock* end_bb = func_->CreateBlock("while.end." + suffix);
builder_.CreateBr(cond_bb);
builder_.SetInsertPoint(cond_bb);
ir::Value* cond = EvalExpr(*ctx->exp());
builder_.CreateCondBr(ToI1(cond), body_bb, end_bb);
break_stack_.push(end_bb);
continue_stack_.push(cond_bb);
builder_.SetInsertPoint(body_bb);
if (std::any_cast<BlockFlow>(ctx->stmt()->accept(this)) == BlockFlow::Continue) {
builder_.CreateBr(cond_bb);
}
break_stack_.pop();
continue_stack_.pop();
builder_.SetInsertPoint(end_bb);
return BlockFlow::Continue;
}
std::any IRGenImpl::visitBreakStmt(SysYParser::BreakStmtContext* ctx) {
if (break_stack_.empty()) {
throw std::runtime_error(FormatError("irgen", "break 语句不在循环内"));
}
builder_.CreateBr(break_stack_.top());
return BlockFlow::Terminated;
}
std::any IRGenImpl::visitContinueStmt(SysYParser::ContinueStmtContext* ctx) {
if (continue_stack_.empty()) {
throw std::runtime_error(FormatError("irgen", "continue 语句不在循环内"));
}
builder_.CreateBr(continue_stack_.top());
return BlockFlow::Terminated;
}
std::any IRGenImpl::visitExpStmt(SysYParser::ExpStmtContext* ctx) {
if (ctx->exp()) {
EvalExpr(*ctx->exp());
}
return BlockFlow::Continue;
}

View File

@@ -6,6 +6,7 @@
#include "frontend/SyntaxTreePrinter.h" #include "frontend/SyntaxTreePrinter.h"
#if !COMPILER_PARSE_ONLY #if !COMPILER_PARSE_ONLY
#include "ir/IR.h" #include "ir/IR.h"
#include "ir/PassManager.h"
#include "irgen/IRGen.h" #include "irgen/IRGen.h"
#include "mir/MIR.h" #include "mir/MIR.h"
#include "sem/Sema.h" #include "sem/Sema.h"
@@ -36,6 +37,7 @@ int main(int argc, char** argv) {
auto sema = RunSema(*comp_unit); auto sema = RunSema(*comp_unit);
auto module = GenerateIR(*comp_unit, sema); auto module = GenerateIR(*comp_unit, sema);
ir::RunOptimizationPasses(*module);
if (opts.emit_ir) { if (opts.emit_ir) {
ir::IRPrinter printer; ir::IRPrinter printer;
if (need_blank_line) { if (need_blank_line) {
@@ -46,13 +48,18 @@ int main(int argc, char** argv) {
} }
if (opts.emit_asm) { if (opts.emit_asm) {
auto machine_func = mir::LowerToMIR(*module); mir::PrintGlobals(*module, std::cout);
mir::RunRegAlloc(*machine_func); auto machine_funcs = mir::LowerToMIR(*module);
mir::RunFrameLowering(*machine_func); for (auto& machine_func : machine_funcs) {
if (need_blank_line) { mir::RunRegAlloc(*machine_func);
std::cout << "\n"; mir::RunFrameLowering(*machine_func);
mir::RunPeephole(*machine_func);
if (need_blank_line) {
std::cout << "\n";
}
mir::PrintAsm(*machine_func, std::cout);
need_blank_line = true;
} }
mir::PrintAsm(*machine_func, std::cout);
} }
#else #else
if (opts.emit_ir || opts.emit_asm) { if (opts.emit_ir || opts.emit_asm) {

View File

@@ -1,7 +1,11 @@
#include "mir/MIR.h" #include "mir/MIR.h"
#include "ir/IR.h"
#include <ostream> #include <ostream>
#include <stdexcept> #include <stdexcept>
#include <cstdint>
#include <vector>
#include <cstring>
#include "utils/Log.h" #include "utils/Log.h"
@@ -16,10 +20,34 @@ const FrameSlot& GetFrameSlot(const MachineFunction& function,
return function.GetFrameSlot(operand.GetFrameIndex()); return function.GetFrameSlot(operand.GetFrameIndex());
} }
bool IsFloatReg(PhysReg reg) {
return reg >= PhysReg::S0 && reg <= PhysReg::S15;
}
void PrintStackAccess(std::ostream& os, const char* mnemonic, PhysReg reg, void PrintStackAccess(std::ostream& os, const char* mnemonic, PhysReg reg,
int offset) { int offset) {
os << " " << mnemonic << " " << PhysRegName(reg) << ", [x29, #" << offset bool is_float = IsFloatReg(reg);
<< "]\n"; const char* ldr_cmd = is_float ? "ldr" : "ldr";
const char* str_cmd = is_float ? "str" : "str";
const char* base_mnemonic = (std::strcmp(mnemonic, "ldur") == 0) ? ldr_cmd : str_cmd;
if (offset >= -256 && offset <= 255) {
if (is_float) {
os << " " << base_mnemonic << " " << PhysRegName(reg) << ", [x29, #" << offset << "]\n";
} else {
os << " " << mnemonic << " " << PhysRegName(reg) << ", [x29, #" << offset << "]\n";
}
} else {
os << " mov x10, #" << offset << "\n";
os << " " << base_mnemonic << " " << PhysRegName(reg) << ", [x29, x10]\n";
}
}
std::string GetBlockLabel(const std::string& func_name, const std::string& block_name) {
if (block_name == "entry" || block_name.empty()) {
return func_name;
}
return ".L_" + func_name + "_" + block_name;
} }
} // namespace } // namespace
@@ -28,51 +56,269 @@ void PrintAsm(const MachineFunction& function, std::ostream& os) {
os << ".text\n"; os << ".text\n";
os << ".global " << function.GetName() << "\n"; os << ".global " << function.GetName() << "\n";
os << ".type " << function.GetName() << ", %function\n"; os << ".type " << function.GetName() << ", %function\n";
os << function.GetName() << ":\n";
for (const auto& inst : function.GetEntry().GetInstructions()) { struct FloatConstant {
const auto& ops = inst.GetOperands(); std::string label;
switch (inst.GetOpcode()) { float value;
case Opcode::Prologue: };
os << " stp x29, x30, [sp, #-16]!\n"; std::vector<FloatConstant> float_constants;
os << " mov x29, sp\n";
if (function.GetFrameSize() > 0) { for (size_t b = 0; b < function.GetBlocks().size(); ++b) {
os << " sub sp, sp, #" << function.GetFrameSize() << "\n"; const auto& block = function.GetBlocks()[b];
// Print the block label
if (b == 0) {
os << function.GetName() << ":\n";
} else {
os << GetBlockLabel(function.GetName(), block.GetName()) << ":\n";
}
for (const auto& inst : block.GetInstructions()) {
const auto& ops = inst.GetOperands();
switch (inst.GetOpcode()) {
case Opcode::Prologue:
os << " stp x29, x30, [sp, #-16]!\n";
os << " mov x29, sp\n";
if (function.GetFrameSize() > 0) {
os << " sub sp, sp, #" << function.GetFrameSize() << "\n";
}
break;
case Opcode::Epilogue:
if (function.GetFrameSize() > 0) {
os << " add sp, sp, #" << function.GetFrameSize() << "\n";
}
os << " ldp x29, x30, [sp], #16\n";
break;
case Opcode::MovImm: {
PhysReg dst = ops.at(0).GetReg();
if (IsFloatReg(dst)) {
// Load float constant
int bits = ops.at(1).GetImm();
float val;
std::memcpy(&val, &bits, sizeof(float));
std::string flabel = ".LC_" + function.GetName() + "_" + std::to_string(float_constants.size());
float_constants.push_back({flabel, val});
os << " adrp x8, " << flabel << "\n";
os << " ldr " << PhysRegName(dst) << ", [x8, :lo12:" << flabel << "]\n";
} else {
os << " mov " << PhysRegName(dst) << ", #" << ops.at(1).GetImm() << "\n";
}
break;
} }
break; case Opcode::LoadStack: {
case Opcode::Epilogue: const auto& slot = GetFrameSlot(function, ops.at(1));
if (function.GetFrameSize() > 0) { PrintStackAccess(os, "ldur", ops.at(0).GetReg(), slot.offset);
os << " add sp, sp, #" << function.GetFrameSize() << "\n"; break;
} }
os << " ldp x29, x30, [sp], #16\n"; case Opcode::StoreStack: {
break; const auto& slot = GetFrameSlot(function, ops.at(1));
case Opcode::MovImm: PrintStackAccess(os, "stur", ops.at(0).GetReg(), slot.offset);
os << " mov " << PhysRegName(ops.at(0).GetReg()) << ", #" break;
<< ops.at(1).GetImm() << "\n"; }
break; case Opcode::AddRR:
case Opcode::LoadStack: { os << " add " << PhysRegName(ops.at(0).GetReg()) << ", "
const auto& slot = GetFrameSlot(function, ops.at(1)); << PhysRegName(ops.at(1).GetReg()) << ", "
PrintStackAccess(os, "ldur", ops.at(0).GetReg(), slot.offset); << PhysRegName(ops.at(2).GetReg()) << "\n";
break; break;
case Opcode::SubRR:
os << " sub " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::MulRR:
os << " mul " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::SDivRR:
os << " sdiv " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::MSubRRRR:
os << " msub " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << ", "
<< PhysRegName(ops.at(3).GetReg()) << "\n";
break;
case Opcode::FAddRRR:
os << " fadd " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::FSubRRR:
os << " fsub " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::FMulRRR:
os << " fmul " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::FDivRRR:
os << " fdiv " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::CmpRR:
os << " cmp " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << "\n";
break;
case Opcode::FCmpRR:
os << " fcmp " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << "\n";
break;
case Opcode::Cset:
os << " cset " << PhysRegName(ops.at(0).GetReg()) << ", "
<< ops.at(1).GetCondCode() << "\n";
break;
case Opcode::B:
os << " b " << GetBlockLabel(function.GetName(), ops.at(0).GetLabelName()) << "\n";
break;
case Opcode::BCond:
os << " b." << ops.at(0).GetCondCode() << " "
<< GetBlockLabel(function.GetName(), ops.at(1).GetLabelName()) << "\n";
break;
case Opcode::Call:
os << " bl " << ops.at(0).GetGlobalName() << "\n";
break;
case Opcode::Ret:
os << " ret\n";
break;
case Opcode::MovReg:
if (IsFloatReg(ops.at(0).GetReg()) || IsFloatReg(ops.at(1).GetReg())) {
os << " fmov " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << "\n";
} else {
os << " mov " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << "\n";
}
break;
case Opcode::Adrp:
os << " adrp " << PhysRegName(ops.at(0).GetReg()) << ", "
<< ops.at(1).GetGlobalName() << "\n";
break;
case Opcode::AddRegImm: {
os << " add " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", ";
if (ops.at(2).GetKind() == Operand::Kind::FrameIndex) {
const auto& slot = function.GetFrameSlot(ops.at(2).GetFrameIndex());
os << "#" << slot.offset << "\n";
} else if (ops.at(2).GetKind() == Operand::Kind::Global) {
os << ":lo12:" << ops.at(2).GetGlobalName() << "\n";
} else {
os << "#" << ops.at(2).GetImm() << "\n";
}
break;
}
case Opcode::LdrRegReg: {
PhysReg reg = ops.at(0).GetReg();
const char* ldr_cmd = IsFloatReg(reg) ? "ldr" : "ldr";
os << " " << ldr_cmd << " " << PhysRegName(reg) << ", ["
<< PhysRegName(ops.at(1).GetReg()) << "]\n";
break;
}
case Opcode::StrRegReg: {
PhysReg reg = ops.at(0).GetReg();
const char* str_cmd = IsFloatReg(reg) ? "str" : "str";
os << " " << str_cmd << " " << PhysRegName(reg) << ", ["
<< PhysRegName(ops.at(1).GetReg()) << "]\n";
break;
}
case Opcode::SIToFP:
os << " scvtf " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << "\n";
break;
case Opcode::FPToSI:
os << " fcvtzs " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << "\n";
break;
case Opcode::ZExt:
if (ops.at(0).GetReg() >= PhysReg::X0 && ops.at(0).GetReg() <= PhysReg::X28) {
os << " sxtw " << PhysRegName(ops.at(0).GetReg()) << ", " << PhysRegName(ops.at(1).GetReg()) << "\n";
} else {
os << " and " << PhysRegName(ops.at(0).GetReg()) << ", " << PhysRegName(ops.at(1).GetReg()) << ", #1\n";
}
break;
} }
case Opcode::StoreStack: {
const auto& slot = GetFrameSlot(function, ops.at(1));
PrintStackAccess(os, "stur", ops.at(0).GetReg(), slot.offset);
break;
}
case Opcode::AddRR:
os << " add " << PhysRegName(ops.at(0).GetReg()) << ", "
<< PhysRegName(ops.at(1).GetReg()) << ", "
<< PhysRegName(ops.at(2).GetReg()) << "\n";
break;
case Opcode::Ret:
os << " ret\n";
break;
} }
} }
os << ".size " << function.GetName() << ", .-" << function.GetName() os << ".size " << function.GetName() << ", .-" << function.GetName() << "\n";
<< "\n";
// Print read-only data segment if there are float constants
if (!float_constants.empty()) {
os << ".section .rodata\n";
os << ".align 2\n";
for (const auto& fc : float_constants) {
os << fc.label << ":\n";
uint32_t bits;
std::memcpy(&bits, &fc.value, sizeof(float));
os << " .word " << bits << " // float " << fc.value << "\n";
}
}
}
static uint32_t GetTypeSize(const ir::Type* type) {
if (type->IsInt32() || type->IsFloat()) {
return 4;
}
if (type->IsPtrInt32() || type->IsPtrFloat()) {
return 8;
}
if (type->IsArray()) {
auto* arr_ty = const_cast<ir::Type*>(type)->GetAsArrayType().get();
return arr_ty->GetNumElements() * GetTypeSize(arr_ty->GetElementType().get());
}
return 4;
}
void PrintGlobals(const ir::Module& module, std::ostream& os) {
for (const auto& gv : module.GetGlobalValues()) {
os << ".global " << gv->GetName() << "\n";
std::shared_ptr<ir::Type> actual_ty = gv->GetType();
if (actual_ty->IsPtrInt32()) actual_ty = ir::Type::GetInt32Type();
else if (actual_ty->IsPtrFloat()) actual_ty = ir::Type::GetFloatType();
uint32_t actual_size = GetTypeSize(actual_ty.get());
if (gv->GetInitializer()) {
os << ".data\n";
os << ".align 2\n";
os << ".size " << gv->GetName() << ", " << actual_size << "\n";
os << gv->GetName() << ":\n";
if (actual_ty->IsFloat()) {
float val = 0.0f;
if (auto* cf = dynamic_cast<const ir::ConstantFloat*>(gv->GetInitializer())) {
val = cf->GetValue();
} else if (auto* ci = dynamic_cast<const ir::ConstantInt*>(gv->GetInitializer())) {
val = static_cast<float>(ci->GetValue());
}
uint32_t bits;
std::memcpy(&bits, &val, sizeof(float));
os << " .word " << bits << " // float " << val << "\n";
} else {
int val = 0;
if (auto* ci = dynamic_cast<const ir::ConstantInt*>(gv->GetInitializer())) {
val = ci->GetValue();
} else if (auto* cf = dynamic_cast<const ir::ConstantFloat*>(gv->GetInitializer())) {
val = static_cast<int>(cf->GetValue());
}
os << " .word " << val << "\n";
}
} else {
os << ".bss\n";
os << ".align 4\n";
os << ".size " << gv->GetName() << ", " << actual_size << "\n";
os << gv->GetName() << ":\n";
os << " .zero " << actual_size << "\n";
}
os << "\n";
}
} }
} // namespace mir } // namespace mir

View File

@@ -18,10 +18,10 @@ void RunFrameLowering(MachineFunction& function) {
int cursor = 0; int cursor = 0;
for (const auto& slot : function.GetFrameSlots()) { for (const auto& slot : function.GetFrameSlots()) {
cursor += slot.size; cursor += slot.size;
if (-cursor < -256) {
throw std::runtime_error(FormatError("mir", "暂不支持过大的栈帧"));
}
} }
// Align stack frames to 16 bytes for AArch64
cursor = AlignTo(cursor, 16);
cursor = 0; cursor = 0;
for (const auto& slot : function.GetFrameSlots()) { for (const auto& slot : function.GetFrameSlots()) {
@@ -30,16 +30,25 @@ void RunFrameLowering(MachineFunction& function) {
} }
function.SetFrameSize(AlignTo(cursor, 16)); function.SetFrameSize(AlignTo(cursor, 16));
auto& insts = function.GetEntry().GetInstructions(); auto& blocks = function.GetBlocks();
std::vector<MachineInstr> lowered; if (blocks.empty()) return;
lowered.emplace_back(Opcode::Prologue);
for (const auto& inst : insts) { // Insert Prologue at the start of the first block
if (inst.GetOpcode() == Opcode::Ret) { auto& entry_insts = blocks.front().GetInstructions();
lowered.emplace_back(Opcode::Epilogue); entry_insts.insert(entry_insts.begin(), MachineInstr(Opcode::Prologue));
// Insert Epilogue before every Ret in all blocks
for (auto& block : blocks) {
auto& insts = block.GetInstructions();
std::vector<MachineInstr> lowered;
for (const auto& inst : insts) {
if (inst.GetOpcode() == Opcode::Ret) {
lowered.emplace_back(Opcode::Epilogue);
}
lowered.push_back(inst);
} }
lowered.push_back(inst); insts = std::move(lowered);
} }
insts = std::move(lowered);
} }
} // namespace mir } // namespace mir

View File

@@ -2,122 +2,535 @@
#include <stdexcept> #include <stdexcept>
#include <unordered_map> #include <unordered_map>
#include <vector>
#include <cstring>
#include "ir/IR.h" #include "ir/IR.h"
#include "utils/Log.h" #include "utils/Log.h"
#include <iostream>
namespace mir { namespace mir {
namespace { namespace {
using ValueSlotMap = std::unordered_map<const ir::Value*, int>; using ValueSlotMap = std::unordered_map<const ir::Value*, int>;
uint32_t GetTypeSize(const ir::Type* type) {
if (type->IsInt32() || type->IsFloat()) {
return 4;
}
if (type->IsPtrInt32() || type->IsPtrFloat()) {
return 8; // 64-bit pointers
}
if (type->IsArray()) {
auto* arr_ty = const_cast<ir::Type*>(type)->GetAsArrayType().get();
return arr_ty->GetNumElements() * GetTypeSize(arr_ty->GetElementType().get());
}
return 4;
}
uint32_t GetAllocaSize(const ir::Instruction& inst) {
auto type = inst.GetType();
if (type->IsPtrInt32() || type->IsPtrFloat()) {
// Check if any StoreInst in the parent function stores a pointer to this alloca
auto* parent_bb = inst.GetParent();
if (parent_bb) {
auto* parent_func = parent_bb->GetParent();
if (parent_func) {
for (const auto& bbPtr : parent_func->GetBlocks()) {
for (const auto& other_inst : bbPtr->GetInstructions()) {
if (other_inst->GetOpcode() == ir::Opcode::Store) {
auto* store = static_cast<const ir::StoreInst*>(other_inst.get());
if (store->GetPtr() == &inst) {
auto val_ty = store->GetValue()->GetType();
if (val_ty->IsPtrInt32() || val_ty->IsPtrFloat()) {
return 8; // Stores a 64-bit pointer
}
}
}
}
}
}
}
return 4;
}
return GetTypeSize(type.get());
}
std::vector<uint32_t> GetGepStrides(const ir::GetElementPtrInst& gep) {
std::vector<uint32_t> strides;
auto curr_type = gep.GetPtr()->GetType();
if (curr_type->IsPtrInt32() || curr_type->IsPtrFloat()) {
strides.push_back(4);
} else if (curr_type->IsArray()) {
strides.push_back(GetTypeSize(curr_type.get()));
for (size_t i = 2; i < gep.GetNumOperands(); ++i) {
curr_type = curr_type->GetAsArrayType()->GetElementType();
strides.push_back(GetTypeSize(curr_type.get()));
}
}
return strides;
}
void EmitAddressToReg(const ir::Value* value, PhysReg target,
const ValueSlotMap& slots, MachineBasicBlock& block) {
if (auto* alloca = dynamic_cast<const ir::Instruction*>(value)) {
if (alloca->GetOpcode() == ir::Opcode::Alloca) {
auto it = slots.find(value);
if (it == slots.end()) {
throw std::runtime_error(FormatError("mir", "找不到局部变量的栈槽: " + value->GetName()));
}
block.Append(Opcode::AddRegImm, {Operand::Reg(target), Operand::Reg(PhysReg::X29), Operand::FrameIndex(it->second)});
return;
}
}
if (value->IsGlobalValue()) {
block.Append(Opcode::Adrp, {Operand::Reg(target), Operand::Global(value->GetName())});
block.Append(Opcode::AddRegImm, {Operand::Reg(target), Operand::Reg(target), Operand::Global(value->GetName())});
return;
}
// Otherwise, the address itself is stored in a stack slot
auto it = slots.find(value);
if (it == slots.end()) {
throw std::runtime_error(FormatError("mir", "找不到指针的值槽: " + value->GetName()));
}
block.Append(Opcode::LoadStack, {Operand::Reg(target), Operand::FrameIndex(it->second)});
}
void EmitValueToReg(const ir::Value* value, PhysReg target, void EmitValueToReg(const ir::Value* value, PhysReg target,
const ValueSlotMap& slots, MachineBasicBlock& block) { const ValueSlotMap& slots, MachineBasicBlock& block) {
if (auto* constant = dynamic_cast<const ir::ConstantInt*>(value)) { if (auto* constant = dynamic_cast<const ir::ConstantInt*>(value)) {
block.Append(Opcode::MovImm, block.Append(Opcode::MovImm, {Operand::Reg(target), Operand::Imm(constant->GetValue())});
{Operand::Reg(target), Operand::Imm(constant->GetValue())}); return;
}
if (auto* constant = dynamic_cast<const ir::ConstantFloat*>(value)) {
float fval = constant->GetValue();
int bits;
std::memcpy(&bits, &fval, sizeof(float));
block.Append(Opcode::MovImm, {Operand::Reg(target), Operand::Imm(bits)});
return;
}
if (value->IsGlobalValue()) {
EmitAddressToReg(value, target, slots, block);
return; return;
} }
auto it = slots.find(value); auto it = slots.find(value);
if (it == slots.end()) { if (it == slots.end()) {
throw std::runtime_error( throw std::runtime_error(FormatError("mir", "找不到值对应的栈槽: " + value->GetName()));
FormatError("mir", "找不到值对应的栈槽: " + value->GetName()));
} }
block.Append(Opcode::LoadStack, block.Append(Opcode::LoadStack, {Operand::Reg(target), Operand::FrameIndex(it->second)});
{Operand::Reg(target), Operand::FrameIndex(it->second)});
} }
void LowerInstruction(const ir::Instruction& inst, MachineFunction& function, void LowerInstruction(const ir::Instruction& inst, MachineFunction& function,
ValueSlotMap& slots) { ValueSlotMap& slots, MachineBasicBlock& block) {
auto& block = function.GetEntry();
switch (inst.GetOpcode()) { switch (inst.GetOpcode()) {
case ir::Opcode::Alloca: { case ir::Opcode::Alloca: {
slots.emplace(&inst, function.CreateFrameIndex()); slots.emplace(&inst, function.CreateFrameIndex(GetAllocaSize(inst)));
return; return;
} }
case ir::Opcode::Store: { case ir::Opcode::Store: {
auto& store = static_cast<const ir::StoreInst&>(inst); auto& store = static_cast<const ir::StoreInst&>(inst);
auto dst = slots.find(store.GetPtr());
if (dst == slots.end()) { if (auto* alloca = dynamic_cast<const ir::Instruction*>(store.GetPtr())) {
throw std::runtime_error( if (alloca->GetOpcode() == ir::Opcode::Alloca) {
FormatError("mir", "暂不支持对非栈变量地址进行写入")); auto it = slots.find(alloca);
if (it != slots.end()) {
PhysReg val_reg = store.GetValue()->GetType()->IsFloat() ? PhysReg::S8 :
(store.GetValue()->GetType()->IsPtrInt32() || store.GetValue()->GetType()->IsPtrFloat()) ? PhysReg::X8 : PhysReg::W8;
EmitValueToReg(store.GetValue(), val_reg, slots, block);
block.Append(Opcode::StoreStack, {Operand::Reg(val_reg), Operand::FrameIndex(it->second)});
return;
}
}
} }
EmitValueToReg(store.GetValue(), PhysReg::W8, slots, block);
block.Append(Opcode::StoreStack, // Dynamic store
{Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst->second)}); PhysReg val_reg = store.GetValue()->GetType()->IsFloat() ? PhysReg::S8 :
(store.GetValue()->GetType()->IsPtrInt32() || store.GetValue()->GetType()->IsPtrFloat()) ? PhysReg::X8 : PhysReg::W8;
EmitValueToReg(store.GetValue(), val_reg, slots, block);
EmitAddressToReg(store.GetPtr(), PhysReg::X9, slots, block);
block.Append(Opcode::StrRegReg, {Operand::Reg(val_reg), Operand::Reg(PhysReg::X9)});
return; return;
} }
case ir::Opcode::Load: { case ir::Opcode::Load: {
auto& load = static_cast<const ir::LoadInst&>(inst); auto& load = static_cast<const ir::LoadInst&>(inst);
auto src = slots.find(load.GetPtr()); int dst_slot = function.CreateFrameIndex(GetTypeSize(load.GetType().get()));
if (src == slots.end()) {
throw std::runtime_error(
FormatError("mir", "暂不支持对非栈变量地址进行读取"));
}
int dst_slot = function.CreateFrameIndex();
block.Append(Opcode::LoadStack,
{Operand::Reg(PhysReg::W8), Operand::FrameIndex(src->second)});
block.Append(Opcode::StoreStack,
{Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)});
slots.emplace(&inst, dst_slot); slots.emplace(&inst, dst_slot);
if (auto* alloca = dynamic_cast<const ir::Instruction*>(load.GetPtr())) {
if (alloca->GetOpcode() == ir::Opcode::Alloca) {
auto it = slots.find(alloca);
if (it != slots.end()) {
PhysReg val_reg = load.GetType()->IsFloat() ? PhysReg::S8 :
(load.GetType()->IsPtrInt32() || load.GetType()->IsPtrFloat()) ? PhysReg::X8 : PhysReg::W8;
block.Append(Opcode::LoadStack, {Operand::Reg(val_reg), Operand::FrameIndex(it->second)});
block.Append(Opcode::StoreStack, {Operand::Reg(val_reg), Operand::FrameIndex(dst_slot)});
return;
}
}
}
// Dynamic load
PhysReg val_reg = load.GetType()->IsFloat() ? PhysReg::S8 :
(load.GetType()->IsPtrInt32() || load.GetType()->IsPtrFloat()) ? PhysReg::X8 : PhysReg::W8;
EmitAddressToReg(load.GetPtr(), PhysReg::X9, slots, block);
block.Append(Opcode::LdrRegReg, {Operand::Reg(val_reg), Operand::Reg(PhysReg::X9)});
block.Append(Opcode::StoreStack, {Operand::Reg(val_reg), Operand::FrameIndex(dst_slot)});
return; return;
} }
case ir::Opcode::Add: { case ir::Opcode::Add:
case ir::Opcode::Sub:
case ir::Opcode::Mul:
case ir::Opcode::Div:
case ir::Opcode::Mod: {
auto& bin = static_cast<const ir::BinaryInst&>(inst); auto& bin = static_cast<const ir::BinaryInst&>(inst);
int dst_slot = function.CreateFrameIndex(); int dst_slot = function.CreateFrameIndex(4);
slots.emplace(&inst, dst_slot);
EmitValueToReg(bin.GetLhs(), PhysReg::W8, slots, block); EmitValueToReg(bin.GetLhs(), PhysReg::W8, slots, block);
EmitValueToReg(bin.GetRhs(), PhysReg::W9, slots, block); EmitValueToReg(bin.GetRhs(), PhysReg::W9, slots, block);
block.Append(Opcode::AddRR, {Operand::Reg(PhysReg::W8),
Operand::Reg(PhysReg::W8), if (inst.GetOpcode() == ir::Opcode::Add) {
Operand::Reg(PhysReg::W9)}); block.Append(Opcode::AddRR, {Operand::Reg(PhysReg::W8), Operand::Reg(PhysReg::W8), Operand::Reg(PhysReg::W9)});
block.Append(Opcode::StoreStack, } else if (inst.GetOpcode() == ir::Opcode::Sub) {
{Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)}); block.Append(Opcode::SubRR, {Operand::Reg(PhysReg::W8), Operand::Reg(PhysReg::W8), Operand::Reg(PhysReg::W9)});
} else if (inst.GetOpcode() == ir::Opcode::Mul) {
block.Append(Opcode::MulRR, {Operand::Reg(PhysReg::W8), Operand::Reg(PhysReg::W8), Operand::Reg(PhysReg::W9)});
} else if (inst.GetOpcode() == ir::Opcode::Div) {
block.Append(Opcode::SDivRR, {Operand::Reg(PhysReg::W8), Operand::Reg(PhysReg::W8), Operand::Reg(PhysReg::W9)});
} else if (inst.GetOpcode() == ir::Opcode::Mod) {
block.Append(Opcode::SDivRR, {Operand::Reg(PhysReg::W10), Operand::Reg(PhysReg::W8), Operand::Reg(PhysReg::W9)});
block.Append(Opcode::MSubRRRR, {Operand::Reg(PhysReg::W8), Operand::Reg(PhysReg::W10), Operand::Reg(PhysReg::W9), Operand::Reg(PhysReg::W8)});
}
block.Append(Opcode::StoreStack, {Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)});
return;
}
case ir::Opcode::FAdd:
case ir::Opcode::FSub:
case ir::Opcode::FMul:
case ir::Opcode::FDiv: {
auto& bin = static_cast<const ir::BinaryInst&>(inst);
int dst_slot = function.CreateFrameIndex(4);
slots.emplace(&inst, dst_slot); slots.emplace(&inst, dst_slot);
EmitValueToReg(bin.GetLhs(), PhysReg::S8, slots, block);
EmitValueToReg(bin.GetRhs(), PhysReg::S9, slots, block);
if (inst.GetOpcode() == ir::Opcode::FAdd) {
block.Append(Opcode::FAddRRR, {Operand::Reg(PhysReg::S8), Operand::Reg(PhysReg::S8), Operand::Reg(PhysReg::S9)});
} else if (inst.GetOpcode() == ir::Opcode::FSub) {
block.Append(Opcode::FSubRRR, {Operand::Reg(PhysReg::S8), Operand::Reg(PhysReg::S8), Operand::Reg(PhysReg::S9)});
} else if (inst.GetOpcode() == ir::Opcode::FMul) {
block.Append(Opcode::FMulRRR, {Operand::Reg(PhysReg::S8), Operand::Reg(PhysReg::S8), Operand::Reg(PhysReg::S9)});
} else if (inst.GetOpcode() == ir::Opcode::FDiv) {
block.Append(Opcode::FDivRRR, {Operand::Reg(PhysReg::S8), Operand::Reg(PhysReg::S8), Operand::Reg(PhysReg::S9)});
}
block.Append(Opcode::StoreStack, {Operand::Reg(PhysReg::S8), Operand::FrameIndex(dst_slot)});
return;
}
case ir::Opcode::ICmpEQ:
case ir::Opcode::ICmpNE:
case ir::Opcode::ICmpLT:
case ir::Opcode::ICmpGT:
case ir::Opcode::ICmpLE:
case ir::Opcode::ICmpGE: {
auto& cmp = static_cast<const ir::BinaryInst&>(inst);
int dst_slot = function.CreateFrameIndex(4);
slots.emplace(&inst, dst_slot);
EmitValueToReg(cmp.GetLhs(), PhysReg::W8, slots, block);
EmitValueToReg(cmp.GetRhs(), PhysReg::W9, slots, block);
block.Append(Opcode::CmpRR, {Operand::Reg(PhysReg::W8), Operand::Reg(PhysReg::W9)});
std::string cond;
switch (inst.GetOpcode()) {
case ir::Opcode::ICmpEQ: cond = "eq"; break;
case ir::Opcode::ICmpNE: cond = "ne"; break;
case ir::Opcode::ICmpLT: cond = "lt"; break;
case ir::Opcode::ICmpGT: cond = "gt"; break;
case ir::Opcode::ICmpLE: cond = "le"; break;
case ir::Opcode::ICmpGE: cond = "ge"; break;
default: break;
}
block.Append(Opcode::Cset, {Operand::Reg(PhysReg::W8), Operand::Cond(cond)});
block.Append(Opcode::StoreStack, {Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)});
return;
}
case ir::Opcode::FCmpEQ:
case ir::Opcode::FCmpNE:
case ir::Opcode::FCmpLT:
case ir::Opcode::FCmpGT:
case ir::Opcode::FCmpLE:
case ir::Opcode::FCmpGE: {
auto& cmp = static_cast<const ir::BinaryInst&>(inst);
int dst_slot = function.CreateFrameIndex(4);
slots.emplace(&inst, dst_slot);
EmitValueToReg(cmp.GetLhs(), PhysReg::S8, slots, block);
EmitValueToReg(cmp.GetRhs(), PhysReg::S9, slots, block);
block.Append(Opcode::FCmpRR, {Operand::Reg(PhysReg::S8), Operand::Reg(PhysReg::S9)});
std::string cond;
switch (inst.GetOpcode()) {
case ir::Opcode::FCmpEQ: cond = "eq"; break;
case ir::Opcode::FCmpNE: cond = "ne"; break;
case ir::Opcode::FCmpLT: cond = "mi"; break;
case ir::Opcode::FCmpGT: cond = "gt"; break;
case ir::Opcode::FCmpLE: cond = "ls"; break;
case ir::Opcode::FCmpGE: cond = "ge"; break;
default: break;
}
block.Append(Opcode::Cset, {Operand::Reg(PhysReg::W8), Operand::Cond(cond)});
block.Append(Opcode::StoreStack, {Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)});
return;
}
case ir::Opcode::ZExt: {
auto& cast = static_cast<const ir::CastInst&>(inst);
int dst_slot = function.CreateFrameIndex(4);
slots.emplace(&inst, dst_slot);
EmitValueToReg(cast.GetValue(), PhysReg::W8, slots, block);
block.Append(Opcode::ZExt, {Operand::Reg(PhysReg::W8), Operand::Reg(PhysReg::W8)});
block.Append(Opcode::StoreStack, {Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)});
return;
}
case ir::Opcode::SIToFP: {
auto& cast = static_cast<const ir::CastInst&>(inst);
int dst_slot = function.CreateFrameIndex(4);
slots.emplace(&inst, dst_slot);
EmitValueToReg(cast.GetValue(), PhysReg::W8, slots, block);
block.Append(Opcode::SIToFP, {Operand::Reg(PhysReg::S8), Operand::Reg(PhysReg::W8)});
block.Append(Opcode::StoreStack, {Operand::Reg(PhysReg::S8), Operand::FrameIndex(dst_slot)});
return;
}
case ir::Opcode::FPToSI: {
auto& cast = static_cast<const ir::CastInst&>(inst);
int dst_slot = function.CreateFrameIndex(4);
slots.emplace(&inst, dst_slot);
EmitValueToReg(cast.GetValue(), PhysReg::S8, slots, block);
block.Append(Opcode::FPToSI, {Operand::Reg(PhysReg::W8), Operand::Reg(PhysReg::S8)});
block.Append(Opcode::StoreStack, {Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)});
return;
}
case ir::Opcode::Br: {
auto& br = static_cast<const ir::BranchInst&>(inst);
auto emit_phi_copies = [&](const ir::BasicBlock* succ) {
if (!succ) return;
for (const auto& succ_inst : succ->GetInstructions()) {
if (succ_inst->GetOpcode() == ir::Opcode::Phi) {
auto* phi = static_cast<const ir::PhiInst*>(succ_inst.get());
const ir::Value* incoming_val = nullptr;
for (size_t i = 0; i < phi->GetNumIncoming(); ++i) {
if (phi->GetIncomingBlock(i) == inst.GetParent()) {
incoming_val = phi->GetIncomingValue(i);
break;
}
}
if (incoming_val) {
auto slot_it = slots.find(phi);
if (slot_it != slots.end()) {
int phi_slot = slot_it->second;
PhysReg val_reg = phi->GetType()->IsFloat() ? PhysReg::S8 :
(phi->GetType()->IsPtrInt32() || phi->GetType()->IsPtrFloat()) ? PhysReg::X8 : PhysReg::W8;
EmitValueToReg(incoming_val, val_reg, slots, block);
block.Append(Opcode::StoreStack, {Operand::Reg(val_reg), Operand::FrameIndex(phi_slot)});
}
}
}
}
};
if (br.IsConditional()) {
emit_phi_copies(br.GetIfTrue());
emit_phi_copies(br.GetIfFalse());
EmitValueToReg(br.GetCondition(), PhysReg::W8, slots, block);
block.Append(Opcode::MovImm, {Operand::Reg(PhysReg::W9), Operand::Imm(0)});
block.Append(Opcode::CmpRR, {Operand::Reg(PhysReg::W8), Operand::Reg(PhysReg::W9)});
block.Append(Opcode::BCond, {Operand::Cond("ne"), Operand::Label(br.GetIfTrue()->GetName())});
block.Append(Opcode::B, {Operand::Label(br.GetIfFalse()->GetName())});
} else {
emit_phi_copies(br.GetDest());
block.Append(Opcode::B, {Operand::Label(br.GetDest()->GetName())});
}
return;
}
case ir::Opcode::Phi: {
return; return;
} }
case ir::Opcode::Ret: { case ir::Opcode::Ret: {
auto& ret = static_cast<const ir::ReturnInst&>(inst); auto& ret = static_cast<const ir::ReturnInst&>(inst);
EmitValueToReg(ret.GetValue(), PhysReg::W0, slots, block); if (ret.GetValue()) {
PhysReg ret_reg = ret.GetValue()->GetType()->IsFloat() ? PhysReg::S0 : PhysReg::W0;
EmitValueToReg(ret.GetValue(), ret_reg, slots, block);
}
block.Append(Opcode::Ret); block.Append(Opcode::Ret);
return; return;
} }
case ir::Opcode::Sub: case ir::Opcode::Call: {
case ir::Opcode::Mul: auto& call = static_cast<const ir::CallInst&>(inst);
throw std::runtime_error(FormatError("mir", "暂不支持该二元运算")); int dst_slot = -1;
if (!call.GetType()->IsVoid()) {
dst_slot = function.CreateFrameIndex(GetTypeSize(call.GetType().get()));
slots.emplace(&inst, dst_slot);
}
int int_idx = 0;
int float_idx = 0;
for (size_t i = 1; i < call.GetNumOperands(); ++i) {
auto* arg = call.GetOperand(i);
if (arg->GetType()->IsFloat()) {
PhysReg reg = static_cast<PhysReg>(static_cast<int>(PhysReg::S0) + float_idx);
EmitValueToReg(arg, reg, slots, block);
float_idx++;
} else {
PhysReg reg = (arg->GetType()->IsPtrInt32() || arg->GetType()->IsPtrFloat())
? static_cast<PhysReg>(static_cast<int>(PhysReg::X0) + int_idx)
: static_cast<PhysReg>(static_cast<int>(PhysReg::W0) + int_idx);
EmitValueToReg(arg, reg, slots, block);
int_idx++;
}
}
block.Append(Opcode::Call, {Operand::Global(call.GetFunction()->GetName())});
if (dst_slot != -1) {
if (call.GetType()->IsFloat()) {
block.Append(Opcode::StoreStack, {Operand::Reg(PhysReg::S0), Operand::FrameIndex(dst_slot)});
} else {
PhysReg ret_reg = (call.GetType()->IsPtrInt32() || call.GetType()->IsPtrFloat()) ? PhysReg::X0 : PhysReg::W0;
block.Append(Opcode::StoreStack, {Operand::Reg(ret_reg), Operand::FrameIndex(dst_slot)});
}
}
return;
}
case ir::Opcode::GEP: {
auto& gep = static_cast<const ir::GetElementPtrInst&>(inst);
int dst_slot = function.CreateFrameIndex(8);
slots.emplace(&inst, dst_slot);
// Load base pointer address into X8
if (gep.GetPtr()->IsGlobalValue()) {
EmitAddressToReg(gep.GetPtr(), PhysReg::X8, slots, block);
} else if (auto* alloca = dynamic_cast<const ir::AllocaInst*>(gep.GetPtr())) {
if (alloca->GetType()->IsArray()) {
EmitAddressToReg(gep.GetPtr(), PhysReg::X8, slots, block);
} else {
EmitValueToReg(gep.GetPtr(), PhysReg::X8, slots, block);
}
} else {
EmitValueToReg(gep.GetPtr(), PhysReg::X8, slots, block);
}
auto strides = GetGepStrides(gep);
for (size_t i = 1; i < gep.GetNumOperands(); ++i) {
auto* idx = gep.GetOperand(i);
uint32_t stride = strides.at(i - 1);
// Skip if offset index is constant 0
if (auto* ci = dynamic_cast<const ir::ConstantInt*>(idx)) {
if (ci->GetValue() == 0) {
continue;
}
}
EmitValueToReg(idx, PhysReg::W9, slots, block);
if (stride > 1) {
block.Append(Opcode::MovImm, {Operand::Reg(PhysReg::W10), Operand::Imm(stride)});
block.Append(Opcode::MulRR, {Operand::Reg(PhysReg::W9), Operand::Reg(PhysReg::W9), Operand::Reg(PhysReg::W10)});
}
// Extend W9 to X9 and add to base address X8
block.Append(Opcode::ZExt, {Operand::Reg(PhysReg::X9), Operand::Reg(PhysReg::W9)});
block.Append(Opcode::AddRR, {Operand::Reg(PhysReg::X8), Operand::Reg(PhysReg::X8), Operand::Reg(PhysReg::X9)});
}
// Store address into GEP's stack slot
block.Append(Opcode::StoreStack, {Operand::Reg(PhysReg::X8), Operand::FrameIndex(dst_slot)});
return;
}
} }
throw std::runtime_error(FormatError("mir", "暂不支持该 IR 指令")); throw std::runtime_error(FormatError("mir", "暂不支持该 IR 指令: " + std::to_string(static_cast<int>(inst.GetOpcode()))));
} }
} // namespace } // namespace
std::unique_ptr<MachineFunction> LowerToMIR(const ir::Module& module) { std::vector<std::unique_ptr<MachineFunction>> LowerToMIR(const ir::Module& module) {
DefaultContext(); DefaultContext();
std::vector<std::unique_ptr<MachineFunction>> mfuncs;
if (module.GetFunctions().size() != 1) { for (const auto& funcPtr : module.GetFunctions()) {
throw std::runtime_error(FormatError("mir", "暂不支持多个函数")); const auto& func = *funcPtr;
if (func.GetBlocks().empty()) continue; // skip declarations
auto machine_func = std::make_unique<MachineFunction>(func.GetName());
ValueSlotMap slots;
// First, create all basic blocks in MachineFunction
std::unordered_map<const ir::BasicBlock*, MachineBasicBlock*> bb_map;
machine_func->GetBlocks().reserve(func.GetBlocks().size());
for (const auto& bbPtr : func.GetBlocks()) {
auto& mbb = machine_func->CreateBlock(bbPtr->GetName());
bb_map[bbPtr.get()] = &mbb;
}
// Pre-allocate stack slots for all Phi instructions in the function
for (const auto& bbPtr : func.GetBlocks()) {
for (const auto& inst : bbPtr->GetInstructions()) {
if (inst->GetOpcode() == ir::Opcode::Phi) {
int slot = machine_func->CreateFrameIndex(GetTypeSize(inst->GetType().get()));
slots.emplace(inst.get(), slot);
}
}
}
auto& entry_block = *bb_map.at(func.GetEntry());
// Lower function arguments at the start of the entry block
const auto& args = func.GetArguments();
int int_idx = 0;
int float_idx = 0;
for (const auto& arg : args) {
int slot = machine_func->CreateFrameIndex(GetTypeSize(arg->GetType().get()));
slots.emplace(arg.get(), slot);
if (arg->GetType()->IsFloat()) {
PhysReg reg = static_cast<PhysReg>(static_cast<int>(PhysReg::S0) + float_idx);
entry_block.Append(Opcode::StoreStack, {Operand::Reg(reg), Operand::FrameIndex(slot)});
float_idx++;
} else {
PhysReg reg = (arg->GetType()->IsPtrInt32() || arg->GetType()->IsPtrFloat())
? static_cast<PhysReg>(static_cast<int>(PhysReg::X0) + int_idx)
: static_cast<PhysReg>(static_cast<int>(PhysReg::W0) + int_idx);
entry_block.Append(Opcode::StoreStack, {Operand::Reg(reg), Operand::FrameIndex(slot)});
int_idx++;
}
}
// Now, lower all instructions block by block
for (const auto& bbPtr : func.GetBlocks()) {
auto& mbb = *bb_map.at(bbPtr.get());
for (const auto& inst : bbPtr->GetInstructions()) {
LowerInstruction(*inst, *machine_func, slots, mbb);
}
}
mfuncs.push_back(std::move(machine_func));
} }
const auto& func = *module.GetFunctions().front(); return mfuncs;
if (func.GetName() != "main") {
throw std::runtime_error(FormatError("mir", "暂不支持非 main 函数"));
}
auto machine_func = std::make_unique<MachineFunction>(func.GetName());
ValueSlotMap slots;
const auto* entry = func.GetEntry();
if (!entry) {
throw std::runtime_error(FormatError("mir", "IR 函数缺少入口基本块"));
}
for (const auto& inst : entry->GetInstructions()) {
LowerInstruction(*inst, *machine_func, slots);
}
return machine_func;
} }
} // namespace mir } // namespace mir

View File

@@ -8,7 +8,12 @@
namespace mir { namespace mir {
MachineFunction::MachineFunction(std::string name) MachineFunction::MachineFunction(std::string name)
: name_(std::move(name)), entry_("entry") {} : name_(std::move(name)) {}
MachineBasicBlock& MachineFunction::CreateBlock(std::string name) {
blocks_.emplace_back(std::move(name));
return blocks_.back();
}
int MachineFunction::CreateFrameIndex(int size) { int MachineFunction::CreateFrameIndex(int size) {
int index = static_cast<int>(frame_slots_.size()); int index = static_cast<int>(frame_slots_.size());

View File

@@ -4,10 +4,12 @@
namespace mir { namespace mir {
Operand::Operand(Kind kind, PhysReg reg, int imm) Operand::Operand(Kind kind, PhysReg reg, int imm, std::string str)
: kind_(kind), reg_(reg), imm_(imm) {} : kind_(kind), reg_(reg), imm_(imm), str_(std::move(str)) {}
Operand Operand::Reg(PhysReg reg) { return Operand(Kind::Reg, reg, 0); } Operand Operand::Reg(PhysReg reg) {
return Operand(Kind::Reg, reg, 0);
}
Operand Operand::Imm(int value) { Operand Operand::Imm(int value) {
return Operand(Kind::Imm, PhysReg::W0, value); return Operand(Kind::Imm, PhysReg::W0, value);
@@ -17,6 +19,18 @@ Operand Operand::FrameIndex(int index) {
return Operand(Kind::FrameIndex, PhysReg::W0, index); return Operand(Kind::FrameIndex, PhysReg::W0, index);
} }
Operand Operand::Global(std::string name) {
return Operand(Kind::Global, PhysReg::W0, 0, std::move(name));
}
Operand Operand::Label(std::string name) {
return Operand(Kind::Label, PhysReg::W0, 0, std::move(name));
}
Operand Operand::Cond(std::string cond) {
return Operand(Kind::Cond, PhysReg::W0, 0, std::move(cond));
}
MachineInstr::MachineInstr(Opcode opcode, std::vector<Operand> operands) MachineInstr::MachineInstr(Opcode opcode, std::vector<Operand> operands)
: opcode_(opcode), operands_(std::move(operands)) {} : opcode_(opcode), operands_(std::move(operands)) {}

View File

@@ -8,26 +8,19 @@ namespace mir {
namespace { namespace {
bool IsAllowedReg(PhysReg reg) { bool IsAllowedReg(PhysReg reg) {
switch (reg) { return true; // We allow all defined physical registers
case PhysReg::W0:
case PhysReg::W8:
case PhysReg::W9:
case PhysReg::X29:
case PhysReg::X30:
case PhysReg::SP:
return true;
}
return false;
} }
} // namespace } // namespace
void RunRegAlloc(MachineFunction& function) { void RunRegAlloc(MachineFunction& function) {
for (const auto& inst : function.GetEntry().GetInstructions()) { for (const auto& block : function.GetBlocks()) {
for (const auto& operand : inst.GetOperands()) { for (const auto& inst : block.GetInstructions()) {
if (operand.GetKind() == Operand::Kind::Reg && for (const auto& operand : inst.GetOperands()) {
!IsAllowedReg(operand.GetReg())) { if (operand.GetKind() == Operand::Kind::Reg &&
throw std::runtime_error(FormatError("mir", "寄存器分配失败")); !IsAllowedReg(operand.GetReg())) {
throw std::runtime_error(FormatError("mir", "寄存器分配失败"));
}
} }
} }
} }

View File

@@ -1,6 +1,7 @@
#include "mir/MIR.h" #include "mir/MIR.h"
#include <stdexcept> #include <stdexcept>
#include <string>
#include "utils/Log.h" #include "utils/Log.h"
@@ -8,18 +9,77 @@ namespace mir {
const char* PhysRegName(PhysReg reg) { const char* PhysRegName(PhysReg reg) {
switch (reg) { switch (reg) {
case PhysReg::W0: case PhysReg::W0: return "w0";
return "w0"; case PhysReg::W1: return "w1";
case PhysReg::W8: case PhysReg::W2: return "w2";
return "w8"; case PhysReg::W3: return "w3";
case PhysReg::W9: case PhysReg::W4: return "w4";
return "w9"; case PhysReg::W5: return "w5";
case PhysReg::X29: case PhysReg::W6: return "w6";
return "x29"; case PhysReg::W7: return "w7";
case PhysReg::X30: case PhysReg::W8: return "w8";
return "x30"; case PhysReg::W9: return "w9";
case PhysReg::SP: case PhysReg::W10: return "w10";
return "sp"; case PhysReg::W11: return "w11";
case PhysReg::W12: return "w12";
case PhysReg::W13: return "w13";
case PhysReg::W14: return "w14";
case PhysReg::W15: return "w15";
case PhysReg::W19: return "w19";
case PhysReg::W20: return "w20";
case PhysReg::W21: return "w21";
case PhysReg::W22: return "w22";
case PhysReg::W23: return "w23";
case PhysReg::W24: return "w24";
case PhysReg::W25: return "w25";
case PhysReg::W26: return "w26";
case PhysReg::W27: return "w27";
case PhysReg::W28: return "w28";
case PhysReg::X0: return "x0";
case PhysReg::X1: return "x1";
case PhysReg::X2: return "x2";
case PhysReg::X3: return "x3";
case PhysReg::X4: return "x4";
case PhysReg::X5: return "x5";
case PhysReg::X6: return "x6";
case PhysReg::X7: return "x7";
case PhysReg::X8: return "x8";
case PhysReg::X9: return "x9";
case PhysReg::X10: return "x10";
case PhysReg::X11: return "x11";
case PhysReg::X12: return "x12";
case PhysReg::X13: return "x13";
case PhysReg::X14: return "x14";
case PhysReg::X15: return "x15";
case PhysReg::X19: return "x19";
case PhysReg::X20: return "x20";
case PhysReg::X21: return "x21";
case PhysReg::X22: return "x22";
case PhysReg::X23: return "x23";
case PhysReg::X24: return "x24";
case PhysReg::X25: return "x25";
case PhysReg::X26: return "x26";
case PhysReg::X27: return "x27";
case PhysReg::X28: return "x28";
case PhysReg::S0: return "s0";
case PhysReg::S1: return "s1";
case PhysReg::S2: return "s2";
case PhysReg::S3: return "s3";
case PhysReg::S4: return "s4";
case PhysReg::S5: return "s5";
case PhysReg::S6: return "s6";
case PhysReg::S7: return "s7";
case PhysReg::S8: return "s8";
case PhysReg::S9: return "s9";
case PhysReg::S10: return "s10";
case PhysReg::S11: return "s11";
case PhysReg::S12: return "s12";
case PhysReg::S13: return "s13";
case PhysReg::S14: return "s14";
case PhysReg::S15: return "s15";
case PhysReg::X29: return "x29";
case PhysReg::X30: return "x30";
case PhysReg::SP: return "sp";
} }
throw std::runtime_error(FormatError("mir", "未知物理寄存器")); throw std::runtime_error(FormatError("mir", "未知物理寄存器"));
} }

View File

@@ -1,4 +1,185 @@
// 窥孔优化Peephole #include "mir/MIR.h"
// - 删除冗余 move、合并常见指令模式 #include <unordered_map>
// - 提升最终汇编质量(按实现范围裁剪) #include <vector>
namespace mir {
namespace {
PhysReg NormalizeReg(PhysReg reg) {
int r = static_cast<int>(reg);
// Map 64-bit X0-X28 registers to 32-bit W0-W28 registers to handle aliasing
if (r >= static_cast<int>(PhysReg::X0) && r <= static_cast<int>(PhysReg::X28)) {
return static_cast<PhysReg>(r - static_cast<int>(PhysReg::X0) + static_cast<int>(PhysReg::W0));
}
return reg;
}
PhysReg MatchRegSize(PhysReg target, PhysReg src) {
int t = static_cast<int>(target);
int s = static_cast<int>(src);
bool target_is_64 = (t >= static_cast<int>(PhysReg::X0) && t <= static_cast<int>(PhysReg::X28)) ||
t == static_cast<int>(PhysReg::X29) ||
t == static_cast<int>(PhysReg::X30) ||
t == static_cast<int>(PhysReg::SP);
bool src_is_64 = (s >= static_cast<int>(PhysReg::X0) && s <= static_cast<int>(PhysReg::X28)) ||
s == static_cast<int>(PhysReg::X29) ||
s == static_cast<int>(PhysReg::X30) ||
s == static_cast<int>(PhysReg::SP);
if (target_is_64 && !src_is_64) {
if (s >= static_cast<int>(PhysReg::W0) && s <= static_cast<int>(PhysReg::W28)) {
return static_cast<PhysReg>(s - static_cast<int>(PhysReg::W0) + static_cast<int>(PhysReg::X0));
}
} else if (!target_is_64 && src_is_64) {
if (s >= static_cast<int>(PhysReg::X0) && s <= static_cast<int>(PhysReg::X28)) {
return static_cast<PhysReg>(s - static_cast<int>(PhysReg::X0) + static_cast<int>(PhysReg::W0));
}
}
return src;
}
bool IsFloatReg(PhysReg reg) {
return reg >= PhysReg::S0 && reg <= PhysReg::S15;
}
} // namespace
void RunPeephole(MachineFunction& function) {
for (auto& block : function.GetBlocks()) {
auto& insts = block.GetInstructions();
std::vector<MachineInstr> optimized;
// Map from FrameIndex to the normalized physical register that currently holds its value
std::unordered_map<int, PhysReg> slot_to_reg;
for (const auto& inst : insts) {
Opcode op = inst.GetOpcode();
const auto& ops = inst.GetOperands();
// 1. Handle register move elimination (e.g. mov w8, w8)
if (op == Opcode::MovReg) {
if (NormalizeReg(ops.at(0).GetReg()) == NormalizeReg(ops.at(1).GetReg())) {
continue; // Delete redundant self-moves
}
}
// 2. Handle redundant Load after Store
if (op == Opcode::LoadStack) {
int fi = ops.at(1).GetFrameIndex();
auto it = slot_to_reg.find(fi);
if (it != slot_to_reg.end()) {
PhysReg source_reg = it->second;
PhysReg dest_reg = NormalizeReg(ops.at(0).GetReg());
if (source_reg == dest_reg) {
// Loading the same register that already has the value - completely redundant!
continue;
} else {
// Replace LoadStack dest_reg, fi with MovReg dest_reg, matched_source
PhysReg matched_source = MatchRegSize(ops.at(0).GetReg(), it->second);
optimized.push_back(MachineInstr(Opcode::MovReg, {Operand::Reg(ops.at(0).GetReg()), Operand::Reg(matched_source)}));
// Invalidate any other slots mapping to dest_reg because dest_reg is written
std::vector<int> to_remove;
for (const auto& pair : slot_to_reg) {
if (NormalizeReg(pair.second) == dest_reg) {
to_remove.push_back(pair.first);
}
}
for (int key : to_remove) {
slot_to_reg.erase(key);
}
// Add new mapping (normalized)
slot_to_reg[fi] = dest_reg;
continue;
}
}
}
// 3. Track stores
if (op == Opcode::StoreStack) {
PhysReg src = NormalizeReg(ops.at(0).GetReg());
int fi = ops.at(1).GetFrameIndex();
slot_to_reg[fi] = src;
}
// 4. Invalidate register mappings on writes
bool writes_reg = false;
PhysReg written_reg = PhysReg::W0; // dummy
switch (op) {
case Opcode::MovImm:
if (!ops.empty() && ops.at(0).GetKind() == Operand::Kind::Reg) {
writes_reg = true;
written_reg = NormalizeReg(ops.at(0).GetReg());
// Under the hood, MovImm to a float register implicitly writes to x8/w8
if (IsFloatReg(ops.at(0).GetReg())) {
PhysReg implicitly_written = NormalizeReg(PhysReg::X8);
std::vector<int> to_remove;
for (const auto& pair : slot_to_reg) {
if (NormalizeReg(pair.second) == implicitly_written) {
to_remove.push_back(pair.first);
}
}
for (int key : to_remove) {
slot_to_reg.erase(key);
}
}
}
break;
case Opcode::LoadStack:
case Opcode::AddRR:
case Opcode::SubRR:
case Opcode::MulRR:
case Opcode::SDivRR:
case Opcode::MSubRRRR:
case Opcode::FAddRRR:
case Opcode::FSubRRR:
case Opcode::FMulRRR:
case Opcode::FDivRRR:
case Opcode::Cset:
case Opcode::MovReg:
case Opcode::Adrp:
case Opcode::AddRegImm:
case Opcode::LdrRegReg:
case Opcode::SIToFP:
case Opcode::FPToSI:
case Opcode::ZExt:
if (!ops.empty() && ops.at(0).GetKind() == Operand::Kind::Reg) {
writes_reg = true;
written_reg = NormalizeReg(ops.at(0).GetReg());
}
break;
case Opcode::Call:
// A function call destroys all temporary/scratch registers.
slot_to_reg.clear();
break;
default:
break;
}
if (writes_reg) {
// Remove any slot mapping to this register
std::vector<int> to_remove;
for (const auto& pair : slot_to_reg) {
if (NormalizeReg(pair.second) == written_reg) {
to_remove.push_back(pair.first);
}
}
for (int key : to_remove) {
slot_to_reg.erase(key);
}
}
optimized.push_back(inst);
}
insts = std::move(optimized);
}
}
} // namespace mir

View File

@@ -10,185 +10,321 @@
namespace { namespace {
std::string GetLValueName(SysYParser::LValueContext& lvalue) {
if (!lvalue.ID()) {
throw std::runtime_error(FormatError("sema", "非法左值"));
}
return lvalue.ID()->getText();
}
class SemaVisitor final : public SysYBaseVisitor { class SemaVisitor final : public SysYBaseVisitor {
public: public:
explicit SemaVisitor() {
// 预填标准库函数
AddBuiltin("getint", Symbol::Kind::Function);
AddBuiltin("getch", Symbol::Kind::Function);
AddBuiltin("getfloat", Symbol::Kind::Function);
AddBuiltin("getarray", Symbol::Kind::Function);
AddBuiltin("getfarray", Symbol::Kind::Function);
AddBuiltin("putint", Symbol::Kind::Function);
AddBuiltin("putch", Symbol::Kind::Function);
AddBuiltin("putfloat", Symbol::Kind::Function);
AddBuiltin("putarray", Symbol::Kind::Function);
AddBuiltin("putfarray", Symbol::Kind::Function);
AddBuiltin("starttime", Symbol::Kind::Function);
AddBuiltin("stoptime", Symbol::Kind::Function);
AddBuiltin("putf", Symbol::Kind::Function);
}
private:
void AddBuiltin(const std::string& name, Symbol::Kind kind) {
Symbol sym;
sym.kind = kind;
sym.def_ctx = nullptr; // 内建函数没有语法树节点
table_.Add(name, sym);
}
std::any visitCompUnit(SysYParser::CompUnitContext* ctx) override { std::any visitCompUnit(SysYParser::CompUnitContext* ctx) override {
if (!ctx) { if (!ctx) return {};
throw std::runtime_error(FormatError("sema", "缺少编译单元")); for (auto* child : ctx->children) {
if (auto* decl = dynamic_cast<SysYParser::DeclContext*>(child)) {
decl->accept(this);
} else if (auto* funcDef = dynamic_cast<SysYParser::FuncDefContext*>(child)) {
funcDef->accept(this);
}
} }
auto* func = ctx->funcDef(); return {};
if (!func || !func->blockStmt()) { }
throw std::runtime_error(FormatError("sema", "缺少 main 函数定义"));
std::any visitDecl(SysYParser::DeclContext* ctx) override {
if (ctx->constDecl()) return ctx->constDecl()->accept(this);
if (ctx->varDecl()) return ctx->varDecl()->accept(this);
return {};
}
std::any visitConstDecl(SysYParser::ConstDeclContext* ctx) override {
for (auto* def : ctx->constDef()) {
const std::string name = def->ID()->getText();
if (table_.IsInCurrentScope(name)) {
throw std::runtime_error(FormatError("sema", "重复定义常量: " + name));
}
Symbol sym;
sym.kind = Symbol::Kind::Constant;
sym.def_ctx = def;
sym.is_const = true;
sym.is_array = !def->exp().empty();
table_.Add(name, sym);
for (auto* exp : def->exp()) exp->accept(this);
def->initValue()->accept(this);
} }
if (!func->ID() || func->ID()->getText() != "main") { return {};
throw std::runtime_error(FormatError("sema", "缺少 main 函数定义")); }
}
func->accept(this); std::any visitVarDecl(SysYParser::VarDeclContext* ctx) override {
if (!seen_return_) { for (auto* def : ctx->varDef()) {
throw std::runtime_error( const std::string name = def->ID()->getText();
FormatError("sema", "main 函数必须包含 return 语句")); if (table_.IsInCurrentScope(name)) {
throw std::runtime_error(FormatError("sema", "重复定义变量: " + name));
}
Symbol sym;
sym.kind = Symbol::Kind::Variable;
sym.def_ctx = def;
sym.is_const = false;
sym.is_array = !def->exp().empty();
table_.Add(name, sym);
for (auto* exp : def->exp()) exp->accept(this);
if (def->initValue()) def->initValue()->accept(this);
} }
return {}; return {};
} }
std::any visitFuncDef(SysYParser::FuncDefContext* ctx) override { std::any visitFuncDef(SysYParser::FuncDefContext* ctx) override {
if (!ctx || !ctx->blockStmt()) { const std::string name = ctx->ID()->getText();
throw std::runtime_error(FormatError("sema", "缺少 main 函数定义")); if (table_.IsInCurrentScope(name)) {
throw std::runtime_error(FormatError("sema", "重复定义函数: " + name));
} }
if (!ctx->funcType() || !ctx->funcType()->INT()) { Symbol sym;
throw std::runtime_error(FormatError("sema", "当前仅支持 int main")); sym.kind = Symbol::Kind::Function;
sym.def_ctx = ctx;
table_.Add(name, sym);
table_.PushScope();
if (ctx->funcFParams()) {
ctx->funcFParams()->accept(this);
} }
const auto& items = ctx->blockStmt()->blockItem(); if (ctx->blockStmt()) {
if (items.empty()) { // Visit block items without pushing another scope to keep params in same scope
throw std::runtime_error( for (auto* item : ctx->blockStmt()->blockItem()) {
FormatError("sema", "main 函数不能为空,且必须以 return 结束")); item->accept(this);
}
} }
ctx->blockStmt()->accept(this); table_.PopScope();
return {};
}
std::any visitFuncFParam(SysYParser::FuncFParamContext* ctx) override {
const std::string name = ctx->ID()->getText();
if (table_.IsInCurrentScope(name)) {
throw std::runtime_error(FormatError("sema", "函数参数名冲突: " + name));
}
Symbol sym;
sym.kind = Symbol::Kind::Parameter;
sym.def_ctx = ctx;
sym.is_array = !ctx->LBRACK().empty();
table_.Add(name, sym);
for (auto* exp : ctx->exp()) exp->accept(this);
return {}; return {};
} }
std::any visitBlockStmt(SysYParser::BlockStmtContext* ctx) override { std::any visitBlockStmt(SysYParser::BlockStmtContext* ctx) override {
if (!ctx) { table_.PushScope();
throw std::runtime_error(FormatError("sema", "缺少语句块")); for (auto* item : ctx->blockItem()) {
}
const auto& items = ctx->blockItem();
for (size_t i = 0; i < items.size(); ++i) {
auto* item = items[i];
if (!item) {
continue;
}
if (seen_return_) {
throw std::runtime_error(
FormatError("sema", "return 必须是 main 函数中的最后一条语句"));
}
current_item_index_ = i;
total_items_ = items.size();
item->accept(this); item->accept(this);
} }
table_.PopScope();
return {}; return {};
} }
std::any visitBlockItem(SysYParser::BlockItemContext* ctx) override { std::any visitAssignStmt(SysYParser::AssignStmtContext* ctx) override {
if (!ctx) { ctx->lValue()->accept(this);
throw std::runtime_error(FormatError("sema", "暂不支持的语句或声明")); const std::string name = ctx->lValue()->ID()->getText();
} Symbol* sym = table_.Lookup(name);
if (ctx->decl()) { if (sym && sym->is_const) {
ctx->decl()->accept(this); throw std::runtime_error(FormatError("sema", "试图给常量赋值: " + name));
return {};
}
if (ctx->stmt()) {
ctx->stmt()->accept(this);
return {};
}
throw std::runtime_error(FormatError("sema", "暂不支持的语句或声明"));
}
std::any visitDecl(SysYParser::DeclContext* ctx) override {
if (!ctx) {
throw std::runtime_error(FormatError("sema", "非法变量声明"));
}
if (!ctx->btype() || !ctx->btype()->INT()) {
throw std::runtime_error(FormatError("sema", "当前仅支持局部 int 变量声明"));
}
auto* var_def = ctx->varDef();
if (!var_def || !var_def->lValue()) {
throw std::runtime_error(FormatError("sema", "非法变量声明"));
}
const std::string name = GetLValueName(*var_def->lValue());
if (table_.Contains(name)) {
throw std::runtime_error(FormatError("sema", "重复定义变量: " + name));
}
if (auto* init = var_def->initValue()) {
if (!init->exp()) {
throw std::runtime_error(FormatError("sema", "当前不支持聚合初始化"));
}
init->exp()->accept(this);
}
table_.Add(name, var_def);
return {};
}
std::any visitStmt(SysYParser::StmtContext* ctx) override {
if (!ctx || !ctx->returnStmt()) {
throw std::runtime_error(FormatError("sema", "暂不支持的语句或声明"));
}
ctx->returnStmt()->accept(this);
return {};
}
std::any visitReturnStmt(SysYParser::ReturnStmtContext* ctx) override {
if (!ctx || !ctx->exp()) {
throw std::runtime_error(FormatError("sema", "return 缺少表达式"));
} }
ctx->exp()->accept(this); ctx->exp()->accept(this);
seen_return_ = true; return {};
if (current_item_index_ + 1 != total_items_) { }
throw std::runtime_error(
FormatError("sema", "return 必须是 main 函数中的最后一条语句")); std::any visitLValue(SysYParser::LValueContext* ctx) override {
const std::string name = ctx->ID()->getText();
Symbol* sym = table_.Lookup(name);
if (!sym) {
throw std::runtime_error(FormatError("sema", "使用了未定义的标识符: " + name));
}
if (sym->kind == Symbol::Kind::Function) {
throw std::runtime_error(FormatError("sema", "函数名不能作为左值: " + name));
}
sema_.BindLValue(ctx, sym->def_ctx);
for (auto* exp : ctx->exp()) exp->accept(this);
return {};
}
std::any visitFuncCallExp(SysYParser::FuncCallExpContext* ctx) override {
const std::string name = ctx->ID()->getText();
Symbol* sym = table_.Lookup(name);
if (!sym) {
throw std::runtime_error(FormatError("sema", "调用未定义的函数: " + name));
}
if (sym->kind != Symbol::Kind::Function) {
throw std::runtime_error(FormatError("sema", "标识符不是函数: " + name));
}
sema_.BindFuncCall(ctx, dynamic_cast<SysYParser::FuncDefContext*>(sym->def_ctx));
if (ctx->funcRParams()) {
ctx->funcRParams()->accept(this);
} }
return {}; return {};
} }
// Visit expressions to ensure all sub-expressions are checked (e.g. for variable uses)
std::any visitParenExp(SysYParser::ParenExpContext* ctx) override { std::any visitParenExp(SysYParser::ParenExpContext* ctx) override {
if (!ctx || !ctx->exp()) { return ctx->exp()->accept(this);
throw std::runtime_error(FormatError("sema", "非法括号表达式"));
}
ctx->exp()->accept(this);
return {};
} }
std::any visitVarExp(SysYParser::VarExpContext* ctx) override { std::any visitLValueExp(SysYParser::LValueExpContext* ctx) override {
if (!ctx || !ctx->var()) { return ctx->lValue()->accept(this);
throw std::runtime_error(FormatError("sema", "非法变量表达式"));
}
ctx->var()->accept(this);
return {};
} }
std::any visitNumberExp(SysYParser::NumberExpContext* ctx) override { std::any visitNumberExp(SysYParser::NumberExpContext* ctx) override {
if (!ctx || !ctx->number() || !ctx->number()->ILITERAL()) { return ctx->number()->accept(this);
throw std::runtime_error(FormatError("sema", "当前仅支持整数字面量")); }
}
std::any visitNumber(SysYParser::NumberContext* ctx) override {
return {}; return {};
} }
std::any visitAdditiveExp(SysYParser::AdditiveExpContext* ctx) override { std::any visitNotExp(SysYParser::NotExpContext* ctx) override {
if (!ctx || !ctx->exp(0) || !ctx->exp(1)) { return ctx->exp()->accept(this);
throw std::runtime_error(FormatError("sema", "暂不支持的表达式形式")); }
}
std::any visitUnaryAddExp(SysYParser::UnaryAddExpContext* ctx) override {
return ctx->exp()->accept(this);
}
std::any visitUnarySubExp(SysYParser::UnarySubExpContext* ctx) override {
return ctx->exp()->accept(this);
}
std::any visitMulExp(SysYParser::MulExpContext* ctx) override {
ctx->exp(0)->accept(this); ctx->exp(0)->accept(this);
ctx->exp(1)->accept(this); ctx->exp(1)->accept(this);
return {}; return {};
} }
std::any visitVar(SysYParser::VarContext* ctx) override { std::any visitDivExp(SysYParser::DivExpContext* ctx) override {
if (!ctx || !ctx->ID()) { ctx->exp(0)->accept(this);
throw std::runtime_error(FormatError("sema", "非法变量引用")); ctx->exp(1)->accept(this);
}
const std::string name = ctx->ID()->getText();
auto* decl = table_.Lookup(name);
if (!decl) {
throw std::runtime_error(FormatError("sema", "使用了未定义的变量: " + name));
}
sema_.BindVarUse(ctx, decl);
return {}; return {};
} }
SemanticContext TakeSemanticContext() { return std::move(sema_); } std::any visitModExp(SysYParser::ModExpContext* ctx) override {
ctx->exp(0)->accept(this);
ctx->exp(1)->accept(this);
return {};
}
private: std::any visitAddExp(SysYParser::AddExpContext* ctx) override {
ctx->exp(0)->accept(this);
ctx->exp(1)->accept(this);
return {};
}
std::any visitSubExp(SysYParser::SubExpContext* ctx) override {
ctx->exp(0)->accept(this);
ctx->exp(1)->accept(this);
return {};
}
std::any visitLtExp(SysYParser::LtExpContext* ctx) override {
ctx->exp(0)->accept(this);
ctx->exp(1)->accept(this);
return {};
}
std::any visitLeExp(SysYParser::LeExpContext* ctx) override {
ctx->exp(0)->accept(this);
ctx->exp(1)->accept(this);
return {};
}
std::any visitGtExp(SysYParser::GtExpContext* ctx) override {
ctx->exp(0)->accept(this);
ctx->exp(1)->accept(this);
return {};
}
std::any visitGeExp(SysYParser::GeExpContext* ctx) override {
ctx->exp(0)->accept(this);
ctx->exp(1)->accept(this);
return {};
}
std::any visitEqExp(SysYParser::EqExpContext* ctx) override {
ctx->exp(0)->accept(this);
ctx->exp(1)->accept(this);
return {};
}
std::any visitNeExp(SysYParser::NeExpContext* ctx) override {
ctx->exp(0)->accept(this);
ctx->exp(1)->accept(this);
return {};
}
std::any visitAndExp(SysYParser::AndExpContext* ctx) override {
ctx->exp(0)->accept(this);
ctx->exp(1)->accept(this);
return {};
}
std::any visitOrExp(SysYParser::OrExpContext* ctx) override {
ctx->exp(0)->accept(this);
ctx->exp(1)->accept(this);
return {};
}
std::any visitReturnStmt(SysYParser::ReturnStmtContext* ctx) override {
if (ctx->exp()) ctx->exp()->accept(this);
return {};
}
std::any visitIfStmt(SysYParser::IfStmtContext* ctx) override {
ctx->exp()->accept(this);
ctx->stmt(0)->accept(this);
if (ctx->stmt(1)) ctx->stmt(1)->accept(this);
return {};
}
std::any visitWhileStmt(SysYParser::WhileStmtContext* ctx) override {
ctx->exp()->accept(this);
ctx->stmt()->accept(this);
return {};
}
std::any visitBreakStmt(SysYParser::BreakStmtContext* ctx) override {
return {};
}
std::any visitContinueStmt(SysYParser::ContinueStmtContext* ctx) override {
return {};
}
std::any visitExpStmt(SysYParser::ExpStmtContext* ctx) override {
if (ctx->exp()) ctx->exp()->accept(this);
return {};
}
public:
SemanticContext TakeSemanticContext() { return std::move(sema_); }
private:
SymbolTable table_; SymbolTable table_;
SemanticContext sema_; SemanticContext sema_;
bool seen_return_ = false;
size_t current_item_index_ = 0;
size_t total_items_ = 0;
}; };
} // namespace } // namespace

View File

@@ -1,17 +1,40 @@
// 维护局部变量声明的注册与查找。
#include "sem/SymbolTable.h" #include "sem/SymbolTable.h"
void SymbolTable::Add(const std::string& name, SymbolTable::SymbolTable() {
SysYParser::VarDefContext* decl) { // Push global scope
table_[name] = decl; PushScope();
} }
bool SymbolTable::Contains(const std::string& name) const { void SymbolTable::PushScope() {
return table_.find(name) != table_.end(); scopes_.emplace_back();
} }
SysYParser::VarDefContext* SymbolTable::Lookup(const std::string& name) const { void SymbolTable::PopScope() {
auto it = table_.find(name); if (scopes_.size() > 1) {
return it == table_.end() ? nullptr : it->second; scopes_.pop_back();
}
}
bool SymbolTable::Add(const std::string& name, const Symbol& symbol) {
auto& current_scope = scopes_.back();
if (current_scope.find(name) != current_scope.end()) {
return false;
}
current_scope[name] = symbol;
return true;
}
Symbol* SymbolTable::Lookup(const std::string& name) {
for (auto it = scopes_.rbegin(); it != scopes_.rend(); ++it) {
auto search = it->find(name);
if (search != it->end()) {
return &search->second;
}
}
return nullptr;
}
bool SymbolTable::IsInCurrentScope(const std::string& name) const {
const auto& current_scope = scopes_.back();
return current_scope.find(name) != current_scope.end();
} }

View File

@@ -1,4 +1,77 @@
// SysY 运行库实现: #include <stdio.h>
// - 按实验/评测规范提供 I/O 等函数实现 #include <sys/time.h>
// - 与编译器生成的目标代码链接,支撑运行时行为
int getint() {
int x;
if (scanf("%d", &x) != 1) return 0;
return x;
}
int getch() {
return getchar();
}
float getfloat() {
double x;
if (scanf("%lf", &x) != 1) return 0.0f;
return (float)x;
}
int getarray(int a[]) {
int n;
if (scanf("%d", &n) != 1) return 0;
for (int i = 0; i < n; i++) {
if (scanf("%d", &a[i]) != 1) break;
}
return n;
}
int getfarray(float a[]) {
int n;
if (scanf("%d", &n) != 1) return 0;
for (int i = 0; i < n; i++) {
double val;
if (scanf("%lf", &val) != 1) break;
a[i] = (float)val;
}
return n;
}
void putint(int x) {
printf("%d", x);
}
void putch(int x) {
putchar(x);
}
void putfloat(float x) {
printf("%a", x);
}
void putarray(int n, int a[]) {
printf("%d:", n);
for (int i = 0; i < n; i++) {
printf(" %d", a[i]);
}
printf("\n");
}
void putfarray(int n, float a[]) {
printf("%d:", n);
for (int i = 0; i < n; i++) {
printf(" %a", a[i]);
}
printf("\n");
}
struct timeval start, stop;
void starttime() {
gettimeofday(&start, NULL);
}
void stoptime() {
gettimeofday(&stop, NULL);
long long duration = (stop.tv_sec - start.tv_sec) * 1000000LL + (stop.tv_usec - start.tv_usec);
printf("timer: %lld us\n", duration);
}