From 969a78a08817501f72e95265c556b0590d520c49 Mon Sep 17 00:00:00 2001 From: rain2133 <1370973498@qq.com> Date: Sun, 17 Aug 2025 14:37:27 +0800 Subject: [PATCH] =?UTF-8?q?[midend-GVN]segmentation=20fault=E6=98=AFGVN?= =?UTF-8?q?=E5=BC=95=E5=85=A5=E7=9A=84=E5=B7=B2=E4=BF=AE=E5=A4=8D=EF=BC=8C?= =?UTF-8?q?LICM=E4=BB=8D=E7=84=B6=E6=9C=89=E9=94=99=E8=AF=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/midend/Pass/Optimize/GVN.cpp | 253 ++++++++++++++++++++++++++++--- src/midend/Pass/Pass.cpp | 14 +- 2 files changed, 243 insertions(+), 24 deletions(-) diff --git a/src/midend/Pass/Optimize/GVN.cpp b/src/midend/Pass/Optimize/GVN.cpp index 9f28609..a2f1c57 100644 --- a/src/midend/Pass/Optimize/GVN.cpp +++ b/src/midend/Pass/Optimize/GVN.cpp @@ -176,18 +176,35 @@ void GVNContext::dfs(BasicBlock *bb) { } Value *GVNContext::checkHashtable(Value *value) { + // 避免无限递归:如果已经在哈希表中,直接返回映射的值 if (auto it = hashtable.find(value); it != hashtable.end()) { + if (DEBUG >= 2) { + std::cout << " Found " << value->getName() << " in hashtable, mapped to " + << it->second->getName() << std::endl; + } return it->second; } + // 如果是指令,尝试获取其值编号 if (auto inst = dynamic_cast(value)) { if (auto valueNumber = getValueNumber(inst)) { - hashtable[value] = valueNumber; - return valueNumber; + // 如果找到了等价的值,建立映射关系 + if (valueNumber != inst) { + hashtable[value] = valueNumber; + if (DEBUG >= 2) { + std::cout << " Mapping " << value->getName() << " to equivalent value " + << valueNumber->getName() << std::endl; + } + return valueNumber; + } } } + // 没有找到等价值,将自己映射到自己 hashtable[value] = value; + if (DEBUG >= 2) { + std::cout << " Mapping " << value->getName() << " to itself (unique)" << std::endl; + } return value; } @@ -227,11 +244,73 @@ Value *GVNContext::getValueNumber(BinaryInst *inst) { if (binary->getKind() == inst->getKind()) { // 检查操作数是否匹配 - if ((lhs == binLhs && rhs == binRhs) || (inst->isCommutative() && lhs == binRhs && rhs == binLhs)) { - if (DEBUG) { - std::cout << " Found equivalent binary instruction: " << binary->getName() << std::endl; + bool operandsMatch = false; + if (lhs == binLhs && rhs == binRhs) { + operandsMatch = true; + } else if (inst->isCommutative() && lhs == binRhs && rhs == binLhs) { + operandsMatch = true; + } + + if (operandsMatch) { + // 检查支配关系,确保替换是安全的 + if (canReplace(inst, binary)) { + // 对于涉及load指令的情况,需要特别检查 + bool hasLoadOperands = (dynamic_cast(lhs) != nullptr) || + (dynamic_cast(rhs) != nullptr); + + if (hasLoadOperands) { + // 检查是否有任何load操作数之间有intervening store + bool hasIntervening = false; + + auto loadLhs = dynamic_cast(lhs); + auto loadRhs = dynamic_cast(rhs); + auto binLoadLhs = dynamic_cast(binLhs); + auto binLoadRhs = dynamic_cast(binRhs); + + if (loadLhs && binLoadLhs) { + if (hasInterveningStore(binLoadLhs, loadLhs, checkHashtable(loadLhs->getPointer()))) { + hasIntervening = true; + } + } + + if (!hasIntervening && loadRhs && binLoadRhs) { + if (hasInterveningStore(binLoadRhs, loadRhs, checkHashtable(loadRhs->getPointer()))) { + hasIntervening = true; + } + } + + // 对于交换操作数的情况,也需要检查 + if (!hasIntervening && inst->isCommutative()) { + if (loadLhs && binLoadRhs) { + if (hasInterveningStore(binLoadRhs, loadLhs, checkHashtable(loadLhs->getPointer()))) { + hasIntervening = true; + } + } + + if (!hasIntervening && loadRhs && binLoadLhs) { + if (hasInterveningStore(binLoadLhs, loadRhs, checkHashtable(loadRhs->getPointer()))) { + hasIntervening = true; + } + } + } + + if (hasIntervening) { + if (DEBUG) { + std::cout << " Found equivalent binary but load operands have intervening store, skipping" << std::endl; + } + continue; + } + } + + if (DEBUG) { + std::cout << " Found equivalent binary instruction: " << binary->getName() << std::endl; + } + return value; + } else { + if (DEBUG) { + std::cout << " Found equivalent binary but dominance check failed: " << binary->getName() << std::endl; + } } - return value; } } } @@ -294,26 +373,47 @@ Value *GVNContext::getValueNumber(GetElementPtrInst *inst) { Value *GVNContext::getValueNumber(LoadInst *inst) { auto ptr = checkHashtable(inst->getPointer()); + if (DEBUG) { + std::cout << " Checking load instruction: " << inst->getName() + << " from address: " << ptr->getName() << std::endl; + } + for (auto [key, value] : hashtable) { if (auto load = dynamic_cast(key)) { auto loadPtr = checkHashtable(load->getPointer()); if (ptr == loadPtr && inst->getType() == load->getType()) { - // 检查两次load之间是否有store指令修改了内存 + if (DEBUG) { + std::cout << " Found potential equivalent load: " << load->getName() << std::endl; + } + + // 检查支配关系:load 必须支配 inst + if (!canReplace(inst, load)) { + if (DEBUG) { + std::cout << " Equivalent load does not dominate current load, skipping" << std::endl; + } + continue; + } + + // 检查是否有中间的store指令影响 if (hasInterveningStore(load, inst, ptr)) { if (DEBUG) { std::cout << " Found intervening store, cannot reuse load value" << std::endl; } continue; // 如果有store指令,不能复用之前的load } + if (DEBUG) { - std::cout << " No intervening store found, can reuse load value" << std::endl; + std::cout << " Can safely reuse load value from: " << load->getName() << std::endl; } return value; } } } + if (DEBUG) { + std::cout << " No equivalent load found" << std::endl; + } return inst; } @@ -427,8 +527,21 @@ bool GVNContext::canReplace(Instruction *original, Value *replacement) { auto replIt = std::find_if(insts.begin(), insts.end(), [replInst](const auto &ptr) { return ptr.get() == replInst; }); - // 替换指令必须在原指令之前 - return std::distance(insts.begin(), replIt) < std::distance(insts.begin(), origIt); + if (origIt == insts.end() || replIt == insts.end()) { + if (DEBUG) { + std::cout << " Cannot find instructions in basic block for dominance check" << std::endl; + } + return false; + } + + // 替换指令必须在原指令之前(支配原指令) + bool canRepl = std::distance(insts.begin(), replIt) < std::distance(insts.begin(), origIt); + if (DEBUG) { + std::cout << " Same block dominance check: " << (canRepl ? "PASS" : "FAIL") + << " (repl at " << std::distance(insts.begin(), replIt) + << ", orig at " << std::distance(insts.begin(), origIt) << ")" << std::endl; + } + return canRepl; } // 使用支配关系检查(如果支配树分析可用) @@ -450,6 +563,9 @@ bool GVNContext::hasInterveningStore(LoadInst* earlierLoad, LoadInst* laterLoad, if (earlierBB != laterBB) { // 跨基本块的情况:为了安全起见,暂时认为有intervening store // 这是保守的做法,可能会错过一些优化机会,但确保正确性 + if (DEBUG) { + std::cout << " Cross-block load optimization: conservatively assuming intervening store" << std::endl; + } return true; } @@ -463,11 +579,28 @@ bool GVNContext::hasInterveningStore(LoadInst* earlierLoad, LoadInst* laterLoad, [laterLoad](const auto &ptr) { return ptr.get() == laterLoad; }); if (earlierIt == insts.end() || laterIt == insts.end()) { + if (DEBUG) { + std::cout << " Could not find load instructions in basic block" << std::endl; + } return true; // 找不到指令,保守返回true } + // 确定实际的执行顺序(哪个load在前,哪个在后) + auto firstIt = earlierIt; + auto secondIt = laterIt; + + if (std::distance(insts.begin(), earlierIt) > std::distance(insts.begin(), laterIt)) { + // 如果"earlier"实际上在"later"之后,交换它们 + firstIt = laterIt; + secondIt = earlierIt; + if (DEBUG) { + std::cout << " Swapped load order: " << laterLoad->getName() + << " actually comes before " << earlierLoad->getName() << std::endl; + } + } + // 检查两个load之间的所有指令 - for (auto it = std::next(earlierIt); it != laterIt; ++it) { + for (auto it = std::next(firstIt); it != secondIt; ++it) { auto inst = it->get(); // 检查是否是store指令 @@ -477,27 +610,34 @@ bool GVNContext::hasInterveningStore(LoadInst* earlierLoad, LoadInst* laterLoad, // 如果store的目标地址与load的地址相同,说明内存被修改了 if (storePtr == ptr) { if (DEBUG) { - std::cout << " Found intervening store to same address, cannot optimize load" << std::endl; + std::cout << " Found intervening store to same address: " << storeInst->getName() << std::endl; } return true; } + + // TODO: 这里还应该检查别名分析,看store是否可能影响load的地址 + // 为了简化,现在只检查精确匹配 } - // TODO: 还需要检查函数调用是否可能修改内存 - // 对于全局变量,任何函数调用都可能修改它 + // 检查函数调用是否可能修改内存 if (auto callInst = dynamic_cast(inst)) { if (sideEffectAnalysis && !sideEffectAnalysis->isPureFunction(callInst->getCallee())) { // 如果是有副作用的函数调用,且load的是全局变量,则可能被修改 if (auto globalPtr = dynamic_cast(ptr)) { if (DEBUG) { - std::cout << " Found function call that may modify global variable, cannot optimize load" << std::endl; + std::cout << " Found function call that may modify global variable: " << callInst->getName() << std::endl; } return true; } + // TODO: 这里还应该检查函数是否可能修改通过指针参数传递的内存 } } } + if (DEBUG) { + std::cout << " No intervening store found between loads" << std::endl; + } + return false; // 没有找到会修改内存的指令 } @@ -508,9 +648,11 @@ void GVNContext::invalidateLoadsAffectedByStore(StoreInst* storeInst) { std::cout << " Invalidating loads affected by store to address" << std::endl; } - // 查找hashtable中所有可能被这个store影响的load指令 + // 查找hashtable中所有可能被这个store影响的指令 std::vector toRemove; + std::set invalidatedLoads; + // 第一步:找到所有被直接影响的load指令 for (auto& [key, value] : hashtable) { if (auto loadInst = dynamic_cast(key)) { auto loadPtr = checkHashtable(loadInst->getPointer()); @@ -518,6 +660,7 @@ void GVNContext::invalidateLoadsAffectedByStore(StoreInst* storeInst) { // 如果load的地址与store的地址相同,则需要从hashtable中移除 if (loadPtr == storePtr) { toRemove.push_back(key); + invalidatedLoads.insert(loadInst); if (DEBUG) { std::cout << " Invalidating load from same address: " << loadInst->getName() << std::endl; } @@ -525,10 +668,86 @@ void GVNContext::invalidateLoadsAffectedByStore(StoreInst* storeInst) { } } - // 从hashtable中移除被影响的load指令 + // 第二步:找到所有依赖被失效load的指令(如binary指令) + bool foundMore = true; + while (foundMore) { + foundMore = false; + std::vector additionalToRemove; + + for (auto& [key, value] : hashtable) { + // 跳过已经标记要删除的指令 + if (std::find(toRemove.begin(), toRemove.end(), key) != toRemove.end()) { + continue; + } + + bool shouldInvalidate = false; + + // 检查binary指令的操作数 + if (auto binaryInst = dynamic_cast(key)) { + auto lhs = checkHashtable(binaryInst->getLhs()); + auto rhs = checkHashtable(binaryInst->getRhs()); + + if (invalidatedLoads.count(lhs) || invalidatedLoads.count(rhs)) { + shouldInvalidate = true; + if (DEBUG) { + std::cout << " Invalidating binary instruction due to invalidated operand: " + << binaryInst->getName() << std::endl; + } + } + } + // 检查unary指令的操作数 + else if (auto unaryInst = dynamic_cast(key)) { + auto operand = checkHashtable(unaryInst->getOperand()); + if (invalidatedLoads.count(operand)) { + shouldInvalidate = true; + if (DEBUG) { + std::cout << " Invalidating unary instruction due to invalidated operand: " + << unaryInst->getName() << std::endl; + } + } + } + // 检查GEP指令的操作数 + else if (auto gepInst = dynamic_cast(key)) { + auto basePtr = checkHashtable(gepInst->getBasePointer()); + if (invalidatedLoads.count(basePtr)) { + shouldInvalidate = true; + } else { + // 检查索引操作数 + for (unsigned i = 0; i < gepInst->getNumIndices(); ++i) { + if (invalidatedLoads.count(checkHashtable(gepInst->getIndex(i)))) { + shouldInvalidate = true; + break; + } + } + } + if (shouldInvalidate && DEBUG) { + std::cout << " Invalidating GEP instruction due to invalidated operand: " + << gepInst->getName() << std::endl; + } + } + + if (shouldInvalidate) { + additionalToRemove.push_back(key); + if (auto inst = dynamic_cast(key)) { + invalidatedLoads.insert(inst); + } + foundMore = true; + } + } + + // 将新找到的失效指令加入移除列表 + toRemove.insert(toRemove.end(), additionalToRemove.begin(), additionalToRemove.end()); + } + + // 从hashtable中移除所有被影响的指令 for (auto key : toRemove) { hashtable.erase(key); } + + if (DEBUG && toRemove.size() > invalidatedLoads.size()) { + std::cout << " Total invalidated instructions: " << toRemove.size() + << " (including " << (toRemove.size() - invalidatedLoads.size()) << " dependent instructions)" << std::endl; + } } std::string GVNContext::getCanonicalExpression(Instruction *inst) { diff --git a/src/midend/Pass/Pass.cpp b/src/midend/Pass/Pass.cpp index 09de26e..69f9790 100644 --- a/src/midend/Pass/Pass.cpp +++ b/src/midend/Pass/Pass.cpp @@ -161,14 +161,14 @@ void PassManager::runOptimizationPipeline(Module* moduleIR, IRBuilder* builderIR } - this->clearPasses(); - this->addPass(&LICM::ID); - this->run(); + // this->clearPasses(); + // this->addPass(&LICM::ID); + // this->run(); - if(DEBUG) { - std::cout << "=== IR After LICM ===\n"; - printPasses(); - } + // if(DEBUG) { + // std::cout << "=== IR After LICM ===\n"; + // printPasses(); + // } this->clearPasses(); this->addPass(&LoopStrengthReduction::ID);