From 467f2f6b242c704dda62439e7fd79b622a0a9cfb Mon Sep 17 00:00:00 2001 From: rain2133 <1370973498@qq.com> Date: Sat, 16 Aug 2025 15:38:41 +0800 Subject: [PATCH] =?UTF-8?q?[midend-GVN]=E5=88=9D=E6=AD=A5=E6=9E=84?= =?UTF-8?q?=E5=BB=BAGVN=EF=BC=8C=E8=83=BD=E5=A4=9F=E4=BC=98=E5=8C=96?= =?UTF-8?q?=E9=83=A8=E5=88=86CSE=E6=97=A0=E6=B3=95=E5=A4=84=E7=90=86?= =?UTF-8?q?=E7=9A=84=E5=AD=90=E8=A1=A8=E8=BE=BE=E5=BC=8F=E4=BD=86=E6=98=AF?= =?UTF-8?q?=E6=9C=89=E9=94=99=E8=AF=AF=E9=9C=80=E8=A6=81debug=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/midend/CMakeLists.txt | 1 + src/midend/IR.cpp | 4 +- src/midend/Pass/Optimize/GVN.cpp | 450 +++++++++++++++++++++++++++++++ src/midend/Pass/Pass.cpp | 13 + 4 files changed, 466 insertions(+), 2 deletions(-) create mode 100644 src/midend/Pass/Optimize/GVN.cpp diff --git a/src/midend/CMakeLists.txt b/src/midend/CMakeLists.txt index b3b86cc..66fc461 100644 --- a/src/midend/CMakeLists.txt +++ b/src/midend/CMakeLists.txt @@ -15,6 +15,7 @@ add_library(midend_lib STATIC Pass/Optimize/DCE.cpp Pass/Optimize/Mem2Reg.cpp Pass/Optimize/Reg2Mem.cpp + Pass/Optimize/GVN.cpp Pass/Optimize/SysYIRCFGOpt.cpp Pass/Optimize/SCCP.cpp Pass/Optimize/LoopNormalization.cpp diff --git a/src/midend/IR.cpp b/src/midend/IR.cpp index 39293f2..d35e16b 100644 --- a/src/midend/IR.cpp +++ b/src/midend/IR.cpp @@ -847,7 +847,7 @@ void CondBrInst::print(std::ostream &os) const { os << "%tmp_cond_" << condName << "_" << uniqueSuffix << " = icmp ne i32 "; printOperand(os, condition); - os << ", 0\n br i1 %tmp_cond_" << condName << "_" << uniqueSuffix; + os << ", 0\n br i1 %tmp_cond_" << condName << "_" << uniqueSuffix; os << ", label %"; printBlockName(os, getThenBlock()); @@ -886,7 +886,7 @@ void MemsetInst::print(std::ostream &os) const { // This is done at print time to avoid modifying the IR structure os << "%tmp_bitcast_" << ptr->getName() << " = bitcast " << *ptr->getType() << " "; printOperand(os, ptr); - os << " to i8*\n "; + os << " to i8*\n "; // Now call memset with the bitcast result os << "call void @llvm.memset.p0i8.i32(i8* %tmp_bitcast_" << ptr->getName() << ", i8 "; diff --git a/src/midend/Pass/Optimize/GVN.cpp b/src/midend/Pass/Optimize/GVN.cpp new file mode 100644 index 0000000..a06ec5f --- /dev/null +++ b/src/midend/Pass/Optimize/GVN.cpp @@ -0,0 +1,450 @@ +#include "GVN.h" +#include "Dom.h" +#include "SysYIROptUtils.h" +#include +#include +#include + +extern int DEBUG; + +namespace sysy { + +// GVN 遍的静态 ID +void *GVN::ID = (void *)&GVN::ID; + +// ====================================================================== +// GVN 类的实现 +// ====================================================================== + +bool GVN::runOnFunction(Function *func, AnalysisManager &AM) { + if (func->getBasicBlocks().empty()) { + return false; + } + + if (DEBUG) { + std::cout << "\n=== Running GVN on function: " << func->getName() << " ===" << std::endl; + } + + bool changed = false; + GVNContext context; + context.run(func, &AM, changed); + + if (DEBUG) { + if (changed) { + std::cout << "GVN: Function " << func->getName() << " was modified" << std::endl; + } else { + std::cout << "GVN: Function " << func->getName() << " was not modified" << std::endl; + } + std::cout << "=== GVN completed for function: " << func->getName() << " ===" << std::endl; + } + + return changed; +} + +void GVN::getAnalysisUsage(std::set &analysisDependencies, std::set &analysisInvalidations) const { + // GVN依赖以下分析: + // 1. 支配树分析 - 用于检查指令的支配关系,确保替换的安全性 + analysisDependencies.insert(&DominatorTreeAnalysisPass::ID); + + // 2. 副作用分析 - 用于判断函数调用是否可以进行GVN + analysisDependencies.insert(&SysYSideEffectAnalysisPass::ID); + + // GVN不会使任何分析失效,因为: + // - GVN只删除冗余计算,不改变CFG结构 + // - GVN不修改程序的语义,只是消除重复计算 + // - 支配关系保持不变 + // - 副作用分析结果保持不变 + // analysisInvalidations 保持为空 + + if (DEBUG) { + std::cout << "GVN: Declared analysis dependencies (DominatorTree, SideEffectAnalysis)" << std::endl; + } +} + +// ====================================================================== +// GVNContext 类的实现 +// ====================================================================== + +void GVNContext::run(Function *func, AnalysisManager *AM, bool &changed) { + if (DEBUG) { + std::cout << " Starting GVN analysis for function: " << func->getName() << std::endl; + } + + // 获取分析结果 + if (AM) { + domTree = AM->getAnalysisResult(func); + sideEffectAnalysis = AM->getAnalysisResult(); + + if (DEBUG) { + if (domTree) { + std::cout << " GVN: Using dominator tree analysis" << std::endl; + } else { + std::cout << " GVN: Warning - dominator tree analysis not available" << std::endl; + } + if (sideEffectAnalysis) { + std::cout << " GVN: Using side effect analysis" << std::endl; + } else { + std::cout << " GVN: Warning - side effect analysis not available" << std::endl; + } + } + } + + // 清空状态 + hashtable.clear(); + visited.clear(); + rpoBlocks.clear(); + needRemove.clear(); + + // 计算逆后序遍历 + computeRPO(func); + + if (DEBUG) { + std::cout << " Computed RPO with " << rpoBlocks.size() << " blocks" << std::endl; + } + + // 按逆后序遍历基本块进行GVN + int blockCount = 0; + for (auto bb : rpoBlocks) { + if (DEBUG) { + std::cout << " Processing block " << ++blockCount << "/" << rpoBlocks.size() + << ": " << bb->getName() << std::endl; + } + + int instCount = 0; + for (auto &instPtr : bb->getInstructions()) { + if (DEBUG) { + std::cout << " Processing instruction " << ++instCount + << ": " << instPtr->getName() << std::endl; + } + visitInstruction(instPtr.get()); + } + } + + if (DEBUG) { + std::cout << " Found " << needRemove.size() << " redundant instructions to remove" << std::endl; + } + + // 删除冗余指令 + int removeCount = 0; + for (auto inst : needRemove) { + auto bb = inst->getParent(); + if (DEBUG) { + std::cout << " Removing redundant instruction " << ++removeCount + << "/" << needRemove.size() << ": " << inst->getName() << std::endl; + } + // 删除指令前先断开所有使用关系 + inst->replaceAllUsesWith(nullptr); + // 使用基本块的删除方法 + // bb->removeInst(inst); + SysYIROptUtils::usedelete(inst); + changed = true; + } + + if (DEBUG) { + std::cout << " GVN analysis completed for function: " << func->getName() << std::endl; + std::cout << " Total instructions analyzed: " << hashtable.size() << std::endl; + std::cout << " Instructions eliminated: " << needRemove.size() << std::endl; + } +} + +void GVNContext::computeRPO(Function *func) { + rpoBlocks.clear(); + visited.clear(); + + auto entry = func->getEntryBlock(); + if (entry) { + dfs(entry); + std::reverse(rpoBlocks.begin(), rpoBlocks.end()); + } +} + +void GVNContext::dfs(BasicBlock *bb) { + if (!bb || visited.count(bb)) { + return; + } + + visited.insert(bb); + + // 访问所有后继基本块 + for (auto succ : bb->getSuccessors()) { + if (visited.find(succ) == visited.end()) { + dfs(succ); + } + } + + rpoBlocks.push_back(bb); +} + +Value *GVNContext::checkHashtable(Value *value) { + if (auto it = hashtable.find(value); it != hashtable.end()) { + return it->second; + } + + if (auto inst = dynamic_cast(value)) { + if (auto valueNumber = getValueNumber(inst)) { + hashtable[value] = valueNumber; + return valueNumber; + } + } + + hashtable[value] = value; + return value; +} + +Value *GVNContext::getValueNumber(Instruction *inst) { + if (auto binary = dynamic_cast(inst)) { + return getValueNumber(binary); + } else if (auto unary = dynamic_cast(inst)) { + return getValueNumber(unary); + } else if (auto gep = dynamic_cast(inst)) { + return getValueNumber(gep); + } else if (auto load = dynamic_cast(inst)) { + return getValueNumber(load); + } else if (auto call = dynamic_cast(inst)) { + return getValueNumber(call); + } + + return nullptr; +} + +Value *GVNContext::getValueNumber(BinaryInst *inst) { + auto lhs = checkHashtable(inst->getLhs()); + auto rhs = checkHashtable(inst->getRhs()); + + if (DEBUG) { + std::cout << " Checking binary instruction: " << inst->getName() + << " (kind: " << static_cast(inst->getKind()) << ")" << std::endl; + } + + for (auto [key, value] : hashtable) { + if (auto binary = dynamic_cast(key)) { + auto binLhs = checkHashtable(binary->getLhs()); + auto binRhs = checkHashtable(binary->getRhs()); + + if (binary->getKind() == inst->getKind()) { + // 检查操作数是否匹配 + if ((lhs == binLhs && rhs == binRhs) || (inst->isCommutative() && lhs == binRhs && rhs == binLhs)) { + if (DEBUG) { + std::cout << " Found equivalent binary instruction: " << binary->getName() << std::endl; + } + return value; + } + } + } + } + + if (DEBUG) { + std::cout << " No equivalent binary instruction found" << std::endl; + } + return inst; +} + +Value *GVNContext::getValueNumber(UnaryInst *inst) { + auto operand = checkHashtable(inst->getOperand()); + + for (auto [key, value] : hashtable) { + if (auto unary = dynamic_cast(key)) { + auto unOperand = checkHashtable(unary->getOperand()); + + if (unary->getKind() == inst->getKind() && operand == unOperand) { + return value; + } + } + } + + return inst; +} + +Value *GVNContext::getValueNumber(GetElementPtrInst *inst) { + auto ptr = checkHashtable(inst->getBasePointer()); + std::vector indices; + + // 使用正确的索引访问方法 + for (unsigned i = 0; i < inst->getNumIndices(); ++i) { + indices.push_back(checkHashtable(inst->getIndex(i))); + } + + for (auto [key, value] : hashtable) { + if (auto gep = dynamic_cast(key)) { + auto gepPtr = checkHashtable(gep->getBasePointer()); + + if (ptr == gepPtr && gep->getNumIndices() == inst->getNumIndices()) { + bool indicesMatch = true; + for (unsigned i = 0; i < inst->getNumIndices(); ++i) { + if (checkHashtable(gep->getIndex(i)) != indices[i]) { + indicesMatch = false; + break; + } + } + + if (indicesMatch && inst->getType() == gep->getType()) { + return value; + } + } + } + } + + return inst; +} + +Value *GVNContext::getValueNumber(LoadInst *inst) { + auto ptr = checkHashtable(inst->getPointer()); + + for (auto [key, value] : hashtable) { + if (auto load = dynamic_cast(key)) { + auto loadPtr = checkHashtable(load->getPointer()); + + if (ptr == loadPtr && inst->getType() == load->getType()) { + return value; + } + } + } + + return inst; +} + +Value *GVNContext::getValueNumber(CallInst *inst) { + // 只为无副作用的函数调用进行GVN + if (sideEffectAnalysis && !sideEffectAnalysis->isPureFunction(inst->getCallee())) { + return nullptr; + } + + for (auto [key, value] : hashtable) { + if (auto call = dynamic_cast(key)) { + if (call->getCallee() == inst->getCallee() && call->getNumOperands() == inst->getNumOperands()) { + + bool argsMatch = true; + // 跳过第一个操作数(函数指针),从参数开始比较 + for (size_t i = 1; i < inst->getNumOperands(); ++i) { + if (checkHashtable(inst->getOperand(i)) != checkHashtable(call->getOperand(i))) { + argsMatch = false; + break; + } + } + + if (argsMatch) { + return value; + } + } + } + } + + return inst; +} + +void GVNContext::visitInstruction(Instruction *inst) { + // 跳过分支指令 + if (inst->isBranch()) { + if (DEBUG) { + std::cout << " Skipping branch instruction: " << inst->getName() << std::endl; + } + return; + } + + if (DEBUG) { + std::cout << " Visiting instruction: " << inst->getName() + << " (kind: " << static_cast(inst->getKind()) << ")" << std::endl; + } + + auto value = checkHashtable(inst); + + if (inst != value) { + if (auto instValue = dynamic_cast(value)) { + if (canReplace(inst, instValue)) { + inst->replaceAllUsesWith(instValue); + needRemove.insert(inst); + + if (DEBUG) { + std::cout << " GVN: Replacing redundant instruction " << inst->getName() + << " with existing instruction " << instValue->getName() << std::endl; + } + } else { + if (DEBUG) { + std::cout << " Cannot replace instruction " << inst->getName() + << " with " << instValue->getName() << " (dominance check failed)" << std::endl; + } + } + } + } else { + if (DEBUG) { + std::cout << " Instruction " << inst->getName() << " is unique" << std::endl; + } + } +} + +bool GVNContext::canReplace(Instruction *original, Value *replacement) { + auto replInst = dynamic_cast(replacement); + if (!replInst) { + return true; // 替换为常量总是安全的 + } + + auto originalBB = original->getParent(); + auto replBB = replInst->getParent(); + + // 如果replacement是Call指令,需要特殊处理 + if (auto callInst = dynamic_cast(replInst)) { + if (sideEffectAnalysis && !sideEffectAnalysis->isPureFunction(callInst->getCallee())) { + // 对于有副作用的函数,只有在同一个基本块且相邻时才能替换 + if (originalBB != replBB) { + return false; + } + + // 检查指令顺序 + auto &insts = originalBB->getInstructions(); + auto origIt = + std::find_if(insts.begin(), insts.end(), [original](const auto &ptr) { return ptr.get() == original; }); + auto replIt = + std::find_if(insts.begin(), insts.end(), [replInst](const auto &ptr) { return ptr.get() == replInst; }); + + if (origIt == insts.end() || replIt == insts.end()) { + return false; + } + + return std::abs(std::distance(origIt, replIt)) == 1; + } + } + + // 简单的支配关系检查:如果在同一个基本块,检查指令顺序 + if (originalBB == replBB) { + auto &insts = originalBB->getInstructions(); + auto origIt = + std::find_if(insts.begin(), insts.end(), [original](const auto &ptr) { return ptr.get() == original; }); + auto replIt = + std::find_if(insts.begin(), insts.end(), [replInst](const auto &ptr) { return ptr.get() == replInst; }); + + // 替换指令必须在原指令之前 + return std::distance(insts.begin(), replIt) < std::distance(insts.begin(), origIt); + } + + // 使用支配关系检查(如果支配树分析可用) + if (domTree) { + auto dominators = domTree->getDominators(originalBB); + if (dominators && dominators->count(replBB)) { + return true; + } + } + + return false; +} + +std::string GVNContext::getCanonicalExpression(Instruction *inst) { + std::ostringstream oss; + + if (auto binary = dynamic_cast(inst)) { + oss << "binary_" << static_cast(binary->getKind()) << "_"; + oss << checkHashtable(binary->getLhs()) << "_"; + oss << checkHashtable(binary->getRhs()); + } else if (auto unary = dynamic_cast(inst)) { + oss << "unary_" << static_cast(unary->getKind()) << "_"; + oss << checkHashtable(unary->getOperand()); + } else if (auto gep = dynamic_cast(inst)) { + oss << "gep_" << checkHashtable(gep->getBasePointer()); + for (unsigned i = 0; i < gep->getNumIndices(); ++i) { + oss << "_" << checkHashtable(gep->getIndex(i)); + } + } + + return oss.str(); +} + +} // namespace sysy diff --git a/src/midend/Pass/Pass.cpp b/src/midend/Pass/Pass.cpp index 440ce0c..449508e 100644 --- a/src/midend/Pass/Pass.cpp +++ b/src/midend/Pass/Pass.cpp @@ -10,6 +10,7 @@ #include "DCE.h" #include "Mem2Reg.h" #include "Reg2Mem.h" +#include "GVN.h" #include "SCCP.h" #include "BuildCFG.h" #include "LargeArrayToGlobal.h" @@ -59,6 +60,8 @@ void PassManager::runOptimizationPipeline(Module* moduleIR, IRBuilder* builderIR // 注册优化遍 registerOptimizationPass(); registerOptimizationPass(); + + registerOptimizationPass(); registerOptimizationPass(); registerOptimizationPass(); @@ -129,6 +132,16 @@ void PassManager::runOptimizationPipeline(Module* moduleIR, IRBuilder* builderIR printPasses(); } + // 添加GVN优化遍 + this->clearPasses(); + this->addPass(&GVN::ID); + this->run(); + + if(DEBUG) { + std::cout << "=== IR After GVN Optimizations ===\n"; + printPasses(); + } + this->clearPasses(); this->addPass(&SCCP::ID); this->run();