From a5318a2c5c30a81e562e0d8f037d9c4d5deb38ba Mon Sep 17 00:00:00 2001 From: CGH0S7 <776459475@qq.com> Date: Thu, 31 Jul 2025 20:46:35 +0800 Subject: [PATCH] =?UTF-8?q?=E4=B8=BA=E4=B8=AD=E7=AB=AF=E5=8A=A0=E5=85=A5?= =?UTF-8?q?=E5=B8=B8=E9=87=8F=E4=BC=A0=E6=92=ADPass?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../midend/Pass/Optimize/ConstPropagation.h | 14 + src/midend/CMakeLists.txt | 1 + src/midend/Pass/Optimize/ConstPropagation.cpp | 241 ++++++++++++++++++ src/midend/Pass/Pass.cpp | 2 + 4 files changed, 258 insertions(+) create mode 100644 src/include/midend/Pass/Optimize/ConstPropagation.h create mode 100644 src/midend/Pass/Optimize/ConstPropagation.cpp diff --git a/src/include/midend/Pass/Optimize/ConstPropagation.h b/src/include/midend/Pass/Optimize/ConstPropagation.h new file mode 100644 index 0000000..605bc20 --- /dev/null +++ b/src/include/midend/Pass/Optimize/ConstPropagation.h @@ -0,0 +1,14 @@ +#pragma once + +#include "Pass.h" + +namespace sysy { + +class ConstPropagation : public OptimizationPass { +public: + ConstPropagation() : OptimizationPass("ConstPropagation", Granularity::Function) {} + bool runOnFunction(Function *F, AnalysisManager& AM) override; + static char ID; +}; + +} // namespace sysy \ No newline at end of file diff --git a/src/midend/CMakeLists.txt b/src/midend/CMakeLists.txt index 4830893..07253eb 100644 --- a/src/midend/CMakeLists.txt +++ b/src/midend/CMakeLists.txt @@ -10,6 +10,7 @@ add_library(midend_lib STATIC Pass/Optimize/Mem2Reg.cpp Pass/Optimize/Reg2Mem.cpp Pass/Optimize/SysYIRCFGOpt.cpp + Pass/Optimize/ConstPropagation.cpp ) # 包含中端模块所需的头文件路径 diff --git a/src/midend/Pass/Optimize/ConstPropagation.cpp b/src/midend/Pass/Optimize/ConstPropagation.cpp new file mode 100644 index 0000000..86f65b9 --- /dev/null +++ b/src/midend/Pass/Optimize/ConstPropagation.cpp @@ -0,0 +1,241 @@ +#include "Pass/Optimize/ConstPropagation.h" +#include "IR.h" +#include "Pass.h" +#include +#include + +namespace sysy { + +char ConstPropagation::ID = 0; + +bool ConstPropagation::runOnFunction(Function *func, AnalysisManager &am) { + bool changed = false; + bool localChanged = true; + + while (localChanged) { + localChanged = false; + + for (auto &bb : func->getBasicBlocks()) { + for (auto instIter = bb->getInstructions().begin(); + instIter != bb->getInstructions().end();) { + auto &inst = *instIter; + bool shouldAdvanceIter = true; + + // 处理二元运算指令 + if (auto *binaryInst = dynamic_cast(inst.get())) { + auto *lhs = binaryInst->getLhs(); + auto *rhs = binaryInst->getRhs(); + + auto *lhsConst = dynamic_cast(lhs); + auto *rhsConst = dynamic_cast(rhs); + + if (lhsConst && rhsConst) { + ConstantValue *newConst = nullptr; + + try { + if (lhs->isInt() && rhs->isInt()) { + int l = lhsConst->getInt(); + int r = rhsConst->getInt(); + int result; + bool validOperation = true; + + switch (binaryInst->getKind()) { + case Instruction::kAdd: + // 检查加法溢出 + if ((r > 0 && l > INT_MAX - r) || (r < 0 && l < INT_MIN - r)) { + validOperation = false; + } else { + result = l + r; + } + break; + case Instruction::kSub: + // 检查减法溢出 + if ((r < 0 && l > INT_MAX + r) || (r > 0 && l < INT_MIN + r)) { + validOperation = false; + } else { + result = l - r; + } + break; + case Instruction::kMul: + // 检查乘法溢出 + if (l != 0 && r != 0 && + (std::abs(l) > INT_MAX / std::abs(r))) { + validOperation = false; + } else { + result = l * r; + } + break; + case Instruction::kDiv: + if (r == 0) { + validOperation = false; + } else { + result = l / r; + } + break; + case Instruction::kRem: + if (r == 0) { + validOperation = false; + } else { + result = l % r; + } + break; + case Instruction::kICmpEQ: result = (l == r) ? 1 : 0; break; + case Instruction::kICmpNE: result = (l != r) ? 1 : 0; break; + case Instruction::kICmpLT: result = (l < r) ? 1 : 0; break; + case Instruction::kICmpGT: result = (l > r) ? 1 : 0; break; + case Instruction::kICmpLE: result = (l <= r) ? 1 : 0; break; + case Instruction::kICmpGE: result = (l >= r) ? 1 : 0; break; + case Instruction::kAnd: result = (l && r) ? 1 : 0; break; + case Instruction::kOr: result = (l || r) ? 1 : 0; break; + default: + validOperation = false; + } + + if (validOperation) { + if (binaryInst->isCmp() || binaryInst->getKind() == Instruction::kAnd || + binaryInst->getKind() == Instruction::kOr) { + newConst = ConstantInteger::get(Type::getIntType(), result); + } else { + newConst = ConstantInteger::get(result); + } + } + } else if (lhs->isFloat() && rhs->isFloat()) { + float l = lhsConst->getFloat(); + float r = rhsConst->getFloat(); + bool validOperation = true; + + switch (binaryInst->getKind()) { + case Instruction::kFAdd: { + float result = l + r; + if (std::isfinite(result)) { + newConst = ConstantFloating::get(result); + } else { + validOperation = false; + } + break; + } + case Instruction::kFSub: { + float result = l - r; + if (std::isfinite(result)) { + newConst = ConstantFloating::get(result); + } else { + validOperation = false; + } + break; + } + case Instruction::kFMul: { + float result = l * r; + if (std::isfinite(result)) { + newConst = ConstantFloating::get(result); + } else { + validOperation = false; + } + break; + } + case Instruction::kFDiv: { + if (std::abs(r) < std::numeric_limits::epsilon()) { + validOperation = false; + } else { + float result = l / r; + if (std::isfinite(result)) { + newConst = ConstantFloating::get(result); + } else { + validOperation = false; + } + } + break; + } + case Instruction::kFCmpEQ: + newConst = ConstantInteger::get(Type::getIntType(), (l == r) ? 1 : 0); + break; + case Instruction::kFCmpNE: + newConst = ConstantInteger::get(Type::getIntType(), (l != r) ? 1 : 0); + break; + case Instruction::kFCmpLT: + newConst = ConstantInteger::get(Type::getIntType(), (l < r) ? 1 : 0); + break; + case Instruction::kFCmpGT: + newConst = ConstantInteger::get(Type::getIntType(), (l > r) ? 1 : 0); + break; + case Instruction::kFCmpLE: + newConst = ConstantInteger::get(Type::getIntType(), (l <= r) ? 1 : 0); + break; + case Instruction::kFCmpGE: + newConst = ConstantInteger::get(Type::getIntType(), (l >= r) ? 1 : 0); + break; + default: + validOperation = false; + } + } + } catch (...) { + // 捕获可能的异常,跳过优化 + newConst = nullptr; + } + + if (newConst) { + binaryInst->replaceAllUsesWith(newConst); + instIter = bb->getInstructions().erase(instIter); + shouldAdvanceIter = false; + localChanged = true; + } + } + } + // 处理一元运算指令 + else if (auto *unaryInst = dynamic_cast(inst.get())) { + auto *operand = unaryInst->getOperand(); + auto *operandConst = dynamic_cast(operand); + + if (operandConst) { + ConstantValue *newConst = nullptr; + + if (operand->isInt()) { + int val = operandConst->getInt(); + + switch (unaryInst->getKind()) { + case Instruction::kNeg: + if (val != INT_MIN) { // 避免溢出 + newConst = ConstantInteger::get(-val); + } + break; + case Instruction::kNot: + newConst = ConstantInteger::get(Type::getIntType(), (!val) ? 1 : 0); + break; + default: + break; + } + } else if (operand->isFloat()) { + float val = operandConst->getFloat(); + + switch (unaryInst->getKind()) { + case Instruction::kFNeg: + newConst = ConstantFloating::get(-val); + break; + default: + break; + } + } + + if (newConst) { + unaryInst->replaceAllUsesWith(newConst); + instIter = bb->getInstructions().erase(instIter); + shouldAdvanceIter = false; + localChanged = true; + } + } + } + + if (shouldAdvanceIter) { + ++instIter; + } + } + } + + if (localChanged) { + changed = true; + } + } + + return changed; +} + +} // namespace sysy \ No newline at end of file diff --git a/src/midend/Pass/Pass.cpp b/src/midend/Pass/Pass.cpp index d6087fe..643efe9 100644 --- a/src/midend/Pass/Pass.cpp +++ b/src/midend/Pass/Pass.cpp @@ -5,6 +5,7 @@ #include "DCE.h" #include "Mem2Reg.h" #include "Reg2Mem.h" +#include "ConstPropagation.h" #include "Pass.h" #include #include @@ -80,6 +81,7 @@ void PassManager::runOptimizationPipeline(Module* moduleIR, IRBuilder* builderIR this->clearPasses(); this->addPass(&Mem2Reg::ID); + this->addPass(&ConstPropagation::ID); this->run(); if(DEBUG) {