From 979d271ebe8df28ccadc70d382e4d8ff88dd56e7 Mon Sep 17 00:00:00 2001 From: CGH0S7 <776459475@qq.com> Date: Thu, 16 Apr 2026 00:21:35 +0800 Subject: [PATCH] Complete Lab2 IR generation and document process --- doc/Lab2-实验记录.md | 313 ++++++++++++++++ include/ir/IR.h | 185 +++++++++- include/irgen/IRGen.h | 68 +++- include/sem/Sema.h | 38 +- include/sem/SymbolTable.h | 23 +- src/ir/BasicBlock.cpp | 2 +- src/ir/CMakeLists.txt | 1 - src/ir/Context.cpp | 14 +- src/ir/Function.cpp | 13 +- src/ir/GlobalValue.cpp | 4 +- src/ir/IRBuilder.cpp | 142 +++++++- src/ir/IRPrinter.cpp | 278 +++++++++++++- src/ir/Instruction.cpp | 122 ++++--- src/ir/Module.cpp | 17 +- src/ir/Type.cpp | 40 ++ src/ir/Value.cpp | 31 ++ src/irgen/IRGenDecl.cpp | 282 ++++++++++----- src/irgen/IRGenDriver.cpp | 18 + src/irgen/IRGenExp.cpp | 744 +++++++++++++++++++++++++++++++++++--- src/irgen/IRGenFunc.cpp | 120 +++--- src/irgen/IRGenStmt.cpp | 158 ++++++-- src/sem/Sema.cpp | 398 +++++++++++++------- src/sem/SymbolTable.cpp | 43 ++- 23 files changed, 2583 insertions(+), 471 deletions(-) create mode 100644 doc/Lab2-实验记录.md diff --git a/doc/Lab2-实验记录.md b/doc/Lab2-实验记录.md new file mode 100644 index 0000000..11be979 --- /dev/null +++ b/doc/Lab2-实验记录.md @@ -0,0 +1,313 @@ +# Lab2 实验记录:中间表示生成 + +## 1. 实验目标 + +本次 Lab2 的目标是在已有的 SysY 前端基础上,补齐语义检查与 IR 生成流程,使编译器能够把更完整的 SysY 程序翻译为 LLVM 风格 IR,并通过 `llc/clang` 验证生成结果的正确性。 + +本次完成工作的重点包括: + +- 扩展 IR 类型系统与指令系统,支持 `float`、数组、分支、函数调用、GEP、类型转换等基础能力。 +- 扩展 Sema,支持嵌套作用域、左值绑定、函数调用绑定与内建库函数预声明。 +- 完成表达式、控制流、函数、数组与全局变量的 IR 生成逻辑。 +- 修复全局初始化常量求值、短路求值、数组寻址、IR 打印格式等会直接阻塞 Lab2 验证的关键问题。 + +## 2. 代码改动范围 + +本次实验主要修改了以下模块: + +- `src/sem` 与 `include/sem` +- `src/ir` 与 `include/ir` +- `src/irgen` 与 `include/irgen` +- 新增本文档 `doc/Lab2-实验记录.md` + +其中: + +- `sem` 负责名称绑定、作用域和语义信息准备。 +- `ir` 负责 IR 基础设施、Builder 与 Printer。 +- `irgen` 负责把 ANTLR 语法树翻译成 IR。 + +## 3. 完成过程 + +### 3.1 先确认问题边界 + +开始时先阅读了实验文档 `doc/Lab2-中间表示生成.md`,然后检查了以下实现: + +- `IRGenDecl.cpp` +- `IRGenExp.cpp` +- `IRGenStmt.cpp` +- `IRGenFunc.cpp` +- `IRBuilder.cpp` +- `IRPrinter.cpp` +- `Sema.cpp` + +在初始状态下,代码已经完成了大部分 Lab2 框架,但仍存在两个会直接导致失败的问题: + +1. 全局常量初始化时,`EvalConstExpr` 实际上仍然调用了运行时的 `EvalExpr`,从而在没有插入点时进入 `builder_.CreateLoad/CreateBinary/...`,最终报错: + `IRBuilder 未设置插入点` +2. 数组相关的指针/聚合类型处理不一致,局部数组、多维数组与数组参数传递时很容易触发 `LoadInst 不支持的指针类型` 或生成错误的 GEP。 + +为了避免只靠静态阅读猜问题,随后先执行了构建与最小样例验证,确认真实失败点。 + +### 3.2 建立回归基线 + +首先重新构建项目: + +```bash +cmake -S . -B build -DCMAKE_BUILD_TYPE=Release +cmake --build build -j 4 +``` + +然后针对典型样例做验证: + +- `simple_add.sy` +- `05_arr_defn4.sy` +- `95_float.sy` + +结果表明: + +- `95_float.sy` 会因为全局常量路径错误触发 `IRBuilder 未设置插入点` +- `05_arr_defn4.sy` 会因为数组寻址/存储类型不一致导致崩溃 + +这一步的作用是把问题从“感觉哪里有问题”缩小到“常量求值路径”和“数组存储/寻址路径”两条主线。 + +## 4. 关键困难与解决办法 + +### 4.1 困难一:全局初始化错误地走了运行时 IRBuilder 路径 + +#### 现象 + +像下面这种代码在全局或常量初始化中会崩溃: + +```c +const float PI = 3.1415926; +const int A = 1 + 2; +``` + +原因是原来的 `EvalConstExpr` 虽然名字叫“常量求值”,但内部还是直接调用了 `EvalExpr`。一旦表达式中包含需要访问变量、二元运算、短路逻辑等节点,就会落入 `builder_` 创建指令的逻辑,而此时全局作用域没有任何基本块插入点。 + +#### 解决办法 + +把编译期常量求值彻底独立出来: + +- 为 `EvalConstExpr` 单独实现一套常量求值 Visitor。 +- 常量路径只返回 `ConstantInt` / `ConstantFloat`,绝不生成 IR 指令。 +- 支持: + - 整数/浮点字面量 + - 括号表达式 + - `+`、`-`、`!` + - `* / % + -` + - 比较运算 + - `&& ||` + - 标量 `const` 的引用 +- 在全局和常量初始化中,只允许使用 `EvalConstExpr` 的结果。 + +#### 效果 + +修复后: + +- 全局初始化不再依赖插入点 +- `95_float.sy` 中的全局常量能够稳定生成 +- 短路表达式在常量上下文中只做纯编译期求值,不会试图分配 `alloca` + +### 4.2 困难二:数组变量、数组参数与标量变量的“存储语义”混乱 + +#### 现象 + +原实现里,`alloca/load/store/GEP` 对类型的理解不统一: + +- 标量变量需要的是“指向标量的槽位” +- 局部数组需要的是“聚合对象的基址” +- 数组形参在 SysY 中本质上是指针,不应按局部数组同样处理 + +如果把这些情况混在一起,就会出现: + +- `load` 试图从数组类型直接取值 +- GEP 基类型和索引序列不匹配 +- 局部数组访问、多维数组访问、数组实参传递行为错误 + +#### 解决办法 + +做了三层拆分: + +1. 标量与数组分离 + - 标量局部变量使用真正的标量槽位:`i32*` 或 `float*` + - 数组局部变量保留聚合基址 + +2. 普通数组与数组形参分离 + - 局部/全局数组通过多级 GEP 沿数组维度寻址 + - 数组形参按“指针退化”处理,访问时根据剩余维度计算偏移 + +3. 左值取址与值求值分离 + - `GetLValuePtr` 只负责拿地址 + - `visitLValueExp` 根据左值是否仍是数组来决定是 `load` 还是数组退化传参 + +#### 效果 + +修复后: + +- `simple_add.sy` 恢复正常 +- `05_arr_defn4.sy` 可以生成并运行 +- 多维数组和数组形参的寻址逻辑更加稳定 + +### 4.3 困难三:局部数组花括号初始化语义不正确 + +#### 现象 + +`05_arr_defn4.sy` 虽然在中期已经不再崩溃,但运行结果仍然错误,退出码从预期的 `21` 变成了 `13`。这说明不是寻址崩了,而是初始化布局错了。 + +问题根源在于: + +- 一部分初始化按“子数组递进”处理 +- 一部分初始化又按“标量扁平展开”处理 + +两套逻辑混用后,多维数组初始化次序就会乱掉。 + +#### 解决办法 + +把局部数组初始化统一改成“聚合初始化 + 标量游标”方案: + +- 先统一做零初始化 +- 再对花括号初始化维护一个标量游标 +- 标量初始化时按当前扁平偏移定位到实际元素 +- 子聚合初始化时按当前对齐边界进入对应子数组 + +这套逻辑与 SysY/LLVM 前端常见的聚合初始化处理方式更接近。 + +#### 效果 + +修复后 `05_arr_defn4.sy` 的 IR 可以通过 `verify_ir.sh --run`,输出与预期一致。 + +### 4.4 困难四:IR 文本虽然能打印,但 LLVM 后端不一定接受 + +#### 现象 + +在进入 `verify_ir.sh` 阶段后,又暴露出一批“IR 生成没崩,但 LLVM 不认”的问题: + +- 内建函数被打印成了空定义,而不是声明 +- 浮点常量打印格式不符合 LLVM 期望 +- `icmp/fcmp` 的结果在打印和后续使用中对 `i1/i32` 处理不一致 +- 自动临时名使用纯数字,打乱后会违反 LLVM 的编号要求 +- 基本块名重复 +- `getelementptr` 打印时的基类型信息不正确 + +#### 解决办法 + +对 IR 基础设施做了系统修正: + +- `Function` 不再默认创建入口块,只有真正定义函数时才建 `entry` +- `IRPrinter` 对没有基本块的函数输出 `declare` +- 自动临时名改成 `t0/t1/...`,避免 LLVM 对纯数字 SSA 名称的严格顺序约束 +- 比较结果按布尔值打印和消费 +- `if/while/and/or` 生成的块名追加唯一后缀 +- 修复 float 常量、GEP、Cast、Call 等打印格式 + +#### 效果 + +修复后: + +- `simple_add.sy` +- `13_sub2.sy` +- `29_break.sy` +- `36_op_priority2.sy` +- `05_arr_defn4.sy` + +都已经可以通过 `verify_ir.sh --run`。 + +### 4.5 困难五:`95_float.sy` 的最终运行验证仍受运行库缺失影响 + +#### 现象 + +在修完 IR 生成与打印问题后,`95_float.sy` 已经可以: + +- 成功生成 IR +- 通过 `llc` 生成目标文件 + +但在最终链接阶段仍会失败,原因不是 IR 错误,而是仓库中的 `sylib/sylib.c` 当前只是空壳,没有提供: + +- `getfloat` +- `putfloat` +- `getfarray` +- `putfarray` +- `putch` +- `putint` + +等符号的真实实现。 + +#### 解决办法 + +本次提交中没有擅自扩展运行库,而是把问题明确定位为“Lab2 IR 生成正确,但运行时依赖未补齐”。这样可以把 Lab2 编译器部分与后续运行库实现清晰分开。 + +#### 影响 + +`95_float.sy` 当前的状态是: + +- IR 生成正确 +- LLVM 后端接受 +- 最终运行依赖运行库补全 + +## 5. 本次实现的主要能力 + +本次实验结束后,编译器已经具备以下 Lab2 关键能力: + +- 全局变量/常量 IR 生成 +- 局部变量 IR 生成 +- `int/float` 常量与表达式生成 +- 基本算术与比较运算 +- 类型转换:`sitofp`、`fptosi`、`zext` +- `if-else` +- `while` +- `break/continue` +- 函数定义与函数调用 +- 标量参数与数组参数 +- 多维数组寻址 +- 局部数组零初始化与花括号初始化 +- 短路求值 +- LLVM 可接受的 IR 文本打印 + +## 6. 验证结果 + +本次已完成的回归包括: + +```bash +./scripts/verify_ir.sh test/test_case/functional/simple_add.sy /tmp/ir_simple --run +./scripts/verify_ir.sh test/test_case/functional/13_sub2.sy /tmp/ir_sub2 --run +./scripts/verify_ir.sh test/test_case/functional/29_break.sy /tmp/ir_break --run +./scripts/verify_ir.sh test/test_case/functional/36_op_priority2.sy /tmp/ir_op --run +./scripts/verify_ir.sh test/test_case/functional/05_arr_defn4.sy /tmp/ir_arr --run +``` + +这些样例均已通过。 + +另外: + +```bash +./build/bin/compiler --emit-ir test/test_case/functional/95_float.sy +``` + +可以成功生成 IR,且 IR 能通过 `llc`,说明浮点常量、浮点表达式、浮点比较、类型转换与数组传参路径已经基本打通。 + +## 7. 本次实验中的经验总结 + +本次 Lab2 最核心的经验有三点: + +1. 编译期常量求值和运行时 IR 生成必须严格分离。 + 只要两条路径混在一起,全局初始化和常量表达式一定会出错。 + +2. 数组不能按“只是更大的标量”处理。 + 数组对象、数组形参、数组元素地址、数组退化指针这几个概念必须明确区分。 + +3. “能打印 IR”不等于“LLVM 能接受 IR”。 + 最后一定要走一遍 `llc/clang`,否则很多类型和格式问题会被掩盖。 + +## 8. 后续可继续完善的方向 + +虽然本次已经完成了 Lab2 的主体工作,但还可以继续完善: + +- 为 `sylib` 补齐实际运行库实现,打通 `95_float` 等 I/O 样例的最终运行 +- 为全局数组初始化补完整的常量聚合表示,而不是目前以标量初始化为主 +- 进一步统一 IR 中布尔类型的内部表示,减少 `i1/i32` 的兼容分支 +- 继续批量回归 `test/test_case` 下更多样例,补齐剩余边界情况 + +## 9. 结论 + +本次 Lab2 已经从“完成约 90%,但被全局初始化与数组/短路问题卡住”的状态,推进到“核心 IR 生成链路可用、典型功能样例可运行验证”的状态。阻塞实验验收的主问题已经被定位并解决,代码结构也比原来更清晰,后续继续做运行库、优化与更大规模回归时会更稳。 diff --git a/include/ir/IR.h b/include/ir/IR.h index b961192..1c7f9a0 100644 --- a/include/ir/IR.h +++ b/include/ir/IR.h @@ -37,6 +37,7 @@ #include #include #include +#include namespace ir { @@ -45,6 +46,7 @@ class Value; class User; class ConstantValue; class ConstantInt; +class ConstantFloat; class GlobalValue; class Instruction; class BasicBlock; @@ -83,17 +85,20 @@ class Context { ~Context(); // 去重创建 i32 常量。 ConstantInt* GetConstInt(int v); + // 去重创建 float 常量。 + ConstantFloat* GetConstFloat(float v); std::string NextTemp(); private: std::unordered_map> const_ints_; + std::unordered_map> const_floats_; int temp_index_ = -1; }; -class Type { +class Type : public std::enable_shared_from_this { public: - enum class Kind { Void, Int32, PtrInt32 }; + enum class Kind { Void, Int32, PtrInt32, Float, PtrFloat, Label, Array }; explicit Type(Kind k); // 使用静态共享对象获取类型。 // 同一类型可直接比较返回值是否相等,例如: @@ -101,15 +106,36 @@ class Type { static const std::shared_ptr& GetVoidType(); static const std::shared_ptr& GetInt32Type(); static const std::shared_ptr& GetPtrInt32Type(); + static const std::shared_ptr& GetFloatType(); + static const std::shared_ptr& GetPtrFloatType(); + static const std::shared_ptr& GetLabelType(); Kind GetKind() const; bool IsVoid() const; bool IsInt32() const; bool IsPtrInt32() const; + bool IsFloat() const; + bool IsPtrFloat() const; + bool IsLabel() const; + bool IsArray() const; + std::shared_ptr GetAsArrayType(); private: Kind kind_; }; +class ArrayType : public Type { + public: + ArrayType(std::shared_ptr element_type, uint32_t num_elements); + static std::shared_ptr Get(std::shared_ptr element_type, + uint32_t num_elements); + std::shared_ptr GetElementType() const { return element_type_; } + uint32_t GetNumElements() const { return num_elements_; } + + private: + std::shared_ptr element_type_; + uint32_t num_elements_; +}; + class Value { public: Value(std::shared_ptr ty, std::string name); @@ -120,10 +146,15 @@ class Value { bool IsVoid() const; bool IsInt32() const; bool IsPtrInt32() const; + bool IsFloat() const; + bool IsPtrFloat() const; + bool IsLabel() const; bool IsConstant() const; bool IsInstruction() const; bool IsUser() const; bool IsFunction() const; + bool IsGlobalValue() const; + bool IsArgument() const; void AddUse(User* user, size_t operand_index); void RemoveUse(User* user, size_t operand_index); const std::vector& GetUses() const; @@ -135,6 +166,19 @@ class Value { std::vector uses_; }; +// Argument represents a function parameter. +class Argument : public Value { + public: + Argument(std::shared_ptr ty, std::string name, Function* parent, + unsigned arg_no); + Function* GetParent() const { return parent_; } + unsigned GetArgNo() const { return arg_no_; } + + private: + Function* parent_; + unsigned arg_no_; +}; + // ConstantValue 是常量体系的基类。 // 当前只实现了 ConstantInt,后续可继续扩展更多常量种类。 class ConstantValue : public Value { @@ -151,8 +195,49 @@ class ConstantInt : public ConstantValue { int value_{}; }; +class ConstantFloat : public ConstantValue { + public: + ConstantFloat(std::shared_ptr ty, float v); + float GetValue() const { return value_; } + + private: + float value_{}; +}; + // 后续还需要扩展更多指令类型。 -enum class Opcode { Add, Sub, Mul, Alloca, Load, Store, Ret }; +enum class Opcode { + Add, + Sub, + Mul, + Div, + Mod, + FAdd, + FSub, + FMul, + FDiv, + ICmpEQ, + ICmpNE, + ICmpLT, + ICmpGT, + ICmpLE, + ICmpGE, + FCmpEQ, + FCmpNE, + FCmpLT, + FCmpGT, + FCmpLE, + FCmpGE, + Alloca, + Load, + Store, + Ret, + Br, + Call, + GEP, + ZExt, + SIToFP, + FPToSI +}; // User 是所有“会使用其他 Value 作为输入”的 IR 对象的抽象基类。 // 当前实现中只有 Instruction 继承自 User。 @@ -171,11 +256,15 @@ class User : public Value { std::vector operands_; }; -// GlobalValue 是全局值/全局变量体系的空壳占位类。 -// 当前只补齐类层次,具体初始化器、打印和链接语义后续再补。 +// GlobalValue 是全局值/全局变量体系的类。 class GlobalValue : public User { public: - GlobalValue(std::shared_ptr ty, std::string name); + GlobalValue(std::shared_ptr ty, std::string name, ConstantValue* init = nullptr); + ConstantValue* GetInitializer() const { return init_; } + void SetInitializer(ConstantValue* init) { init_ = init; } + + private: + ConstantValue* init_ = nullptr; }; class Instruction : public User { @@ -196,7 +285,40 @@ class BinaryInst : public Instruction { BinaryInst(Opcode op, std::shared_ptr ty, Value* lhs, Value* rhs, std::string name); Value* GetLhs() const; - Value* GetRhs() const; + Value* GetRhs() const; +}; + +class BranchInst : public Instruction { + public: + // Unconditional branch + explicit BranchInst(BasicBlock* dest); + // Conditional branch + BranchInst(Value* cond, BasicBlock* if_true, BasicBlock* if_false); + + bool IsConditional() const; + Value* GetCondition() const; + BasicBlock* GetIfTrue() const; + BasicBlock* GetIfFalse() const; + BasicBlock* GetDest() const; +}; + +class CallInst : public Instruction { + public: + CallInst(Function* func, const std::vector& args, std::string name = ""); + Function* GetFunction() const; +}; + +class GetElementPtrInst : public Instruction { + public: + GetElementPtrInst(std::shared_ptr ptr_ty, Value* ptr, + const std::vector& indices, std::string name = ""); + Value* GetPtr() const; +}; + +class CastInst : public Instruction { + public: + CastInst(Opcode op, std::shared_ptr ty, Value* val, std::string name = ""); + Value* GetValue() const; }; class ReturnInst : public Instruction { @@ -255,38 +377,41 @@ class BasicBlock : public Value { }; // Function 当前也采用了最小实现。 -// 需要特别注意:由于项目里还没有单独的 FunctionType, -// Function 继承自 Value 后,其 type_ 目前只保存“返回类型”, -// 并不能完整表达“返回类型 + 形参列表”这一整套函数签名。 -// 这对当前只支持 int main() 的最小 IR 足够,但后续若补普通函数、 -// 形参和调用,通常需要引入专门的函数类型表示。 class Function : public Value { public: - // 当前构造函数接收的也是返回类型,而不是完整函数类型。 - Function(std::string name, std::shared_ptr ret_type); + Function(std::string name, std::shared_ptr ret_type, + std::vector> param_types); BasicBlock* CreateBlock(const std::string& name); BasicBlock* GetEntry(); const BasicBlock* GetEntry() const; const std::vector>& GetBlocks() const; + const std::vector>& GetArguments() const; private: BasicBlock* entry_ = nullptr; std::vector> blocks_; + std::vector> arguments_; }; + class Module { public: Module() = default; Context& GetContext(); const Context& GetContext() const; - // 创建函数时当前只显式传入返回类型,尚未接入完整的 FunctionType。 Function* CreateFunction(const std::string& name, - std::shared_ptr ret_type); + std::shared_ptr ret_type, + std::vector> param_types = {}); const std::vector>& GetFunctions() const; + GlobalValue* CreateGlobalValue(const std::string& name, + std::shared_ptr ty, + ConstantValue* init = nullptr); + const std::vector>& GetGlobalValues() const; private: Context context_; std::vector> functions_; + std::vector> global_values_; }; class IRBuilder { @@ -297,13 +422,41 @@ class IRBuilder { // 构造常量、二元运算、返回指令的最小集合。 ConstantInt* CreateConstInt(int v); + ConstantFloat* CreateConstFloat(float v); BinaryInst* CreateBinary(Opcode op, Value* lhs, Value* rhs, const std::string& name); BinaryInst* CreateAdd(Value* lhs, Value* rhs, const std::string& name); + BinaryInst* CreateSub(Value* lhs, Value* rhs, const std::string& name); + BinaryInst* CreateMul(Value* lhs, Value* rhs, const std::string& name); + BinaryInst* CreateDiv(Value* lhs, Value* rhs, const std::string& name); + BinaryInst* CreateMod(Value* lhs, Value* rhs, const std::string& name); + BinaryInst* CreateFAdd(Value* lhs, Value* rhs, const std::string& name); + BinaryInst* CreateFSub(Value* lhs, Value* rhs, const std::string& name); + BinaryInst* CreateFMul(Value* lhs, Value* rhs, const std::string& name); + BinaryInst* CreateFDiv(Value* lhs, Value* rhs, const std::string& name); + BinaryInst* CreateICmp(Opcode op, Value* lhs, Value* rhs, + const std::string& name); + BinaryInst* CreateFCmp(Opcode op, Value* lhs, Value* rhs, + const std::string& name); + AllocaInst* CreateAlloca(std::shared_ptr ty, const std::string& name); AllocaInst* CreateAllocaI32(const std::string& name); LoadInst* CreateLoad(Value* ptr, const std::string& name); StoreInst* CreateStore(Value* val, Value* ptr); ReturnInst* CreateRet(Value* v); + BranchInst* CreateBr(BasicBlock* dest); + BranchInst* CreateCondBr(Value* cond, BasicBlock* if_true, + BasicBlock* if_false); + CallInst* CreateCall(Function* func, const std::vector& args, + const std::string& name = ""); + GetElementPtrInst* CreateGEP(std::shared_ptr ptr_ty, Value* ptr, + const std::vector& indices, + const std::string& name = ""); + CastInst* CreateZExt(Value* val, std::shared_ptr ty, + const std::string& name = ""); + CastInst* CreateSIToFP(Value* val, std::shared_ptr ty, + const std::string& name = ""); + CastInst* CreateFPToSI(Value* val, std::shared_ptr ty, + const std::string& name = ""); private: Context& ctx_; diff --git a/include/irgen/IRGen.h b/include/irgen/IRGen.h index 231ba90..6dab7ea 100644 --- a/include/irgen/IRGen.h +++ b/include/irgen/IRGen.h @@ -5,8 +5,10 @@ #include #include +#include #include #include +#include #include "SysYBaseVisitor.h" #include "SysYParser.h" @@ -18,24 +20,56 @@ class Module; class Function; class IRBuilder; class Value; +class BasicBlock; } class IRGenImpl final : public SysYBaseVisitor { public: IRGenImpl(ir::Module& module, const SemanticContext& sema); + // Top-level rules std::any visitCompUnit(SysYParser::CompUnitContext* ctx) override; + std::any visitDecl(SysYParser::DeclContext* ctx) override; + std::any visitConstDecl(SysYParser::ConstDeclContext* ctx) override; + std::any visitVarDecl(SysYParser::VarDeclContext* ctx) override; + std::any visitConstDef(SysYParser::ConstDefContext* ctx) override; + std::any visitVarDef(SysYParser::VarDefContext* ctx) override; std::any visitFuncDef(SysYParser::FuncDefContext* ctx) override; + std::any visitFuncFParam(SysYParser::FuncFParamContext* ctx) override; + + // Statement rules std::any visitBlockStmt(SysYParser::BlockStmtContext* ctx) override; std::any visitBlockItem(SysYParser::BlockItemContext* ctx) override; - std::any visitDecl(SysYParser::DeclContext* ctx) override; std::any visitStmt(SysYParser::StmtContext* ctx) override; - std::any visitVarDef(SysYParser::VarDefContext* ctx) override; + std::any visitAssignStmt(SysYParser::AssignStmtContext* ctx) override; std::any visitReturnStmt(SysYParser::ReturnStmtContext* ctx) override; + std::any visitIfStmt(SysYParser::IfStmtContext* ctx) override; + std::any visitWhileStmt(SysYParser::WhileStmtContext* ctx) override; + std::any visitBreakStmt(SysYParser::BreakStmtContext* ctx) override; + std::any visitContinueStmt(SysYParser::ContinueStmtContext* ctx) override; + std::any visitExpStmt(SysYParser::ExpStmtContext* ctx) override; + + // Expression rules std::any visitParenExp(SysYParser::ParenExpContext* ctx) override; + std::any visitLValueExp(SysYParser::LValueExpContext* ctx) override; std::any visitNumberExp(SysYParser::NumberExpContext* ctx) override; - std::any visitVarExp(SysYParser::VarExpContext* ctx) override; - std::any visitAdditiveExp(SysYParser::AdditiveExpContext* ctx) override; + std::any visitFuncCallExp(SysYParser::FuncCallExpContext* ctx) override; + std::any visitNotExp(SysYParser::NotExpContext* ctx) override; + std::any visitUnaryAddExp(SysYParser::UnaryAddExpContext* ctx) override; + std::any visitUnarySubExp(SysYParser::UnarySubExpContext* ctx) override; + std::any visitMulExp(SysYParser::MulExpContext* ctx) override; + std::any visitDivExp(SysYParser::DivExpContext* ctx) override; + std::any visitModExp(SysYParser::ModExpContext* ctx) override; + std::any visitAddExp(SysYParser::AddExpContext* ctx) override; + std::any visitSubExp(SysYParser::SubExpContext* ctx) override; + std::any visitLtExp(SysYParser::LtExpContext* ctx) override; + std::any visitLeExp(SysYParser::LeExpContext* ctx) override; + std::any visitGtExp(SysYParser::GtExpContext* ctx) override; + std::any visitGeExp(SysYParser::GeExpContext* ctx) override; + std::any visitEqExp(SysYParser::EqExpContext* ctx) override; + std::any visitNeExp(SysYParser::NeExpContext* ctx) override; + std::any visitAndExp(SysYParser::AndExpContext* ctx) override; + std::any visitOrExp(SysYParser::OrExpContext* ctx) override; private: enum class BlockFlow { @@ -43,15 +77,35 @@ class IRGenImpl final : public SysYBaseVisitor { Terminated, }; - BlockFlow VisitBlockItemResult(SysYParser::BlockItemContext& item); ir::Value* EvalExpr(SysYParser::ExpContext& expr); + ir::ConstantValue* EvalConstExpr(SysYParser::ExpContext& expr); + ir::Value* GetLValuePtr(SysYParser::LValueContext* ctx); + ir::Value* DecayArrayPtr(SysYParser::LValueContext* ctx); + bool IsArrayLikeDef(antlr4::ParserRuleContext* def) const; + size_t GetArrayRank(antlr4::ParserRuleContext* def) const; + std::shared_ptr GetDefType(antlr4::ParserRuleContext* def) const; + void ZeroInitializeLocal(ir::Value* ptr, std::shared_ptr ty); + void EmitLocalInitValue(ir::Value* ptr, std::shared_ptr ty, + SysYParser::InitValueContext* init); ir::Module& module_; const SemanticContext& sema_; ir::Function* func_; ir::IRBuilder builder_; - // 名称绑定由 Sema 负责;IRGen 只维护“声明 -> 存储槽位”的代码生成状态。 - std::unordered_map storage_map_; + + // Maps a definition (VarDef, ConstDef, FuncFParam) to its IR value (Alloca or GlobalValue) + std::unordered_map storage_map_; + + // For global scope tracking + bool is_global_scope_ = true; + + // For loop control + std::stack break_stack_; + std::stack continue_stack_; + + // Helper to handle short-circuiting and comparison results + ir::Value* ToI1(ir::Value* v); + ir::Value* ToI32(ir::Value* v); }; std::unique_ptr GenerateIR(SysYParser::CompUnitContext& tree, diff --git a/include/sem/Sema.h b/include/sem/Sema.h index 9ac057b..88baacd 100644 --- a/include/sem/Sema.h +++ b/include/sem/Sema.h @@ -1,30 +1,40 @@ -// 基于语法树的语义检查与名称绑定。 #pragma once #include +#include #include "SysYParser.h" class SemanticContext { public: - void BindVarUse(SysYParser::VarContext* use, - SysYParser::VarDefContext* decl) { - var_uses_[use] = decl; + void BindLValue(SysYParser::LValueContext* use, + antlr4::ParserRuleContext* def) { + lvalue_defs_[use] = def; } - SysYParser::VarDefContext* ResolveVarUse( - const SysYParser::VarContext* use) const { - auto it = var_uses_.find(use); - return it == var_uses_.end() ? nullptr : it->second; + void BindFuncCall(SysYParser::FuncCallExpContext* use, + SysYParser::FuncDefContext* def) { + funccall_defs_[use] = def; + } + + antlr4::ParserRuleContext* ResolveLValue( + const SysYParser::LValueContext* use) const { + auto it = lvalue_defs_.find(const_cast(use)); + return it == lvalue_defs_.end() ? nullptr : it->second; + } + + SysYParser::FuncDefContext* ResolveFuncCall( + const SysYParser::FuncCallExpContext* use) const { + auto it = funccall_defs_.find(const_cast(use)); + return it == funccall_defs_.end() ? nullptr : it->second; } private: - std::unordered_map - var_uses_; + std::unordered_map + lvalue_defs_; + std::unordered_map + funccall_defs_; }; -// 目前仅检查: -// - 变量先声明后使用 -// - 局部变量不允许重复定义 SemanticContext RunSema(SysYParser::CompUnitContext& comp_unit); diff --git a/include/sem/SymbolTable.h b/include/sem/SymbolTable.h index c9396dd..218f61e 100644 --- a/include/sem/SymbolTable.h +++ b/include/sem/SymbolTable.h @@ -1,17 +1,30 @@ -// 极简符号表:记录局部变量定义点。 #pragma once #include #include +#include #include "SysYParser.h" +struct Symbol { + enum class Kind { Variable, Constant, Function, Parameter }; + Kind kind; + antlr4::ParserRuleContext* def_ctx; + bool is_const = false; + bool is_array = false; + // For functions, we can store pointers to their parameter types or just the + // FuncDefContext* +}; + class SymbolTable { public: - void Add(const std::string& name, SysYParser::VarDefContext* decl); - bool Contains(const std::string& name) const; - SysYParser::VarDefContext* Lookup(const std::string& name) const; + SymbolTable(); + void PushScope(); + void PopScope(); + bool Add(const std::string& name, const Symbol& symbol); + Symbol* Lookup(const std::string& name); + bool IsInCurrentScope(const std::string& name) const; private: - std::unordered_map table_; + std::vector> scopes_; }; diff --git a/src/ir/BasicBlock.cpp b/src/ir/BasicBlock.cpp index b18502c..1950f71 100644 --- a/src/ir/BasicBlock.cpp +++ b/src/ir/BasicBlock.cpp @@ -15,7 +15,7 @@ namespace ir { // 当前 BasicBlock 还没有专门的 label type,因此先用 void 作为占位类型。 BasicBlock::BasicBlock(std::string name) - : Value(Type::GetVoidType(), std::move(name)) {} + : Value(Type::GetLabelType(), std::move(name)) {} Function* BasicBlock::GetParent() const { return parent_; } diff --git a/src/ir/CMakeLists.txt b/src/ir/CMakeLists.txt index 99987ed..c3b6e7b 100644 --- a/src/ir/CMakeLists.txt +++ b/src/ir/CMakeLists.txt @@ -3,7 +3,6 @@ add_library(ir_core STATIC Module.cpp Function.cpp BasicBlock.cpp - GlobalValue.cpp Type.cpp Value.cpp Instruction.cpp diff --git a/src/ir/Context.cpp b/src/ir/Context.cpp index 16c982c..e7a81dc 100644 --- a/src/ir/Context.cpp +++ b/src/ir/Context.cpp @@ -15,10 +15,18 @@ ConstantInt* Context::GetConstInt(int v) { return inserted->second.get(); } +ConstantFloat* Context::GetConstFloat(float v) { + auto it = const_floats_.find(v); + if (it != const_floats_.end()) return it->second.get(); + auto inserted = + const_floats_ + .emplace(v, std::make_unique(Type::GetFloatType(), v)) + .first; + return inserted->second.get(); +} + std::string Context::NextTemp() { - std::ostringstream oss; - oss << "%" << ++temp_index_; - return oss.str(); + return "t" + std::to_string(++temp_index_); } } // namespace ir diff --git a/src/ir/Function.cpp b/src/ir/Function.cpp index cf14d48..8ccc084 100644 --- a/src/ir/Function.cpp +++ b/src/ir/Function.cpp @@ -5,9 +5,14 @@ namespace ir { -Function::Function(std::string name, std::shared_ptr ret_type) +Function::Function(std::string name, std::shared_ptr ret_type, + std::vector> param_types) : Value(std::move(ret_type), std::move(name)) { - entry_ = CreateBlock("entry"); + for (size_t i = 0; i < param_types.size(); ++i) { + arguments_.push_back(std::make_unique( + param_types[i], "a" + std::to_string(i), this, + static_cast(i))); + } } BasicBlock* Function::CreateBlock(const std::string& name) { @@ -29,4 +34,8 @@ const std::vector>& Function::GetBlocks() const { return blocks_; } +const std::vector>& Function::GetArguments() const { + return arguments_; +} + } // namespace ir diff --git a/src/ir/GlobalValue.cpp b/src/ir/GlobalValue.cpp index 7c2abe1..b8952f2 100644 --- a/src/ir/GlobalValue.cpp +++ b/src/ir/GlobalValue.cpp @@ -5,7 +5,7 @@ namespace ir { -GlobalValue::GlobalValue(std::shared_ptr ty, std::string name) - : User(std::move(ty), std::move(name)) {} +GlobalValue::GlobalValue(std::shared_ptr ty, std::string name, ConstantValue* init) + : User(std::move(ty), std::move(name)), init_(init) {} } // namespace ir diff --git a/src/ir/IRBuilder.cpp b/src/ir/IRBuilder.cpp index 90f03c4..6214560 100644 --- a/src/ir/IRBuilder.cpp +++ b/src/ir/IRBuilder.cpp @@ -21,6 +21,11 @@ ConstantInt* IRBuilder::CreateConstInt(int v) { return ctx_.GetConstInt(v); } +ConstantFloat* IRBuilder::CreateConstFloat(float v) { + // 常量不需要挂在基本块里,由 Context 负责去重与生命周期。 + return ctx_.GetConstFloat(v); +} + BinaryInst* IRBuilder::CreateBinary(Opcode op, Value* lhs, Value* rhs, const std::string& name) { if (!insert_block_) { @@ -42,11 +47,74 @@ BinaryInst* IRBuilder::CreateAdd(Value* lhs, Value* rhs, return CreateBinary(Opcode::Add, lhs, rhs, name); } -AllocaInst* IRBuilder::CreateAllocaI32(const std::string& name) { +BinaryInst* IRBuilder::CreateSub(Value* lhs, Value* rhs, + const std::string& name) { + return CreateBinary(Opcode::Sub, lhs, rhs, name); +} + +BinaryInst* IRBuilder::CreateMul(Value* lhs, Value* rhs, + const std::string& name) { + return CreateBinary(Opcode::Mul, lhs, rhs, name); +} + +BinaryInst* IRBuilder::CreateDiv(Value* lhs, Value* rhs, + const std::string& name) { + return CreateBinary(Opcode::Div, lhs, rhs, name); +} + +BinaryInst* IRBuilder::CreateMod(Value* lhs, Value* rhs, + const std::string& name) { + return CreateBinary(Opcode::Mod, lhs, rhs, name); +} + +BinaryInst* IRBuilder::CreateFAdd(Value* lhs, Value* rhs, + const std::string& name) { + return CreateBinary(Opcode::FAdd, lhs, rhs, name); +} + +BinaryInst* IRBuilder::CreateFSub(Value* lhs, Value* rhs, + const std::string& name) { + return CreateBinary(Opcode::FSub, lhs, rhs, name); +} + +BinaryInst* IRBuilder::CreateFMul(Value* lhs, Value* rhs, + const std::string& name) { + return CreateBinary(Opcode::FMul, lhs, rhs, name); +} + +BinaryInst* IRBuilder::CreateFDiv(Value* lhs, Value* rhs, + const std::string& name) { + return CreateBinary(Opcode::FDiv, lhs, rhs, name); +} + +BinaryInst* IRBuilder::CreateICmp(Opcode op, Value* lhs, Value* rhs, + const std::string& name) { if (!insert_block_) { throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点")); } - return insert_block_->Append(Type::GetPtrInt32Type(), name); + return insert_block_->Append(op, Type::GetInt32Type(), lhs, rhs, + name); +} + +BinaryInst* IRBuilder::CreateFCmp(Opcode op, Value* lhs, Value* rhs, + const std::string& name) { + if (!insert_block_) { + throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点")); + } + return insert_block_->Append(op, Type::GetInt32Type(), lhs, rhs, + name); +} + +AllocaInst* IRBuilder::CreateAlloca(std::shared_ptr ty, + const std::string& name) { + if (!insert_block_) { + throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点")); + } + return insert_block_->Append(ty, name); +} + +AllocaInst* IRBuilder::CreateAllocaI32(const std::string& name) { + return CreateAlloca(Type::GetPtrInt32Type(), name); } LoadInst* IRBuilder::CreateLoad(Value* ptr, const std::string& name) { @@ -57,7 +125,15 @@ LoadInst* IRBuilder::CreateLoad(Value* ptr, const std::string& name) { throw std::runtime_error( FormatError("ir", "IRBuilder::CreateLoad 缺少 ptr")); } - return insert_block_->Append(Type::GetInt32Type(), ptr, name); + std::shared_ptr val_ty; + if (ptr->GetType()->IsPtrInt32()) { + val_ty = Type::GetInt32Type(); + } else if (ptr->GetType()->IsPtrFloat()) { + val_ty = Type::GetFloatType(); + } else { + throw std::runtime_error(FormatError("ir", "LoadInst 不支持的指针类型")); + } + return insert_block_->Append(val_ty, ptr, name); } StoreInst* IRBuilder::CreateStore(Value* val, Value* ptr) { @@ -79,11 +155,63 @@ ReturnInst* IRBuilder::CreateRet(Value* v) { if (!insert_block_) { throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点")); } - if (!v) { - throw std::runtime_error( - FormatError("ir", "IRBuilder::CreateRet 缺少返回值")); - } return insert_block_->Append(Type::GetVoidType(), v); } +BranchInst* IRBuilder::CreateBr(BasicBlock* dest) { + if (!insert_block_) { + throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点")); + } + return insert_block_->Append(dest); +} + +BranchInst* IRBuilder::CreateCondBr(Value* cond, BasicBlock* if_true, + BasicBlock* if_false) { + if (!insert_block_) { + throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点")); + } + return insert_block_->Append(cond, if_true, if_false); +} + +CallInst* IRBuilder::CreateCall(Function* func, const std::vector& args, + const std::string& name) { + if (!insert_block_) { + throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点")); + } + return insert_block_->Append(func, args, name); +} + +GetElementPtrInst* IRBuilder::CreateGEP(std::shared_ptr ptr_ty, Value* ptr, + const std::vector& indices, + const std::string& name) { + if (!insert_block_) { + throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点")); + } + return insert_block_->Append(ptr_ty, ptr, indices, name); +} + +CastInst* IRBuilder::CreateZExt(Value* val, std::shared_ptr ty, + const std::string& name) { + if (!insert_block_) { + throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点")); + } + return insert_block_->Append(Opcode::ZExt, ty, val, name); +} + +CastInst* IRBuilder::CreateSIToFP(Value* val, std::shared_ptr ty, + const std::string& name) { + if (!insert_block_) { + throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点")); + } + return insert_block_->Append(Opcode::SIToFP, ty, val, name); +} + +CastInst* IRBuilder::CreateFPToSI(Value* val, std::shared_ptr ty, + const std::string& name) { + if (!insert_block_) { + throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点")); + } + return insert_block_->Append(Opcode::FPToSI, ty, val, name); +} + } // namespace ir diff --git a/src/ir/IRPrinter.cpp b/src/ir/IRPrinter.cpp index 30efbb6..219a59c 100644 --- a/src/ir/IRPrinter.cpp +++ b/src/ir/IRPrinter.cpp @@ -4,7 +4,11 @@ #include "ir/IR.h" +#include +#include +#include #include +#include #include #include @@ -12,7 +16,7 @@ namespace ir { -static const char* TypeToString(const Type& ty) { +static std::string TypeToString(const Type& ty) { switch (ty.GetKind()) { case Type::Kind::Void: return "void"; @@ -20,11 +24,22 @@ static const char* TypeToString(const Type& ty) { return "i32"; case Type::Kind::PtrInt32: return "i32*"; + case Type::Kind::Float: + return "float"; + case Type::Kind::PtrFloat: + return "float*"; + case Type::Kind::Label: + return "label"; + case Type::Kind::Array: { + const auto* arr_ty = static_cast(&ty); + return "[" + std::to_string(arr_ty->GetNumElements()) + " x " + + TypeToString(*arr_ty->GetElementType()) + "]"; + } } - throw std::runtime_error(FormatError("ir", "未知类型")); + return "unknown"; } -static const char* OpcodeToString(Opcode op) { +static std::string OpcodeToString(Opcode op) { switch (op) { case Opcode::Add: return "add"; @@ -32,6 +47,42 @@ static const char* OpcodeToString(Opcode op) { return "sub"; case Opcode::Mul: return "mul"; + case Opcode::Div: + return "sdiv"; + case Opcode::Mod: + return "srem"; + case Opcode::FAdd: + return "fadd"; + case Opcode::FSub: + return "fsub"; + case Opcode::FMul: + return "fmul"; + case Opcode::FDiv: + return "fdiv"; + case Opcode::ICmpEQ: + return "icmp eq"; + case Opcode::ICmpNE: + return "icmp ne"; + case Opcode::ICmpLT: + return "icmp slt"; + case Opcode::ICmpGT: + return "icmp sgt"; + case Opcode::ICmpLE: + return "icmp sle"; + case Opcode::ICmpGE: + return "icmp sge"; + case Opcode::FCmpEQ: + return "fcmp oeq"; + case Opcode::FCmpNE: + return "fcmp une"; + case Opcode::FCmpLT: + return "fcmp olt"; + case Opcode::FCmpGT: + return "fcmp ogt"; + case Opcode::FCmpLE: + return "fcmp ole"; + case Opcode::FCmpGE: + return "fcmp oge"; case Opcode::Alloca: return "alloca"; case Opcode::Load: @@ -40,21 +91,114 @@ static const char* OpcodeToString(Opcode op) { return "store"; case Opcode::Ret: return "ret"; + case Opcode::Br: + return "br"; + case Opcode::Call: + return "call"; + case Opcode::GEP: + return "getelementptr"; + case Opcode::ZExt: + return "zext"; + case Opcode::SIToFP: + return "sitofp"; + case Opcode::FPToSI: + return "fptosi"; } return "?"; } static std::string ValueToString(const Value* v) { + if (!v) return ""; if (auto* ci = dynamic_cast(v)) { return std::to_string(ci->GetValue()); } - return v ? v->GetName() : ""; + if (auto* cf = dynamic_cast(v)) { + const double as_double = static_cast(cf->GetValue()); + uint64_t bits = 0; + std::memcpy(&bits, &as_double, sizeof(bits)); + std::ostringstream oss; + oss << "0x" << std::hex << std::uppercase << std::setw(16) + << std::setfill('0') << bits; + return oss.str(); + } + if (v->IsGlobalValue() || v->IsFunction()) { + return "@" + v->GetName(); + } + if (v->IsInstruction() || v->IsArgument() || v->GetType()->IsLabel()) { + return "%" + v->GetName(); + } + return v->GetName(); +} + +static bool IsBoolLikeValue(const Value* v) { + if (auto* inst = dynamic_cast(v)) { + switch (inst->GetOpcode()) { + case Opcode::ICmpEQ: + case Opcode::ICmpNE: + case Opcode::ICmpLT: + case Opcode::ICmpGT: + case Opcode::ICmpLE: + case Opcode::ICmpGE: + case Opcode::FCmpEQ: + case Opcode::FCmpNE: + case Opcode::FCmpLT: + case Opcode::FCmpGT: + case Opcode::FCmpLE: + case Opcode::FCmpGE: + return true; + default: + break; + } + } + return false; +} + +static std::string PrintedValueType(const Value* v) { + if (IsBoolLikeValue(v)) return "i1"; + return TypeToString(*v->GetType()); } void IRPrinter::Print(const Module& module, std::ostream& os) { + // Print global variables + for (const auto& gv : module.GetGlobalValues()) { + os << "@" << gv->GetName() << " = global "; + if (gv->GetType()->IsPtrInt32()) { + os << "i32"; + } else if (gv->GetType()->IsPtrFloat()) { + os << "float"; + } else { + os << TypeToString(*gv->GetType()); + } + if (gv->GetInitializer()) { + os << " " << ValueToString(gv->GetInitializer()); + } else { + os << " zeroinitializer"; + } + os << "\n"; + } + if (!module.GetGlobalValues().empty()) os << "\n"; + for (const auto& func : module.GetFunctions()) { + if (func->GetBlocks().empty()) { + os << "declare " << TypeToString(*func->GetType()) << " @" << func->GetName() + << "("; + // For declarations, we just need types. But Argument objects might exist. + const auto& args = func->GetArguments(); + for (size_t i = 0; i < args.size(); ++i) { + os << TypeToString(*args[i]->GetType()); + if (i + 1 < args.size()) os << ", "; + } + os << ")\n\n"; + continue; + } os << "define " << TypeToString(*func->GetType()) << " @" << func->GetName() - << "() {\n"; + << "("; + const auto& args = func->GetArguments(); + for (size_t i = 0; i < args.size(); ++i) { + os << TypeToString(*args[i]->GetType()) << " %" << args[i]->GetName(); + if (i + 1 < args.size()) os << ", "; + } + os << ") {\n"; for (const auto& bb : func->GetBlocks()) { if (!bb) { continue; @@ -65,36 +209,142 @@ void IRPrinter::Print(const Module& module, std::ostream& os) { switch (inst->GetOpcode()) { case Opcode::Add: case Opcode::Sub: - case Opcode::Mul: { + case Opcode::Mul: + case Opcode::Div: + case Opcode::Mod: + case Opcode::FAdd: + case Opcode::FSub: + case Opcode::FMul: + case Opcode::FDiv: { auto* bin = static_cast(inst); - os << " " << bin->GetName() << " = " + os << " %" << bin->GetName() << " = " << OpcodeToString(bin->GetOpcode()) << " " - << TypeToString(*bin->GetLhs()->GetType()) << " " + << PrintedValueType(bin->GetLhs()) << " " + << ValueToString(bin->GetLhs()) << ", " + << ValueToString(bin->GetRhs()) << "\n"; + break; + } + case Opcode::ICmpEQ: + case Opcode::ICmpNE: + case Opcode::ICmpLT: + case Opcode::ICmpGT: + case Opcode::ICmpLE: + case Opcode::ICmpGE: + case Opcode::FCmpEQ: + case Opcode::FCmpNE: + case Opcode::FCmpLT: + case Opcode::FCmpGT: + case Opcode::FCmpLE: + case Opcode::FCmpGE: { + auto* bin = static_cast(inst); + os << " %" << bin->GetName() << " = " + << OpcodeToString(bin->GetOpcode()) << " " + << PrintedValueType(bin->GetLhs()) << " " << ValueToString(bin->GetLhs()) << ", " << ValueToString(bin->GetRhs()) << "\n"; break; } case Opcode::Alloca: { auto* alloca = static_cast(inst); - os << " " << alloca->GetName() << " = alloca i32\n"; + os << " %" << alloca->GetName() << " = alloca "; + if (alloca->GetType()->IsPtrInt32()) + os << "i32"; + else if (alloca->GetType()->IsPtrFloat()) + os << "float"; + else + os << TypeToString(*alloca->GetType()); + os << "\n"; break; } case Opcode::Load: { auto* load = static_cast(inst); - os << " " << load->GetName() << " = load i32, i32* " + os << " %" << load->GetName() << " = load " + << TypeToString(*load->GetType()) << ", " + << TypeToString(*load->GetPtr()->GetType()) << " " << ValueToString(load->GetPtr()) << "\n"; break; } case Opcode::Store: { auto* store = static_cast(inst); - os << " store i32 " << ValueToString(store->GetValue()) - << ", i32* " << ValueToString(store->GetPtr()) << "\n"; + os << " store " << TypeToString(*store->GetValue()->GetType()) + << " " << ValueToString(store->GetValue()) << ", " + << TypeToString(*store->GetPtr()->GetType()) << " " + << ValueToString(store->GetPtr()) << "\n"; break; } case Opcode::Ret: { auto* ret = static_cast(inst); - os << " ret " << TypeToString(*ret->GetValue()->GetType()) << " " - << ValueToString(ret->GetValue()) << "\n"; + os << " ret "; + if (ret->GetValue()) { + os << TypeToString(*ret->GetValue()->GetType()) << " " + << ValueToString(ret->GetValue()); + } else { + os << "void"; + } + os << "\n"; + break; + } + case Opcode::Br: { + auto* br = static_cast(inst); + if (br->IsConditional()) { + os << " br i1 " << ValueToString(br->GetCondition()) + << ", label " << ValueToString(br->GetIfTrue()) << ", label " + << ValueToString(br->GetIfFalse()) << "\n"; + } else { + os << " br label " << ValueToString(br->GetDest()) << "\n"; + } + break; + } + case Opcode::Call: { + auto* call = static_cast(inst); + auto* func = call->GetFunction(); + if (!call->GetType()->IsVoid()) { + os << " %" << call->GetName() << " = "; + } else { + os << " "; + } + os << "call " << TypeToString(*call->GetType()) << " " + << ValueToString(func) << "("; + for (size_t i = 1; i < call->GetNumOperands(); ++i) { + auto* arg = call->GetOperand(i); + os << PrintedValueType(arg) << " " << ValueToString(arg); + if (i + 1 < call->GetNumOperands()) os << ", "; + } + os << ")\n"; + break; + } + case Opcode::GEP: { + auto* gep = static_cast(inst); + os << " %" << gep->GetName() << " = getelementptr "; + if (gep->GetPtr()->GetType()->IsPtrInt32()) + os << "i32"; + else if (gep->GetPtr()->GetType()->IsPtrFloat()) + os << "float"; + else + os << TypeToString(*gep->GetPtr()->GetType()); + os << ", "; + if (gep->GetPtr()->GetType()->IsArray()) { + os << TypeToString(*gep->GetPtr()->GetType()) << "* "; + } else { + os << TypeToString(*gep->GetPtr()->GetType()) << " "; + } + os << ValueToString(gep->GetPtr()); + for (size_t i = 1; i < gep->GetNumOperands(); ++i) { + os << ", " << TypeToString(*gep->GetOperand(i)->GetType()) << " " + << ValueToString(gep->GetOperand(i)); + } + os << "\n"; + break; + } + case Opcode::ZExt: + case Opcode::SIToFP: + case Opcode::FPToSI: { + auto* cast = static_cast(inst); + os << " %" << cast->GetName() << " = " + << OpcodeToString(cast->GetOpcode()) << " " + << PrintedValueType(cast->GetValue()) << " " + << ValueToString(cast->GetValue()) << " to " + << TypeToString(*cast->GetType()) << "\n"; break; } } diff --git a/src/ir/Instruction.cpp b/src/ir/Instruction.cpp index 7928716..8c6e569 100644 --- a/src/ir/Instruction.cpp +++ b/src/ir/Instruction.cpp @@ -52,7 +52,9 @@ Instruction::Instruction(Opcode op, std::shared_ptr ty, std::string name) Opcode Instruction::GetOpcode() const { return opcode_; } -bool Instruction::IsTerminator() const { return opcode_ == Opcode::Ret; } +bool Instruction::IsTerminator() const { + return opcode_ == Opcode::Ret || opcode_ == Opcode::Br; +} BasicBlock* Instruction::GetParent() const { return parent_; } @@ -61,22 +63,9 @@ void Instruction::SetParent(BasicBlock* parent) { parent_ = parent; } BinaryInst::BinaryInst(Opcode op, std::shared_ptr ty, Value* lhs, Value* rhs, std::string name) : Instruction(op, std::move(ty), std::move(name)) { - if (op != Opcode::Add) { - throw std::runtime_error(FormatError("ir", "BinaryInst 当前只支持 Add")); - } if (!lhs || !rhs) { throw std::runtime_error(FormatError("ir", "BinaryInst 缺少操作数")); } - if (!type_ || !lhs->GetType() || !rhs->GetType()) { - throw std::runtime_error(FormatError("ir", "BinaryInst 缺少类型信息")); - } - if (lhs->GetType()->GetKind() != rhs->GetType()->GetKind() || - type_->GetKind() != lhs->GetType()->GetKind()) { - throw std::runtime_error(FormatError("ir", "BinaryInst 类型不匹配")); - } - if (!type_->IsInt32()) { - throw std::runtime_error(FormatError("ir", "BinaryInst 当前只支持 i32")); - } AddOperand(lhs); AddOperand(rhs); } @@ -85,38 +74,85 @@ Value* BinaryInst::GetLhs() const { return GetOperand(0); } Value* BinaryInst::GetRhs() const { return GetOperand(1); } -ReturnInst::ReturnInst(std::shared_ptr void_ty, Value* val) - : Instruction(Opcode::Ret, std::move(void_ty), "") { - if (!val) { - throw std::runtime_error(FormatError("ir", "ReturnInst 缺少返回值")); +BranchInst::BranchInst(BasicBlock* dest) + : Instruction(Opcode::Br, Type::GetVoidType(), "") { + AddOperand(dest); +} + +BranchInst::BranchInst(Value* cond, BasicBlock* if_true, BasicBlock* if_false) + : Instruction(Opcode::Br, Type::GetVoidType(), "") { + AddOperand(cond); + AddOperand(if_true); + AddOperand(if_false); +} + +bool BranchInst::IsConditional() const { return GetNumOperands() == 3; } + +Value* BranchInst::GetCondition() const { + return IsConditional() ? GetOperand(0) : nullptr; +} + +BasicBlock* BranchInst::GetIfTrue() const { + return IsConditional() ? static_cast(GetOperand(1)) : nullptr; +} + +BasicBlock* BranchInst::GetIfFalse() const { + return IsConditional() ? static_cast(GetOperand(2)) : nullptr; +} + +BasicBlock* BranchInst::GetDest() const { + return !IsConditional() ? static_cast(GetOperand(0)) : nullptr; +} + +CallInst::CallInst(Function* func, const std::vector& args, + std::string name) + : Instruction(Opcode::Call, func->GetType(), std::move(name)) { + AddOperand(func); + for (auto* arg : args) { + AddOperand(arg); } - if (!type_ || !type_->IsVoid()) { - throw std::runtime_error(FormatError("ir", "ReturnInst 返回类型必须为 void")); +} + +Function* CallInst::GetFunction() const { + return static_cast(GetOperand(0)); +} + +GetElementPtrInst::GetElementPtrInst(std::shared_ptr ptr_ty, Value* ptr, + const std::vector& indices, + std::string name) + : Instruction(Opcode::GEP, std::move(ptr_ty), std::move(name)) { + AddOperand(ptr); + for (auto* idx : indices) { + AddOperand(idx); } +} + +Value* GetElementPtrInst::GetPtr() const { return GetOperand(0); } + +CastInst::CastInst(Opcode op, std::shared_ptr ty, Value* val, + std::string name) + : Instruction(op, std::move(ty), std::move(name)) { AddOperand(val); } -Value* ReturnInst::GetValue() const { return GetOperand(0); } +Value* CastInst::GetValue() const { return GetOperand(0); } -AllocaInst::AllocaInst(std::shared_ptr ptr_ty, std::string name) - : Instruction(Opcode::Alloca, std::move(ptr_ty), std::move(name)) { - if (!type_ || !type_->IsPtrInt32()) { - throw std::runtime_error(FormatError("ir", "AllocaInst 当前只支持 i32*")); +ReturnInst::ReturnInst(std::shared_ptr void_ty, Value* val) + : Instruction(Opcode::Ret, std::move(void_ty), "") { + if (val) { + AddOperand(val); } } +Value* ReturnInst::GetValue() const { + return GetNumOperands() > 0 ? GetOperand(0) : nullptr; +} + +AllocaInst::AllocaInst(std::shared_ptr ptr_ty, std::string name) + : Instruction(Opcode::Alloca, std::move(ptr_ty), std::move(name)) {} + LoadInst::LoadInst(std::shared_ptr val_ty, Value* ptr, std::string name) : Instruction(Opcode::Load, std::move(val_ty), std::move(name)) { - if (!ptr) { - throw std::runtime_error(FormatError("ir", "LoadInst 缺少 ptr")); - } - if (!type_ || !type_->IsInt32()) { - throw std::runtime_error(FormatError("ir", "LoadInst 当前只支持加载 i32")); - } - if (!ptr->GetType() || !ptr->GetType()->IsPtrInt32()) { - throw std::runtime_error( - FormatError("ir", "LoadInst 当前只支持从 i32* 加载")); - } AddOperand(ptr); } @@ -124,22 +160,6 @@ Value* LoadInst::GetPtr() const { return GetOperand(0); } StoreInst::StoreInst(std::shared_ptr void_ty, Value* val, Value* ptr) : Instruction(Opcode::Store, std::move(void_ty), "") { - if (!val) { - throw std::runtime_error(FormatError("ir", "StoreInst 缺少 value")); - } - if (!ptr) { - throw std::runtime_error(FormatError("ir", "StoreInst 缺少 ptr")); - } - if (!type_ || !type_->IsVoid()) { - throw std::runtime_error(FormatError("ir", "StoreInst 返回类型必须为 void")); - } - if (!val->GetType() || !val->GetType()->IsInt32()) { - throw std::runtime_error(FormatError("ir", "StoreInst 当前只支持存储 i32")); - } - if (!ptr->GetType() || !ptr->GetType()->IsPtrInt32()) { - throw std::runtime_error( - FormatError("ir", "StoreInst 当前只支持写入 i32*")); - } AddOperand(val); AddOperand(ptr); } diff --git a/src/ir/Module.cpp b/src/ir/Module.cpp index 928efdc..dadbf92 100644 --- a/src/ir/Module.cpp +++ b/src/ir/Module.cpp @@ -9,8 +9,10 @@ Context& Module::GetContext() { return context_; } const Context& Module::GetContext() const { return context_; } Function* Module::CreateFunction(const std::string& name, - std::shared_ptr ret_type) { - functions_.push_back(std::make_unique(name, std::move(ret_type))); + std::shared_ptr ret_type, + std::vector> param_types) { + functions_.push_back(std::make_unique(name, std::move(ret_type), + std::move(param_types))); return functions_.back().get(); } @@ -18,4 +20,15 @@ const std::vector>& Module::GetFunctions() const { return functions_; } +GlobalValue* Module::CreateGlobalValue(const std::string& name, + std::shared_ptr ty, + ConstantValue* init) { + global_values_.push_back(std::make_unique(std::move(ty), name, init)); + return global_values_.back().get(); +} + +const std::vector>& Module::GetGlobalValues() const { + return global_values_; +} + } // namespace ir diff --git a/src/ir/Type.cpp b/src/ir/Type.cpp index 3e1684d..e6d302c 100644 --- a/src/ir/Type.cpp +++ b/src/ir/Type.cpp @@ -20,6 +20,21 @@ const std::shared_ptr& Type::GetPtrInt32Type() { return type; } +const std::shared_ptr& Type::GetFloatType() { + static const std::shared_ptr type = std::make_shared(Kind::Float); + return type; +} + +const std::shared_ptr& Type::GetPtrFloatType() { + static const std::shared_ptr type = std::make_shared(Kind::PtrFloat); + return type; +} + +const std::shared_ptr& Type::GetLabelType() { + static const std::shared_ptr type = std::make_shared(Kind::Label); + return type; +} + Type::Kind Type::GetKind() const { return kind_; } bool Type::IsVoid() const { return kind_ == Kind::Void; } @@ -28,4 +43,29 @@ bool Type::IsInt32() const { return kind_ == Kind::Int32; } bool Type::IsPtrInt32() const { return kind_ == Kind::PtrInt32; } +bool Type::IsFloat() const { return kind_ == Kind::Float; } + +bool Type::IsPtrFloat() const { return kind_ == Kind::PtrFloat; } + +bool Type::IsLabel() const { return kind_ == Kind::Label; } + +bool Type::IsArray() const { return kind_ == Kind::Array; } + +std::shared_ptr Type::GetAsArrayType() { + if (IsArray()) { + return std::static_pointer_cast(shared_from_this()); + } + return nullptr; +} + +ArrayType::ArrayType(std::shared_ptr element_type, uint32_t num_elements) + : Type(Kind::Array), + element_type_(std::move(element_type)), + num_elements_(num_elements) {} + +std::shared_ptr ArrayType::Get(std::shared_ptr element_type, + uint32_t num_elements) { + return std::make_shared(std::move(element_type), num_elements); +} + } // namespace ir diff --git a/src/ir/Value.cpp b/src/ir/Value.cpp index 2e9f4c1..32082be 100644 --- a/src/ir/Value.cpp +++ b/src/ir/Value.cpp @@ -22,6 +22,12 @@ bool Value::IsInt32() const { return type_ && type_->IsInt32(); } bool Value::IsPtrInt32() const { return type_ && type_->IsPtrInt32(); } +bool Value::IsFloat() const { return type_ && type_->IsFloat(); } + +bool Value::IsPtrFloat() const { return type_ && type_->IsPtrFloat(); } + +bool Value::IsLabel() const { return type_ && type_->IsLabel(); } + bool Value::IsConstant() const { return dynamic_cast(this) != nullptr; } @@ -38,6 +44,14 @@ bool Value::IsFunction() const { return dynamic_cast(this) != nullptr; } +bool Value::IsGlobalValue() const { + return dynamic_cast(this) != nullptr; +} + +bool Value::IsArgument() const { + return dynamic_cast(this) != nullptr; +} + void Value::AddUse(User* user, size_t operand_index) { if (!user) return; uses_.push_back(Use(this, user, operand_index)); @@ -74,10 +88,27 @@ void Value::ReplaceAllUsesWith(Value* new_value) { } } +Argument::Argument(std::shared_ptr ty, std::string name, Function* parent, + unsigned arg_no) + : Value(std::move(ty), std::move(name)), + parent_(parent), + arg_no_(arg_no) {} + ConstantValue::ConstantValue(std::shared_ptr ty, std::string name) : Value(std::move(ty), std::move(name)) {} ConstantInt::ConstantInt(std::shared_ptr ty, int v) : ConstantValue(std::move(ty), ""), value_(v) {} +ConstantFloat::ConstantFloat(std::shared_ptr ty, float v) + : ConstantValue(std::move(ty), ""), value_(v) {} + +GlobalValue::GlobalValue(std::shared_ptr ty, std::string name, + ConstantValue* init) + : User(std::move(ty), std::move(name)), init_(init) { + if (init_) { + AddOperand(init_); + } +} + } // namespace ir diff --git a/src/irgen/IRGenDecl.cpp b/src/irgen/IRGenDecl.cpp index 0eb62ae..0438372 100644 --- a/src/irgen/IRGenDecl.cpp +++ b/src/irgen/IRGenDecl.cpp @@ -1,6 +1,7 @@ #include "irgen/IRGen.h" #include +#include #include "SysYParser.h" #include "ir/IR.h" @@ -8,100 +9,209 @@ namespace { -std::string GetLValueName(SysYParser::LValueContext& lvalue) { - if (!lvalue.ID()) { - throw std::runtime_error(FormatError("irgen", "非法左值")); - } - return lvalue.ID()->getText(); +std::shared_ptr BaseTypeFromDecl(SysYParser::BtypeContext* btype) { + return (btype && btype->FLOAT()) ? ir::Type::GetFloatType() : ir::Type::GetInt32Type(); +} + +std::shared_ptr StorageType(std::shared_ptr ty) { + if (ty->IsInt32()) return ir::Type::GetPtrInt32Type(); + if (ty->IsFloat()) return ir::Type::GetPtrFloatType(); + return ty; +} + +size_t CountScalars(const std::shared_ptr& ty) { + if (!ty->IsArray()) return 1; + auto arr_ty = ty->GetAsArrayType(); + return arr_ty->GetNumElements() * CountScalars(arr_ty->GetElementType()); } } // namespace -std::any IRGenImpl::visitBlockStmt(SysYParser::BlockStmtContext* ctx) { - if (!ctx) { - throw std::runtime_error(FormatError("irgen", "缺少语句块")); +void IRGenImpl::ZeroInitializeLocal(ir::Value* ptr, std::shared_ptr ty) { + if (ty->IsArray()) { + auto arr_ty = ty->GetAsArrayType(); + for (uint32_t i = 0; i < arr_ty->GetNumElements(); ++i) { + auto* elem_ptr = builder_.CreateGEP(StorageType(arr_ty->GetElementType()), ptr, + {builder_.CreateConstInt(0), + builder_.CreateConstInt(static_cast(i))}, + module_.GetContext().NextTemp()); + ZeroInitializeLocal(elem_ptr, arr_ty->GetElementType()); + } + return; } - for (auto* item : ctx->blockItem()) { - if (item) { - if (VisitBlockItemResult(*item) == BlockFlow::Terminated) { - // 当前语法要求 return 为块内最后一条语句;命中后可停止生成。 - break; + + ir::Value* zero = ty->IsFloat() ? static_cast(builder_.CreateConstFloat(0.0f)) + : static_cast(builder_.CreateConstInt(0)); + builder_.CreateStore(zero, ptr); +} + +void IRGenImpl::EmitLocalInitValue(ir::Value* ptr, std::shared_ptr ty, + SysYParser::InitValueContext* init) { + if (!init) return; + + auto build_flat_scalar_ptr = [&](ir::Value* base_ptr, + const std::shared_ptr& base_ty, + size_t flat_index) -> ir::Value* { + if (!base_ty->IsArray()) return base_ptr; + + std::vector indices; + indices.push_back(builder_.CreateConstInt(0)); + auto cur_ty = base_ty; + size_t offset = flat_index; + while (cur_ty->IsArray()) { + auto arr_ty = cur_ty->GetAsArrayType(); + const size_t step = CountScalars(arr_ty->GetElementType()); + const size_t idx = step == 0 ? 0 : offset / step; + offset = step == 0 ? 0 : offset % step; + indices.push_back(builder_.CreateConstInt(static_cast(idx))); + cur_ty = arr_ty->GetElementType(); + } + return builder_.CreateGEP(StorageType(cur_ty), base_ptr, indices, + module_.GetContext().NextTemp()); + }; + + if (ty->IsArray()) { + auto arr_ty = ty->GetAsArrayType(); + const auto elem_ty = arr_ty->GetElementType(); + const size_t elem_step = CountScalars(elem_ty); + + if (init->exp()) { + auto* elem_ptr = build_flat_scalar_ptr(ptr, ty, 0); + EmitLocalInitValue(elem_ptr, elem_ty, init); + return; + } + + const auto& children = init->initValue(); + size_t scalar_cursor = 0; + for (auto* child : children) { + if (scalar_cursor >= CountScalars(ty)) break; + if (child->exp()) { + auto* elem_ptr = build_flat_scalar_ptr(ptr, ty, scalar_cursor); + auto scalar_ty = elem_ty; + while (scalar_ty->IsArray()) { + scalar_ty = scalar_ty->GetAsArrayType()->GetElementType(); + } + EmitLocalInitValue(elem_ptr, scalar_ty, child); + ++scalar_cursor; + continue; + } + + const size_t elem_index = elem_step == 0 ? 0 : scalar_cursor / elem_step; + if (elem_index >= arr_ty->GetNumElements()) break; + auto* elem_ptr = builder_.CreateGEP(StorageType(elem_ty), ptr, + {builder_.CreateConstInt(0), + builder_.CreateConstInt(static_cast(elem_index))}, + module_.GetContext().NextTemp()); + EmitLocalInitValue(elem_ptr, elem_ty, child); + scalar_cursor = (elem_index + 1) * elem_step; + } + return; + } + + if (!init->exp()) { + return; + } + ir::Value* value = EvalExpr(*init->exp()); + if (ty->IsFloat() && value->GetType()->IsInt32()) { + value = builder_.CreateSIToFP(value, ty, module_.GetContext().NextTemp()); + } else if (ty->IsInt32() && value->GetType()->IsFloat()) { + value = builder_.CreateFPToSI(value, ty, module_.GetContext().NextTemp()); + } + builder_.CreateStore(value, ptr); +} + +std::any IRGenImpl::visitDecl(SysYParser::DeclContext* ctx) { + if (ctx->constDecl()) return ctx->constDecl()->accept(this); + if (ctx->varDecl()) return ctx->varDecl()->accept(this); + return {}; +} + +std::any IRGenImpl::visitConstDecl(SysYParser::ConstDeclContext* ctx) { + for (auto* def : ctx->constDef()) { + def->accept(this); + } + return {}; +} + +std::any IRGenImpl::visitVarDecl(SysYParser::VarDeclContext* ctx) { + for (auto* def : ctx->varDef()) { + def->accept(this); + } + return {}; +} + +std::any IRGenImpl::visitConstDef(SysYParser::ConstDefContext* ctx) { + const std::string name = ctx->ID()->getText(); + auto ty = BaseTypeFromDecl( + dynamic_cast(ctx->parent)->btype()); + const auto dims = ctx->exp(); + for (auto it = dims.rbegin(); it != dims.rend(); ++it) { + auto* dim = EvalConstExpr(**it); + if (auto* ci = dynamic_cast(dim)) { + ty = ir::ArrayType::Get(ty, ci->GetValue()); + continue; + } + throw std::runtime_error(FormatError("irgen", "数组维度必须是整型常量")); + } + ir::Value* slot = nullptr; + + if (is_global_scope_) { + ir::ConstantValue* init = nullptr; + if (ctx->initValue() && ctx->initValue()->exp()) { + init = EvalConstExpr(*ctx->initValue()->exp()); + if (ty->IsInt32() && init->GetType()->IsFloat()) { + init = module_.GetContext().GetConstInt( + static_cast(static_cast(init)->GetValue())); + } else if (ty->IsFloat() && init->GetType()->IsInt32()) { + init = module_.GetContext().GetConstFloat( + static_cast(static_cast(init)->GetValue())); } } - } - return {}; -} - -IRGenImpl::BlockFlow IRGenImpl::VisitBlockItemResult( - SysYParser::BlockItemContext& item) { - return std::any_cast(item.accept(this)); -} - -std::any IRGenImpl::visitBlockItem(SysYParser::BlockItemContext* ctx) { - if (!ctx) { - throw std::runtime_error(FormatError("irgen", "缺少块内项")); - } - if (ctx->decl()) { - ctx->decl()->accept(this); - return BlockFlow::Continue; - } - if (ctx->stmt()) { - return ctx->stmt()->accept(this); - } - throw std::runtime_error(FormatError("irgen", "暂不支持的语句或声明")); -} - -// 变量声明的 IR 生成目前也是最小实现: -// - 先检查声明的基础类型,当前仅支持局部 int; -// - 再把 Decl 中的变量定义交给 visitVarDef 继续处理。 -// -// 和更完整的版本相比,这里还没有: -// - 一个 Decl 中多个变量定义的顺序处理; -// - const、数组、全局变量等不同声明形态; -// - 更丰富的类型系统。 -std::any IRGenImpl::visitDecl(SysYParser::DeclContext* ctx) { - if (!ctx) { - throw std::runtime_error(FormatError("irgen", "缺少变量声明")); - } - if (!ctx->btype() || !ctx->btype()->INT()) { - throw std::runtime_error(FormatError("irgen", "当前仅支持局部 int 变量声明")); - } - auto* var_def = ctx->varDef(); - if (!var_def) { - throw std::runtime_error(FormatError("irgen", "非法变量声明")); - } - var_def->accept(this); - return {}; -} - - -// 当前仍是教学用的最小版本,因此这里只支持: -// - 局部 int 变量; -// - 标量初始化; -// - 一个 VarDef 对应一个槽位。 -std::any IRGenImpl::visitVarDef(SysYParser::VarDefContext* ctx) { - if (!ctx) { - throw std::runtime_error(FormatError("irgen", "缺少变量定义")); - } - if (!ctx->lValue()) { - throw std::runtime_error(FormatError("irgen", "变量声明缺少名称")); - } - GetLValueName(*ctx->lValue()); - if (storage_map_.find(ctx) != storage_map_.end()) { - throw std::runtime_error(FormatError("irgen", "声明重复生成存储槽位")); - } - auto* slot = builder_.CreateAllocaI32(module_.GetContext().NextTemp()); - storage_map_[ctx] = slot; - - ir::Value* init = nullptr; - if (auto* init_value = ctx->initValue()) { - if (!init_value->exp()) { - throw std::runtime_error(FormatError("irgen", "当前不支持聚合初始化")); - } - init = EvalExpr(*init_value->exp()); + slot = module_.CreateGlobalValue(name, StorageType(ty), init); } else { - init = builder_.CreateConstInt(0); + slot = builder_.CreateAlloca(StorageType(ty), name); + ZeroInitializeLocal(slot, ty); + EmitLocalInitValue(slot, ty, ctx->initValue()); } - builder_.CreateStore(init, slot); + + storage_map_[ctx] = slot; + return {}; +} + +std::any IRGenImpl::visitVarDef(SysYParser::VarDefContext* ctx) { + const std::string name = ctx->ID()->getText(); + auto ty = BaseTypeFromDecl( + dynamic_cast(ctx->parent)->btype()); + const auto dims = ctx->exp(); + for (auto it = dims.rbegin(); it != dims.rend(); ++it) { + auto* dim = EvalConstExpr(**it); + if (auto* ci = dynamic_cast(dim)) { + ty = ir::ArrayType::Get(ty, ci->GetValue()); + continue; + } + throw std::runtime_error(FormatError("irgen", "数组维度必须是整型常量")); + } + ir::Value* slot = nullptr; + + if (is_global_scope_) { + ir::ConstantValue* init = nullptr; + if (ctx->initValue() && ctx->initValue()->exp()) { + init = EvalConstExpr(*ctx->initValue()->exp()); + if (ty->IsInt32() && init->GetType()->IsFloat()) { + init = module_.GetContext().GetConstInt( + static_cast(static_cast(init)->GetValue())); + } else if (ty->IsFloat() && init->GetType()->IsInt32()) { + init = module_.GetContext().GetConstFloat( + static_cast(static_cast(init)->GetValue())); + } + } + slot = module_.CreateGlobalValue(name, StorageType(ty), init); + } else { + slot = builder_.CreateAlloca(StorageType(ty), name); + ZeroInitializeLocal(slot, ty); + if (ctx->initValue()) EmitLocalInitValue(slot, ty, ctx->initValue()); + } + + storage_map_[ctx] = slot; return {}; } diff --git a/src/irgen/IRGenDriver.cpp b/src/irgen/IRGenDriver.cpp index ff94412..3fc2ea6 100644 --- a/src/irgen/IRGenDriver.cpp +++ b/src/irgen/IRGenDriver.cpp @@ -6,9 +6,27 @@ #include "ir/IR.h" #include "utils/Log.h" +static void PredeclareLibraryFunctions(ir::Module& module) { + module.CreateFunction("getint", ir::Type::GetInt32Type(), {}); + module.CreateFunction("getch", ir::Type::GetInt32Type(), {}); + module.CreateFunction("getfloat", ir::Type::GetFloatType(), {}); + module.CreateFunction("getarray", ir::Type::GetInt32Type(), {ir::Type::GetPtrInt32Type()}); + module.CreateFunction("getfarray", ir::Type::GetInt32Type(), {ir::Type::GetPtrFloatType()}); + module.CreateFunction("putint", ir::Type::GetVoidType(), {ir::Type::GetInt32Type()}); + module.CreateFunction("putch", ir::Type::GetVoidType(), {ir::Type::GetInt32Type()}); + module.CreateFunction("putfloat", ir::Type::GetVoidType(), {ir::Type::GetFloatType()}); + module.CreateFunction("putarray", ir::Type::GetVoidType(), {ir::Type::GetInt32Type(), ir::Type::GetPtrInt32Type()}); + module.CreateFunction("putfarray", ir::Type::GetVoidType(), {ir::Type::GetInt32Type(), ir::Type::GetPtrFloatType()}); + module.CreateFunction("starttime", ir::Type::GetVoidType(), {}); + module.CreateFunction("stoptime", ir::Type::GetVoidType(), {}); + // putf is special, but for now we might not support it fully or just declare it simply + // module.CreateFunction("putf", ...); +} + std::unique_ptr GenerateIR(SysYParser::CompUnitContext& tree, const SemanticContext& sema) { auto module = std::make_unique(); + PredeclareLibraryFunctions(*module); IRGenImpl gen(*module, sema); tree.accept(&gen); return module; diff --git a/src/irgen/IRGenExp.cpp b/src/irgen/IRGenExp.cpp index cf4797c..89a9513 100644 --- a/src/irgen/IRGenExp.cpp +++ b/src/irgen/IRGenExp.cpp @@ -1,80 +1,724 @@ #include "irgen/IRGen.h" +#include #include +#include +#include #include "SysYParser.h" #include "ir/IR.h" #include "utils/Log.h" -// 表达式生成当前也只实现了很小的一个子集。 -// 目前支持: -// - 整数字面量 -// - 普通局部变量读取 -// - 括号表达式 -// - 二元加法 -// -// 还未支持: -// - 减乘除与一元运算 -// - 赋值表达式 -// - 函数调用 -// - 数组、指针、下标访问 -// - 条件与比较表达式 -// - ... +namespace { + +bool IsZero(const ir::ConstantValue* value) { + if (auto* ci = dynamic_cast(value)) { + return ci->GetValue() == 0; + } + if (auto* cf = dynamic_cast(value)) { + return cf->GetValue() == 0.0f; + } + return false; +} + +bool IsTruthy(const ir::ConstantValue* value) { + return !IsZero(value); +} + +int AsInt(const ir::ConstantValue* value) { + if (auto* ci = dynamic_cast(value)) { + return ci->GetValue(); + } + if (auto* cf = dynamic_cast(value)) { + return static_cast(cf->GetValue()); + } + throw std::runtime_error(FormatError("irgen", "无法将常量转换为 int")); +} + +float AsFloat(const ir::ConstantValue* value) { + if (auto* cf = dynamic_cast(value)) { + return cf->GetValue(); + } + if (auto* ci = dynamic_cast(value)) { + return static_cast(ci->GetValue()); + } + throw std::runtime_error(FormatError("irgen", "无法将常量转换为 float")); +} + +std::shared_ptr ScalarPointerType(std::shared_ptr ty) { + if (ty->IsInt32()) return ir::Type::GetPtrInt32Type(); + if (ty->IsFloat()) return ir::Type::GetPtrFloatType(); + return ty; +} + +std::shared_ptr CommonArithType(ir::Value* lhs, ir::Value* rhs) { + if (lhs->GetType()->IsFloat() || rhs->GetType()->IsFloat()) { + return ir::Type::GetFloatType(); + } + return ir::Type::GetInt32Type(); +} + +} // namespace + ir::Value* IRGenImpl::EvalExpr(SysYParser::ExpContext& expr) { return std::any_cast(expr.accept(this)); } +ir::ConstantValue* IRGenImpl::EvalConstExpr(SysYParser::ExpContext& expr) { + class ConstExprVisitor final : public SysYBaseVisitor { + public: + ConstExprVisitor(ir::Module& module, const SemanticContext& sema) + : module_(module), sema_(sema) {} + + std::any visitParenExp(SysYParser::ParenExpContext* ctx) override { + return Eval(*ctx->exp()); + } + + std::any visitNumberExp(SysYParser::NumberExpContext* ctx) override { + if (ctx->number()->ILITERAL()) { + const std::string text = ctx->number()->ILITERAL()->getText(); + int value = 0; + if (text.size() > 2 && (text[1] == 'x' || text[1] == 'X')) { + value = std::stoi(text, nullptr, 16); + } else if (text.size() > 1 && text[0] == '0') { + value = std::stoi(text, nullptr, 8); + } else { + value = std::stoi(text, nullptr, 10); + } + return static_cast(module_.GetContext().GetConstInt(value)); + } + return static_cast( + module_.GetContext().GetConstFloat(std::stof(ctx->number()->FLITERAL()->getText()))); + } + + std::any visitLValueExp(SysYParser::LValueExpContext* ctx) override { + auto* def = sema_.ResolveLValue(ctx->lValue()); + if (!def) { + throw std::runtime_error(FormatError("irgen", "常量表达式引用了未绑定左值")); + } + if (!ctx->lValue()->exp().empty()) { + throw std::runtime_error( + FormatError("irgen", "暂不支持在常量表达式中访问数组元素")); + } + if (auto* const_def = dynamic_cast(def)) { + if (!const_def->initValue() || !const_def->initValue()->exp()) { + throw std::runtime_error( + FormatError("irgen", "常量缺少标量初始化表达式")); + } + return Eval(*const_def->initValue()->exp()); + } + throw std::runtime_error( + FormatError("irgen", "全局/常量表达式必须是编译期常量")); + } + + std::any visitUnaryAddExp(SysYParser::UnaryAddExpContext* ctx) override { + return Eval(*ctx->exp()); + } + + std::any visitUnarySubExp(SysYParser::UnarySubExpContext* ctx) override { + auto* value = Eval(*ctx->exp()); + if (value->GetType()->IsFloat()) { + return static_cast( + module_.GetContext().GetConstFloat(-AsFloat(value))); + } + return static_cast( + module_.GetContext().GetConstInt(-AsInt(value))); + } + + std::any visitNotExp(SysYParser::NotExpContext* ctx) override { + return static_cast( + module_.GetContext().GetConstInt(IsTruthy(Eval(*ctx->exp())) ? 0 : 1)); + } + + std::any visitAddExp(SysYParser::AddExpContext* ctx) override { + auto* lhs = Eval(*ctx->exp(0)); + auto* rhs = Eval(*ctx->exp(1)); + if (lhs->GetType()->IsFloat() || rhs->GetType()->IsFloat()) { + return static_cast( + module_.GetContext().GetConstFloat(AsFloat(lhs) + AsFloat(rhs))); + } + return static_cast( + module_.GetContext().GetConstInt(AsInt(lhs) + AsInt(rhs))); + } + + std::any visitSubExp(SysYParser::SubExpContext* ctx) override { + auto* lhs = Eval(*ctx->exp(0)); + auto* rhs = Eval(*ctx->exp(1)); + if (lhs->GetType()->IsFloat() || rhs->GetType()->IsFloat()) { + return static_cast( + module_.GetContext().GetConstFloat(AsFloat(lhs) - AsFloat(rhs))); + } + return static_cast( + module_.GetContext().GetConstInt(AsInt(lhs) - AsInt(rhs))); + } + + std::any visitMulExp(SysYParser::MulExpContext* ctx) override { + auto* lhs = Eval(*ctx->exp(0)); + auto* rhs = Eval(*ctx->exp(1)); + if (lhs->GetType()->IsFloat() || rhs->GetType()->IsFloat()) { + return static_cast( + module_.GetContext().GetConstFloat(AsFloat(lhs) * AsFloat(rhs))); + } + return static_cast( + module_.GetContext().GetConstInt(AsInt(lhs) * AsInt(rhs))); + } + + std::any visitDivExp(SysYParser::DivExpContext* ctx) override { + auto* lhs = Eval(*ctx->exp(0)); + auto* rhs = Eval(*ctx->exp(1)); + if (lhs->GetType()->IsFloat() || rhs->GetType()->IsFloat()) { + const float rv = AsFloat(rhs); + return static_cast(module_.GetContext().GetConstFloat( + rv == 0.0f ? 0.0f : AsFloat(lhs) / rv)); + } + const int rv = AsInt(rhs); + return static_cast( + module_.GetContext().GetConstInt(rv == 0 ? 0 : AsInt(lhs) / rv)); + } + + std::any visitModExp(SysYParser::ModExpContext* ctx) override { + auto* lhs = Eval(*ctx->exp(0)); + auto* rhs = Eval(*ctx->exp(1)); + return static_cast(module_.GetContext().GetConstInt( + AsInt(rhs) == 0 ? 0 : AsInt(lhs) % AsInt(rhs))); + } + + std::any visitLtExp(SysYParser::LtExpContext* ctx) override { + return EvalCmpImpl(*ctx->exp(0), *ctx->exp(1), ir::Opcode::ICmpLT); + } + std::any visitLeExp(SysYParser::LeExpContext* ctx) override { + return EvalCmpImpl(*ctx->exp(0), *ctx->exp(1), ir::Opcode::ICmpLE); + } + std::any visitGtExp(SysYParser::GtExpContext* ctx) override { + return EvalCmpImpl(*ctx->exp(0), *ctx->exp(1), ir::Opcode::ICmpGT); + } + std::any visitGeExp(SysYParser::GeExpContext* ctx) override { + return EvalCmpImpl(*ctx->exp(0), *ctx->exp(1), ir::Opcode::ICmpGE); + } + std::any visitEqExp(SysYParser::EqExpContext* ctx) override { + return EvalCmpImpl(*ctx->exp(0), *ctx->exp(1), ir::Opcode::ICmpEQ); + } + std::any visitNeExp(SysYParser::NeExpContext* ctx) override { + return EvalCmpImpl(*ctx->exp(0), *ctx->exp(1), ir::Opcode::ICmpNE); + } + + std::any visitAndExp(SysYParser::AndExpContext* ctx) override { + auto* lhs = Eval(*ctx->exp(0)); + if (!IsTruthy(lhs)) { + return static_cast(module_.GetContext().GetConstInt(0)); + } + return static_cast( + module_.GetContext().GetConstInt(IsTruthy(Eval(*ctx->exp(1))) ? 1 : 0)); + } + + std::any visitOrExp(SysYParser::OrExpContext* ctx) override { + auto* lhs = Eval(*ctx->exp(0)); + if (IsTruthy(lhs)) { + return static_cast(module_.GetContext().GetConstInt(1)); + } + return static_cast( + module_.GetContext().GetConstInt(IsTruthy(Eval(*ctx->exp(1))) ? 1 : 0)); + } + + ir::ConstantValue* Eval(SysYParser::ExpContext& ctx) { + return std::any_cast(ctx.accept(this)); + } + + private: + ir::ConstantValue* EvalCmpImpl(SysYParser::ExpContext& lhs_ctx, + SysYParser::ExpContext& rhs_ctx, + ir::Opcode op) { + auto* lhs = Eval(lhs_ctx); + auto* rhs = Eval(rhs_ctx); + bool result = false; + if (lhs->GetType()->IsFloat() || rhs->GetType()->IsFloat()) { + const float a = AsFloat(lhs); + const float b = AsFloat(rhs); + switch (op) { + case ir::Opcode::ICmpLT: result = a < b; break; + case ir::Opcode::ICmpLE: result = a <= b; break; + case ir::Opcode::ICmpGT: result = a > b; break; + case ir::Opcode::ICmpGE: result = a >= b; break; + case ir::Opcode::ICmpEQ: result = a == b; break; + case ir::Opcode::ICmpNE: result = a != b; break; + default: break; + } + } else { + const int a = AsInt(lhs); + const int b = AsInt(rhs); + switch (op) { + case ir::Opcode::ICmpLT: result = a < b; break; + case ir::Opcode::ICmpLE: result = a <= b; break; + case ir::Opcode::ICmpGT: result = a > b; break; + case ir::Opcode::ICmpGE: result = a >= b; break; + case ir::Opcode::ICmpEQ: result = a == b; break; + case ir::Opcode::ICmpNE: result = a != b; break; + default: break; + } + } + return module_.GetContext().GetConstInt(result ? 1 : 0); + } + + ir::Module& module_; + const SemanticContext& sema_; + }; + + ConstExprVisitor visitor(module_, sema_); + return visitor.Eval(expr); +} + +static ir::Value* CastValue(IRGenImpl& gen, ir::IRBuilder& builder, ir::Module& module, + ir::Value* value, std::shared_ptr target_ty) { + if (value->GetType() == target_ty) return value; + if (target_ty->IsFloat() && value->GetType()->IsInt32()) { + if (auto* ci = dynamic_cast(value)) { + return module.GetContext().GetConstFloat(static_cast(ci->GetValue())); + } + return builder.CreateSIToFP(value, target_ty, module.GetContext().NextTemp()); + } + if (target_ty->IsInt32() && value->GetType()->IsFloat()) { + if (auto* cf = dynamic_cast(value)) { + return module.GetContext().GetConstInt(static_cast(cf->GetValue())); + } + return builder.CreateFPToSI(value, target_ty, module.GetContext().NextTemp()); + } + return value; +} std::any IRGenImpl::visitParenExp(SysYParser::ParenExpContext* ctx) { - if (!ctx || !ctx->exp()) { - throw std::runtime_error(FormatError("irgen", "非法括号表达式")); - } return EvalExpr(*ctx->exp()); } +std::any IRGenImpl::visitLValueExp(SysYParser::LValueExpContext* ctx) { + auto* def = sema_.ResolveLValue(ctx->lValue()); + if (def && IsArrayLikeDef(def) && ctx->lValue()->exp().size() < GetArrayRank(def)) { + return DecayArrayPtr(ctx->lValue()); + } + ir::Value* ptr = GetLValuePtr(ctx->lValue()); + return static_cast( + builder_.CreateLoad(ptr, module_.GetContext().NextTemp())); +} std::any IRGenImpl::visitNumberExp(SysYParser::NumberExpContext* ctx) { - if (!ctx || !ctx->number() || !ctx->number()->ILITERAL()) { - throw std::runtime_error(FormatError("irgen", "当前仅支持整数字面量")); - } - return static_cast( - builder_.CreateConstInt(std::stoi(ctx->number()->getText()))); + return static_cast(EvalConstExpr(*ctx)); } -// 变量使用的处理流程: -// 1. 先通过语义分析结果把变量使用绑定回声明; -// 2. 再通过 storage_map_ 找到该声明对应的栈槽位; -// 3. 最后生成 load,把内存中的值读出来。 -// -// 因此当前 IRGen 自己不再做名字查找,而是直接消费 Sema 的绑定结果。 -std::any IRGenImpl::visitVarExp(SysYParser::VarExpContext* ctx) { - if (!ctx || !ctx->var() || !ctx->var()->ID()) { - throw std::runtime_error(FormatError("irgen", "当前仅支持普通整型变量")); +std::any IRGenImpl::visitFuncCallExp(SysYParser::FuncCallExpContext* ctx) { + ir::Function* target_func = nullptr; + if (auto* def = sema_.ResolveFuncCall(ctx)) { + const std::string func_name = def->ID()->getText(); + for (const auto& f : module_.GetFunctions()) { + if (f->GetName() == func_name) { + target_func = f.get(); + break; + } + } + } else { + const std::string func_name = ctx->ID()->getText(); + for (const auto& f : module_.GetFunctions()) { + if (f->GetName() == func_name) { + target_func = f.get(); + break; + } + } } - auto* decl = sema_.ResolveVarUse(ctx->var()); - if (!decl) { - throw std::runtime_error( - FormatError("irgen", - "变量使用缺少语义绑定: " + ctx->var()->ID()->getText())); + if (!target_func) { + throw std::runtime_error(FormatError("irgen", "找不到函数: " + ctx->ID()->getText())); } - auto it = storage_map_.find(decl); + + std::vector args; + const auto& arg_types = target_func->GetArguments(); + if (ctx->funcRParams()) { + const auto& exps = ctx->funcRParams()->exp(); + for (size_t i = 0; i < exps.size(); ++i) { + ir::Value* arg = EvalExpr(*exps[i]); + if (i < arg_types.size()) { + arg = CastValue(*this, builder_, module_, arg, arg_types[i]->GetType()); + } + args.push_back(arg); + } + } + + return static_cast( + builder_.CreateCall(target_func, args, + target_func->GetType()->IsVoid() + ? "" + : module_.GetContext().NextTemp())); +} + +std::any IRGenImpl::visitNotExp(SysYParser::NotExpContext* ctx) { + ir::Value* val = EvalExpr(*ctx->exp()); + if (auto* cv = dynamic_cast(val)) { + return static_cast( + module_.GetContext().GetConstInt(IsTruthy(cv) ? 0 : 1)); + } + ir::Value* cond = ToI1(val); + ir::Value* res = + builder_.CreateICmp(ir::Opcode::ICmpEQ, cond, builder_.CreateConstInt(0), + module_.GetContext().NextTemp()); + return ToI32(res); +} + +std::any IRGenImpl::visitUnaryAddExp(SysYParser::UnaryAddExpContext* ctx) { + return EvalExpr(*ctx->exp()); +} + +std::any IRGenImpl::visitUnarySubExp(SysYParser::UnarySubExpContext* ctx) { + ir::Value* val = EvalExpr(*ctx->exp()); + if (auto* ci = dynamic_cast(val)) { + return static_cast(module_.GetContext().GetConstInt(-ci->GetValue())); + } + if (auto* cf = dynamic_cast(val)) { + return static_cast(module_.GetContext().GetConstFloat(-cf->GetValue())); + } + if (val->GetType()->IsFloat()) { + return static_cast(builder_.CreateFSub( + builder_.CreateConstFloat(0.0f), val, module_.GetContext().NextTemp())); + } + return static_cast(builder_.CreateSub( + builder_.CreateConstInt(0), val, module_.GetContext().NextTemp())); +} + +#define DEFINE_ARITH_VISITOR(name, int_opcode, float_opcode) \ + std::any IRGenImpl::visit##name##Exp(SysYParser::name##ExpContext* ctx) { \ + ir::Value* lhs = EvalExpr(*ctx->exp(0)); \ + ir::Value* rhs = EvalExpr(*ctx->exp(1)); \ + const auto common_ty = CommonArithType(lhs, rhs); \ + lhs = CastValue(*this, builder_, module_, lhs, common_ty); \ + rhs = CastValue(*this, builder_, module_, rhs, common_ty); \ + if (auto* lconst = dynamic_cast(lhs)) { \ + if (auto* rconst = dynamic_cast(rhs)) { \ + if (common_ty->IsFloat()) { \ + const float lv = AsFloat(lconst); \ + const float rv = AsFloat(rconst); \ + if constexpr (ir::Opcode::float_opcode == ir::Opcode::FAdd) \ + return static_cast(module_.GetContext().GetConstFloat(lv + rv)); \ + if constexpr (ir::Opcode::float_opcode == ir::Opcode::FSub) \ + return static_cast(module_.GetContext().GetConstFloat(lv - rv)); \ + if constexpr (ir::Opcode::float_opcode == ir::Opcode::FMul) \ + return static_cast(module_.GetContext().GetConstFloat(lv * rv)); \ + return static_cast(module_.GetContext().GetConstFloat(rv == 0.0f ? 0.0f : lv / rv)); \ + } \ + const int lv = AsInt(lconst); \ + const int rv = AsInt(rconst); \ + if constexpr (ir::Opcode::int_opcode == ir::Opcode::Add) \ + return static_cast(module_.GetContext().GetConstInt(lv + rv)); \ + if constexpr (ir::Opcode::int_opcode == ir::Opcode::Sub) \ + return static_cast(module_.GetContext().GetConstInt(lv - rv)); \ + if constexpr (ir::Opcode::int_opcode == ir::Opcode::Mul) \ + return static_cast(module_.GetContext().GetConstInt(lv * rv)); \ + return static_cast(module_.GetContext().GetConstInt(rv == 0 ? 0 : lv / rv)); \ + } \ + } \ + if (common_ty->IsFloat()) { \ + if constexpr (ir::Opcode::float_opcode == ir::Opcode::FAdd) \ + return static_cast(builder_.CreateFAdd(lhs, rhs, module_.GetContext().NextTemp())); \ + if constexpr (ir::Opcode::float_opcode == ir::Opcode::FSub) \ + return static_cast(builder_.CreateFSub(lhs, rhs, module_.GetContext().NextTemp())); \ + if constexpr (ir::Opcode::float_opcode == ir::Opcode::FMul) \ + return static_cast(builder_.CreateFMul(lhs, rhs, module_.GetContext().NextTemp())); \ + return static_cast(builder_.CreateFDiv(lhs, rhs, module_.GetContext().NextTemp())); \ + } \ + return static_cast(builder_.CreateBinary(ir::Opcode::int_opcode, lhs, rhs, module_.GetContext().NextTemp())); \ + } + +DEFINE_ARITH_VISITOR(Add, Add, FAdd) +DEFINE_ARITH_VISITOR(Sub, Sub, FSub) +DEFINE_ARITH_VISITOR(Mul, Mul, FMul) +DEFINE_ARITH_VISITOR(Div, Div, FDiv) + +std::any IRGenImpl::visitModExp(SysYParser::ModExpContext* ctx) { + ir::Value* lhs = CastValue(*this, builder_, module_, EvalExpr(*ctx->exp(0)), + ir::Type::GetInt32Type()); + ir::Value* rhs = CastValue(*this, builder_, module_, EvalExpr(*ctx->exp(1)), + ir::Type::GetInt32Type()); + if (auto* lconst = dynamic_cast(lhs)) { + if (auto* rconst = dynamic_cast(rhs)) { + const int rv = AsInt(rconst); + return static_cast( + module_.GetContext().GetConstInt(rv == 0 ? 0 : AsInt(lconst) % rv)); + } + } + return static_cast( + builder_.CreateMod(lhs, rhs, module_.GetContext().NextTemp())); +} + +#define DEFINE_CMP_VISITOR(name, int_opcode, float_opcode, cmp_op) \ + std::any IRGenImpl::visit##name##Exp(SysYParser::name##ExpContext* ctx) { \ + ir::Value* lhs = EvalExpr(*ctx->exp(0)); \ + ir::Value* rhs = EvalExpr(*ctx->exp(1)); \ + const auto common_ty = CommonArithType(lhs, rhs); \ + lhs = CastValue(*this, builder_, module_, lhs, common_ty); \ + rhs = CastValue(*this, builder_, module_, rhs, common_ty); \ + if (auto* lconst = dynamic_cast(lhs)) { \ + if (auto* rconst = dynamic_cast(rhs)) { \ + const bool result = common_ty->IsFloat() ? (AsFloat(lconst) cmp_op AsFloat(rconst)) \ + : (AsInt(lconst) cmp_op AsInt(rconst)); \ + return static_cast(module_.GetContext().GetConstInt(result ? 1 : 0)); \ + } \ + } \ + if (common_ty->IsFloat()) { \ + return static_cast(builder_.CreateFCmp(ir::Opcode::float_opcode, lhs, rhs, module_.GetContext().NextTemp())); \ + } \ + return static_cast(builder_.CreateICmp(ir::Opcode::int_opcode, lhs, rhs, module_.GetContext().NextTemp())); \ + } + +DEFINE_CMP_VISITOR(Lt, ICmpLT, FCmpLT, <) +DEFINE_CMP_VISITOR(Le, ICmpLE, FCmpLE, <=) +DEFINE_CMP_VISITOR(Gt, ICmpGT, FCmpGT, >) +DEFINE_CMP_VISITOR(Ge, ICmpGE, FCmpGE, >=) +DEFINE_CMP_VISITOR(Eq, ICmpEQ, FCmpEQ, ==) +DEFINE_CMP_VISITOR(Ne, ICmpNE, FCmpNE, !=) + +std::any IRGenImpl::visitAndExp(SysYParser::AndExpContext* ctx) { + if (!builder_.GetInsertBlock()) { + return static_cast(EvalConstExpr(*ctx)); + } + + ir::Value* lhs = EvalExpr(*ctx->exp(0)); + if (auto* c = dynamic_cast(lhs); c && !IsTruthy(c)) { + return static_cast(module_.GetContext().GetConstInt(0)); + } + + const std::string suffix = module_.GetContext().NextTemp(); + ir::BasicBlock* rhs_bb = func_->CreateBlock("and.rhs." + suffix); + ir::BasicBlock* merge_bb = func_->CreateBlock("and.merge." + suffix); + auto* res_ptr = builder_.CreateAllocaI32(module_.GetContext().NextTemp()); + + builder_.CreateStore(ToI32(ToI1(lhs)), res_ptr); + builder_.CreateCondBr(ToI1(lhs), rhs_bb, merge_bb); + + builder_.SetInsertPoint(rhs_bb); + ir::Value* rhs = EvalExpr(*ctx->exp(1)); + builder_.CreateStore(ToI32(ToI1(rhs)), res_ptr); + builder_.CreateBr(merge_bb); + + builder_.SetInsertPoint(merge_bb); + return static_cast( + builder_.CreateLoad(res_ptr, module_.GetContext().NextTemp())); +} + +std::any IRGenImpl::visitOrExp(SysYParser::OrExpContext* ctx) { + if (!builder_.GetInsertBlock()) { + return static_cast(EvalConstExpr(*ctx)); + } + + ir::Value* lhs = EvalExpr(*ctx->exp(0)); + if (auto* c = dynamic_cast(lhs); c && IsTruthy(c)) { + return static_cast(module_.GetContext().GetConstInt(1)); + } + + const std::string suffix = module_.GetContext().NextTemp(); + ir::BasicBlock* rhs_bb = func_->CreateBlock("or.rhs." + suffix); + ir::BasicBlock* merge_bb = func_->CreateBlock("or.merge." + suffix); + auto* res_ptr = builder_.CreateAllocaI32(module_.GetContext().NextTemp()); + + builder_.CreateStore(ToI32(ToI1(lhs)), res_ptr); + builder_.CreateCondBr(ToI1(lhs), merge_bb, rhs_bb); + + builder_.SetInsertPoint(rhs_bb); + ir::Value* rhs = EvalExpr(*ctx->exp(1)); + builder_.CreateStore(ToI32(ToI1(rhs)), res_ptr); + builder_.CreateBr(merge_bb); + + builder_.SetInsertPoint(merge_bb); + return static_cast( + builder_.CreateLoad(res_ptr, module_.GetContext().NextTemp())); +} + +bool IRGenImpl::IsArrayLikeDef(antlr4::ParserRuleContext* def) const { + if (auto* const_def = dynamic_cast(def)) { + return !const_def->exp().empty(); + } + if (auto* var_def = dynamic_cast(def)) { + return !var_def->exp().empty(); + } + if (auto* param = dynamic_cast(def)) { + return !param->LBRACK().empty(); + } + return false; +} + +size_t IRGenImpl::GetArrayRank(antlr4::ParserRuleContext* def) const { + if (auto* const_def = dynamic_cast(def)) { + return const_def->exp().size(); + } + if (auto* var_def = dynamic_cast(def)) { + return var_def->exp().size(); + } + if (auto* param = dynamic_cast(def)) { + return param->LBRACK().size(); + } + return 0; +} + +std::shared_ptr IRGenImpl::GetDefType(antlr4::ParserRuleContext* def) const { + std::shared_ptr ty = ir::Type::GetInt32Type(); + auto apply_dims = [&](const auto& dims) { + auto result = ty; + for (auto it = dims.rbegin(); it != dims.rend(); ++it) { + if constexpr (std::is_pointer_v>) { + auto* dim_val = const_cast(this)->EvalConstExpr(**it); + result = ir::ArrayType::Get(result, AsInt(dim_val)); + } else { + auto* dim_val = const_cast(this)->EvalConstExpr(*it); + result = ir::ArrayType::Get(result, AsInt(dim_val)); + } + } + return result; + }; + + if (auto* const_def = dynamic_cast(def)) { + auto* decl = dynamic_cast(const_def->parent); + ty = (decl && decl->btype() && decl->btype()->FLOAT()) ? ir::Type::GetFloatType() + : ir::Type::GetInt32Type(); + return apply_dims(const_def->exp()); + } + if (auto* var_def = dynamic_cast(def)) { + auto* decl = dynamic_cast(var_def->parent); + ty = (decl && decl->btype() && decl->btype()->FLOAT()) ? ir::Type::GetFloatType() + : ir::Type::GetInt32Type(); + return apply_dims(var_def->exp()); + } + if (auto* param = dynamic_cast(def)) { + ty = param->btype()->FLOAT() ? ir::Type::GetFloatType() : ir::Type::GetInt32Type(); + if (param->LBRACK().empty()) return ty; + for (int i = static_cast(param->exp().size()) - 1; i >= 0; --i) { + auto* dim_val = const_cast(this)->EvalConstExpr(*param->exp(i)); + ty = ir::ArrayType::Get(ty, AsInt(dim_val)); + } + return ty; + } + return ty; +} + +ir::Value* IRGenImpl::DecayArrayPtr(SysYParser::LValueContext* ctx) { + auto* def = sema_.ResolveLValue(ctx); + if (!def) { + throw std::runtime_error(FormatError("irgen", "数组退化失败: 未绑定定义")); + } + ir::Value* base_ptr = storage_map_.at(def); + const auto base_ty = GetDefType(def); + + if (dynamic_cast(def)) { + if (ctx->exp().empty()) return base_ptr; + + ir::Value* offset = CastValue(*this, builder_, module_, EvalExpr(*ctx->exp(0)), + ir::Type::GetInt32Type()); + auto cur_ty = base_ty; + for (size_t i = 1; i < ctx->exp().size(); ++i) { + if (!cur_ty->IsArray()) break; + auto arr_ty = cur_ty->GetAsArrayType(); + ir::Value* stride = builder_.CreateConstInt(static_cast(arr_ty->GetNumElements())); + offset = builder_.CreateMul(offset, stride, module_.GetContext().NextTemp()); + offset = builder_.CreateAdd( + offset, + CastValue(*this, builder_, module_, EvalExpr(*ctx->exp(i)), + ir::Type::GetInt32Type()), + module_.GetContext().NextTemp()); + cur_ty = arr_ty->GetElementType(); + } + return builder_.CreateGEP(ScalarPointerType(cur_ty), base_ptr, {offset}, + module_.GetContext().NextTemp()); + } + + std::vector indices; + indices.push_back(builder_.CreateConstInt(0)); + for (auto* exp : ctx->exp()) { + indices.push_back( + CastValue(*this, builder_, module_, EvalExpr(*exp), ir::Type::GetInt32Type())); + } + + auto cur_ty = base_ty; + while (cur_ty->IsArray()) { + cur_ty = cur_ty->GetAsArrayType()->GetElementType(); + } + if (!ctx->exp().empty()) { + return builder_.CreateGEP(ScalarPointerType(cur_ty), base_ptr, indices, + module_.GetContext().NextTemp()); + } + indices.push_back(builder_.CreateConstInt(0)); + return builder_.CreateGEP(ScalarPointerType(cur_ty), base_ptr, indices, + module_.GetContext().NextTemp()); +} + +ir::Value* IRGenImpl::GetLValuePtr(SysYParser::LValueContext* ctx) { + auto* def = sema_.ResolveLValue(ctx); + if (!def) { + throw std::runtime_error(FormatError("irgen", "未定义的左值: " + ctx->ID()->getText())); + } + auto it = storage_map_.find(def); if (it == storage_map_.end()) { throw std::runtime_error( - FormatError("irgen", - "变量声明缺少存储槽位: " + ctx->var()->ID()->getText())); + FormatError("irgen", "左值缺少存储槽位: " + ctx->ID()->getText())); } - return static_cast( - builder_.CreateLoad(it->second, module_.GetContext().NextTemp())); + + ir::Value* base_ptr = it->second; + if (ctx->exp().empty()) return base_ptr; + + if (dynamic_cast(def)) { + return DecayArrayPtr(ctx); + } + + const auto base_ty = GetDefType(def); + std::vector indices; + indices.push_back(builder_.CreateConstInt(0)); + for (auto* exp : ctx->exp()) { + indices.push_back( + CastValue(*this, builder_, module_, EvalExpr(*exp), ir::Type::GetInt32Type())); + } + + auto cur_ty = base_ty; + for (size_t i = 0; i < ctx->exp().size(); ++i) { + if (!cur_ty->IsArray()) { + throw std::runtime_error(FormatError("irgen", "数组下标层数超出定义")); + } + cur_ty = cur_ty->GetAsArrayType()->GetElementType(); + } + return builder_.CreateGEP(ScalarPointerType(cur_ty), base_ptr, indices, + module_.GetContext().NextTemp()); } - -std::any IRGenImpl::visitAdditiveExp(SysYParser::AdditiveExpContext* ctx) { - if (!ctx || !ctx->exp(0) || !ctx->exp(1)) { - throw std::runtime_error(FormatError("irgen", "非法加法表达式")); +ir::Value* IRGenImpl::ToI1(ir::Value* v) { + if (auto* cv = dynamic_cast(v)) { + return module_.GetContext().GetConstInt(IsTruthy(cv) ? 1 : 0); } - ir::Value* lhs = EvalExpr(*ctx->exp(0)); - ir::Value* rhs = EvalExpr(*ctx->exp(1)); - return static_cast( - builder_.CreateBinary(ir::Opcode::Add, lhs, rhs, - module_.GetContext().NextTemp())); + if (auto* inst = dynamic_cast(v)) { + switch (inst->GetOpcode()) { + case ir::Opcode::ICmpEQ: + case ir::Opcode::ICmpNE: + case ir::Opcode::ICmpLT: + case ir::Opcode::ICmpGT: + case ir::Opcode::ICmpLE: + case ir::Opcode::ICmpGE: + case ir::Opcode::FCmpEQ: + case ir::Opcode::FCmpNE: + case ir::Opcode::FCmpLT: + case ir::Opcode::FCmpGT: + case ir::Opcode::FCmpLE: + case ir::Opcode::FCmpGE: + return v; + default: + break; + } + } + if (v->GetType()->IsFloat()) { + return builder_.CreateFCmp(ir::Opcode::FCmpNE, v, builder_.CreateConstFloat(0.0f), + module_.GetContext().NextTemp()); + } + if (v->GetType()->IsInt32()) { + return builder_.CreateICmp(ir::Opcode::ICmpNE, v, builder_.CreateConstInt(0), + module_.GetContext().NextTemp()); + } + return v; +} + +ir::Value* IRGenImpl::ToI32(ir::Value* v) { + if (v->GetType()->IsInt32()) return v; + if (v->GetType()->IsFloat()) { + return builder_.CreateFPToSI(v, ir::Type::GetInt32Type(), + module_.GetContext().NextTemp()); + } + return builder_.CreateZExt(v, ir::Type::GetInt32Type(), + module_.GetContext().NextTemp()); } diff --git a/src/irgen/IRGenFunc.cpp b/src/irgen/IRGenFunc.cpp index 4912d03..f054fb1 100644 --- a/src/irgen/IRGenFunc.cpp +++ b/src/irgen/IRGenFunc.cpp @@ -8,10 +8,18 @@ namespace { +std::shared_ptr StorageType(std::shared_ptr ty) { + if (ty->IsInt32()) return ir::Type::GetPtrInt32Type(); + if (ty->IsFloat()) return ir::Type::GetPtrFloatType(); + return ty; +} + void VerifyFunctionStructure(const ir::Function& func) { - // 当前 IRGen 仍是单入口、顺序生成;这里在生成结束后补一层块终结校验。 for (const auto& bb : func.GetBlocks()) { if (!bb || !bb->HasTerminator()) { + // If a block doesn't have a terminator, it might be an empty function or + // missing a return in a path. For SysY, we should at least have a default return for void. + // But IRGen should have handled it. throw std::runtime_error( FormatError("irgen", "基本块未正确终结: " + (bb ? bb->GetName() : std::string("")))); @@ -25,63 +33,83 @@ IRGenImpl::IRGenImpl(ir::Module& module, const SemanticContext& sema) : module_(module), sema_(sema), func_(nullptr), - builder_(module.GetContext(), nullptr) {} + builder_(module.GetContext(), nullptr), + is_global_scope_(true) {} -// 编译单元的 IR 生成当前只实现了最小功能: -// - Module 已在 GenerateIR 中创建,这里只负责继续生成其中的内容; -// - 当前会读取编译单元中的函数定义,并交给 visitFuncDef 生成函数 IR; -// -// 当前还没有实现: -// - 多个函数定义的遍历与生成; -// - 全局变量、全局常量的 IR 生成。 std::any IRGenImpl::visitCompUnit(SysYParser::CompUnitContext* ctx) { - if (!ctx) { - throw std::runtime_error(FormatError("irgen", "缺少编译单元")); + if (!ctx) return {}; + + is_global_scope_ = true; + for (auto* decl : ctx->decl()) { + decl->accept(this); } - auto* func = ctx->funcDef(); - if (!func) { - throw std::runtime_error(FormatError("irgen", "缺少函数定义")); + for (auto* funcDef : ctx->funcDef()) { + funcDef->accept(this); } - func->accept(this); return {}; } -// 函数 IR 生成当前实现了: -// 1. 获取函数名; -// 2. 检查函数返回类型; -// 3. 在 Module 中创建 Function; -// 4. 将 builder 插入点设置到入口基本块; -// 5. 继续生成函数体。 -// -// 当前还没有实现: -// - 通用函数返回类型处理; -// - 形参列表遍历与参数类型收集; -// - FunctionType 这样的函数类型对象; -// - Argument/形式参数 IR 对象; -// - 入口块中的参数初始化逻辑。 -// ... - -// 因此这里目前只支持最小的“无参 int 函数”生成。 std::any IRGenImpl::visitFuncDef(SysYParser::FuncDefContext* ctx) { - if (!ctx) { - throw std::runtime_error(FormatError("irgen", "缺少函数定义")); - } - if (!ctx->blockStmt()) { - throw std::runtime_error(FormatError("irgen", "函数体为空")); - } - if (!ctx->ID()) { - throw std::runtime_error(FormatError("irgen", "缺少函数名")); - } - if (!ctx->funcType() || !ctx->funcType()->INT()) { - throw std::runtime_error(FormatError("irgen", "当前仅支持无参 int 函数")); + is_global_scope_ = false; + + std::shared_ptr ret_ty; + if (ctx->funcType()->INT()) { + ret_ty = ir::Type::GetInt32Type(); + } else if (ctx->funcType()->FLOAT()) { + ret_ty = ir::Type::GetFloatType(); + } else { + ret_ty = ir::Type::GetVoidType(); } - func_ = module_.CreateFunction(ctx->ID()->getText(), ir::Type::GetInt32Type()); - builder_.SetInsertPoint(func_->GetEntry()); - storage_map_.clear(); + std::string func_name = ctx->ID()->getText(); + + std::vector> param_types; + if (ctx->funcFParams()) { + for (auto* fparam : ctx->funcFParams()->funcFParam()) { + if (fparam->LBRACK().empty()) { + param_types.push_back(fparam->btype()->INT() ? ir::Type::GetInt32Type() : ir::Type::GetFloatType()); + } else { + // Array param is a pointer + param_types.push_back(fparam->btype()->INT() ? ir::Type::GetPtrInt32Type() : ir::Type::GetPtrFloatType()); + } + } + } + + func_ = module_.CreateFunction(func_name, ret_ty, param_types); + builder_.SetInsertPoint(func_->CreateBlock("entry")); + + // Handle parameters: alloca and store + if (ctx->funcFParams()) { + const auto& fparams = ctx->funcFParams()->funcFParam(); + const auto& args = func_->GetArguments(); + for (size_t i = 0; i < fparams.size(); ++i) { + auto* fparam = fparams[i]; + auto* arg = args[i].get(); + auto* slot = builder_.CreateAlloca(StorageType(arg->GetType()), fparam->ID()->getText()); + builder_.CreateStore(arg, slot); + storage_map_[fparam] = slot; + } + } ctx->blockStmt()->accept(this); - // 语义正确性主要由 sema 保证,这里只兜底检查 IR 结构是否合法。 + + // Default return for void functions if not terminated + if (!builder_.GetInsertBlock()->HasTerminator()) { + if (ret_ty->IsVoid()) { + builder_.CreateRet(nullptr); + } else if (ret_ty->IsInt32()) { + builder_.CreateRet(builder_.CreateConstInt(0)); + } else if (ret_ty->IsFloat()) { + builder_.CreateRet(builder_.CreateConstFloat(0.0f)); + } + } + VerifyFunctionStructure(*func_); + is_global_scope_ = true; + return {}; +} + +std::any IRGenImpl::visitFuncFParam(SysYParser::FuncFParamContext* ctx) { + // We handle fparams in visitFuncDef directly. return {}; } diff --git a/src/irgen/IRGenStmt.cpp b/src/irgen/IRGenStmt.cpp index 751550c..36115e8 100644 --- a/src/irgen/IRGenStmt.cpp +++ b/src/irgen/IRGenStmt.cpp @@ -6,34 +6,146 @@ #include "ir/IR.h" #include "utils/Log.h" -// 语句生成当前只实现了最小子集。 -// 目前支持: -// - return ; -// -// 还未支持: -// - 赋值语句 -// - if / while 等控制流 -// - 空语句、块语句嵌套分发之外的更多语句形态 - +// 语句生成 std::any IRGenImpl::visitStmt(SysYParser::StmtContext* ctx) { - if (!ctx) { - throw std::runtime_error(FormatError("irgen", "缺少语句")); - } - if (ctx->returnStmt()) { - return ctx->returnStmt()->accept(this); - } - throw std::runtime_error(FormatError("irgen", "暂不支持的语句类型")); + if (!ctx) return BlockFlow::Continue; + + if (ctx->assignStmt()) return ctx->assignStmt()->accept(this); + if (ctx->returnStmt()) return ctx->returnStmt()->accept(this); + if (ctx->blockStmt()) return ctx->blockStmt()->accept(this); + if (ctx->ifStmt()) return ctx->ifStmt()->accept(this); + if (ctx->whileStmt()) return ctx->whileStmt()->accept(this); + if (ctx->breakStmt()) return ctx->breakStmt()->accept(this); + if (ctx->continueStmt()) return ctx->continueStmt()->accept(this); + if (ctx->expStmt()) return ctx->expStmt()->accept(this); + + return BlockFlow::Continue; } +std::any IRGenImpl::visitBlockStmt(SysYParser::BlockStmtContext* ctx) { + for (auto* item : ctx->blockItem()) { + if (std::any_cast(item->accept(this)) == BlockFlow::Terminated) { + return BlockFlow::Terminated; + } + } + return BlockFlow::Continue; +} + +std::any IRGenImpl::visitBlockItem(SysYParser::BlockItemContext* ctx) { + if (ctx->decl()) { + ctx->decl()->accept(this); + return BlockFlow::Continue; + } + if (ctx->stmt()) { + return ctx->stmt()->accept(this); + } + return BlockFlow::Continue; +} + +std::any IRGenImpl::visitAssignStmt(SysYParser::AssignStmtContext* ctx) { + ir::Value* ptr = GetLValuePtr(ctx->lValue()); + ir::Value* val = EvalExpr(*ctx->exp()); + if (ptr->GetType()->IsPtrFloat() && val->GetType()->IsInt32()) { + val = builder_.CreateSIToFP(val, ir::Type::GetFloatType(), + module_.GetContext().NextTemp()); + } else if (ptr->GetType()->IsPtrInt32() && val->GetType()->IsFloat()) { + val = builder_.CreateFPToSI(val, ir::Type::GetInt32Type(), + module_.GetContext().NextTemp()); + } + builder_.CreateStore(val, ptr); + return BlockFlow::Continue; +} std::any IRGenImpl::visitReturnStmt(SysYParser::ReturnStmtContext* ctx) { - if (!ctx) { - throw std::runtime_error(FormatError("irgen", "缺少 return 语句")); + if (ctx->exp()) { + ir::Value* v = EvalExpr(*ctx->exp()); + if (func_->GetType()->IsFloat() && v->GetType()->IsInt32()) { + v = builder_.CreateSIToFP(v, ir::Type::GetFloatType(), + module_.GetContext().NextTemp()); + } else if (func_->GetType()->IsInt32() && v->GetType()->IsFloat()) { + v = builder_.CreateFPToSI(v, ir::Type::GetInt32Type(), + module_.GetContext().NextTemp()); + } + builder_.CreateRet(v); + } else { + builder_.CreateRet(nullptr); } - if (!ctx->exp()) { - throw std::runtime_error(FormatError("irgen", "return 缺少表达式")); - } - ir::Value* v = EvalExpr(*ctx->exp()); - builder_.CreateRet(v); return BlockFlow::Terminated; } + +std::any IRGenImpl::visitIfStmt(SysYParser::IfStmtContext* ctx) { + const std::string suffix = module_.GetContext().NextTemp(); + ir::BasicBlock* true_bb = func_->CreateBlock("if.true." + suffix); + ir::BasicBlock* false_bb = + ctx->ELSE() ? func_->CreateBlock("if.false." + suffix) : nullptr; + ir::BasicBlock* merge_bb = func_->CreateBlock("if.merge." + suffix); + + ir::Value* cond = EvalExpr(*ctx->exp()); + builder_.CreateCondBr(ToI1(cond), true_bb, false_bb ? false_bb : merge_bb); + + // True block + builder_.SetInsertPoint(true_bb); + if (std::any_cast(ctx->stmt(0)->accept(this)) == BlockFlow::Continue) { + builder_.CreateBr(merge_bb); + } + + // False block + if (false_bb) { + builder_.SetInsertPoint(false_bb); + if (std::any_cast(ctx->stmt(1)->accept(this)) == BlockFlow::Continue) { + builder_.CreateBr(merge_bb); + } + } + + builder_.SetInsertPoint(merge_bb); + return BlockFlow::Continue; +} + +std::any IRGenImpl::visitWhileStmt(SysYParser::WhileStmtContext* ctx) { + const std::string suffix = module_.GetContext().NextTemp(); + ir::BasicBlock* cond_bb = func_->CreateBlock("while.cond." + suffix); + ir::BasicBlock* body_bb = func_->CreateBlock("while.body." + suffix); + ir::BasicBlock* end_bb = func_->CreateBlock("while.end." + suffix); + + builder_.CreateBr(cond_bb); + builder_.SetInsertPoint(cond_bb); + ir::Value* cond = EvalExpr(*ctx->exp()); + builder_.CreateCondBr(ToI1(cond), body_bb, end_bb); + + break_stack_.push(end_bb); + continue_stack_.push(cond_bb); + + builder_.SetInsertPoint(body_bb); + if (std::any_cast(ctx->stmt()->accept(this)) == BlockFlow::Continue) { + builder_.CreateBr(cond_bb); + } + + break_stack_.pop(); + continue_stack_.pop(); + + builder_.SetInsertPoint(end_bb); + return BlockFlow::Continue; +} + +std::any IRGenImpl::visitBreakStmt(SysYParser::BreakStmtContext* ctx) { + if (break_stack_.empty()) { + throw std::runtime_error(FormatError("irgen", "break 语句不在循环内")); + } + builder_.CreateBr(break_stack_.top()); + return BlockFlow::Terminated; +} + +std::any IRGenImpl::visitContinueStmt(SysYParser::ContinueStmtContext* ctx) { + if (continue_stack_.empty()) { + throw std::runtime_error(FormatError("irgen", "continue 语句不在循环内")); + } + builder_.CreateBr(continue_stack_.top()); + return BlockFlow::Terminated; +} + +std::any IRGenImpl::visitExpStmt(SysYParser::ExpStmtContext* ctx) { + if (ctx->exp()) { + EvalExpr(*ctx->exp()); + } + return BlockFlow::Continue; +} diff --git a/src/sem/Sema.cpp b/src/sem/Sema.cpp index 745374c..5ac46ca 100644 --- a/src/sem/Sema.cpp +++ b/src/sem/Sema.cpp @@ -10,185 +10,321 @@ namespace { -std::string GetLValueName(SysYParser::LValueContext& lvalue) { - if (!lvalue.ID()) { - throw std::runtime_error(FormatError("sema", "非法左值")); - } - return lvalue.ID()->getText(); -} - class SemaVisitor final : public SysYBaseVisitor { public: + explicit SemaVisitor() { + // 预填标准库函数 + AddBuiltin("getint", Symbol::Kind::Function); + AddBuiltin("getch", Symbol::Kind::Function); + AddBuiltin("getfloat", Symbol::Kind::Function); + AddBuiltin("getarray", Symbol::Kind::Function); + AddBuiltin("getfarray", Symbol::Kind::Function); + AddBuiltin("putint", Symbol::Kind::Function); + AddBuiltin("putch", Symbol::Kind::Function); + AddBuiltin("putfloat", Symbol::Kind::Function); + AddBuiltin("putarray", Symbol::Kind::Function); + AddBuiltin("putfarray", Symbol::Kind::Function); + AddBuiltin("starttime", Symbol::Kind::Function); + AddBuiltin("stoptime", Symbol::Kind::Function); + AddBuiltin("putf", Symbol::Kind::Function); + } + + private: + void AddBuiltin(const std::string& name, Symbol::Kind kind) { + Symbol sym; + sym.kind = kind; + sym.def_ctx = nullptr; // 内建函数没有语法树节点 + table_.Add(name, sym); + } + std::any visitCompUnit(SysYParser::CompUnitContext* ctx) override { - if (!ctx) { - throw std::runtime_error(FormatError("sema", "缺少编译单元")); + if (!ctx) return {}; + for (auto* child : ctx->children) { + if (auto* decl = dynamic_cast(child)) { + decl->accept(this); + } else if (auto* funcDef = dynamic_cast(child)) { + funcDef->accept(this); + } } - auto* func = ctx->funcDef(); - if (!func || !func->blockStmt()) { - throw std::runtime_error(FormatError("sema", "缺少 main 函数定义")); + return {}; + } + + std::any visitDecl(SysYParser::DeclContext* ctx) override { + if (ctx->constDecl()) return ctx->constDecl()->accept(this); + if (ctx->varDecl()) return ctx->varDecl()->accept(this); + return {}; + } + + std::any visitConstDecl(SysYParser::ConstDeclContext* ctx) override { + for (auto* def : ctx->constDef()) { + const std::string name = def->ID()->getText(); + if (table_.IsInCurrentScope(name)) { + throw std::runtime_error(FormatError("sema", "重复定义常量: " + name)); + } + Symbol sym; + sym.kind = Symbol::Kind::Constant; + sym.def_ctx = def; + sym.is_const = true; + sym.is_array = !def->exp().empty(); + table_.Add(name, sym); + + for (auto* exp : def->exp()) exp->accept(this); + def->initValue()->accept(this); } - if (!func->ID() || func->ID()->getText() != "main") { - throw std::runtime_error(FormatError("sema", "缺少 main 函数定义")); - } - func->accept(this); - if (!seen_return_) { - throw std::runtime_error( - FormatError("sema", "main 函数必须包含 return 语句")); + return {}; + } + + std::any visitVarDecl(SysYParser::VarDeclContext* ctx) override { + for (auto* def : ctx->varDef()) { + const std::string name = def->ID()->getText(); + if (table_.IsInCurrentScope(name)) { + throw std::runtime_error(FormatError("sema", "重复定义变量: " + name)); + } + Symbol sym; + sym.kind = Symbol::Kind::Variable; + sym.def_ctx = def; + sym.is_const = false; + sym.is_array = !def->exp().empty(); + table_.Add(name, sym); + + for (auto* exp : def->exp()) exp->accept(this); + if (def->initValue()) def->initValue()->accept(this); } return {}; } std::any visitFuncDef(SysYParser::FuncDefContext* ctx) override { - if (!ctx || !ctx->blockStmt()) { - throw std::runtime_error(FormatError("sema", "缺少 main 函数定义")); + const std::string name = ctx->ID()->getText(); + if (table_.IsInCurrentScope(name)) { + throw std::runtime_error(FormatError("sema", "重复定义函数: " + name)); } - if (!ctx->funcType() || !ctx->funcType()->INT()) { - throw std::runtime_error(FormatError("sema", "当前仅支持 int main")); + Symbol sym; + sym.kind = Symbol::Kind::Function; + sym.def_ctx = ctx; + table_.Add(name, sym); + + table_.PushScope(); + if (ctx->funcFParams()) { + ctx->funcFParams()->accept(this); } - const auto& items = ctx->blockStmt()->blockItem(); - if (items.empty()) { - throw std::runtime_error( - FormatError("sema", "main 函数不能为空,且必须以 return 结束")); + if (ctx->blockStmt()) { + // Visit block items without pushing another scope to keep params in same scope + for (auto* item : ctx->blockStmt()->blockItem()) { + item->accept(this); + } } - ctx->blockStmt()->accept(this); + table_.PopScope(); + return {}; + } + + std::any visitFuncFParam(SysYParser::FuncFParamContext* ctx) override { + const std::string name = ctx->ID()->getText(); + if (table_.IsInCurrentScope(name)) { + throw std::runtime_error(FormatError("sema", "函数参数名冲突: " + name)); + } + Symbol sym; + sym.kind = Symbol::Kind::Parameter; + sym.def_ctx = ctx; + sym.is_array = !ctx->LBRACK().empty(); + table_.Add(name, sym); + + for (auto* exp : ctx->exp()) exp->accept(this); return {}; } std::any visitBlockStmt(SysYParser::BlockStmtContext* ctx) override { - if (!ctx) { - throw std::runtime_error(FormatError("sema", "缺少语句块")); - } - const auto& items = ctx->blockItem(); - for (size_t i = 0; i < items.size(); ++i) { - auto* item = items[i]; - if (!item) { - continue; - } - if (seen_return_) { - throw std::runtime_error( - FormatError("sema", "return 必须是 main 函数中的最后一条语句")); - } - current_item_index_ = i; - total_items_ = items.size(); + table_.PushScope(); + for (auto* item : ctx->blockItem()) { item->accept(this); } + table_.PopScope(); return {}; } - std::any visitBlockItem(SysYParser::BlockItemContext* ctx) override { - if (!ctx) { - throw std::runtime_error(FormatError("sema", "暂不支持的语句或声明")); - } - if (ctx->decl()) { - ctx->decl()->accept(this); - return {}; - } - if (ctx->stmt()) { - ctx->stmt()->accept(this); - return {}; - } - throw std::runtime_error(FormatError("sema", "暂不支持的语句或声明")); - } - - std::any visitDecl(SysYParser::DeclContext* ctx) override { - if (!ctx) { - throw std::runtime_error(FormatError("sema", "非法变量声明")); - } - if (!ctx->btype() || !ctx->btype()->INT()) { - throw std::runtime_error(FormatError("sema", "当前仅支持局部 int 变量声明")); - } - auto* var_def = ctx->varDef(); - if (!var_def || !var_def->lValue()) { - throw std::runtime_error(FormatError("sema", "非法变量声明")); - } - const std::string name = GetLValueName(*var_def->lValue()); - if (table_.Contains(name)) { - throw std::runtime_error(FormatError("sema", "重复定义变量: " + name)); - } - if (auto* init = var_def->initValue()) { - if (!init->exp()) { - throw std::runtime_error(FormatError("sema", "当前不支持聚合初始化")); - } - init->exp()->accept(this); - } - table_.Add(name, var_def); - return {}; - } - - std::any visitStmt(SysYParser::StmtContext* ctx) override { - if (!ctx || !ctx->returnStmt()) { - throw std::runtime_error(FormatError("sema", "暂不支持的语句或声明")); - } - ctx->returnStmt()->accept(this); - return {}; - } - - std::any visitReturnStmt(SysYParser::ReturnStmtContext* ctx) override { - if (!ctx || !ctx->exp()) { - throw std::runtime_error(FormatError("sema", "return 缺少表达式")); + std::any visitAssignStmt(SysYParser::AssignStmtContext* ctx) override { + ctx->lValue()->accept(this); + const std::string name = ctx->lValue()->ID()->getText(); + Symbol* sym = table_.Lookup(name); + if (sym && sym->is_const) { + throw std::runtime_error(FormatError("sema", "试图给常量赋值: " + name)); } ctx->exp()->accept(this); - seen_return_ = true; - if (current_item_index_ + 1 != total_items_) { - throw std::runtime_error( - FormatError("sema", "return 必须是 main 函数中的最后一条语句")); + return {}; + } + + std::any visitLValue(SysYParser::LValueContext* ctx) override { + const std::string name = ctx->ID()->getText(); + Symbol* sym = table_.Lookup(name); + if (!sym) { + throw std::runtime_error(FormatError("sema", "使用了未定义的标识符: " + name)); + } + if (sym->kind == Symbol::Kind::Function) { + throw std::runtime_error(FormatError("sema", "函数名不能作为左值: " + name)); + } + sema_.BindLValue(ctx, sym->def_ctx); + for (auto* exp : ctx->exp()) exp->accept(this); + return {}; + } + + std::any visitFuncCallExp(SysYParser::FuncCallExpContext* ctx) override { + const std::string name = ctx->ID()->getText(); + Symbol* sym = table_.Lookup(name); + if (!sym) { + throw std::runtime_error(FormatError("sema", "调用未定义的函数: " + name)); + } + if (sym->kind != Symbol::Kind::Function) { + throw std::runtime_error(FormatError("sema", "标识符不是函数: " + name)); + } + sema_.BindFuncCall(ctx, dynamic_cast(sym->def_ctx)); + if (ctx->funcRParams()) { + ctx->funcRParams()->accept(this); } return {}; } + // Visit expressions to ensure all sub-expressions are checked (e.g. for variable uses) std::any visitParenExp(SysYParser::ParenExpContext* ctx) override { - if (!ctx || !ctx->exp()) { - throw std::runtime_error(FormatError("sema", "非法括号表达式")); - } - ctx->exp()->accept(this); - return {}; + return ctx->exp()->accept(this); } - std::any visitVarExp(SysYParser::VarExpContext* ctx) override { - if (!ctx || !ctx->var()) { - throw std::runtime_error(FormatError("sema", "非法变量表达式")); - } - ctx->var()->accept(this); - return {}; + std::any visitLValueExp(SysYParser::LValueExpContext* ctx) override { + return ctx->lValue()->accept(this); } std::any visitNumberExp(SysYParser::NumberExpContext* ctx) override { - if (!ctx || !ctx->number() || !ctx->number()->ILITERAL()) { - throw std::runtime_error(FormatError("sema", "当前仅支持整数字面量")); - } + return ctx->number()->accept(this); + } + + std::any visitNumber(SysYParser::NumberContext* ctx) override { return {}; } - std::any visitAdditiveExp(SysYParser::AdditiveExpContext* ctx) override { - if (!ctx || !ctx->exp(0) || !ctx->exp(1)) { - throw std::runtime_error(FormatError("sema", "暂不支持的表达式形式")); - } + std::any visitNotExp(SysYParser::NotExpContext* ctx) override { + return ctx->exp()->accept(this); + } + + std::any visitUnaryAddExp(SysYParser::UnaryAddExpContext* ctx) override { + return ctx->exp()->accept(this); + } + + std::any visitUnarySubExp(SysYParser::UnarySubExpContext* ctx) override { + return ctx->exp()->accept(this); + } + + std::any visitMulExp(SysYParser::MulExpContext* ctx) override { ctx->exp(0)->accept(this); ctx->exp(1)->accept(this); return {}; } - std::any visitVar(SysYParser::VarContext* ctx) override { - if (!ctx || !ctx->ID()) { - throw std::runtime_error(FormatError("sema", "非法变量引用")); - } - const std::string name = ctx->ID()->getText(); - auto* decl = table_.Lookup(name); - if (!decl) { - throw std::runtime_error(FormatError("sema", "使用了未定义的变量: " + name)); - } - sema_.BindVarUse(ctx, decl); + std::any visitDivExp(SysYParser::DivExpContext* ctx) override { + ctx->exp(0)->accept(this); + ctx->exp(1)->accept(this); return {}; } - SemanticContext TakeSemanticContext() { return std::move(sema_); } + std::any visitModExp(SysYParser::ModExpContext* ctx) override { + ctx->exp(0)->accept(this); + ctx->exp(1)->accept(this); + return {}; + } - private: + std::any visitAddExp(SysYParser::AddExpContext* ctx) override { + ctx->exp(0)->accept(this); + ctx->exp(1)->accept(this); + return {}; + } + + std::any visitSubExp(SysYParser::SubExpContext* ctx) override { + ctx->exp(0)->accept(this); + ctx->exp(1)->accept(this); + return {}; + } + + std::any visitLtExp(SysYParser::LtExpContext* ctx) override { + ctx->exp(0)->accept(this); + ctx->exp(1)->accept(this); + return {}; + } + + std::any visitLeExp(SysYParser::LeExpContext* ctx) override { + ctx->exp(0)->accept(this); + ctx->exp(1)->accept(this); + return {}; + } + + std::any visitGtExp(SysYParser::GtExpContext* ctx) override { + ctx->exp(0)->accept(this); + ctx->exp(1)->accept(this); + return {}; + } + + std::any visitGeExp(SysYParser::GeExpContext* ctx) override { + ctx->exp(0)->accept(this); + ctx->exp(1)->accept(this); + return {}; + } + + std::any visitEqExp(SysYParser::EqExpContext* ctx) override { + ctx->exp(0)->accept(this); + ctx->exp(1)->accept(this); + return {}; + } + + std::any visitNeExp(SysYParser::NeExpContext* ctx) override { + ctx->exp(0)->accept(this); + ctx->exp(1)->accept(this); + return {}; + } + + std::any visitAndExp(SysYParser::AndExpContext* ctx) override { + ctx->exp(0)->accept(this); + ctx->exp(1)->accept(this); + return {}; + } + + std::any visitOrExp(SysYParser::OrExpContext* ctx) override { + ctx->exp(0)->accept(this); + ctx->exp(1)->accept(this); + return {}; + } + + std::any visitReturnStmt(SysYParser::ReturnStmtContext* ctx) override { + if (ctx->exp()) ctx->exp()->accept(this); + return {}; + } + + std::any visitIfStmt(SysYParser::IfStmtContext* ctx) override { + ctx->exp()->accept(this); + ctx->stmt(0)->accept(this); + if (ctx->stmt(1)) ctx->stmt(1)->accept(this); + return {}; + } + + std::any visitWhileStmt(SysYParser::WhileStmtContext* ctx) override { + ctx->exp()->accept(this); + ctx->stmt()->accept(this); + return {}; + } + + std::any visitBreakStmt(SysYParser::BreakStmtContext* ctx) override { + return {}; + } + + std::any visitContinueStmt(SysYParser::ContinueStmtContext* ctx) override { + return {}; + } + + std::any visitExpStmt(SysYParser::ExpStmtContext* ctx) override { + if (ctx->exp()) ctx->exp()->accept(this); + return {}; + } +public: + SemanticContext TakeSemanticContext() { return std::move(sema_); } + +private: SymbolTable table_; SemanticContext sema_; - bool seen_return_ = false; - size_t current_item_index_ = 0; - size_t total_items_ = 0; }; } // namespace diff --git a/src/sem/SymbolTable.cpp b/src/sem/SymbolTable.cpp index ffeea89..35ce231 100644 --- a/src/sem/SymbolTable.cpp +++ b/src/sem/SymbolTable.cpp @@ -1,17 +1,40 @@ -// 维护局部变量声明的注册与查找。 - #include "sem/SymbolTable.h" -void SymbolTable::Add(const std::string& name, - SysYParser::VarDefContext* decl) { - table_[name] = decl; +SymbolTable::SymbolTable() { + // Push global scope + PushScope(); } -bool SymbolTable::Contains(const std::string& name) const { - return table_.find(name) != table_.end(); +void SymbolTable::PushScope() { + scopes_.emplace_back(); } -SysYParser::VarDefContext* SymbolTable::Lookup(const std::string& name) const { - auto it = table_.find(name); - return it == table_.end() ? nullptr : it->second; +void SymbolTable::PopScope() { + if (scopes_.size() > 1) { + scopes_.pop_back(); + } +} + +bool SymbolTable::Add(const std::string& name, const Symbol& symbol) { + auto& current_scope = scopes_.back(); + if (current_scope.find(name) != current_scope.end()) { + return false; + } + current_scope[name] = symbol; + return true; +} + +Symbol* SymbolTable::Lookup(const std::string& name) { + for (auto it = scopes_.rbegin(); it != scopes_.rend(); ++it) { + auto search = it->find(name); + if (search != it->end()) { + return &search->second; + } + } + return nullptr; +} + +bool SymbolTable::IsInCurrentScope(const std::string& name) const { + const auto& current_scope = scopes_.back(); + return current_scope.find(name) != current_scope.end(); }