From 0b0bc04be3dd994c64321e00bff6d0bcde2970ab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=A8=8B=E6=99=AF=E6=84=89?= <776459475@qq.com> Date: Sat, 25 Apr 2026 14:30:22 +0800 Subject: [PATCH] feat: complete Lab3 instruction selection and assembly generation --- doc/Lab3-实验记录.md | 119 ++++++++++ include/mir/MIR.h | 54 ++++- scripts/verify_asm.sh | 2 +- src/irgen/IRGenExp.cpp | 14 +- src/main.cpp | 16 +- src/mir/AsmPrinter.cpp | 326 +++++++++++++++++++++++---- src/mir/FrameLowering.cpp | 31 ++- src/mir/Lowering.cpp | 463 +++++++++++++++++++++++++++++++++----- src/mir/MIRFunction.cpp | 7 +- src/mir/MIRInstr.cpp | 20 +- src/mir/RegAlloc.cpp | 23 +- src/mir/Register.cpp | 84 ++++++- sylib/sylib.c | 79 ++++++- 13 files changed, 1078 insertions(+), 160 deletions(-) create mode 100644 doc/Lab3-实验记录.md diff --git a/doc/Lab3-实验记录.md b/doc/Lab3-实验记录.md new file mode 100644 index 0000000..e4ff418 --- /dev/null +++ b/doc/Lab3-实验记录.md @@ -0,0 +1,119 @@ +# Lab3 实验记录:指令选择与汇编生成 + +## 1. 实验目标 + +本次 Lab3 的目标是在已有的 SysY 前端与 IR 生成基础上,补齐 AArch64 后端指令选择、控制流翻译、全局变量和运行时库接口,使编译器能够把 SysY IR 翻译为可在 AArch64(ARM64)平台上运行的汇编程序,并通过 QEMU 模拟器验证生成结果的正确性。 + +本次完成工作的重点包括: +- 扩展 MIR 中物理寄存器、指令操作数种类与机器指令集,完整覆盖 AArch64 核心子集。 +- 扩展指令选择逻辑(`Lowering.cpp`),支持多函数、多基本块、函数调用、浮点数与多维数组(GEP)地址计算。 +- 处理 AArch64 调用约定(ABI)中参数传递(整数/浮点前 8 传参)与栈帧落地细节。 +- 解决 AArch64 特有的指令寻址与栈槽大偏移(超出 ldur/stur 范围)的物理寄存器备用搬运机制。 +- 补齐 SysY 运行时库(`sylib/sylib.c`)中所有 I/O、时间统计与十六进制浮点输入输出功能。 + +## 2. 代码改动范围 + +本次实验主要修改/新增了以下文件: +- `include/mir/MIR.h` 与 `src/mir/MIRFunction.cpp`、`src/mir/MIRInstr.cpp`、`src/mir/Register.cpp`、`src/mir/RegAlloc.cpp`、`src/mir/FrameLowering.cpp` +- `src/mir/Lowering.cpp` (核心指令选择) +- `src/mir/AsmPrinter.cpp` (核心汇编文本打印) +- `sylib/sylib.c` (SysY 运行库) +- `scripts/verify_asm.sh` (自动化编译链接脚本) +- `src/main.cpp` (后端多函数汇编流适配) +- `src/irgen/IRGenExp.cpp` (修复前端常数类型转换缺陷) +- 新增本文档 `doc/Lab3-实验记录.md` + +## 3. 完成过程 + +### 3.1 梳理后端结构与定位边界 +阅读了实验文档 `doc/Lab3-指令选择与汇编生成.md`,原有的后端属于“极简演示”: +- 仅支持单函数 `main` 与单基本块。 +- 仅支持 `alloca`, `load`, `store`, `add`, `ret` 五种指令。 +- 栈帧偏移与寻址硬编码为 `ldur`/`stur`,没有考虑多维数组、浮点数以及超出 `[-256, 255]` 寻址范围的指令级溢出崩溃问题。 + +### 3.2 解决前置类型转换 bug +在回归测试 `95_float.sy` 时,我们发现由于前端对 `const int` 类型常量初始值为 `float` 时没有及时阶段性类型截断,导致 `const int FIVE = TWO + THREE`(其中 `TWO = 2.9, THREE = 3.2`)的编译期常量求值被错误地计算为 `2.9 + 3.2 = 6.1` 再向下转型为 `6`,而实际应该先将 `TWO` 转型为 `2`,`THREE` 转型为 `3`,二者相加得到 `5`。 +我们在 `IRGenExp.cpp` 的 `ConstExprVisitor::visitLValueExp` 中实现了类型安全截断,彻底解决了这一隐式类型转换带来的精度和常量值错误。 + +### 3.3 AArch64 后端指令扩充与栈槽模型构建 +我们保持并完善了后端的高可靠“栈槽模型”: +1. 每一个 IR 中产生的 `Value`(包括临时虚拟寄存器和指令)均在 `LowerToMIR` 中分配一个专属的 64 位(或 32 位)栈槽(`FrameIndex`)。 +2. 在 lowering 每一条指令时,先从它们的栈槽加载操作数到 AArch64 的 scratch 寄存器(`w8`/`w9` 或 `s8`/`s9` 等),执行运算后再把结果写回栈槽。 +3. 这种模型虽然带来了一定的访存冗余(可通过 Lab5 寄存器分配和窥孔优化消除),但在本阶段能够 **100% 保证变量活跃期与正确性**,排除了寄存器冲突。 + +--- + +## 4. 关键困难与解决办法 + +### 4.1 困难一:双向迭代器/指针失效(BasicBlock vector 重配引发的段错误) +#### 现象 +在对包含复杂控制流的用例(如 `29_break.sy`)进行编译时,后端经常发生 `段错误(Segmentation Fault)`。 +经过定位,我们在 `LowerToMIR` 发现,基本块是通过 `machine_func->CreateBlock(bbPtr->GetName())` 动态添加进 `std::vector blocks_` 中的。随着 blocks vector 容量扩张,底层的内存发生重分配,导致此前在 `std::unordered_map bb_map` 中记录的所有指向 `MachineBasicBlock` 的指针全部变成了野指针(Dangling Pointer),再次使用时引发段错误。 +#### 解决办法 +在创建基本块循环前,预先调用 `machine_func->GetBlocks().reserve(func.GetBlocks().size())` 保障 vector 拥有足够容量,彻底杜绝了动态重分配带来的指针失效问题。 + +### 4.2 困难二:栈帧槽寻址大偏移超出 AArch64 立即数范围 +#### 现象 +在 `25_scope3.sy` 和 `95_float.sy` 中,函数内临时变量繁多,栈帧空间轻松超过 256 字节。AArch64 的 `ldur`/`stur` 的非对齐 9 位带符号偏移限制在 `[-256, 255]` 范围内。一旦栈帧偏移动态计算结果为 `-268` 等越界值,汇编器(`as`)便会报错 `immediate offset out of range` 拒绝编译。 +#### 解决办法 +在 `AsmPrinter.cpp` 的 `PrintStackAccess` 寻址生成中增加偏移区间自适应检测: +- 若偏移量在 `[-256, 255]` 之间,照常生成轻量的 `ldur`/`stur`; +- 若偏移量超出该区间,则先生成 `mov x10, #offset` 汇编指令将偏移加载至备用 64 位寄存器 `x10`,然后再使用 AArch64 的寄存器偏移寻址格式 `ldr reg, [x29, x10]` 或 `str reg, [x29, x10]` 完美避开立即数范围限制。 + +### 4.3 困难三:浮点常量与全局变量打印的精度丢失 +#### 现象 +`95_float.sy` 中对浮点数相等的比较非常苛刻。如果全局浮点变量打印为 `.float 3.14159`,在 C++ `ostream` 默认 6 位精度输出下会造成严重的低位比特丢失,导致十六进制浮点输入输出断言失败。 +#### 解决办法 +我们将所有全局和局部的浮点常数转换为底层的 bit-exact 二进制字面量表示。例如浮点数 `val`,先通过 `memcpy` 获取其 32 位整型二进制比特,然后以 `.word ` 指令原封不动写回汇编。这保证了在编译、汇编、运行的全生命周期中,浮点数值是 **100% 位一致** 的。 + +### 4.4 困难四:SysY 库函数接口的缺失与十六进制浮点适配 +#### 现象 +由于原仓库的 `sylib/sylib.c` 是一个空壳,导致调用了 I/O 运行库的测试用例链接失败。并且评测指标中浮点数的输入输出要求使用十六进制浮点格式(`%a`)输出。 +#### 解决办法 +1. 完整用 C 语言重写了 `sylib/sylib.c`,提供 `getint`, `getch`, `getfloat`, `getarray`, `getfarray`, `putint`, `putch`, `putfloat`, `putarray`, `putfarray`, `starttime`, `stoptime` 的高可靠实现。 +2. 将 `putfloat` 和 `putfarray` 适配为 `%a` 十六进制浮点格式,同时采用 `double` 精度读取以消除单双精度转换过程中的尾数舍入偏差。 +3. 修改 `verify_asm.sh`,在汇编可执行文件生成阶段自动打包链接 `sylib/sylib.c`。 + +--- + +## 5. 本次实现的主要能力 + +本阶段完成后,后端编译器已具备以下完整功能: +- **AArch64 指令覆盖**:支持算术(`add`, `sub`, `mul`, `sdiv`, `msub`)、比较(`cmp`, `fcmp`)、条件选择(`cset`)、控制流分支(`b`, `b.cond`)、函数调用(`bl`)、内存传输(`ldr`, `str`, `ldur`, `stur`)、浮点数转换(`scvtf`, `fcvtzs`)。 +- **ABI 调用约定规范**:完整实现了前 8 个整型/指针参数及前 8 个浮点参数通过寄存器传递,返回结果分别放入 `w0`/`x0`/`s0`。 +- **多函数多块控制流**:支持具有任意多非声明函数、多基本块的控制流图(CFG)后端降低。 +- **高保真浮点系统**:支持 bit-perfect 浮点常数生成和位级别精确度全局变量初始化。 +- **大栈帧保障寻址**:突破 AArch64 立即数偏移寻址范围,保障任意超大型函数的安全编译。 + +## 6. 验证结果 + +我们对 `test/test_case/functional` 目录下的所有用例执行了汇编与执行回归。所有用例均成功生成 AArch64 汇编,成功链接运行库,且运行输出结果与退出码与预期文件(`.out`)**100% 吻合,完全通过**: + +```bash +=== Running test/test_case/functional/05_arr_defn4.sy === +输出匹配: test/test_case/functional/05_arr_defn4.out +=== Running test/test_case/functional/09_func_defn.sy === +输出匹配: test/test_case/functional/09_func_defn.out +=== Running test/test_case/functional/11_add2.sy === +输出匹配: test/test_case/functional/11_add2.out +=== Running test/test_case/functional/13_sub2.sy === +输出匹配: test/test_case/functional/13_sub2.out +=== Running test/test_case/functional/15_graph_coloring.sy === +输出匹配: test/test_case/functional/15_graph_coloring.out +=== Running test/test_case/functional/22_matrix_multiply.sy === +输出匹配: test/test_case/functional/22_matrix_multiply.out +=== Running test/test_case/functional/25_scope3.sy === +输出匹配: test/test_case/functional/25_scope3.out +=== Running test/test_case/functional/29_break.sy === +输出匹配: test/test_case/functional/29_break.out +=== Running test/test_case/functional/36_op_priority2.sy === +输出匹配: test/test_case/functional/36_op_priority2.out +=== Running test/test_case/functional/95_float.sy === +输出匹配: test/test_case/functional/95_float.out +=== Running test/test_case/functional/simple_add.sy === +输出匹配: test/test_case/functional/simple_add.out +``` + +## 7. 结论 + +本次 Lab3 完成了后端指令选择与汇编生成的完美跨越,成功将一个“玩具”后端重构成了一个支持多函数、多基本块、复杂数组与完整浮点运算的高可靠 AArch64 生成引擎。阻塞链路的所有底层越界与精度问题已被完美解决,为 Lab4-6 的标量优化、寄存器分配以及循环分析打下了极其坚实的后端基石。 diff --git a/include/mir/MIR.h b/include/mir/MIR.h index 47b8959..69c1ee3 100644 --- a/include/mir/MIR.h +++ b/include/mir/MIR.h @@ -19,7 +19,14 @@ class MIRContext { MIRContext& DefaultContext(); -enum class PhysReg { W0, W8, W9, X29, X30, SP }; +enum class PhysReg { + W0, W1, W2, W3, W4, W5, W6, W7, W8, W9, W10, W11, W12, W13, W14, W15, + W19, W20, W21, W22, W23, W24, W25, W26, W27, W28, + X0, X1, X2, X3, X4, X5, X6, X7, X8, X9, X10, X11, X12, X13, X14, X15, + X19, X20, X21, X22, X23, X24, X25, X26, X27, X28, + S0, S1, S2, S3, S4, S5, S6, S7, S8, S9, S10, S11, S12, S13, S14, S15, + X29, X30, SP +}; const char* PhysRegName(PhysReg reg); @@ -30,28 +37,57 @@ enum class Opcode { LoadStack, StoreStack, AddRR, + SubRR, + MulRR, + SDivRR, + MSubRRRR, + FAddRRR, + FSubRRR, + FMulRRR, + FDivRRR, + CmpRR, + FCmpRR, + Cset, + B, + BCond, + Call, Ret, + MovReg, + Adrp, + AddRegImm, + LdrRegReg, + StrRegReg, + SIToFP, + FPToSI, + ZExt }; class Operand { public: - enum class Kind { Reg, Imm, FrameIndex }; + enum class Kind { Reg, Imm, FrameIndex, Global, Label, Cond }; static Operand Reg(PhysReg reg); static Operand Imm(int value); static Operand FrameIndex(int index); + static Operand Global(std::string name); + static Operand Label(std::string name); + static Operand Cond(std::string cond); Kind GetKind() const { return kind_; } PhysReg GetReg() const { return reg_; } int GetImm() const { return imm_; } int GetFrameIndex() const { return imm_; } + const std::string& GetGlobalName() const { return str_; } + const std::string& GetLabelName() const { return str_; } + const std::string& GetCondCode() const { return str_; } private: - Operand(Kind kind, PhysReg reg, int imm); + Operand(Kind kind, PhysReg reg, int imm, std::string str = ""); Kind kind_; PhysReg reg_; int imm_; + std::string str_; }; class MachineInstr { @@ -93,9 +129,12 @@ class MachineFunction { explicit MachineFunction(std::string name); const std::string& GetName() const { return name_; } - MachineBasicBlock& GetEntry() { return entry_; } - const MachineBasicBlock& GetEntry() const { return entry_; } + + MachineBasicBlock& CreateBlock(std::string name); + std::vector& GetBlocks() { return blocks_; } + const std::vector& GetBlocks() const { return blocks_; } + // Stack/Frame management int CreateFrameIndex(int size = 4); FrameSlot& GetFrameSlot(int index); const FrameSlot& GetFrameSlot(int index) const; @@ -106,14 +145,15 @@ class MachineFunction { private: std::string name_; - MachineBasicBlock entry_; + std::vector blocks_; std::vector frame_slots_; int frame_size_ = 0; }; -std::unique_ptr LowerToMIR(const ir::Module& module); +std::vector> LowerToMIR(const ir::Module& module); void RunRegAlloc(MachineFunction& function); void RunFrameLowering(MachineFunction& function); void PrintAsm(const MachineFunction& function, std::ostream& os); +void PrintGlobals(const ir::Module& module, std::ostream& os); } // namespace mir diff --git a/scripts/verify_asm.sh b/scripts/verify_asm.sh index a4b8ae2..16c3bb8 100755 --- a/scripts/verify_asm.sh +++ b/scripts/verify_asm.sh @@ -52,7 +52,7 @@ expected_file="$input_dir/$stem.out" "$compiler" --emit-asm "$input" > "$asm_file" echo "汇编已生成: $asm_file" -aarch64-linux-gnu-gcc "$asm_file" -o "$exe" +aarch64-linux-gnu-gcc "$asm_file" sylib/sylib.c -o "$exe" echo "可执行文件已生成: $exe" if [[ "$run_exec" == true ]]; then diff --git a/src/irgen/IRGenExp.cpp b/src/irgen/IRGenExp.cpp index 89a9513..076f0fc 100644 --- a/src/irgen/IRGenExp.cpp +++ b/src/irgen/IRGenExp.cpp @@ -88,7 +88,7 @@ ir::ConstantValue* IRGenImpl::EvalConstExpr(SysYParser::ExpContext& expr) { return static_cast(module_.GetContext().GetConstInt(value)); } return static_cast( - module_.GetContext().GetConstFloat(std::stof(ctx->number()->FLITERAL()->getText()))); + module_.GetContext().GetConstFloat(static_cast(std::stod(ctx->number()->FLITERAL()->getText())))); } std::any visitLValueExp(SysYParser::LValueExpContext* ctx) override { @@ -105,7 +105,17 @@ ir::ConstantValue* IRGenImpl::EvalConstExpr(SysYParser::ExpContext& expr) { throw std::runtime_error( FormatError("irgen", "常量缺少标量初始化表达式")); } - return Eval(*const_def->initValue()->exp()); + auto* init = Eval(*const_def->initValue()->exp()); + auto* decl = dynamic_cast(const_def->parent); + bool is_float = (decl && decl->btype() && decl->btype()->FLOAT()); + if (!is_float && init->GetType()->IsFloat()) { + init = module_.GetContext().GetConstInt( + static_cast(static_cast(init)->GetValue())); + } else if (is_float && init->GetType()->IsInt32()) { + init = module_.GetContext().GetConstFloat( + static_cast(static_cast(init)->GetValue())); + } + return init; } throw std::runtime_error( FormatError("irgen", "全局/常量表达式必须是编译期常量")); diff --git a/src/main.cpp b/src/main.cpp index 88ed747..3e81cdf 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -46,13 +46,17 @@ int main(int argc, char** argv) { } if (opts.emit_asm) { - auto machine_func = mir::LowerToMIR(*module); - mir::RunRegAlloc(*machine_func); - mir::RunFrameLowering(*machine_func); - if (need_blank_line) { - std::cout << "\n"; + mir::PrintGlobals(*module, std::cout); + auto machine_funcs = mir::LowerToMIR(*module); + for (auto& machine_func : machine_funcs) { + mir::RunRegAlloc(*machine_func); + mir::RunFrameLowering(*machine_func); + if (need_blank_line) { + std::cout << "\n"; + } + mir::PrintAsm(*machine_func, std::cout); + need_blank_line = true; } - mir::PrintAsm(*machine_func, std::cout); } #else if (opts.emit_ir || opts.emit_asm) { diff --git a/src/mir/AsmPrinter.cpp b/src/mir/AsmPrinter.cpp index 4d1f65f..547146e 100644 --- a/src/mir/AsmPrinter.cpp +++ b/src/mir/AsmPrinter.cpp @@ -1,7 +1,11 @@ #include "mir/MIR.h" +#include "ir/IR.h" #include #include +#include +#include +#include #include "utils/Log.h" @@ -16,10 +20,34 @@ const FrameSlot& GetFrameSlot(const MachineFunction& function, return function.GetFrameSlot(operand.GetFrameIndex()); } +bool IsFloatReg(PhysReg reg) { + return reg >= PhysReg::S0 && reg <= PhysReg::S15; +} + void PrintStackAccess(std::ostream& os, const char* mnemonic, PhysReg reg, int offset) { - os << " " << mnemonic << " " << PhysRegName(reg) << ", [x29, #" << offset - << "]\n"; + bool is_float = IsFloatReg(reg); + const char* ldr_cmd = is_float ? "ldr" : "ldr"; + const char* str_cmd = is_float ? "str" : "str"; + const char* base_mnemonic = (std::strcmp(mnemonic, "ldur") == 0) ? ldr_cmd : str_cmd; + + if (offset >= -256 && offset <= 255) { + if (is_float) { + os << " " << base_mnemonic << " " << PhysRegName(reg) << ", [x29, #" << offset << "]\n"; + } else { + os << " " << mnemonic << " " << PhysRegName(reg) << ", [x29, #" << offset << "]\n"; + } + } else { + os << " mov x10, #" << offset << "\n"; + os << " " << base_mnemonic << " " << PhysRegName(reg) << ", [x29, x10]\n"; + } +} + +std::string GetBlockLabel(const std::string& func_name, const std::string& block_name) { + if (block_name == "entry" || block_name.empty()) { + return func_name; + } + return ".L_" + func_name + "_" + block_name; } } // namespace @@ -28,51 +56,269 @@ void PrintAsm(const MachineFunction& function, std::ostream& os) { os << ".text\n"; os << ".global " << function.GetName() << "\n"; os << ".type " << function.GetName() << ", %function\n"; - os << function.GetName() << ":\n"; - for (const auto& inst : function.GetEntry().GetInstructions()) { - const auto& ops = inst.GetOperands(); - switch (inst.GetOpcode()) { - case Opcode::Prologue: - os << " stp x29, x30, [sp, #-16]!\n"; - os << " mov x29, sp\n"; - if (function.GetFrameSize() > 0) { - os << " sub sp, sp, #" << function.GetFrameSize() << "\n"; + struct FloatConstant { + std::string label; + float value; + }; + std::vector float_constants; + + for (size_t b = 0; b < function.GetBlocks().size(); ++b) { + const auto& block = function.GetBlocks()[b]; + + // Print the block label + if (b == 0) { + os << function.GetName() << ":\n"; + } else { + os << GetBlockLabel(function.GetName(), block.GetName()) << ":\n"; + } + + for (const auto& inst : block.GetInstructions()) { + const auto& ops = inst.GetOperands(); + switch (inst.GetOpcode()) { + case Opcode::Prologue: + os << " stp x29, x30, [sp, #-16]!\n"; + os << " mov x29, sp\n"; + if (function.GetFrameSize() > 0) { + os << " sub sp, sp, #" << function.GetFrameSize() << "\n"; + } + break; + case Opcode::Epilogue: + if (function.GetFrameSize() > 0) { + os << " add sp, sp, #" << function.GetFrameSize() << "\n"; + } + os << " ldp x29, x30, [sp], #16\n"; + break; + case Opcode::MovImm: { + PhysReg dst = ops.at(0).GetReg(); + if (IsFloatReg(dst)) { + // Load float constant + int bits = ops.at(1).GetImm(); + float val; + std::memcpy(&val, &bits, sizeof(float)); + std::string flabel = ".LC_" + function.GetName() + "_" + std::to_string(float_constants.size()); + float_constants.push_back({flabel, val}); + + os << " adrp x8, " << flabel << "\n"; + os << " ldr " << PhysRegName(dst) << ", [x8, :lo12:" << flabel << "]\n"; + } else { + os << " mov " << PhysRegName(dst) << ", #" << ops.at(1).GetImm() << "\n"; + } + break; } - break; - case Opcode::Epilogue: - if (function.GetFrameSize() > 0) { - os << " add sp, sp, #" << function.GetFrameSize() << "\n"; + case Opcode::LoadStack: { + const auto& slot = GetFrameSlot(function, ops.at(1)); + PrintStackAccess(os, "ldur", ops.at(0).GetReg(), slot.offset); + break; } - os << " ldp x29, x30, [sp], #16\n"; - break; - case Opcode::MovImm: - os << " mov " << PhysRegName(ops.at(0).GetReg()) << ", #" - << ops.at(1).GetImm() << "\n"; - break; - case Opcode::LoadStack: { - const auto& slot = GetFrameSlot(function, ops.at(1)); - PrintStackAccess(os, "ldur", ops.at(0).GetReg(), slot.offset); - break; + case Opcode::StoreStack: { + const auto& slot = GetFrameSlot(function, ops.at(1)); + PrintStackAccess(os, "stur", ops.at(0).GetReg(), slot.offset); + break; + } + case Opcode::AddRR: + os << " add " << PhysRegName(ops.at(0).GetReg()) << ", " + << PhysRegName(ops.at(1).GetReg()) << ", " + << PhysRegName(ops.at(2).GetReg()) << "\n"; + break; + case Opcode::SubRR: + os << " sub " << PhysRegName(ops.at(0).GetReg()) << ", " + << PhysRegName(ops.at(1).GetReg()) << ", " + << PhysRegName(ops.at(2).GetReg()) << "\n"; + break; + case Opcode::MulRR: + os << " mul " << PhysRegName(ops.at(0).GetReg()) << ", " + << PhysRegName(ops.at(1).GetReg()) << ", " + << PhysRegName(ops.at(2).GetReg()) << "\n"; + break; + case Opcode::SDivRR: + os << " sdiv " << PhysRegName(ops.at(0).GetReg()) << ", " + << PhysRegName(ops.at(1).GetReg()) << ", " + << PhysRegName(ops.at(2).GetReg()) << "\n"; + break; + case Opcode::MSubRRRR: + os << " msub " << PhysRegName(ops.at(0).GetReg()) << ", " + << PhysRegName(ops.at(1).GetReg()) << ", " + << PhysRegName(ops.at(2).GetReg()) << ", " + << PhysRegName(ops.at(3).GetReg()) << "\n"; + break; + case Opcode::FAddRRR: + os << " fadd " << PhysRegName(ops.at(0).GetReg()) << ", " + << PhysRegName(ops.at(1).GetReg()) << ", " + << PhysRegName(ops.at(2).GetReg()) << "\n"; + break; + case Opcode::FSubRRR: + os << " fsub " << PhysRegName(ops.at(0).GetReg()) << ", " + << PhysRegName(ops.at(1).GetReg()) << ", " + << PhysRegName(ops.at(2).GetReg()) << "\n"; + break; + case Opcode::FMulRRR: + os << " fmul " << PhysRegName(ops.at(0).GetReg()) << ", " + << PhysRegName(ops.at(1).GetReg()) << ", " + << PhysRegName(ops.at(2).GetReg()) << "\n"; + break; + case Opcode::FDivRRR: + os << " fdiv " << PhysRegName(ops.at(0).GetReg()) << ", " + << PhysRegName(ops.at(1).GetReg()) << ", " + << PhysRegName(ops.at(2).GetReg()) << "\n"; + break; + case Opcode::CmpRR: + os << " cmp " << PhysRegName(ops.at(0).GetReg()) << ", " + << PhysRegName(ops.at(1).GetReg()) << "\n"; + break; + case Opcode::FCmpRR: + os << " fcmp " << PhysRegName(ops.at(0).GetReg()) << ", " + << PhysRegName(ops.at(1).GetReg()) << "\n"; + break; + case Opcode::Cset: + os << " cset " << PhysRegName(ops.at(0).GetReg()) << ", " + << ops.at(1).GetCondCode() << "\n"; + break; + case Opcode::B: + os << " b " << GetBlockLabel(function.GetName(), ops.at(0).GetLabelName()) << "\n"; + break; + case Opcode::BCond: + os << " b." << ops.at(0).GetCondCode() << " " + << GetBlockLabel(function.GetName(), ops.at(1).GetLabelName()) << "\n"; + break; + case Opcode::Call: + os << " bl " << ops.at(0).GetGlobalName() << "\n"; + break; + case Opcode::Ret: + os << " ret\n"; + break; + case Opcode::MovReg: + if (IsFloatReg(ops.at(0).GetReg()) || IsFloatReg(ops.at(1).GetReg())) { + os << " fmov " << PhysRegName(ops.at(0).GetReg()) << ", " + << PhysRegName(ops.at(1).GetReg()) << "\n"; + } else { + os << " mov " << PhysRegName(ops.at(0).GetReg()) << ", " + << PhysRegName(ops.at(1).GetReg()) << "\n"; + } + break; + case Opcode::Adrp: + os << " adrp " << PhysRegName(ops.at(0).GetReg()) << ", " + << ops.at(1).GetGlobalName() << "\n"; + break; + case Opcode::AddRegImm: { + os << " add " << PhysRegName(ops.at(0).GetReg()) << ", " + << PhysRegName(ops.at(1).GetReg()) << ", "; + if (ops.at(2).GetKind() == Operand::Kind::FrameIndex) { + const auto& slot = function.GetFrameSlot(ops.at(2).GetFrameIndex()); + os << "#" << slot.offset << "\n"; + } else if (ops.at(2).GetKind() == Operand::Kind::Global) { + os << ":lo12:" << ops.at(2).GetGlobalName() << "\n"; + } else { + os << "#" << ops.at(2).GetImm() << "\n"; + } + break; + } + case Opcode::LdrRegReg: { + PhysReg reg = ops.at(0).GetReg(); + const char* ldr_cmd = IsFloatReg(reg) ? "ldr" : "ldr"; + os << " " << ldr_cmd << " " << PhysRegName(reg) << ", [" + << PhysRegName(ops.at(1).GetReg()) << "]\n"; + break; + } + case Opcode::StrRegReg: { + PhysReg reg = ops.at(0).GetReg(); + const char* str_cmd = IsFloatReg(reg) ? "str" : "str"; + os << " " << str_cmd << " " << PhysRegName(reg) << ", [" + << PhysRegName(ops.at(1).GetReg()) << "]\n"; + break; + } + case Opcode::SIToFP: + os << " scvtf " << PhysRegName(ops.at(0).GetReg()) << ", " + << PhysRegName(ops.at(1).GetReg()) << "\n"; + break; + case Opcode::FPToSI: + os << " fcvtzs " << PhysRegName(ops.at(0).GetReg()) << ", " + << PhysRegName(ops.at(1).GetReg()) << "\n"; + break; + case Opcode::ZExt: + if (ops.at(0).GetReg() >= PhysReg::X0 && ops.at(0).GetReg() <= PhysReg::X28) { + os << " sxtw " << PhysRegName(ops.at(0).GetReg()) << ", " << PhysRegName(ops.at(1).GetReg()) << "\n"; + } else { + os << " and " << PhysRegName(ops.at(0).GetReg()) << ", " << PhysRegName(ops.at(1).GetReg()) << ", #1\n"; + } + break; } - case Opcode::StoreStack: { - const auto& slot = GetFrameSlot(function, ops.at(1)); - PrintStackAccess(os, "stur", ops.at(0).GetReg(), slot.offset); - break; - } - case Opcode::AddRR: - os << " add " << PhysRegName(ops.at(0).GetReg()) << ", " - << PhysRegName(ops.at(1).GetReg()) << ", " - << PhysRegName(ops.at(2).GetReg()) << "\n"; - break; - case Opcode::Ret: - os << " ret\n"; - break; } } - os << ".size " << function.GetName() << ", .-" << function.GetName() - << "\n"; + os << ".size " << function.GetName() << ", .-" << function.GetName() << "\n"; + + // Print read-only data segment if there are float constants + if (!float_constants.empty()) { + os << ".section .rodata\n"; + os << ".align 2\n"; + for (const auto& fc : float_constants) { + os << fc.label << ":\n"; + uint32_t bits; + std::memcpy(&bits, &fc.value, sizeof(float)); + os << " .word " << bits << " // float " << fc.value << "\n"; + } + } +} + +static uint32_t GetTypeSize(const ir::Type* type) { + if (type->IsInt32() || type->IsFloat()) { + return 4; + } + if (type->IsPtrInt32() || type->IsPtrFloat()) { + return 8; + } + if (type->IsArray()) { + auto* arr_ty = const_cast(type)->GetAsArrayType().get(); + return arr_ty->GetNumElements() * GetTypeSize(arr_ty->GetElementType().get()); + } + return 4; +} + +void PrintGlobals(const ir::Module& module, std::ostream& os) { + for (const auto& gv : module.GetGlobalValues()) { + os << ".global " << gv->GetName() << "\n"; + + std::shared_ptr actual_ty = gv->GetType(); + if (actual_ty->IsPtrInt32()) actual_ty = ir::Type::GetInt32Type(); + else if (actual_ty->IsPtrFloat()) actual_ty = ir::Type::GetFloatType(); + + uint32_t actual_size = GetTypeSize(actual_ty.get()); + + if (gv->GetInitializer()) { + os << ".data\n"; + os << ".align 2\n"; + os << ".size " << gv->GetName() << ", " << actual_size << "\n"; + os << gv->GetName() << ":\n"; + + if (actual_ty->IsFloat()) { + float val = 0.0f; + if (auto* cf = dynamic_cast(gv->GetInitializer())) { + val = cf->GetValue(); + } else if (auto* ci = dynamic_cast(gv->GetInitializer())) { + val = static_cast(ci->GetValue()); + } + uint32_t bits; + std::memcpy(&bits, &val, sizeof(float)); + os << " .word " << bits << " // float " << val << "\n"; + } else { + int val = 0; + if (auto* ci = dynamic_cast(gv->GetInitializer())) { + val = ci->GetValue(); + } else if (auto* cf = dynamic_cast(gv->GetInitializer())) { + val = static_cast(cf->GetValue()); + } + os << " .word " << val << "\n"; + } + } else { + os << ".bss\n"; + os << ".align 4\n"; + os << ".size " << gv->GetName() << ", " << actual_size << "\n"; + os << gv->GetName() << ":\n"; + os << " .zero " << actual_size << "\n"; + } + os << "\n"; + } } } // namespace mir diff --git a/src/mir/FrameLowering.cpp b/src/mir/FrameLowering.cpp index 679ab68..d29a0da 100644 --- a/src/mir/FrameLowering.cpp +++ b/src/mir/FrameLowering.cpp @@ -18,10 +18,10 @@ void RunFrameLowering(MachineFunction& function) { int cursor = 0; for (const auto& slot : function.GetFrameSlots()) { cursor += slot.size; - if (-cursor < -256) { - throw std::runtime_error(FormatError("mir", "暂不支持过大的栈帧")); - } } + + // Align stack frames to 16 bytes for AArch64 + cursor = AlignTo(cursor, 16); cursor = 0; for (const auto& slot : function.GetFrameSlots()) { @@ -30,16 +30,25 @@ void RunFrameLowering(MachineFunction& function) { } function.SetFrameSize(AlignTo(cursor, 16)); - auto& insts = function.GetEntry().GetInstructions(); - std::vector lowered; - lowered.emplace_back(Opcode::Prologue); - for (const auto& inst : insts) { - if (inst.GetOpcode() == Opcode::Ret) { - lowered.emplace_back(Opcode::Epilogue); + auto& blocks = function.GetBlocks(); + if (blocks.empty()) return; + + // Insert Prologue at the start of the first block + auto& entry_insts = blocks.front().GetInstructions(); + entry_insts.insert(entry_insts.begin(), MachineInstr(Opcode::Prologue)); + + // Insert Epilogue before every Ret in all blocks + for (auto& block : blocks) { + auto& insts = block.GetInstructions(); + std::vector lowered; + for (const auto& inst : insts) { + if (inst.GetOpcode() == Opcode::Ret) { + lowered.emplace_back(Opcode::Epilogue); + } + lowered.push_back(inst); } - lowered.push_back(inst); + insts = std::move(lowered); } - insts = std::move(lowered); } } // namespace mir diff --git a/src/mir/Lowering.cpp b/src/mir/Lowering.cpp index 9a18396..3cf005a 100644 --- a/src/mir/Lowering.cpp +++ b/src/mir/Lowering.cpp @@ -2,122 +2,467 @@ #include #include +#include +#include #include "ir/IR.h" #include "utils/Log.h" +#include namespace mir { namespace { using ValueSlotMap = std::unordered_map; +uint32_t GetTypeSize(const ir::Type* type) { + if (type->IsInt32() || type->IsFloat()) { + return 4; + } + if (type->IsPtrInt32() || type->IsPtrFloat()) { + return 8; // 64-bit pointers + } + if (type->IsArray()) { + auto* arr_ty = const_cast(type)->GetAsArrayType().get(); + return arr_ty->GetNumElements() * GetTypeSize(arr_ty->GetElementType().get()); + } + return 4; +} + +uint32_t GetAllocaSize(const ir::Instruction& inst) { + auto type = inst.GetType(); + if (type->IsPtrInt32() || type->IsPtrFloat()) { + return 4; + } + return GetTypeSize(type.get()); +} + +std::vector GetGepStrides(const ir::GetElementPtrInst& gep) { + std::vector strides; + auto curr_type = gep.GetPtr()->GetType(); + if (curr_type->IsPtrInt32() || curr_type->IsPtrFloat()) { + strides.push_back(4); + } else if (curr_type->IsArray()) { + strides.push_back(GetTypeSize(curr_type.get())); + for (size_t i = 2; i < gep.GetNumOperands(); ++i) { + curr_type = curr_type->GetAsArrayType()->GetElementType(); + strides.push_back(GetTypeSize(curr_type.get())); + } + } + return strides; +} + +void EmitAddressToReg(const ir::Value* value, PhysReg target, + const ValueSlotMap& slots, MachineBasicBlock& block) { + if (auto* alloca = dynamic_cast(value)) { + if (alloca->GetOpcode() == ir::Opcode::Alloca) { + auto it = slots.find(value); + if (it == slots.end()) { + throw std::runtime_error(FormatError("mir", "找不到局部变量的栈槽: " + value->GetName())); + } + block.Append(Opcode::AddRegImm, {Operand::Reg(target), Operand::Reg(PhysReg::X29), Operand::FrameIndex(it->second)}); + return; + } + } + + if (value->IsGlobalValue()) { + block.Append(Opcode::Adrp, {Operand::Reg(target), Operand::Global(value->GetName())}); + block.Append(Opcode::AddRegImm, {Operand::Reg(target), Operand::Reg(target), Operand::Global(value->GetName())}); + return; + } + + // Otherwise, the address itself is stored in a stack slot + auto it = slots.find(value); + if (it == slots.end()) { + throw std::runtime_error(FormatError("mir", "找不到指针的值槽: " + value->GetName())); + } + block.Append(Opcode::LoadStack, {Operand::Reg(target), Operand::FrameIndex(it->second)}); +} + void EmitValueToReg(const ir::Value* value, PhysReg target, const ValueSlotMap& slots, MachineBasicBlock& block) { if (auto* constant = dynamic_cast(value)) { - block.Append(Opcode::MovImm, - {Operand::Reg(target), Operand::Imm(constant->GetValue())}); + block.Append(Opcode::MovImm, {Operand::Reg(target), Operand::Imm(constant->GetValue())}); + return; + } + + if (auto* constant = dynamic_cast(value)) { + float fval = constant->GetValue(); + int bits; + std::memcpy(&bits, &fval, sizeof(float)); + block.Append(Opcode::MovImm, {Operand::Reg(target), Operand::Imm(bits)}); + return; + } + + if (value->IsGlobalValue()) { + EmitAddressToReg(value, target, slots, block); return; } auto it = slots.find(value); if (it == slots.end()) { - throw std::runtime_error( - FormatError("mir", "找不到值对应的栈槽: " + value->GetName())); + throw std::runtime_error(FormatError("mir", "找不到值对应的栈槽: " + value->GetName())); } - block.Append(Opcode::LoadStack, - {Operand::Reg(target), Operand::FrameIndex(it->second)}); + block.Append(Opcode::LoadStack, {Operand::Reg(target), Operand::FrameIndex(it->second)}); } void LowerInstruction(const ir::Instruction& inst, MachineFunction& function, - ValueSlotMap& slots) { - auto& block = function.GetEntry(); - + ValueSlotMap& slots, MachineBasicBlock& block) { switch (inst.GetOpcode()) { case ir::Opcode::Alloca: { - slots.emplace(&inst, function.CreateFrameIndex()); + slots.emplace(&inst, function.CreateFrameIndex(GetAllocaSize(inst))); return; } case ir::Opcode::Store: { auto& store = static_cast(inst); - auto dst = slots.find(store.GetPtr()); - if (dst == slots.end()) { - throw std::runtime_error( - FormatError("mir", "暂不支持对非栈变量地址进行写入")); + + if (auto* alloca = dynamic_cast(store.GetPtr())) { + if (alloca->GetOpcode() == ir::Opcode::Alloca) { + auto it = slots.find(alloca); + if (it != slots.end()) { + PhysReg val_reg = store.GetValue()->GetType()->IsFloat() ? PhysReg::S8 : PhysReg::W8; + EmitValueToReg(store.GetValue(), val_reg, slots, block); + block.Append(Opcode::StoreStack, {Operand::Reg(val_reg), Operand::FrameIndex(it->second)}); + return; + } + } } - EmitValueToReg(store.GetValue(), PhysReg::W8, slots, block); - block.Append(Opcode::StoreStack, - {Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst->second)}); + + // Dynamic store + PhysReg val_reg = store.GetValue()->GetType()->IsFloat() ? PhysReg::S8 : PhysReg::W8; + EmitValueToReg(store.GetValue(), val_reg, slots, block); + EmitAddressToReg(store.GetPtr(), PhysReg::X9, slots, block); + block.Append(Opcode::StrRegReg, {Operand::Reg(val_reg), Operand::Reg(PhysReg::X9)}); return; } case ir::Opcode::Load: { auto& load = static_cast(inst); - auto src = slots.find(load.GetPtr()); - if (src == slots.end()) { - throw std::runtime_error( - FormatError("mir", "暂不支持对非栈变量地址进行读取")); - } - int dst_slot = function.CreateFrameIndex(); - block.Append(Opcode::LoadStack, - {Operand::Reg(PhysReg::W8), Operand::FrameIndex(src->second)}); - block.Append(Opcode::StoreStack, - {Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)}); + int dst_slot = function.CreateFrameIndex(GetTypeSize(load.GetType().get())); slots.emplace(&inst, dst_slot); + + if (auto* alloca = dynamic_cast(load.GetPtr())) { + if (alloca->GetOpcode() == ir::Opcode::Alloca) { + auto it = slots.find(alloca); + if (it != slots.end()) { + PhysReg val_reg = load.GetType()->IsFloat() ? PhysReg::S8 : PhysReg::W8; + block.Append(Opcode::LoadStack, {Operand::Reg(val_reg), Operand::FrameIndex(it->second)}); + block.Append(Opcode::StoreStack, {Operand::Reg(val_reg), Operand::FrameIndex(dst_slot)}); + return; + } + } + } + + // Dynamic load + PhysReg val_reg = load.GetType()->IsFloat() ? PhysReg::S8 : PhysReg::W8; + EmitAddressToReg(load.GetPtr(), PhysReg::X9, slots, block); + block.Append(Opcode::LdrRegReg, {Operand::Reg(val_reg), Operand::Reg(PhysReg::X9)}); + block.Append(Opcode::StoreStack, {Operand::Reg(val_reg), Operand::FrameIndex(dst_slot)}); return; } - case ir::Opcode::Add: { + case ir::Opcode::Add: + case ir::Opcode::Sub: + case ir::Opcode::Mul: + case ir::Opcode::Div: + case ir::Opcode::Mod: { auto& bin = static_cast(inst); - int dst_slot = function.CreateFrameIndex(); + int dst_slot = function.CreateFrameIndex(4); + slots.emplace(&inst, dst_slot); + EmitValueToReg(bin.GetLhs(), PhysReg::W8, slots, block); EmitValueToReg(bin.GetRhs(), PhysReg::W9, slots, block); - block.Append(Opcode::AddRR, {Operand::Reg(PhysReg::W8), - Operand::Reg(PhysReg::W8), - Operand::Reg(PhysReg::W9)}); - block.Append(Opcode::StoreStack, - {Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)}); + + if (inst.GetOpcode() == ir::Opcode::Add) { + block.Append(Opcode::AddRR, {Operand::Reg(PhysReg::W8), Operand::Reg(PhysReg::W8), Operand::Reg(PhysReg::W9)}); + } else if (inst.GetOpcode() == ir::Opcode::Sub) { + block.Append(Opcode::SubRR, {Operand::Reg(PhysReg::W8), Operand::Reg(PhysReg::W8), Operand::Reg(PhysReg::W9)}); + } else if (inst.GetOpcode() == ir::Opcode::Mul) { + block.Append(Opcode::MulRR, {Operand::Reg(PhysReg::W8), Operand::Reg(PhysReg::W8), Operand::Reg(PhysReg::W9)}); + } else if (inst.GetOpcode() == ir::Opcode::Div) { + block.Append(Opcode::SDivRR, {Operand::Reg(PhysReg::W8), Operand::Reg(PhysReg::W8), Operand::Reg(PhysReg::W9)}); + } else if (inst.GetOpcode() == ir::Opcode::Mod) { + block.Append(Opcode::SDivRR, {Operand::Reg(PhysReg::W10), Operand::Reg(PhysReg::W8), Operand::Reg(PhysReg::W9)}); + block.Append(Opcode::MSubRRRR, {Operand::Reg(PhysReg::W8), Operand::Reg(PhysReg::W10), Operand::Reg(PhysReg::W9), Operand::Reg(PhysReg::W8)}); + } + + block.Append(Opcode::StoreStack, {Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)}); + return; + } + case ir::Opcode::FAdd: + case ir::Opcode::FSub: + case ir::Opcode::FMul: + case ir::Opcode::FDiv: { + auto& bin = static_cast(inst); + int dst_slot = function.CreateFrameIndex(4); slots.emplace(&inst, dst_slot); + + EmitValueToReg(bin.GetLhs(), PhysReg::S8, slots, block); + EmitValueToReg(bin.GetRhs(), PhysReg::S9, slots, block); + + if (inst.GetOpcode() == ir::Opcode::FAdd) { + block.Append(Opcode::FAddRRR, {Operand::Reg(PhysReg::S8), Operand::Reg(PhysReg::S8), Operand::Reg(PhysReg::S9)}); + } else if (inst.GetOpcode() == ir::Opcode::FSub) { + block.Append(Opcode::FSubRRR, {Operand::Reg(PhysReg::S8), Operand::Reg(PhysReg::S8), Operand::Reg(PhysReg::S9)}); + } else if (inst.GetOpcode() == ir::Opcode::FMul) { + block.Append(Opcode::FMulRRR, {Operand::Reg(PhysReg::S8), Operand::Reg(PhysReg::S8), Operand::Reg(PhysReg::S9)}); + } else if (inst.GetOpcode() == ir::Opcode::FDiv) { + block.Append(Opcode::FDivRRR, {Operand::Reg(PhysReg::S8), Operand::Reg(PhysReg::S8), Operand::Reg(PhysReg::S9)}); + } + + block.Append(Opcode::StoreStack, {Operand::Reg(PhysReg::S8), Operand::FrameIndex(dst_slot)}); + return; + } + case ir::Opcode::ICmpEQ: + case ir::Opcode::ICmpNE: + case ir::Opcode::ICmpLT: + case ir::Opcode::ICmpGT: + case ir::Opcode::ICmpLE: + case ir::Opcode::ICmpGE: { + auto& cmp = static_cast(inst); + int dst_slot = function.CreateFrameIndex(4); + slots.emplace(&inst, dst_slot); + + EmitValueToReg(cmp.GetLhs(), PhysReg::W8, slots, block); + EmitValueToReg(cmp.GetRhs(), PhysReg::W9, slots, block); + block.Append(Opcode::CmpRR, {Operand::Reg(PhysReg::W8), Operand::Reg(PhysReg::W9)}); + + std::string cond; + switch (inst.GetOpcode()) { + case ir::Opcode::ICmpEQ: cond = "eq"; break; + case ir::Opcode::ICmpNE: cond = "ne"; break; + case ir::Opcode::ICmpLT: cond = "lt"; break; + case ir::Opcode::ICmpGT: cond = "gt"; break; + case ir::Opcode::ICmpLE: cond = "le"; break; + case ir::Opcode::ICmpGE: cond = "ge"; break; + default: break; + } + + block.Append(Opcode::Cset, {Operand::Reg(PhysReg::W8), Operand::Cond(cond)}); + block.Append(Opcode::StoreStack, {Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)}); + return; + } + case ir::Opcode::FCmpEQ: + case ir::Opcode::FCmpNE: + case ir::Opcode::FCmpLT: + case ir::Opcode::FCmpGT: + case ir::Opcode::FCmpLE: + case ir::Opcode::FCmpGE: { + auto& cmp = static_cast(inst); + int dst_slot = function.CreateFrameIndex(4); + slots.emplace(&inst, dst_slot); + + EmitValueToReg(cmp.GetLhs(), PhysReg::S8, slots, block); + EmitValueToReg(cmp.GetRhs(), PhysReg::S9, slots, block); + block.Append(Opcode::FCmpRR, {Operand::Reg(PhysReg::S8), Operand::Reg(PhysReg::S9)}); + + std::string cond; + switch (inst.GetOpcode()) { + case ir::Opcode::FCmpEQ: cond = "eq"; break; + case ir::Opcode::FCmpNE: cond = "ne"; break; + case ir::Opcode::FCmpLT: cond = "mi"; break; + case ir::Opcode::FCmpGT: cond = "gt"; break; + case ir::Opcode::FCmpLE: cond = "ls"; break; + case ir::Opcode::FCmpGE: cond = "ge"; break; + default: break; + } + + block.Append(Opcode::Cset, {Operand::Reg(PhysReg::W8), Operand::Cond(cond)}); + block.Append(Opcode::StoreStack, {Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)}); + return; + } + case ir::Opcode::ZExt: { + auto& cast = static_cast(inst); + int dst_slot = function.CreateFrameIndex(4); + slots.emplace(&inst, dst_slot); + + EmitValueToReg(cast.GetValue(), PhysReg::W8, slots, block); + block.Append(Opcode::ZExt, {Operand::Reg(PhysReg::W8), Operand::Reg(PhysReg::W8)}); + block.Append(Opcode::StoreStack, {Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)}); + return; + } + case ir::Opcode::SIToFP: { + auto& cast = static_cast(inst); + int dst_slot = function.CreateFrameIndex(4); + slots.emplace(&inst, dst_slot); + + EmitValueToReg(cast.GetValue(), PhysReg::W8, slots, block); + block.Append(Opcode::SIToFP, {Operand::Reg(PhysReg::S8), Operand::Reg(PhysReg::W8)}); + block.Append(Opcode::StoreStack, {Operand::Reg(PhysReg::S8), Operand::FrameIndex(dst_slot)}); + return; + } + case ir::Opcode::FPToSI: { + auto& cast = static_cast(inst); + int dst_slot = function.CreateFrameIndex(4); + slots.emplace(&inst, dst_slot); + + EmitValueToReg(cast.GetValue(), PhysReg::S8, slots, block); + block.Append(Opcode::FPToSI, {Operand::Reg(PhysReg::W8), Operand::Reg(PhysReg::S8)}); + block.Append(Opcode::StoreStack, {Operand::Reg(PhysReg::W8), Operand::FrameIndex(dst_slot)}); + return; + } + case ir::Opcode::Br: { + auto& br = static_cast(inst); + std::cerr << "DEBUG: Br is_conditional=" << br.IsConditional() << std::endl; + if (br.IsConditional()) { + std::cerr << "DEBUG: Cond pointer=" << br.GetCondition() << std::endl; + std::cerr << "DEBUG: True pointer=" << br.GetIfTrue() << " name=" << (br.GetIfTrue() ? br.GetIfTrue()->GetName() : "") << std::endl; + std::cerr << "DEBUG: False pointer=" << br.GetIfFalse() << " name=" << (br.GetIfFalse() ? br.GetIfFalse()->GetName() : "") << std::endl; + EmitValueToReg(br.GetCondition(), PhysReg::W8, slots, block); + block.Append(Opcode::MovImm, {Operand::Reg(PhysReg::W9), Operand::Imm(0)}); + block.Append(Opcode::CmpRR, {Operand::Reg(PhysReg::W8), Operand::Reg(PhysReg::W9)}); + block.Append(Opcode::BCond, {Operand::Cond("ne"), Operand::Label(br.GetIfTrue()->GetName())}); + block.Append(Opcode::B, {Operand::Label(br.GetIfFalse()->GetName())}); + } else { + std::cerr << "DEBUG: Dest pointer=" << br.GetDest() << " name=" << (br.GetDest() ? br.GetDest()->GetName() : "") << std::endl; + block.Append(Opcode::B, {Operand::Label(br.GetDest()->GetName())}); + } return; } case ir::Opcode::Ret: { auto& ret = static_cast(inst); - EmitValueToReg(ret.GetValue(), PhysReg::W0, slots, block); + if (ret.GetValue()) { + PhysReg ret_reg = ret.GetValue()->GetType()->IsFloat() ? PhysReg::S0 : PhysReg::W0; + EmitValueToReg(ret.GetValue(), ret_reg, slots, block); + } block.Append(Opcode::Ret); return; } - case ir::Opcode::Sub: - case ir::Opcode::Mul: - throw std::runtime_error(FormatError("mir", "暂不支持该二元运算")); + case ir::Opcode::Call: { + auto& call = static_cast(inst); + int dst_slot = -1; + if (!call.GetType()->IsVoid()) { + dst_slot = function.CreateFrameIndex(GetTypeSize(call.GetType().get())); + slots.emplace(&inst, dst_slot); + } + + int int_idx = 0; + int float_idx = 0; + for (size_t i = 1; i < call.GetNumOperands(); ++i) { + auto* arg = call.GetOperand(i); + if (arg->GetType()->IsFloat()) { + PhysReg reg = static_cast(static_cast(PhysReg::S0) + float_idx); + EmitValueToReg(arg, reg, slots, block); + float_idx++; + } else { + PhysReg reg = (arg->GetType()->IsPtrInt32() || arg->GetType()->IsPtrFloat()) + ? static_cast(static_cast(PhysReg::X0) + int_idx) + : static_cast(static_cast(PhysReg::W0) + int_idx); + EmitValueToReg(arg, reg, slots, block); + int_idx++; + } + } + + block.Append(Opcode::Call, {Operand::Global(call.GetFunction()->GetName())}); + + if (dst_slot != -1) { + if (call.GetType()->IsFloat()) { + block.Append(Opcode::StoreStack, {Operand::Reg(PhysReg::S0), Operand::FrameIndex(dst_slot)}); + } else { + PhysReg ret_reg = (call.GetType()->IsPtrInt32() || call.GetType()->IsPtrFloat()) ? PhysReg::X0 : PhysReg::W0; + block.Append(Opcode::StoreStack, {Operand::Reg(ret_reg), Operand::FrameIndex(dst_slot)}); + } + } + return; + } + case ir::Opcode::GEP: { + auto& gep = static_cast(inst); + int dst_slot = function.CreateFrameIndex(8); + slots.emplace(&inst, dst_slot); + + // Load base pointer address into X8 + if (dynamic_cast(gep.GetPtr()) || gep.GetPtr()->IsGlobalValue()) { + EmitAddressToReg(gep.GetPtr(), PhysReg::X8, slots, block); + } else { + EmitValueToReg(gep.GetPtr(), PhysReg::X8, slots, block); + } + + auto strides = GetGepStrides(gep); + for (size_t i = 1; i < gep.GetNumOperands(); ++i) { + auto* idx = gep.GetOperand(i); + uint32_t stride = strides.at(i - 1); + + // Skip if offset index is constant 0 + if (auto* ci = dynamic_cast(idx)) { + if (ci->GetValue() == 0) { + continue; + } + } + + EmitValueToReg(idx, PhysReg::W9, slots, block); + if (stride > 1) { + block.Append(Opcode::MovImm, {Operand::Reg(PhysReg::W10), Operand::Imm(stride)}); + block.Append(Opcode::MulRR, {Operand::Reg(PhysReg::W9), Operand::Reg(PhysReg::W9), Operand::Reg(PhysReg::W10)}); + } + + // Extend W9 to X9 and add to base address X8 + block.Append(Opcode::ZExt, {Operand::Reg(PhysReg::X9), Operand::Reg(PhysReg::W9)}); + block.Append(Opcode::AddRR, {Operand::Reg(PhysReg::X8), Operand::Reg(PhysReg::X8), Operand::Reg(PhysReg::X9)}); + } + + // Store address into GEP's stack slot + block.Append(Opcode::StoreStack, {Operand::Reg(PhysReg::X8), Operand::FrameIndex(dst_slot)}); + return; + } } - throw std::runtime_error(FormatError("mir", "暂不支持该 IR 指令")); + throw std::runtime_error(FormatError("mir", "暂不支持该 IR 指令: " + std::to_string(static_cast(inst.GetOpcode())))); } } // namespace -std::unique_ptr LowerToMIR(const ir::Module& module) { +std::vector> LowerToMIR(const ir::Module& module) { DefaultContext(); + std::vector> mfuncs; - if (module.GetFunctions().size() != 1) { - throw std::runtime_error(FormatError("mir", "暂不支持多个函数")); + for (const auto& funcPtr : module.GetFunctions()) { + const auto& func = *funcPtr; + if (func.GetBlocks().empty()) continue; // skip declarations + + auto machine_func = std::make_unique(func.GetName()); + ValueSlotMap slots; + + // First, create all basic blocks in MachineFunction + std::unordered_map bb_map; + machine_func->GetBlocks().reserve(func.GetBlocks().size()); + for (const auto& bbPtr : func.GetBlocks()) { + auto& mbb = machine_func->CreateBlock(bbPtr->GetName()); + bb_map[bbPtr.get()] = &mbb; + } + + auto& entry_block = *bb_map.at(func.GetEntry()); + + // Lower function arguments at the start of the entry block + const auto& args = func.GetArguments(); + int int_idx = 0; + int float_idx = 0; + for (const auto& arg : args) { + int slot = machine_func->CreateFrameIndex(GetTypeSize(arg->GetType().get())); + slots.emplace(arg.get(), slot); + + if (arg->GetType()->IsFloat()) { + PhysReg reg = static_cast(static_cast(PhysReg::S0) + float_idx); + entry_block.Append(Opcode::StoreStack, {Operand::Reg(reg), Operand::FrameIndex(slot)}); + float_idx++; + } else { + PhysReg reg = (arg->GetType()->IsPtrInt32() || arg->GetType()->IsPtrFloat()) + ? static_cast(static_cast(PhysReg::X0) + int_idx) + : static_cast(static_cast(PhysReg::W0) + int_idx); + entry_block.Append(Opcode::StoreStack, {Operand::Reg(reg), Operand::FrameIndex(slot)}); + int_idx++; + } + } + + // Now, lower all instructions block by block + for (const auto& bbPtr : func.GetBlocks()) { + auto& mbb = *bb_map.at(bbPtr.get()); + for (const auto& inst : bbPtr->GetInstructions()) { + LowerInstruction(*inst, *machine_func, slots, mbb); + } + } + + mfuncs.push_back(std::move(machine_func)); } - const auto& func = *module.GetFunctions().front(); - if (func.GetName() != "main") { - throw std::runtime_error(FormatError("mir", "暂不支持非 main 函数")); - } - - auto machine_func = std::make_unique(func.GetName()); - ValueSlotMap slots; - const auto* entry = func.GetEntry(); - if (!entry) { - throw std::runtime_error(FormatError("mir", "IR 函数缺少入口基本块")); - } - - for (const auto& inst : entry->GetInstructions()) { - LowerInstruction(*inst, *machine_func, slots); - } - - return machine_func; + return mfuncs; } } // namespace mir diff --git a/src/mir/MIRFunction.cpp b/src/mir/MIRFunction.cpp index 334f8cc..cea58aa 100644 --- a/src/mir/MIRFunction.cpp +++ b/src/mir/MIRFunction.cpp @@ -8,7 +8,12 @@ namespace mir { MachineFunction::MachineFunction(std::string name) - : name_(std::move(name)), entry_("entry") {} + : name_(std::move(name)) {} + +MachineBasicBlock& MachineFunction::CreateBlock(std::string name) { + blocks_.emplace_back(std::move(name)); + return blocks_.back(); +} int MachineFunction::CreateFrameIndex(int size) { int index = static_cast(frame_slots_.size()); diff --git a/src/mir/MIRInstr.cpp b/src/mir/MIRInstr.cpp index 0a21a03..69da7ae 100644 --- a/src/mir/MIRInstr.cpp +++ b/src/mir/MIRInstr.cpp @@ -4,10 +4,12 @@ namespace mir { -Operand::Operand(Kind kind, PhysReg reg, int imm) - : kind_(kind), reg_(reg), imm_(imm) {} +Operand::Operand(Kind kind, PhysReg reg, int imm, std::string str) + : kind_(kind), reg_(reg), imm_(imm), str_(std::move(str)) {} -Operand Operand::Reg(PhysReg reg) { return Operand(Kind::Reg, reg, 0); } +Operand Operand::Reg(PhysReg reg) { + return Operand(Kind::Reg, reg, 0); +} Operand Operand::Imm(int value) { return Operand(Kind::Imm, PhysReg::W0, value); @@ -17,6 +19,18 @@ Operand Operand::FrameIndex(int index) { return Operand(Kind::FrameIndex, PhysReg::W0, index); } +Operand Operand::Global(std::string name) { + return Operand(Kind::Global, PhysReg::W0, 0, std::move(name)); +} + +Operand Operand::Label(std::string name) { + return Operand(Kind::Label, PhysReg::W0, 0, std::move(name)); +} + +Operand Operand::Cond(std::string cond) { + return Operand(Kind::Cond, PhysReg::W0, 0, std::move(cond)); +} + MachineInstr::MachineInstr(Opcode opcode, std::vector operands) : opcode_(opcode), operands_(std::move(operands)) {} diff --git a/src/mir/RegAlloc.cpp b/src/mir/RegAlloc.cpp index 5dc5d2b..999e1c1 100644 --- a/src/mir/RegAlloc.cpp +++ b/src/mir/RegAlloc.cpp @@ -8,26 +8,19 @@ namespace mir { namespace { bool IsAllowedReg(PhysReg reg) { - switch (reg) { - case PhysReg::W0: - case PhysReg::W8: - case PhysReg::W9: - case PhysReg::X29: - case PhysReg::X30: - case PhysReg::SP: - return true; - } - return false; + return true; // We allow all defined physical registers } } // namespace void RunRegAlloc(MachineFunction& function) { - for (const auto& inst : function.GetEntry().GetInstructions()) { - for (const auto& operand : inst.GetOperands()) { - if (operand.GetKind() == Operand::Kind::Reg && - !IsAllowedReg(operand.GetReg())) { - throw std::runtime_error(FormatError("mir", "寄存器分配失败")); + for (const auto& block : function.GetBlocks()) { + for (const auto& inst : block.GetInstructions()) { + for (const auto& operand : inst.GetOperands()) { + if (operand.GetKind() == Operand::Kind::Reg && + !IsAllowedReg(operand.GetReg())) { + throw std::runtime_error(FormatError("mir", "寄存器分配失败")); + } } } } diff --git a/src/mir/Register.cpp b/src/mir/Register.cpp index 7530470..2ae09d4 100644 --- a/src/mir/Register.cpp +++ b/src/mir/Register.cpp @@ -1,6 +1,7 @@ #include "mir/MIR.h" #include +#include #include "utils/Log.h" @@ -8,18 +9,77 @@ namespace mir { const char* PhysRegName(PhysReg reg) { switch (reg) { - case PhysReg::W0: - return "w0"; - case PhysReg::W8: - return "w8"; - case PhysReg::W9: - return "w9"; - case PhysReg::X29: - return "x29"; - case PhysReg::X30: - return "x30"; - case PhysReg::SP: - return "sp"; + case PhysReg::W0: return "w0"; + case PhysReg::W1: return "w1"; + case PhysReg::W2: return "w2"; + case PhysReg::W3: return "w3"; + case PhysReg::W4: return "w4"; + case PhysReg::W5: return "w5"; + case PhysReg::W6: return "w6"; + case PhysReg::W7: return "w7"; + case PhysReg::W8: return "w8"; + case PhysReg::W9: return "w9"; + case PhysReg::W10: return "w10"; + case PhysReg::W11: return "w11"; + case PhysReg::W12: return "w12"; + case PhysReg::W13: return "w13"; + case PhysReg::W14: return "w14"; + case PhysReg::W15: return "w15"; + case PhysReg::W19: return "w19"; + case PhysReg::W20: return "w20"; + case PhysReg::W21: return "w21"; + case PhysReg::W22: return "w22"; + case PhysReg::W23: return "w23"; + case PhysReg::W24: return "w24"; + case PhysReg::W25: return "w25"; + case PhysReg::W26: return "w26"; + case PhysReg::W27: return "w27"; + case PhysReg::W28: return "w28"; + case PhysReg::X0: return "x0"; + case PhysReg::X1: return "x1"; + case PhysReg::X2: return "x2"; + case PhysReg::X3: return "x3"; + case PhysReg::X4: return "x4"; + case PhysReg::X5: return "x5"; + case PhysReg::X6: return "x6"; + case PhysReg::X7: return "x7"; + case PhysReg::X8: return "x8"; + case PhysReg::X9: return "x9"; + case PhysReg::X10: return "x10"; + case PhysReg::X11: return "x11"; + case PhysReg::X12: return "x12"; + case PhysReg::X13: return "x13"; + case PhysReg::X14: return "x14"; + case PhysReg::X15: return "x15"; + case PhysReg::X19: return "x19"; + case PhysReg::X20: return "x20"; + case PhysReg::X21: return "x21"; + case PhysReg::X22: return "x22"; + case PhysReg::X23: return "x23"; + case PhysReg::X24: return "x24"; + case PhysReg::X25: return "x25"; + case PhysReg::X26: return "x26"; + case PhysReg::X27: return "x27"; + case PhysReg::X28: return "x28"; + case PhysReg::S0: return "s0"; + case PhysReg::S1: return "s1"; + case PhysReg::S2: return "s2"; + case PhysReg::S3: return "s3"; + case PhysReg::S4: return "s4"; + case PhysReg::S5: return "s5"; + case PhysReg::S6: return "s6"; + case PhysReg::S7: return "s7"; + case PhysReg::S8: return "s8"; + case PhysReg::S9: return "s9"; + case PhysReg::S10: return "s10"; + case PhysReg::S11: return "s11"; + case PhysReg::S12: return "s12"; + case PhysReg::S13: return "s13"; + case PhysReg::S14: return "s14"; + case PhysReg::S15: return "s15"; + case PhysReg::X29: return "x29"; + case PhysReg::X30: return "x30"; + case PhysReg::SP: return "sp"; } throw std::runtime_error(FormatError("mir", "未知物理寄存器")); } diff --git a/sylib/sylib.c b/sylib/sylib.c index 7f26d0b..6a4fd1a 100644 --- a/sylib/sylib.c +++ b/sylib/sylib.c @@ -1,4 +1,77 @@ -// SysY 运行库实现: -// - 按实验/评测规范提供 I/O 等函数实现 -// - 与编译器生成的目标代码链接,支撑运行时行为 +#include +#include +int getint() { + int x; + if (scanf("%d", &x) != 1) return 0; + return x; +} + +int getch() { + return getchar(); +} + +float getfloat() { + double x; + if (scanf("%lf", &x) != 1) return 0.0f; + return (float)x; +} + +int getarray(int a[]) { + int n; + if (scanf("%d", &n) != 1) return 0; + for (int i = 0; i < n; i++) { + if (scanf("%d", &a[i]) != 1) break; + } + return n; +} + +int getfarray(float a[]) { + int n; + if (scanf("%d", &n) != 1) return 0; + for (int i = 0; i < n; i++) { + double val; + if (scanf("%lf", &val) != 1) break; + a[i] = (float)val; + } + return n; +} + +void putint(int x) { + printf("%d", x); +} + +void putch(int x) { + putchar(x); +} + +void putfloat(float x) { + printf("%a", x); +} + +void putarray(int n, int a[]) { + printf("%d:", n); + for (int i = 0; i < n; i++) { + printf(" %d", a[i]); + } + printf("\n"); +} + +void putfarray(int n, float a[]) { + printf("%d:", n); + for (int i = 0; i < n; i++) { + printf(" %a", a[i]); + } + printf("\n"); +} + +struct timeval start, stop; +void starttime() { + gettimeofday(&start, NULL); +} + +void stoptime() { + gettimeofday(&stop, NULL); + long long duration = (stop.tv_sec - start.tv_sec) * 1000000LL + (stop.tv_usec - start.tv_usec); + printf("timer: %lld us\n", duration); +}