From ea944f6ba007716545cde08f4d382a3030097027 Mon Sep 17 00:00:00 2001 From: rain2133 <1370973498@qq.com> Date: Wed, 13 Aug 2025 01:13:01 +0800 Subject: [PATCH] =?UTF-8?q?[midend-Loop-InductionVarStrengthReduction]?= =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E5=BE=AA=E7=8E=AF=E8=A7=84=E7=BA=A6=E5=8F=98?= =?UTF-8?q?=E9=87=8F=E5=BC=BA=E5=BA=A6=E5=89=8A=E5=BC=B1=E4=BC=98=E5=8C=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../Pass/Optimize/LoopStrengthReduction.h | 178 ++++++ src/midend/CMakeLists.txt | 1 + .../Pass/Analysis/LoopCharacteristics.cpp | 164 +++++- .../Pass/Optimize/LoopStrengthReduction.cpp | 535 ++++++++++++++++++ src/midend/Pass/Pass.cpp | 11 +- src/sysyc.cpp | 2 +- 6 files changed, 883 insertions(+), 8 deletions(-) create mode 100644 src/include/midend/Pass/Optimize/LoopStrengthReduction.h create mode 100644 src/midend/Pass/Optimize/LoopStrengthReduction.cpp diff --git a/src/include/midend/Pass/Optimize/LoopStrengthReduction.h b/src/include/midend/Pass/Optimize/LoopStrengthReduction.h new file mode 100644 index 0000000..31397b1 --- /dev/null +++ b/src/include/midend/Pass/Optimize/LoopStrengthReduction.h @@ -0,0 +1,178 @@ +#pragma once + +#include "Pass.h" +#include "IR.h" +#include "LoopCharacteristics.h" +#include "Loop.h" +#include "Dom.h" +#include +#include +#include +#include + +namespace sysy { + +// 前向声明 +class LoopCharacteristicsResult; +class LoopAnalysisResult; + +/** + * @brief 强度削弱候选项信息 + * 记录一个可以进行强度削弱的表达式信息 + */ +struct StrengthReductionCandidate { + Instruction* originalInst; // 原始指令 (如 i*4) + Value* inductionVar; // 归纳变量 (如 i) + int multiplier; // 乘数 (如 4) + int offset; // 偏移量 (如常数项) + BasicBlock* containingBlock; // 所在基本块 + Loop* containingLoop; // 所在循环 + + // 强度削弱后的新变量 + PhiInst* newPhi = nullptr; // 新的 phi 指令 + Value* newInductionVar = nullptr; // 新的归纳变量 (递增 multiplier) + + StrengthReductionCandidate(Instruction* inst, Value* iv, int mult, int off, + BasicBlock* bb, Loop* loop) + : originalInst(inst), inductionVar(iv), multiplier(mult), offset(off), + containingBlock(bb), containingLoop(loop) {} +}; + +/** + * @brief 强度削弱上下文类 + * 封装强度削弱优化的核心逻辑和状态 + */ +class StrengthReductionContext { +public: + StrengthReductionContext(IRBuilder* builder) : builder(builder) {} + + /** + * 运行强度削弱优化 + * @param F 目标函数 + * @param AM 分析管理器 + * @return 是否修改了IR + */ + bool run(Function* F, AnalysisManager& AM); + +private: + IRBuilder* builder; + + // 分析结果缓存 + LoopAnalysisResult* loopAnalysis = nullptr; + LoopCharacteristicsResult* loopCharacteristics = nullptr; + DominatorTree* dominatorTree = nullptr; + + // 候选项存储 + std::vector> candidates; + std::unordered_map> loopToCandidates; + + // ========== 核心分析和优化阶段 ========== + + /** + * 阶段1:识别强度削弱候选项 + * 扫描所有循环中的乘法指令,找出可以优化的模式 + */ + void identifyStrengthReductionCandidates(Function* F); + + /** + * 阶段2:分析候选项的优化潜力 + * 评估每个候选项的收益,过滤掉不值得优化的情况 + */ + void analyzeOptimizationPotential(); + + /** + * 阶段3:执行强度削弱变换 + * 对选中的候选项执行实际的强度削弱优化 + */ + bool performStrengthReduction(); + + // ========== 辅助方法 ========== + + /** + * 检查指令是否为强度削弱候选项 + * @param inst 要检查的指令 + * @param loop 所在循环 + * @return 如果是候选项返回候选项信息,否则返回nullptr + */ + std::unique_ptr + isStrengthReductionCandidate(Instruction* inst, Loop* loop); + + /** + * 检查值是否为循环的归纳变量 + * @param val 要检查的值 + * @param loop 循环 + * @param characteristics 循环特征信息 + * @return 如果是归纳变量返回归纳变量信息,否则返回nullptr + */ + const InductionVarInfo* + getInductionVarInfo(Value* val, Loop* loop, const LoopCharacteristics* characteristics); + + /** + * 为候选项创建新的归纳变量 + * @param candidate 候选项 + * @return 是否成功创建 + */ + bool createNewInductionVariable(StrengthReductionCandidate* candidate); + + /** + * 替换原始指令的所有使用 + * @param candidate 候选项 + * @return 是否成功替换 + */ + bool replaceOriginalInstruction(StrengthReductionCandidate* candidate); + + /** + * 估算优化收益 + * 计算强度削弱后的性能提升 + * @param candidate 候选项 + * @return 估算的收益分数 + */ + double estimateOptimizationBenefit(const StrengthReductionCandidate* candidate); + + /** + * 检查优化的合法性 + * @param candidate 候选项 + * @return 是否可以安全地进行优化 + */ + bool isOptimizationLegal(const StrengthReductionCandidate* candidate); + + /** + * 打印调试信息 + */ + void printDebugInfo(); +}; + +/** + * @brief 循环强度削弱优化遍 + * 将循环中的乘法运算转换为更高效的加法运算 + */ +class LoopStrengthReduction : public OptimizationPass { +public: + // 唯一的 Pass ID + static void *ID; + + LoopStrengthReduction(IRBuilder* builder) + : OptimizationPass("LoopStrengthReduction", Granularity::Function), + builder(builder) {} + + /** + * 在函数上运行强度削弱优化 + * @param F 目标函数 + * @param AM 分析管理器 + * @return 是否修改了IR + */ + bool runOnFunction(Function* F, AnalysisManager& AM) override; + + /** + * 声明分析依赖和失效信息 + */ + void getAnalysisUsage(std::set& analysisDependencies, + std::set& analysisInvalidations) const override; + + void* getPassID() const override { return &ID; } + +private: + IRBuilder* builder; +}; + +} // namespace sysy diff --git a/src/midend/CMakeLists.txt b/src/midend/CMakeLists.txt index 870e965..93d1468 100644 --- a/src/midend/CMakeLists.txt +++ b/src/midend/CMakeLists.txt @@ -19,6 +19,7 @@ add_library(midend_lib STATIC Pass/Optimize/SCCP.cpp Pass/Optimize/LoopNormalization.cpp Pass/Optimize/LICM.cpp + Pass/Optimize/LoopStrengthReduction.cpp Pass/Optimize/BuildCFG.cpp Pass/Optimize/LargeArrayToGlobal.cpp ) diff --git a/src/midend/Pass/Analysis/LoopCharacteristics.cpp b/src/midend/Pass/Analysis/LoopCharacteristics.cpp index c6f6d2c..87c4549 100644 --- a/src/midend/Pass/Analysis/LoopCharacteristics.cpp +++ b/src/midend/Pass/Analysis/LoopCharacteristics.cpp @@ -306,22 +306,74 @@ void LoopCharacteristicsPass::identifyBasicInductionVariables( BasicBlock* header = loop->getHeader(); std::vector> ivs; + if (DEBUG) { + std::cout << " === Identifying Induction Variables for Loop: " << loop->getName() << " ===" << std::endl; + std::cout << " Loop header: " << header->getName() << std::endl; + std::cout << " Loop blocks: "; + for (auto* bb : loop->getBlocks()) { + std::cout << bb->getName() << " "; + } + std::cout << std::endl; + } + // 1. 识别所有BIV for (auto& inst : header->getInstructions()) { auto* phi = dynamic_cast(inst.get()); if (!phi) continue; if (isBasicInductionVariable(phi, loop)) { ivs.push_back(InductionVarInfo::createBasicBIV(phi, Instruction::Kind::kPhi)); - if (DEBUG) std::cout << " Found basic induction variable: " << phi->getName() << std::endl; + if (DEBUG) { + std::cout << " [BIV] Found basic induction variable: " << phi->getName() << std::endl; + std::cout << " Incoming values: "; + for (auto& [incomingBB, incomingVal] : phi->getIncomingValues()) { + std::cout << "{" << incomingBB->getName() << ": " << incomingVal->getName() << "} "; + } + std::cout << std::endl; + } } } + if (DEBUG) { + std::cout << " Found " << ivs.size() << " basic induction variables" << std::endl; + } + // 2. 递归识别所有派生DIV std::set visited; + size_t initialSize = ivs.size(); for (const auto& biv : ivs) { + if (DEBUG) { + std::cout << " Searching for derived IVs from BIV: " << biv->div->getName() << std::endl; + } findDerivedInductionVars(biv->div, biv->base, loop, ivs, visited); } + if (DEBUG) { + size_t derivedCount = ivs.size() - initialSize; + std::cout << " Found " << derivedCount << " derived induction variables" << std::endl; + + // 打印所有归纳变量的详细信息 + std::cout << " === Final Induction Variables Summary ===" << std::endl; + for (size_t i = 0; i < ivs.size(); ++i) { + const auto& iv = ivs[i]; + std::cout << " [" << i << "] " << iv->div->getName() + << " (kind: " << (iv->ivkind == IVKind::kBasic ? "Basic" : + iv->ivkind == IVKind::kLinear ? "Linear" : "Complex") << ")" << std::endl; + std::cout << " Operation: " << static_cast(iv->Instkind) << std::endl; + if (iv->base) { + std::cout << " Base: " << iv->base->getName() << std::endl; + } + if (iv->Multibase.first || iv->Multibase.second) { + std::cout << " Multi-base: "; + if (iv->Multibase.first) std::cout << iv->Multibase.first->getName() << " "; + if (iv->Multibase.second) std::cout << iv->Multibase.second->getName() << " "; + std::cout << std::endl; + } + std::cout << " Factor: " << iv->factor << ", Offset: " << iv->offset << std::endl; + std::cout << " Valid: " << (iv->valid ? "Yes" : "No") << std::endl; + } + std::cout << " =============================================" << std::endl; + } + characteristics->InductionVars = std::move(ivs); } @@ -342,52 +394,97 @@ static LinearExpr analyzeLinearExpr(Value* val, Loop* loop, std::vector= 2) { // 更详细的调试级别 + if (auto* inst = dynamic_cast(val)) { + std::cout << " Analyzing linear expression for: " << val->getName() + << " (kind: " << static_cast(inst->getKind()) << ")" << std::endl; + } else { + std::cout << " Analyzing linear expression for value: " << val->getName() << std::endl; + } + } + // 基本变量:常数 if (auto* cint = dynamic_cast(val)) { + if (DEBUG >= 2) { + std::cout << " -> Constant: " << cint->getInt() << std::endl; + } return {nullptr, nullptr, 0, 0, cint->getInt(), true, false}; } + // 基本变量:BIV或派生IV for (auto& iv : ivs) { if (iv->div == val) { if (iv->ivkind == IVKind::kBasic || iv->ivkind == IVKind::kLinear) { + if (DEBUG >= 2) { + std::cout << " -> Found " << (iv->ivkind == IVKind::kBasic ? "Basic" : "Linear") + << " IV with base: " << (iv->base ? iv->base->getName() : "null") + << ", factor: " << iv->factor << ", offset: " << iv->offset << std::endl; + } return {iv->base, nullptr, iv->factor, 0, iv->offset, true, true}; } // 复杂归纳变量 if (iv->ivkind == IVKind::kCmplx) { + if (DEBUG >= 2) { + std::cout << " -> Found Complex IV with multi-base" << std::endl; + } return {iv->Multibase.first, iv->Multibase.second, 1, 1, 0, true, false}; } } } + // 一元负号 if (auto* inst = dynamic_cast(val)) { auto kind = inst->getKind(); if (kind == Instruction::Kind::kNeg) { + if (DEBUG >= 2) { + std::cout << " -> Analyzing negation" << std::endl; + } auto expr = analyzeLinearExpr(inst->getOperand(0), loop, ivs); if (!expr.valid) return expr; expr.factor1 = -expr.factor1; expr.factor2 = -expr.factor2; expr.offset = -expr.offset; expr.isSimple = (expr.base2 == nullptr); + if (DEBUG >= 2) { + std::cout << " -> Negation result: valid=" << expr.valid << ", simple=" << expr.isSimple << std::endl; + } return expr; } + // 二元加减乘 if (kind == Instruction::Kind::kAdd || kind == Instruction::Kind::kSub) { + if (DEBUG >= 2) { + std::cout << " -> Analyzing " << (kind == Instruction::Kind::kAdd ? "addition" : "subtraction") << std::endl; + } auto expr0 = analyzeLinearExpr(inst->getOperand(0), loop, ivs); auto expr1 = analyzeLinearExpr(inst->getOperand(1), loop, ivs); - if (!expr0.valid || !expr1.valid) return {nullptr, nullptr, 0, 0, 0, false, false}; + if (!expr0.valid || !expr1.valid) { + if (DEBUG >= 2) { + std::cout << " -> Failed: operand not linear (expr0.valid=" << expr0.valid << ", expr1.valid=" << expr1.valid << ")" << std::endl; + } + return {nullptr, nullptr, 0, 0, 0, false, false}; + } + // 合并:若BIV相同或有一个是常数 // 单BIV+常数 if (expr0.base1 && !expr1.base1 && !expr1.base2) { int sign = (kind == Instruction::Kind::kAdd ? 1 : -1); + if (DEBUG >= 2) { + std::cout << " -> Single BIV + constant pattern" << std::endl; + } return {expr0.base1, nullptr, expr0.factor1, 0, expr0.offset + sign * expr1.offset, true, expr0.isSimple}; } if (!expr0.base1 && !expr0.base2 && expr1.base1) { int sign = (kind == Instruction::Kind::kAdd ? 1 : -1); int f = sign * expr1.factor1; int off = expr0.offset + sign * expr1.offset; + if (DEBUG >= 2) { + std::cout << " -> Constant + single BIV pattern" << std::endl; + } return {expr1.base1, nullptr, f, 0, off, true, expr1.isSimple}; } + // 双BIV线性组合 if (expr0.base1 && expr1.base1 && expr0.base1 != expr1.base1 && !expr0.base2 && !expr1.base2) { int sign = (kind == Instruction::Kind::kAdd ? 1 : -1); @@ -396,31 +493,56 @@ static LinearExpr analyzeLinearExpr(Value* val, Loop* loop, std::vector= 2) { + std::cout << " -> Double BIV linear combination" << std::endl; + } return {base1, base2, f1, f2, off, true, false}; } + // 同BIV合并 if (expr0.base1 && expr1.base1 && expr0.base1 == expr1.base1 && !expr0.base2 && !expr1.base2) { int sign = (kind == Instruction::Kind::kAdd ? 1 : -1); int f = expr0.factor1 + sign * expr1.factor1; int off = expr0.offset + sign * expr1.offset; + if (DEBUG >= 2) { + std::cout << " -> Same BIV combination" << std::endl; + } return {expr0.base1, nullptr, f, 0, off, true, true}; } } + // 乘法:BIV*const 或 const*BIV if (kind == Instruction::Kind::kMul) { + if (DEBUG >= 2) { + std::cout << " -> Analyzing multiplication" << std::endl; + } auto expr0 = analyzeLinearExpr(inst->getOperand(0), loop, ivs); auto expr1 = analyzeLinearExpr(inst->getOperand(1), loop, ivs); + // 只允许一侧为常数 if (expr0.base1 && !expr1.base1 && !expr1.base2 && expr1.offset) { + if (DEBUG >= 2) { + std::cout << " -> BIV * constant pattern" << std::endl; + } return {expr0.base1, nullptr, expr0.factor1 * expr1.offset, 0, expr0.offset * expr1.offset, true, true}; } if (!expr0.base1 && !expr0.base2 && expr0.offset && expr1.base1) { + if (DEBUG >= 2) { + std::cout << " -> Constant * BIV pattern" << std::endl; + } return {expr1.base1, nullptr, expr1.factor1 * expr0.offset, 0, expr1.offset * expr0.offset, true, true}; } // 双BIV乘法不支持 + if (DEBUG >= 2) { + std::cout << " -> Multiplication pattern not supported" << std::endl; + } } } + // 其它情况 + if (DEBUG >= 2) { + std::cout << " -> Other case: not linear" << std::endl; + } return {nullptr, nullptr, 0, 0, 0, false, false}; } @@ -506,19 +628,40 @@ void LoopCharacteristicsPass::findDerivedInductionVars( if (visited.count(root)) return; visited.insert(root); + if (DEBUG) { + std::cout << " Analyzing uses of: " << root->getName() << std::endl; + } + for (auto use : root->getUses()) { auto user = use->getUser(); Instruction* inst = dynamic_cast(user); if (!inst) continue; - if (!loop->contains(inst->getParent())) continue; + if (!loop->contains(inst->getParent())) { + if (DEBUG) { + std::cout << " Skipping user outside loop: " << inst->getName() << std::endl; + } + continue; + } + + if (DEBUG) { + std::cout << " Checking instruction: " << inst->getName() + << " (kind: " << static_cast(inst->getKind()) << ")" << std::endl; + } // 下面是一个例子:假设你有线性归约分析(可用analyzeLinearExpr等递归辅助) auto expr = analyzeLinearExpr(inst, loop, ivs); if (!expr.valid) { + if (DEBUG) { + std::cout << " Linear expression analysis failed for: " << inst->getName() << std::endl; + } // 复杂非线性归纳变量,作为kCmplx记录(假如你想追踪) // 这里假设expr.base1、base2都有效才记录double if (expr.base1 && expr.base2) { + if (DEBUG) { + std::cout << " [DIV-COMPLEX] Creating complex derived IV: " << inst->getName() + << " with bases: " << expr.base1->getName() << ", " << expr.base2->getName() << std::endl; + } ivs.push_back(InductionVarInfo::createDoubleDIV(inst, inst->getKind(), expr.base1, expr.base2, 0, expr.offset)); } continue; @@ -526,15 +669,30 @@ void LoopCharacteristicsPass::findDerivedInductionVars( // 单BIV线性 if (expr.base1 && !expr.base2) { + if (DEBUG) { + std::cout << " [DIV-LINEAR] Creating single-base derived IV: " << inst->getName() + << " with base: " << expr.base1->getName() + << ", factor: " << expr.factor1 + << ", offset: " << expr.offset << std::endl; + } ivs.push_back(InductionVarInfo::createSingleDIV(inst, inst->getKind(), expr.base1, expr.factor1, expr.offset)); findDerivedInductionVars(inst, expr.base1, loop, ivs, visited); } // 双BIV线性 else if (expr.base1 && expr.base2) { + if (DEBUG) { + std::cout << " [DIV-COMPLEX] Creating double-base derived IV: " << inst->getName() + << " with bases: " << expr.base1->getName() << ", " << expr.base2->getName() + << ", offset: " << expr.offset << std::endl; + } ivs.push_back(InductionVarInfo::createDoubleDIV(inst, inst->getKind(), expr.base1, expr.base2, 0, expr.offset)); // 双BIV情形一般不再递归下游 } } + + if (DEBUG) { + std::cout << " Finished analyzing uses of: " << root->getName() << std::endl; + } } // 递归/推进式判定 diff --git a/src/midend/Pass/Optimize/LoopStrengthReduction.cpp b/src/midend/Pass/Optimize/LoopStrengthReduction.cpp new file mode 100644 index 0000000..f1c6266 --- /dev/null +++ b/src/midend/Pass/Optimize/LoopStrengthReduction.cpp @@ -0,0 +1,535 @@ +#include "LoopStrengthReduction.h" +#include "LoopCharacteristics.h" +#include "Loop.h" +#include "Dom.h" +#include "IRBuilder.h" +#include "SysYIROptUtils.h" +#include +#include +#include + +// 使用全局调试开关 +extern int DEBUG; + +namespace sysy { + +// 定义 Pass 的唯一 ID +void *LoopStrengthReduction::ID = (void *)&LoopStrengthReduction::ID; + +bool LoopStrengthReduction::runOnFunction(Function* F, AnalysisManager& AM) { + if (F->getBasicBlocks().empty()) { + return false; // 空函数 + } + + if (DEBUG) { + std::cout << "Running LoopStrengthReduction on function: " << F->getName() << std::endl; + } + + // 创建优化上下文并运行 + StrengthReductionContext context(builder); + bool modified = context.run(F, AM); + + if (DEBUG) { + std::cout << "LoopStrengthReduction " << (modified ? "modified" : "did not modify") + << " function: " << F->getName() << std::endl; + } + + return modified; +} + +void LoopStrengthReduction::getAnalysisUsage(std::set& analysisDependencies, + std::set& analysisInvalidations) const { + // 依赖的分析 + analysisDependencies.insert(&LoopAnalysisPass::ID); + analysisDependencies.insert(&LoopCharacteristicsPass::ID); + analysisDependencies.insert(&DominatorTreeAnalysisPass::ID); + + // 会使失效的分析(强度削弱会修改IR结构) + analysisInvalidations.insert(&LoopCharacteristicsPass::ID); + // 注意:支配树分析通常不会因为强度削弱而失效,因为我们不改变控制流 +} + +// ========== StrengthReductionContext 实现 ========== + +bool StrengthReductionContext::run(Function* F, AnalysisManager& AM) { + if (DEBUG) { + std::cout << " Starting strength reduction analysis..." << std::endl; + } + + // 获取必要的分析结果 + loopAnalysis = AM.getAnalysisResult(F); + if (!loopAnalysis || !loopAnalysis->hasLoops()) { + if (DEBUG) { + std::cout << " No loops found, skipping strength reduction" << std::endl; + } + return false; + } + + loopCharacteristics = AM.getAnalysisResult(F); + if (!loopCharacteristics) { + if (DEBUG) { + std::cout << " LoopCharacteristics analysis not available" << std::endl; + } + return false; + } + + dominatorTree = AM.getAnalysisResult(F); + if (!dominatorTree) { + if (DEBUG) { + std::cout << " DominatorTree analysis not available" << std::endl; + } + return false; + } + + // 执行三个阶段的优化 + + // 阶段1:识别候选项 + identifyStrengthReductionCandidates(F); + + if (candidates.empty()) { + if (DEBUG) { + std::cout << " No strength reduction candidates found" << std::endl; + } + return false; + } + + if (DEBUG) { + std::cout << " Found " << candidates.size() << " potential candidates" << std::endl; + } + + // 阶段2:分析优化潜力 + analyzeOptimizationPotential(); + + // 阶段3:执行优化 + bool modified = performStrengthReduction(); + + if (DEBUG) { + printDebugInfo(); + } + + return modified; +} + +void StrengthReductionContext::identifyStrengthReductionCandidates(Function* F) { + if (DEBUG) { + std::cout << " === Phase 1: Identifying Strength Reduction Candidates ===" << std::endl; + } + + // 遍历所有循环 + for (const auto& loop_ptr : loopAnalysis->getAllLoops()) { + Loop* loop = loop_ptr.get(); + + if (DEBUG) { + std::cout << " Analyzing loop: " << loop->getName() << std::endl; + } + + // 获取循环特征 + const LoopCharacteristics* characteristics = loopCharacteristics->getCharacteristics(loop); + if (!characteristics) { + if (DEBUG) { + std::cout << " No characteristics available for loop" << std::endl; + } + continue; + } + + if (characteristics->InductionVars.empty()) { + if (DEBUG) { + std::cout << " No induction variables found in loop" << std::endl; + } + continue; + } + + // 遍历循环中的所有指令 + for (BasicBlock* bb : loop->getBlocks()) { + for (auto& inst_ptr : bb->getInstructions()) { + Instruction* inst = inst_ptr.get(); + + // 检查是否为强度削弱候选项 + auto candidate = isStrengthReductionCandidate(inst, loop); + if (candidate) { + if (DEBUG) { + std::cout << " Found candidate: %" << inst->getName() + << " (IV: %" << candidate->inductionVar->getName() + << ", multiplier: " << candidate->multiplier + << ", offset: " << candidate->offset << ")" << std::endl; + } + + // 添加到候选项列表 + loopToCandidates[loop].push_back(candidate.get()); + candidates.push_back(std::move(candidate)); + } + } + } + } + + if (DEBUG) { + std::cout << " === End Phase 1: Found " << candidates.size() << " candidates ===" << std::endl; + } +} + +std::unique_ptr +StrengthReductionContext::isStrengthReductionCandidate(Instruction* inst, Loop* loop) { + // 只考虑乘法指令 + if (inst->getKind() != Instruction::Kind::kMul) { + return nullptr; + } + + auto* mulInst = dynamic_cast(inst); + if (!mulInst) { + return nullptr; + } + + Value* op0 = mulInst->getOperand(0); + Value* op1 = mulInst->getOperand(1); + + // 检查模式:归纳变量 * 常数 或 常数 * 归纳变量 + Value* inductionVar = nullptr; + int multiplier = 0; + + // 获取循环特征信息 + const LoopCharacteristics* characteristics = loopCharacteristics->getCharacteristics(loop); + if (!characteristics) { + return nullptr; + } + + // 模式1: IV * const + const InductionVarInfo* ivInfo = getInductionVarInfo(op0, loop, characteristics); + if (ivInfo && dynamic_cast(op1)) { + inductionVar = op0; + multiplier = dynamic_cast(op1)->getInt(); + } + // 模式2: const * IV + else { + ivInfo = getInductionVarInfo(op1, loop, characteristics); + if (ivInfo && dynamic_cast(op0)) { + inductionVar = op1; + multiplier = dynamic_cast(op0)->getInt(); + } + } + + if (!inductionVar || multiplier <= 1) { + return nullptr; // 不是有效的候选项 + } + + // 创建候选项 + return std::make_unique( + inst, inductionVar, multiplier, 0, inst->getParent(), loop + ); +} + +const InductionVarInfo* +StrengthReductionContext::getInductionVarInfo(Value* val, Loop* loop, + const LoopCharacteristics* characteristics) { + for (const auto& iv : characteristics->InductionVars) { + if (iv->div == val) { + return iv.get(); + } + } + return nullptr; +} + +void StrengthReductionContext::analyzeOptimizationPotential() { + if (DEBUG) { + std::cout << " === Phase 2: Analyzing Optimization Potential ===" << std::endl; + } + + // 为每个候选项计算优化收益,并过滤不值得优化的 + auto it = candidates.begin(); + while (it != candidates.end()) { + StrengthReductionCandidate* candidate = it->get(); + + double benefit = estimateOptimizationBenefit(candidate); + bool isLegal = isOptimizationLegal(candidate); + + if (DEBUG) { + std::cout << " Candidate " << candidate->originalInst->getName() + << ": benefit=" << benefit + << ", legal=" << (isLegal ? "yes" : "no") << std::endl; + } + + // 如果收益太小或不合法,移除候选项 + if (benefit < 1.0 || !isLegal) { + // 从 loopToCandidates 中移除 + auto& loopCandidates = loopToCandidates[candidate->containingLoop]; + loopCandidates.erase( + std::remove(loopCandidates.begin(), loopCandidates.end(), candidate), + loopCandidates.end() + ); + + it = candidates.erase(it); + } else { + ++it; + } + } + + if (DEBUG) { + std::cout << " === End Phase 2: " << candidates.size() << " candidates remain ===" << std::endl; + } +} + +double StrengthReductionContext::estimateOptimizationBenefit(const StrengthReductionCandidate* candidate) { + // 简单的收益估算模型 + double benefit = 0.0; + + // 基础收益:乘法变加法的性能提升 + benefit += 2.0; // 假设乘法比加法慢2倍 + + // 乘数因子:乘数越大,收益越高 + if (candidate->multiplier >= 4) { + benefit += 1.0; + } + if (candidate->multiplier >= 8) { + benefit += 1.0; + } + + // 循环热度因子 + Loop* loop = candidate->containingLoop; + double hotness = loop->getLoopHotness(); + benefit *= (1.0 + hotness / 100.0); + + // 使用次数因子 + size_t useCount = candidate->originalInst->getUses().size(); + if (useCount > 1) { + benefit *= (1.0 + useCount * 0.2); + } + + return benefit; +} + +bool StrengthReductionContext::isOptimizationLegal(const StrengthReductionCandidate* candidate) { + // 检查优化的合法性 + + // 1. 确保归纳变量在循环头有 phi 指令 + auto* phiInst = dynamic_cast(candidate->inductionVar); + if (!phiInst || phiInst->getParent() != candidate->containingLoop->getHeader()) { + if (DEBUG >= 2) { + std::cout << " Illegal: induction variable is not a phi in loop header" << std::endl; + } + return false; + } + + // 2. 确保乘法指令在循环内 + if (!candidate->containingLoop->contains(candidate->containingBlock)) { + if (DEBUG >= 2) { + std::cout << " Illegal: instruction not in loop" << std::endl; + } + return false; + } + + // 3. 检查是否有溢出风险(简化检查) + if (candidate->multiplier > 1000) { + if (DEBUG >= 2) { + std::cout << " Illegal: multiplier too large (overflow risk)" << std::endl; + } + return false; + } + + // 4. 确保该指令不在循环的退出条件中(避免影响循环语义) + for (BasicBlock* exitingBB : candidate->containingLoop->getExitingBlocks()) { + auto terminatorIt = exitingBB->terminator(); + if (terminatorIt != exitingBB->end()) { + Instruction* terminator = terminatorIt->get(); + if (terminator && (terminator->getOperand(0) == candidate->originalInst || + (terminator->getNumOperands() > 1 && terminator->getOperand(1) == candidate->originalInst))) { + if (DEBUG >= 2) { + std::cout << " Illegal: instruction used in loop exit condition" << std::endl; + } + return false; + } + } + } + + return true; +} + +bool StrengthReductionContext::performStrengthReduction() { + if (DEBUG) { + std::cout << " === Phase 3: Performing Strength Reduction ===" << std::endl; + } + + bool modified = false; + + for (auto& candidate : candidates) { + if (DEBUG) { + std::cout << " Processing candidate: " << candidate->originalInst->getName() << std::endl; + } + + // 创建新的归纳变量 + if (!createNewInductionVariable(candidate.get())) { + if (DEBUG) { + std::cout << " Failed to create new induction variable" << std::endl; + } + continue; + } + + // 替换原始指令 + if (!replaceOriginalInstruction(candidate.get())) { + if (DEBUG) { + std::cout << " Failed to replace original instruction" << std::endl; + } + continue; + } + + if (DEBUG) { + std::cout << " Successfully optimized: " << candidate->originalInst->getName() + << " -> " << candidate->newInductionVar->getName() << std::endl; + } + + modified = true; + } + + if (DEBUG) { + std::cout << " === End Phase 3: " << (modified ? "Optimizations applied" : "No optimizations") << " ===" << std::endl; + } + + return modified; +} + +bool StrengthReductionContext::createNewInductionVariable(StrengthReductionCandidate* candidate) { + Loop* loop = candidate->containingLoop; + BasicBlock* header = loop->getHeader(); + BasicBlock* preheader = loop->getPreHeader(); + + if (!preheader) { + if (DEBUG) { + std::cout << " No preheader found for loop" << std::endl; + } + return false; + } + + // 获取原始归纳变量的 phi 指令 + auto* originalPhi = dynamic_cast(candidate->inductionVar); + if (!originalPhi) { + return false; + } + + // 1. 在循环头创建新的 phi 指令 + builder->setPosition(header, header->begin()); + candidate->newPhi = builder->createPhiInst(originalPhi->getType()); + candidate->newPhi->setName(originalPhi->getName() + "_sr"); + + // 2. 找到原始归纳变量的初始值和步长 + Value* initialValue = nullptr; + Value* stepValue = nullptr; + BasicBlock* latchBlock = nullptr; + + for (auto& [incomingBB, incomingVal] : originalPhi->getIncomingValues()) { + if (!loop->contains(incomingBB)) { + // 来自循环外的初始值 + initialValue = incomingVal; + } else { + // 来自循环内的递增值 + latchBlock = incomingBB; + // 尝试找到步长 + if (auto* addInst = dynamic_cast(incomingVal)) { + if (addInst->getKind() == Instruction::Kind::kAdd) { + if (addInst->getOperand(0) == originalPhi) { + stepValue = addInst->getOperand(1); + } else if (addInst->getOperand(1) == originalPhi) { + stepValue = addInst->getOperand(0); + } + } + } + } + } + + if (!initialValue || !stepValue || !latchBlock) { + if (DEBUG) { + std::cout << " Failed to find initial value, step, or latch block" << std::endl; + } + return false; + } + + // 3. 计算新归纳变量的初始值和步长 + // 新IV的初始值 = 原IV初始值 * multiplier + Value* newInitialValue; + if (auto* constInt = dynamic_cast(initialValue)) { + newInitialValue = ConstantInteger::get(constInt->getInt() * candidate->multiplier); + } else { + // 如果初始值不是常数,需要在preheader中插入乘法 + builder->setPosition(preheader, preheader->terminator()); + newInitialValue = builder->createMulInst(initialValue, + ConstantInteger::get(candidate->multiplier)); + } + + // 新IV的步长 = 原IV步长 * multiplier + Value* newStepValue; + if (auto* constInt = dynamic_cast(stepValue)) { + newStepValue = ConstantInteger::get(constInt->getInt() * candidate->multiplier); + } else { + builder->setPosition(latchBlock, latchBlock->terminator()); + newStepValue = builder->createMulInst(stepValue, + ConstantInteger::get(candidate->multiplier)); + } + + // 4. 创建新归纳变量的递增指令 + builder->setPosition(latchBlock, latchBlock->terminator()); + Value* newIncrementedValue = builder->createAddInst(candidate->newPhi, newStepValue); + + // 5. 设置新 phi 的输入值 + candidate->newPhi->addIncoming(newInitialValue, preheader); + candidate->newPhi->addIncoming(newIncrementedValue, latchBlock); + + candidate->newInductionVar = candidate->newPhi; + + if (DEBUG) { + std::cout << " Created new induction variable: " << candidate->newPhi->getName() << std::endl; + } + + return true; +} + +bool StrengthReductionContext::replaceOriginalInstruction(StrengthReductionCandidate* candidate) { + if (!candidate->newInductionVar) { + return false; + } + + // 处理偏移量 + Value* replacementValue = candidate->newInductionVar; + if (candidate->offset != 0) { + builder->setPosition(candidate->containingBlock, + candidate->containingBlock->findInstIterator(candidate->originalInst)); + replacementValue = builder->createAddInst( + candidate->newInductionVar, + ConstantInteger::get(candidate->offset) + ); + } + + // 替换所有使用 + candidate->originalInst->replaceAllUsesWith(replacementValue); + + // 从基本块中移除原始指令 + auto* bb = candidate->originalInst->getParent(); + auto it = bb->findInstIterator(candidate->originalInst); + if (it != bb->end()) { + bb->getInstructions().erase(it); + } + + if (DEBUG) { + std::cout << " Replaced and removed original instruction" << std::endl; + } + + return true; +} + +void StrengthReductionContext::printDebugInfo() { + if (!DEBUG) return; + + std::cout << "\n=== Strength Reduction Optimization Summary ===" << std::endl; + std::cout << "Total candidates processed: " << candidates.size() << std::endl; + + for (auto& [loop, loopCandidates] : loopToCandidates) { + if (!loopCandidates.empty()) { + std::cout << "Loop " << loop->getName() << ": " << loopCandidates.size() << " optimizations" << std::endl; + for (auto* candidate : loopCandidates) { + if (candidate->newInductionVar) { + std::cout << " " << candidate->inductionVar->getName() << " * " << candidate->multiplier + << " -> " << candidate->newInductionVar->getName() << std::endl; + } + } + } + } + std::cout << "===============================================" << std::endl; +} + +} // namespace sysy diff --git a/src/midend/Pass/Pass.cpp b/src/midend/Pass/Pass.cpp index 74c1a67..5b9db22 100644 --- a/src/midend/Pass/Pass.cpp +++ b/src/midend/Pass/Pass.cpp @@ -15,6 +15,7 @@ #include "LargeArrayToGlobal.h" #include "LoopNormalization.h" #include "LICM.h" +#include "LoopStrengthReduction.h" #include "Pass.h" #include #include @@ -70,6 +71,7 @@ void PassManager::runOptimizationPipeline(Module* moduleIR, IRBuilder* builderIR registerOptimizationPass(builderIR); registerOptimizationPass(builderIR); registerOptimizationPass(builderIR); + registerOptimizationPass(builderIR); registerOptimizationPass(builderIR); registerOptimizationPass(builderIR); @@ -137,16 +139,17 @@ void PassManager::runOptimizationPipeline(Module* moduleIR, IRBuilder* builderIR this->clearPasses(); this->addPass(&LoopNormalizationPass::ID); this->addPass(&LICM::ID); + this->addPass(&LoopStrengthReduction::ID); this->run(); if(DEBUG) { - std::cout << "=== IR After Loop Normalization and LICM Optimizations ===\n"; + std::cout << "=== IR After Loop Normalization, LICM, and Strength Reduction Optimizations ===\n"; printPasses(); } - this->clearPasses(); - this->addPass(&Reg2Mem::ID); - this->run(); + // this->clearPasses(); + // this->addPass(&Reg2Mem::ID); + // this->run(); if(DEBUG) { std::cout << "=== IR After Reg2Mem Optimizations ===\n"; diff --git a/src/sysyc.cpp b/src/sysyc.cpp index cbb553c..78930a0 100644 --- a/src/sysyc.cpp +++ b/src/sysyc.cpp @@ -35,7 +35,7 @@ void usage(int code) { "Supported options:\n" " -h \tprint help message and exit\n" " -f \tpretty-format the input file\n" - " -s {ast,ir,asm,llvmir,asmd,ird}\tstop after generating AST/IR/Assembly\n" + " -s {ast,ir,asm,asmd,ird}\tstop after generating AST/IR/Assembly\n" " -S \tcompile to assembly (.s file)\n" " -o \tplace the output into \n" " -O\tenable optimization at (e.g., -O0, -O1)\n";