[midend-GVN]初步构建GVN,能够优化部分CSE无法处理的子表达式但是有错误需要debug。

This commit is contained in:
rain2133
2025-08-16 15:38:41 +08:00
parent fa33bf5134
commit 467f2f6b24
4 changed files with 466 additions and 2 deletions

View File

@@ -15,6 +15,7 @@ add_library(midend_lib STATIC
Pass/Optimize/DCE.cpp
Pass/Optimize/Mem2Reg.cpp
Pass/Optimize/Reg2Mem.cpp
Pass/Optimize/GVN.cpp
Pass/Optimize/SysYIRCFGOpt.cpp
Pass/Optimize/SCCP.cpp
Pass/Optimize/LoopNormalization.cpp

View File

@@ -0,0 +1,450 @@
#include "GVN.h"
#include "Dom.h"
#include "SysYIROptUtils.h"
#include <algorithm>
#include <cassert>
#include <iostream>
extern int DEBUG;
namespace sysy {
// GVN 遍的静态 ID
void *GVN::ID = (void *)&GVN::ID;
// ======================================================================
// GVN 类的实现
// ======================================================================
bool GVN::runOnFunction(Function *func, AnalysisManager &AM) {
if (func->getBasicBlocks().empty()) {
return false;
}
if (DEBUG) {
std::cout << "\n=== Running GVN on function: " << func->getName() << " ===" << std::endl;
}
bool changed = false;
GVNContext context;
context.run(func, &AM, changed);
if (DEBUG) {
if (changed) {
std::cout << "GVN: Function " << func->getName() << " was modified" << std::endl;
} else {
std::cout << "GVN: Function " << func->getName() << " was not modified" << std::endl;
}
std::cout << "=== GVN completed for function: " << func->getName() << " ===" << std::endl;
}
return changed;
}
void GVN::getAnalysisUsage(std::set<void *> &analysisDependencies, std::set<void *> &analysisInvalidations) const {
// GVN依赖以下分析
// 1. 支配树分析 - 用于检查指令的支配关系,确保替换的安全性
analysisDependencies.insert(&DominatorTreeAnalysisPass::ID);
// 2. 副作用分析 - 用于判断函数调用是否可以进行GVN
analysisDependencies.insert(&SysYSideEffectAnalysisPass::ID);
// GVN不会使任何分析失效因为
// - GVN只删除冗余计算不改变CFG结构
// - GVN不修改程序的语义只是消除重复计算
// - 支配关系保持不变
// - 副作用分析结果保持不变
// analysisInvalidations 保持为空
if (DEBUG) {
std::cout << "GVN: Declared analysis dependencies (DominatorTree, SideEffectAnalysis)" << std::endl;
}
}
// ======================================================================
// GVNContext 类的实现
// ======================================================================
void GVNContext::run(Function *func, AnalysisManager *AM, bool &changed) {
if (DEBUG) {
std::cout << " Starting GVN analysis for function: " << func->getName() << std::endl;
}
// 获取分析结果
if (AM) {
domTree = AM->getAnalysisResult<DominatorTree, DominatorTreeAnalysisPass>(func);
sideEffectAnalysis = AM->getAnalysisResult<SideEffectAnalysisResult, SysYSideEffectAnalysisPass>();
if (DEBUG) {
if (domTree) {
std::cout << " GVN: Using dominator tree analysis" << std::endl;
} else {
std::cout << " GVN: Warning - dominator tree analysis not available" << std::endl;
}
if (sideEffectAnalysis) {
std::cout << " GVN: Using side effect analysis" << std::endl;
} else {
std::cout << " GVN: Warning - side effect analysis not available" << std::endl;
}
}
}
// 清空状态
hashtable.clear();
visited.clear();
rpoBlocks.clear();
needRemove.clear();
// 计算逆后序遍历
computeRPO(func);
if (DEBUG) {
std::cout << " Computed RPO with " << rpoBlocks.size() << " blocks" << std::endl;
}
// 按逆后序遍历基本块进行GVN
int blockCount = 0;
for (auto bb : rpoBlocks) {
if (DEBUG) {
std::cout << " Processing block " << ++blockCount << "/" << rpoBlocks.size()
<< ": " << bb->getName() << std::endl;
}
int instCount = 0;
for (auto &instPtr : bb->getInstructions()) {
if (DEBUG) {
std::cout << " Processing instruction " << ++instCount
<< ": " << instPtr->getName() << std::endl;
}
visitInstruction(instPtr.get());
}
}
if (DEBUG) {
std::cout << " Found " << needRemove.size() << " redundant instructions to remove" << std::endl;
}
// 删除冗余指令
int removeCount = 0;
for (auto inst : needRemove) {
auto bb = inst->getParent();
if (DEBUG) {
std::cout << " Removing redundant instruction " << ++removeCount
<< "/" << needRemove.size() << ": " << inst->getName() << std::endl;
}
// 删除指令前先断开所有使用关系
inst->replaceAllUsesWith(nullptr);
// 使用基本块的删除方法
// bb->removeInst(inst);
SysYIROptUtils::usedelete(inst);
changed = true;
}
if (DEBUG) {
std::cout << " GVN analysis completed for function: " << func->getName() << std::endl;
std::cout << " Total instructions analyzed: " << hashtable.size() << std::endl;
std::cout << " Instructions eliminated: " << needRemove.size() << std::endl;
}
}
void GVNContext::computeRPO(Function *func) {
rpoBlocks.clear();
visited.clear();
auto entry = func->getEntryBlock();
if (entry) {
dfs(entry);
std::reverse(rpoBlocks.begin(), rpoBlocks.end());
}
}
void GVNContext::dfs(BasicBlock *bb) {
if (!bb || visited.count(bb)) {
return;
}
visited.insert(bb);
// 访问所有后继基本块
for (auto succ : bb->getSuccessors()) {
if (visited.find(succ) == visited.end()) {
dfs(succ);
}
}
rpoBlocks.push_back(bb);
}
Value *GVNContext::checkHashtable(Value *value) {
if (auto it = hashtable.find(value); it != hashtable.end()) {
return it->second;
}
if (auto inst = dynamic_cast<Instruction *>(value)) {
if (auto valueNumber = getValueNumber(inst)) {
hashtable[value] = valueNumber;
return valueNumber;
}
}
hashtable[value] = value;
return value;
}
Value *GVNContext::getValueNumber(Instruction *inst) {
if (auto binary = dynamic_cast<BinaryInst *>(inst)) {
return getValueNumber(binary);
} else if (auto unary = dynamic_cast<UnaryInst *>(inst)) {
return getValueNumber(unary);
} else if (auto gep = dynamic_cast<GetElementPtrInst *>(inst)) {
return getValueNumber(gep);
} else if (auto load = dynamic_cast<LoadInst *>(inst)) {
return getValueNumber(load);
} else if (auto call = dynamic_cast<CallInst *>(inst)) {
return getValueNumber(call);
}
return nullptr;
}
Value *GVNContext::getValueNumber(BinaryInst *inst) {
auto lhs = checkHashtable(inst->getLhs());
auto rhs = checkHashtable(inst->getRhs());
if (DEBUG) {
std::cout << " Checking binary instruction: " << inst->getName()
<< " (kind: " << static_cast<int>(inst->getKind()) << ")" << std::endl;
}
for (auto [key, value] : hashtable) {
if (auto binary = dynamic_cast<BinaryInst *>(key)) {
auto binLhs = checkHashtable(binary->getLhs());
auto binRhs = checkHashtable(binary->getRhs());
if (binary->getKind() == inst->getKind()) {
// 检查操作数是否匹配
if ((lhs == binLhs && rhs == binRhs) || (inst->isCommutative() && lhs == binRhs && rhs == binLhs)) {
if (DEBUG) {
std::cout << " Found equivalent binary instruction: " << binary->getName() << std::endl;
}
return value;
}
}
}
}
if (DEBUG) {
std::cout << " No equivalent binary instruction found" << std::endl;
}
return inst;
}
Value *GVNContext::getValueNumber(UnaryInst *inst) {
auto operand = checkHashtable(inst->getOperand());
for (auto [key, value] : hashtable) {
if (auto unary = dynamic_cast<UnaryInst *>(key)) {
auto unOperand = checkHashtable(unary->getOperand());
if (unary->getKind() == inst->getKind() && operand == unOperand) {
return value;
}
}
}
return inst;
}
Value *GVNContext::getValueNumber(GetElementPtrInst *inst) {
auto ptr = checkHashtable(inst->getBasePointer());
std::vector<Value *> indices;
// 使用正确的索引访问方法
for (unsigned i = 0; i < inst->getNumIndices(); ++i) {
indices.push_back(checkHashtable(inst->getIndex(i)));
}
for (auto [key, value] : hashtable) {
if (auto gep = dynamic_cast<GetElementPtrInst *>(key)) {
auto gepPtr = checkHashtable(gep->getBasePointer());
if (ptr == gepPtr && gep->getNumIndices() == inst->getNumIndices()) {
bool indicesMatch = true;
for (unsigned i = 0; i < inst->getNumIndices(); ++i) {
if (checkHashtable(gep->getIndex(i)) != indices[i]) {
indicesMatch = false;
break;
}
}
if (indicesMatch && inst->getType() == gep->getType()) {
return value;
}
}
}
}
return inst;
}
Value *GVNContext::getValueNumber(LoadInst *inst) {
auto ptr = checkHashtable(inst->getPointer());
for (auto [key, value] : hashtable) {
if (auto load = dynamic_cast<LoadInst *>(key)) {
auto loadPtr = checkHashtable(load->getPointer());
if (ptr == loadPtr && inst->getType() == load->getType()) {
return value;
}
}
}
return inst;
}
Value *GVNContext::getValueNumber(CallInst *inst) {
// 只为无副作用的函数调用进行GVN
if (sideEffectAnalysis && !sideEffectAnalysis->isPureFunction(inst->getCallee())) {
return nullptr;
}
for (auto [key, value] : hashtable) {
if (auto call = dynamic_cast<CallInst *>(key)) {
if (call->getCallee() == inst->getCallee() && call->getNumOperands() == inst->getNumOperands()) {
bool argsMatch = true;
// 跳过第一个操作数(函数指针),从参数开始比较
for (size_t i = 1; i < inst->getNumOperands(); ++i) {
if (checkHashtable(inst->getOperand(i)) != checkHashtable(call->getOperand(i))) {
argsMatch = false;
break;
}
}
if (argsMatch) {
return value;
}
}
}
}
return inst;
}
void GVNContext::visitInstruction(Instruction *inst) {
// 跳过分支指令
if (inst->isBranch()) {
if (DEBUG) {
std::cout << " Skipping branch instruction: " << inst->getName() << std::endl;
}
return;
}
if (DEBUG) {
std::cout << " Visiting instruction: " << inst->getName()
<< " (kind: " << static_cast<int>(inst->getKind()) << ")" << std::endl;
}
auto value = checkHashtable(inst);
if (inst != value) {
if (auto instValue = dynamic_cast<Instruction *>(value)) {
if (canReplace(inst, instValue)) {
inst->replaceAllUsesWith(instValue);
needRemove.insert(inst);
if (DEBUG) {
std::cout << " GVN: Replacing redundant instruction " << inst->getName()
<< " with existing instruction " << instValue->getName() << std::endl;
}
} else {
if (DEBUG) {
std::cout << " Cannot replace instruction " << inst->getName()
<< " with " << instValue->getName() << " (dominance check failed)" << std::endl;
}
}
}
} else {
if (DEBUG) {
std::cout << " Instruction " << inst->getName() << " is unique" << std::endl;
}
}
}
bool GVNContext::canReplace(Instruction *original, Value *replacement) {
auto replInst = dynamic_cast<Instruction *>(replacement);
if (!replInst) {
return true; // 替换为常量总是安全的
}
auto originalBB = original->getParent();
auto replBB = replInst->getParent();
// 如果replacement是Call指令需要特殊处理
if (auto callInst = dynamic_cast<CallInst *>(replInst)) {
if (sideEffectAnalysis && !sideEffectAnalysis->isPureFunction(callInst->getCallee())) {
// 对于有副作用的函数,只有在同一个基本块且相邻时才能替换
if (originalBB != replBB) {
return false;
}
// 检查指令顺序
auto &insts = originalBB->getInstructions();
auto origIt =
std::find_if(insts.begin(), insts.end(), [original](const auto &ptr) { return ptr.get() == original; });
auto replIt =
std::find_if(insts.begin(), insts.end(), [replInst](const auto &ptr) { return ptr.get() == replInst; });
if (origIt == insts.end() || replIt == insts.end()) {
return false;
}
return std::abs(std::distance(origIt, replIt)) == 1;
}
}
// 简单的支配关系检查:如果在同一个基本块,检查指令顺序
if (originalBB == replBB) {
auto &insts = originalBB->getInstructions();
auto origIt =
std::find_if(insts.begin(), insts.end(), [original](const auto &ptr) { return ptr.get() == original; });
auto replIt =
std::find_if(insts.begin(), insts.end(), [replInst](const auto &ptr) { return ptr.get() == replInst; });
// 替换指令必须在原指令之前
return std::distance(insts.begin(), replIt) < std::distance(insts.begin(), origIt);
}
// 使用支配关系检查(如果支配树分析可用)
if (domTree) {
auto dominators = domTree->getDominators(originalBB);
if (dominators && dominators->count(replBB)) {
return true;
}
}
return false;
}
std::string GVNContext::getCanonicalExpression(Instruction *inst) {
std::ostringstream oss;
if (auto binary = dynamic_cast<BinaryInst *>(inst)) {
oss << "binary_" << static_cast<int>(binary->getKind()) << "_";
oss << checkHashtable(binary->getLhs()) << "_";
oss << checkHashtable(binary->getRhs());
} else if (auto unary = dynamic_cast<UnaryInst *>(inst)) {
oss << "unary_" << static_cast<int>(unary->getKind()) << "_";
oss << checkHashtable(unary->getOperand());
} else if (auto gep = dynamic_cast<GetElementPtrInst *>(inst)) {
oss << "gep_" << checkHashtable(gep->getBasePointer());
for (unsigned i = 0; i < gep->getNumIndices(); ++i) {
oss << "_" << checkHashtable(gep->getIndex(i));
}
}
return oss.str();
}
} // namespace sysy

View File

@@ -10,6 +10,7 @@
#include "DCE.h"
#include "Mem2Reg.h"
#include "Reg2Mem.h"
#include "GVN.h"
#include "SCCP.h"
#include "BuildCFG.h"
#include "LargeArrayToGlobal.h"
@@ -60,6 +61,8 @@ void PassManager::runOptimizationPipeline(Module* moduleIR, IRBuilder* builderIR
registerOptimizationPass<BuildCFG>();
registerOptimizationPass<LargeArrayToGlobalPass>();
registerOptimizationPass<GVN>();
registerOptimizationPass<SysYDelInstAfterBrPass>();
registerOptimizationPass<SysYDelNoPreBLockPass>();
registerOptimizationPass<SysYBlockMergePass>();
@@ -129,6 +132,16 @@ void PassManager::runOptimizationPipeline(Module* moduleIR, IRBuilder* builderIR
printPasses();
}
// 添加GVN优化遍
this->clearPasses();
this->addPass(&GVN::ID);
this->run();
if(DEBUG) {
std::cout << "=== IR After GVN Optimizations ===\n";
printPasses();
}
this->clearPasses();
this->addPass(&SCCP::ID);
this->run();