diff --git a/.gitignore b/.gitignore index 435dca3..774f5a6 100644 --- a/.gitignore +++ b/.gitignore @@ -52,4 +52,5 @@ __init__.py .DS_* -antlr/ \ No newline at end of file +antlr/ +.clang-format diff --git a/Pass_ID_List.md b/Pass_ID_List.md new file mode 100644 index 0000000..14f3379 --- /dev/null +++ b/Pass_ID_List.md @@ -0,0 +1,6 @@ +# 记录中端遍的开发进度 + +| 名称 | 优化级别 | 开发进度 | +| ------------ | ------------ | ---------- | +| CFG优化 | 函数级 | 已完成 | +| DCE | 函数级 | 待测试 | \ No newline at end of file diff --git a/README.md b/README.md index 65660b9..085b55b 100644 --- a/README.md +++ b/README.md @@ -37,4 +37,13 @@ mysysy/ $ bash setup.sh ``` ### 配套脚本 - (TODO: 需要完善) \ No newline at end of file + (TODO: 需要完善) + + +### TODO_list: + +除开注释中的TODO后续时间充足可以考虑的TODO: + +- store load指令由于gep指令的引入, 维度信息的记录是非必须的, 考虑删除 + +- use def关系经过mem2reg和phi函数明确转换为ssa形式, 以及函数参数通过value数组明确定义, 使得基本块的args参数信息记录非必须, 考虑删除 \ No newline at end of file diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 1c17eb1..01f54d8 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -21,9 +21,11 @@ add_executable(sysyc IR.cpp SysYIRGenerator.cpp SysYIRPrinter.cpp - SysYIROptPre.cpp - SysYIRAnalyser.cpp - # DeadCodeElimination.cpp + SysYIRCFGOpt.cpp + Pass.cpp + Dom.cpp + Liveness.cpp + DCE.cpp AddressCalculationExpansion.cpp # Mem2Reg.cpp # Reg2Mem.cpp diff --git a/src/DCE.cpp b/src/DCE.cpp new file mode 100644 index 0000000..db5d966 --- /dev/null +++ b/src/DCE.cpp @@ -0,0 +1,140 @@ +#include "DCE.h" // 包含DCE遍的头文件 +#include "IR.h" // 包含IR相关的定义 +#include "SysYIROptUtils.h" // 包含SysY IR优化工具类的定义 +#include // 用于断言 +#include // 用于调试输出 +#include // 包含set,虽然DCEContext内部用unordered_set,但这里保留 + +namespace sysy { + +// DCE 遍的静态 ID +void *DCE::ID = (void *)&DCE::ID; + +// ====================================================================== +// DCEContext 类的实现 +// 封装了 DCE 遍的核心逻辑和状态,确保每次函数优化运行时状态独立 +// ====================================================================== + +// DCEContext 的 run 方法实现 +void DCEContext::run(Function *func, AnalysisManager *AM, bool &changed) { + // 清空活跃指令集合,确保每次运行都是新的状态 + alive_insts.clear(); + + // 第一次遍历:扫描所有指令,识别“天然活跃”的指令并将其及其依赖标记为活跃 + // 使用 func->getBasicBlocks() 获取基本块列表,保留用户风格 + auto basicBlocks = func->getBasicBlocks(); + for (auto &basicBlock : basicBlocks) { + // 确保基本块有效 + if (!basicBlock) + continue; + // 使用 basicBlock->getInstructions() 获取指令列表,保留用户风格 + for (auto &inst : basicBlock->getInstructions()) { + // 确保指令有效 + if (!inst) + continue; + // 调用 DCEContext 自身的 isAlive 和 addAlive 方法 + if (isAlive(inst.get())) { + addAlive(inst.get()); + } + } + } + + // 第二次遍历:删除所有未被标记为活跃的指令。 + for (auto &basicBlock : basicBlocks) { + if (!basicBlock) + continue; + // 使用传统的迭代器循环,并手动管理迭代器, + // 以便在删除元素后正确前进。保留用户风格 + for (auto instIter = basicBlock->getInstructions().begin(); instIter != basicBlock->getInstructions().end();) { + auto &inst = *instIter; + Instruction *currentInst = inst.get(); + // 如果指令不在活跃集合中,则删除它。 + // 分支和返回指令由 isAlive 处理,并会被保留。 + if (alive_insts.count(currentInst) == 0) { + // 删除指令,保留用户风格的 SysYIROptUtils::usedelete 和 erase + changed = true; // 标记 IR 已被修改 + SysYIROptUtils::usedelete(currentInst); + instIter = basicBlock->getInstructions().erase(instIter); // 删除后返回下一个迭代器 + } else { + ++instIter; // 指令活跃,移动到下一个 + } + } + } +} + +// 判断指令是否是“天然活跃”的实现 +// 只有具有副作用的指令(如存储、函数调用、原子操作) +// 和控制流指令(如分支、返回)是天然活跃的。 +bool DCEContext::isAlive(Instruction *inst) { + // TODO: 后续程序并发考虑原子操作 + // 其结果不被其他指令使用的指令(例如 StoreInst, BranchInst, ReturnInst)。 + // dynamic_cast(inst) 检查是否是函数调用指令, + // 函数调用通常有副作用。 + // 终止指令 (BranchInst, ReturnInst) 必须是活跃的,因为它控制了程序的执行流程。 + // 保留用户提供的 isAlive 逻辑 + bool isBranchOrReturn = inst->isBranch() || inst->isReturn(); + bool isCall = inst->isCall(); + bool isStoreOrMemset = inst->isStore() || inst->isMemset(); + return isBranchOrReturn || isCall || isStoreOrMemset; +} + +// 递归地将活跃指令及其依赖加入到 alive_insts 集合中 +void DCEContext::addAlive(Instruction *inst) { + // 如果指令已经存在于活跃集合中,则无需重复处理 + if (alive_insts.count(inst) > 0) { + return; + } + // 将当前指令标记为活跃 + alive_insts.insert(inst); + // 遍历当前指令的所有操作数 + // 保留用户提供的 getOperands() 和 getValue() + for (auto operand : inst->getOperands()) { + // 如果操作数是一个指令(即它是一个值的定义), + // 并且它还没有被标记为活跃 + if (auto opInst = dynamic_cast(operand->getValue())) { + addAlive(opInst); // 递归地将操作数指令标记为活跃 + } + } +} + +// ====================================================================== +// DCE Pass 类的实现 +// 主要负责与 PassManager 交互,创建 DCEContext 实例并运行优化 +// ====================================================================== + +// DCE 遍的 runOnFunction 方法实现 +bool DCE::runOnFunction(Function *func, AnalysisManager &AM) { + + DCEContext ctx; + bool changed = false; + ctx.run(func, &AM, changed); // 运行 DCE 优化 + + // 如果 IR 被修改,则使相关的分析结果失效 + if (changed) { + // DCE 会删除指令,这会影响数据流分析,尤其是活跃性分析。 + // 如果删除导致基本块变空,也可能间接影响 CFG 和支配树。 + // AM.invalidateAnalysis(&LivenessAnalysisPass::ID, func); // 活跃性分析失效 + // AM.invalidateAnalysis(&DominatorTreeAnalysisPass::ID, func); // 支配树分析可能失效 + // 其他所有依赖于数据流或 IR 结构的分析都可能失效。 + } + return changed; +} + +// 声明DCE遍的分析依赖和失效信息 +void DCE::getAnalysisUsage(std::set &analysisDependencies, std::set &analysisInvalidations) const { + // DCE不依赖特定的分析结果,它通过遍历和副作用判断来工作。 + + // DCE会删除指令,这会影响许多分析结果。 + // 至少,它会影响活跃性分析、支配树、控制流图(如果删除导致基本块为空并被合并)。 + // 假设存在LivenessAnalysisPass和DominatorTreeAnalysisPass + // analysisInvalidations.insert(&LivenessAnalysisPass::ID); + // analysisInvalidations.insert(&DominatorTreeAnalysisPass::ID); + // 任何改变IR结构的优化,都可能导致通用分析(如活跃性、支配树、循环信息)失效。 + // 最保守的做法是使所有函数粒度的分析失效,或者只声明你明确知道会受影响的分析。 + // 考虑到这个DCE仅删除指令,如果它不删除基本块,CFG可能不变,但数据流分析会失效。 + // 对于更激进的DCE(如ADCE),CFG也会改变。 + // 这里我们假设它主要影响数据流分析,并且可能间接影响CFG相关分析。 + // 如果有SideEffectInfo,它也可能被修改,但通常SideEffectInfo是静态的,不因DCE而变。 +} + +} // namespace sysy diff --git a/src/DeadCodeElimination.cpp b/src/DeadCodeElimination.cpp deleted file mode 100644 index ffe6022..0000000 --- a/src/DeadCodeElimination.cpp +++ /dev/null @@ -1,276 +0,0 @@ -#include "DeadCodeElimination.h" -#include - -extern int DEBUG; -namespace sysy { - -void DeadCodeElimination::runDCEPipeline() { - const auto& functions = pModule->getFunctions(); - for (const auto& function : functions) { - const auto& func = function.second; - bool changed = true; - while (changed) { - changed = false; - eliminateDeadStores(func.get(), changed); - eliminateDeadLoads(func.get(), changed); - eliminateDeadAllocas(func.get(), changed); - eliminateDeadRedundantLoadStore(func.get(), changed); - eliminateDeadGlobals(changed); - } - } -} - -// 消除无用存储 消除条件: -// 存储的目标指针(pointer)不是全局变量(!isGlobal(pointer))。 -// 存储的目标指针不是数组参数(!isArr(pointer) 或不在函数参数列表里)。 -// 该指针的所有使用者(uses)仅限 alloca 或 store(即没有 load 或其他指令使用它)。 -void DeadCodeElimination::eliminateDeadStores(Function* func, bool& changed) { - for (const auto& block : func->getBasicBlocks()) { - auto& instrs = block->getInstructions(); - for (auto iter = instrs.begin(); iter != instrs.end();) { - auto inst = iter->get(); - if (!inst->isStore()) { - ++iter; - continue; - } - - auto storeInst = dynamic_cast(inst); - auto pointer = storeInst->getPointer(); - // 如果是全局变量或者是函数的数组参数 - if (isGlobal(pointer) || (isArr(pointer) && - std::find(func->getEntryBlock()->getArguments().begin(), - func->getEntryBlock()->getArguments().end(), - pointer) != func->getEntryBlock()->getArguments().end())) { - ++iter; - continue; - } - - bool changetag = true; - for (auto& use : pointer->getUses()) { - // 依次判断store的指针是否被其他指令使用 - auto user = use->getUser(); - auto userInst = dynamic_cast(user); - // 如果使用store的指针的指令不是Alloca或Store,则不删除 - if (userInst != nullptr && !userInst->isAlloca() && !userInst->isStore()) { - changetag = false; - break; - } - } - - if (changetag) { - changed = true; - if(DEBUG){ - std::cout << "=== Dead Store Found ===\n"; - SysYPrinter::printInst(storeInst); - } - usedelete(storeInst); - iter = instrs.erase(iter); - } else { - ++iter; - } - } - } -} -// 消除无用加载 消除条件: -// 该指令的结果未被使用(inst->getUses().empty())。 -void DeadCodeElimination::eliminateDeadLoads(Function* func, bool& changed) { - for (const auto& block : func->getBasicBlocks()) { - auto& instrs = block->getInstructions(); - for (auto iter = instrs.begin(); iter != instrs.end();) { - auto inst = iter->get(); - if (inst->isBinary() || inst->isUnary() || inst->isLoad()) { - if (inst->getUses().empty()) { - changed = true; - if(DEBUG){ - std::cout << "=== Dead Load Binary Unary Found ===\n"; - SysYPrinter::printInst(inst); - } - usedelete(inst); - iter = instrs.erase(iter); - continue; - } - } - ++iter; - } - } -} - -// 消除无用加载 消除条件: -// 该 alloca 未被任何指令使用(allocaInst->getUses().empty())。 -// 该 alloca 不是函数的参数(不在 entry 块的参数列表里)。 -void DeadCodeElimination::eliminateDeadAllocas(Function* func, bool& changed) { - for (const auto& block : func->getBasicBlocks()) { - auto& instrs = block->getInstructions(); - for (auto iter = instrs.begin(); iter != instrs.end();) { - auto inst = iter->get(); - if (inst->isAlloca()) { - auto allocaInst = dynamic_cast(inst); - if (allocaInst->getUses().empty() && - std::find(func->getEntryBlock()->getArguments().begin(), - func->getEntryBlock()->getArguments().end(), - allocaInst) == func->getEntryBlock()->getArguments().end()) { - changed = true; - if(DEBUG){ - std::cout << "=== Dead Alloca Found ===\n"; - SysYPrinter::printInst(inst); - } - usedelete(inst); - iter = instrs.erase(iter); - continue; - } - } - ++iter; - } - } -} - -void DeadCodeElimination::eliminateDeadIndirectiveAllocas(Function* func, bool& changed) { - // 删除mem2reg时引入的且现在已经没有value使用了的隐式alloca - FunctionAnalysisInfo* funcInfo = pCFA->getFunctionAnalysisInfo(func); - for (auto it = funcInfo->getIndirectAllocas().begin(); it != funcInfo->getIndirectAllocas().end();) { - auto &allocaInst = *it; - if (allocaInst->getUses().empty()) { - changed = true; - if(DEBUG){ - std::cout << "=== Dead Indirect Alloca Found ===\n"; - SysYPrinter::printInst(allocaInst.get()); - } - it = funcInfo->getIndirectAllocas().erase(it); - } else { - ++it; - } - } -} - -// 该全局变量未被任何指令使用(global->getUses().empty())。 -void DeadCodeElimination::eliminateDeadGlobals(bool& changed) { - auto& globals = pModule->getGlobals(); - for (auto it = globals.begin(); it != globals.end();) { - auto& global = *it; - if (global->getUses().empty()) { - changed = true; - if(DEBUG){ - std::cout << "=== Dead Global Found ===\n"; - SysYPrinter::printValue(global.get()); - } - it = globals.erase(it); - } else { - ++it; - } - } -} - -// 消除冗余加载和存储 消除条件: -// phi 指令的目标指针仅被该 phi 使用(无其他 store/load 使用)。 -// memset 指令的目标指针未被使用(pointer->getUses().empty()) -// store -> load -> store 模式 -void DeadCodeElimination::eliminateDeadRedundantLoadStore(Function* func, bool& changed) { - for (const auto& block : func->getBasicBlocks()) { - auto& instrs = block->getInstructions(); - for (auto iter = instrs.begin(); iter != instrs.end();) { - auto inst = iter->get(); - if (inst->isPhi()) { - auto phiInst = dynamic_cast(inst); - auto pointer = phiInst->getPointer(); - bool tag = true; - for (const auto& use : pointer->getUses()) { - auto user = use->getUser(); - if (user != inst) { - tag = false; - break; - } - } - /// 如果 pointer 仅被该 phi 使用,可以删除 ph - if (tag) { - changed = true; - usedelete(inst); - iter = instrs.erase(iter); - continue; - } - // 数组指令还不完善,不保证memset优化效果 - } else if (inst->isMemset()) { - auto memsetInst = dynamic_cast(inst); - auto pointer = memsetInst->getPointer(); - if (pointer->getUses().empty()) { - changed = true; - usedelete(inst); - iter = instrs.erase(iter); - continue; - } - }else if(inst->isLoad()) { - if (iter != instrs.begin()) { - auto loadInst = dynamic_cast(inst); - auto loadPointer = loadInst->getPointer(); - // TODO:store -> load -> store 模式 - auto prevIter = std::prev(iter); - auto prevInst = prevIter->get(); - if (prevInst->isStore()) { - auto prevStore = dynamic_cast(prevInst); - auto prevStorePointer = prevStore->getPointer(); - auto prevStoreValue = prevStore->getOperand(0); - // 确保前一个 store 不是数组操作 - if (prevStore->getIndices().empty()) { - // 检查后一条指令是否是 store 同一个值 - auto nextIter = std::next(iter); - if (nextIter != instrs.end()) { - auto nextInst = nextIter->get(); - if (nextInst->isStore()) { - auto nextStore = dynamic_cast(nextInst); - auto nextStorePointer = nextStore->getPointer(); - auto nextStoreValue = nextStore->getOperand(0); - // 确保后一个 store 不是数组操作 - if (nextStore->getIndices().empty()) { - // 判断优化条件: - // 1. prevStore 的指针操作数 == load 的指针操作数 - // 2. nextStore 的值操作数 == load 指令本身 - if (prevStorePointer == loadPointer && - nextStoreValue == loadInst) { - // 可以优化直接把prevStorePointer的值存到nextStorePointer - changed = true; - nextStore->setOperand(0, prevStoreValue); - if(DEBUG){ - std::cout << "=== Dead Store Load Store Found(now only del Load) ===\n"; - SysYPrinter::printInst(prevStore); - SysYPrinter::printInst(loadInst); - SysYPrinter::printInst(nextStore); - } - usedelete(loadInst); - iter = instrs.erase(iter); - // 删除 prevStore 这里是不是可以留给删除无用store处理? - // if (prevStore->getUses().empty()) { - // usedelete(prevStore); - // instrs.erase(prevIter); // 删除 prevStore - // } - continue; // 跳过 ++iter,因为已经移动迭代器 - } - } - } - } - } - } - } - } - ++iter; - } - } -} - - -bool DeadCodeElimination::isGlobal(Value *val){ - auto gval = dynamic_cast(val); - return gval != nullptr; -} - -bool DeadCodeElimination::isArr(Value *val){ - auto aval = dynamic_cast(val); - return aval != nullptr && aval->getNumDims() != 0; -} - -void DeadCodeElimination::usedelete(Instruction *instr){ - for (auto &use1 : instr->getOperands()) { - auto val1 = use1->getValue(); - val1->removeUse(use1); - } -} - -} // namespace sysy \ No newline at end of file diff --git a/src/Dom.cpp b/src/Dom.cpp new file mode 100644 index 0000000..d476c49 --- /dev/null +++ b/src/Dom.cpp @@ -0,0 +1,180 @@ +#include "Dom.h" +#include // for std::numeric_limits +#include + +namespace sysy { + +// 初始化 支配树静态 ID +void *DominatorTreeAnalysisPass::ID = (void *)&DominatorTreeAnalysisPass::ID; +// ============================================================== +// DominatorTree 结果类的实现 +// ============================================================== + +DominatorTree::DominatorTree(Function *F) : AssociatedFunction(F) { + // 构造时可以不计算,在分析遍运行里计算并填充 +} + +const std::set *DominatorTree::getDominators(BasicBlock *BB) const { + auto it = Dominators.find(BB); + if (it != Dominators.end()) { + return &(it->second); + } + return nullptr; +} + +BasicBlock *DominatorTree::getImmediateDominator(BasicBlock *BB) const { + auto it = IDoms.find(BB); + if (it != IDoms.end()) { + return it->second; + } + return nullptr; +} + +const std::set *DominatorTree::getDominanceFrontier(BasicBlock *BB) const { + auto it = DominanceFrontiers.find(BB); + if (it != DominanceFrontiers.end()) { + return &(it->second); + } + return nullptr; +} + +void DominatorTree::computeDominators(Function *F) { + // 经典的迭代算法计算支配者集合 + // TODO: 可以替换为更高效的算法,如 Lengauer-Tarjan 算法 + BasicBlock *entryBlock = F->getEntryBlock(); + + for (const auto &bb_ptr : F->getBasicBlocks()) { + BasicBlock *bb = bb_ptr.get(); + if (bb == entryBlock) { + Dominators[bb].insert(bb); + } else { + for (const auto &all_bb_ptr : F->getBasicBlocks()) { + Dominators[bb].insert(all_bb_ptr.get()); + } + } + } + + bool changed = true; + while (changed) { + changed = false; + for (const auto &bb_ptr : F->getBasicBlocks()) { + BasicBlock *bb = bb_ptr.get(); + if (bb == entryBlock) + continue; + + std::set newDom; + bool firstPred = true; + for (BasicBlock *pred : bb->getPredecessors()) { + if (Dominators.count(pred)) { + if (firstPred) { + newDom = Dominators[pred]; + firstPred = false; + } else { + std::set intersection; + std::set_intersection(newDom.begin(), newDom.end(), Dominators[pred].begin(), Dominators[pred].end(), + std::inserter(intersection, intersection.begin())); + newDom = intersection; + } + } + } + newDom.insert(bb); + + if (newDom != Dominators[bb]) { + Dominators[bb] = newDom; + changed = true; + } + } + } +} + +void DominatorTree::computeIDoms(Function *F) { + // 采用与之前类似的简化实现。TODO:Lengauer-Tarjan等算法。 + BasicBlock *entryBlock = F->getEntryBlock(); + IDoms[entryBlock] = nullptr; + + for (const auto &bb_ptr : F->getBasicBlocks()) { + BasicBlock *bb = bb_ptr.get(); + if (bb == entryBlock) + continue; + + BasicBlock *currentIDom = nullptr; + const std::set *domsOfBB = getDominators(bb); + if (!domsOfBB) + continue; + + for (BasicBlock *D : *domsOfBB) { + if (D == bb) + continue; + + bool isCandidateIDom = true; + for (BasicBlock *candidate : *domsOfBB) { + if (candidate == bb || candidate == D) + continue; + const std::set *domsOfCandidate = getDominators(candidate); + if (domsOfCandidate && domsOfCandidate->count(D) == 0 && domsOfBB->count(candidate)) { + isCandidateIDom = false; + break; + } + } + if (isCandidateIDom) { + currentIDom = D; + break; + } + } + IDoms[bb] = currentIDom; + } +} + +void DominatorTree::computeDominanceFrontiers(Function *F) { + // 经典的支配边界计算算法 + for (const auto &bb_ptr_X : F->getBasicBlocks()) { + BasicBlock *X = bb_ptr_X.get(); + DominanceFrontiers[X].clear(); + + for (BasicBlock *Y : X->getSuccessors()) { + const std::set *domsOfY = getDominators(Y); + if (domsOfY && domsOfY->find(X) == domsOfY->end()) { + DominanceFrontiers[X].insert(Y); + } + } + + const std::set *domsOfX = getDominators(X); + if (!domsOfX) + continue; + for (const auto &bb_ptr_Z : F->getBasicBlocks()) { + BasicBlock *Z = bb_ptr_Z.get(); + if (Z == X) + continue; + const std::set *domsOfZ = getDominators(Z); + if (domsOfZ && domsOfZ->count(X) && Z != X) { + + for (BasicBlock *Y : Z->getSuccessors()) { + const std::set *domsOfY = getDominators(Y); + if (domsOfY && domsOfY->find(X) == domsOfY->end()) { + DominanceFrontiers[X].insert(Y); + } + } + } + } + } +} + +// ============================================================== +// DominatorTreeAnalysisPass 的实现 +// ============================================================== + + +bool DominatorTreeAnalysisPass::runOnFunction(Function* F, AnalysisManager &AM) { + CurrentDominatorTree = std::make_unique(F); + CurrentDominatorTree->computeDominators(F); + CurrentDominatorTree->computeIDoms(F); + CurrentDominatorTree->computeDominanceFrontiers(F); + return false; +} + +std::unique_ptr DominatorTreeAnalysisPass::getResult() { + // 返回计算好的 DominatorTree 实例,所有权转移给 AnalysisManager + return std::move(CurrentDominatorTree); +} + +} // namespace sysy \ No newline at end of file diff --git a/src/IR.cpp b/src/IR.cpp index 5f4e0c5..faa9aed 100644 --- a/src/IR.cpp +++ b/src/IR.cpp @@ -49,6 +49,11 @@ auto Type::getFunctionType(Type *returnType, const std::vector ¶mTyp return FunctionType::get(returnType, paramTypes); } +auto Type::getArrayType(Type *elementType, unsigned numElements) -> Type * { + // forward to ArrayType + return ArrayType::get(elementType, numElements); +} + auto Type::getSize() const -> unsigned { switch (kind) { case kInt: @@ -58,6 +63,10 @@ auto Type::getSize() const -> unsigned { case kPointer: case kFunction: return 8; + case Kind::kArray: { + const ArrayType* arrType = static_cast(this); + return arrType->getElementType()->getSize() * arrType->getNumElements(); + } case kVoid: return 0; } @@ -95,6 +104,20 @@ FunctionType*FunctionType::get(Type *returnType, const std::vector ¶ return result.first->get(); } +ArrayType *ArrayType::get(Type *elementType, unsigned numElements) { + static std::set> arrayTypes; + auto iter = std::find_if(arrayTypes.begin(), arrayTypes.end(), [&](const std::unique_ptr &type) -> bool { + return elementType == type->getElementType() && numElements == type->getNumElements(); + }); + if (iter != arrayTypes.end()) { + return iter->get(); + } + auto type = new ArrayType(elementType, numElements); + assert(type); + auto result = arrayTypes.emplace(type); + return result.first->get(); +} + void Value::replaceAllUsesWith(Value *value) { for (auto &use : uses) { use->getUser()->setOperand(use->getIndex(), value); @@ -465,44 +488,7 @@ Function * Function::clone(const std::string &suffix) const { break; } - case Instruction::kLa: { - auto oldLaInst = dynamic_cast(inst); - auto oldPointer = oldLaInst->getPointer(); - Value *newPointer; - std::vector newIndices; - newPointer = oldNewValueMap.at(oldPointer); - - for (const auto &index : oldLaInst->getIndices()) { - newIndices.emplace_back(oldNewValueMap.at(index->getValue())); - } - ss << oldLaInst->getName() << suffix; - auto newLaInst = new LaInst(newPointer, newIndices, oldNewBlockMap.at(oldLaInst->getParent()), ss.str()); - ss.str(""); - oldNewValueMap.emplace(oldLaInst, newLaInst); - break; - } - - case Instruction::kGetSubArray: { - auto oldGetSubArrayInst = dynamic_cast(inst); - auto oldFather = oldGetSubArrayInst->getFatherArray(); - auto oldChild = oldGetSubArrayInst->getChildArray(); - Value *newFather; - Value *newChild; - std::vector newIndices; - newFather = oldNewValueMap.at(oldFather); - newChild = oldNewValueMap.at(oldChild); - - for (const auto &index : oldGetSubArrayInst->getIndices()) { - newIndices.emplace_back(oldNewValueMap.at(index->getValue())); - } - ss << oldGetSubArrayInst->getName() << suffix; - auto newGetSubArrayInst = - new GetSubArrayInst(dynamic_cast(newFather), dynamic_cast(newChild), newIndices, - oldNewBlockMap.at(oldGetSubArrayInst->getParent()), ss.str()); - ss.str(""); - oldNewValueMap.emplace(oldGetSubArrayInst, newGetSubArrayInst); - break; - } + // TODO:复制GEP指令 case Instruction::kMemset: { auto oldMemsetInst = dynamic_cast(inst); @@ -661,7 +647,7 @@ Function * CallInst::getCallee() const { return dynamic_cast(getOper /** * 获取变量指针 */ -auto SymbolTable::getVariable(const std::string &name) const -> User * { +auto SymbolTable::getVariable(const std::string &name) const -> Value * { auto node = curNode; while (node != nullptr) { auto iter = node->varList.find(name); @@ -676,8 +662,8 @@ auto SymbolTable::getVariable(const std::string &name) const -> User * { /** * 添加变量到符号表 */ -auto SymbolTable::addVariable(const std::string &name, User *variable) -> User * { - User *result = nullptr; +auto SymbolTable::addVariable(const std::string &name, Value *variable) -> Value * { + Value *result = nullptr; if (curNode != nullptr) { std::stringstream ss; auto iter = variableIndex.find(name); diff --git a/src/Liveness.cpp b/src/Liveness.cpp new file mode 100644 index 0000000..11c0f71 --- /dev/null +++ b/src/Liveness.cpp @@ -0,0 +1,160 @@ +#include "Liveness.h" +#include // For std::set_union, std::set_difference +#include +#include // Potentially for worklist, though not strictly needed for the iterative approach below +#include // For std::set + +namespace sysy { + +// 初始化静态 ID +void *LivenessAnalysisPass::ID = (void *)&LivenessAnalysisPass::ID; +// ============================================================== +// LivenessAnalysisResult 结果类的实现 +// ============================================================== + +const std::set *LivenessAnalysisResult::getLiveIn(BasicBlock *BB) const { + auto it = liveInSets.find(BB); + if (it != liveInSets.end()) { + return &(it->second); + } + // 返回一个空集合,表示未找到或不存在 + static const std::set emptySet; + return &emptySet; +} + +const std::set *LivenessAnalysisResult::getLiveOut(BasicBlock *BB) const { + auto it = liveOutSets.find(BB); + if (it != liveOutSets.end()) { + return &(it->second); + } + static const std::set emptySet; + return &emptySet; +} + +void LivenessAnalysisResult::computeDefUse(BasicBlock *BB, std::set &def, std::set &use) { + def.clear(); // 将持有在 BB 中定义的值 + use.clear(); // 将持有在 BB 中使用但在其定义之前的值 + + // 临时集合,用于跟踪当前基本块中已经定义过的变量 + std::set defined_in_block_so_far; + + // 按照指令在块中的顺序遍历 + for (const auto &inst_ptr : BB->getInstructions()) { + Instruction *inst = inst_ptr.get(); + + // 1. 处理指令的操作数 (Use) - 在定义之前的使用 + for (const auto &use_ptr : inst->getOperands()) { // 修正迭代器类型 + Value *operand = use_ptr->getValue(); // 从 shared_ptr 获取 Value* + + // 过滤掉常量和全局变量,因为它们通常不被视为活跃变量 + ConstantValue *constValue = dynamic_cast(operand); + GlobalValue *globalValue = dynamic_cast(operand); + if (constValue || globalValue) { + continue; // 跳过常量和全局变量 + } + + // 如果操作数是一个变量(Instruction 或 Argument),并且它在此基本块的当前点之前尚未被定义 + if (defined_in_block_so_far.find(operand) == defined_in_block_so_far.end()) { + use.insert(operand); + } + } + + // 2. 处理指令自身产生的定义 (Def) + if (inst->isDefine()) { // 使用 isDefine() 方法 + // 指令自身定义了一个值。将其添加到块的 def 集合, + // 并添加到当前块中已定义的值的临时集合。 + def.insert(inst); // inst 本身就是被定义的值(例如,虚拟寄存器) + defined_in_block_so_far.insert(inst); + } + } +} + +void LivenessAnalysisResult::computeLiveness(Function *F) { + // 每次计算前清空旧结果 + liveInSets.clear(); // 直接清空 map,不再使用 F 作为键 + liveOutSets.clear(); // 直接清空 map + + // 初始化所有基本块的 LiveIn 和 LiveOut 集合为空 + for (const auto &bb_ptr : F->getBasicBlocks()) { + BasicBlock *bb = bb_ptr.get(); + liveInSets[bb] = {}; // 直接以 bb 为键 + liveOutSets[bb] = {}; // 直接以 bb 为键 + } + + bool changed = true; + while (changed) { + changed = false; + + // TODO : 目前为逆序遍历基本块,考虑反向拓扑序遍历基本块 + + // 逆序遍历基本块 + // std::list> basicBlocks(F->getBasicBlocks().begin(), F->getBasicBlocks().end()); + // std::reverse(basicBlocks.begin(), basicBlocks.end()); + // 然后遍历 basicBlocks + // 创建一个 BasicBlock* 的列表来存储指针,避免拷贝 unique_ptr + // Option 1: Using std::vector (preferred for performance with reverse) + std::vector basicBlocksPointers; + for (const auto& bb_ptr : F->getBasicBlocks()) { + basicBlocksPointers.push_back(bb_ptr.get()); + } + std::reverse(basicBlocksPointers.begin(), basicBlocksPointers.end()); + + for (auto bb_iter = basicBlocksPointers.begin(); bb_iter != basicBlocksPointers.end(); ++bb_iter) { + BasicBlock *bb = *bb_iter; // 获取 BasicBlock 指针 + if (!bb) + continue; // 避免空指针 + + std::set oldLiveIn = liveInSets[bb]; + std::set oldLiveOut = liveOutSets[bb]; + + // 1. 计算 LiveOut(BB) = Union(LiveIn(Succ) for Succ in Successors(BB)) + std::set newLiveOut; + for (BasicBlock *succ : bb->getSuccessors()) { + const std::set *succLiveIn = getLiveIn(succ); // 获取后继的 LiveIn + if (succLiveIn) { + newLiveOut.insert(succLiveIn->begin(), succLiveIn->end()); + } + } + liveOutSets[bb] = newLiveOut; + + // 2. 计算 LiveIn(BB) = Use(BB) Union (LiveOut(BB) - Def(BB)) + std::set defSet, useSet; + computeDefUse(bb, defSet, useSet); // 计算当前块的 Def 和 Use + + std::set liveOutMinusDef; + std::set_difference(newLiveOut.begin(), newLiveOut.end(), defSet.begin(), defSet.end(), + std::inserter(liveOutMinusDef, liveOutMinusDef.begin())); + + std::set newLiveIn = useSet; + newLiveIn.insert(liveOutMinusDef.begin(), liveOutMinusDef.end()); + liveInSets[bb] = newLiveIn; + + // 检查是否发生变化 + if (oldLiveIn != newLiveIn || oldLiveOut != newLiveOut) { + changed = true; + } + } + } +} + +// ============================================================== +// LivenessAnalysisPass 的实现 +// ============================================================== + +bool LivenessAnalysisPass::runOnFunction(Function *F, AnalysisManager &AM) { + // 每次运行创建一个新的 LivenessAnalysisResult 对象来存储结果 + CurrentLivenessResult = std::make_unique(F); + + // 调用 LivenessAnalysisResult 内部的方法来计算分析结果 + CurrentLivenessResult->computeLiveness(F); + + // 分析遍通常不修改 IR,所以返回 false + return false; +} + +std::unique_ptr LivenessAnalysisPass::getResult() { + // 返回计算好的 LivenessAnalysisResult 实例,所有权转移给 AnalysisManager + return std::move(CurrentLivenessResult); +} + +} // namespace sysy \ No newline at end of file diff --git a/src/Mem2Reg.cpp b/src/Mem2Reg.cpp deleted file mode 100644 index db584ed..0000000 --- a/src/Mem2Reg.cpp +++ /dev/null @@ -1,801 +0,0 @@ -#include "Mem2Reg.h" -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include "IR.h" -#include "SysYIRAnalyser.h" -#include "SysYIRPrinter.h" - -namespace sysy { - -// 计算给定变量的定义块集合的迭代支配边界 -// TODO:优化Semi-Naive IDF -std::unordered_set Mem2Reg::computeIterDf(const std::unordered_set &blocks) { - std::unordered_set workList; - std::unordered_set ret_list; - workList.insert(blocks.begin(), blocks.end()); - - while (!workList.empty()) { - auto n = workList.begin(); - BlockAnalysisInfo* blockInfo = controlFlowAnalysis->getBlockAnalysisInfo(*n); - auto DFs = blockInfo->getDomFrontiers(); - for (auto c : DFs) { - // 如果c不在ret_list中,则将其加入ret_list和workList - // 这里的c是n的支配边界 - // 也就是n的支配边界中的块 - // 需要注意的是,支配边界是一个集合,所以可能会有重复 - if (ret_list.count(c) == 0U) { - ret_list.emplace(c); - workList.emplace(c); - } - } - workList.erase(n); - } - return ret_list; -} - -/** - * 计算value2Blocks的映射,包括value2AllocBlocks、value2DefBlocks以及value2UseBlocks - * 其中value2DefBlocks可用于计算迭代支配边界来插入相应变量的phi结点 - * 这里的value2AllocBlocks、value2DefBlocks和value2UseBlocks改变了函数级别的分析信息 - */ -auto Mem2Reg::computeValue2Blocks() -> void { - SysYPrinter printer(pModule); // 初始化打印机 - // std::cout << "===== Start computeValue2Blocks =====" << std::endl; - - auto &functions = pModule->getFunctions(); - for (const auto &function : functions) { - auto func = function.second.get(); - // std::cout << "\nProcessing function: " << func->getName() << std::endl; - - FunctionAnalysisInfo* funcInfo = controlFlowAnalysis->getFunctionAnalysisInfo(func); - if (!funcInfo) { - std::cerr << "ERROR: No analysis info for function " << func->getName() << std::endl; - continue; - } - - auto basicBlocks = func->getBasicBlocks(); - // std::cout << "BasicBlocks count: " << basicBlocks.size() << std::endl; - - for (auto &it : basicBlocks) { - auto basicBlock = it.get(); - // std::cout << "\nProcessing BB: " << basicBlock->getName() << std::endl; - // printer.printBlock(basicBlock); // 打印基本块内容 - - auto &instrs = basicBlock->getInstructions(); - for (auto &instr : instrs) { - // std::cout << " Analyzing instruction: "; - // printer.printInst(instr.get()); - // std::cout << std::endl; - - if (instr->isAlloca()) { - if (!(isArr(instr.get()) || isGlobal(instr.get()))) { - // std::cout << " Found alloca: "; - // printer.printInst(instr.get()); - // std::cout << " -> Adding to allocBlocks" << std::endl; - - funcInfo->addValue2AllocBlocks(instr.get(), basicBlock); - } else { - // std::cout << " Skip array/global alloca: "; - // printer.printInst(instr.get()); - // std::cout << std::endl; - } - } - else if (instr->isStore()) { - auto val = instr->getOperand(1); - // std::cout << " Store target: "; - // printer.printInst(dynamic_cast(val)); - - if (!(isArr(val) || isGlobal(val))) { - // std::cout << " Adding store to defBlocks for value: "; - // printer.printInst(dynamic_cast(instr.get())); - // std::cout << std::endl; - // 将store的目标值添加到defBlocks中 - funcInfo->addValue2DefBlocks(val, basicBlock); - } else { - // std::cout << " Skip array/global store" << std::endl; - } - } - else if (instr->isLoad()) { - auto val = instr->getOperand(0); - // std::cout << " Load source: "; - // printer.printInst(dynamic_cast(val)); - // std::cout << std::endl; - - if (!(isArr(val) || isGlobal(val))) { - // std::cout << " Adding load to useBlocks for value: "; - // printer.printInst(dynamic_cast(val)); - // std::cout << std::endl; - - funcInfo->addValue2UseBlocks(val, basicBlock); - } else { - // std::cout << " Skip array/global load" << std::endl; - } - } - } - } - - // 打印分析结果 - // std::cout << "\nAnalysis results for function " << func->getName() << ":" << std::endl; - - // auto &allocMap = funcInfo->getValue2AllocBlocks(); - // std::cout << "AllocBlocks (" << allocMap.size() << "):" << std::endl; - // for (auto &[val, bb] : allocMap) { - // std::cout << " "; - // printer.printInst(dynamic_cast(val)); - // std::cout << " in BB: " << bb->getName() << std::endl; - // } - - // auto &defMap = funcInfo->getValue2DefBlocks(); - // std::cout << "DefBlocks (" << defMap.size() << "):" << std::endl; - // for (auto &[val, bbs] : defMap) { - // std::cout << " "; - // printer.printInst(dynamic_cast(val)); - // for (const auto &[bb, count] : bbs) { - // std::cout << " in BB: " << bb->getName() << " (count: " << count << ")"; - // } - // } - - // auto &useMap = funcInfo->getValue2UseBlocks(); - // std::cout << "UseBlocks (" << useMap.size() << "):" << std::endl; - // for (auto &[val, bbs] : useMap) { - // std::cout << " "; - // printer.printInst(dynamic_cast(val)); - // for (const auto &[bb, count] : bbs) { - // std::cout << " in BB: " << bb->getName() << " (count: " << count << ")"; - // } - // } - } - // std::cout << "===== End computeValue2Blocks =====" << std::endl; -} - - -/** - * @brief 级联关系的顺带消除,用于llvm mem2reg类预优化1 - * - * 采用队列进行模拟,从某种程度上来看其实可以看作是UD链的反向操作; - * - * @param [in] instr store指令使用的指令 - * @param [in] changed 不动点法的判断标准,地址传递 - * @param [in] func 指令所在函数 - * @param [in] block 指令所在基本块 - * @param [in] instrs 基本块所在指令集合,地址传递 - * @return 无返回值,但满足条件的情况下会对指令进行删除 - */ -auto Mem2Reg::cascade(Instruction *instr, bool &changed, Function *func, BasicBlock *block, - std::list> &instrs) -> void { - if (instr != nullptr) { - if (instr->isUnary() || instr->isBinary() || instr->isLoad()) { - std::queue toRemove; - toRemove.push(instr); - while (!toRemove.empty()) { - auto top = toRemove.front(); - toRemove.pop(); - auto operands = top->getOperands(); - for (const auto &operand : operands) { - auto elem = dynamic_cast(operand->getValue()); - if (elem != nullptr) { - if ((elem->isUnary() || elem->isBinary() || elem->isLoad()) && elem->getUses().size() == 1 && - elem->getUses().front()->getUser() == top) { - toRemove.push(elem); - } else if (elem->isAlloca()) { - // value2UseBlock中该block对应次数-1,如果该变量的该useblock中count减为0了,则意味着 - // 该block其他地方也没用到该alloc了,故从value2UseBlock中删除 - FunctionAnalysisInfo* funcInfo = controlFlowAnalysis->getFunctionAnalysisInfo(func); - auto res = funcInfo->removeValue2UseBlock(elem, block); - // 只要有一次返回了true,就说明有变化 - if (res) { - changed = true; - } - } - } - } - auto tofind = - std::find_if(instrs.begin(), instrs.end(), [&top](const auto &instr) { return instr.get() == top; }); - assert(tofind != instrs.end()); - usedelete(tofind->get()); - instrs.erase(tofind); - } - } - } -} - -/** - * llvm mem2reg预优化1: 删除不含load的alloc和store - * - * 1. 删除不含load的alloc和store; - * 2. 删除store指令,之前的用于作store指令第0个操作数的那些级联指令就冗余了,也要删除; - * 3. 删除之后,可能有些变量的load使用恰好又没有了,因此再次从第一步开始循环,这里使用不动点法 - * - * 由于删除了级联关系,所以这里的方法有点儿激进; - * 同时也考虑了级联关系时如果调用了函数,可能会有side effect,所以没有删除调用函数的级联关系; - * 而且关于函数参数的alloca不会在指令中删除,也不会在value2Alloca中删除; - * 同样地,我们不考虑数组和global,不过这里的代码是基于value2blocks的,在value2blocks中已经考虑了,所以不用显式指明 - *= - */ -auto Mem2Reg::preOptimize1() -> void { - SysYPrinter printer(pModule); // 初始化打印机 - - auto &functions = pModule->getFunctions(); - // std::cout << "===== Start preOptimize1 =====" << std::endl; - - for (const auto &function : functions) { - auto func = function.second.get(); - // std::cout << "\nProcessing function: " << func->getName() << std::endl; - - FunctionAnalysisInfo* funcInfo = controlFlowAnalysis->getFunctionAnalysisInfo(func); - if (!funcInfo) { - // std::cerr << "ERROR: No analysis info for function " << func->getName() << std::endl; - continue; - } - - auto &vToDefB = funcInfo->getValue2DefBlocks(); - auto &vToUseB = funcInfo->getValue2UseBlocks(); - auto &vToAllocB = funcInfo->getValue2AllocBlocks(); - - // 打印初始状态 - // std::cout << "Initial allocas: " << vToAllocB.size() << std::endl; - // for (auto &[val, bb] : vToAllocB) { - // std::cout << " Alloca: "; - // printer.printInst(dynamic_cast(val)); - // std::cout << " in BB: " << bb->getName() << std::endl; - // } - - // 阶段1:删除无store的alloca - // std::cout << "\nPhase 1: Remove unused allocas" << std::endl; - for (auto iter = vToAllocB.begin(); iter != vToAllocB.end();) { - auto val = iter->first; - auto bb = iter->second; - - // std::cout << "Checking alloca: "; - // printer.printInst(dynamic_cast(val)); - // std::cout << " in BB: " << bb->getName() << std::endl; - - // 如果该alloca没有对应的store指令,且不在函数参数中 - // 这里的vToDefB是value2DefBlocks,vToUseB是value2UseBlocks - - // 打印vToDefB - // std::cout << "DefBlocks (" << vToDefB.size() << "):" << std::endl; - // for (auto &[val, bbs] : vToDefB) { - // std::cout << " "; - // printer.printInst(dynamic_cast(val)); - // for (const auto &[bb, count] : bbs) { - // std::cout << " in BB: " << bb->getName() << " (count: " << count << ")" << std::endl; - // } - // } - // std::cout << vToDefB.count(val) << std::endl; - - if (vToDefB.count(val) == 0U && - std::find(func->getEntryBlock()->getArguments().begin(), - func->getEntryBlock()->getArguments().end(), - val) == func->getEntryBlock()->getArguments().end()) { - - // std::cout << " Removing unused alloca: "; - // printer.printInst(dynamic_cast(val)); - // std::cout << std::endl; - - auto tofind = std::find_if(bb->getInstructions().begin(), - bb->getInstructions().end(), - [val](const auto &instr) { - return instr.get() == val; - }); - if (tofind == bb->getInstructions().end()) { - // std::cerr << "ERROR: Alloca not found in BB!" << std::endl; - ++iter; - continue; - } - - usedelete(tofind->get()); - bb->getInstructions().erase(tofind); - iter = vToAllocB.erase(iter); - } else { - ++iter; - } - } - - // 阶段2:删除无load的store - // std::cout << "\nPhase 2: Remove dead stores" << std::endl; - bool changed = true; - int iteration = 0; - - while (changed) { - changed = false; - iteration++; - // std::cout << "\nIteration " << iteration << std::endl; - - for (auto iter = vToDefB.begin(); iter != vToDefB.end();) { - auto val = iter->first; - - // std::cout << "Checking value: "; - // printer.printInst(dynamic_cast(val)); - // std::cout << std::endl; - - if (vToUseB.count(val) == 0U) { - // std::cout << " Found dead store for value: "; - // printer.printInst(dynamic_cast(val)); - // std::cout << std::endl; - - auto blocks = funcInfo->getDefBlocksByValue(val); - for (auto block : blocks) { - // std::cout << " Processing BB: " << block->getName() << std::endl; - // printer.printBlock(block); // 打印基本块内容 - - auto &instrs = block->getInstructions(); - for (auto it = instrs.begin(); it != instrs.end();) { - if ((*it)->isStore() && (*it)->getOperand(1) == val) { - // std::cout << " Removing store: "; - // printer.printInst(it->get()); - std::cout << std::endl; - - auto valUsedByStore = dynamic_cast((*it)->getOperand(0)); - usedelete(it->get()); - - if (valUsedByStore != nullptr && - valUsedByStore->getUses().size() == 1 && - valUsedByStore->getUses().front()->getUser() == (*it).get()) { - // std::cout << " Cascade deleting: "; - // printer.printInst(valUsedByStore); - // std::cout << std::endl; - - cascade(valUsedByStore, changed, func, block, instrs); - } - it = instrs.erase(it); - changed = true; - } else { - ++it; - } - } - } - - // 删除对应的alloca - if (std::find(func->getEntryBlock()->getArguments().begin(), - func->getEntryBlock()->getArguments().end(), - val) == func->getEntryBlock()->getArguments().end()) { - auto bb = funcInfo->getAllocBlockByValue(val); - if (bb != nullptr) { - // std::cout << " Removing alloca: "; - // printer.printInst(dynamic_cast(val)); - // std::cout << " in BB: " << bb->getName() << std::endl; - - funcInfo->removeValue2AllocBlock(val); - auto tofind = std::find_if(bb->getInstructions().begin(), - bb->getInstructions().end(), - [val](const auto &instr) { - return instr.get() == val; - }); - if (tofind != bb->getInstructions().end()) { - usedelete(tofind->get()); - bb->getInstructions().erase(tofind); - } else { - std::cerr << "ERROR: Alloca not found in BB!" << std::endl; - } - } - } - iter = vToDefB.erase(iter); - } else { - ++iter; - } - } - } - } - // std::cout << "===== End preOptimize1 =====" << std::endl; -} - -/** - * llvm mem2reg预优化2: 针对某个变量的Defblocks只有一个块的情况 - * - * 1. 该基本块最后一次对该变量的store指令后的所有对该变量的load指令都可以替换为该基本块最后一次store指令的第0个操作数; - * 2. 以该基本块为必经结点的结点集合中的对该变量的load指令都可以替换为该基本块最后一次对该变量的store指令的第0个操作数; - * 3. - * 如果对该变量的所有load均替换掉了,删除该基本块中最后一次store指令,如果这个store指令是唯一的define,那么再删除alloca指令(不删除参数的alloca); - * 4. - * 如果对该value的所有load都替换掉了,对于该变量剩下还有store的话,就转换成了preOptimize1的情况,再调用preOptimize1进行删除; - * - * 同样不考虑数组和全局变量,因为这些变量不会被mem2reg优化,在value2blocks中已经考虑了,所以不用显式指明; - * 替换的操作采用了UD链进行简化和效率的提升 - * - */ -auto Mem2Reg::preOptimize2() -> void { - auto &functions = pModule->getFunctions(); - for (const auto &function : functions) { - auto func = function.second.get(); - FunctionAnalysisInfo* funcInfo = controlFlowAnalysis->getFunctionAnalysisInfo(func); - auto values = funcInfo->getValuesOfDefBlock(); - for (auto val : values) { - auto blocks = funcInfo->getDefBlocksByValue(val); - // 该val只有一个defining block - if (blocks.size() == 1) { - auto block = *blocks.begin(); - auto &instrs = block->getInstructions(); - auto rit = std::find_if(instrs.rbegin(), instrs.rend(), - [val](const auto &instr) { return instr->isStore() && instr->getOperand(1) == val; }); - // 注意reverse_iterator求base后是指向下一个指令,因此要减一才是原来的指令 - assert(rit != instrs.rend()); - auto it = --rit.base(); - auto propogationVal = (*it)->getOperand(0); - // 其实该块中it后对该val的load指令也可以替换掉了 - for (auto curit = std::next(it); curit != instrs.end();) { - if ((*curit)->isLoad() && (*curit)->getOperand(0) == val) { - curit->get()->replaceAllUsesWith(propogationVal); - usedelete(curit->get()); - curit = instrs.erase(curit); - funcInfo->removeValue2UseBlock(val, block); - } else { - ++curit; - } - } - // 在支配树后继结点中替换load指令的操作数 - BlockAnalysisInfo* blockInfo = controlFlowAnalysis->getBlockAnalysisInfo(block); - std::vector blkchildren; - // 获取该块的支配树后继结点 - std::queue q; - auto sdoms = blockInfo->getSdoms(); - for (auto sdom : sdoms) { - q.push(sdom); - blkchildren.push_back(sdom); - } - while (!q.empty()) { - auto blk = q.front(); - q.pop(); - BlockAnalysisInfo* blkInfo = controlFlowAnalysis->getBlockAnalysisInfo(blk); - for (auto sdom : blkInfo->getSdoms()) { - q.push(sdom); - blkchildren.push_back(sdom); - } - } - for (auto child : blkchildren) { - auto &childInstrs = child->getInstructions(); - for (auto childIter = childInstrs.begin(); childIter != childInstrs.end();) { - if ((*childIter)->isLoad() && (*childIter)->getOperand(0) == val) { - childIter->get()->replaceAllUsesWith(propogationVal); - usedelete(childIter->get()); - childIter = childInstrs.erase(childIter); - funcInfo->removeValue2UseBlock(val, child); - } else { - ++childIter; - } - } - } - // 如果对该val的所有load均替换掉了,那么对于该val的defining block中的最后一个define也可以删除了 - // 同时该块中前面对于该val的define也变成死代码了,可调用preOptimize1进行删除 - if (funcInfo->getUseBlocksByValue(val).empty()) { - usedelete(it->get()); - instrs.erase(it); - auto change = funcInfo->removeValue2DefBlock(val, block); - if (change) { - // 如果define是唯一的,且不是函数参数的alloca,直接删alloca - if (std::find(func->getEntryBlock()->getArguments().begin(), func->getEntryBlock()->getArguments().end(), - val) == func->getEntryBlock()->getArguments().end()) { - auto bb = funcInfo->getAllocBlockByValue(val); - assert(bb != nullptr); - auto tofind = std::find_if(bb->getInstructions().begin(), bb->getInstructions().end(), - [val](const auto &instr) { return instr.get() == val; }); - usedelete(tofind->get()); - bb->getInstructions().erase(tofind); - funcInfo->removeValue2AllocBlock(val); - } - } else { - // 如果该变量还有其他的define,那么前面的define也变成死代码了 - assert(!funcInfo->getDefBlocksByValue(val).empty()); - assert(funcInfo->getUseBlocksByValue(val).empty()); - preOptimize1(); - } - } - } - } - } -} - -/** - * @brief llvm mem2reg类预优化3:针对某个变量的所有读写都在同一个块中的情况 - * - * 1. 将每一个load替换成前一个store的值,并删除该load; - * 2. 如果在load前没有对该变量的store,则不删除该load; - * 3. 如果一个store后没有任何对改变量的load,则删除该store; - * - * @note 额外说明:第二点不用显式处理,因为我们的方法是从找到第一个store开始; - * 第三点其实可以更激进一步地理解,即每次替换了load之后,它对应地那个store也可以删除了,同时注意这里不要使用preoptimize1进行处理,因为他们的级联关系是有用的:即用来求load的替换值; - * 同样地,我们这里不考虑数组和全局变量,因为这些变量不会被mem2reg优化,不过这里在计算value2DefBlocks时已经跳过了,所以不需要再显式处理了; - * 替换的操作采用了UD链进行简化和效率的提升 - * - * @param [in] void - * @return 无返回值,但满足条件的情况下会对指令的操作数进行替换以及对指令进行删除 - */ -auto Mem2Reg::preOptimize3() -> void { - auto &functions = pModule->getFunctions(); - for (const auto &function : functions) { - auto func = function.second.get(); - FunctionAnalysisInfo* funcInfo = controlFlowAnalysis->getFunctionAnalysisInfo(func); - auto values = funcInfo->getValuesOfDefBlock(); - for (auto val : values) { - auto sblocks = funcInfo->getDefBlocksByValue(val); - auto lblocks = funcInfo->getUseBlocksByValue(val); - if (sblocks.size() == 1 && lblocks.size() == 1 && *sblocks.begin() == *lblocks.begin()) { - auto block = *sblocks.begin(); - auto &instrs = block->getInstructions(); - auto it = std::find_if(instrs.begin(), instrs.end(), - [val](const auto &instr) { return instr->isStore() && instr->getOperand(1) == val; }); - while (it != instrs.end()) { - auto propogationVal = (*it)->getOperand(0); - auto last = std::find_if(std::next(it), instrs.end(), [val](const auto &instr) { - return instr->isStore() && instr->getOperand(1) == val; - }); - for (auto curit = std::next(it); curit != last;) { - if ((*curit)->isLoad() && (*curit)->getOperand(0) == val) { - curit->get()->replaceAllUsesWith(propogationVal); - usedelete(curit->get()); - curit = instrs.erase(curit); - funcInfo->removeValue2UseBlock(val, block); - } else { - ++curit; - } - } - // 替换了load之后,它对应地那个store也可以删除了 - if (!(std::find_if(func->getEntryBlock()->getArguments().begin(), func->getEntryBlock()->getArguments().end(), - [val](const auto &instr) { return instr == val; }) != - func->getEntryBlock()->getArguments().end()) && - last == instrs.end()) { - usedelete(it->get()); - it = instrs.erase(it); - if (funcInfo->removeValue2DefBlock(val, block)) { - auto bb = funcInfo->getAllocBlockByValue(val); - if (bb != nullptr) { - auto tofind = std::find_if(bb->getInstructions().begin(), bb->getInstructions().end(), - [val](const auto &instr) { return instr.get() == val; }); - usedelete(tofind->get()); - bb->getInstructions().erase(tofind); - funcInfo->removeValue2AllocBlock(val); - } - } - } - it = last; - } - } - } - } -} - -/** - * 为所有变量的定义块集合的迭代支配边界插入phi结点 - * - * insertPhi是mem2reg的核心之一,这里是对所有变量的迭代支配边界的phi结点插入,无参数也无返回值; - * 同样跳过对数组和全局变量的处理,因为这些变量不会被mem2reg优化,刚好这里在计算value2DefBlocks时已经跳过了,所以不需要再显式处理了; - * 同时我们进行了剪枝处理,只有在基本块入口活跃的变量,才插入phi函数 - * - */ -auto Mem2Reg::insertPhi() -> void { - auto &functions = pModule->getFunctions(); - for (const auto &function : functions) { - auto func = function.second.get(); - FunctionAnalysisInfo* funcInfo = controlFlowAnalysis->getFunctionAnalysisInfo(func); - const auto &vToDefB = funcInfo->getValue2DefBlocks(); - for (const auto &map_pair : vToDefB) { - // 首先为每个变量找到迭代支配边界 - auto val = map_pair.first; - auto blocks = funcInfo->getDefBlocksByValue(val); - auto itDFs = computeIterDf(blocks); - // 然后在每个变量相应的迭代支配边界上插入phi结点 - for (auto basicBlock : itDFs) { - const auto &actiTable = activeVarAnalysis->getActiveTable(); - auto dval = dynamic_cast(val); - // 只有在基本块入口活跃的变量,才插入phi函数 - if (actiTable.at(basicBlock).front().count(dval) != 0U) { - pBuilder->createPhiInst(val->getType(), val, basicBlock); - } - } - } - } -} - -/** - * 重命名 - * - * 重命名是mem2reg的核心之二,这里是对单个块的重命名,递归实现 - * 同样跳过对数组和全局变量的处理,因为这些变量不会被mem2reg优化 - * - */ -auto Mem2Reg::rename(BasicBlock *block, std::unordered_map &count, - std::unordered_map> &stacks) -> void { - auto &instrs = block->getInstructions(); - std::unordered_map valPop; - // 第一大步:对块中的所有指令遍历处理 - for (auto iter = instrs.begin(); iter != instrs.end();) { - auto instr = iter->get(); - // 对于load指令,变量用最新的那个 - if (instr->isLoad()) { - auto val = instr->getOperand(0); - if (!(isArr(val) || isGlobal(val))) { - if (!stacks[val].empty()) { - instr->replaceOperand(0, stacks[val].top()); - } - } - } - // 然后对于define的情况,看alloca、store和phi指令 - if (instr->isDefine()) { - if (instr->isAlloca()) { - // alloca指令名字不改了,命名就按x,x_1,x_2...来就行 - auto val = instr; - if (!(isArr(val) || isGlobal(val))) { - ++valPop[val]; - stacks[val].push(val); - ++count[val]; - } - } else if (instr->isPhi()) { - // Phi指令也是一条特殊的define指令 - auto val = dynamic_cast(instr)->getMapVal(); - if (!(isArr(val) || isGlobal(val))) { - auto i = count[val]; - if (i == 0) { - // 对还未alloca就有phi的指令的处理,直接删除 - usedelete(iter->get()); - iter = instrs.erase(iter); - continue; - } - auto newname = dynamic_cast(val)->getName() + "_" + std::to_string(i); - auto newalloca = pBuilder->createAllocaInstWithoutInsert(val->getType(), {}, block, newname); - FunctionAnalysisInfo* ParentfuncInfo = controlFlowAnalysis->getFunctionAnalysisInfo(block->getParent()); - ParentfuncInfo->addIndirectAlloca(newalloca); - instr->replaceOperand(0, newalloca); - ++valPop[val]; - stacks[val].push(newalloca); - ++count[val]; - } - } else { - // store指令看operand的名字,我们的实现是规定变量在operand的第二位,用一个新的alloca x_i代替 - auto val = instr->getOperand(1); - if (!(isArr(val) || isGlobal(val))) { - auto i = count[val]; - auto newname = dynamic_cast(val)->getName() + "_" + std::to_string(i); - auto newalloca = pBuilder->createAllocaInstWithoutInsert(val->getType(), {}, block, newname); - FunctionAnalysisInfo* ParentfuncInfo = controlFlowAnalysis->getFunctionAnalysisInfo(block->getParent()); - ParentfuncInfo->addIndirectAlloca(newalloca); - // block->getParent()->addIndirectAlloca(newalloca); - instr->replaceOperand(1, newalloca); - ++valPop[val]; - stacks[val].push(newalloca); - ++count[val]; - } - } - } - ++iter; - } - // 第二大步:把所有CFG中的该块的successor的phi指令的相应operand确定 - for (auto succ : block->getSuccessors()) { - auto position = getPredIndex(block, succ); - for (auto &instr : succ->getInstructions()) { - if (instr->isPhi()) { - auto val = dynamic_cast(instr.get())->getMapVal(); - if (!stacks[val].empty()) { - instr->replaceOperand(position + 1, stacks[val].top()); - } - } else { - // phi指令是添加在块的最前面的,因此过了之后就不会有phi了,直接break - break; - } - } - } - // 第三大步:递归支配树的后继,支配树才能表示define-use关系 - BlockAnalysisInfo* blockInfo = controlFlowAnalysis->getBlockAnalysisInfo(block); - for (auto sdom : blockInfo->getSdoms()) { - rename(sdom, count, stacks); - } - // 第四大步:遍历块中的所有指令,如果涉及到define,就弹栈,这一步是必要的,可以从递归的整体性来思考原因 - // 注意这里count没清理,因为平级之间计数仍然是一直增加的,但是stack要清理,因为define-use关系来自直接 - // 支配结点而不是平级之间,不清理栈会被污染 - // 提前优化:知道变量对应的要弹栈的次数就可以了,没必要遍历所有instr. - for (auto val_pair : valPop) { - auto val = val_pair.first; - for (int i = 0; i < val_pair.second; ++i) { - stacks[val].pop(); - } - } -} - -/** - * 重命名所有块 - * - * 调用rename,自上而下实现所有rename - * - */ -auto Mem2Reg::renameAll() -> void { - auto &functions = pModule->getFunctions(); - for (const auto &function : functions) { - auto func = function.second.get(); - // 对于每个function都要SSA化,所以count和stacks定义在这并初始化 - std::unordered_map count; - std::unordered_map> stacks; - FunctionAnalysisInfo* funcInfo = controlFlowAnalysis->getFunctionAnalysisInfo(func); - for (const auto &map_pair : funcInfo->getValue2DefBlocks()) { - auto val = map_pair.first; - count[val] = 0; - } - rename(func->getEntryBlock(), count, stacks); - } -} - -/** - * mem2reg,对外的接口 - * - * 静态单一赋值 + mem2reg等pass的逻辑组合 - * - */ -auto Mem2Reg::mem2regPipeline() -> void { - // 首先进行mem2reg的前置分析 - controlFlowAnalysis->clear(); - controlFlowAnalysis->runControlFlowAnalysis(); - // 活跃变量分析 - activeVarAnalysis->clear(); - dataFlowAnalysisUtils.addBackwardAnalyzer(activeVarAnalysis); - dataFlowAnalysisUtils.backwardAnalyze(pModule); - - // 计算所有valueToBlocks的定义映射 - computeValue2Blocks(); - // SysYPrinter printer(pModule); - // 参考llvm的mem2reg遍,在插入phi结点之前,先做些优化 - preOptimize1(); - // printer.printIR(); - preOptimize2(); - // printer.printIR(); - // 优化三 可能会针对局部变量优化而删除整个块的alloca/store - preOptimize3(); - //再进行活跃变量分析 - // 报错? - - // printer.printIR(); - dataFlowAnalysisUtils.backwardAnalyze(pModule); - // 为所有变量插入phi结点 - insertPhi(); - // 重命名 - renameAll(); -} - -/** - * 计算块n是块s的第几个前驱 - * - * helperfunction,没有返回值,但是会将dom和other的交集赋值给dom - * - */ -auto Mem2Reg::getPredIndex(BasicBlock *n, BasicBlock *s) -> int { - int index = 0; - for (auto elem : s->getPredecessors()) { - if (elem == n) { - break; - } - ++index; - } - assert(index < static_cast(s->getPredecessors().size()) && "n is not a predecessor of s."); - return index; -} - -/** - * 判断一个value是不是全局变量 - */ -auto Mem2Reg::isGlobal(Value *val) -> bool { - auto gval = dynamic_cast(val); - return gval != nullptr; -} - -/** - * 判断一个value是不是数组 - */ -auto Mem2Reg::isArr(Value *val) -> bool { - auto aval = dynamic_cast(val); - return aval != nullptr && aval->getNumDims() != 0; -} - -/** - * 删除一个指令的operand对应的value的该条use - */ -auto Mem2Reg::usedelete(Instruction *instr) -> void { - for (auto &use : instr->getOperands()) { - auto val = use->getValue(); - val->removeUse(use); - } -} -} // namespace sysy diff --git a/src/Pass.cpp b/src/Pass.cpp new file mode 100644 index 0000000..bba8e36 --- /dev/null +++ b/src/Pass.cpp @@ -0,0 +1,175 @@ +#include "Dom.h" +#include "Liveness.h" +#include "SysYIRCFGOpt.h" +#include "SysYIRPrinter.h" +#include "DCE.h" +#include "Pass.h" +#include +#include +#include +#include +#include +#include + +extern int DEBUG; // 全局调试标志 +namespace sysy { + +// ====================================================================== +// 封装优化流程的函数:包含Pass注册和迭代运行逻辑 +// ====================================================================== + +void PassManager::runOptimizationPipeline(Module* moduleIR, IRBuilder* builderIR, int optLevel) { + if (DEBUG) std::cout << "--- Starting Middle-End Optimizations (Level -O" << optLevel << ") ---\n"; + + /* + 中端开发框架基本流程: + 1) 分析pass + 1. 实现分析pass并引入Pass.cpp + 2. 注册分析pass + 2) 优化pass + 1. 实现优化pass并引入Pass.cpp + 2. 注册优化pass + 3. 添加优化passid + */ + // 注册分析遍 + registerAnalysisPass(); + registerAnalysisPass(); + + // 注册优化遍 + registerOptimizationPass(); + registerOptimizationPass(); + registerOptimizationPass(); + + registerOptimizationPass(builderIR); + registerOptimizationPass(builderIR); + registerOptimizationPass(builderIR); + + if (optLevel >= 1) { + //经过设计安排优化遍的执行顺序以及执行逻辑 + if (DEBUG) std::cout << "Applying -O1 optimizations.\n"; + if (DEBUG) std::cout << "--- Running custom optimization sequence ---\n"; + + this->clearPasses(); + this->addPass(&SysYDelInstAfterBrPass::ID); + this->addPass(&SysYDelNoPreBLockPass::ID); + this->addPass(&SysYBlockMergePass::ID); + this->addPass(&SysYDelEmptyBlockPass::ID); + this->addPass(&SysYCondBr2BrPass::ID); + this->addPass(&SysYAddReturnPass::ID); + this->run(); + + this->clearPasses(); + this->addPass(&DCE::ID); + this->run(); + + if (DEBUG) std::cout << "--- Custom optimization sequence finished ---\n"; + } + + // 2. 创建遍管理器 + // 3. 根据优化级别添加不同的优化遍 + // TODO : 根据 optLevel 添加不同的优化遍 + // 讨论 是不动点迭代进行优化遍还是手动客制化优化遍的顺序? + + + if (DEBUG) { + std::cout << "=== Final IR After Middle-End Optimizations (Level -O" << optLevel << ") ===\n"; + SysYPrinter printer(moduleIR); + printer.printIR(); + } +} + +void PassManager::clearPasses() { + passes.clear(); +} + +void PassManager::addPass(void *passID) { + + PassRegistry ®istry = PassRegistry::getPassRegistry(); + std::unique_ptr P = registry.createPass(passID); + if (!P) { + // Error: Pass not found or failed to create + return; + } + + passes.push_back(std::move(P)); +} + +// 运行所有注册的遍 +bool PassManager::run() { + bool changed = false; + for (const auto &p : passes) { + bool passChanged = false; // 记录当前遍是否修改了 IR + + // 处理优化遍的分析依赖和失效 + if (p->getPassKind() == Pass::PassKind::Optimization) { + OptimizationPass *optPass = static_cast(p.get()); + std::set analysisDependencies; + std::set analysisInvalidations; + optPass->getAnalysisUsage(analysisDependencies, analysisInvalidations); + + // PassManager 不显式运行分析依赖。 + // 而是优化遍在 runOnFunction 内部通过 AnalysisManager.getAnalysisResult 按需请求。 + } + + if (p->getGranularity() == Pass::Granularity::Module) { + passChanged = p->runOnModule(pmodule, analysisManager); + } else if (p->getGranularity() == Pass::Granularity::Function) { + for (auto &funcPair : pmodule->getFunctions()) { + Function *F = funcPair.second.get(); + passChanged = p->runOnFunction(F, analysisManager) || passChanged; + + if (passChanged && p->getPassKind() == Pass::PassKind::Optimization) { + OptimizationPass *optPass = static_cast(p.get()); + std::set analysisDependencies; + std::set analysisInvalidations; + optPass->getAnalysisUsage(analysisDependencies, analysisInvalidations); + for (void *invalidationID : analysisInvalidations) { + analysisManager.invalidateAnalysis(invalidationID, F); + } + } + } + } else if (p->getGranularity() == Pass::Granularity::BasicBlock) { + for (auto &funcPair : pmodule->getFunctions()) { + Function *F = funcPair.second.get(); + for (auto &bbPtr : funcPair.second->getBasicBlocks()) { + passChanged = p->runOnBasicBlock(bbPtr.get(), analysisManager) || passChanged; + + if (passChanged && p->getPassKind() == Pass::PassKind::Optimization) { + OptimizationPass *optPass = static_cast(p.get()); + std::set analysisDependencies; + std::set analysisInvalidations; + optPass->getAnalysisUsage(analysisDependencies, analysisInvalidations); + for (void *invalidationID : analysisInvalidations) { + analysisManager.invalidateAnalysis(invalidationID, F); + } + } + } + } + } + changed = changed || passChanged; + } + return changed; + +} + + +template void registerAnalysisPass() { + PassRegistry::getPassRegistry().registerPass(&AnalysisPassType::ID, + []() { return std::make_unique(); }); +} + +template ::value, int>::type> +void registerOptimizationPass(IRBuilder* builder) { + PassRegistry::getPassRegistry().registerPass(&OptimizationPassType::ID, + [builder]() { return std::make_unique(builder); }); +} + +template ::value, int>::type> +void registerOptimizationPass() { + PassRegistry::getPassRegistry().registerPass(&OptimizationPassType::ID, + []() { return std::make_unique(); }); +} + +} // namespace sysy \ No newline at end of file diff --git a/src/RISCv64AsmPrinter.cpp b/src/RISCv64AsmPrinter.cpp index 1995024..65dbe5d 100644 --- a/src/RISCv64AsmPrinter.cpp +++ b/src/RISCv64AsmPrinter.cpp @@ -31,6 +31,8 @@ void RISCv64AsmPrinter::run(std::ostream& os, bool debug) { } } +// 在 RISCv64AsmPrinter.cpp 文件中 + void RISCv64AsmPrinter::printPrologue() { StackFrameInfo& frame_info = MFunc->getFrameInfo(); // 序言需要为保存ra和s0预留16字节 @@ -44,12 +46,16 @@ void RISCv64AsmPrinter::printPrologue() { *OS << " sd s0, " << (aligned_stack_size - 16) << "(sp)\n"; *OS << " addi s0, sp, " << aligned_stack_size << "\n"; } - - // 忠实还原保存函数入口参数的逻辑 + + // 为函数参数分配寄存器 Function* F = MFunc->getFunc(); if (F && F->getEntryBlock()) { int arg_idx = 0; RISCv64ISel* isel = MFunc->getISel(); + + // 获取函数所有参数的类型列表 + auto param_types = F->getParamTypes(); + for (AllocaInst* alloca_for_param : F->getEntryBlock()->getArguments()) { if (arg_idx >= 8) break; @@ -57,7 +63,25 @@ void RISCv64AsmPrinter::printPrologue() { if (frame_info.alloca_offsets.count(vreg)) { int offset = frame_info.alloca_offsets.at(vreg); auto arg_reg = static_cast(static_cast(PhysicalReg::A0) + arg_idx); - *OS << " sw " << regToString(arg_reg) << ", " << offset << "(s0)\n"; + + // 1. 获取当前参数的真实类型 + // 注意:F->getParamTypes() 返回的是一个 range-based view,需要转换为vector或直接使用 + Type* current_param_type = nullptr; + int temp_idx = 0; + for(auto p_type : param_types) { + if (temp_idx == arg_idx) { + current_param_type = p_type; + break; + } + temp_idx++; + } + assert(current_param_type && "Could not find parameter type."); + + // 2. 根据类型决定使用 "sw" 还是 "sd" + const char* store_op = current_param_type->isPointer() ? "sd" : "sw"; + + // 3. 打印正确的存储指令 + *OS << " " << store_op << " " << regToString(arg_reg) << ", " << offset << "(s0)\n"; } arg_idx++; } @@ -133,17 +157,23 @@ void RISCv64AsmPrinter::printInstruction(MachineInstr* instr, bool debug) { case RVOpcodes::SNEZ: *OS << "snez "; break; case RVOpcodes::CALL: *OS << "call "; break; case RVOpcodes::LABEL: - // printOperand(instr->getOperands()[0].get()); - // *OS << ":"; break; - case RVOpcodes::FRAME_LOAD: + case RVOpcodes::FRAME_LOAD_W: // It should have been eliminated by RegAlloc if (!debug) throw std::runtime_error("FRAME pseudo-instruction not eliminated before AsmPrinter"); - *OS << "frame_load "; break; - case RVOpcodes::FRAME_STORE: + *OS << "frame_load_w "; break; + case RVOpcodes::FRAME_LOAD_D: // It should have been eliminated by RegAlloc if (!debug) throw std::runtime_error("FRAME pseudo-instruction not eliminated before AsmPrinter"); - *OS << "frame_store "; break; + *OS << "frame_load_d "; break; + case RVOpcodes::FRAME_STORE_W: + // It should have been eliminated by RegAlloc + if (!debug) throw std::runtime_error("FRAME pseudo-instruction not eliminated before AsmPrinter"); + *OS << "frame_store_w "; break; + case RVOpcodes::FRAME_STORE_D: + // It should have been eliminated by RegAlloc + if (!debug) throw std::runtime_error("FRAME pseudo-instruction not eliminated before AsmPrinter"); + *OS << "frame_store_d "; break; case RVOpcodes::FRAME_ADDR: // It should have been eliminated by RegAlloc if (!debug) throw std::runtime_error("FRAME pseudo-instruction not eliminated before AsmPrinter"); diff --git a/src/RISCv64Backend.cpp b/src/RISCv64Backend.cpp index 4f45fde..c429a63 100644 --- a/src/RISCv64Backend.cpp +++ b/src/RISCv64Backend.cpp @@ -85,7 +85,7 @@ std::string RISCv64CodeGen::function_gen(Function* func) { std::stringstream ss; RISCv64AsmPrinter printer(mfunc.get()); printer.run(ss); - if (DEBUG) ss << ss1.str(); // 将指令选择阶段的结果也包含在最终输出中 + if (DEBUG) ss << "\n" << ss1.str(); // 将指令选择阶段的结果也包含在最终输出中 return ss.str(); } diff --git a/src/RISCv64ISel.cpp b/src/RISCv64ISel.cpp index 2ab47b5..db07908 100644 --- a/src/RISCv64ISel.cpp +++ b/src/RISCv64ISel.cpp @@ -10,7 +10,7 @@ namespace sysy { // DAG节点定义 (内部实现) struct RISCv64ISel::DAGNode { - enum NodeKind { CONSTANT, LOAD, STORE, BINARY, CALL, RETURN, BRANCH, ALLOCA_ADDR, UNARY, MEMSET }; + enum NodeKind { CONSTANT, LOAD, STORE, BINARY, CALL, RETURN, BRANCH, ALLOCA_ADDR, UNARY, MEMSET, GET_ELEMENT_PTR}; NodeKind kind; Value* value = nullptr; std::vector operands; @@ -149,38 +149,51 @@ void RISCv64ISel::selectNode(DAGNode* node) { auto dest_vreg = getVReg(node->value); Value* ptr_val = node->operands[0]->value; - // [V1设计保留] 对于从栈变量加载,继续使用伪指令 FRAME_LOAD。 - // 这种设计将栈帧布局的具体计算推迟到后续的 `eliminateFrameIndices` 阶段,保持了模块化。 + // --- 修改点 --- + // 1. 获取加载结果的类型 (即这个LOAD指令自身的类型) + Type* loaded_type = node->value->getType(); + + // 2. 根据类型选择正确的伪指令或真实指令操作码 + RVOpcodes frame_opcode = loaded_type->isPointer() ? RVOpcodes::FRAME_LOAD_D : RVOpcodes::FRAME_LOAD_W; + RVOpcodes real_opcode = loaded_type->isPointer() ? RVOpcodes::LD : RVOpcodes::LW; + + if (auto alloca = dynamic_cast(ptr_val)) { - auto instr = std::make_unique(RVOpcodes::FRAME_LOAD); + // 3. 创建使用新的、区分宽度的伪指令 + auto instr = std::make_unique(frame_opcode); instr->addOperand(std::make_unique(dest_vreg)); instr->addOperand(std::make_unique(getVReg(alloca))); CurMBB->addInstruction(std::move(instr)); + } else if (auto global = dynamic_cast(ptr_val)) { - // 对于全局变量,先用 la 加载其地址,再用 lw 加载其值。 + // 对于全局变量,先用 la 加载其地址 auto addr_vreg = getNewVReg(); auto la = std::make_unique(RVOpcodes::LA); la->addOperand(std::make_unique(addr_vreg)); la->addOperand(std::make_unique(global->getName())); CurMBB->addInstruction(std::move(la)); - auto lw = std::make_unique(RVOpcodes::LW); - lw->addOperand(std::make_unique(dest_vreg)); - lw->addOperand(std::make_unique( + // 然后根据类型使用 ld 或 lw 加载其值 + auto load_instr = std::make_unique(real_opcode); + load_instr->addOperand(std::make_unique(dest_vreg)); + load_instr->addOperand(std::make_unique( std::make_unique(addr_vreg), std::make_unique(0) )); - CurMBB->addInstruction(std::move(lw)); + CurMBB->addInstruction(std::move(load_instr)); + } else { - // 对于已经在虚拟寄存器中的指针地址,直接通过该地址加载。 + // 对于已经在虚拟寄存器中的指针地址,直接通过该地址加载 auto ptr_vreg = getVReg(ptr_val); - auto lw = std::make_unique(RVOpcodes::LW); - lw->addOperand(std::make_unique(dest_vreg)); - lw->addOperand(std::make_unique( + + // 根据类型使用 ld 或 lw + auto load_instr = std::make_unique(real_opcode); + load_instr->addOperand(std::make_unique(dest_vreg)); + load_instr->addOperand(std::make_unique( std::make_unique(ptr_vreg), std::make_unique(0) )); - CurMBB->addInstruction(std::move(lw)); + CurMBB->addInstruction(std::move(load_instr)); } break; } @@ -189,13 +202,8 @@ void RISCv64ISel::selectNode(DAGNode* node) { Value* val_to_store = node->operands[0]->value; Value* ptr_val = node->operands[1]->value; - // [V2优点] 在STORE节点内部负责加载作为源的常量。 - // 如果要存储的值是一个常量,就在这里生成 `li` 指令加载它。 + // 如果要存储的值是一个常量,就在这里生成 `li` 指令加载它 if (auto val_const = dynamic_cast(val_to_store)) { - if (DEBUG) { - std::cout << "[DEBUG] selectNode-BINARY: Found constant operand with value " << val_const->getInt() - << ". Generating LI instruction." << std::endl; - } auto li = std::make_unique(RVOpcodes::LI); li->addOperand(std::make_unique(getVReg(val_const))); li->addOperand(std::make_unique(val_const->getInt())); @@ -203,37 +211,50 @@ void RISCv64ISel::selectNode(DAGNode* node) { } auto val_vreg = getVReg(val_to_store); - // [V1设计保留] 同样,对于向栈变量的存储,使用 FRAME_STORE 伪指令。 + // --- 修改点 --- + // 1. 获取被存储的值的类型 + Type* stored_type = val_to_store->getType(); + + // 2. 根据类型选择正确的伪指令或真实指令操作码 + RVOpcodes frame_opcode = stored_type->isPointer() ? RVOpcodes::FRAME_STORE_D : RVOpcodes::FRAME_STORE_W; + RVOpcodes real_opcode = stored_type->isPointer() ? RVOpcodes::SD : RVOpcodes::SW; + if (auto alloca = dynamic_cast(ptr_val)) { - auto instr = std::make_unique(RVOpcodes::FRAME_STORE); + // 3. 创建使用新的、区分宽度的伪指令 + auto instr = std::make_unique(frame_opcode); instr->addOperand(std::make_unique(val_vreg)); instr->addOperand(std::make_unique(getVReg(alloca))); CurMBB->addInstruction(std::move(instr)); + } else if (auto global = dynamic_cast(ptr_val)) { - // 向全局变量存储。 + // 向全局变量存储 auto addr_vreg = getNewVReg(); auto la = std::make_unique(RVOpcodes::LA); la->addOperand(std::make_unique(addr_vreg)); la->addOperand(std::make_unique(global->getName())); CurMBB->addInstruction(std::move(la)); - auto sw = std::make_unique(RVOpcodes::SW); - sw->addOperand(std::make_unique(val_vreg)); - sw->addOperand(std::make_unique( + // 根据类型使用 sd 或 sw + auto store_instr = std::make_unique(real_opcode); + store_instr->addOperand(std::make_unique(val_vreg)); + store_instr->addOperand(std::make_unique( std::make_unique(addr_vreg), std::make_unique(0) )); - CurMBB->addInstruction(std::move(sw)); + CurMBB->addInstruction(std::move(store_instr)); + } else { - // 向一个指针(存储在虚拟寄存器中)指向的地址存储。 + // 向一个指针(存储在虚拟寄存器中)指向的地址存储 auto ptr_vreg = getVReg(ptr_val); - auto sw = std::make_unique(RVOpcodes::SW); - sw->addOperand(std::make_unique(val_vreg)); - sw->addOperand(std::make_unique( + + // 根据类型使用 sd 或 sw + auto store_instr = std::make_unique(real_opcode); + store_instr->addOperand(std::make_unique(val_vreg)); + store_instr->addOperand(std::make_unique( std::make_unique(ptr_vreg), std::make_unique(0) )); - CurMBB->addInstruction(std::move(sw)); + CurMBB->addInstruction(std::move(store_instr)); } break; } @@ -792,6 +813,108 @@ void RISCv64ISel::selectNode(DAGNode* node) { break; } + case DAGNode::GET_ELEMENT_PTR: { + auto gep = dynamic_cast(node->value); + auto result_vreg = getVReg(gep); + + // --- Step 1: 获取基地址 (此部分逻辑正确,保持不变) --- + auto base_ptr_node = node->operands[0]; + auto current_addr_vreg = getNewVReg(); + + if (auto alloca_base = dynamic_cast(base_ptr_node->value)) { + auto frame_addr_instr = std::make_unique(RVOpcodes::FRAME_ADDR); + frame_addr_instr->addOperand(std::make_unique(current_addr_vreg)); + frame_addr_instr->addOperand(std::make_unique(getVReg(alloca_base))); + CurMBB->addInstruction(std::move(frame_addr_instr)); + } else if (auto global_base = dynamic_cast(base_ptr_node->value)) { + auto la_instr = std::make_unique(RVOpcodes::LA); + la_instr->addOperand(std::make_unique(current_addr_vreg)); + la_instr->addOperand(std::make_unique(global_base->getName())); + CurMBB->addInstruction(std::move(la_instr)); + } else { + auto base_vreg = getVReg(base_ptr_node->value); + auto mv = std::make_unique(RVOpcodes::MV); + mv->addOperand(std::make_unique(current_addr_vreg)); + mv->addOperand(std::make_unique(base_vreg)); + CurMBB->addInstruction(std::move(mv)); + } + + // --- Step 2: [最终权威版] 遵循LLVM GEP语义迭代计算地址 --- + + // 初始被索引的类型,是基指针指向的那个类型 (例如, [2 x i32]) + Type* current_type = gep->getBasePointer()->getType()->as()->getBaseType(); + + // 迭代处理 GEP 的每一个索引 + for (size_t i = 0; i < gep->getNumIndices(); ++i) { + Value* indexValue = gep->getIndex(i); + + // GEP的第一个索引以整个 `current_type` 的大小为步长。 + // 后续的索引则以 `current_type` 的元素大小为步长。 + // 这一步是计算地址偏移的关键。 + unsigned stride = getTypeSizeInBytes(current_type); + + // 如果步长为0(例如对一个void类型或空结构体索引),则不产生任何偏移 + if (stride != 0) { + // --- 为当前索引和步长生成偏移计算指令 --- + auto offset_vreg = getNewVReg(); + auto index_vreg = getVReg(indexValue); + + // 如果索引是常量,先用 LI 指令加载到虚拟寄存器 + if (auto const_index = dynamic_cast(indexValue)) { + auto li = std::make_unique(RVOpcodes::LI); + li->addOperand(std::make_unique(index_vreg)); + li->addOperand(std::make_unique(const_index->getInt())); + CurMBB->addInstruction(std::move(li)); + } + + // 优化:如果步长是1,可以直接移动(MV)作为偏移量,无需乘法 + if (stride == 1) { + auto mv = std::make_unique(RVOpcodes::MV); + mv->addOperand(std::make_unique(offset_vreg)); + mv->addOperand(std::make_unique(index_vreg)); + CurMBB->addInstruction(std::move(mv)); + } else { + // 步长不为1,需要生成乘法指令 + auto size_vreg = getNewVReg(); + auto li_size = std::make_unique(RVOpcodes::LI); + li_size->addOperand(std::make_unique(size_vreg)); + li_size->addOperand(std::make_unique(stride)); + CurMBB->addInstruction(std::move(li_size)); + + auto mul = std::make_unique(RVOpcodes::MULW); + mul->addOperand(std::make_unique(offset_vreg)); + mul->addOperand(std::make_unique(index_vreg)); + mul->addOperand(std::make_unique(size_vreg)); + CurMBB->addInstruction(std::move(mul)); + } + + // 将计算出的偏移量累加到当前地址上 + auto add = std::make_unique(RVOpcodes::ADD); + add->addOperand(std::make_unique(current_addr_vreg)); + add->addOperand(std::make_unique(current_addr_vreg)); + add->addOperand(std::make_unique(offset_vreg)); + CurMBB->addInstruction(std::move(add)); + } + + // --- 为下一次迭代更新类型:深入一层 --- + if (auto array_type = current_type->as()) { + current_type = array_type->getElementType(); + } else if (auto ptr_type = current_type->as()) { + // 这种情况不应该在第二次迭代后发生,但为了逻辑健壮性保留 + current_type = ptr_type->getBaseType(); + } + // 如果`current_type`已经是i32等基本类型,它会保持不变, + // 但下一次循环如果还有索引,`getTypeSizeInBytes(i32)`仍然能正确计算步长。 + } + + // --- Step 3: 将最终计算出的地址存入GEP的目标虚拟寄存器 (保持不变) --- + auto final_mv = std::make_unique(RVOpcodes::MV); + final_mv->addOperand(std::make_unique(result_vreg)); + final_mv->addOperand(std::make_unique(current_addr_vreg)); + CurMBB->addInstruction(std::move(final_mv)); + break; + } + default: throw std::runtime_error("Unsupported DAGNode kind in ISel"); } @@ -850,6 +973,21 @@ std::vector> RISCv64ISel::build_dag(BasicB std::cout << " -> Operand " << i << " has kind: " << memset_node->operands[i]->kind << std::endl; } } + } else if (auto gep = dynamic_cast(inst)) { + // 如果这个GEP指令已经创建过节点,则跳过 + if(value_to_node.count(gep)) continue; + + // 创建一个新的 GET_ELEMENT_PTR 类型的节点 + auto gep_node = create_node(DAGNode::GET_ELEMENT_PTR, gep, value_to_node, nodes_storage); + + // 第一个操作数是基指针(即数组本身) + gep_node->operands.push_back(get_operand_node(gep->getBasePointer(), value_to_node, nodes_storage)); + + // 依次添加所有索引作为后续的操作数 + for (auto index : gep->getIndices()) { + // [修复] 从 Use 对象中获取真正的 Value* + gep_node->operands.push_back(get_operand_node(index->getValue(), value_to_node, nodes_storage)); + } } else if (auto load = dynamic_cast(inst)) { auto load_node = create_node(DAGNode::LOAD, load, value_to_node, nodes_storage); load_node->operands.push_back(get_operand_node(load->getPointer(), value_to_node, nodes_storage)); @@ -892,6 +1030,43 @@ std::vector> RISCv64ISel::build_dag(BasicB return nodes_storage; } +/** + * @brief 计算一个类型在内存中占用的字节数。 + * @param type 需要计算大小的IR类型。 + * @return 该类型占用的字节数。 + */ +unsigned RISCv64ISel::getTypeSizeInBytes(Type* type) { + if (!type) { + assert(false && "Cannot get size of a null type."); + return 0; + } + + switch (type->getKind()) { + // 对于SysY语言,基本类型int和float都占用4字节 + case Type::kInt: + case Type::kFloat: + return 4; + + // 指针类型在RISC-V 64位架构下占用8字节 + // 虽然SysY没有'int*'语法,但数组变量在IR层面本身就是指针类型 + case Type::kPointer: + return 8; + + // 数组类型的总大小 = 元素数量 * 单个元素的大小 + case Type::kArray: { + auto arrayType = type->as(); + // 递归调用以计算元素大小 + return arrayType->getNumElements() * getTypeSizeInBytes(arrayType->getElementType()); + } + + // 其他类型,如Void, Label等不占用栈空间,或者不应该出现在这里 + default: + // 如果遇到未处理的类型,触发断言,方便调试 + assert(false && "Unsupported type for size calculation."); + return 0; + } +} + // [新] 打印DAG图以供调试的辅助函数 void RISCv64ISel::print_dag(const std::vector>& dag, const std::string& bb_name) { // 检查是否有DEBUG宏或者全局变量,避免在非调试模式下打印 diff --git a/src/RISCv64RegAlloc.cpp b/src/RISCv64RegAlloc.cpp index 0c0a4a3..5edcf98 100644 --- a/src/RISCv64RegAlloc.cpp +++ b/src/RISCv64RegAlloc.cpp @@ -27,24 +27,26 @@ void RISCv64RegAlloc::run() { void RISCv64RegAlloc::eliminateFrameIndices() { StackFrameInfo& frame_info = MFunc->getFrameInfo(); - int current_offset = 20; // 这里写20是为了在$s0和第一个变量之间留出20字节的安全区, - // 以防止一些函数调用方面的恶性bug。 + // 初始偏移量,为保存ra和s0留出空间。可以根据你的函数序言调整。 + // 假设序言是 addi sp, sp, -stack_size; sd ra, stack_size-8(sp); sd s0, stack_size-16(sp); + int current_offset = 16; + Function* F = MFunc->getFunc(); RISCv64ISel* isel = MFunc->getISel(); + // --- MODIFICATION START: 动态计算栈帧大小 --- + // 遍历AllocaInst来计算局部变量所需的总空间 for (auto& bb : F->getBasicBlocks()) { for (auto& inst : bb->getInstructions()) { if (auto alloca = dynamic_cast(inst.get())) { - int size = 4; - if (!alloca->getDims().empty()) { - int num_elements = 1; - for (const auto& dim_use : alloca->getDims()) { - if (auto const_dim = dynamic_cast(dim_use->getValue())) { - num_elements *= const_dim->getInt(); - } - } - size *= num_elements; - } + // 获取Alloca指令指向的类型 (例如 alloca i32* 中,获取 i32) + Type* allocated_type = alloca->getType()->as()->getBaseType(); + int size = getTypeSizeInBytes(allocated_type); + + // RISC-V要求栈地址8字节对齐 + size = (size + 7) & ~7; + if (size == 0) size = 8; // 至少分配8字节 + current_offset += size; unsigned alloca_vreg = isel->getVReg(alloca); frame_info.alloca_offsets[alloca_vreg] = -current_offset; @@ -52,50 +54,66 @@ void RISCv64RegAlloc::eliminateFrameIndices() { } } frame_info.locals_size = current_offset; + // --- MODIFICATION END --- + // 遍历所有机器指令,将伪指令展开为真实指令 for (auto& mbb : MFunc->getBlocks()) { std::vector> new_instructions; for (auto& instr_ptr : mbb->getInstructions()) { - if (instr_ptr->getOpcode() == RVOpcodes::FRAME_LOAD) { + RVOpcodes opcode = instr_ptr->getOpcode(); + + // --- MODIFICATION START: 处理区分宽度的伪指令 --- + if (opcode == RVOpcodes::FRAME_LOAD_W || opcode == RVOpcodes::FRAME_LOAD_D) { + // 确定要生成的真实加载指令是 lw 还是 ld + RVOpcodes real_load_op = (opcode == RVOpcodes::FRAME_LOAD_W) ? RVOpcodes::LW : RVOpcodes::LD; + auto& operands = instr_ptr->getOperands(); unsigned dest_vreg = static_cast(operands[0].get())->getVRegNum(); unsigned alloca_vreg = static_cast(operands[1].get())->getVRegNum(); int offset = frame_info.alloca_offsets.at(alloca_vreg); auto addr_vreg = isel->getNewVReg(); + // 展开为: addi addr_vreg, s0, offset auto addi = std::make_unique(RVOpcodes::ADDI); addi->addOperand(std::make_unique(addr_vreg)); addi->addOperand(std::make_unique(PhysicalReg::S0)); addi->addOperand(std::make_unique(offset)); new_instructions.push_back(std::move(addi)); - auto lw = std::make_unique(RVOpcodes::LW); - lw->addOperand(std::make_unique(dest_vreg)); - lw->addOperand(std::make_unique( + // 展开为: lw/ld dest_vreg, 0(addr_vreg) + auto load_instr = std::make_unique(real_load_op); + load_instr->addOperand(std::make_unique(dest_vreg)); + load_instr->addOperand(std::make_unique( std::make_unique(addr_vreg), std::make_unique(0))); - new_instructions.push_back(std::move(lw)); + new_instructions.push_back(std::move(load_instr)); + + } else if (opcode == RVOpcodes::FRAME_STORE_W || opcode == RVOpcodes::FRAME_STORE_D) { + // 确定要生成的真实存储指令是 sw 还是 sd + RVOpcodes real_store_op = (opcode == RVOpcodes::FRAME_STORE_W) ? RVOpcodes::SW : RVOpcodes::SD; - } else if (instr_ptr->getOpcode() == RVOpcodes::FRAME_STORE) { auto& operands = instr_ptr->getOperands(); unsigned src_vreg = static_cast(operands[0].get())->getVRegNum(); unsigned alloca_vreg = static_cast(operands[1].get())->getVRegNum(); int offset = frame_info.alloca_offsets.at(alloca_vreg); auto addr_vreg = isel->getNewVReg(); + // 展开为: addi addr_vreg, s0, offset auto addi = std::make_unique(RVOpcodes::ADDI); addi->addOperand(std::make_unique(addr_vreg)); addi->addOperand(std::make_unique(PhysicalReg::S0)); addi->addOperand(std::make_unique(offset)); new_instructions.push_back(std::move(addi)); - auto sw = std::make_unique(RVOpcodes::SW); - sw->addOperand(std::make_unique(src_vreg)); - sw->addOperand(std::make_unique( + // 展开为: sw/sd src_vreg, 0(addr_vreg) + auto store_instr = std::make_unique(real_store_op); + store_instr->addOperand(std::make_unique(src_vreg)); + store_instr->addOperand(std::make_unique( std::make_unique(addr_vreg), std::make_unique(0))); - new_instructions.push_back(std::move(sw)); - } else if (instr_ptr->getOpcode() == RVOpcodes::FRAME_ADDR) { // [新] 处理FRAME_ADDR + new_instructions.push_back(std::move(store_instr)); + + } else if (instr_ptr->getOpcode() == RVOpcodes::FRAME_ADDR) { auto& operands = instr_ptr->getOperands(); unsigned dest_vreg = static_cast(operands[0].get())->getVRegNum(); unsigned alloca_vreg = static_cast(operands[1].get())->getVRegNum(); @@ -104,12 +122,13 @@ void RISCv64RegAlloc::eliminateFrameIndices() { // 将 `frame_addr rd, rs` 展开为 `addi rd, s0, offset` auto addi = std::make_unique(RVOpcodes::ADDI); addi->addOperand(std::make_unique(dest_vreg)); - addi->addOperand(std::make_unique(PhysicalReg::S0)); // 基地址是帧指针 s0 + addi->addOperand(std::make_unique(PhysicalReg::S0)); addi->addOperand(std::make_unique(offset)); new_instructions.push_back(std::move(addi)); } else { new_instructions.push_back(std::move(instr_ptr)); } + // --- MODIFICATION END --- } mbb->getInstructions() = std::move(new_instructions); } @@ -119,30 +138,72 @@ void RISCv64RegAlloc::getInstrUseDef(MachineInstr* instr, LiveSet& use, LiveSet& bool is_def = true; auto opcode = instr->getOpcode(); - // 预定义def和use规则 + // --- MODIFICATION START: 细化对指令的 use/def 定义 --- + + // 对于没有定义目标寄存器的指令,预先设置 is_def = false if (opcode == RVOpcodes::SW || opcode == RVOpcodes::SD || opcode == RVOpcodes::BEQ || opcode == RVOpcodes::BNE || opcode == RVOpcodes::BLT || opcode == RVOpcodes::BGE || + opcode == RVOpcodes::BLTU || opcode == RVOpcodes::BGEU || opcode == RVOpcodes::RET || opcode == RVOpcodes::J) { is_def = false; } + + // 对 CALL 指令进行特殊处理 if (opcode == RVOpcodes::CALL) { - // CALL会杀死所有调用者保存寄存器,这是一个简化处理 - // 同时也使用了传入a0-a7的参数 + // CALL 指令的第一个操作数通常是目标函数标签,不是寄存器。 + // 它可能会有一个可选的返回值(def),以及一系列参数(use)。 + // 这里的处理假定 CALL 的机器指令操作数布局是: + // [可选: dest_vreg (def)], [函数标签], [可选: arg1_vreg (use)], [可选: arg2_vreg (use)], ... + + // 我们需要一种方法来识别哪些操作数是def,哪些是use。 + // 一个简单的约定:如果第一个操作数是寄存器,则它是def(返回值)。 + if (!instr->getOperands().empty() && instr->getOperands().front()->getKind() == MachineOperand::KIND_REG) { + auto reg_op = static_cast(instr->getOperands().front().get()); + if (reg_op->isVirtual()) { + def.insert(reg_op->getVRegNum()); + } + } + + // 遍历所有操作数,非第一个寄存器操作数均视为use + bool first_reg_skipped = false; + for (const auto& op : instr->getOperands()) { + if (op->getKind() == MachineOperand::KIND_REG) { + if (!first_reg_skipped) { + first_reg_skipped = true; + continue; // 跳过我们已经作为def处理的返回值 + } + auto reg_op = static_cast(op.get()); + if (reg_op->isVirtual()) { + use.insert(reg_op->getVRegNum()); + } + } + } + + // **重要**: CALL指令还隐式定义(杀死)了所有调用者保存的寄存器。 + // 一个完整的实现会在这里将所有caller-saved寄存器标记为def, + // 以确保任何跨调用存活的变量都不会被分配到这些寄存器中。 + // 这个简化的实现暂不处理隐式def,但这是未来优化的关键点。 + + return; // CALL 指令处理完毕,直接返回 } + // --- MODIFICATION END --- + + // 对其他所有指令的通用处理逻辑 for (const auto& op : instr->getOperands()) { if (op->getKind() == MachineOperand::KIND_REG) { auto reg_op = static_cast(op.get()); if (reg_op->isVirtual()) { if (is_def) { def.insert(reg_op->getVRegNum()); - is_def = false; + is_def = false; // 一条指令通常只有一个目标寄存ator } else { use.insert(reg_op->getVRegNum()); } } } else if (op->getKind() == MachineOperand::KIND_MEM) { + // 内存操作数 `offset(base)` 中的 base 寄存器是 use auto mem_op = static_cast(op.get()); if (mem_op->getBase()->isVirtual()) { use.insert(mem_op->getBase()->getVRegNum()); @@ -151,6 +212,43 @@ void RISCv64RegAlloc::getInstrUseDef(MachineInstr* instr, LiveSet& use, LiveSet& } } +/** + * @brief 计算一个类型在内存中占用的字节数。 + * @param type 需要计算大小的IR类型。 + * @return 该类型占用的字节数。 + */ +unsigned RISCv64RegAlloc::getTypeSizeInBytes(Type* type) { + if (!type) { + assert(false && "Cannot get size of a null type."); + return 0; + } + + switch (type->getKind()) { + // 对于SysY语言,基本类型int和float都占用4字节 + case Type::kInt: + case Type::kFloat: + return 4; + + // 指针类型在RISC-V 64位架构下占用8字节 + // 虽然SysY没有'int*'语法,但数组变量在IR层面本身就是指针类型 + case Type::kPointer: + return 8; + + // 数组类型的总大小 = 元素数量 * 单个元素的大小 + case Type::kArray: { + auto arrayType = type->as(); + // 递归调用以计算元素大小 + return arrayType->getNumElements() * getTypeSizeInBytes(arrayType->getElementType()); + } + + // 其他类型,如Void, Label等不占用栈空间,或者不应该出现在这里 + default: + // 如果遇到未处理的类型,触发断言,方便调试 + assert(false && "Unsupported type for size calculation."); + return 0; + } +} + void RISCv64RegAlloc::analyzeLiveness() { bool changed = true; while (changed) { @@ -259,8 +357,21 @@ void RISCv64RegAlloc::colorGraph() { void RISCv64RegAlloc::rewriteFunction() { StackFrameInfo& frame_info = MFunc->getFrameInfo(); int current_offset = frame_info.locals_size; + + // --- FIX 1: 动态计算溢出槽大小 --- + // 根据溢出虚拟寄存器的真实类型,为其在栈上分配正确大小的空间。 for (unsigned vreg : spilled_vregs) { - current_offset += 4; + // 从反向映射中查找 vreg 对应的 IR Value + assert(vreg_to_value_map.count(vreg) && "Spilled vreg not found in map!"); + Value* val = vreg_to_value_map.at(vreg); + + // 使用辅助函数获取类型大小 + int size = getTypeSizeInBytes(val->getType()); + + // 保持栈8字节对齐 + current_offset += size; + current_offset = (current_offset + 7) & ~7; + frame_info.spill_offsets[vreg] = -current_offset; } frame_info.spill_size = current_offset - frame_info.locals_size; @@ -271,10 +382,16 @@ void RISCv64RegAlloc::rewriteFunction() { LiveSet use, def; getInstrUseDef(instr_ptr.get(), use, def); + // --- FIX 2: 为溢出的 'use' 操作数插入正确的加载指令 --- for (unsigned vreg : use) { if (spilled_vregs.count(vreg)) { + // 同样地,根据 vreg 的类型决定使用 lw 还是 ld + assert(vreg_to_value_map.count(vreg)); + Value* val = vreg_to_value_map.at(vreg); + RVOpcodes load_op = val->getType()->isPointer() ? RVOpcodes::LD : RVOpcodes::LW; + int offset = frame_info.spill_offsets.at(vreg); - auto load = std::make_unique(RVOpcodes::LW); + auto load = std::make_unique(load_op); load->addOperand(std::make_unique(vreg)); load->addOperand(std::make_unique( std::make_unique(PhysicalReg::S0), @@ -286,10 +403,16 @@ void RISCv64RegAlloc::rewriteFunction() { new_instructions.push_back(std::move(instr_ptr)); + // --- FIX 3: 为溢出的 'def' 操作数插入正确的存储指令 --- for (unsigned vreg : def) { if (spilled_vregs.count(vreg)) { + // 根据 vreg 的类型决定使用 sw 还是 sd + assert(vreg_to_value_map.count(vreg)); + Value* val = vreg_to_value_map.at(vreg); + RVOpcodes store_op = val->getType()->isPointer() ? RVOpcodes::SD : RVOpcodes::SW; + int offset = frame_info.spill_offsets.at(vreg); - auto store = std::make_unique(RVOpcodes::SW); + auto store = std::make_unique(store_op); store->addOperand(std::make_unique(vreg)); store->addOperand(std::make_unique( std::make_unique(PhysicalReg::S0), @@ -302,27 +425,39 @@ void RISCv64RegAlloc::rewriteFunction() { mbb->getInstructions() = std::move(new_instructions); } + // 最后的虚拟寄存器到物理寄存器的替换过程保持不变 for (auto& mbb : MFunc->getBlocks()) { for (auto& instr_ptr : mbb->getInstructions()) { for (auto& op_ptr : instr_ptr->getOperands()) { + + // 情况一:操作数本身就是一个寄存器 (例如 add rd, rs1, rs2 中的所有操作数) if(op_ptr->getKind() == MachineOperand::KIND_REG) { auto reg_op = static_cast(op_ptr.get()); if (reg_op->isVirtual()) { unsigned vreg = reg_op->getVRegNum(); if (color_map.count(vreg)) { + // 如果vreg被成功着色,替换为物理寄存器 reg_op->setPReg(color_map.at(vreg)); } else if (spilled_vregs.count(vreg)) { - reg_op->setPReg(PhysicalReg::T6); // 溢出统一用t6 + // 如果vreg被溢出,替换为专用的溢出物理寄存器t6 + reg_op->setPReg(PhysicalReg::T6); } } - } else if (op_ptr->getKind() == MachineOperand::KIND_MEM) { + } + // 情况二:操作数是一个内存地址 (例如 lw rd, offset(rs1) 中的 offset(rs1)) + else if (op_ptr->getKind() == MachineOperand::KIND_MEM) { auto mem_op = static_cast(op_ptr.get()); + // 获取内存操作数内部的“基址寄存器” auto base_reg_op = mem_op->getBase(); + + // 对这个基址寄存器,执行与情况一完全相同的替换逻辑 if(base_reg_op->isVirtual()){ unsigned vreg = base_reg_op->getVRegNum(); if(color_map.count(vreg)) { + // 如果基址vreg被成功着色,替换 base_reg_op->setPReg(color_map.at(vreg)); } else if (spilled_vregs.count(vreg)) { + // 如果基址vreg被溢出,替换为t6 base_reg_op->setPReg(PhysicalReg::T6); } } diff --git a/src/Reg2Mem.cpp b/src/Reg2Mem.cpp deleted file mode 100644 index d44d1c8..0000000 --- a/src/Reg2Mem.cpp +++ /dev/null @@ -1,129 +0,0 @@ -#include "Reg2Mem.h" -#include -#include -#include -#include - -namespace sysy { - -/** - * 删除phi节点 - * 删除phi节点后可能会生成冗余存储代码 - */ -void Reg2Mem::DeletePhiInst(){ - auto &functions = pModule->getFunctions(); - for (auto &function : functions) { - auto basicBlocks = function.second->getBasicBlocks(); - for (auto &basicBlock : basicBlocks) { - - for (auto iter = basicBlock->begin(); iter != basicBlock->end();) { - auto &instruction = *iter; - if (instruction->isPhi()) { - auto predBlocks = basicBlock->getPredecessors(); - // 寻找源和目的 - // 目的就是phi指令的第一个操作数 - // 源就是phi指令的后续操作数 - auto destination = instruction->getOperand(0); - int predBlockindex = 0; - for (auto &predBlock : predBlocks) { - ++predBlockindex; - // 判断前驱块儿只有一个后继还是多个后继 - // 如果有多个 - auto source = instruction->getOperand(predBlockindex); - if (source == destination) { - continue; - } - // std::cout << predBlock->getNumSuccessors() << std::endl; - if (predBlock->getNumSuccessors() > 1) { - // 创建一个basicblock - auto newbasicBlock = function.second->addBasicBlock(); - std::stringstream ss; - ss << " phidel.L" << pBuilder->getLabelIndex(); - newbasicBlock->setName(ss.str()); - ss.str(""); - // // 修改前驱后继关系 - basicBlock->replacePredecessor(predBlock, newbasicBlock); - // predBlock = newbasicBlock; - newbasicBlock->addPredecessor(predBlock); - newbasicBlock->addSuccessor(basicBlock.get()); - predBlock->removeSuccessor(basicBlock.get()); - predBlock->addSuccessor(newbasicBlock); - // std::cout << "the block name is " << basicBlock->getName() << std::endl; - // for (auto pb : basicBlock->getPredecessors()) { - // // newbasicBlock->addPredecessor(pb); - // std::cout << pb->getName() << std::endl; - // } - // sysy::BasicBlock::conectBlocks(newbasicBlock, static_cast(basicBlock.get())); - // 若后为跳转指令,应该修改跳转指令所到达的位置 - auto thelastinst = predBlock->end(); - (--thelastinst); - - if (thelastinst->get()->isConditional() || thelastinst->get()->isUnconditional()) { // 如果是跳转指令 - auto opnum = thelastinst->get()->getNumOperands(); - for (size_t i = 0; i < opnum; i++) { - if (thelastinst->get()->getOperand(i) == basicBlock.get()) { - thelastinst->get()->replaceOperand(i, newbasicBlock); - } - } - } - // 在新块中插入store指令 - pBuilder->setPosition(newbasicBlock, newbasicBlock->end()); - // pBuilder->createStoreInst(source, destination); - if (source->isInt() || source->isFloat()) { - pBuilder->createStoreInst(source, destination); - } else { - auto loadInst = pBuilder->createLoadInst(source); - pBuilder->createStoreInst(loadInst, destination); - } - // pBuilder->createMoveInst(Instruction::kMove, destination->getType(), destination, source, - // newbasicBlock); - pBuilder->setPosition(newbasicBlock, newbasicBlock->end()); - pBuilder->createUncondBrInst(basicBlock.get(), {}); - } else { - // 如果前驱块只有一个后继 - auto thelastinst = predBlock->end(); - (--thelastinst); - // std::cout << predBlock->getName() << std::endl; - // std::cout << thelastinst->get() << std::endl; - // std::cout << "First point 11 " << std::endl; - if (thelastinst->get()->isConditional() || thelastinst->get()->isUnconditional()) { - // 在跳转语句前insert st指令 - pBuilder->setPosition(predBlock, thelastinst); - } else { - pBuilder->setPosition(predBlock, predBlock->end()); - } - - if (source->isInt() || source->isFloat()) { - pBuilder->createStoreInst(source, destination); - } else { - auto loadInst = pBuilder->createLoadInst(source); - pBuilder->createStoreInst(loadInst, destination); - } - } - } - // 删除phi指令 - auto &instructions = basicBlock->getInstructions(); - usedelete(iter->get()); - iter = instructions.erase(iter); - if (basicBlock->getNumInstructions() == 0) { - if (basicBlock->getNumSuccessors() == 1) { - pBuilder->setPosition(basicBlock.get(), basicBlock->end()); - pBuilder->createUncondBrInst(basicBlock->getSuccessors()[0], {}); - } - } - } else { - break; - } - } - } - } -} - -void Reg2Mem::usedelete(Instruction *instr) { - for (auto &use : instr->getOperands()) { - auto val = use->getValue(); - val->removeUse(use); - } -} - -} // namespace sysy diff --git a/src/SysYIRAnalyser.cpp b/src/SysYIRAnalyser.cpp index 0761c0a..51e2b27 100644 --- a/src/SysYIRAnalyser.cpp +++ b/src/SysYIRAnalyser.cpp @@ -523,9 +523,6 @@ bool ActiveVarAnalysis::analyze(Module *pModule, BasicBlock *block) { } -auto ActiveVarAnalysis::getActiveTable() const -> const std::map>> & { - return activeTable; -} } // namespace sysy diff --git a/src/SysYIRCFGOpt.cpp b/src/SysYIRCFGOpt.cpp new file mode 100644 index 0000000..1a6c3a1 --- /dev/null +++ b/src/SysYIRCFGOpt.cpp @@ -0,0 +1,606 @@ +#include "SysYIRCFGOpt.h" +#include "SysYIROptUtils.h" +#include +#include +#include +#include +#include +#include +#include // 引入队列,SysYDelNoPreBLock需要 + +namespace sysy { + +// 定义静态ID +void *SysYDelInstAfterBrPass::ID = (void *)&SysYDelInstAfterBrPass::ID; +void *SysYDelEmptyBlockPass::ID = (void *)&SysYDelEmptyBlockPass::ID; +void *SysYDelNoPreBLockPass::ID = (void *)&SysYDelNoPreBLockPass::ID; +void *SysYBlockMergePass::ID = (void *)&SysYBlockMergePass::ID; +void *SysYAddReturnPass::ID = (void *)&SysYAddReturnPass::ID; +void *SysYCondBr2BrPass::ID = (void *)&SysYCondBr2BrPass::ID; + + +// ====================================================================== +// SysYCFGOptUtils: 辅助工具类,包含实际的CFG优化逻辑 +// ====================================================================== + +// 删除br后的无用指令 +bool SysYCFGOptUtils::SysYDelInstAfterBr(Function *func) { + bool changed = false; + + auto basicBlocks = func->getBasicBlocks(); + for (auto &basicBlock : basicBlocks) { + bool Branch = false; + auto &instructions = basicBlock->getInstructions(); + auto Branchiter = instructions.end(); + for (auto iter = instructions.begin(); iter != instructions.end(); ++iter) { + if ((*iter)->isTerminator()){ + Branch = true; + Branchiter = iter; + break; + } + } + if (Branchiter != instructions.end()) ++Branchiter; + while (Branchiter != instructions.end()) { + changed = true; + Branchiter = instructions.erase(Branchiter); + } + + if (Branch) { // 更新前驱后继关系 + auto thelastinstinst = basicBlock->getInstructions().end(); + --thelastinstinst; + auto &Successors = basicBlock->getSuccessors(); + for (auto iterSucc = Successors.begin(); iterSucc != Successors.end();) { + (*iterSucc)->removePredecessor(basicBlock.get()); + basicBlock->removeSuccessor(*iterSucc); + } + if (thelastinstinst->get()->isUnconditional()) { + BasicBlock* branchBlock = dynamic_cast(thelastinstinst->get()->getOperand(0)); + basicBlock->addSuccessor(branchBlock); + branchBlock->addPredecessor(basicBlock.get()); + } else if (thelastinstinst->get()->isConditional()) { + BasicBlock* thenBlock = dynamic_cast(thelastinstinst->get()->getOperand(1)); + BasicBlock* elseBlock = dynamic_cast(thelastinstinst->get()->getOperand(2)); + basicBlock->addSuccessor(thenBlock); + basicBlock->addSuccessor(elseBlock); + thenBlock->addPredecessor(basicBlock.get()); + elseBlock->addPredecessor(basicBlock.get()); + } + } + } + + return changed; +} + +// 合并基本块 +bool SysYCFGOptUtils::SysYBlockMerge(Function *func) { + bool changed = false; + + for (auto blockiter = func->getBasicBlocks().begin(); + blockiter != func->getBasicBlocks().end();) { + if (blockiter->get()->getNumSuccessors() == 1) { + // 如果当前块只有一个后继块 + // 且后继块只有一个前驱块 + // 则将当前块和后继块合并 + if (((blockiter->get())->getSuccessors()[0])->getNumPredecessors() == 1) { + // std::cout << "merge block: " << blockiter->get()->getName() << std::endl; + BasicBlock* block = blockiter->get(); + BasicBlock* nextBlock = blockiter->get()->getSuccessors()[0]; + auto nextarguments = nextBlock->getArguments(); + // 删除br指令 + if (block->getNumInstructions() != 0) { + auto thelastinstinst = block->end(); + (--thelastinstinst); + if (thelastinstinst->get()->isUnconditional()) { + SysYIROptUtils::usedelete(thelastinstinst->get()); + thelastinstinst = block->getInstructions().erase(thelastinstinst); + } else if (thelastinstinst->get()->isConditional()) { + // 如果是条件分支,判断条件是否相同,主要优化相同布尔表达式 + if (thelastinstinst->get()->getOperand(1)->getName() == thelastinstinst->get()->getOperand(1)->getName()) { + SysYIROptUtils::usedelete(thelastinstinst->get()); + thelastinstinst = block->getInstructions().erase(thelastinstinst); + } + } + } + // 将后继块的指令移动到当前块 + // 并将后继块的父指针改为当前块 + for (auto institer = nextBlock->begin(); institer != nextBlock->end();) { + institer->get()->setParent(block); + block->getInstructions().emplace_back(institer->release()); + institer = nextBlock->getInstructions().erase(institer); + } + // 合并参数 + // TODO:是否需要去重? + for (auto &argm : nextarguments) { + argm->setParent(block); + block->insertArgument(argm); + } + // 更新前驱后继关系,类似树节点操作 + block->removeSuccessor(nextBlock); + nextBlock->removePredecessor(block); + std::list succshoulddel; + for (auto &succ : nextBlock->getSuccessors()) { + block->addSuccessor(succ); + succ->replacePredecessor(nextBlock, block); + succshoulddel.push_back(succ); + } + for (auto del : succshoulddel) { + nextBlock->removeSuccessor(del); + } + + func->removeBasicBlock(nextBlock); + changed = true; + + } else { + blockiter++; + } + } else { + blockiter++; + } + } + + return changed; +} + +// 删除无前驱块,兼容SSA后的处理 +bool SysYCFGOptUtils::SysYDelNoPreBLock(Function *func) { + + bool changed = false; + + for (auto &block : func->getBasicBlocks()) { + block->setreachableFalse(); + } + // 对函数基本块做一个拓扑排序,排查不可达基本块 + auto entryBlock = func->getEntryBlock(); + entryBlock->setreachableTrue(); + std::queue blockqueue; + blockqueue.push(entryBlock); + while (!blockqueue.empty()) { + auto block = blockqueue.front(); + blockqueue.pop(); + for (auto &succ : block->getSuccessors()) { + if (!succ->getreachable()) { + succ->setreachableTrue(); + blockqueue.push(succ); + } + } + } + + // 删除不可达基本块指令 + for (auto blockIter = func->getBasicBlocks().begin(); blockIter != func->getBasicBlocks().end(); blockIter++) { + if (!blockIter->get()->getreachable()) { + for (auto instIter = blockIter->get()->getInstructions().begin(); + instIter != blockIter->get()->getInstructions().end();) { + SysYIROptUtils::usedelete(instIter->get()); + instIter = blockIter->get()->getInstructions().erase(instIter); + } + } + } + + + for (auto blockIter = func->getBasicBlocks().begin(); blockIter != func->getBasicBlocks().end();) { + if (!blockIter->get()->getreachable()) { + for (auto succblock : blockIter->get()->getSuccessors()) { + for (auto &phiinst : succblock->getInstructions()) { + if (phiinst->getKind() != Instruction::kPhi) { + break; + } + // 使用 delBlk 方法正确地删除对应于被删除基本块的传入值 + dynamic_cast(phiinst.get())->delBlk(blockIter->get()); + } + } + // 删除不可达基本块,注意迭代器不可达问题 + func->removeBasicBlock((blockIter++)->get()); + changed = true; + } else { + blockIter++; + } + } + + return changed; +} + +// 删除空块 +bool SysYCFGOptUtils::SysYDelEmptyBlock(Function *func, IRBuilder* pBuilder) { + bool changed = false; + + // 收集不可达基本块 + // 这里的不可达基本块是指没有实际指令的基本块 + // 当一个基本块没有实际指令例如只有phi指令和一个uncondbr指令时,也会被视作不可达 + auto basicBlocks = func->getBasicBlocks(); + std::map EmptyBlocks; + // 空块儿和后继的基本块的映射 + for (auto &basicBlock : basicBlocks) { + if (basicBlock->getNumInstructions() == 0) { + if (basicBlock->getNumSuccessors() == 1) { + EmptyBlocks[basicBlock.get()] = basicBlock->getSuccessors().front(); + } + } + else{ + // 如果只有phi指令和一个uncondbr。(phi)*(uncondbr)? + // 判断除了最后一个指令之外是不是只有phi指令 + bool onlyPhi = true; + for (auto &inst : basicBlock->getInstructions()) { + if (!inst->isPhi() && !inst->isUnconditional()) { + onlyPhi = false; + break; + } + } + if(onlyPhi && basicBlock->getNumSuccessors() == 1) // 确保有后继且只有一个 + EmptyBlocks[basicBlock.get()] = basicBlock->getSuccessors().front(); + } + } + // 更新基本块信息,增加必要指令 + for (auto &basicBlock : basicBlocks) { + // 把空块转换成只有跳转指令的不可达块 (这段逻辑在优化遍中可能需要调整,这里是原样保留) + // 通常,DelEmptyBlock 应该在BlockMerge之后运行,如果存在完全空块,它会尝试填充一个Br指令。 + // 但是,它主要目的是重定向跳转。 + if (distance(basicBlock->begin(), basicBlock->end()) == 0) { + if (basicBlock->getNumSuccessors() == 0) { + continue; + } + if (basicBlock->getNumSuccessors() > 1) { + // 如果一个空块有多个后继,说明CFG结构有问题或者需要特殊处理,这里简单assert + assert(false && "Empty block with multiple successors found during SysYDelEmptyBlock"); + } + // 这里的逻辑有点问题,如果一个块是空的,且只有一个后继,应该直接跳转到后继。 + // 如果这个块最终被删除了,那么其前驱也需要重定向。 + // 这个循环的目的是重定向现有的跳转指令,而不是创建新的。 + // 所以下面的逻辑才是核心。 + // pBuilder->setPosition(basicBlock.get(), basicBlock->end()); + // pBuilder->createUncondBrInst(basicBlock->getSuccessors()[0], {}); + continue; + } + + auto thelastinst = basicBlock->getInstructions().end(); + --thelastinst; + + // 根据br指令传递的后继块信息,跳过空块链 + if (thelastinst->get()->isUnconditional()) { + BasicBlock* OldBrBlock = dynamic_cast(thelastinst->get()->getOperand(0)); + BasicBlock *thelastBlockOld = nullptr; + // 如果空块链表为多个块 + while (EmptyBlocks.count(dynamic_cast(thelastinst->get()->getOperand(0)))) { + thelastBlockOld = dynamic_cast(thelastinst->get()->getOperand(0)); + thelastinst->get()->replaceOperand(0, EmptyBlocks[thelastBlockOld]); + } + + // 如果有重定向发生 + if (thelastBlockOld != nullptr) { + basicBlock->removeSuccessor(OldBrBlock); + OldBrBlock->removePredecessor(basicBlock.get()); + basicBlock->addSuccessor(dynamic_cast(thelastinst->get()->getOperand(0))); + dynamic_cast(thelastinst->get()->getOperand(0))->addPredecessor(basicBlock.get()); + changed = true; // 标记IR被修改 + } + + + if (thelastBlockOld != nullptr) { + for (auto &InstInNew : dynamic_cast(thelastinst->get()->getOperand(0))->getInstructions()) { + if (InstInNew->isPhi()) { + // 使用 delBlk 方法删除 oldBlock 对应的传入值 + dynamic_cast(InstInNew.get())->delBlk(thelastBlockOld); + } else { + break; + } + } + } + + } else if (thelastinst->get()->getKind() == Instruction::kCondBr) { + auto OldThenBlock = dynamic_cast(thelastinst->get()->getOperand(1)); + auto OldElseBlock = dynamic_cast(thelastinst->get()->getOperand(2)); + bool thenChanged = false; + bool elseChanged = false; + + + BasicBlock *thelastBlockOld = nullptr; + while (EmptyBlocks.count(dynamic_cast(thelastinst->get()->getOperand(1)))) { + thelastBlockOld = dynamic_cast(thelastinst->get()->getOperand(1)); + thelastinst->get()->replaceOperand( + 1, EmptyBlocks[dynamic_cast(thelastinst->get()->getOperand(1))]); + thenChanged = true; + } + + if (thenChanged) { + basicBlock->removeSuccessor(OldThenBlock); + OldThenBlock->removePredecessor(basicBlock.get()); + basicBlock->addSuccessor(dynamic_cast(thelastinst->get()->getOperand(1))); + dynamic_cast(thelastinst->get()->getOperand(1))->addPredecessor(basicBlock.get()); + changed = true; // 标记IR被修改 + } + + // 处理 then 和 else 分支合并的情况 + if (dynamic_cast(thelastinst->get()->getOperand(1)) == + dynamic_cast(thelastinst->get()->getOperand(2))) { + auto thebrBlock = dynamic_cast(thelastinst->get()->getOperand(1)); + SysYIROptUtils::usedelete(thelastinst->get()); + thelastinst = basicBlock->getInstructions().erase(thelastinst); + pBuilder->setPosition(basicBlock.get(), basicBlock->end()); + pBuilder->createUncondBrInst(thebrBlock, {}); + changed = true; // 标记IR被修改 + continue; + } + + if (thelastBlockOld != nullptr) { + for (auto &InstInNew : dynamic_cast(thelastinst->get()->getOperand(1))->getInstructions()) { + if (InstInNew->isPhi()) { + // 使用 delBlk 方法删除 oldBlock 对应的传入值 + dynamic_cast(InstInNew.get())->delBlk(thelastBlockOld); + } else { + break; + } + } + } + + thelastBlockOld = nullptr; + while (EmptyBlocks.count(dynamic_cast(thelastinst->get()->getOperand(2)))) { + thelastBlockOld = dynamic_cast(thelastinst->get()->getOperand(2)); + thelastinst->get()->replaceOperand( + 2, EmptyBlocks[dynamic_cast(thelastinst->get()->getOperand(2))]); + elseChanged = true; + } + + if (elseChanged) { + basicBlock->removeSuccessor(OldElseBlock); + OldElseBlock->removePredecessor(basicBlock.get()); + basicBlock->addSuccessor(dynamic_cast(thelastinst->get()->getOperand(2))); + dynamic_cast(thelastinst->get()->getOperand(2))->addPredecessor(basicBlock.get()); + changed = true; // 标记IR被修改 + } + + // 处理 then 和 else 分支合并的情况 + if (dynamic_cast(thelastinst->get()->getOperand(1)) == + dynamic_cast(thelastinst->get()->getOperand(2))) { + auto thebrBlock = dynamic_cast(thelastinst->get()->getOperand(1)); + SysYIROptUtils::usedelete(thelastinst->get()); + thelastinst = basicBlock->getInstructions().erase(thelastinst); + pBuilder->setPosition(basicBlock.get(), basicBlock->end()); + pBuilder->createUncondBrInst(thebrBlock, {}); + changed = true; // 标记IR被修改 + continue; + } + + + // 如果有重定向发生 + // 需要更新后继块的前驱关系 + if (thelastBlockOld != nullptr) { + for (auto &InstInNew : dynamic_cast(thelastinst->get()->getOperand(2))->getInstructions()) { + if (InstInNew->isPhi()) { + // 使用 delBlk 方法删除 oldBlock 对应的传入值 + dynamic_cast(InstInNew.get())->delBlk(thelastBlockOld); + } else { + break; + } + } + } + + } else { + // 如果不是终止指令,但有后继 (例如,末尾没有显式终止指令的块) + // 这段逻辑可能需要更严谨的CFG检查来确保正确性 + if (basicBlock->getNumSuccessors() == 1) { + // 这里的逻辑似乎是想为没有terminator的块添加一个,但通常这应该在CFG构建阶段完成。 + // 如果这里仍然执行,确保它符合预期。 + // pBuilder->setPosition(basicBlock.get(), basicBlock->end()); + // pBuilder->createUncondBrInst(basicBlock->getSuccessors()[0], {}); + // auto thelastinst = basicBlock->getInstructions().end(); + // (--thelastinst); + // auto OldBrBlock = dynamic_cast(thelastinst->get()->getOperand(0)); + // sysy::BasicBlock *thelastBlockOld = nullptr; + // while (EmptyBlocks.find(dynamic_cast(thelastinst->get()->getOperand(0))) != + // EmptyBlocks.end()) { + // thelastBlockOld = dynamic_cast(thelastinst->get()->getOperand(0)); + + // thelastinst->get()->replaceOperand( + // 0, EmptyBlocks[dynamic_cast(thelastinst->get()->getOperand(0))]); + // } + + // basicBlock->removeSuccessor(OldBrBlock); + // OldBrBlock->removePredecessor(basicBlock.get()); + // basicBlock->addSuccessor(dynamic_cast(thelastinst->get()->getOperand(0))); + // dynamic_cast(thelastinst->get()->getOperand(0))->addPredecessor(basicBlock.get()); + // changed = true; // 标记IR被修改 + // if (thelastBlockOld != nullptr) { + // int indexphi = 0; + // for (auto &pred : dynamic_cast(thelastinst->get()->getOperand(0))->getPredecessors()) { + // if (pred == thelastBlockOld) { + // break; + // } + // indexphi++; + // } + + // for (auto &InstInNew : dynamic_cast(thelastinst->get()->getOperand(0))->getInstructions()) { + // if (InstInNew->isPhi()) { + // dynamic_cast(InstInNew.get())->removeOperand(indexphi + 1); + // } else { + // break; + // } + // } + // } + } + } + } + + // 真正的删除空块 + for (auto iter = func->getBasicBlocks().begin(); iter != func->getBasicBlocks().end();) { + + if (EmptyBlocks.count(iter->get())) { + // EntryBlock跳过 + if (iter->get() == func->getEntryBlock()) { + ++iter; + continue; + } + + for (auto instIter = iter->get()->getInstructions().begin(); + instIter != iter->get()->getInstructions().end();) { + SysYIROptUtils::usedelete(instIter->get()); // 仅删除 use 关系 + // 显式地从基本块中删除指令并更新迭代器 + instIter = iter->get()->getInstructions().erase(instIter); + } + // 删除不可达基本块的phi指令的操作数 + for (auto &succ : iter->get()->getSuccessors()) { + for (auto &instinsucc : succ->getInstructions()) { + if (instinsucc->isPhi()) { + // iter->get() 就是当前被删除的空基本块,它作为前驱连接到这里的Phi指令 + dynamic_cast(instinsucc.get())->delBlk(iter->get()); + } else { + // Phi 指令通常在基本块的开头,如果不是 Phi 指令就停止检查 + break; + } + } + } + + func->removeBasicBlock((iter++)->get()); + changed = true; + } else { + ++iter; + } + } + + return changed; +} + +// 如果函数没有返回指令,则添加一个默认返回指令(主要解决void函数没有返回指令的问题) +bool SysYCFGOptUtils::SysYAddReturn(Function *func, IRBuilder* pBuilder) { + bool changed = false; + auto basicBlocks = func->getBasicBlocks(); + for (auto &block : basicBlocks) { + if (block->getNumSuccessors() == 0) { + // 如果基本块没有后继块,则添加一个返回指令 + if (block->getNumInstructions() == 0) { + pBuilder->setPosition(block.get(), block->end()); + pBuilder->createReturnInst(); + changed = true; // 标记IR被修改 + } else { + auto thelastinst = block->getInstructions().end(); + --thelastinst; + if (thelastinst->get()->getKind() != Instruction::kReturn) { + // std::cout << "Warning: Function " << func->getName() << " has no return instruction, adding default return." << std::endl; + + pBuilder->setPosition(block.get(), block->end()); + // TODO: 如果int float函数缺少返回值是否需要报错 + if (func->getReturnType()->isInt()) { + pBuilder->createReturnInst(ConstantInteger::get(0)); + } else if (func->getReturnType()->isFloat()) { + pBuilder->createReturnInst(ConstantFloating::get(0.0F)); + } else { + pBuilder->createReturnInst(); + } + changed = true; // 标记IR被修改 + } + } + } + } + + return changed; +} + +// 条件分支转换为无条件分支 +// 主要针对已知条件值的分支转换为无条件分支 +// 例如 if (cond) { ... } else { ... } 中的 cond 已经 +// 确定为 true 或 false 的情况 +bool SysYCFGOptUtils::SysYCondBr2Br(Function *func, IRBuilder* pBuilder) { + bool changed = false; + + for (auto &basicblock : func->getBasicBlocks()) { + if (basicblock->getNumInstructions() == 0) + continue; + + auto thelast = basicblock->getInstructions().end(); + --thelast; + + if (thelast->get()->isConditional()){ + ConstantValue *constOperand = dynamic_cast(thelast->get()->getOperand(0)); + std::string opname; + int constint = 0; + float constfloat = 0.0F; + bool constint_Use = false; + bool constfloat_Use = false; + if (constOperand != nullptr) { + if (constOperand->isFloat()) { + constfloat = constOperand->getFloat(); + constfloat_Use = true; + } else { + constint = constOperand->getInt(); + constint_Use = true; + } + } + // 如果可以计算 + if (constfloat_Use || constint_Use) { + changed = true; + + auto thenBlock = dynamic_cast(thelast->get()->getOperand(1)); + auto elseBlock = dynamic_cast(thelast->get()->getOperand(2)); + SysYIROptUtils::usedelete(thelast->get()); + thelast = basicblock->getInstructions().erase(thelast); + if ((constfloat_Use && constfloat == 1.0F) || (constint_Use && constint == 1)) { + // cond为true或非0 + pBuilder->setPosition(basicblock.get(), basicblock->end()); + pBuilder->createUncondBrInst(thenBlock, {}); + + // 更新CFG关系 + basicblock->removeSuccessor(elseBlock); + elseBlock->removePredecessor(basicblock.get()); + + // 删除elseBlock的phi指令中对应的basicblock.get()的传入值 + for (auto &phiinst : elseBlock->getInstructions()) { + if (phiinst->getKind() != Instruction::kPhi) { + break; + } + // 使用 delBlk 方法删除 basicblock.get() 对应的传入值 + dynamic_cast(phiinst.get())->delBlk(basicblock.get()); + } + + } else { // cond为false或0 + + pBuilder->setPosition(basicblock.get(), basicblock->end()); + pBuilder->createUncondBrInst(elseBlock, {}); + + // 更新CFG关系 + basicblock->removeSuccessor(thenBlock); + thenBlock->removePredecessor(basicblock.get()); + + // 删除thenBlock的phi指令中对应的basicblock.get()的传入值 + for (auto &phiinst : thenBlock->getInstructions()) { + if (phiinst->getKind() != Instruction::kPhi) { + break; + } + // 使用 delBlk 方法删除 basicblock.get() 对应的传入值 + dynamic_cast(phiinst.get())->delBlk(basicblock.get()); + } + + } + } + } + } + + return changed; +} + +// ====================================================================== +// 独立的CFG优化遍的实现 +// ====================================================================== + +bool SysYDelInstAfterBrPass::runOnFunction(Function *F, AnalysisManager& AM) { + return SysYCFGOptUtils::SysYDelInstAfterBr(F); +} + +bool SysYDelEmptyBlockPass::runOnFunction(Function *F, AnalysisManager& AM) { + return SysYCFGOptUtils::SysYDelEmptyBlock(F, pBuilder); +} + +bool SysYDelNoPreBLockPass::runOnFunction(Function *F, AnalysisManager& AM) { + return SysYCFGOptUtils::SysYDelNoPreBLock(F); +} + +bool SysYBlockMergePass::runOnFunction(Function *F, AnalysisManager& AM) { + return SysYCFGOptUtils::SysYBlockMerge(F); +} + +bool SysYAddReturnPass::runOnFunction(Function *F, AnalysisManager& AM) { + return SysYCFGOptUtils::SysYAddReturn(F, pBuilder); +} + +bool SysYCondBr2BrPass::runOnFunction(Function *F, AnalysisManager& AM) { + return SysYCFGOptUtils::SysYCondBr2Br(F, pBuilder); +} + +} // namespace sysy \ No newline at end of file diff --git a/src/SysYIRGenerator.cpp b/src/SysYIRGenerator.cpp index 40b5f59..08332ec 100644 --- a/src/SysYIRGenerator.cpp +++ b/src/SysYIRGenerator.cpp @@ -15,6 +15,46 @@ using namespace std; namespace sysy { + +Type* SysYIRGenerator::buildArrayType(Type* baseType, const std::vector& dims){ + Type* currentType = baseType; + // 从最内层维度开始构建 ArrayType + // 例如对于 int arr[2][3],先处理 [3],再处理 [2] + // 注意:SysY 的 dims 是从最外层到最内层,所以我们需要反向迭代 + // 或者调整逻辑,使得从内到外构建 ArrayType + // 假设 dims 列表是 [dim1, dim2, dim3...] (例如 [2, 3] for int[2][3]) + // 我们需要从最内层维度开始向外构建 ArrayType + for (int i = dims.size() - 1; i >= 0; --i) { + // 维度大小必须是常量,否则无法构建 ArrayType + ConstantInteger* constDim = dynamic_cast(dims[i]); + if (constDim == nullptr) { + // 如果维度不是常量,可能需要特殊处理,例如将其视为指针 + // 对于函数参数 int arr[] 这种,第一个维度可以为未知 + // 在这里,我们假设所有声明的数组维度都是常量 + assert(false && "Array dimension must be a constant integer!"); + return nullptr; + } + unsigned dimSize = constDim->getInt(); + currentType = Type::getArrayType(currentType, dimSize); + } + return currentType; +} + +// @brief: 获取 GEP 指令的地址 +// @param basePointer: GEP 的基指针,已经过适当的加载/处理,类型为 LLVM IR 中的指针类型。 +// 例如,对于局部数组,它是 AllocaInst;对于参数数组,它是 LoadInst 的结果。 +// @param indices: 已经包含了所有必要的偏移索引 (包括可能的初始 0 索引,由 visitLValue 准备)。 +// @return: 计算得到的地址值 (也是一个指针类型) +Value* SysYIRGenerator::getGEPAddressInst(Value* basePointer, const std::vector& indices) { + // 检查 basePointer 是否为指针类型 + assert(basePointer->getType()->isPointer() && "Base pointer must be a pointer type!"); + + // `indices` 向量现在由调用方(如 visitLValue, visitVarDecl, visitAssignStmt)负责完整准备, + // 包括是否需要添加初始的 `0` 索引。 + // 所以这里直接将其传递给 `builder.createGetElementPtrInst`。 + return builder.createGetElementPtrInst(basePointer, indices); +} + /* * @brief: visit compUnit * @details: @@ -79,7 +119,11 @@ std::any SysYIRGenerator::visitGlobalVarDecl(SysYParser::GlobalVarDeclContext *c delete root; } // 创建全局变量,并更新符号表 - module->createGlobalValue(name, Type::getPointerType(type), dims, values); + Type* variableType = type; + if (!dims.empty()) { // 如果有维度,说明是数组 + variableType = buildArrayType(type, dims); // 构建完整的 ArrayType + } + module->createGlobalValue(name, Type::getPointerType(variableType), dims, values); } return std::any(); } @@ -118,24 +162,28 @@ std::any SysYIRGenerator::visitVarDecl(SysYParser::VarDeclContext *ctx) { } } + Type* variableType = type; + if (!dims.empty()) { // 如果有维度,说明是数组 + variableType = buildArrayType(type, dims); // 构建完整的 ArrayType + } + + // 对于数组,alloca 的类型将是指针指向数组类型,例如 `int[2][3]*` + // 对于标量,alloca 的类型将是指针指向标量类型,例如 `int*` AllocaInst* alloca = - builder.createAllocaInst(Type::getPointerType(type), dims, name); + builder.createAllocaInst(Type::getPointerType(variableType), {}, name); if (varDef->initVal() != nullptr) { ValueCounter values; - // 这里的varDef->initVal()可能是ScalarInitValue或ArrayInitValue ArrayValueTree* root = std::any_cast(varDef->initVal()->accept(this)); Utils::tree2Array(type, root, dims, dims.size(), values, &builder); delete root; - if (dims.empty()) { - builder.createStoreInst(values.getValue(0), alloca); - } else{ - // **数组变量初始化** - const std::vector &counterValues = values.getValues(); - // 计算数组的**总元素数量**和**总字节大小** + if (dims.empty()) { // 标量变量初始化 + builder.createStoreInst(values.getValue(0), alloca); + } else { // 数组变量初始化 + const std::vector &counterValues = values.getValues(); + const std::vector &counterNumbers = values.getNumbers(); int numElements = 1; - // 存储每个维度的实际整数大小,用于索引计算 std::vector dimSizes; for (Value *dimVal : dims) { if (ConstantInteger *constInt = dynamic_cast(dimVal)) { @@ -145,12 +193,11 @@ std::any SysYIRGenerator::visitVarDecl(SysYParser::VarDeclContext *ctx) { } // TODO else 错误处理:数组维度必须是常量(对于静态分配) } - unsigned int elementSizeInBytes = type->getSize(); // 获取单个元素的大小(字节) + unsigned int elementSizeInBytes = type->getSize(); unsigned int totalSizeInBytes = numElements * elementSizeInBytes; - // **判断是否可以进行全零初始化优化** bool allValuesAreZero = false; - if (counterValues.empty()) { // 例如 int arr[3] = {}; 或 int arr[3][4] = {}; + if (counterValues.empty()) { allValuesAreZero = true; } else { @@ -163,7 +210,6 @@ std::any SysYIRGenerator::visitVarDecl(SysYParser::VarDeclContext *ctx) { } } else{ - // 如果值不是常量,我们通常不能确定它是否为零,所以不进行 memset 优化 allValuesAreZero = false; break; } @@ -171,64 +217,67 @@ std::any SysYIRGenerator::visitVarDecl(SysYParser::VarDeclContext *ctx) { } if (allValuesAreZero) { - // 如果所有初始化值都是零(或没有明确初始化但语法允许),使用 memset 优化 builder.createMemsetInst( - alloca, // 目标数组的起始地址 - ConstantInteger::get(0), // 偏移量(通常为0),后续删除 + alloca, + ConstantInteger::get(0), ConstantInteger::get(totalSizeInBytes), - ConstantInteger::get(0)); // 填充的总字节数 + ConstantInteger::get(0)); } else { - // **逐元素存储:遍历所有初始值,并为每个值生成一个 store 指令** - for (size_t k = 0; k < counterValues.size(); ++k) { - // 用于存储当前元素的索引列表 - std::vector currentIndices; - int tempLinearIndex = k; // 临时线性索引,用于计算多维索引 + + int linearIndexOffset = 0; // 用于追踪当前处理的线性索引的偏移量 + for (int k = 0; k < counterValues.size(); ++k) { + // 当前 Value 的值和重复次数 + Value* currentValue = counterValues[k]; + unsigned currentRepeatNum = counterNumbers[k]; - // **将线性索引转换为多维索引** - // 这个循环从最内层维度开始倒推,计算每个维度的索引 - // 假设是行主序(row-major order),这是 C/C++ 数组的标准存储方式 - for (int dimIdx = dimSizes.size() - 1; dimIdx >= 0; --dimIdx) - { - // 计算当前维度的索引,并插入到列表的最前面 - currentIndices.insert(currentIndices.begin(), - ConstantInteger::get(static_cast(tempLinearIndex % dimSizes[dimIdx]))); - // 更新线性索引,用于计算下一个更高维度的索引 - tempLinearIndex /= dimSizes[dimIdx]; - } + for (unsigned i = 0; i < currentRepeatNum; ++i) { + std::vector currentIndices; + int tempLinearIndex = linearIndexOffset + i; // 使用偏移量和当前重复次数内的索引 - // **生成 store 指令,传入值、基指针和计算出的索引列表** - // 你的 builder.createStoreInst 签名需要能够接受这些参数 - // 假设你的 builder.createStoreInst(Value *val, Value *ptr, const std::vector &indices, ...) - builder.createStoreInst(counterValues[k], alloca, currentIndices); + // 将线性索引转换为多维索引 + for (int dimIdx = dimSizes.size() - 1; dimIdx >= 0; --dimIdx) { + currentIndices.insert(currentIndices.begin(), + ConstantInteger::get(static_cast(tempLinearIndex % dimSizes[dimIdx]))); + tempLinearIndex /= dimSizes[dimIdx]; + } + + // 对于局部数组,alloca 本身就是 GEP 的基指针。 + // GEP 的第一个索引必须是 0,用于“步过”整个数组。 + std::vector gepIndicesForInit; + gepIndicesForInit.push_back(ConstantInteger::get(0)); + gepIndicesForInit.insert(gepIndicesForInit.end(), currentIndices.begin(), currentIndices.end()); + + // 计算元素的地址 + Value* elementAddress = getGEPAddressInst(alloca, gepIndicesForInit); + // 生成 store 指令 + builder.createStoreInst(currentValue, elementAddress); + } + // 更新线性索引偏移量,以便下一次迭代从正确的位置开始 + linearIndexOffset += currentRepeatNum; } + } } } - else - { // **如果没有显式初始化值,默认对数组进行零初始化** - if (!dims.empty()) - { // 只有数组才需要默认的零初始化 + else { // 如果没有显式初始化值,默认对数组进行零初始化 + if (!dims.empty()) { // 只有数组才需要默认的零初始化 int numElements = 1; - for (Value *dimVal : dims) - { - if (ConstantInteger *constInt = dynamic_cast(dimVal)) - { + for (Value *dimVal : dims) { + if (ConstantInteger *constInt = dynamic_cast(dimVal)) { numElements *= constInt->getInt(); } } unsigned int elementSizeInBytes = type->getSize(); unsigned int totalSizeInBytes = numElements * elementSizeInBytes; - // 使用 memset 将整个数组清零 builder.createMemsetInst( alloca, ConstantInteger::get(0), ConstantInteger::get(totalSizeInBytes), ConstantInteger::get(0) - ); // 填充的总字节数 + ); } - // 标量变量如果没有初始化值,通常不生成额外的初始化指令,因为其内存已分配但未赋值。 } module->addVariable(name, alloca); @@ -284,58 +333,107 @@ std::any SysYIRGenerator::visitFuncType(SysYParser::FuncTypeContext *ctx) { std::any SysYIRGenerator::visitFuncDef(SysYParser::FuncDefContext *ctx){ // 更新作用域 module->enterNewScope(); - HasReturnInst = false; auto name = ctx->Ident()->getText(); std::vector paramTypes; + std::vector paramActualTypes; std::vector paramNames; std::vector> paramDims; if (ctx->funcFParams() != nullptr) { auto params = ctx->funcFParams()->funcFParam(); for (const auto ¶m : params) { - paramTypes.push_back(std::any_cast(visitBType(param->bType()))); - paramNames.push_back(param->Ident()->getText()); - std::vector dims = {}; - if (!param->LBRACK().empty()) { - dims.push_back(ConstantInteger::get(-1)); // 第一个维度不确定 + Type* baseBType = std::any_cast(visitBType(param->bType())); + std::string paramName = param->Ident()->getText(); + + // 用于收集当前参数的维度信息(如果它是数组) + std::vector currentParamDims; + if (!param->LBRACK().empty()) { // 如果参数声明中有方括号,说明是数组 + // SysY 数组参数的第一个维度可以是未知的(例如 int arr[] 或 int arr[][10]) + // 这里的 ConstantInteger::get(-1) 表示未知维度,但对于 LLVM 类型构建,我们主要关注已知维度 + currentParamDims.push_back(ConstantInteger::get(-1)); // 标记第一个维度为未知 for (const auto &exp : param->exp()) { - dims.push_back(std::any_cast(visitExp(exp))); + // 访问表达式以获取维度大小,这些维度必须是常量 + Value* dimVal = std::any_cast(visitExp(exp)); + // 确保维度是常量整数,否则 buildArrayType 会断言失败 + assert(dynamic_cast(dimVal) && "Array dimension in parameter must be a constant integer!"); + currentParamDims.push_back(dimVal); } } - paramDims.emplace_back(dims); + + // 根据解析出的信息,确定参数在 LLVM IR 中的实际类型 + Type* actualParamType; + if (currentParamDims.empty()) { // 情况1:标量参数 (e.g., int x) + actualParamType = baseBType; // 实际类型就是基本类型 + } else { // 情况2&3:数组参数 (e.g., int arr[] 或 int arr[][10]) + // 数组参数在函数传递时会退化为指针。 + // 这个指针指向的类型是除第一维外,由后续维度构成的数组类型。 + + // 从 currentParamDims 中移除第一个标记未知维度的 -1 + std::vector fixedDimsForTypeBuilding; + if (currentParamDims.size() > 1) { // 如果有固定维度 (e.g., int arr[][10]) + // 复制除第一个 -1 之外的所有维度 + fixedDimsForTypeBuilding.assign(currentParamDims.begin() + 1, currentParamDims.end()); + } + + Type* pointedToArrayType = baseBType; // 从基本类型开始构建 + // 从最内层维度向外层构建数组类型 + // buildArrayType 期望 dims 是从最外层到最内层,但它内部反向迭代,所以这里直接传入 + // 例如,对于 int arr[][10],fixedDimsForTypeBuilding 包含 [10],构建出 [10 x i32] + if (!fixedDimsForTypeBuilding.empty()) { + pointedToArrayType = buildArrayType(baseBType, fixedDimsForTypeBuilding); + } + + // 实际参数类型是指向这个构建好的数组类型的指针 + actualParamType = Type::getPointerType(pointedToArrayType); // e.g., i32* 或 [10 x i32]* + } + + paramActualTypes.push_back(actualParamType); // 存储参数的实际 LLVM IR 类型 + paramNames.push_back(paramName); // 存储参数名称 + } } Type* returnType = std::any_cast(visitFuncType(ctx->funcType())); - Type* funcType = Type::getFunctionType(returnType, paramTypes); + Type* funcType = Type::getFunctionType(returnType, paramActualTypes); Function* function = module->createFunction(name, funcType); BasicBlock* entry = function->getEntryBlock(); builder.setPosition(entry, entry->end()); - for (size_t i = 0; i < paramTypes.size(); ++i) { - AllocaInst* alloca = builder.createAllocaInst(Type::getPointerType(paramTypes[i]), - paramDims[i], paramNames[i]); + for (int i = 0; i < paramActualTypes.size(); ++i) { + AllocaInst* alloca = builder.createAllocaInst(Type::getPointerType(paramActualTypes[i]), {},paramNames[i]); entry->insertArgument(alloca); module->addVariable(paramNames[i], alloca); } + // 在处理函数体之前,创建一个新的基本块作为函数体的实际入口 + // 这样 entryBB 就可以在完成初始化后跳转到这里 + BasicBlock* funcBodyEntry = function->addBasicBlock("funcBodyEntry_" + name); + + // 从 entryBB 无条件跳转到 funcBodyEntry + builder.createUncondBrInst(funcBodyEntry, {}); + builder.setPosition(funcBodyEntry,funcBodyEntry->end()); // 将插入点设置到 funcBodyEntry + for (auto item : ctx->blockStmt()->blockItem()) { visitBlockItem(item); } - if(HasReturnInst == false) { - // 如果没有return语句,则默认返回0 - if (returnType != Type::getVoidType()) { - Value* returnValue = ConstantInteger::get(0); - if (returnType == Type::getFloatType()) { - returnValue = ConstantFloating::get(0.0f); - } - builder.createReturnInst(returnValue); + // 如果函数没有显式的返回语句,且返回类型不是 void,则需要添加一个默认的返回值 + ReturnInst* retinst = nullptr; + retinst = dynamic_cast(builder.getBasicBlock()->terminator()->get()); + + if (!retinst) { + if (returnType->isVoid()) { + builder.createReturnInst(); + } else if (returnType->isInt()) { + builder.createReturnInst(ConstantInteger::get(0)); // 默认返回 0 + } else if (returnType->isFloat()) { + builder.createReturnInst(ConstantFloating::get(0.0f)); // 默认返回 0.0f } else { - builder.createReturnInst(); + assert(false && "Function with no explicit return and non-void type should return a value."); } } + module->leaveScope(); return std::any(); @@ -352,33 +450,79 @@ std::any SysYIRGenerator::visitBlockStmt(SysYParser::BlockStmtContext *ctx) { std::any SysYIRGenerator::visitAssignStmt(SysYParser::AssignStmtContext *ctx) { auto lVal = ctx->lValue(); std::string name = lVal->Ident()->getText(); - std::vector dims; - for (const auto &exp : lVal->exp()) { - dims.push_back(std::any_cast(visitExp(exp))); + Value* LValue = nullptr; + Value* variable = module->getVariable(name); // 左值 + + vector indices; + if (lVal->exp().size() > 0) { + // 如果有下标,访问表达式获取下标值 + for (const auto &exp : lVal->exp()) { + Value* indexValue = std::any_cast(visitExp(exp)); + indices.push_back(indexValue); + } + } + if (indices.empty()) { + // variable 本身就是指向标量的指针 (e.g., int* %a) + if (dynamic_cast(variable) || dynamic_cast(variable)) { + LValue = variable; + } + } + else { + // 对于数组或多维数组的左值处理 + // 需要获取 GEP 地址 + Value* gepBasePointer = nullptr; + std::vector gepIndices; + if (AllocaInst *alloc = dynamic_cast(variable)) { + Type* allocatedType = alloc->getType()->as()->getBaseType(); + if (allocatedType->isPointer()) { + gepBasePointer = builder.createLoadInst(alloc); + gepIndices = indices; + } else { + gepBasePointer = alloc; + gepIndices.push_back(ConstantInteger::get(0)); + gepIndices.insert(gepIndices.end(), indices.begin(), indices.end()); + } + } else if (GlobalValue *glob = dynamic_cast(variable)) { + // 情况 B: 全局变量 (GlobalValue) + gepBasePointer = glob; + gepIndices.push_back(ConstantInteger::get(0)); + gepIndices.insert(gepIndices.end(), indices.begin(), indices.end()); + } else if (ConstantVariable *constV = dynamic_cast(variable)) { + gepBasePointer = constV; + gepIndices.push_back(ConstantInteger::get(0)); + gepIndices.insert(gepIndices.end(), indices.begin(), indices.end()); + } + // 左值为地址 + LValue = getGEPAddressInst(gepBasePointer, gepIndices); } - auto variable = module->getVariable(name); - Value* value = std::any_cast(visitExp(ctx->exp())); - Type* variableType = dynamic_cast(variable->getType())->getBaseType(); + Value* RValue = std::any_cast(visitExp(ctx->exp())); // 右值 - // 左值右值类型不同处理 - if (variableType != value->getType()) { - ConstantValue * constValue = dynamic_cast(value); + // 先推断 LValue 的类型 + // 如果 LValue 是指向数组的指针,则需要根据 indices 获取正确的类型 + // 如果 LValue 是标量,则直接使用其类型 + // 注意:LValue 的类型可能是指向数组的指针 (e.g., int(*)[3]) 或者指向标量的指针 (e.g., int*) 也能推断 + Type* LType = builder.getIndexedType(variable->getType(), indices); + Type* RType = RValue->getType(); + + if (LType != RType) { + ConstantValue * constValue = dynamic_cast(RValue); if (constValue != nullptr) { - if (variableType == Type::getFloatType()) { - value = ConstantInteger::get(static_cast(constValue->getInt())); - } else { - value = ConstantFloating::get(static_cast(constValue->getFloat())); + if (LType == Type::getFloatType()) { + RValue = ConstantFloating::get(static_cast(constValue->getFloat())); + } else { // 假设如果不是浮点型,就是整型 + RValue = ConstantInteger::get(static_cast(constValue->getInt())); } } else { - if (variableType == Type::getFloatType()) { - value = builder.createIToFInst(value); - } else { - value = builder.createFtoIInst(value); + if (LType == Type::getFloatType()) { + RValue = builder.createIToFInst(RValue); + } else { // 假设如果不是浮点型,就是整型 + RValue = builder.createFtoIInst(RValue); } } } - builder.createStoreInst(value, variable, dims, variable->getName()); + + builder.createStoreInst(RValue, LValue); return std::any(); } @@ -466,6 +610,7 @@ std::any SysYIRGenerator::visitIfStmt(SysYParser::IfStmtContext *ctx) { ctx->stmt(0)->accept(this); module->leaveScope(); } + builder.createUncondBrInst(exitBlock, {}); BasicBlock::conectBlocks(builder.getBasicBlock(), exitBlock); labelstring << "if_exit.L" << builder.getLabelIndex(); @@ -487,6 +632,7 @@ std::any SysYIRGenerator::visitWhileStmt(SysYParser::WhileStmtContext *ctx) { labelstring << "while_head.L" << builder.getLabelIndex(); BasicBlock *headBlock = function->addBasicBlock(labelstring.str()); labelstring.str(""); + builder.createUncondBrInst(headBlock, {}); BasicBlock::conectBlocks(curBlock, headBlock); builder.setPosition(headBlock, headBlock->end()); @@ -571,56 +717,141 @@ std::any SysYIRGenerator::visitReturnStmt(SysYParser::ReturnStmtContext *ctx) { } } builder.createReturnInst(returnValue); - HasReturnInst = true; + return std::any(); } +// 辅助函数:计算给定类型中嵌套的数组维度数量 +// 例如: +// - 对于 i32* 类型,它指向 i32,维度为 0。 +// - 对于 [10 x i32]* 类型,它指向 [10 x i32],维度为 1。 +// - 对于 [20 x [10 x i32]]* 类型,它指向 [20 x [10 x i32]],维度为 2。 +unsigned SysYIRGenerator::countArrayDimensions(Type* type) { + unsigned dims = 0; + Type* currentType = type; + + // 如果是指针类型,先获取它指向的基础类型 + if (currentType->isPointer()) { + currentType = currentType->as()->getBaseType(); + } + + // 递归地计算数组的维度层数 + while (currentType && currentType->isArray()) { + dims++; + currentType = currentType->as()->getElementType(); + } + return dims; +} + std::any SysYIRGenerator::visitLValue(SysYParser::LValueContext *ctx) { std::string name = ctx->Ident()->getText(); - User* variable = module->getVariable(name); + Value* variable = module->getVariable(name); Value* value = nullptr; + std::vector dims; for (const auto &exp : ctx->exp()) { dims.push_back(std::any_cast(visitExp(exp))); } - if (variable == nullptr) { - throw std::runtime_error("Variable " + name + " not found."); - } - - bool indicesConstant = true; - for (const auto &dim : dims) { - if (dynamic_cast(dim) == nullptr) { - indicesConstant = false; - break; - } - } + // 1. 获取变量的声明维度数量 + unsigned declaredNumDims = countArrayDimensions(variable->getType()); + // 2. 处理常量变量 (ConstantVariable) 且所有索引都是常量的情况 ConstantVariable* constVar = dynamic_cast(variable); - GlobalValue* globalVar = dynamic_cast(variable); - AllocaInst* localVar = dynamic_cast(variable); - if (constVar != nullptr && indicesConstant) { - // 如果是常量变量,且索引是常量,则直接获取子数组 - value = constVar->getByIndices(dims); - } else if (module->isInGlobalArea() && (globalVar != nullptr)) { - assert(indicesConstant); - value = globalVar->getByIndices(dims); - } else { - if ((globalVar != nullptr && globalVar->getNumDims() > dims.size()) || - (localVar != nullptr && localVar->getNumDims() > dims.size()) || - (constVar != nullptr && constVar->getNumDims() > dims.size())) { - // value = builder.createLaInst(variable, indices); - // 如果变量是全局变量或局部变量,且索引数量小于维度数量,则创建createGetSubArray获取子数组 - auto getArrayInst = - builder.createGetSubArray(dynamic_cast(variable), dims); - value = getArrayInst->getChildArray(); - } else { - value = builder.createLoadInst(variable, dims); + if (constVar != nullptr) { + bool allIndicesConstant = true; + for (const auto &dim : dims) { + if (dynamic_cast(dim) == nullptr) { + allIndicesConstant = false; + break; + } + } + if (allIndicesConstant) { + // 如果是常量变量且所有索引都是常量,直接通过 getByIndices 获取编译时值 + // 这个方法会根据索引深度返回最终的标量值或指向子数组的指针 (作为 ConstantValue/Variable) + return constVar->getByIndices(dims); } } + // 3. 处理可变变量 (AllocaInst/GlobalValue) 或带非常量索引的常量变量 + // 这里区分标量访问和数组元素/子数组访问 + + // 检查是否是访问标量变量本身(没有索引,且声明维度为0) + if (dims.empty() && declaredNumDims == 0) { + // 对于标量变量,直接加载其值。 + // variable 本身就是指向标量的指针 (e.g., int* %a) + if (dynamic_cast(variable) || dynamic_cast(variable)) { + value = builder.createLoadInst(variable); + } else { + // 如果走到这里且不是AllocaInst/GlobalValue,但dims为空且declaredNumDims为0, + // 且又不是ConstantVariable (前面已处理),则可能是错误情况。 + assert(false && "Unhandled scalar variable type in LValue access."); + return static_cast(nullptr); + } + } else { + // 访问数组元素或子数组(有索引,或变量本身是数组/多维指针) + Value* gepBasePointer = nullptr; + std::vector gepIndices; // 准备传递给 getGEPAddressInst 的索引列表 + // GEP 的基指针就是变量本身(它是一个指向内存的指针) + if (AllocaInst *alloc = dynamic_cast(variable)) { + // 情况 A: 局部变量 (AllocaInst) + // 获取 AllocaInst 分配的内存的实际类型。 + // 例如:对于 `int b[10][20];`,`allocatedType` 是 `[10 x [20 x i32]]`。 + // 对于 `int b[][20]` 的函数参数,其 AllocaInst 存储的是一个指针, + // 此时 `allocatedType` 是 `[20 x i32]*`。 + Type* allocatedType = alloc->getType()->as()->getBaseType(); + + if (allocatedType->isPointer()) { + // 如果 AllocaInst 分配的是一个指针类型 (例如,用于存储函数参数的指针,如 int b[][20] 中的 b) + // 即 `allocatedType` 是一个指向数组指针的指针 (e.g., [20 x i32]**) + // 那么 GEP 的基指针是加载这个指针变量的值。 + gepBasePointer = builder.createLoadInst(alloc); // 加载出实际的指针值 (e.g., [20 x i32]*) + // 对于这种参数指针,用户提供的索引直接作用于它。不需要额外的 0。 + gepIndices = dims; + } else { + // 如果 AllocaInst 分配的是实际的数组数据 (例如,int b[10][20] 中的 b) + // 那么 AllocaInst 本身就是 GEP 的基指针。 + // 这里的 `alloc` 是指向数组的指针 (e.g., [10 x [20 x i32]]*) + gepBasePointer = alloc; // 类型是 [10 x [20 x i32]]* + // 对于这种完整的数组分配,GEP 的第一个索引必须是 0,用于“步过”整个数组。 + gepIndices.push_back(ConstantInteger::get(0)); + gepIndices.insert(gepIndices.end(), dims.begin(), dims.end()); + } + } else if (GlobalValue *glob = dynamic_cast(variable)) { + // 情况 B: 全局变量 (GlobalValue) + // GlobalValue 总是指向全局数据的指针。 + gepBasePointer = glob; // 类型是 [61 x [67 x i32]]* + // 对于全局数组,GEP 的第一个索引必须是 0,用于“步过”整个数组。 + gepIndices.push_back(ConstantInteger::get(0)); + gepIndices.insert(gepIndices.end(), dims.begin(), dims.end()); + } else if (ConstantVariable *constV = dynamic_cast(variable)) { + // 情况 C: 常量变量 (ConstantVariable),如果它代表全局数组常量 + // 假设 ConstantVariable 可以直接作为 GEP 的基指针。 + gepBasePointer = constV; + // 对于常量数组,也需要 0 索引来“步过”整个数组。 + // 这里可以进一步检查 constV->getType()->as()->getBaseType()->isArray() + // 但为了简洁,假设所有 ConstantVariable 作为 GEP 基指针时都需要此 0。 + gepIndices.push_back(ConstantInteger::get(0)); + gepIndices.insert(gepIndices.end(), dims.begin(), dims.end()); + } else { + assert(false && "LValue variable type not supported for GEP base pointer."); + return static_cast(nullptr); + } + + // 现在调用 getGEPAddressInst,传入正确准备的基指针和索引列表 + Value *targetAddress = getGEPAddressInst(gepBasePointer, gepIndices); + + // 如果提供的索引数量少于声明的维度数量,则表示访问的是子数组,返回其地址 + if (dims.size() < declaredNumDims) { + value = targetAddress; + } else { + // 否则,表示访问的是最终的标量元素,加载其值 + // 假设 createLoadInst 接受 Value* pointer + value = builder.createLoadInst(targetAddress); + } + } return value; } @@ -662,32 +893,63 @@ std::any SysYIRGenerator::visitCall(SysYParser::CallContext *ctx) { std::vector args = {}; if (funcName == "starttime" || funcName == "stoptime") { - // 如果是starttime或stoptime函数 - // TODO: 这里需要处理starttime和stoptime函数的参数 - // args.emplace_back() + args.emplace_back( + ConstantInteger::get(static_cast(ctx->getStart()->getLine()))); } else { if (ctx->funcRParams() != nullptr) { args = std::any_cast>(visitFuncRParams(ctx->funcRParams())); } - auto params = function->getEntryBlock()->getArguments(); - for (size_t i = 0; i < args.size(); i++) { - // 参数类型转换 - if (params[i]->getType() != args[i]->getType() && - (params[i]->getNumDims() != 0 || - params[i]->getType()->as()->getBaseType() != args[i]->getType())) { - ConstantValue * constValue = dynamic_cast(args[i]); + // 获取形参列表。`getArguments()` 返回的是 `Argument*` 的集合, + // 每个 `Argument` 代表一个函数形参,其 `getType()` 就是指向形参的类型的指针类型。 + auto formalParamsAlloca = function->getEntryBlock()->getArguments(); + + // 检查实参和形参数量是否匹配。 + if (args.size() != formalParamsAlloca.size()) { + std::cerr << "Error: Function call argument count mismatch for function '" << funcName << "'." << std::endl; + assert(false && "Function call argument count mismatch!"); + } + + for (int i = 0; i < args.size(); i++) { + // 形参的类型 (e.g., i32, float, i32*, [10 x i32]*) + Type* formalParamExpectedValueType = formalParamsAlloca[i]->getType()->as()->getBaseType(); + // 实参的实际类型 (e.g., i32, float, i32*, [67 x i32]*) + Type* actualArgType = args[i]->getType(); + // 如果实参类型与形参类型不匹配,则尝试进行类型转换 + if (formalParamExpectedValueType != actualArgType) { + ConstantValue *constValue = dynamic_cast(args[i]); if (constValue != nullptr) { - if (params[i]->getType() == Type::getPointerType(Type::getFloatType())) { - args[i] = ConstantInteger::get(static_cast(constValue->getInt())); + if (formalParamExpectedValueType->isInt() && actualArgType->isFloat()) { + args[i] = ConstantInteger::get(static_cast(constValue->getFloat())); + } else if (formalParamExpectedValueType->isFloat() && actualArgType->isInt()) { + args[i] = ConstantFloating::get(static_cast(constValue->getInt())); } else { - args[i] = ConstantFloating::get(static_cast(constValue->getFloat())); + // 如果是常量但不是简单的 int/float 标量转换, + // 或者是指针常量需要 bitcast,则让它进入非常量转换逻辑。 + // 例如,一个常量数组的地址,需要 bitcast 成另一种指针类型。 + // 目前不知道样例有没有这种情况,所以这里不做处理。 } - } else { - if (params[i]->getType() == Type::getPointerType(Type::getFloatType())) { - args[i] = builder.createIToFInst(args[i]); - } else { + } + else { + // 1. 标量值类型转换 (例如:int_reg 到 float_reg,float_reg 到 int_reg) + if (formalParamExpectedValueType->isInt() && actualArgType->isFloat()) { args[i] = builder.createFtoIInst(args[i]); + } else if (formalParamExpectedValueType->isFloat() && actualArgType->isInt()) { + args[i] = builder.createIToFInst(args[i]); + } + // 2. 指针类型转换 (例如数组退化:`[N x T]*` 到 `T*`,或兼容指针类型之间) TODO:不清楚有没有这种样例 + // 这种情况常见于数组参数,实参可能是一个更具体的数组指针类型, + // 而形参是其退化后的基础指针类型。LLVM 的 `bitcast` 指令可以用于 + // 在相同大小的指针类型之间进行转换,这对于数组退化至关重要。 + // else if (formalParamType->isPointer() && actualArgType->isPointer()) { + // 检查指针基类型是否兼容,或者是否是数组退化导致的类型不同。 + // 使用 bitcast, + // args[i] = builder.createBitCastInst(args[i], formalParamType); + // } + // 3. 其他未预期的类型不匹配 + // 如果代码执行到这里,说明存在编译器前端未处理的类型不兼容或错误。 + else { + // assert(false && "Unhandled type mismatch for function call argument."); } } } @@ -757,7 +1019,7 @@ std::any SysYIRGenerator::visitFuncRParams(SysYParser::FuncRParamsContext *ctx) std::any SysYIRGenerator::visitMulExp(SysYParser::MulExpContext *ctx) { Value * result = std::any_cast(visitUnaryExp(ctx->unaryExp(0))); - for (size_t i = 1; i < ctx->unaryExp().size(); i++) { + for (int i = 1; i < ctx->unaryExp().size(); i++) { auto opNode = dynamic_cast(ctx->children[2*i-1]); int opType = opNode->getSymbol()->getType(); @@ -833,7 +1095,7 @@ std::any SysYIRGenerator::visitMulExp(SysYParser::MulExpContext *ctx) { std::any SysYIRGenerator::visitAddExp(SysYParser::AddExpContext *ctx) { Value* result = std::any_cast(visitMulExp(ctx->mulExp(0))); - for (size_t i = 1; i < ctx->mulExp().size(); i++) { + for (int i = 1; i < ctx->mulExp().size(); i++) { auto opNode = dynamic_cast(ctx->children[2*i-1]); int opType = opNode->getSymbol()->getType(); @@ -894,7 +1156,7 @@ std::any SysYIRGenerator::visitAddExp(SysYParser::AddExpContext *ctx) { std::any SysYIRGenerator::visitRelExp(SysYParser::RelExpContext *ctx) { Value* result = std::any_cast(visitAddExp(ctx->addExp(0))); - for (size_t i = 1; i < ctx->addExp().size(); i++) { + for (int i = 1; i < ctx->addExp().size(); i++) { auto opNode = dynamic_cast(ctx->children[2*i-1]); int opType = opNode->getSymbol()->getType(); @@ -966,7 +1228,7 @@ std::any SysYIRGenerator::visitRelExp(SysYParser::RelExpContext *ctx) { std::any SysYIRGenerator::visitEqExp(SysYParser::EqExpContext *ctx) { Value * result = std::any_cast(visitRelExp(ctx->relExp(0))); - for (size_t i = 1; i < ctx->relExp().size(); i++) { + for (int i = 1; i < ctx->relExp().size(); i++) { auto opNode = dynamic_cast(ctx->children[2*i-1]); int opType = opNode->getSymbol()->getType(); @@ -1040,7 +1302,7 @@ std::any SysYIRGenerator::visitLAndExp(SysYParser::LAndExpContext *ctx){ BasicBlock *falseBlock = builder.getFalseBlock(); auto conds = ctx->eqExp(); - for (size_t i = 0; i < conds.size() - 1; i++) { + for (int i = 0; i < conds.size() - 1; i++) { labelstring << "AND.L" << builder.getLabelIndex(); BasicBlock *newtrueBlock = function->addBasicBlock(labelstring.str()); @@ -1071,7 +1333,7 @@ auto SysYIRGenerator::visitLOrExp(SysYParser::LOrExpContext *ctx) -> std::any { Function *function = curBlock->getParent(); auto conds = ctx->lAndExp(); - for (size_t i = 0; i < conds.size() - 1; i++) { + for (int i = 0; i < conds.size() - 1; i++) { labelstring << "OR.L" << builder.getLabelIndex(); BasicBlock *newFalseBlock = function->addBasicBlock(labelstring.str()); labelstring.str(""); @@ -1088,6 +1350,7 @@ auto SysYIRGenerator::visitLOrExp(SysYParser::LOrExpContext *ctx) -> std::any { return std::any(); } +// attention : 这里的type是数组元素的type void Utils::tree2Array(Type *type, ArrayValueTree *root, const std::vector &dims, unsigned numDims, ValueCounter &result, IRBuilder *builder) { @@ -1158,9 +1421,9 @@ void Utils::createExternalFunction( auto entry = function->getEntryBlock(); pBuilder->setPosition(entry, entry->end()); - for (size_t i = 0; i < paramTypes.size(); ++i) { + for (int i = 0; i < paramTypes.size(); ++i) { auto alloca = pBuilder->createAllocaInst( - Type::getPointerType(paramTypes[i]), paramDims[i], paramNames[i]); + Type::getPointerType(paramTypes[i]), {}, paramNames[i]); entry->insertArgument(alloca); // pModule->addVariable(paramNames[i], alloca); } diff --git a/src/SysYIROptPre.cpp b/src/SysYIROptPre.cpp deleted file mode 100644 index 771c474..0000000 --- a/src/SysYIROptPre.cpp +++ /dev/null @@ -1,484 +0,0 @@ -#include "SysYIROptPre.h" -#include -#include -#include -#include -#include -#include -#include "IR.h" -#include "IRBuilder.h" - -namespace sysy { - -/** - * use删除operand,以免扰乱后续分析 - * instr: 要删除的指令 - */ -void SysYOptPre::usedelete(Instruction *instr) { - for (auto &use : instr->getOperands()) { - Value* val = use->getValue(); - // std::cout << delete << val->getName() << std::endl; - val->removeUse(use); - } -} - - -// 删除br后的无用指令 -void SysYOptPre::SysYDelInstAfterBr() { - auto &functions = pModule->getFunctions(); - for (auto &function : functions) { - auto basicBlocks = function.second->getBasicBlocks(); - for (auto &basicBlock : basicBlocks) { - bool Branch = false; - auto &instructions = basicBlock->getInstructions(); - auto Branchiter = instructions.end(); - for (auto iter = instructions.begin(); iter != instructions.end(); ++iter) { - if (Branch) - usedelete(iter->get()); - else if ((*iter)->isTerminator()){ - Branch = true; - Branchiter = iter; - } - } - if (Branchiter != instructions.end()) ++Branchiter; - while (Branchiter != instructions.end()) - Branchiter = instructions.erase(Branchiter); - - if (Branch) { // 更新前驱后继关系 - auto thelastinstinst = basicBlock->getInstructions().end(); - --thelastinstinst; - auto &Successors = basicBlock->getSuccessors(); - for (auto iterSucc = Successors.begin(); iterSucc != Successors.end();) { - (*iterSucc)->removePredecessor(basicBlock.get()); - basicBlock->removeSuccessor(*iterSucc); - } - if (thelastinstinst->get()->isUnconditional()) { - BasicBlock* branchBlock = dynamic_cast(thelastinstinst->get()->getOperand(0)); - basicBlock->addSuccessor(branchBlock); - branchBlock->addPredecessor(basicBlock.get()); - } else if (thelastinstinst->get()->isConditional()) { - BasicBlock* thenBlock = dynamic_cast(thelastinstinst->get()->getOperand(1)); - BasicBlock* elseBlock = dynamic_cast(thelastinstinst->get()->getOperand(2)); - basicBlock->addSuccessor(thenBlock); - basicBlock->addSuccessor(elseBlock); - thenBlock->addPredecessor(basicBlock.get()); - elseBlock->addPredecessor(basicBlock.get()); - } - } - } - } -} - - -void SysYOptPre::SysYBlockMerge() { - auto &functions = pModule->getFunctions(); //std::map> - for (auto &function : functions) { - // auto basicBlocks = function.second->getBasicBlocks(); - auto &func = function.second; - for (auto blockiter = func->getBasicBlocks().begin(); - blockiter != func->getBasicBlocks().end();) { - if (blockiter->get()->getNumSuccessors() == 1) { - // 如果当前块只有一个后继块 - // 且后继块只有一个前驱块 - // 则将当前块和后继块合并 - if (((blockiter->get())->getSuccessors()[0])->getNumPredecessors() == 1) { - // std::cout << "merge block: " << blockiter->get()->getName() << std::endl; - BasicBlock* block = blockiter->get(); - BasicBlock* nextBlock = blockiter->get()->getSuccessors()[0]; - auto nextarguments = nextBlock->getArguments(); - // 删除br指令 - if (block->getNumInstructions() != 0) { - auto thelastinstinst = block->end(); - (--thelastinstinst); - if (thelastinstinst->get()->isUnconditional()) { - usedelete(thelastinstinst->get()); - block->getInstructions().erase(thelastinstinst); - } else if (thelastinstinst->get()->isConditional()) { - // 如果是条件分支,判断条件是否相同,主要优化相同布尔表达式 - if (thelastinstinst->get()->getOperand(1)->getName() == thelastinstinst->get()->getOperand(1)->getName()) { - usedelete(thelastinstinst->get()); - block->getInstructions().erase(thelastinstinst); - } - } - } - // 将后继块的指令移动到当前块 - // 并将后继块的父指针改为当前块 - for (auto institer = nextBlock->begin(); institer != nextBlock->end();) { - institer->get()->setParent(block); - block->getInstructions().emplace_back(institer->release()); - institer = nextBlock->getInstructions().erase(institer); - } - // 合并参数 - // TODO:是否需要去重? - for (auto &argm : nextarguments) { - argm->setParent(block); - block->insertArgument(argm); - } - // 更新前驱后继关系,类似树节点操作 - block->removeSuccessor(nextBlock); - nextBlock->removePredecessor(block); - std::list succshoulddel; - for (auto &succ : nextBlock->getSuccessors()) { - block->addSuccessor(succ); - succ->replacePredecessor(nextBlock, block); - succshoulddel.push_back(succ); - } - for (auto del : succshoulddel) { - nextBlock->removeSuccessor(del); - } - - func->removeBasicBlock(nextBlock); - - } else { - blockiter++; - } - } else { - blockiter++; - } - } - } -} - -// 删除无前驱块,兼容SSA后的处理 -void SysYOptPre::SysYDelNoPreBLock() { - - auto &functions = pModule->getFunctions(); // std::map> - for (auto &function : functions) { - auto &func = function.second; - - for (auto &block : func->getBasicBlocks()) { - block->setreachableFalse(); - } - // 对函数基本块做一个拓扑排序,排查不可达基本块 - auto entryBlock = func->getEntryBlock(); - entryBlock->setreachableTrue(); - std::queue blockqueue; - blockqueue.push(entryBlock); - while (!blockqueue.empty()) { - auto block = blockqueue.front(); - blockqueue.pop(); - for (auto &succ : block->getSuccessors()) { - if (!succ->getreachable()) { - succ->setreachableTrue(); - blockqueue.push(succ); - } - } - } - - // 删除不可达基本块指令 - for (auto blockIter = func->getBasicBlocks().begin(); blockIter != func->getBasicBlocks().end();blockIter++) { - - if (!blockIter->get()->getreachable()) - for (auto &iterInst : blockIter->get()->getInstructions()) - usedelete(iterInst.get()); - - } - - - for (auto blockIter = func->getBasicBlocks().begin(); blockIter != func->getBasicBlocks().end();) { - if (!blockIter->get()->getreachable()) { - for (auto succblock : blockIter->get()->getSuccessors()) { - int indexphi = 1; - for (auto pred : succblock->getPredecessors()) { - if (pred == blockIter->get()) { - break; - } - indexphi++; - } - for (auto &phiinst : succblock->getInstructions()) { - if (phiinst->getKind() != Instruction::kPhi) { - break; - } - phiinst->removeOperand(indexphi); - } - } - // 删除不可达基本块,注意迭代器不可达问题 - func->removeBasicBlock((blockIter++)->get()); - } else { - blockIter++; - } - } - } -} - -void SysYOptPre::SysYDelEmptyBlock() { - auto &functions = pModule->getFunctions(); - for (auto &function : functions) { - // 收集不可达基本块 - // 这里的不可达基本块是指没有实际指令的基本块 - // 当一个基本块没有实际指令例如只有phi指令和一个uncondbr指令时,也会被视作不可达 - auto basicBlocks = function.second->getBasicBlocks(); - std::map EmptyBlocks; - // 空块儿和后继的基本块的映射 - for (auto &basicBlock : basicBlocks) { - if (basicBlock->getNumInstructions() == 0) { - if (basicBlock->getNumSuccessors() == 1) { - EmptyBlocks[basicBlock.get()] = basicBlock->getSuccessors().front(); - } - } - else{ - // 如果只有phi指令和一个uncondbr。(phi)*(uncondbr)? - // 判断除了最后一个指令之外是不是只有phi指令 - bool onlyPhi = true; - for (auto &inst : basicBlock->getInstructions()) { - if (!inst->isPhi() && !inst->isUnconditional()) { - onlyPhi = false; - break; - } - } - if(onlyPhi) - EmptyBlocks[basicBlock.get()] = basicBlock->getSuccessors().front(); - } - - - } - // 更新基本块信息,增加必要指令 - for (auto &basicBlock : basicBlocks) { - // 把空块转换成只有跳转指令的不可达块 - if (distance(basicBlock->begin(), basicBlock->end()) == 0) { - if (basicBlock->getNumSuccessors() == 0) { - continue; - } - if (basicBlock->getNumSuccessors() > 1) { - assert(""); - } - pBuilder->setPosition(basicBlock.get(), basicBlock->end()); - pBuilder->createUncondBrInst(basicBlock->getSuccessors()[0], {}); - continue; - } - - auto thelastinst = basicBlock->getInstructions().end(); - --thelastinst; - - // 根据br指令传递的后继块信息,跳过空块链 - if (thelastinst->get()->isUnconditional()) { - BasicBlock* OldBrBlock = dynamic_cast(thelastinst->get()->getOperand(0)); - BasicBlock *thelastBlockOld = nullptr; - // 如果空块链表为多个块 - while (EmptyBlocks.find(dynamic_cast(thelastinst->get()->getOperand(0))) != - EmptyBlocks.end()) { - thelastBlockOld = dynamic_cast(thelastinst->get()->getOperand(0)); - thelastinst->get()->replaceOperand(0, EmptyBlocks[thelastBlockOld]); - } - - basicBlock->removeSuccessor(OldBrBlock); - OldBrBlock->removePredecessor(basicBlock.get()); - basicBlock->addSuccessor(dynamic_cast(thelastinst->get()->getOperand(0))); - dynamic_cast(thelastinst->get()->getOperand(0))->addPredecessor(basicBlock.get()); - - if (thelastBlockOld != nullptr) { - int indexphi = 0; - for (auto &pred : dynamic_cast(thelastinst->get()->getOperand(0))->getPredecessors()) { - if (pred == thelastBlockOld) { - break; - } - indexphi++; - } - - // 更新phi指令的操作数 - // 移除thelastBlockOld对应的phi操作数 - for (auto &InstInNew : dynamic_cast(thelastinst->get()->getOperand(0))->getInstructions()) { - if (InstInNew->isPhi()) { - dynamic_cast(InstInNew.get())->removeOperand(indexphi + 1); - } else { - break; - } - } - } - - } else if (thelastinst->get()->getKind() == Instruction::kCondBr) { - auto OldThenBlock = dynamic_cast(thelastinst->get()->getOperand(1)); - auto OldElseBlock = dynamic_cast(thelastinst->get()->getOperand(2)); - - BasicBlock *thelastBlockOld = nullptr; - while (EmptyBlocks.find(dynamic_cast(thelastinst->get()->getOperand(1))) != - EmptyBlocks.end()) { - thelastBlockOld = dynamic_cast(thelastinst->get()->getOperand(1)); - thelastinst->get()->replaceOperand( - 1, EmptyBlocks[dynamic_cast(thelastinst->get()->getOperand(1))]); - } - basicBlock->removeSuccessor(OldThenBlock); - OldThenBlock->removePredecessor(basicBlock.get()); - // 处理 then 和 else 分支合并的情况 - if (dynamic_cast(thelastinst->get()->getOperand(1)) == - dynamic_cast(thelastinst->get()->getOperand(2))) { - auto thebrBlock = dynamic_cast(thelastinst->get()->getOperand(1)); - usedelete(thelastinst->get()); - thelastinst = basicBlock->getInstructions().erase(thelastinst); - pBuilder->setPosition(basicBlock.get(), basicBlock->end()); - pBuilder->createUncondBrInst(thebrBlock, {}); - continue; - } - basicBlock->addSuccessor(dynamic_cast(thelastinst->get()->getOperand(1))); - dynamic_cast(thelastinst->get()->getOperand(1))->addPredecessor(basicBlock.get()); - // auto indexInNew = dynamic_cast(thelastinst->get()->getOperand(0))->getPredecessors(). - - if (thelastBlockOld != nullptr) { - int indexphi = 0; - for (auto &pred : dynamic_cast(thelastinst->get()->getOperand(1))->getPredecessors()) { - if (pred == thelastBlockOld) { - break; - } - indexphi++; - } - - for (auto &InstInNew : dynamic_cast(thelastinst->get()->getOperand(1))->getInstructions()) { - if (InstInNew->isPhi()) { - dynamic_cast(InstInNew.get())->removeOperand(indexphi + 1); - } else { - break; - } - } - } - - thelastBlockOld = nullptr; - while (EmptyBlocks.find(dynamic_cast(thelastinst->get()->getOperand(2))) != - EmptyBlocks.end()) { - thelastBlockOld = dynamic_cast(thelastinst->get()->getOperand(2)); - thelastinst->get()->replaceOperand( - 2, EmptyBlocks[dynamic_cast(thelastinst->get()->getOperand(2))]); - } - basicBlock->removeSuccessor(OldElseBlock); - OldElseBlock->removePredecessor(basicBlock.get()); - // 处理 then 和 else 分支合并的情况 - if (dynamic_cast(thelastinst->get()->getOperand(1)) == - dynamic_cast(thelastinst->get()->getOperand(2))) { - auto thebrBlock = dynamic_cast(thelastinst->get()->getOperand(1)); - usedelete(thelastinst->get()); - thelastinst = basicBlock->getInstructions().erase(thelastinst); - pBuilder->setPosition(basicBlock.get(), basicBlock->end()); - pBuilder->createUncondBrInst(thebrBlock, {}); - continue; - } - basicBlock->addSuccessor(dynamic_cast(thelastinst->get()->getOperand(2))); - dynamic_cast(thelastinst->get()->getOperand(2))->addPredecessor(basicBlock.get()); - - if (thelastBlockOld != nullptr) { - int indexphi = 0; - for (auto &pred : dynamic_cast(thelastinst->get()->getOperand(2))->getPredecessors()) { - if (pred == thelastBlockOld) { - break; - } - indexphi++; - } - for (auto &InstInNew : dynamic_cast(thelastinst->get()->getOperand(2))->getInstructions()) { - if (InstInNew->isPhi()) { - dynamic_cast(InstInNew.get())->removeOperand(indexphi + 1); - } else { - break; - } - } - } - } else { - if (basicBlock->getNumSuccessors() == 1) { - pBuilder->setPosition(basicBlock.get(), basicBlock->end()); - pBuilder->createUncondBrInst(basicBlock->getSuccessors()[0], {}); - auto thelastinst = basicBlock->getInstructions().end(); - (--thelastinst); - auto OldBrBlock = dynamic_cast(thelastinst->get()->getOperand(0)); - sysy::BasicBlock *thelastBlockOld = nullptr; - while (EmptyBlocks.find(dynamic_cast(thelastinst->get()->getOperand(0))) != - EmptyBlocks.end()) { - thelastBlockOld = dynamic_cast(thelastinst->get()->getOperand(0)); - - thelastinst->get()->replaceOperand( - 0, EmptyBlocks[dynamic_cast(thelastinst->get()->getOperand(0))]); - } - - basicBlock->removeSuccessor(OldBrBlock); - OldBrBlock->removePredecessor(basicBlock.get()); - basicBlock->addSuccessor(dynamic_cast(thelastinst->get()->getOperand(0))); - dynamic_cast(thelastinst->get()->getOperand(0))->addPredecessor(basicBlock.get()); - if (thelastBlockOld != nullptr) { - int indexphi = 0; - for (auto &pred : dynamic_cast(thelastinst->get()->getOperand(0))->getPredecessors()) { - if (pred == thelastBlockOld) { - break; - } - indexphi++; - } - - for (auto &InstInNew : dynamic_cast(thelastinst->get()->getOperand(0))->getInstructions()) { - if (InstInNew->isPhi()) { - dynamic_cast(InstInNew.get())->removeOperand(indexphi + 1); - } else { - break; - } - } - } - } - } - } - - for (auto iter = function.second->getBasicBlocks().begin(); iter != function.second->getBasicBlocks().end();) { - - if (EmptyBlocks.find(iter->get()) != EmptyBlocks.end()) { - // EntryBlock跳过 - if (iter->get() == function.second->getEntryBlock()) { - ++iter; - continue; - } - - for (auto &iterInst : iter->get()->getInstructions()) - usedelete(iterInst.get()); - // 删除不可达基本块的phi指令的操作数 - for (auto &succ : iter->get()->getSuccessors()) { - int index = 0; - for (auto &pred : succ->getPredecessors()) { - if (pred == iter->get()) { - break; - } - index++; - } - - for (auto &instinsucc : succ->getInstructions()) { - if (instinsucc->isPhi()) { - dynamic_cast(instinsucc.get())->removeOperand(index); - } else { - break; - } - } - } - - function.second->removeBasicBlock((iter++)->get()); - } else { - ++iter; - } - } - } -} - -// 如果函数没有返回指令,则添加一个默认返回指令(主要解决void函数没有返回指令的问题) -void SysYOptPre::SysYAddReturn() { - auto &functions = pModule->getFunctions(); - for (auto &function : functions) { - auto &func = function.second; - auto basicBlocks = func->getBasicBlocks(); - for (auto &block : basicBlocks) { - if (block->getNumSuccessors() == 0) { - // 如果基本块没有后继块,则添加一个返回指令 - if (block->getNumInstructions() == 0) { - pBuilder->setPosition(block.get(), block->end()); - pBuilder->createReturnInst(); - } - auto thelastinst = block->getInstructions().end(); - --thelastinst; - if (thelastinst->get()->getKind() != Instruction::kReturn) { - // std::cout << "Warning: Function " << func->getName() << " has no return instruction, adding default return." << std::endl; - - pBuilder->setPosition(block.get(), block->end()); - // TODO: 如果int float函数缺少返回值是否需要报错 - if (func->getReturnType()->isInt()) { - pBuilder->createReturnInst(ConstantInteger::get(0)); - } else if (func->getReturnType()->isFloat()) { - pBuilder->createReturnInst(ConstantFloating::get(0.0F)); - } else { - pBuilder->createReturnInst(); - } - } - } - } - } -} - -} // namespace sysy diff --git a/src/SysYIRPassManager.cpp b/src/SysYIRPassManager.cpp new file mode 100644 index 0000000..f66f74a --- /dev/null +++ b/src/SysYIRPassManager.cpp @@ -0,0 +1,36 @@ +// PassManager.cpp +#include "SysYIRPassManager.h" +#include + +namespace sysy { + +void PassManager::run(Module& M) { + // 首先运行Module级别的Pass + for (auto& pass : modulePasses) { + std::cout << "Running Module Pass: " << pass->getPassName() << std::endl; + pass->runOnModule(M); + } + + // 然后对每个函数运行Function级别的Pass + auto& functions = M.getFunctions(); + for (auto& pair : functions) { + Function& F = *(pair.second); // 获取Function的引用 + std::cout << " Processing Function: " << F.getName() << std::endl; + + // 在每个函数上运行FunctionPasses + bool changedInFunction; + do { + changedInFunction = false; + for (auto& pass : functionPasses) { + // 对于FunctionPasses,可以考虑一个迭代执行的循环,直到稳定 + std::cout << " Running Function Pass: " << pass->getPassName() << std::endl; + changedInFunction |= pass->runOnFunction(F); + } + } while (changedInFunction); // 循环直到函数稳定,这模拟了您SysYCFGOpt的while(changed)逻辑 + } + + // 分析Pass的运行可以在其他Pass需要时触发,或者在特定的PassManager阶段触发 + // 对于依赖于分析结果的Pass,可以在其run方法中通过PassManager::getAnalysis()来获取 +} + +} // namespace sysy \ No newline at end of file diff --git a/src/SysYIRPrinter.cpp b/src/SysYIRPrinter.cpp index 7d92a0f..689fd50 100644 --- a/src/SysYIRPrinter.cpp +++ b/src/SysYIRPrinter.cpp @@ -3,12 +3,11 @@ #include #include #include -#include "IR.h" +#include "IR.h" // 确保IR.h包含了ArrayType、GetElementPtrInst等的定义 namespace sysy { void SysYPrinter::printIR() { - const auto &functions = pModule->getFunctions(); //TODO: Print target datalayout and triple (minimal required by LLVM) @@ -36,11 +35,18 @@ std::string SysYPrinter::getTypeString(Type *type) { return "i32"; } else if (type->isFloat()) { return "float"; - } else if (auto ptrType = dynamic_cast(type)) { + // 递归打印指针指向的类型,然后加上 '*' return getTypeString(ptrType->getBaseType()) + "*"; - } else if (auto ptrType = dynamic_cast(type)) { - return getTypeString(ptrType->getReturnType()); + } else if (auto funcType = dynamic_cast(type)) { + // 对于函数类型,打印其返回类型 + // 注意:这里可能需要更完整的函数签名打印,取决于你的IR表示方式 + // 比如:`retType (paramType1, paramType2, ...)` + // 但为了简化和LLVM IR兼容性,通常在定义时完整打印 + return getTypeString(funcType->getReturnType()); + } else if (auto arrayType = dynamic_cast(type)) { // 新增:处理数组类型 + // 打印格式为 [num_elements x element_type] + return "[" + std::to_string(arrayType->getNumElements()) + " x " + getTypeString(arrayType->getElementType()) + "]"; } assert(false && "Unsupported type"); return ""; @@ -51,15 +57,23 @@ std::string SysYPrinter::getValueName(Value *value) { return "@" + global->getName(); } else if (auto inst = dynamic_cast(value)) { return "%" + inst->getName(); - } else if (auto constVal = dynamic_cast(value)) { - if (constVal->isFloat()) { - return std::to_string(constVal->getFloat()); + } else if (auto constInt = dynamic_cast(value)) { // 优先匹配具体的常量类型 + return std::to_string(constInt->getInt()); + } else if (auto constFloat = dynamic_cast(value)) { // 优先匹配具体的常量类型 + return std::to_string(constFloat->getFloat()); + } else if (auto constUndef = dynamic_cast(value)) { // 如果有Undef类型 + return "undef"; + } else if (auto constVal = dynamic_cast(value)) { // fallback for generic ConstantValue + // 这里的逻辑可能需要根据你ConstantValue的实际设计调整 + // 确保它能处理所有可能的ConstantValue + if (constVal->getType()->isFloat()) { + return std::to_string(constVal->getFloat()); } return std::to_string(constVal->getInt()); } else if (auto constVar = dynamic_cast(value)) { - return constVar->getName(); + return constVar->getName(); // 假设ConstantVariable有自己的名字或通过getByIndices获取值 } - assert(false && "Unknown value type"); + assert(false && "Unknown value type or unable to get value name"); return ""; } @@ -77,44 +91,35 @@ void SysYPrinter::printGlobalVariable() { for (const auto &global : globals) { std::cout << "@" << global->getName() << " = global "; - auto baseType = dynamic_cast(global->getType())->getBaseType(); - printType(baseType); - - if (global->getNumDims() > 0) { - // Array type - std::cout << " ["; - for (unsigned i = 0; i < global->getNumDims(); i++) { - if (i > 0) std::cout << " x "; - std::cout << getValueName(global->getDim(i)); - } - std::cout << "]"; - } + // 全局变量的类型是一个指针,指向其基类型 (可能是 ArrayType 或 Integer/FloatType) + auto globalVarBaseType = dynamic_cast(global->getType())->getBaseType(); + printType(globalVarBaseType); // 打印全局变量的实际类型 (例如 i32 或 [10 x i32]) std::cout << " "; - if (global->getNumDims() > 0) { - // Array initializer - std::cout << "["; - auto values = global->getInitValues(); - auto counterValues = values.getValues(); - auto counterNumbers = values.getNumbers(); + // 检查是否是数组类型 (通过检查 globalVarBaseType 是否是 ArrayType) + if (globalVarBaseType->isArray()) { + // 数组初始化器 + std::cout << "["; // LLVM IR 数组初始化器格式: [type value, type value, ...] + auto values = global->getInitValues(); // 假设 getInitValues() 返回一个 ValueCounter + const std::vector &counterValues = values.getValues(); // 获取所有值 - for (size_t i = 0; i < counterNumbers.size(); i++) { + for (size_t i = 0; i < counterValues.size(); i++) { if (i > 0) std::cout << ", "; - if (baseType->isFloat()) { - std::cout << "float " << dynamic_cast(counterValues[i])->getFloat(); - } else { - std::cout << "i32 " << dynamic_cast(counterValues[i])->getInt(); - } + // 打印元素类型,这个元素类型应该是数组的最终元素类型,例如 i32 或 float + // 可以从 globalVarBaseType 逐层剥离得到最终元素类型,但这里简化为直接从值获取 + printType(counterValues[i]->getType()); + std::cout << " "; + printValue(counterValues[i]); } std::cout << "]"; } else { - // Scalar initializer - if (baseType->isFloat()) { - std::cout << "float " << dynamic_cast(global->getByIndex(0))->getFloat(); - } else { - std::cout << "i32 " << dynamic_cast(global->getByIndex(0))->getInt(); - } + // 标量初始化器 + // 假设标量全局变量的初始化值通过 getByIndex(0) 获取 + Value* initVal = global->getByIndex(0); + printType(initVal->getType()); // 打印标量值的类型 + std::cout << " "; + printValue(initVal); // 打印标量值 } std::cout << ", align 4" << std::endl; @@ -145,9 +150,7 @@ void SysYPrinter::printFunction(Function *function) { for (const auto &blockIter : function->getBasicBlocks()) { // Basic block label BasicBlock* blockPtr = blockIter.get(); - if (blockPtr == function->getEntryBlock()) { - std::cout << "entry:" << std::endl; - } else if (!blockPtr->getName().empty()) { + if (!blockPtr->getName().empty()) { std::cout << blockPtr->getName() << ":" << std::endl; } @@ -209,19 +212,19 @@ void SysYPrinter::printInst(Instruction *pInst) { case Kind::kFDiv: std::cout << "fdiv"; break; case Kind::kICmpEQ: std::cout << "icmp eq"; break; case Kind::kICmpNE: std::cout << "icmp ne"; break; - case Kind::kICmpLT: std::cout << "icmp slt"; break; + case Kind::kICmpLT: std::cout << "icmp slt"; break; // LLVM uses slt/sgt for signed less/greater than case Kind::kICmpGT: std::cout << "icmp sgt"; break; case Kind::kICmpLE: std::cout << "icmp sle"; break; case Kind::kICmpGE: std::cout << "icmp sge"; break; - case Kind::kFCmpEQ: std::cout << "fcmp oeq"; break; - case Kind::kFCmpNE: std::cout << "fcmp one"; break; - case Kind::kFCmpLT: std::cout << "fcmp olt"; break; - case Kind::kFCmpGT: std::cout << "fcmp ogt"; break; - case Kind::kFCmpLE: std::cout << "fcmp ole"; break; - case Kind::kFCmpGE: std::cout << "fcmp oge"; break; + case Kind::kFCmpEQ: std::cout << "fcmp oeq"; break; // oeq for ordered equal + case Kind::kFCmpNE: std::cout << "fcmp one"; break; // one for ordered not equal + case Kind::kFCmpLT: std::cout << "fcmp olt"; break; // olt for ordered less than + case Kind::kFCmpGT: std::cout << "fcmp ogt"; break; // ogt for ordered greater than + case Kind::kFCmpLE: std::cout << "fcmp ole"; break; // ole for ordered less than or equal + case Kind::kFCmpGE: std::cout << "fcmp oge"; break; // oge for ordered greater than or equal case Kind::kAnd: std::cout << "and"; break; case Kind::kOr: std::cout << "or"; break; - default: break; + default: break; // Should not reach here } // Types and operands @@ -238,7 +241,6 @@ void SysYPrinter::printInst(Instruction *pInst) { case Kind::kNeg: case Kind::kNot: case Kind::kFNeg: - case Kind::kFNot: case Kind::kFtoI: case Kind::kBitFtoI: case Kind::kItoF: @@ -250,31 +252,39 @@ void SysYPrinter::printInst(Instruction *pInst) { } switch (pInst->getKind()) { - case Kind::kNeg: std::cout << "sub "; break; - case Kind::kNot: std::cout << "not "; break; - case Kind::kFNeg: std::cout << "fneg "; break; - case Kind::kFNot: std::cout << "fneg "; break; // FNot not standard, map to fneg - case Kind::kFtoI: std::cout << "fptosi "; break; - case Kind::kBitFtoI: std::cout << "bitcast "; break; - case Kind::kItoF: std::cout << "sitofp "; break; - case Kind::kBitItoF: std::cout << "bitcast "; break; - default: break; + case Kind::kNeg: std::cout << "sub "; break; // integer negation is `sub i32 0, operand` + case Kind::kNot: std::cout << "xor "; break; // logical/bitwise NOT is `xor i32 -1, operand` or `xor i1 true, operand` + case Kind::kFNeg: std::cout << "fneg "; break; // float negation + case Kind::kFtoI: std::cout << "fptosi "; break; // float to signed integer + case Kind::kBitFtoI: std::cout << "bitcast "; break; // bitcast float to int + case Kind::kItoF: std::cout << "sitofp "; break; // signed integer to float + case Kind::kBitItoF: std::cout << "bitcast "; break; // bitcast int to float + default: break; // Should not reach here } - printType(unyInst->getType()); + printType(unyInst->getOperand()->getType()); // Print operand type std::cout << " "; - // Special handling for negation - if (pInst->getKind() == Kind::kNeg || pInst->getKind() == Kind::kNot) { - std::cout << "i32 0, "; + // Special handling for integer negation and logical NOT + if (pInst->getKind() == Kind::kNeg) { + std::cout << "0, "; // for 'sub i32 0, operand' + } else if (pInst->getKind() == Kind::kNot) { + // For logical NOT (i1 -> i1), use 'xor i1 true, operand' + // For bitwise NOT (i32 -> i32), use 'xor i32 -1, operand' + if (unyInst->getOperand()->getType()->isInt()) { // Assuming i32 for bitwise NOT + std::cout << "NOT, "; // or specific bitmask for NOT + } else { // Assuming i1 for logical NOT + std::cout << "true, "; + } } printValue(pInst->getOperand(0)); - // For bitcast, need to specify destination type - if (pInst->getKind() == Kind::kBitFtoI || pInst->getKind() == Kind::kBitItoF) { + // For type conversions (fptosi, sitofp, bitcast), need to specify destination type + if (pInst->getKind() == Kind::kFtoI || pInst->getKind() == Kind::kItoF || + pInst->getKind() == Kind::kBitFtoI || pInst->getKind() == Kind::kBitItoF) { std::cout << " to "; - printType(unyInst->getType()); + printType(unyInst->getType()); // Print result type } std::cout << std::endl; @@ -289,7 +299,7 @@ void SysYPrinter::printInst(Instruction *pInst) { } std::cout << "call "; - printType(callInst->getType()); + printType(callInst->getType()); // Return type of the call std::cout << " @" << function->getName() << "("; auto params = callInst->getArguments(); @@ -297,9 +307,9 @@ void SysYPrinter::printInst(Instruction *pInst) { for (auto ¶m : params) { if (!first) std::cout << ", "; first = false; - printType(param->getValue()->getType()); + printType(param->getValue()->getType()); // Type of argument std::cout << " "; - printValue(param->getValue()); + printValue(param->getValue()); // Value of argument } std::cout << ")" << std::endl; @@ -307,7 +317,7 @@ void SysYPrinter::printInst(Instruction *pInst) { case Kind::kCondBr: { auto condBrInst = dynamic_cast(pInst); - std::cout << "br i1 "; + std::cout << "br i1 "; // Condition type should be i1 printValue(condBrInst->getCondition()); std::cout << ", label %" << condBrInst->getThenBlock()->getName(); std::cout << ", label %" << condBrInst->getElseBlock()->getName(); @@ -337,14 +347,17 @@ void SysYPrinter::printInst(Instruction *pInst) { auto allocaInst = dynamic_cast(pInst); std::cout << "%" << allocaInst->getName() << " = alloca "; - auto baseType = dynamic_cast(allocaInst->getType())->getBaseType(); - printType(baseType); + // AllocaInst 的类型现在应该是一个 PointerType,指向正确的 ArrayType 或 ScalarType + // 例如:alloca i32, align 4 或者 alloca [10 x i32], align 4 + auto allocatedType = dynamic_cast(allocaInst->getType())->getBaseType(); + printType(allocatedType); - if (allocaInst->getNumDims() > 0) { + // 仍然打印维度信息,如果存在的话 + if (allocaInst->getNumDims() > 0) { std::cout << ", "; for (size_t i = 0; i < allocaInst->getNumDims(); i++) { if (i > 0) std::cout << ", "; - printType(Type::getIntType()); + printType(Type::getIntType()); // 维度大小通常是 i32 类型 std::cout << " "; printValue(allocaInst->getDim(i)); } @@ -356,70 +369,74 @@ void SysYPrinter::printInst(Instruction *pInst) { case Kind::kLoad: { auto loadInst = dynamic_cast(pInst); std::cout << "%" << loadInst->getName() << " = load "; - printType(loadInst->getType()); + printType(loadInst->getType()); // 加载的结果类型 std::cout << ", "; - printType(loadInst->getPointer()->getType()); + printType(loadInst->getPointer()->getType()); // 指针类型 std::cout << " "; - printValue(loadInst->getPointer()); + printValue(loadInst->getPointer()); // 要加载的地址 + // 仍然打印索引信息,如果存在的话 if (loadInst->getNumIndices() > 0) { - std::cout << ", "; + std::cout << ", indices "; // 或者其他分隔符,取决于你期望的格式 for (size_t i = 0; i < loadInst->getNumIndices(); i++) { - if (i > 0) std::cout << ", "; - printType(Type::getIntType()); - std::cout << " "; - printValue(loadInst->getIndex(i)); + if (i > 0) std::cout << ", "; + printType(loadInst->getIndex(i)->getType()); + std::cout << " "; + printValue(loadInst->getIndex(i)); } } std::cout << ", align 4" << std::endl; } break; - case Kind::kLa: { - auto laInst = dynamic_cast(pInst); - std::cout << "%" << laInst->getName() << " = getelementptr inbounds "; - - auto ptrType = dynamic_cast(laInst->getPointer()->getType()); - printType(ptrType->getBaseType()); - std::cout << ", "; - printType(laInst->getPointer()->getType()); - std::cout << " "; - printValue(laInst->getPointer()); - std::cout << ", "; - - for (size_t i = 0; i < laInst->getNumIndices(); i++) { - if (i > 0) std::cout << ", "; - printType(Type::getIntType()); - std::cout << " "; - printValue(laInst->getIndex(i)); - } - - std::cout << std::endl; - } break; - case Kind::kStore: { auto storeInst = dynamic_cast(pInst); std::cout << "store "; - printType(storeInst->getValue()->getType()); + printType(storeInst->getValue()->getType()); // 要存储的值的类型 std::cout << " "; - printValue(storeInst->getValue()); + printValue(storeInst->getValue()); // 要存储的值 std::cout << ", "; - printType(storeInst->getPointer()->getType()); + printType(storeInst->getPointer()->getType()); // 目标指针的类型 std::cout << " "; - printValue(storeInst->getPointer()); + printValue(storeInst->getPointer()); // 目标地址 + // 仍然打印索引信息,如果存在的话 if (storeInst->getNumIndices() > 0) { - std::cout << ", "; + std::cout << ", indices "; // 或者其他分隔符 for (size_t i = 0; i < storeInst->getNumIndices(); i++) { - if (i > 0) std::cout << ", "; - printType(Type::getIntType()); - std::cout << " "; - printValue(storeInst->getIndex(i)); + if (i > 0) std::cout << ", "; + printType(storeInst->getIndex(i)->getType()); + std::cout << " "; + printValue(storeInst->getIndex(i)); } } std::cout << ", align 4" << std::endl; } break; + + case Kind::kGetElementPtr: { // 新增:GetElementPtrInst 打印 + auto gepInst = dynamic_cast(pInst); + std::cout << "%" << gepInst->getName() << " = getelementptr inbounds "; // 假设总是 inbounds + + // GEP 的第一个操作数是基指针,其类型是一个指向聚合类型的指针 + // 第一个参数是基指针所指向的聚合类型的类型 (e.g., [10 x i32]) + auto basePtrType = dynamic_cast(gepInst->getBasePointer()->getType()); + printType(basePtrType->getBaseType()); // 打印基指针指向的类型 + + std::cout << ", "; + printType(gepInst->getBasePointer()->getType()); // 打印基指针自身的类型 (e.g., [10 x i32]*) + std::cout << " "; + printValue(gepInst->getBasePointer()); // 打印基指针 + + // 打印所有索引 + for (auto indexVal : gepInst->getIndices()) { // 使用 getIndices() 迭代器 + std::cout << ", "; + printType(indexVal->getValue()->getType()); // 打印索引的类型 (通常是 i32) + std::cout << " "; + printValue(indexVal->getValue()); // 打印索引值 + } + std::cout << std::endl; + } break; case Kind::kMemset: { auto memsetInst = dynamic_cast(pInst); @@ -433,51 +450,40 @@ void SysYPrinter::printInst(Instruction *pInst) { printValue(memsetInst->getValue()); std::cout << ", i32 "; printValue(memsetInst->getSize()); - std::cout << ", i1 false)" << std::endl; + std::cout << ", i1 false)" << std::endl; // alignment for memset is typically i1 } break; case Kind::kPhi: { auto phiInst = dynamic_cast(pInst); - printValue(phiInst->getOperand(0)); - std::cout << " = phi "; - printType(phiInst->getType()); + // Phi 指令的名称通常是结果变量 + std::cout << "%" << phiInst->getName() << " = phi "; + printType(phiInst->getType()); // Phi 结果类型 - for (unsigned i = 1; i < phiInst->getNumOperands(); i++) { - if (i > 0) std::cout << ", "; + // Phi 指令的操作数是成对的 [value, basic_block] + // 这里假设 getOperands() 返回的是 (val1, block1, val2, block2...) + // 如果你的 PhiInst 存储方式是 getIncomingValues() 和 getIncomingBlocks(),请相应调整 + // LLVM IR 格式: phi type [value1, block1], [value2, block2] + bool firstPair = true; + for (unsigned i = 0; i < phiInst->getNumOperands() / 2; ++i) { // 遍历成对的操作数 + if (!firstPair) std::cout << ", "; + firstPair = false; std::cout << "[ "; - printValue(phiInst->getOperand(i)); + printValue(phiInst->getOperand(i * 2)); // value + std::cout << ", %"; + printValue(phiInst->getOperand(i * 2 + 1)); // block std::cout << " ]"; } std::cout << std::endl; } break; - case Kind::kGetSubArray: { - auto getSubArrayInst = dynamic_cast(pInst); - std::cout << "%" << getSubArrayInst->getName() << " = getelementptr inbounds "; - - auto ptrType = dynamic_cast(getSubArrayInst->getFatherArray()->getType()); - printType(ptrType->getBaseType()); - std::cout << ", "; - printType(getSubArrayInst->getFatherArray()->getType()); - std::cout << " "; - printValue(getSubArrayInst->getFatherArray()); - std::cout << ", "; - bool firstIndex = true; - for (auto &index : getSubArrayInst->getIndices()) { - if (!firstIndex) std::cout << ", "; - firstIndex = false; - printType(Type::getIntType()); - std::cout << " "; - printValue(index->getValue()); - } - - std::cout << std::endl; - } break; + // 以下两个 Kind 应该删除或替换为 kGEP + // case Kind::kLa: { /* REMOVED */ } break; + // case Kind::kGetSubArray: { /* REMOVED */ } break; default: - assert(false && "Unsupported instruction kind"); + assert(false && "Unsupported instruction kind in SysYPrinter"); break; } } -} // namespace sysy +} // namespace sysy \ No newline at end of file diff --git a/src/include/DCE.h b/src/include/DCE.h new file mode 100644 index 0000000..41bc223 --- /dev/null +++ b/src/include/DCE.h @@ -0,0 +1,63 @@ +#pragma once + +#include "Pass.h" +#include "IR.h" +#include "SysYIROptUtils.h" +#include "Dom.h" +#include +#include + +namespace sysy { + +// 前向声明分析结果类,确保在需要时可以引用 +// class DominatorTreeAnalysisResult; // Pass.h 中已包含,这里不再需要 +class SideEffectInfoAnalysisResult; // 假设有副作用分析结果类 + +// DCEContext 类,用于封装DCE的内部逻辑和状态 +// 这样可以避免静态变量在多线程或多次运行时的冲突,并保持代码的模块化 +class DCEContext { +public: + // 运行DCE的主要方法 + // func: 当前要优化的函数 + // tp: 分析管理器,用于获取其他分析结果(如果需要) + void run(Function* func, AnalysisManager* AM, bool &changed); + +private: + // 存储活跃指令的集合 + std::unordered_set alive_insts; + + // 判断指令是否是“天然活跃”的(即总是保留的) + // inst: 要检查的指令 + // 返回值: 如果指令是天然活跃的,则为true,否则为false + bool isAlive(Instruction* inst); + + // 递归地将活跃指令及其依赖加入到 alive_insts 集合中 + // inst: 要标记为活跃的指令 + void addAlive(Instruction* inst); +}; + +// DCE 优化遍类,继承自 OptimizationPass +class DCE : public OptimizationPass { +public: + // 构造函数 + DCE() : OptimizationPass("DCE", Granularity::Function) {} + + // 静态成员,作为该遍的唯一ID + static void *ID; + + // 运行在函数上的优化逻辑 + // F: 当前要优化的函数 + // AM: 分析管理器,用于获取或使分析结果失效 + // 返回值: 如果IR被修改,则为true,否则为false + bool runOnFunction(Function *F, AnalysisManager& AM) override; + + // 声明该遍的分析依赖和失效信息 + // analysisDependencies: 该遍运行前需要哪些分析结果 + // analysisInvalidations: 该遍运行后会使哪些分析结果失效 + void getAnalysisUsage(std::set &analysisDependencies, std::set &analysisInvalidations) const override; + + // Pass 基类中的纯虚函数,必须实现 + void *getPassID() const override { return &ID; } +}; + +} // namespace sysy \ No newline at end of file diff --git a/src/include/DeadCodeElimination.h b/src/include/DeadCodeElimination.h deleted file mode 100644 index 72b9935..0000000 --- a/src/include/DeadCodeElimination.h +++ /dev/null @@ -1,39 +0,0 @@ -#pragma once - -#include "IR.h" -#include "SysYIRAnalyser.h" -#include "SysYIRPrinter.h" - -namespace sysy { - -class DeadCodeElimination { - private: - Module *pModule; - ControlFlowAnalysis *pCFA; // 控制流分析指针 - ActiveVarAnalysis *pAVA; // 活跃变量分析指针 - DataFlowAnalysisUtils dataFlowAnalysisUtils; // 数据流分析工具类 - - public: - explicit DeadCodeElimination(Module *pMoudle, - ControlFlowAnalysis *pCFA = nullptr, - ActiveVarAnalysis *pAVA = nullptr) - : pModule(pMoudle), pCFA(pCFA), pAVA(pAVA), dataFlowAnalysisUtils() {} // 构造函数 - - // TODO:根据参数传入的passes来运行不同的死代码删除流程 - // void runDCEPipeline(const std::vector& passes = { - // "dead-store", "redundant-load-store", "dead-load", "dead-alloca", "dead-global" - // }); - void runDCEPipeline(); // 运行死代码删除 - - void eliminateDeadStores(Function* func, bool& changed); // 消除无用存储 - void eliminateDeadLoads(Function* func, bool& changed); // 消除无用加载 - void eliminateDeadAllocas(Function* func, bool& changed); // 消除无用内存分配 - void eliminateDeadGlobals(bool& changed); // 消除无用全局变量 - void eliminateDeadIndirectiveAllocas(Function* func, bool& changed); // 消除无用间接内存分配(phi节点) - void eliminateDeadRedundantLoadStore(Function* func, bool& changed); // 消除冗余加载和存储 - bool isGlobal(Value *val); - bool isArr(Value *val); - void usedelete(Instruction *instr); - -}; -} // namespace sysy diff --git a/src/include/Dom.h b/src/include/Dom.h new file mode 100644 index 0000000..f69dcb9 --- /dev/null +++ b/src/include/Dom.h @@ -0,0 +1,52 @@ +#pragma once + +#include "Pass.h" // 包含 Pass 框架 +#include "IR.h" // 包含 IR 定义 +#include +#include +#include +#include + +namespace sysy { + +// 支配树分析结果类 (保持不变) +class DominatorTree : public AnalysisResultBase { +public: + DominatorTree(Function* F); + const std::set* getDominators(BasicBlock* BB) const; + BasicBlock* getImmediateDominator(BasicBlock* BB) const; + const std::set* getDominanceFrontier(BasicBlock* BB) const; + const std::map>& getDominatorsMap() const { return Dominators; } + const std::map& getIDomsMap() const { return IDoms; } + const std::map>& getDominanceFrontiersMap() const { return DominanceFrontiers; } + void computeDominators(Function* F); + void computeIDoms(Function* F); + void computeDominanceFrontiers(Function* F); +private: + Function* AssociatedFunction; + std::map> Dominators; + std::map IDoms; + std::map> DominanceFrontiers; +}; + + +// 支配树分析遍 +class DominatorTreeAnalysisPass : public AnalysisPass { +public: + // 唯一的 Pass ID + static void *ID; + + DominatorTreeAnalysisPass() : AnalysisPass("DominatorTreeAnalysis", Pass::Granularity::Function) {} + + // 实现 getPassID + void* getPassID() const override { return &ID; } + + bool runOnFunction(Function* F, AnalysisManager &AM) override; + + std::unique_ptr getResult() override; + +private: + std::unique_ptr CurrentDominatorTree; +}; + +} // namespace sysy \ No newline at end of file diff --git a/src/include/IR.h b/src/include/IR.h index 060bdc5..8f0103d 100644 --- a/src/include/IR.h +++ b/src/include/IR.h @@ -49,6 +49,7 @@ class Type { kLabel, kPointer, kFunction, + kArray, }; Kind kind; ///< 表示具体类型的变量 @@ -65,6 +66,7 @@ class Type { static Type* getPointerType(Type *baseType); ///< 返回表示指向baseType类型的Pointer类型的Type指针 static Type* getFunctionType(Type *returnType, const std::vector ¶mTypes = {}); ///< 返回表示返回类型为returnType,形参类型列表为paramTypes的函数类型的Type指针 + static Type* getArrayType(Type *elementType, unsigned numElements); public: Kind getKind() const { return kind; } ///< 返回Type对象代表原始标量类型 @@ -74,6 +76,7 @@ class Type { bool isLabel() const { return kind == kLabel; } ///< 判定是否为Label类型 bool isPointer() const { return kind == kPointer; } ///< 判定是否为Pointer类型 bool isFunction() const { return kind == kFunction; } ///< 判定是否为Function类型 + bool isArray() const { return kind == Kind::kArray; } unsigned getSize() const; ///< 返回类型所占的空间大小(字节) /// 尝试将一个变量转换为给定的Type及其派生类类型的变量 template @@ -115,6 +118,22 @@ class FunctionType : public Type { unsigned getNumParams() const { return paramTypes.size(); } ///< 获取形参数量 }; +class ArrayType : public Type { + public: + // elements:数组的元素类型 (例如,int[3] 的 elementType 是 int) + // numElements:该维度的大小 (例如,int[3] 的 numElements 是 3) + static ArrayType *get(Type *elementType, unsigned numElements); + + Type *getElementType() const { return elementType; } + unsigned getNumElements() const { return numElements; } + + protected: + ArrayType(Type *elementType, unsigned numElements) + : Type(Kind::kArray), elementType(elementType), numElements(numElements) {} + Type *elementType; + unsigned numElements; // 当前维度的大小 +}; + /*! * @} */ @@ -502,13 +521,22 @@ public: Function* getParent() const { return parent; } void setParent(Function *func) { parent = func; } inst_list& getInstructions() { return instructions; } + auto getInstructions_Range() const { return make_range(instructions); } arg_list& getArguments() { return arguments; } - const block_list& getPredecessors() const { return predecessors; } + block_list& getPredecessors() { return predecessors; } + void clearPredecessors() { predecessors.clear(); } block_list& getSuccessors() { return successors; } + void clearSuccessors() { successors.clear(); } iterator begin() { return instructions.begin(); } iterator end() { return instructions.end(); } iterator terminator() { return std::prev(end()); } void insertArgument(AllocaInst *inst) { arguments.push_back(inst); } + bool hasSuccessor(BasicBlock *block) const { + return std::find(successors.begin(), successors.end(), block) != successors.end(); + } ///< 判断是否有后继块 + bool hasPredecessor(BasicBlock *block) const { + return std::find(predecessors.begin(), predecessors.end(), block) != predecessors.end(); + } ///< 判断是否有前驱块 void addPredecessor(BasicBlock *block) { if (std::find(predecessors.begin(), predecessors.end(), block) == predecessors.end()) { predecessors.push_back(block); @@ -561,6 +589,15 @@ public: next->addPredecessor(prev); } void removeInst(iterator pos) { instructions.erase(pos); } + void removeInst(Instruction *inst) { + auto pos = std::find_if(instructions.begin(), instructions.end(), + [inst](const std::unique_ptr &i) { return i.get() == inst; }); + if (pos != instructions.end()) { + instructions.erase(pos); + } else { + assert(false && "Instruction not found in BasicBlock"); + } + } ///< 移除指定位置的指令 iterator moveInst(iterator sourcePos, iterator targetPos, BasicBlock *block); }; @@ -602,49 +639,6 @@ class User : public Value { void setOperand(unsigned index, Value *value); ///< 设置操作数 }; -class GetSubArrayInst; -/** - * 左值 具有地址的对象 - */ -class LVal { - friend class GetSubArrayInst; - - protected: - LVal *fatherLVal{}; ///< 父左值 - std::list> childrenLVals; ///< 子左值 - GetSubArrayInst *defineInst{}; /// 定义该左值的GetSubArray指令 - - protected: - LVal() = default; - - public: - virtual ~LVal() = default; - virtual std::vector getLValDims() const = 0; ///< 获取左值的维度 - virtual unsigned getLValNumDims() const = 0; ///< 获取左值的维度数量 - - public: - LVal* getFatherLVal() const { return fatherLVal; } ///< 获取父左值 - const std::list>& getChildrenLVals() const { - return childrenLVals; - } ///< 获取子左值列表 - LVal* getAncestorLVal() const { - auto curLVal = const_cast(this); - while (curLVal->getFatherLVal() != nullptr) { - curLVal = curLVal->getFatherLVal(); - } - return curLVal; - } ///< 获取祖先左值 - void setFatherLVal(LVal *father) { fatherLVal = father; } ///< 设置父左值 - void setDefineInst(GetSubArrayInst *inst) { defineInst = inst; } ///< 设置定义指令 - void addChild(LVal *child) { childrenLVals.emplace_back(child); } ///< 添加子左值 - void removeChild(LVal *child) { - auto iter = std::find_if(childrenLVals.begin(), childrenLVals.end(), - [child](const std::unique_ptr &ptr) { return ptr.get() == child; }); - childrenLVals.erase(iter); - } ///< 移除子左值 - GetSubArrayInst* getDefineInst() const { return defineInst; } ///< 获取定义指令 -}; - /*! * Base of all concrete instruction types. */ @@ -694,15 +688,15 @@ class Instruction : public User { kAlloca = 0x1UL << 33, kLoad = 0x1UL << 34, kStore = 0x1UL << 35, - kLa = 0x1UL << 36, + kGetElementPtr = 0x1UL << 36, kMemset = 0x1UL << 37, - kGetSubArray = 0x1UL << 38, + // kGetSubArray = 0x1UL << 38, // Constant Kind removed as Constants are now Values, not Instructions. // kConstant = 0x1UL << 37, // Conflicts with kMemset if kept as is // phi kPhi = 0x1UL << 39, kBitItoF = 0x1UL << 40, - kBitFtoI = 0x1UL << 41 + kBitFtoI = 0x1UL << 41, }; protected: @@ -793,14 +787,12 @@ public: return "Load"; case kStore: return "Store"; - case kLa: - return "La"; + case kGetElementPtr: + return "GetElementPtr"; case kMemset: return "Memset"; case kPhi: return "Phi"; - case kGetSubArray: - return "GetSubArray"; default: return "Unknown"; } @@ -853,9 +845,8 @@ public: bool isAlloca() const { return kind == kAlloca; } bool isLoad() const { return kind == kLoad; } bool isStore() const { return kind == kStore; } - bool isLa() const { return kind == kLa; } + bool isGetElementPtr() const { return kind == kGetElementPtr; } bool isMemset() const { return kind == kMemset; } - bool isGetSubArray() const { return kind == kGetSubArray; } bool isCall() const { return kind == kCall; } bool isReturn() const { return kind == kReturn; } bool isDefine() const { @@ -867,26 +858,6 @@ public: class Function; //! Function call. -class LaInst : public Instruction { - friend class Function; - friend class IRBuilder; - - protected: - explicit LaInst(Value *pointer, const std::vector &indices = {}, BasicBlock *parent = nullptr, - const std::string &name = "") - : Instruction(Kind::kLa, pointer->getType(), parent, name) { - assert(pointer); - addOperand(pointer); - addOperands(indices); - } - - public: - unsigned getNumIndices() const { return getNumOperands() - 1; } ///< 获取索引长度 - Value* getPointer() const { return getOperand(0); } ///< 获取目标变量的Value指针 - auto getIndices() const { return make_range(std::next(operand_begin()), operand_end()); } ///< 获取索引列表 - Value* getIndex(unsigned index) const { return getOperand(index + 1); } ///< 获取位置为index的索引分量 -}; - class PhiInst : public Instruction { friend class IRBuilder; friend class Function; @@ -1134,7 +1105,7 @@ public: }; // class CondBrInst //! Allocate memory for stack variables, used for non-global variable declartion -class AllocaInst : public Instruction , public LVal { +class AllocaInst : public Instruction { friend class IRBuilder; friend class Function; protected: @@ -1145,14 +1116,6 @@ protected: } public: - std::vector getLValDims() const override { - std::vector dims; - for (const auto &dim : getOperands()) { - dims.emplace_back(dim->getValue()); - } - return dims; - } ///< 获取作为左值的维度数组 - unsigned getLValNumDims() const override { return getNumOperands(); } int getNumDims() const { return getNumOperands(); } auto getDims() const { return getOperands(); } @@ -1161,37 +1124,40 @@ public: }; // class AllocaInst -class GetSubArrayInst : public Instruction { +class GetElementPtrInst : public Instruction { friend class IRBuilder; - friend class Function; - public: - GetSubArrayInst(LVal *fatherArray, LVal *childArray, const std::vector &indices, - BasicBlock *parent = nullptr, const std::string &name = "") - : Instruction(Kind::kGetSubArray, Type::getVoidType(), parent, name) { - auto predicate = [childArray](const std::unique_ptr &child) -> bool { return child.get() == childArray; }; - if (std::find_if(fatherArray->childrenLVals.begin(), fatherArray->childrenLVals.end(), predicate) == - fatherArray->childrenLVals.end()) { - fatherArray->childrenLVals.emplace_back(childArray); - } - childArray->fatherLVal = fatherArray; - childArray->defineInst = this; - auto fatherArrayValue = dynamic_cast(fatherArray); - auto childArrayValue = dynamic_cast(childArray); - assert(fatherArrayValue); - assert(childArrayValue); - addOperand(fatherArrayValue); - addOperand(childArrayValue); - addOperands(indices); +protected: + // GEP的构造函数: + // resultType: GEP计算出的地址的类型 (通常是指向目标元素类型的指针) + // basePointer: 基指针 (第一个操作数) + // indices: 索引列表 (后续操作数) + GetElementPtrInst(Type *resultType, + Value *basePointer, + const std::vector &indices = {}, + BasicBlock *parent = nullptr, const std::string &name = "") + : Instruction(Kind::kGetElementPtr, resultType, parent, name) { + assert(basePointer && "GEP base pointer cannot be null!"); + // TODO : 安全检查 + assert(basePointer->getType()->isPointer() ); + addOperand(basePointer); // 第一个操作数是基指针 + addOperands(indices); // 随后的操作数是索引 + } +public: + Value* getBasePointer() const { return getOperand(0); } + unsigned getNumIndices() const { return getNumOperands() - 1; } + auto getIndices() const { return make_range(std::next(operand_begin()), operand_end());} + Value* getIndex(unsigned index) const { + assert(index < getNumIndices() && "Index out of bounds for GEP!"); + return getOperand(index + 1); } - public: - Value* getFatherArray() const { return getOperand(0); } ///< 获取父数组 - Value* getChildArray() const { return getOperand(1); } ///< 获取子数组 - LVal* getFatherLVal() const { return dynamic_cast(getOperand(0)); } ///< 获取父左值 - LVal* getChildLVal() const { return dynamic_cast(getOperand(1)); } ///< 获取子左值 - auto getIndices() const { return make_range(std::next(operand_begin(), 2), operand_end()); } ///< 获取索引 - unsigned getNumIndices() const { return getNumOperands() - 2; } ///< 获取索引数量 + // 静态工厂方法,用于创建GEP指令 (如果需要外部直接创建而非通过IRBuilder) + static GetElementPtrInst* create(Type *resultType, Value *basePointer, + const std::vector &indices = {}, + BasicBlock *parent = nullptr, const std::string &name = "") { + return new GetElementPtrInst(resultType, basePointer, indices, parent, name); + } }; //! Load a value from memory address specified by a pointer value @@ -1215,22 +1181,7 @@ public: return make_range(std::next(operand_begin()), operand_end()); } Value* getIndex(int index) const { return getOperand(index + 1); } - std::list getAncestorIndices() const { - std::list indices; - for (const auto &index : getIndices()) { - indices.emplace_back(index->getValue()); - } - auto curPointer = dynamic_cast(getPointer()); - while (curPointer->getFatherLVal() != nullptr) { - auto inserter = std::next(indices.begin()); - for (const auto &index : curPointer->getDefineInst()->getIndices()) { - indices.insert(inserter, index->getValue()); - } - curPointer = curPointer->getFatherLVal(); - } - - return indices; - } ///< 获取相对于祖先数组的索引列表 + }; // class LoadInst //! Store a value to memory address specified by a pointer value @@ -1256,22 +1207,6 @@ public: return make_range(std::next(operand_begin(), 2), operand_end()); } Value* getIndex(int index) const { return getOperand(index + 2); } - std::list getAncestorIndices() const { - std::list indices; - for (const auto &index : getIndices()) { - indices.emplace_back(index->getValue()); - } - auto curPointer = dynamic_cast(getPointer()); - while (curPointer->getFatherLVal() != nullptr) { - auto inserter = std::next(indices.begin()); - for (const auto &index : curPointer->getDefineInst()->getIndices()) { - indices.insert(inserter, index->getValue()); - } - curPointer = curPointer->getFatherLVal(); - } - - return indices; - } ///< 获取相对于祖先数组的索引列表 }; // class StoreInst @@ -1314,7 +1249,7 @@ class Function : public Value { friend class Module; protected: Function(Module *parent, Type *type, const std::string &name) : Value(type, name), parent(parent) { - blocks.emplace_back(new BasicBlock(this)); + blocks.emplace_back(new BasicBlock(this, "entry_" + name)); ///< 创建一个入口基本块 } public: @@ -1373,7 +1308,7 @@ protected: }; //! Global value declared at file scope -class GlobalValue : public User, public LVal { +class GlobalValue : public User { friend class Module; protected: @@ -1407,16 +1342,6 @@ protected: } public: - unsigned getLValNumDims() const override { return numDims; } ///< 获取作为左值的维度数量 - std::vector getLValDims() const override { - std::vector dims; - for (const auto &dim : getOperands()) { - dims.emplace_back(dim->getValue()); - } - - return dims; - } ///< 获取作为左值的维度列表 - unsigned getNumDims() const { return numDims; } ///< 获取维度数量 Value* getDim(unsigned index) const { return getOperand(index); } ///< 获取位置为index的维度 auto getDims() const { return getOperands(); } ///< 获取维度列表 @@ -1438,7 +1363,7 @@ public: }; // class GlobalValue -class ConstantVariable : public User, public LVal { +class ConstantVariable : public User { friend class Module; protected: @@ -1457,15 +1382,6 @@ class ConstantVariable : public User, public LVal { } public: - unsigned getLValNumDims() const override { return numDims; } ///< 获取作为左值的维度数量 - std::vector getLValDims() const override { - std::vector dims; - for (const auto &dim : getOperands()) { - dims.emplace_back(dim->getValue()); - } - - return dims; - } ///< 获取作为左值的维度列表 Value* getByIndex(unsigned index) const { return initValues.getValue(index); } ///< 通过一维位置index获取值 Value* getByIndices(const std::vector &indices) const { int index = 0; @@ -1489,7 +1405,7 @@ class ConstantVariable : public User, public LVal { using SymbolTableNode = struct SymbolTableNode { SymbolTableNode *pNode; ///< 父节点 std::vector children; ///< 子节点列表 - std::map varList; ///< 变量列表 + std::map varList; ///< 变量列表 }; @@ -1504,8 +1420,8 @@ class SymbolTable { public: SymbolTable() = default; - User* getVariable(const std::string &name) const; ///< 根据名字name以及当前作用域获取变量 - User* addVariable(const std::string &name, User *variable); ///< 添加变量 + Value* getVariable(const std::string &name) const; ///< 根据名字name以及当前作用域获取变量 + Value* addVariable(const std::string &name, Value *variable); ///< 添加变量 std::vector>& getGlobals(); ///< 获取全局变量列表 const std::vector>& getConsts() const; ///< 获取常量列表 void enterNewScope(); ///< 进入新的作用域 @@ -1567,7 +1483,7 @@ class Module { void addVariable(const std::string &name, AllocaInst *variable) { variableTable.addVariable(name, variable); } ///< 添加变量 - User* getVariable(const std::string &name) { + Value* getVariable(const std::string &name) { return variableTable.getVariable(name); } ///< 根据名字name和当前作用域获取变量 Function* getFunction(const std::string &name) const { diff --git a/src/include/IRBuilder.h b/src/include/IRBuilder.h index 6df82e7..d9e92ef 100644 --- a/src/include/IRBuilder.h +++ b/src/include/IRBuilder.h @@ -280,46 +280,6 @@ class IRBuilder { block->getInstructions().emplace(position, inst); return inst; } ///< 创建load指令 - LaInst * createLaInst(Value *pointer, const std::vector &indices = {}, const std::string &name = "") { - std::string newName; - if (name.empty()) { - std::stringstream ss; - ss << tmpIndex; - newName = ss.str(); - tmpIndex++; - } else { - newName = name; - } - - auto inst = new LaInst(pointer, indices, block, newName); - assert(inst); - block->getInstructions().emplace(position, inst); - return inst; - } ///< 创建la指令 - GetSubArrayInst * createGetSubArray(LVal *fatherArray, const std::vector &indices, const std::string &name = "") { - assert(fatherArray->getLValNumDims() > indices.size()); - std::vector subDims; - auto dims = fatherArray->getLValDims(); - auto iter = std::next(dims.begin(), indices.size()); - while (iter != dims.end()) { - subDims.emplace_back(*iter); - iter++; - } - - std::string childArrayName; - std::stringstream ss; - ss << "A" - << "%" << tmpIndex; - childArrayName = ss.str(); - tmpIndex++; - - auto fatherArrayValue = dynamic_cast(fatherArray); - auto childArray = new AllocaInst(fatherArrayValue->getType(), subDims, block, childArrayName); - auto inst = new GetSubArrayInst(fatherArray, childArray, indices, block, childArrayName); - assert(inst); - block->getInstructions().emplace(position, inst); - return inst; - } ///< 创建获取部分数组指令 MemsetInst * createMemsetInst(Value *pointer, Value *begin, Value *size, Value *value, const std::string &name = "") { auto inst = new MemsetInst(pointer, begin, size, value, block, name); assert(inst); @@ -334,12 +294,102 @@ class IRBuilder { return inst; } ///< 创建store指令 PhiInst * createPhiInst(Type *type, const std::vector &vals = {}, const std::vector &blks = {}, const std::string &name = "") { - auto predNum = block->getNumPredecessors(); auto inst = new PhiInst(type, vals, blks, block, name); assert(inst); block->getInstructions().emplace(block->begin(), inst); return inst; } ///< 创建Phi指令 + // GetElementPtrInst* createGetElementPtrInst(Value *basePointer, + // const std::vector &indices = {}, + // const std::string &name = "") { + // std::string newName; + // if (name.empty()) { + // std::stringstream ss; + // ss << tmpIndex; + // newName = ss.str(); + // tmpIndex++; + // } else { + // newName = name; + // } + + // auto inst = new GetElementPtrInst(basePointer, indices, block, newName); + // assert(inst); + // block->getInstructions().emplace(position, inst); + // return inst; + // } + /** + * @brief 根据 LLVM 设计模式创建 GEP 指令。 + * 它会自动推断返回类型,无需手动指定。 + */ + GetElementPtrInst *createGetElementPtrInst(Value *basePointer, const std::vector &indices, + const std::string &name = "") { + Type *ResultElementType = getIndexedType(basePointer->getType(), indices); + if (!ResultElementType) { + assert(false && "Invalid GEP indexing!"); + return nullptr; + } + Type *ResultType = PointerType::get(ResultElementType); + std::string newName; + if (name.empty()) { + std::stringstream ss; + ss << tmpIndex; + newName = ss.str(); + tmpIndex++; + } else { + newName = name; + } + + auto inst = new GetElementPtrInst(ResultType, basePointer, indices, block, newName); + assert(inst); + block->getInstructions().emplace(position, inst); + return inst; + } + + static Type *getIndexedType(Type *pointerType, const std::vector &indices) { + assert(pointerType->isPointer() && "base must be a pointer type!"); + // GEP 的类型推断从基指针所指向的类型开始。 + // 例如: + // - 如果 pointerType 是 `[20 x [10 x i32]]*`,`currentWalkType` 初始为 `[20 x [10 x i32]]`。 + // - 如果 pointerType 是 `i32*`,`currentWalkType` 初始为 `i32`。 + // - 如果 pointerType 是 `i32**`,`currentWalkType` 初始为 `i32*`。 + Type *currentWalkType = pointerType->as()->getBaseType(); + + // 遍历所有索引来深入类型层次结构。 + // `indices` 向量包含了所有 GEP 索引,包括由 `visitLValue` 等函数添加的初始 `0` 索引。 + for (int i = 0; i < indices.size(); ++i) { + if (currentWalkType->isArray()) { + // 情况一:当前遍历类型是 `ArrayType`。 + // 索引用于选择数组元素,`currentWalkType` 更新为数组的元素类型。 + currentWalkType = currentWalkType->as()->getElementType(); + } else if (currentWalkType->isPointer()) { + // 情况二:当前遍历类型是 `PointerType`。 + // 这意味着我们正在通过一个指针来访问其指向的内存。 + // 索引用于选择该指针所指向的“数组”的元素。 + // `currentWalkType` 更新为该指针所指向的基础类型。 + // 例如:如果 `currentWalkType` 是 `i32*`,它将变为 `i32`。 + // 如果 `currentWalkType` 是 `[10 x i32]*`,它将变为 `[10 x i32]`。 + currentWalkType = currentWalkType->as()->getBaseType(); + } else { + // 情况三:当前遍历类型是标量类型 (例如 `i32`, `float` 等非聚合、非指针类型)。 + // + // 如果 `currentWalkType` 是标量,并且当前索引 `i` **不是** `indices` 向量中的最后一个索引, + // 这意味着尝试对一个标量类型进行进一步的结构性索引,这是**无效的**。 + // 例如:`int x; x[0];` 对应的 GEP 链中,`x` 的类型是 `i32`,再加 `[0]` 索引就是错误。 + // + // 如果 `currentWalkType` 是标量,且这是**最后一个索引** (`i == indices.size() - 1`), + // 那么 GEP 是合法的,它只是计算一个偏移地址,最终的类型就是这个标量类型。 + // 此时 `currentWalkType` 保持不变,循环结束。 + if (i < indices.size() - 1) { + assert(false && "Invalid GEP indexing: attempting to index into a non-aggregate/non-pointer type with further indices."); + return nullptr; // 返回空指针表示类型推断失败 + } + // 如果是最后一个索引,且当前类型是标量,则类型保持不变,这是合法的。 + // 循环会自然结束,返回正确的 `currentWalkType`。 + } + } + // 所有索引处理完毕后,`currentWalkType` 就是 GEP 指令最终计算出的地址所指向的元素的类型。 + return currentWalkType; + } }; } // namespace sysy diff --git a/src/include/Liveness.h b/src/include/Liveness.h new file mode 100644 index 0000000..f101b7c --- /dev/null +++ b/src/include/Liveness.h @@ -0,0 +1,72 @@ +#pragma once + +#include "IR.h" // 包含 IR 定义 +#include "Pass.h" // 包含 Pass 框架 +#include // for std::set_union, std::set_difference +#include +#include +#include + +namespace sysy { + +// 前向声明 +class Function; +class BasicBlock; +class Value; +class Instruction; + +// 活跃变量分析结果类 +// 它将包含 LiveIn 和 LiveOut 集合 +class LivenessAnalysisResult : public AnalysisResultBase { +public: + LivenessAnalysisResult(Function *F) : AssociatedFunction(F) {} + + // 获取给定基本块的 LiveIn 集合 + const std::set *getLiveIn(BasicBlock *BB) const; + + // 获取给定基本块的 LiveOut 集合 + const std::set *getLiveOut(BasicBlock *BB) const; + + // 暴露内部数据结构,如果需要更直接的访问 + const std::map> &getLiveInSets() const { return liveInSets; } + const std::map> &getLiveOutSets() const { return liveOutSets; } + + // 核心计算方法,由 LivenessAnalysisPass 调用 + void computeLiveness(Function *F); + +private: + Function *AssociatedFunction; // 这个活跃变量分析是为哪个函数计算的 + std::map> liveInSets; + std::map> liveOutSets; + + // 辅助函数:计算基本块的 Def 和 Use 集合 + // Def: 块内定义,且定义在所有使用之前的值 + // Use: 块内使用,且使用在所有定义之前的值 + void computeDefUse(BasicBlock *BB, std::set &def, std::set &use); +}; + +// 活跃变量分析遍 +class LivenessAnalysisPass : public AnalysisPass { +public: + // 唯一的 Pass ID + static void *ID; // LLVM 风格的唯一 ID + + LivenessAnalysisPass() : AnalysisPass("LivenessAnalysis", Pass::Granularity::Function) {} + + // 实现 getPassID + void *getPassID() const override { return &ID; } + + // 运行分析并返回结果。现在接受 AnalysisManager& AM 参数 + bool runOnFunction(Function *F, AnalysisManager &AM) override; + + // 获取分析结果的指针。 + // 注意:AnalysisManager 将会调用此方法来获取结果并进行缓存。 + std::unique_ptr getResult() override; + +private: + // 存储当前分析计算出的 LivenessAnalysisResult 实例 + // runOnFunction 每次调用都会创建新的 LivenessAnalysisResult 对象 + std::unique_ptr CurrentLivenessResult; +}; + +} // namespace sysy \ No newline at end of file diff --git a/src/include/Mem2Reg.h b/src/include/Mem2Reg.h deleted file mode 100644 index 0004708..0000000 --- a/src/include/Mem2Reg.h +++ /dev/null @@ -1,59 +0,0 @@ -#pragma once - -#include -#include -#include -#include -#include -#include "IR.h" -#include "IRBuilder.h" -#include "SysYIRAnalyser.h" - -namespace sysy { -/** - * 实现静态单变量赋值核心类 mem2reg - */ -class Mem2Reg { -private: - Module *pModule; - IRBuilder *pBuilder; - ControlFlowAnalysis *controlFlowAnalysis; // 控制流分析 - ActiveVarAnalysis *activeVarAnalysis; // 活跃变量分析 - DataFlowAnalysisUtils dataFlowAnalysisUtils; - -public: - Mem2Reg(Module *pMoudle, IRBuilder *pBuilder, - ControlFlowAnalysis *pCFA = nullptr, ActiveVarAnalysis *pAVA = nullptr) : - pModule(pMoudle), pBuilder(pBuilder), controlFlowAnalysis(pCFA), activeVarAnalysis(pAVA), dataFlowAnalysisUtils() - {} // 初始化函数 - - void mem2regPipeline(); ///< mem2reg - -private: - - // phi节点的插入需要计算IDF - std::unordered_set computeIterDf(const std::unordered_set &blocks); ///< 计算定义块集合的迭代支配边界 - - auto computeValue2Blocks() -> void; ///< 计算value2block的映射(不包括数组和global) - - auto preOptimize1() -> void; ///< llvm memtoreg预优化1: 删除不含load的alloc和store - auto preOptimize2() -> void; ///< llvm memtoreg预优化2: 针对某个变量的Defblocks只有一个块的情况 - auto preOptimize3() -> void; ///< llvm memtoreg预优化3: 针对某个变量的所有读写都在同一个块中的情况 - - auto insertPhi() -> void; ///< 为所有变量的迭代支配边界插入phi结点 - - auto rename(BasicBlock *block, std::unordered_map &count, - std::unordered_map> &stacks) -> void; ///< 单个块的重命名 - auto renameAll() -> void; ///< 重命名所有块 - - // private helper function. -private: - auto getPredIndex(BasicBlock *n, BasicBlock *s) -> int; ///< 获取前驱索引 - auto cascade(Instruction *instr, bool &changed, Function *func, BasicBlock *block, - std::list> &instrs) -> void; ///< 消除级联关系 - auto isGlobal(Value *val) -> bool; ///< 判断是否是全局变量 - auto isArr(Value *val) -> bool; ///< 判断是否是数组 - auto usedelete(Instruction *instr) -> void; ///< 删除指令相关的value-use-user关系 - -}; -} // namespace sysy diff --git a/src/include/Pass.h b/src/include/Pass.h new file mode 100644 index 0000000..d387e9e --- /dev/null +++ b/src/include/Pass.h @@ -0,0 +1,316 @@ +#pragma once + +#include // For std::function +#include +#include +#include +#include +#include // For std::type_index (although void* ID is more common in LLVM) +#include +#include +#include "IR.h" +#include "IRBuilder.h" + +namespace sysy { + +//前向声明 +class PassManager; +class AnalysisManager; + +// 抽象基类:分析结果 +class AnalysisResultBase { +public: + virtual ~AnalysisResultBase() = default; +}; + +// 抽象基类:Pass +class Pass { +public: + enum class Granularity { Module, Function, BasicBlock }; + + enum class PassKind { Analysis, Optimization }; + + Pass(const std::string &name, Granularity g, PassKind k) : Name(name), G(g), K(k) {} + virtual ~Pass() = default; + + const std::string &getName() const { return Name; } + Granularity getGranularity() const { return G; } + PassKind getPassKind() const { return K; } + + virtual bool runOnModule(Module *M, AnalysisManager& AM) { return false; } + virtual bool runOnFunction(Function *F, AnalysisManager& AM) { return false; } + virtual bool runOnBasicBlock(BasicBlock *BB, AnalysisManager& AM) { return false; } + + // 所有 Pass 都必须提供一个唯一的 ID + // 这通常是一个静态成员,并在 Pass 类外部定义 + virtual void *getPassID() const = 0; + +protected: + std::string Name; + Granularity G; + PassKind K; +}; + +// 抽象基类:分析遍 +class AnalysisPass : public Pass { +public: + AnalysisPass(const std::string &name, Granularity g) : Pass(name, g, PassKind::Analysis) {} + + virtual std::unique_ptr getResult() = 0; +}; + +// 抽象基类:优化遍 +class OptimizationPass : public Pass { +public: + OptimizationPass(const std::string &name, Granularity g) : Pass(name, g, PassKind::Optimization) {} + + virtual void getAnalysisUsage(std::set &analysisDependencies, std::set &analysisInvalidations) const { + // 默认不依赖也不修改任何分析 + } +}; + +// ====================================================================== +// PassRegistry: 全局 Pass 注册表 (单例) +// ====================================================================== +class PassRegistry { +public: + // Pass 工厂函数类型:返回 Pass 的唯一指针 + using PassFactory = std::function()>; + + // 获取 PassRegistry 实例 (单例模式) + static PassRegistry &getPassRegistry() { + static PassRegistry instance; + return instance; + } + + // 注册一个 Pass + // passID 是 Pass 类的唯一静态 ID (例如 MyPass::ID 的地址) + // factory 是一个 lambda 或函数指针,用于创建该 Pass 的实例 + void registerPass(void *passID, PassFactory factory) { + if (factories.count(passID)) { + // Error: Pass with this ID already registered + // You might want to throw an exception or log an error + return; + } + factories[passID] = std::move(factory); + } + + // 通过 Pass ID 创建一个 Pass 实例 + std::unique_ptr createPass(void *passID) { + auto it = factories.find(passID); + if (it == factories.end()) { + // Error: Pass with this ID not registered + return nullptr; + } + return it->second(); // 调用工厂函数创建实例 + } + +private: + PassRegistry() = default; // 私有构造函数,实现单例 + ~PassRegistry() = default; + PassRegistry(const PassRegistry &) = delete; // 禁用拷贝构造 + PassRegistry &operator=(const PassRegistry &) = delete; // 禁用赋值操作 + + std::map factories; +}; + +// ====================================================================== +// AnalysisManager: 负责管理和提供分析结果 +// ====================================================================== +class AnalysisManager { +private: + Module *pModuleRef; // 指向被分析的Module + + // 缓存不同粒度的分析结果 + std::map> moduleCachedResults; + std::map, std::unique_ptr> functionCachedResults; + std::map, std::unique_ptr> basicBlockCachedResults; + + +public: + // 构造函数接收 Module 指针 + AnalysisManager(Module *M) : pModuleRef(M) {} + AnalysisManager() = delete; // 禁止无参构造 + + ~AnalysisManager() = default; + + // 获取分析结果的通用模板函数 + // T 是 AnalysisResult 的具体类型,E 是 AnalysisPass 的具体类型 + // F 和 BB 参数用于提供上下文,根据分析遍的粒度来使用 + template T *getAnalysisResult(Function *F = nullptr, BasicBlock *BB = nullptr) { + void *analysisID = E::ID; // 获取分析遍的唯一 ID + + // 尝试从注册表创建分析遍实例 + std::unique_ptr basePass = PassRegistry::getPassRegistry().createPass(analysisID); + if (!basePass) { + // Error: Analysis pass not registered + std::cerr << "Error: Analysis pass with ID " << analysisID << " not registered.\n"; + return nullptr; + } + AnalysisPass *analysisPass = static_cast(basePass.get()); + + // 根据分析遍的粒度处理 + switch (analysisPass->getGranularity()) { + case Pass::Granularity::Module: { + // 检查是否已存在有效结果 + auto it = moduleCachedResults.find(analysisID); + if (it != moduleCachedResults.end()) { + return static_cast(it->second.get()); // 返回缓存结果 + } + // 运行模块级分析遍 + if (!pModuleRef) { + std::cerr << "Error: Module reference not set for AnalysisManager to run Module Pass.\n"; + return nullptr; + } + analysisPass->runOnModule(pModuleRef, *this); + // 获取结果并缓存 + std::unique_ptr result = analysisPass->getResult(); + T *specificResult = static_cast(result.get()); + moduleCachedResults[analysisID] = std::move(result); // 缓存结果 + return specificResult; + } + case Pass::Granularity::Function: { + // 检查请求的上下文是否正确 + if (!F) { + std::cerr << "Error: Function context required for Function-level Analysis Pass.\n"; + return nullptr; + } + // 检查是否已存在有效结果 + auto it = functionCachedResults.find({F, analysisID}); + if (it != functionCachedResults.end()) { + return static_cast(it->second.get()); // 返回缓存结果 + } + // 运行函数级分析遍 + analysisPass->runOnFunction(F, *this); + // 获取结果并缓存 + std::unique_ptr result = analysisPass->getResult(); + T *specificResult = static_cast(result.get()); + functionCachedResults[{F, analysisID}] = std::move(result); // 缓存结果 + return specificResult; + } + case Pass::Granularity::BasicBlock: { + // 检查请求的上下文是否正确 + if (!BB) { + std::cerr << "Error: BasicBlock context required for BasicBlock-level Analysis Pass.\n"; + return nullptr; + } + // 检查是否已存在有效结果 + auto it = basicBlockCachedResults.find({BB, analysisID}); + if (it != basicBlockCachedResults.end()) { + return static_cast(it->second.get()); // 返回缓存结果 + } + // 运行基本块级分析遍 + analysisPass->runOnBasicBlock(BB, *this); + // 获取结果并缓存 + std::unique_ptr result = analysisPass->getResult(); + T *specificResult = static_cast(result.get()); + basicBlockCachedResults[{BB, analysisID}] = std::move(result); // 缓存结果 + return specificResult; + } + } + return nullptr; // 不会到达这里 + } + + // 使所有分析结果失效 (当 IR 被修改时调用) + void invalidateAllAnalyses() { + moduleCachedResults.clear(); + functionCachedResults.clear(); + basicBlockCachedResults.clear(); + } + + // 使特定分析结果失效 + // void *analysisID: 要失效的分析的ID + // Function *F: 如果是函数级分析,指定函数;如果是模块级或基本块级,则为nullptr (取决于调用方式) + // BasicBlock *BB: 如果是基本块级分析,指定基本块;否则为nullptr + void invalidateAnalysis(void *analysisID, Function *F = nullptr, BasicBlock *BB = nullptr) { + if (BB) { + // 使特定基本块的特定分析结果失效 + basicBlockCachedResults.erase({BB, analysisID}); + } else if (F) { + // 使特定函数的特定分析结果失效 (也可能包含聚合的BasicBlock结果) + functionCachedResults.erase({F, analysisID}); + // 遍历所有属于F的基本块,使其BasicBlockCache失效 (如果该分析是BasicBlock粒度的) + // 这需要遍历F的所有基本块,效率较低,更推荐在BasicBlockPass的invalidateAnalysisUsage中精确指定 + // 或者在Function级别的invalidate时,清空该Function的所有BasicBlock分析 + // 这里的实现简单地清空该Function下所有该ID的BasicBlock缓存 + for (auto it = basicBlockCachedResults.begin(); it != basicBlockCachedResults.end(); ) { + // 假设BasicBlock::getParent()方法存在,可以获取所属Function + if (it->first.second == analysisID /* && it->first.first->getParent() == F */) { // 需要BasicBlock能获取其父函数 + it = basicBlockCachedResults.erase(it); + } else { + ++it; + } + } + + } else { + // 使所有函数的特定分析结果失效 (Module级和所有Function/BasicBlock级) + moduleCachedResults.erase(analysisID); + for (auto it = functionCachedResults.begin(); it != functionCachedResults.end(); ) { + if (it->first.second == analysisID) { + it = functionCachedResults.erase(it); + } else { + ++it; + } + } + for (auto it = basicBlockCachedResults.begin(); it != basicBlockCachedResults.end(); ) { + if (it->first.second == analysisID) { + it = basicBlockCachedResults.erase(it); + } else { + ++it; + } + } + } + } +}; + +// ====================================================================== +// PassManager:遍管理器 +// ====================================================================== +class PassManager { +private: + std::vector> passes; + AnalysisManager analysisManager; + Module *pmodule; + IRBuilder *pBuilder; + +public: + PassManager() = default; + ~PassManager() = default; + + PassManager(Module *module, IRBuilder *builder) : pmodule(module) ,pBuilder(builder), analysisManager(module) {} + + // 运行所有注册的遍 + bool run(); + + // 运行优化管道主要负责注册和运行优化遍 + // 这里可以根据 optLevel 和 DEBUG 控制不同的优化遍 + void runOptimizationPipeline(Module* moduleIR, IRBuilder* builder, int optLevel); + + // 添加遍:现在接受 Pass 的 ID,而不是直接的 unique_ptr + void addPass(void *passID); + + AnalysisManager &getAnalysisManager() { return analysisManager; } + + void clearPasses(); +}; + +// ====================================================================== +// 辅助宏或函数,用于简化 Pass 的注册 +// ====================================================================== + +// 用于分析遍的注册 +template void registerAnalysisPass(); + +// (1) 针对需要 IRBuilder 参数的优化遍的重载 +// 这个模板只在 OptimizationPassType 可以通过 IRBuilder* 构造时才有效 +template ::value, int>::type = 0> +void registerOptimizationPass(IRBuilder* builder); + +// (2) 针对不需要 IRBuilder 参数的所有其他优化遍的重载 +// 这个模板只在 OptimizationPassType 不能通过 IRBuilder* 构造时才有效 +template ::value, int>::type = 0> +void registerOptimizationPass(); + +} // namespace sysy \ No newline at end of file diff --git a/src/include/RISCv64ISel.h b/src/include/RISCv64ISel.h index 0bb977a..472edfe 100644 --- a/src/include/RISCv64ISel.h +++ b/src/include/RISCv64ISel.h @@ -33,6 +33,8 @@ private: std::vector> build_dag(BasicBlock* bb); DAGNode* get_operand_node(Value* val_ir, std::map&, std::vector>&); DAGNode* create_node(int kind, Value* val, std::map&, std::vector>&); + // 用于计算类型大小的辅助函数 + unsigned getTypeSizeInBytes(Type* type); void print_dag(const std::vector>& dag, const std::string& bb_name); diff --git a/src/include/RISCv64LLIR.h b/src/include/RISCv64LLIR.h index 86de7d4..d8797bc 100644 --- a/src/include/RISCv64LLIR.h +++ b/src/include/RISCv64LLIR.h @@ -44,9 +44,11 @@ enum class RVOpcodes { // 特殊标记,非指令 LABEL, // 新增伪指令,用于解耦栈帧处理 - FRAME_LOAD, // 从栈帧加载 (AllocaInst) - FRAME_STORE, // 保存到栈帧 (AllocaInst) - FRAME_ADDR, // [新] 获取栈帧变量的地址 + FRAME_LOAD_W, // 从栈帧加载 32位 Word (对应 lw) + FRAME_LOAD_D, // 从栈帧加载 64位 Doubleword (对应 ld) + FRAME_STORE_W, // 保存 32位 Word 到栈帧 (对应 sw) + FRAME_STORE_D, // 保存 64位 Doubleword 到栈帧 (对应 sd) + FRAME_ADDR, // 获取栈帧变量的地址 }; class MachineOperand; diff --git a/src/include/RISCv64Passes.h b/src/include/RISCv64Passes.h index 7205b10..d2da152 100644 --- a/src/include/RISCv64Passes.h +++ b/src/include/RISCv64Passes.h @@ -6,13 +6,13 @@ namespace sysy { /** - * @class Pass + * @class BackendPass * @brief 所有优化Pass的抽象基类 (可选,但推荐) * * 定义一个通用的接口,所有优化都应该实现它。 */ -class Pass { +class BackendPass { public: - virtual ~Pass() = default; + virtual ~BackendPass() = default; virtual void runOnMachineFunction(MachineFunction* mfunc) = 0; }; @@ -25,7 +25,7 @@ public: * * 在虚拟寄存器上进行操作,此时调度自由度最大, * 主要目标是隐藏指令延迟,提高流水线效率。 */ -class PreRA_Scheduler : public Pass { +class PreRA_Scheduler : public BackendPass { public: void runOnMachineFunction(MachineFunction* mfunc) override; }; @@ -39,7 +39,7 @@ public: * * 在已分配物理寄存器的指令流上,通过一个小的滑动窗口来查找 * 并替换掉一些冗余或低效的指令模式。 */ -class PeepholeOptimizer : public Pass { +class PeepholeOptimizer : public BackendPass { public: void runOnMachineFunction(MachineFunction* mfunc) override; }; @@ -50,7 +50,7 @@ public: * * 主要目标是优化寄存器分配器插入的spill/fill代码(lw/sw), * 尝试将加载指令提前,以隐藏其访存延迟。 */ -class PostRA_Scheduler : public Pass { +class PostRA_Scheduler : public BackendPass { public: void runOnMachineFunction(MachineFunction* mfunc) override; }; diff --git a/src/include/RISCv64RegAlloc.h b/src/include/RISCv64RegAlloc.h index c786bde..724ad1c 100644 --- a/src/include/RISCv64RegAlloc.h +++ b/src/include/RISCv64RegAlloc.h @@ -49,6 +49,13 @@ private: // 可用的物理寄存器池 std::vector allocable_int_regs; + + // 存储vreg到IR Value*的反向映射 + // 这个map将在run()函数开始时被填充,并在rewriteFunction()中使用。 + std::map vreg_to_value_map; + + // 用于计算类型大小的辅助函数 + unsigned getTypeSizeInBytes(Type* type); }; } // namespace sysy diff --git a/src/include/Reg2Mem.h b/src/include/Reg2Mem.h deleted file mode 100644 index 6249d71..0000000 --- a/src/include/Reg2Mem.h +++ /dev/null @@ -1,23 +0,0 @@ -#pragma once - -#include "IR.h" -#include "IRBuilder.h" - -namespace sysy { -/** - * Reg2Mem(后端未做phi指令翻译) - */ -class Reg2Mem { -private: - Module *pModule; - IRBuilder *pBuilder; - -public: - Reg2Mem(Module *pMoudle, IRBuilder *pBuilder) : pModule(pMoudle), pBuilder(pBuilder) {} - - void DeletePhiInst(); - // 删除UD关系, 因为删除了phi指令会修改ud关系 - void usedelete(Instruction *instr); -}; - -} // namespace sysy \ No newline at end of file diff --git a/src/include/SCCP.h b/src/include/SCCP.h new file mode 100644 index 0000000..7db0a7b --- /dev/null +++ b/src/include/SCCP.h @@ -0,0 +1,196 @@ +#pragma once + +#include "IR.h" + +namespace sysy { + +// 稀疏条件常量传播类 +// Sparse Conditional Constant Propagation +/* +伪代码 +function SCCP_Optimization(Module): + for each Function in Module: + changed = true + while changed: + changed = false + // 阶段1: 常量传播与折叠 + changed |= PropagateConstants(Function) + // 阶段2: 控制流简化 + changed |= SimplifyControlFlow(Function) + end while + end for + +function PropagateConstants(Function): + // 初始化 + executableBlocks = {entryBlock} + valueState = map // 值->状态映射 + instWorkList = Queue() + edgeWorkList = Queue() + + // 初始化工作列表 + for each inst in entryBlock: + instWorkList.push(inst) + + // 迭代处理 + while !instWorkList.empty() || !edgeWorkList.empty(): + // 处理指令工作列表 + while !instWorkList.empty(): + inst = instWorkList.pop() + // 如果指令是可执行基本块中的 + if executableBlocks.contains(inst.parent): + ProcessInstruction(inst) + + // 处理边工作列表 + while !edgeWorkList.empty(): + edge = edgeWorkList.pop() + ProcessEdge(edge) + + // 应用常量替换 + for each inst in Function: + if valueState[inst] == CONSTANT: + ReplaceWithConstant(inst, valueState[inst].constant) + changed = true + + return changed + +function ProcessInstruction(Instruction inst): + switch inst.type: + //二元操作 + case BINARY_OP: + lhs = GetValueState(inst.operands[0]) + rhs = GetValueState(inst.operands[1]) + if lhs == CONSTANT && rhs == CONSTANT: + newState = ComputeConstant(inst.op, lhs.value, rhs.value) + UpdateState(inst, newState) + else if lhs == BOTTOM || rhs == BOTTOM: + UpdateState(inst, BOTTOM) + //phi + case PHI: + mergedState = ⊤ + for each incoming in inst.incomings: + // 检查每个输入的状态 + if executableBlocks.contains(incoming.block): + incomingState = GetValueState(incoming.value) + mergedState = Meet(mergedState, incomingState) + UpdateState(inst, mergedState) + // 条件分支 + case COND_BRANCH: + cond = GetValueState(inst.condition) + if cond == CONSTANT: + // 判断条件分支 + if cond.value == true: + AddEdgeToWorkList(inst.parent, inst.trueTarget) + else: + AddEdgeToWorkList(inst.parent, inst.falseTarget) + else if cond == BOTTOM: + AddEdgeToWorkList(inst.parent, inst.trueTarget) + AddEdgeToWorkList(inst.parent, inst.falseTarget) + + case UNCOND_BRANCH: + AddEdgeToWorkList(inst.parent, inst.target) + + // 其他指令处理... + +function ProcessEdge(Edge edge): + fromBB, toBB = edge + if !executableBlocks.contains(toBB): + executableBlocks.add(toBB) + for each inst in toBB: + if inst is PHI: + instWorkList.push(inst) + else: + instWorkList.push(inst) // 非PHI指令 + + // 更新PHI节点的输入 + for each phi in toBB.phis: + instWorkList.push(phi) + +function SimplifyControlFlow(Function): + changed = false + // 标记可达基本块 + ReachableBBs = FindReachableBlocks(Function.entry) + + // 删除不可达块 + for each bb in Function.blocks: + if !ReachableBBs.contains(bb): + RemoveDeadBlock(bb) + changed = true + + // 简化条件分支 + for each bb in Function.blocks: + terminator = bb.terminator + if terminator is COND_BRANCH: + cond = GetValueState(terminator.condition) + if cond == CONSTANT: + SimplifyBranch(terminator, cond.value) + changed = true + + return changed + +function RemoveDeadBlock(BasicBlock bb): + // 1. 更新前驱块的分支指令 + for each pred in bb.predecessors: + UpdateTerminator(pred, bb) + + // 2. 更新后继块的PHI节点 + for each succ in bb.successors: + RemovePhiIncoming(succ, bb) + + // 3. 删除块内所有指令 + for each inst in bb.instructions: + inst.remove() + + // 4. 从函数中移除基本块 + Function.removeBlock(bb) + +function Meet(State a, State b): + if a == ⊤: return b + if b == ⊤: return a + if a == ⊥ || b == ⊥: return ⊥ + if a.value == b.value: return a + return ⊥ + +function UpdateState(Value v, State newState): + oldState = valueState.get(v, ⊤) + if newState != oldState: + valueState[v] = newState + for each user in v.users: + if user is Instruction: + instWorkList.push(user) + +*/ + +enum class LatticeValue { + Top, // ⊤ (Unknown) + Constant, // c (Constant) + Bottom // ⊥ (Undefined / Varying) +}; +// LatticeValue: 用于表示值的状态,Top表示未知,Constant表示常量,Bottom表示未定义或变化的值。 +// 这里的LatticeValue用于跟踪每个SSA值(变量、指令结果)的状态, +// 以便在SCCP过程中进行常量传播和控制流简化。 + +//TODO: 下列数据结构考虑集成到类中,避免重命名问题 +static std::set Worklist; +static std::unordered_set Executable_Blocks; +static std::queue > Executable_Edges; +static std::map valueState; + +class SCCP { +private: + Module *pModule; + +public: + SCCP(Module *pMoudle) : pModule(pMoudle) {} + + void run(); + bool PropagateConstants(Function *function); + bool SimplifyControlFlow(Function *function); + void ProcessInstruction(Instruction *inst); + void ProcessEdge(const std::pair &edge); + void RemoveDeadBlock(BasicBlock *bb); + void UpdateState(Value *v, LatticeValue newState); + LatticeValue Meet(LatticeValue a, LatticeValue b); + LatticeValue GetValueState(Value *v); +}; + +} // namespace sysy diff --git a/src/include/SysYFormatter.h b/src/include/SysYFormatter.h deleted file mode 100644 index d4c9fb7..0000000 --- a/src/include/SysYFormatter.h +++ /dev/null @@ -1,340 +0,0 @@ -#pragma once - -#include "SysYBaseVisitor.h" -#include "SysYParser.h" -#include - -namespace sysy { - -class SysYFormatter : public SysYBaseVisitor { -protected: - std::ostream &os; - int indent = 0; - -public: - SysYFormatter(std::ostream &os) : os(os), indent(0) {} - -protected: - struct Indentor { - static constexpr int TabSize = 2; - int &indent; - Indentor(int &indent) : indent(indent) { indent += TabSize; } - ~Indentor() { indent -= TabSize; } - }; - std::ostream &space() { return os << std::string(indent, ' '); } - template - std::ostream &interleave(const T &container, const std::string sep = ", ") { - auto b = container.begin(), e = container.end(); - (*b)->accept(this); - for (b = std::next(b); b != e; b = std::next(b)) { - os << sep; - (*b)->accept(this); - } - return os; - } - -public: - // virtual std::any visitModule(SysYParser::ModuleContext *ctx) override { - // return visitChildren(ctx); - // } - - virtual std::any visitBtype(SysYParser::BtypeContext *ctx) override { - os << ctx->getText(); - return 0; - } - - virtual std::any visitDecl(SysYParser::DeclContext *ctx) override { - space(); - if (ctx->CONST()) - os << ctx->CONST()->getText() << ' '; - ctx->btype()->accept(this); - os << ' '; - interleave(ctx->varDef(), ", ") << ';' << '\n'; - return 0; - } - - virtual std::any visitVarDef(SysYParser::VarDefContext *ctx) override { - ctx->lValue()->accept(this); - if (ctx->initValue()) { - os << ' ' << '=' << ' '; - ctx->initValue()->accept(this); - } - return 0; - } - - virtual std::any visitInitValue(SysYParser::InitValueContext *ctx) override { - if (not ctx->exp()) { - os << '{'; - auto values = ctx->initValue(); - if (values.size()) - interleave(values, ", "); - os << '}'; - } - return 0; - } - - virtual std::any visitFunc(SysYParser::FuncContext *ctx) override { - ctx->funcType()->accept(this); - os << ' ' << ctx->ID()->getText() << '('; - if (ctx->funcFParams()) - ctx->funcFParams()->accept(this); - os << ')' << ' '; - ctx->blockStmt()->accept(this); - os << '\n'; - return 0; - } - - virtual std::any visitFuncType(SysYParser::FuncTypeContext *ctx) override { - os << ctx->getText(); - return 0; - } - - virtual std::any - visitFuncFParams(SysYParser::FuncFParamsContext *ctx) override { - interleave(ctx->funcFParam(), ", "); - return 0; - } - - virtual std::any - visitFuncFParam(SysYParser::FuncFParamContext *ctx) override { - ctx->btype()->accept(this); - os << ' ' << ctx->ID()->getText(); - if (not ctx->LBRACKET().empty()) { - os << '['; - auto exp = ctx->exp(); - if (not exp.empty()) { - os << '['; - interleave(exp, "][") << ']'; - } - } - return 0; - } - - virtual std::any visitBlockStmt(SysYParser::BlockStmtContext *ctx) override { - os << '{' << '\n'; - { - Indentor indentor(indent); - auto items = ctx->blockItem(); - if (not items.empty()) - interleave(items, ""); - } - space() << ctx->RBRACE()->getText() << '\n'; - return 0; - } - - // virtual std::any visitBlockItem(SysYParser::BlockItemContext *ctx) - // override { - // return visitChildren(ctx); - // } - - // virtual std::any visitStmt(SysYParser::StmtContext *ctx) override { - // return visitChildren(ctx); - // } - - virtual std::any - visitAssignStmt(SysYParser::AssignStmtContext *ctx) override { - space(); - ctx->lValue()->accept(this); - os << " = "; - ctx->exp()->accept(this); - os << ';' << '\n'; - return 0; - } - - virtual std::any visitExpStmt(SysYParser::ExpStmtContext *ctx) override { - space(); - ctx->exp()->accept(this); - os << ';' << '\n'; - return 0; - } - - void wrapBlock(SysYParser::StmtContext *stmt) { - bool isBlock = stmt->blockStmt(); - if (isBlock) { - stmt->accept(this); - } else { - os << "{\n"; - { - Indentor indentor(indent); - stmt->accept(this); - } - space() << "}\n"; - } - }; - virtual std::any visitIfStmt(SysYParser::IfStmtContext *ctx) override { - space(); - os << ctx->IF()->getText() << " ("; - ctx->exp()->accept(this); - os << ") "; - auto stmt = ctx->stmt(); - auto ifStmt = stmt[0]; - wrapBlock(ifStmt); - if (stmt.size() == 2) { - auto elseStmt = stmt[1]; - wrapBlock(elseStmt); - } - return 0; - } - - virtual std::any visitWhileStmt(SysYParser::WhileStmtContext *ctx) override { - space(); - os << ctx->WHILE()->getText() << " ("; - ctx->exp()->accept(this); - os << ") "; - wrapBlock(ctx->stmt()); - return 0; - } - - virtual std::any visitBreakStmt(SysYParser::BreakStmtContext *ctx) override { - space() << ctx->BREAK()->getText() << ';' << '\n'; - return 0; - } - - virtual std::any - visitContinueStmt(SysYParser::ContinueStmtContext *ctx) override { - space() << ctx->CONTINUE()->getText() << ';' << '\n'; - return 0; - } - - virtual std::any - visitReturnStmt(SysYParser::ReturnStmtContext *ctx) override { - space() << ctx->RETURN()->getText(); - if (ctx->exp()) { - os << ' '; - ctx->exp()->accept(this); - } - os << ';' << '\n'; - return 0; - } - - // virtual std::any visitEmptyStmt(SysYParser::EmptyStmtContext *ctx) - // override { - // return visitChildren(ctx); - // } - - virtual std::any - visitRelationExp(SysYParser::RelationExpContext *ctx) override { - auto lhs = ctx->exp(0); - auto rhs = ctx->exp(1); - std::string op = - ctx->LT() ? "<" : (ctx->LE() ? "<=" : (ctx->GT() ? ">" : ">=")); - lhs->accept(this); - os << ' ' << op << ' '; - rhs->accept(this); - return 0; - } - - virtual std::any - visitMultiplicativeExp(SysYParser::MultiplicativeExpContext *ctx) override { - auto lhs = ctx->exp(0); - auto rhs = ctx->exp(1); - std::string op = ctx->MUL() ? "*" : (ctx->DIV() ? "/" : "%"); - lhs->accept(this); - os << ' ' << op << ' '; - rhs->accept(this); - return 0; - } - - // virtual std::any visitLValueExp(SysYParser::LValueExpContext *ctx) - // override { - // return visitChildren(ctx); - // } - - // virtual std::any visitNumberExp(SysYParser::NumberExpContext *ctx) - // override { - // return visitChildren(ctx); - // } - - virtual std::any visitAndExp(SysYParser::AndExpContext *ctx) override { - ctx->exp(0)->accept(this); - os << " && "; - ctx->exp(1)->accept(this); - return 0; - } - - virtual std::any visitUnaryExp(SysYParser::UnaryExpContext *ctx) override { - std::string op = ctx->ADD() ? "+" : (ctx->SUB() ? "-" : "!"); - os << op; - ctx->exp()->accept(this); - return 0; - } - - virtual std::any visitParenExp(SysYParser::ParenExpContext *ctx) override { - os << '('; - ctx->exp()->accept(this); - os << ')'; - return 0; - } - - virtual std::any visitStringExp(SysYParser::StringExpContext *ctx) override { - return visitChildren(ctx); - } - - virtual std::any visitOrExp(SysYParser::OrExpContext *ctx) override { - ctx->exp(0)->accept(this); - os << " || "; - ctx->exp(1)->accept(this); - return 0; - } - - // virtual std::any visitCallExp(SysYParser::CallExpContext *ctx) override { - // return visitChildren(ctx); - // } - - virtual std::any - visitAdditiveExp(SysYParser::AdditiveExpContext *ctx) override { - auto lhs = ctx->exp(0); - auto rhs = ctx->exp(1); - std::string op = ctx->ADD() ? "+" : "-"; - lhs->accept(this); - os << ' ' << op << ' '; - rhs->accept(this); - return 0; - } - - virtual std::any visitEqualExp(SysYParser::EqualExpContext *ctx) override { - auto lhs = ctx->exp(0); - auto rhs = ctx->exp(1); - std::string op = ctx->EQ() ? "==" : "!="; - lhs->accept(this); - os << ' ' << op << ' '; - rhs->accept(this); - return 0; - } - - virtual std::any visitCall(SysYParser::CallContext *ctx) override { - os << ctx->ID()->getText() << '('; - if (ctx->funcRParams()) - ctx->funcRParams()->accept(this); - os << ')'; - return 0; - } - - virtual std::any visitLValue(SysYParser::LValueContext *ctx) override { - os << ctx->ID()->getText(); - auto exp = ctx->exp(); - if (not exp.empty()) { - os << '['; - interleave(exp, "][") << ']'; - } - return 0; - } - - virtual std::any visitNumber(SysYParser::NumberContext *ctx) override { - os << ctx->getText(); - return 0; - } - - virtual std::any visitString(SysYParser::StringContext *ctx) override { - os << ctx->getText(); - return 0; - } - - virtual std::any - visitFuncRParams(SysYParser::FuncRParamsContext *ctx) override { - interleave(ctx->exp(), ", "); - return 0; - } -}; - -} // namespace sysy diff --git a/src/include/SysYIRCFGOpt.h b/src/include/SysYIRCFGOpt.h new file mode 100644 index 0000000..a7ba08b --- /dev/null +++ b/src/include/SysYIRCFGOpt.h @@ -0,0 +1,101 @@ +#pragma once + +#include "IR.h" +#include "IRBuilder.h" +#include "Pass.h" + +namespace sysy { + +// 优化前对SysY IR的预处理,也可以视作部分CFG优化 +// 主要包括删除无用指令、合并基本块、删除空块等 +// 这些操作可以在SysY IR生成时就完成,但为了简化IR生成过程, +// 这里将其放在SysY IR生成后进行预处理 +// 同时兼容phi节点的处理,可以再mem2reg后再次调用优化 + +//TODO: 可增加的CFG优化和方法 +// - 检查基本块跳转关系正确性 +// - 简化条件分支(Branch Simplification),如条件恒真/恒假转为直接跳转 +// - 合并连续的跳转指令(Jump Threading)在合并不可达块中似乎已经实现了 +// - 基本块重排序(Block Reordering),提升局部性 + +// 辅助工具类,包含实际的CFG优化逻辑 +// 这些方法可以被独立的Pass调用 +class SysYCFGOptUtils { +public: + static bool SysYDelInstAfterBr(Function *func); // 删除br后面的指令 + static bool SysYDelEmptyBlock(Function *func, IRBuilder* pBuilder); // 空块删除 + static bool SysYDelNoPreBLock(Function *func); // 删除无前驱块(不可达块) + static bool SysYBlockMerge(Function *func); // 合并基本块 + static bool SysYAddReturn(Function *func, IRBuilder* pBuilder); // 添加return指令 + static bool SysYCondBr2Br(Function *func, IRBuilder* pBuilder); // 条件分支转换为无条件分支 +}; + +// ====================================================================== +// 独立的CFG优化遍 +// ====================================================================== + +class SysYDelInstAfterBrPass : public OptimizationPass { +public: + static void *ID; // 唯一ID + SysYDelInstAfterBrPass() : OptimizationPass("SysYDelInstAfterBrPass", Granularity::Function) {} + bool runOnFunction(Function *F, AnalysisManager& AM) override; + void getAnalysisUsage(std::set &analysisDependencies, std::set &analysisInvalidations) const override { + // 这个优化可能改变CFG结构,使一些CFG相关的分析失效 + // 可以在这里指定哪些分析会失效,例如支配树、活跃变量等 + // analysisInvalidations.insert(DominatorTreeAnalysisPass::ID); // 示例 + } + void *getPassID() const override { return &ID; } +}; + +class SysYDelEmptyBlockPass : public OptimizationPass { +private: + IRBuilder *pBuilder; +public: + static void *ID; + SysYDelEmptyBlockPass(IRBuilder *builder) : OptimizationPass("SysYDelEmptyBlockPass", Granularity::Function), pBuilder(builder) {} + bool runOnFunction(Function *F, AnalysisManager& AM) override; + void getAnalysisUsage(std::set &analysisDependencies, std::set &analysisInvalidations) const override {}; + void *getPassID() const override { return &ID; } +}; + +class SysYDelNoPreBLockPass : public OptimizationPass { +public: + static void *ID; + SysYDelNoPreBLockPass() : OptimizationPass("SysYDelNoPreBLockPass", Granularity::Function) {} + bool runOnFunction(Function *F, AnalysisManager& AM) override; + void getAnalysisUsage(std::set &analysisDependencies, std::set &analysisInvalidations) const override {}; + void *getPassID() const override { return &ID; } +}; + +class SysYBlockMergePass : public OptimizationPass { +public: + static void *ID; + SysYBlockMergePass() : OptimizationPass("SysYBlockMergePass", Granularity::Function) {} + bool runOnFunction(Function *F, AnalysisManager& AM) override; + void getAnalysisUsage(std::set &analysisDependencies, std::set &analysisInvalidations) const override {}; + void *getPassID() const override { return &ID; } +}; + +class SysYAddReturnPass : public OptimizationPass { +private: + IRBuilder *pBuilder; +public: + static void *ID; + SysYAddReturnPass(IRBuilder *builder) : OptimizationPass("SysYAddReturnPass", Granularity::Function), pBuilder(builder) {} + bool runOnFunction(Function *F, AnalysisManager& AM) override; + void getAnalysisUsage(std::set &analysisDependencies, std::set &analysisInvalidations) const override {}; + void *getPassID() const override { return &ID; } +}; + +class SysYCondBr2BrPass : public OptimizationPass { +private: + IRBuilder *pBuilder; +public: + static void *ID; + SysYCondBr2BrPass(IRBuilder *builder) : OptimizationPass("SysYCondBr2BrPass", Granularity::Function), pBuilder(builder) {} + bool runOnFunction(Function *F, AnalysisManager& AM) override; + void getAnalysisUsage(std::set &analysisDependencies, std::set &analysisInvalidations) const override {}; + void *getPassID() const override { return &ID; } +}; + +} // namespace sysy \ No newline at end of file diff --git a/src/include/SysYIRGenerator.h b/src/include/SysYIRGenerator.h index fe309e8..aac6ec9 100644 --- a/src/include/SysYIRGenerator.h +++ b/src/include/SysYIRGenerator.h @@ -62,12 +62,11 @@ private: public: SysYIRGenerator() = default; - bool HasReturnInst; - public: Module *get() const { return module.get(); } IRBuilder *getBuilder(){ return &builder; } public: + std::any visitCompUnit(SysYParser::CompUnitContext *ctx) override; std::any visitGlobalConstDecl(SysYParser::GlobalConstDeclContext *ctx) override; @@ -134,6 +133,13 @@ public: // std::any visitConstExp(SysYParser::ConstExpContext *ctx) override; +public: + // 获取GEP指令的地址 + Value* getGEPAddressInst(Value* basePointer, const std::vector& indices); + // 构建数组类型 + Type* buildArrayType(Type* baseType, const std::vector& dims); + + unsigned countArrayDimensions(Type* type); }; // class SysYIRGenerator diff --git a/src/include/SysYIROptPre.h b/src/include/SysYIROptPre.h deleted file mode 100644 index 4f0bdca..0000000 --- a/src/include/SysYIROptPre.h +++ /dev/null @@ -1,37 +0,0 @@ -#pragma once - -#include "IR.h" -#include "IRBuilder.h" - -namespace sysy { - -// 优化前对SysY IR的预处理,也可以视作部分CFG优化 -// 主要包括删除无用指令、合并基本块、删除空块等 -// 这些操作可以在SysY IR生成时就完成,但为了简化IR生成过程, -// 这里将其放在SysY IR生成后进行预处理 -// 同时兼容phi节点的处理,可以再mem2reg后再次调用优化 -class SysYOptPre { - private: - Module *pModule; - IRBuilder *pBuilder; - - public: - SysYOptPre(Module *pMoudle, IRBuilder *pBuilder) : pModule(pMoudle), pBuilder(pBuilder) {} - - void SysYOptimizateAfterIR(){ - SysYDelInstAfterBr(); - SysYBlockMerge(); - SysYDelNoPreBLock(); - SysYDelEmptyBlock(); - SysYAddReturn(); - } - void SysYDelInstAfterBr(); // 删除br后面的指令 - void SysYDelEmptyBlock(); // 空块删除 - void SysYDelNoPreBLock(); // 删除无前驱块 - void SysYBlockMerge(); // 合并基本块(主要针对嵌套if while的exit块, - // 也可以修改IR生成实现回填机制 - void SysYAddReturn(); // 添加return指令(主要针对Void函数) - void usedelete(Instruction *instr); // use删除 -}; - -} // namespace sysy diff --git a/src/include/SysYIROptUtils.h b/src/include/SysYIROptUtils.h new file mode 100644 index 0000000..1b764ec --- /dev/null +++ b/src/include/SysYIROptUtils.h @@ -0,0 +1,33 @@ +#pragma once + +#include "IR.h" + +namespace sysy { + +// 优化工具类,包含一些通用的优化方法 +// 这些方法可以在不同的优化 pass 中复用 +// 例如:删除use关系,判断是否是全局变量等 +class SysYIROptUtils{ + +public: + // 仅仅删除use关系 + static void usedelete(Instruction *instr) { + for (auto &use : instr->getOperands()) { + Value* val = use->getValue(); + val->removeUse(use); + } + } + + // 判断是否是全局变量 + static bool isGlobal(Value *val) { + auto gval = dynamic_cast(val); + return gval != nullptr; + } + // 判断是否是数组 + static bool isArr(Value *val) { + auto aval = dynamic_cast(val); + return aval != nullptr && aval->getNumDims() != 0; + } +}; + +}// namespace sysy \ No newline at end of file diff --git a/src/include/SysYIRPassManager.h b/src/include/SysYIRPassManager.h new file mode 100644 index 0000000..310b50f --- /dev/null +++ b/src/include/SysYIRPassManager.h @@ -0,0 +1,58 @@ +// PassManager.h +#pragma once + +#include +#include +#include // For std::type_index +#include +#include "SysYIRPass.h" +#include "IR.h" // 假设你的IR.h定义了Module, Function等 + +namespace sysy { + +class PassManager { +public: + PassManager() = default; + + // 添加一个FunctionPass + void addPass(std::unique_ptr pass) { + functionPasses.push_back(std::move(pass)); + } + + // 添加一个ModulePass + void addPass(std::unique_ptr pass) { + modulePasses.push_back(std::move(pass)); + } + + // 添加一个AnalysisPass + template + T* addAnalysisPass(Args&&... args) { + static_assert(std::is_base_of::value, "T must derive from AnalysisPass"); + auto analysis = std::make_unique(std::forward(args)...); + T* rawPtr = analysis.get(); + analysisPasses[std::type_index(typeid(T))] = std::move(analysis); + return rawPtr; + } + + // 获取分析结果(用于其他Pass访问) + template + T* getAnalysis() { + static_assert(std::is_base_of::value, "T must derive from AnalysisPass"); + auto it = analysisPasses.find(std::type_index(typeid(T))); + if (it != analysisPasses.end()) { + return static_cast(it->second.get()); + } + return nullptr; // 或者抛出异常 + } + + // 运行所有注册的遍 + void run(Module& M); + +private: + std::vector> functionPasses; + std::vector> modulePasses; + std::unordered_map> analysisPasses; + // 未来可以添加AnalysisPass的缓存机制 +}; + +} // namespace sysy \ No newline at end of file diff --git a/src/sysyc.cpp b/src/sysyc.cpp index cac39a9..3f79108 100644 --- a/src/sysyc.cpp +++ b/src/sysyc.cpp @@ -13,13 +13,10 @@ using namespace antlr4; #include "SysYIRGenerator.h" #include "SysYIRPrinter.h" -#include "SysYIROptPre.h" +#include "SysYIRCFGOpt.h" // 包含 CFG 优化 #include "RISCv64Backend.h" -#include "SysYIRAnalyser.h" -// #include "DeadCodeElimination.h" +#include "Pass.h" // 包含新的 Pass 框架 #include "AddressCalculationExpansion.h" -// #include "Mem2Reg.h" -// #include "Reg2Mem.h" using namespace sysy; @@ -131,19 +128,20 @@ int main(int argc, char **argv) { if (argStopAfter == "ird") { DEBUG = 1; // 这里可能需要更精细地控制 DEBUG 的开启时机和范围 } - // 默认优化 pass (在所有优化级别都会执行) - SysYOptPre optPre(moduleIR, builder); - optPre.SysYOptimizateAfterIR(); - - ControlFlowAnalysis cfa(moduleIR); - cfa.init(); - ActiveVarAnalysis ava; - ava.init(moduleIR); - + if (DEBUG) { - cout << "=== After CFA & AVA (Default) ===\n"; + cout << "=== Init IR ===\n"; SysYPrinter(moduleIR).printIR(); // 临时打印器用于调试 } + + // 创建 Pass 管理器并运行优化管道 + PassManager passManager(moduleIR, builder); // 创建 Pass 管理器 + // 好像都不用传递module和builder了,因为 PassManager 初始化了 + passManager.runOptimizationPipeline(moduleIR, builder, optLevel); + + + + AddressCalculationExpansion ace(moduleIR, builder); if (ace.run()) { if (DEBUG) cout << "AddressCalculationExpansion made changes.\n";