Compare commits

...

4 Commits

Author SHA1 Message Date
rain2133
97d83d733e [midend-Funtioninline]ir增加函数复制方法,函数内联需要用 2025-08-18 23:58:39 +08:00
rain2133
ad74e435ba [midend-GSR]修复错误的代数简化 2025-08-18 21:55:57 +08:00
rain2133
5c34cbc7b8 [midend-GSR]将魔数求解移动到utils的静态方法中。 2025-08-18 20:37:20 +08:00
rain2133
c9a0c700e1 [midend]增加全局强度削弱优化遍 2025-08-18 11:30:40 +08:00
9 changed files with 1312 additions and 204 deletions

View File

@@ -652,6 +652,13 @@ public:
} ///< 移除指定位置的指令
iterator moveInst(iterator sourcePos, iterator targetPos, BasicBlock *block);
static void bbdfs(BasicBlock* bb, std::function<bool(BasicBlock*)> func) {
if (func(bb))
return;
for (auto succ : bb->getSuccessors())
bbdfs(succ, func);
}
/// 清理基本块中的所有使用关系
void cleanup();
@@ -767,7 +774,7 @@ protected:
: User(type, name), kind(kind), parent(parent) {}
public:
virtual Instruction* copy(std::function<Value*(Value*)> getValue) const = 0;
public:
Kind getKind() const { return kind; }
std::string getKindString() const{
@@ -964,6 +971,18 @@ class PhiInst : public Instruction {
}
}
PhiInst(Type *type,
const std::vector<std::pair<BasicBlock*, Value*>> IncomingValues,
BasicBlock *parent = nullptr,
const std::string &name = "")
: Instruction(Kind::kPhi, type, parent, name), vsize(IncomingValues.size()) {
for(size_t i = 0; i < vsize; ++i) {
addOperand(IncomingValues[i].first);
addOperand(IncomingValues[i].second);
}
refreshMap(); ///< 刷新块到值的映射关系
}
public:
unsigned getNumIncomingValues() const { return vsize; } ///< 获取传入值的数量
Value *getIncomingValue(unsigned Idx) const { return getOperand(Idx * 2); } ///< 获取指定位置的传入值
@@ -1013,6 +1032,7 @@ class PhiInst : public Instruction {
} ///< 刷新块到值的映射关系
auto getValues() { return make_range(std::next(operand_begin()), operand_end()); }
void print(std::ostream& os) const override;
Instruction* copy(std::function<Value*(Value*)> getValue) const override;
};
@@ -1047,6 +1067,7 @@ protected:
public:
Value* getOperand() const { return User::getOperand(0); }
void print(std::ostream& os) const override;
Instruction* copy(std::function<Value*(Value*)> getValue) const override;
}; // class UnaryInst
//! Binary instruction, e.g., arithmatic, relation, logic, etc.
@@ -1126,6 +1147,7 @@ public:
return new BinaryInst(kind, type, lhs, rhs, parent, name);
}
void print(std::ostream& os) const override;
Instruction* copy(std::function<Value*(Value*)> getValue) const override;
}; // class BinaryInst
//! The return statement
@@ -1217,6 +1239,7 @@ public:
return succs;
}
void print(std::ostream& os) const override;
Instruction* copy(std::function<Value*(Value*)> getValue) const override;
}; // class CondBrInst
class UnreachableInst : public Instruction {
@@ -1225,6 +1248,7 @@ public:
explicit UnreachableInst(const std::string& name, BasicBlock *parent = nullptr)
: Instruction(kUnreachable, Type::getVoidType(), parent, "") {}
void print(std::ostream& os) const { os << "unreachable"; }
Instruction* copy(std::function<Value*(Value*)> getValue) const override;
};
//! Allocate memory for stack variables, used for non-global variable declartion
@@ -1243,6 +1267,7 @@ public:
return getType()->as<PointerType>()->getBaseType();
} ///< 获取分配的类型
void print(std::ostream& os) const override;
Instruction* copy(std::function<Value*(Value*)> getValue) const override;
}; // class AllocaInst
@@ -1281,6 +1306,7 @@ public:
return new GetElementPtrInst(resultType, basePointer, indices, parent, name);
}
void print(std::ostream& os) const override;
Instruction* copy(std::function<Value*(Value*)> getValue) const override;
};
//! Load a value from memory address specified by a pointer value
@@ -1299,6 +1325,7 @@ protected:
public:
Value* getPointer() const { return getOperand(0); }
void print(std::ostream& os) const override;
Instruction* copy(std::function<Value*(Value*)> getValue) const override;
}; // class LoadInst
//! Store a value to memory address specified by a pointer value
@@ -1318,6 +1345,7 @@ public:
Value* getValue() const { return getOperand(0); }
Value* getPointer() const { return getOperand(1); }
void print(std::ostream& os) const override;
Instruction* copy(std::function<Value*(Value*)> getValue) const override;
}; // class StoreInst
//! Memset instruction
@@ -1348,6 +1376,7 @@ public:
Value* getSize() const { return getOperand(2); }
Value* getValue() const { return getOperand(3); }
void print(std::ostream& os) const override;
Instruction* copy(std::function<Value*(Value*)> getValue) const override;
};
class GlobalValue;
@@ -1385,19 +1414,11 @@ protected:
public:
using block_list = std::list<std::unique_ptr<BasicBlock>>;
using arg_list = std::vector<Argument *>;
enum FunctionAttribute : uint64_t {
PlaceHolder = 0x0UL,
Pure = 0x1UL << 0,
SelfRecursive = 0x1UL << 1,
SideEffect = 0x1UL << 2,
NoPureCauseMemRead = 0x1UL << 3
};
protected:
Module *parent; ///< 函数的父模块
block_list blocks; ///< 函数包含的基本块列表
arg_list arguments; ///< 函数参数列表
FunctionAttribute attribute = PlaceHolder; ///< 函数属性
std::set<Function *> callees; ///< 函数调用的函数集合
public:
static unsigned getcloneIndex() {
@@ -1405,17 +1426,12 @@ protected:
cloneIndex += 1;
return cloneIndex - 1;
}
Function* clone(const std::string &suffix = "_" + std::to_string(getcloneIndex()) + "@") const;
Function* copy_func();
const std::set<Function *>& getCallees() { return callees; }
void addCallee(Function *callee) { callees.insert(callee); }
void removeCallee(Function *callee) { callees.erase(callee); }
void clearCallees() { callees.clear(); }
std::set<Function *> getCalleesWithNoExternalAndSelf();
FunctionAttribute getAttribute() const { return attribute; }
void setAttribute(FunctionAttribute attr) {
attribute = static_cast<FunctionAttribute>(attribute | attr);
}
void clearAttribute() { attribute = PlaceHolder; }
Type* getReturnType() const { return getType()->as<FunctionType>()->getReturnType(); }
auto getParamTypes() const { return getType()->as<FunctionType>()->getParamTypes(); }
auto getBasicBlocks() { return make_range(blocks); }

View File

@@ -0,0 +1,107 @@
#pragma once
#include "Pass.h"
#include "IR.h"
#include "SideEffectAnalysis.h"
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include <cstdint>
namespace sysy {
// 魔数乘法结构,用于除法优化
struct MagicNumber {
uint32_t multiplier;
int shift;
bool needAdd;
MagicNumber(uint32_t m, int s, bool add = false)
: multiplier(m), shift(s), needAdd(add) {}
};
// 全局强度削弱优化遍的核心逻辑封装类
class GlobalStrengthReductionContext {
public:
// 构造函数接受IRBuilder参数
explicit GlobalStrengthReductionContext(IRBuilder* builder) : builder(builder) {}
// 运行优化的主要方法
void run(Function* func, AnalysisManager* AM, bool& changed);
private:
IRBuilder* builder; // IR构建器
// 分析结果
SideEffectAnalysisResult* sideEffectAnalysis = nullptr;
// 优化计数
int algebraicOptCount = 0;
int strengthReductionCount = 0;
int divisionOptCount = 0;
// 主要优化方法
bool processBasicBlock(BasicBlock* bb);
bool processInstruction(Instruction* inst);
// 代数优化方法
bool tryAlgebraicOptimization(Instruction* inst);
bool optimizeAddition(BinaryInst* inst);
bool optimizeSubtraction(BinaryInst* inst);
bool optimizeMultiplication(BinaryInst* inst);
bool optimizeDivision(BinaryInst* inst);
bool optimizeComparison(BinaryInst* inst);
bool optimizeLogical(BinaryInst* inst);
// 强度削弱方法
bool tryStrengthReduction(Instruction* inst);
bool reduceMultiplication(BinaryInst* inst);
bool reduceDivision(BinaryInst* inst);
bool reducePower(CallInst* inst);
// 复杂乘法强度削弱方法
bool tryComplexMultiplication(BinaryInst* inst, Value* variable, int constant);
bool findOptimalShiftDecomposition(int constant, std::vector<int>& shifts);
Value* createShiftDecomposition(BinaryInst* inst, Value* variable, const std::vector<int>& shifts);
// 魔数乘法相关方法
MagicNumber computeMagicNumber(uint32_t divisor);
std::pair<int, int> computeMulhMagicNumbers(int divisor);
Value* createMagicDivision(BinaryInst* divInst, uint32_t divisor, const MagicNumber& magic);
Value* createMagicDivisionLibdivide(BinaryInst* divInst, int divisor);
bool isPowerOfTwo(uint32_t n);
int log2OfPowerOfTwo(uint32_t n);
// 辅助方法
bool isConstantInt(Value* val, int& constVal);
bool isConstantInt(Value* val, uint32_t& constVal);
ConstantInteger* getConstantInt(int val);
bool hasOnlyLocalUses(Instruction* inst);
void replaceWithOptimized(Instruction* original, Value* replacement);
};
// 全局强度削弱优化遍类
class GlobalStrengthReduction : public OptimizationPass {
private:
IRBuilder* builder; // IR构建器用于创建新指令
public:
// 静态成员作为该遍的唯一ID
static void* ID;
// 构造函数接受IRBuilder参数
explicit GlobalStrengthReduction(IRBuilder* builder)
: OptimizationPass("GlobalStrengthReduction", Granularity::Function), builder(builder) {}
// 在函数上运行优化
bool runOnFunction(Function* func, AnalysisManager& AM) override;
// 返回该遍的唯一ID
void* getPassID() const override { return ID; }
// 声明分析依赖
void getAnalysisUsage(std::set<void*>& analysisDependencies,
std::set<void*>& analysisInvalidations) const override;
};
} // namespace sysy

View File

@@ -127,13 +127,6 @@ private:
*/
bool analyzeInductionVariableRange(const InductionVarInfo* ivInfo, Loop* loop) const;
/**
* 计算用于除法优化的魔数和移位量
* @param divisor 除数
* @return {魔数, 移位量}
*/
std::pair<int, int> computeMulhMagicNumbers(int divisor) const;
/**
* 生成除法替换代码
* @param candidate 优化候选项

View File

@@ -107,6 +107,190 @@ public:
// 所以当AllocaInst的basetype是PointerType时一维数组或者是指向ArrayType的PointerType多位数组返回true
return aval && (baseType->isPointer() || baseType->as<PointerType>()->getBaseType()->isArray());
}
//该实现参考了libdivide的算法
static std::pair<int, int> computeMulhMagicNumbers(int divisor) {
if (DEBUG) {
std::cout << "\n[SR] ===== Computing magic numbers for divisor " << divisor << " (libdivide algorithm) =====" << std::endl;
}
if (divisor == 0) {
if (DEBUG) std::cout << "[SR] Error: divisor must be != 0" << std::endl;
return {-1, -1};
}
// libdivide 常数
const uint8_t LIBDIVIDE_ADD_MARKER = 0x40;
const uint8_t LIBDIVIDE_NEGATIVE_DIVISOR = 0x80;
// 辅助函数:计算前导零个数
auto count_leading_zeros32 = [](uint32_t val) -> uint32_t {
if (val == 0) return 32;
return __builtin_clz(val);
};
// 辅助函数64位除法返回32位商和余数
auto div_64_32 = [](uint32_t high, uint32_t low, uint32_t divisor, uint32_t* rem) -> uint32_t {
uint64_t dividend = ((uint64_t)high << 32) | low;
uint32_t quotient = dividend / divisor;
*rem = dividend % divisor;
return quotient;
};
if (DEBUG) {
std::cout << "[SR] Input divisor: " << divisor << std::endl;
}
// libdivide_internal_s32_gen 算法实现
int32_t d = divisor;
uint32_t ud = (uint32_t)d;
uint32_t absD = (d < 0) ? -ud : ud;
if (DEBUG) {
std::cout << "[SR] absD = " << absD << std::endl;
}
uint32_t floor_log_2_d = 31 - count_leading_zeros32(absD);
if (DEBUG) {
std::cout << "[SR] floor_log_2_d = " << floor_log_2_d << std::endl;
}
// 检查 absD 是否为2的幂
if ((absD & (absD - 1)) == 0) {
if (DEBUG) {
std::cout << "[SR] " << absD << " 是2的幂使用移位方法" << std::endl;
}
// 对于2的幂我们只使用移位不需要魔数
int shift = floor_log_2_d;
if (d < 0) shift |= 0x80; // 标记负数
if (DEBUG) {
std::cout << "[SR] Power of 2 result: magic=0, shift=" << shift << std::endl;
std::cout << "[SR] ===== End magic computation =====" << std::endl;
}
// 对于我们的目的我们将在IR生成中以不同方式处理2的幂
// 返回特殊标记
return {0, shift};
}
if (DEBUG) {
std::cout << "[SR] " << absD << " is not a power of 2, computing magic number" << std::endl;
}
// 非2的幂除数的魔数计算
uint8_t more;
uint32_t rem, proposed_m;
// 计算 proposed_m = floor(2^(floor_log_2_d + 31) / absD)
proposed_m = div_64_32((uint32_t)1 << (floor_log_2_d - 1), 0, absD, &rem);
const uint32_t e = absD - rem;
if (DEBUG) {
std::cout << "[SR] proposed_m = " << proposed_m << ", rem = " << rem << ", e = " << e << std::endl;
}
// 确定是否需要"加法"版本
const bool branchfree = false; // 使用分支版本
if (!branchfree && e < ((uint32_t)1 << floor_log_2_d)) {
// 这个幂次有效
more = (uint8_t)(floor_log_2_d - 1);
if (DEBUG) {
std::cout << "[SR] Using basic algorithm, shift = " << (int)more << std::endl;
}
} else {
// 我们需要上升一个等级
proposed_m += proposed_m;
const uint32_t twice_rem = rem + rem;
if (twice_rem >= absD || twice_rem < rem) {
proposed_m += 1;
}
more = (uint8_t)(floor_log_2_d | LIBDIVIDE_ADD_MARKER);
if (DEBUG) {
std::cout << "[SR] Using add algorithm, proposed_m = " << proposed_m << ", more = " << (int)more << std::endl;
}
}
proposed_m += 1;
int32_t magic = (int32_t)proposed_m;
// 处理负除数
if (d < 0) {
more |= LIBDIVIDE_NEGATIVE_DIVISOR;
if (!branchfree) {
magic = -magic;
}
if (DEBUG) {
std::cout << "[SR] Negative divisor, magic = " << magic << ", more = " << (int)more << std::endl;
}
}
// 为我们的IR生成提取移位量和标志
int shift = more & 0x3F; // 移除标志保留移位量位0-5
bool need_add = (more & LIBDIVIDE_ADD_MARKER) != 0;
bool is_negative = (more & LIBDIVIDE_NEGATIVE_DIVISOR) != 0;
if (DEBUG) {
std::cout << "[SR] Final result: magic = " << magic << ", more = " << (int)more
<< " (0x" << std::hex << (int)more << std::dec << ")" << std::endl;
std::cout << "[SR] Shift = " << shift << ", need_add = " << need_add
<< ", is_negative = " << is_negative << std::endl;
// Test the magic number using the correct libdivide algorithm
std::cout << "[SR] Testing magic number (libdivide algorithm):" << std::endl;
int test_values[] = {1, 7, 37, 100, 999, -1, -7, -37, -100};
for (int test_val : test_values) {
int64_t quotient;
// 实现正确的libdivide算法
int64_t product = (int64_t)test_val * magic;
int64_t high_bits = product >> 32;
if (need_add) {
// ADD_MARKER情况移位前加上被除数
// 这是libdivide的关键洞察
high_bits += test_val;
quotient = high_bits >> shift;
} else {
// 正常情况:只是移位
quotient = high_bits >> shift;
}
// 符号修正这是libdivide有符号除法的关键部分
// 如果被除数为负商需要加1来匹配C语言的截断除法语义
if (test_val < 0) {
quotient += 1;
}
int expected = test_val / divisor;
bool correct = (quotient == expected);
std::cout << "[SR] " << test_val << " / " << divisor << " = " << quotient
<< " (expected " << expected << ") " << (correct ? "" : "") << std::endl;
}
std::cout << "[SR] ===== End magic computation =====" << std::endl;
}
// 返回魔数、移位量并在移位中编码ADD_MARKER标志
// 我们将使用移位的第6位表示ADD_MARKER第7位表示负数如果需要
int encoded_shift = shift;
if (need_add) {
encoded_shift |= 0x40; // 设置第6位表示ADD_MARKER
if (DEBUG) {
std::cout << "[SR] Encoding ADD_MARKER in shift: " << encoded_shift << std::endl;
}
}
return {magic, encoded_shift};
}
};
}// namespace sysy

View File

@@ -22,6 +22,7 @@ add_library(midend_lib STATIC
Pass/Optimize/LICM.cpp
Pass/Optimize/LoopStrengthReduction.cpp
Pass/Optimize/InductionVariableElimination.cpp
Pass/Optimize/GlobalStrengthReduction.cpp
Pass/Optimize/BuildCFG.cpp
Pass/Optimize/LargeArrayToGlobal.cpp
)

View File

@@ -652,6 +652,12 @@ void BasicBlock::print(std::ostream &os) const {
}
}
void Instruction::print(std::ostream &os) const {
this->getType()->print(os);
std::cerr << " %" << getName() << " ";
return;
}
void PhiInst::print(std::ostream &os) const {
printVarName(os, this);
os << " = " << getKindString() << " " << *getType() << " ";
@@ -1440,4 +1446,76 @@ void Argument::cleanup() {
uses.clear();
}
Function* Function::copy_func(){
std::unordered_map<Value*, Value*> valueMap;
// copy global
for (auto &gvalue : parent->getGlobals()) {
valueMap.emplace(gvalue.get(), gvalue.get());
}
// copy global const
for (auto &cvalue : parent->getConsts()) {
valueMap.emplace(cvalue.get(), cvalue.get());
}
// copy function
auto newFunc = new Function(parent, type, name + "_copy" + std::to_string(getcloneIndex()));
// copy arg
for(int i = 0; i < arguments.size(); i++) {
auto arg = arguments[i];
// Argument(paramActualTypes[i], function, i, paramNames[i]);
auto newarg = new Argument(arg->getType(), newFunc, i, arg->getName() + "_copy");
newFunc->insertArgument(newarg);
valueMap[arg] = newarg;
}
//
for (auto &bb : blocks) {
BasicBlock* copybb = newFunc->addBasicBlock();
valueMap.emplace(bb, copybb);
}
for(auto &bb : blocks) {
auto BB = bb.get();
auto copybb = dynamic_cast<BasicBlock*>(valueMap[BB]);
for(auto pred : BB->getPredecessors()) {
copybb->addPredecessor(dynamic_cast<BasicBlock*>(valueMap[pred]));
}
for(auto succ : BB->getSuccessors()) {
copybb->addSuccessor(dynamic_cast<BasicBlock*>(valueMap[succ]));
}
}
// if cant find, return itself
auto getValue = [&](Value* val) -> Value* {
if (val == nullptr) {
std::cerr << "getValue(nullptr)" << std::endl;
return nullptr;
}
if (dynamic_cast<ConstantValue*>(val)) return val;
if (auto iter = valueMap.find(val); iter != valueMap.end()) return iter->second;
return val;
};
std::set<BasicBlock*> visitedbb;
const auto copyBlock = [&](BasicBlock* bb) -> bool {
if (visitedbb.count(bb)) return true;
visitedbb.insert(bb);
auto bbCpy = dynamic_cast<BasicBlock*>(valueMap.at(bb));
for (auto &Inst : bb->getInstructions()) {
auto inst = Inst.get();
// inst->print(std::cerr);
// std::cerr << std::endl;
auto copyinst = inst->copy(getValue);
copyinst->setParent(bbCpy);
valueMap.emplace(inst, copyinst);
bbCpy->instructions.emplace_back(copyinst);
}
return false;
};
//dfs基本块将指令复制到新函数中
BasicBlock::bbdfs(getEntryBlock(), copyBlock);
return newFunc;
}
} // namespace sysy

View File

@@ -0,0 +1,897 @@
#include "GlobalStrengthReduction.h"
#include "SysYIROptUtils.h"
#include "IRBuilder.h"
#include <algorithm>
#include <cassert>
#include <iostream>
#include <cmath>
extern int DEBUG;
namespace sysy {
// 全局强度削弱优化遍的静态 ID
void *GlobalStrengthReduction::ID = (void *)&GlobalStrengthReduction::ID;
// ======================================================================
// GlobalStrengthReduction 类的实现
// ======================================================================
bool GlobalStrengthReduction::runOnFunction(Function *func, AnalysisManager &AM) {
if (func->getBasicBlocks().empty()) {
return false;
}
if (DEBUG) {
std::cout << "\n=== Running GlobalStrengthReduction on function: " << func->getName() << " ===" << std::endl;
}
bool changed = false;
GlobalStrengthReductionContext context(builder);
context.run(func, &AM, changed);
if (DEBUG) {
if (changed) {
std::cout << "GlobalStrengthReduction: Function " << func->getName() << " was modified" << std::endl;
} else {
std::cout << "GlobalStrengthReduction: Function " << func->getName() << " was not modified" << std::endl;
}
std::cout << "=== GlobalStrengthReduction completed for function: " << func->getName() << " ===" << std::endl;
}
return changed;
}
void GlobalStrengthReduction::getAnalysisUsage(std::set<void *> &analysisDependencies, std::set<void *> &analysisInvalidations) const {
// 强度削弱依赖副作用分析来判断指令是否可以安全优化
analysisDependencies.insert(&SysYSideEffectAnalysisPass::ID);
// 强度削弱不会使分析失效,因为:
// - 只替换计算指令,不改变控制流
// - 不修改内存,不影响别名分析
// - 保持程序语义不变
// analysisInvalidations 保持为空
if (DEBUG) {
std::cout << "GlobalStrengthReduction: Declared analysis dependencies (SideEffectAnalysis)" << std::endl;
}
}
// ======================================================================
// GlobalStrengthReductionContext 类的实现
// ======================================================================
void GlobalStrengthReductionContext::run(Function *func, AnalysisManager *AM, bool &changed) {
if (DEBUG) {
std::cout << " Starting GlobalStrengthReduction analysis for function: " << func->getName() << std::endl;
}
// 获取分析结果
if (AM) {
sideEffectAnalysis = AM->getAnalysisResult<SideEffectAnalysisResult, SysYSideEffectAnalysisPass>();
if (DEBUG) {
if (sideEffectAnalysis) {
std::cout << " GlobalStrengthReduction: Using side effect analysis" << std::endl;
} else {
std::cout << " GlobalStrengthReduction: Warning - side effect analysis not available" << std::endl;
}
}
}
// 重置计数器
algebraicOptCount = 0;
strengthReductionCount = 0;
divisionOptCount = 0;
// 遍历所有基本块进行优化
for (auto &bb_ptr : func->getBasicBlocks()) {
if (processBasicBlock(bb_ptr.get())) {
changed = true;
}
}
if (DEBUG) {
std::cout << " GlobalStrengthReduction completed for function: " << func->getName() << std::endl;
std::cout << " Algebraic optimizations: " << algebraicOptCount << std::endl;
std::cout << " Strength reductions: " << strengthReductionCount << std::endl;
std::cout << " Division optimizations: " << divisionOptCount << std::endl;
}
}
bool GlobalStrengthReductionContext::processBasicBlock(BasicBlock *bb) {
bool changed = false;
if (DEBUG) {
std::cout << " Processing block: " << bb->getName() << std::endl;
}
// 收集需要处理的指令(避免迭代器失效)
std::vector<Instruction*> instructions;
for (auto &inst_ptr : bb->getInstructions()) {
instructions.push_back(inst_ptr.get());
}
// 处理每条指令
for (auto inst : instructions) {
if (processInstruction(inst)) {
changed = true;
}
}
return changed;
}
bool GlobalStrengthReductionContext::processInstruction(Instruction *inst) {
if (DEBUG) {
std::cout << " Processing instruction: " << inst->getName() << std::endl;
}
// 先尝试代数优化
if (tryAlgebraicOptimization(inst)) {
algebraicOptCount++;
return true;
}
// 再尝试强度削弱
if (tryStrengthReduction(inst)) {
strengthReductionCount++;
return true;
}
return false;
}
// ======================================================================
// 代数优化方法
// ======================================================================
bool GlobalStrengthReductionContext::tryAlgebraicOptimization(Instruction *inst) {
auto binary = dynamic_cast<BinaryInst*>(inst);
if (!binary) {
return false;
}
switch (binary->getKind()) {
case Instruction::kAdd:
return optimizeAddition(binary);
case Instruction::kSub:
return optimizeSubtraction(binary);
case Instruction::kMul:
return optimizeMultiplication(binary);
case Instruction::kDiv:
return optimizeDivision(binary);
case Instruction::kICmpEQ:
case Instruction::kICmpNE:
case Instruction::kICmpLT:
case Instruction::kICmpGT:
case Instruction::kICmpLE:
case Instruction::kICmpGE:
return optimizeComparison(binary);
case Instruction::kAnd:
case Instruction::kOr:
return optimizeLogical(binary);
default:
return false;
}
}
bool GlobalStrengthReductionContext::optimizeAddition(BinaryInst *inst) {
Value *lhs = inst->getLhs();
Value *rhs = inst->getRhs();
int constVal;
// x + 0 = x
if (isConstantInt(rhs, constVal) && constVal == 0) {
if (DEBUG) {
std::cout << " Algebraic: " << inst->getName() << " = x + 0 -> x" << std::endl;
}
replaceWithOptimized(inst, lhs);
return true;
}
// 0 + x = x
if (isConstantInt(lhs, constVal) && constVal == 0) {
if (DEBUG) {
std::cout << " Algebraic: " << inst->getName() << " = 0 + x -> x" << std::endl;
}
replaceWithOptimized(inst, rhs);
return true;
}
// x + (-y) = x - y
if (auto rhsInst = dynamic_cast<UnaryInst*>(rhs)) {
if (rhsInst->getKind() == Instruction::kNeg) {
if (DEBUG) {
std::cout << " Algebraic: " << inst->getName() << " = x + (-y) -> x - y" << std::endl;
}
// 创建减法指令
builder->setPosition(inst->getParent(), inst->getParent()->findInstIterator(inst));
auto subInst = builder->createSubInst(lhs, rhsInst->getOperand());
replaceWithOptimized(inst, subInst);
return true;
}
}
return false;
}
bool GlobalStrengthReductionContext::optimizeSubtraction(BinaryInst *inst) {
Value *lhs = inst->getLhs();
Value *rhs = inst->getRhs();
int constVal;
// x - 0 = x
if (isConstantInt(rhs, constVal) && constVal == 0) {
if (DEBUG) {
std::cout << " Algebraic: " << inst->getName() << " = x - 0 -> x" << std::endl;
}
replaceWithOptimized(inst, lhs);
return true;
}
// x - x = 0 (如果x没有副作用)
if (lhs == rhs && hasOnlyLocalUses(dynamic_cast<Instruction*>(lhs))) {
if (DEBUG) {
std::cout << " Algebraic: " << inst->getName() << " = x - x -> 0" << std::endl;
}
replaceWithOptimized(inst, getConstantInt(0));
return true;
}
// x - (-y) = x + y
if (auto rhsInst = dynamic_cast<UnaryInst*>(rhs)) {
if (rhsInst->getKind() == Instruction::kNeg) {
if (DEBUG) {
std::cout << " Algebraic: " << inst->getName() << " = x - (-y) -> x + y" << std::endl;
}
builder->setPosition(inst->getParent(), inst->getParent()->findInstIterator(inst));
auto addInst = builder->createAddInst(lhs, rhsInst->getOperand());
replaceWithOptimized(inst, addInst);
return true;
}
}
return false;
}
bool GlobalStrengthReductionContext::optimizeMultiplication(BinaryInst *inst) {
Value *lhs = inst->getLhs();
Value *rhs = inst->getRhs();
int constVal;
// x * 0 = 0
if (isConstantInt(rhs, constVal) && constVal == 0) {
if (DEBUG) {
std::cout << " Algebraic: " << inst->getName() << " = x * 0 -> 0" << std::endl;
}
replaceWithOptimized(inst, getConstantInt(0));
return true;
}
// 0 * x = 0
if (isConstantInt(lhs, constVal) && constVal == 0) {
if (DEBUG) {
std::cout << " Algebraic: " << inst->getName() << " = 0 * x -> 0" << std::endl;
}
replaceWithOptimized(inst, getConstantInt(0));
return true;
}
// x * 1 = x
if (isConstantInt(rhs, constVal) && constVal == 1) {
if (DEBUG) {
std::cout << " Algebraic: " << inst->getName() << " = x * 1 -> x" << std::endl;
}
replaceWithOptimized(inst, lhs);
return true;
}
// 1 * x = x
if (isConstantInt(lhs, constVal) && constVal == 1) {
if (DEBUG) {
std::cout << " Algebraic: " << inst->getName() << " = 1 * x -> x" << std::endl;
}
replaceWithOptimized(inst, rhs);
return true;
}
// x * (-1) = -x
if (isConstantInt(rhs, constVal) && constVal == -1) {
if (DEBUG) {
std::cout << " Algebraic: " << inst->getName() << " = x * (-1) -> -x" << std::endl;
}
builder->setPosition(inst->getParent(), inst->getParent()->findInstIterator(inst));
auto negInst = builder->createNegInst(lhs);
replaceWithOptimized(inst, negInst);
return true;
}
return false;
}
bool GlobalStrengthReductionContext::optimizeDivision(BinaryInst *inst) {
Value *lhs = inst->getLhs();
Value *rhs = inst->getRhs();
int constVal;
// x / 1 = x
if (isConstantInt(rhs, constVal) && constVal == 1) {
if (DEBUG) {
std::cout << " Algebraic: " << inst->getName() << " = x / 1 -> x" << std::endl;
}
replaceWithOptimized(inst, lhs);
return true;
}
// x / (-1) = -x
if (isConstantInt(rhs, constVal) && constVal == -1) {
if (DEBUG) {
std::cout << " Algebraic: " << inst->getName() << " = x / (-1) -> -x" << std::endl;
}
builder->setPosition(inst->getParent(), inst->getParent()->findInstIterator(inst));
auto negInst = builder->createNegInst(lhs);
replaceWithOptimized(inst, negInst);
return true;
}
// x / x = 1 (如果x != 0且没有副作用)
if (lhs == rhs && hasOnlyLocalUses(dynamic_cast<Instruction*>(lhs))) {
if (DEBUG) {
std::cout << " Algebraic: " << inst->getName() << " = x / x -> 1" << std::endl;
}
replaceWithOptimized(inst, getConstantInt(1));
return true;
}
return false;
}
bool GlobalStrengthReductionContext::optimizeComparison(BinaryInst *inst) {
Value *lhs = inst->getLhs();
Value *rhs = inst->getRhs();
// x == x = true (如果x没有副作用)
if (inst->getKind() == Instruction::kICmpEQ && lhs == rhs &&
hasOnlyLocalUses(dynamic_cast<Instruction*>(lhs))) {
if (DEBUG) {
std::cout << " Algebraic: " << inst->getName() << " = x == x -> true" << std::endl;
}
replaceWithOptimized(inst, getConstantInt(1));
return true;
}
// x != x = false (如果x没有副作用)
if (inst->getKind() == Instruction::kICmpNE && lhs == rhs &&
hasOnlyLocalUses(dynamic_cast<Instruction*>(lhs))) {
if (DEBUG) {
std::cout << " Algebraic: " << inst->getName() << " = x != x -> false" << std::endl;
}
replaceWithOptimized(inst, getConstantInt(0));
return true;
}
return false;
}
bool GlobalStrengthReductionContext::optimizeLogical(BinaryInst *inst) {
Value *lhs = inst->getLhs();
Value *rhs = inst->getRhs();
int constVal;
if (inst->getKind() == Instruction::kAnd) {
// x && 0 = 0
if (isConstantInt(rhs, constVal) && constVal == 0) {
if (DEBUG) {
std::cout << " Algebraic: " << inst->getName() << " = x && 0 -> 0" << std::endl;
}
replaceWithOptimized(inst, getConstantInt(0));
return true;
}
// x && -1 = x
if (isConstantInt(rhs, constVal) && constVal == -1) {
if (DEBUG) {
std::cout << " Algebraic: " << inst->getName() << " = x && 1 -> x" << std::endl;
}
replaceWithOptimized(inst, lhs);
return true;
}
// x && x = x
if (lhs == rhs) {
if (DEBUG) {
std::cout << " Algebraic: " << inst->getName() << " = x && x -> x" << std::endl;
}
replaceWithOptimized(inst, lhs);
return true;
}
} else if (inst->getKind() == Instruction::kOr) {
// x || 0 = x
if (isConstantInt(rhs, constVal) && constVal == 0) {
if (DEBUG) {
std::cout << " Algebraic: " << inst->getName() << " = x || 0 -> x" << std::endl;
}
replaceWithOptimized(inst, lhs);
return true;
}
// x || x = x
if (lhs == rhs) {
if (DEBUG) {
std::cout << " Algebraic: " << inst->getName() << " = x || x -> x" << std::endl;
}
replaceWithOptimized(inst, lhs);
return true;
}
}
return false;
}
// ======================================================================
// 强度削弱方法
// ======================================================================
bool GlobalStrengthReductionContext::tryStrengthReduction(Instruction *inst) {
if (auto binary = dynamic_cast<BinaryInst*>(inst)) {
switch (binary->getKind()) {
case Instruction::kMul:
return reduceMultiplication(binary);
case Instruction::kDiv:
return reduceDivision(binary);
default:
return false;
}
} else if (auto call = dynamic_cast<CallInst*>(inst)) {
return reducePower(call);
}
return false;
}
bool GlobalStrengthReductionContext::reduceMultiplication(BinaryInst *inst) {
Value *lhs = inst->getLhs();
Value *rhs = inst->getRhs();
int constVal;
// 尝试右操作数为常数
Value* variable = lhs;
if (isConstantInt(rhs, constVal) && constVal > 0) {
return tryComplexMultiplication(inst, variable, constVal);
}
// 尝试左操作数为常数
if (isConstantInt(lhs, constVal) && constVal > 0) {
variable = rhs;
return tryComplexMultiplication(inst, variable, constVal);
}
return false;
}
bool GlobalStrengthReductionContext::tryComplexMultiplication(BinaryInst* inst, Value* variable, int constant) {
// 首先检查是否为2的幂使用简单位移
if (isPowerOfTwo(constant)) {
int shiftAmount = log2OfPowerOfTwo(constant);
if (DEBUG) {
std::cout << " StrengthReduction: " << inst->getName()
<< " = x * " << constant << " -> x << " << shiftAmount << std::endl;
}
builder->setPosition(inst->getParent(), inst->getParent()->findInstIterator(inst));
auto shiftInst = builder->createBinaryInst(Instruction::kSll, Type::getIntType(), variable, getConstantInt(shiftAmount));
replaceWithOptimized(inst, shiftInst);
return true;
}
// 尝试分解为位移和加法的组合
std::vector<int> shifts;
if (findOptimalShiftDecomposition(constant, shifts)) {
if (DEBUG) {
std::cout << " StrengthReduction: " << inst->getName()
<< " = x * " << constant << " -> shift decomposition with " << shifts.size() << " terms" << std::endl;
}
Value* result = createShiftDecomposition(inst, variable, shifts);
if (result) {
replaceWithOptimized(inst, result);
return true;
}
}
return false;
}
bool GlobalStrengthReductionContext::findOptimalShiftDecomposition(int constant, std::vector<int>& shifts) {
shifts.clear();
// 常见的有效分解模式
switch (constant) {
case 3: // 3 = 2^1 + 2^0 -> (x << 1) + x
shifts = {1, 0};
return true;
case 5: // 5 = 2^2 + 2^0 -> (x << 2) + x
shifts = {2, 0};
return true;
case 6: // 6 = 2^2 + 2^1 -> (x << 2) + (x << 1)
shifts = {2, 1};
return true;
case 7: // 7 = 2^2 + 2^1 + 2^0 -> (x << 2) + (x << 1) + x
shifts = {2, 1, 0};
return true;
case 9: // 9 = 2^3 + 2^0 -> (x << 3) + x
shifts = {3, 0};
return true;
case 10: // 10 = 2^3 + 2^1 -> (x << 3) + (x << 1)
shifts = {3, 1};
return true;
case 11: // 11 = 2^3 + 2^1 + 2^0 -> (x << 3) + (x << 1) + x
shifts = {3, 1, 0};
return true;
case 12: // 12 = 2^3 + 2^2 -> (x << 3) + (x << 2)
shifts = {3, 2};
return true;
case 13: // 13 = 2^3 + 2^2 + 2^0 -> (x << 3) + (x << 2) + x
shifts = {3, 2, 0};
return true;
case 14: // 14 = 2^3 + 2^2 + 2^1 -> (x << 3) + (x << 2) + (x << 1)
shifts = {3, 2, 1};
return true;
case 15: // 15 = 2^3 + 2^2 + 2^1 + 2^0 -> (x << 3) + (x << 2) + (x << 1) + x
shifts = {3, 2, 1, 0};
return true;
case 17: // 17 = 2^4 + 2^0 -> (x << 4) + x
shifts = {4, 0};
return true;
case 18: // 18 = 2^4 + 2^1 -> (x << 4) + (x << 1)
shifts = {4, 1};
return true;
case 20: // 20 = 2^4 + 2^2 -> (x << 4) + (x << 2)
shifts = {4, 2};
return true;
case 24: // 24 = 2^4 + 2^3 -> (x << 4) + (x << 3)
shifts = {4, 3};
return true;
case 25: // 25 = 2^4 + 2^3 + 2^0 -> (x << 4) + (x << 3) + x
shifts = {4, 3, 0};
return true;
case 100: // 100 = 2^6 + 2^5 + 2^2 -> (x << 6) + (x << 5) + (x << 2)
shifts = {6, 5, 2};
return true;
}
// 通用二进制分解最多4个项避免过度复杂化
if (constant > 0 && constant < 256) {
std::vector<int> binaryShifts;
int temp = constant;
int bit = 0;
while (temp > 0 && binaryShifts.size() < 4) {
if (temp & 1) {
binaryShifts.push_back(bit);
}
temp >>= 1;
bit++;
}
// 只有当项数不超过3个时才使用二进制分解比直接乘法更有效
if (binaryShifts.size() <= 3 && binaryShifts.size() >= 2) {
shifts = binaryShifts;
return true;
}
}
return false;
}
Value* GlobalStrengthReductionContext::createShiftDecomposition(BinaryInst* inst, Value* variable, const std::vector<int>& shifts) {
if (shifts.empty()) return nullptr;
builder->setPosition(inst->getParent(), inst->getParent()->findInstIterator(inst));
Value* result = nullptr;
for (int shift : shifts) {
Value* term;
if (shift == 0) {
// 0位移就是原变量
term = variable;
} else {
// 创建位移指令
term = builder->createBinaryInst(Instruction::kSll, Type::getIntType(), variable, getConstantInt(shift));
}
if (result == nullptr) {
result = term;
} else {
// 累加到结果中
result = builder->createAddInst(result, term);
}
}
return result;
}
bool GlobalStrengthReductionContext::reduceDivision(BinaryInst *inst) {
Value *lhs = inst->getLhs();
Value *rhs = inst->getRhs();
uint32_t constVal;
// x / 2^n = x >> n (对于无符号除法或已知为正数的情况)
if (isConstantInt(rhs, constVal) && constVal > 0 && isPowerOfTwo(constVal)) {
builder->setPosition(inst->getParent(), inst->getParent()->findInstIterator(inst));
int shiftAmount = log2OfPowerOfTwo(constVal);
// 有符号除法校正:(x + (x >> 31) & mask) >> k
int maskValue = constVal - 1;
// x >> 31 (算术右移获取符号位)
Value* signShift = ConstantInteger::get(31);
Value* signBits = builder->createBinaryInst(
Instruction::Kind::kSra, // 算术右移
lhs->getType(),
lhs,
signShift
);
// (x >> 31) & mask
Value* mask = ConstantInteger::get(maskValue);
Value* correction = builder->createBinaryInst(
Instruction::Kind::kAnd,
lhs->getType(),
signBits,
mask
);
// x + correction
Value* corrected = builder->createAddInst(lhs, correction);
// (x + correction) >> k
Value* divShift = ConstantInteger::get(shiftAmount);
Value* shiftInst = builder->createBinaryInst(
Instruction::Kind::kSra, // 算术右移
lhs->getType(),
corrected,
divShift
);
if (DEBUG) {
std::cout << " StrengthReduction: " << inst->getName()
<< " = x / " << constVal << " -> (x + (x >> 31) & mask) >> " << shiftAmount << std::endl;
}
// builder->setPosition(inst->getParent(), inst->getParent()->findInstIterator(inst));
// Value* divisor_minus_1 = ConstantInteger::get(constVal - 1);
// Value* adjusted = builder->createAddInst(lhs, divisor_minus_1);
// Value* shiftInst = builder->createBinaryInst(Instruction::kSra, Type::getIntType(), adjusted, getConstantInt(shiftAmount));
replaceWithOptimized(inst, shiftInst);
strengthReductionCount++;
return true;
}
// x / c = x * magic_number (魔数乘法优化 - 使用libdivide算法)
if (isConstantInt(rhs, constVal) && constVal > 1 && constVal != (uint32_t)(-1)) {
// auto magicPair = computeMulhMagicNumbers(static_cast<int>(constVal));
Value* magicResult = createMagicDivisionLibdivide(inst, static_cast<int>(constVal));
replaceWithOptimized(inst, magicResult);
divisionOptCount++;
return true;
}
return false;
}
bool GlobalStrengthReductionContext::reducePower(CallInst *inst) {
// 检查是否是pow函数调用
Function* callee = inst->getCallee();
if (!callee || callee->getName() != "pow") {
return false;
}
// pow(x, 2) = x * x
if (inst->getNumOperands() >= 2) {
int exponent;
if (isConstantInt(inst->getOperand(1), exponent)) {
if (exponent == 2) {
if (DEBUG) {
std::cout << " StrengthReduction: pow(x, 2) -> x * x" << std::endl;
}
Value* base = inst->getOperand(0);
builder->setPosition(inst->getParent(), inst->getParent()->findInstIterator(inst));
auto mulInst = builder->createMulInst(base, base);
replaceWithOptimized(inst, mulInst);
strengthReductionCount++;
return true;
} else if (exponent >= 3 && exponent <= 8) {
// 对于小的指数,展开为连续乘法
if (DEBUG) {
std::cout << " StrengthReduction: pow(x, " << exponent << ") -> repeated multiplication" << std::endl;
}
Value* base = inst->getOperand(0);
Value* result = base;
builder->setPosition(inst->getParent(), inst->getParent()->findInstIterator(inst));
for (int i = 1; i < exponent; i++) {
result = builder->createMulInst(result, base);
}
replaceWithOptimized(inst, result);
strengthReductionCount++;
return true;
}
}
}
return false;
}
Value* GlobalStrengthReductionContext::createMagicDivisionLibdivide(BinaryInst* divInst, int divisor) {
builder->setPosition(divInst->getParent(), divInst->getParent()->findInstIterator(divInst));
// 使用mulh指令优化任意常数除法
auto [magic, shift] = SysYIROptUtils::computeMulhMagicNumbers(divisor);
// 检查是否无法优化magic == -1, shift == -1 表示失败)
if (magic == -1 && shift == -1) {
if (DEBUG) {
std::cout << "[SR] Cannot optimize division by " << divisor
<< ", keeping original division" << std::endl;
}
// 返回 nullptr 表示无法优化,调用方应该保持原始除法
return nullptr;
}
// 2的幂次方除法可以用移位优化但这不是魔数法的情况这种情况应该不会被分类到这里但是还是做一个保护措施
if ((divisor & (divisor - 1)) == 0 && divisor > 0) {
// 是2的幂次方可以用移位
int shift_amount = 0;
int temp = divisor;
while (temp > 1) {
temp >>= 1;
shift_amount++;
}
Value* shiftConstant = ConstantInteger::get(shift_amount);
// 对于有符号除法,需要先加上除数-1然后再移位为了正确处理负数舍入
Value* divisor_minus_1 = ConstantInteger::get(divisor - 1);
Value* adjusted = builder->createAddInst(divInst->getOperand(0), divisor_minus_1);
return builder->createBinaryInst(
Instruction::Kind::kSra, // 算术右移
divInst->getOperand(0)->getType(),
adjusted,
shiftConstant
);
}
// 创建魔数常量
// 检查魔数是否能放入32位如果不能则不进行优化
if (magic > INT32_MAX || magic < INT32_MIN) {
if (DEBUG) {
std::cout << "[SR] Magic number " << magic << " exceeds 32-bit range, skipping optimization" << std::endl;
}
return nullptr; // 无法优化,保持原始除法
}
Value* magicConstant = ConstantInteger::get((int32_t)magic);
// 检查是否需要ADD_MARKER处理加法调整
bool needAdd = (shift & 0x40) != 0;
int actualShift = shift & 0x3F; // 提取真实的移位量
if (DEBUG) {
std::cout << "[SR] IR Generation: magic=" << magic << ", needAdd=" << needAdd
<< ", actualShift=" << actualShift << std::endl;
}
// 执行高位乘法mulh(x, magic)
Value* mulhResult = builder->createBinaryInst(
Instruction::Kind::kMulh, // 高位乘法
divInst->getOperand(0)->getType(),
divInst->getOperand(0),
magicConstant
);
if (needAdd) {
// ADD_MARKER 情况:需要在移位前加上被除数
// 这对应于 libdivide 的加法调整算法
if (DEBUG) {
std::cout << "[SR] Applying ADD_MARKER: adding dividend before shift" << std::endl;
}
mulhResult = builder->createAddInst(mulhResult, divInst->getOperand(0));
}
if (actualShift > 0) {
// 如果需要额外移位
Value* shiftConstant = ConstantInteger::get(actualShift);
mulhResult = builder->createBinaryInst(
Instruction::Kind::kSra, // 算术右移
divInst->getOperand(0)->getType(),
mulhResult,
shiftConstant
);
}
// 标准的有符号除法符号修正如果被除数为负商需要加1
// 这对所有有符号除法都需要,不管是否可能有负数
Value* isNegative = builder->createICmpLTInst(divInst->getOperand(0), ConstantInteger::get(0));
// 将i1转换为i32负数时为1非负数时为0 ICmpLTInst的结果会默认转化为32位
mulhResult = builder->createAddInst(mulhResult, isNegative);
return mulhResult;
}
// ======================================================================
// 辅助方法
// ======================================================================
bool GlobalStrengthReductionContext::isPowerOfTwo(uint32_t n) {
return n > 0 && (n & (n - 1)) == 0;
}
int GlobalStrengthReductionContext::log2OfPowerOfTwo(uint32_t n) {
int result = 0;
while (n > 1) {
n >>= 1;
result++;
}
return result;
}
bool GlobalStrengthReductionContext::isConstantInt(Value* val, int& constVal) {
if (auto constInt = dynamic_cast<ConstantInteger*>(val)) {
constVal = std::get<int>(constInt->getVal());
return true;
}
return false;
}
bool GlobalStrengthReductionContext::isConstantInt(Value* val, uint32_t& constVal) {
if (auto constInt = dynamic_cast<ConstantInteger*>(val)) {
int signedVal = std::get<int>(constInt->getVal());
if (signedVal >= 0) {
constVal = static_cast<uint32_t>(signedVal);
return true;
}
}
return false;
}
ConstantInteger* GlobalStrengthReductionContext::getConstantInt(int val) {
return ConstantInteger::get(val);
}
bool GlobalStrengthReductionContext::hasOnlyLocalUses(Instruction* inst) {
if (!inst) return true;
// 简单检查:如果指令没有副作用,则认为是本地的
if (sideEffectAnalysis) {
auto sideEffect = sideEffectAnalysis->getInstructionSideEffect(inst);
return sideEffect.type == SideEffectType::NO_SIDE_EFFECT;
}
// 没有副作用分析时,保守处理
return !inst->isCall() && !inst->isStore() && !inst->isLoad();
}
void GlobalStrengthReductionContext::replaceWithOptimized(Instruction* original, Value* replacement) {
if (DEBUG) {
std::cout << " Replacing " << original->getName()
<< " with " << replacement->getName() << std::endl;
}
original->replaceAllUsesWith(replacement);
// 如果替换值是新创建的指令,确保它有合适的名字
// if (auto replInst = dynamic_cast<Instruction*>(replacement)) {
// if (replInst->getName().empty()) {
// replInst->setName(original->getName() + "_opt");
// }
// }
// 删除原指令,让调用者处理
SysYIROptUtils::usedelete(original);
}
} // namespace sysy

View File

@@ -106,187 +106,6 @@ bool StrengthReductionContext::analyzeInductionVariableRange(
return hasNegativePotential;
}
//该实现参考了libdivide的算法
std::pair<int, int> StrengthReductionContext::computeMulhMagicNumbers(int divisor) const {
if (DEBUG) {
std::cout << "\n[SR] ===== Computing magic numbers for divisor " << divisor << " (libdivide algorithm) =====" << std::endl;
}
if (divisor == 0) {
if (DEBUG) std::cout << "[SR] Error: divisor must be != 0" << std::endl;
return {-1, -1};
}
// libdivide 常数
const uint8_t LIBDIVIDE_ADD_MARKER = 0x40;
const uint8_t LIBDIVIDE_NEGATIVE_DIVISOR = 0x80;
// 辅助函数:计算前导零个数
auto count_leading_zeros32 = [](uint32_t val) -> uint32_t {
if (val == 0) return 32;
return __builtin_clz(val);
};
// 辅助函数64位除法返回32位商和余数
auto div_64_32 = [](uint32_t high, uint32_t low, uint32_t divisor, uint32_t* rem) -> uint32_t {
uint64_t dividend = ((uint64_t)high << 32) | low;
uint32_t quotient = dividend / divisor;
*rem = dividend % divisor;
return quotient;
};
if (DEBUG) {
std::cout << "[SR] Input divisor: " << divisor << std::endl;
}
// libdivide_internal_s32_gen 算法实现
int32_t d = divisor;
uint32_t ud = (uint32_t)d;
uint32_t absD = (d < 0) ? -ud : ud;
if (DEBUG) {
std::cout << "[SR] absD = " << absD << std::endl;
}
uint32_t floor_log_2_d = 31 - count_leading_zeros32(absD);
if (DEBUG) {
std::cout << "[SR] floor_log_2_d = " << floor_log_2_d << std::endl;
}
// 检查 absD 是否为2的幂
if ((absD & (absD - 1)) == 0) {
if (DEBUG) {
std::cout << "[SR] " << absD << " 是2的幂使用移位方法" << std::endl;
}
// 对于2的幂我们只使用移位不需要魔数
int shift = floor_log_2_d;
if (d < 0) shift |= 0x80; // 标记负数
if (DEBUG) {
std::cout << "[SR] Power of 2 result: magic=0, shift=" << shift << std::endl;
std::cout << "[SR] ===== End magic computation =====" << std::endl;
}
// 对于我们的目的我们将在IR生成中以不同方式处理2的幂
// 返回特殊标记
return {0, shift};
}
if (DEBUG) {
std::cout << "[SR] " << absD << " is not a power of 2, computing magic number" << std::endl;
}
// 非2的幂除数的魔数计算
uint8_t more;
uint32_t rem, proposed_m;
// 计算 proposed_m = floor(2^(floor_log_2_d + 31) / absD)
proposed_m = div_64_32((uint32_t)1 << (floor_log_2_d - 1), 0, absD, &rem);
const uint32_t e = absD - rem;
if (DEBUG) {
std::cout << "[SR] proposed_m = " << proposed_m << ", rem = " << rem << ", e = " << e << std::endl;
}
// 确定是否需要"加法"版本
const bool branchfree = false; // 使用分支版本
if (!branchfree && e < ((uint32_t)1 << floor_log_2_d)) {
// 这个幂次有效
more = (uint8_t)(floor_log_2_d - 1);
if (DEBUG) {
std::cout << "[SR] Using basic algorithm, shift = " << (int)more << std::endl;
}
} else {
// 我们需要上升一个等级
proposed_m += proposed_m;
const uint32_t twice_rem = rem + rem;
if (twice_rem >= absD || twice_rem < rem) {
proposed_m += 1;
}
more = (uint8_t)(floor_log_2_d | LIBDIVIDE_ADD_MARKER);
if (DEBUG) {
std::cout << "[SR] Using add algorithm, proposed_m = " << proposed_m << ", more = " << (int)more << std::endl;
}
}
proposed_m += 1;
int32_t magic = (int32_t)proposed_m;
// 处理负除数
if (d < 0) {
more |= LIBDIVIDE_NEGATIVE_DIVISOR;
if (!branchfree) {
magic = -magic;
}
if (DEBUG) {
std::cout << "[SR] Negative divisor, magic = " << magic << ", more = " << (int)more << std::endl;
}
}
// 为我们的IR生成提取移位量和标志
int shift = more & 0x3F; // 移除标志保留移位量位0-5
bool need_add = (more & LIBDIVIDE_ADD_MARKER) != 0;
bool is_negative = (more & LIBDIVIDE_NEGATIVE_DIVISOR) != 0;
if (DEBUG) {
std::cout << "[SR] Final result: magic = " << magic << ", more = " << (int)more
<< " (0x" << std::hex << (int)more << std::dec << ")" << std::endl;
std::cout << "[SR] Shift = " << shift << ", need_add = " << need_add
<< ", is_negative = " << is_negative << std::endl;
// Test the magic number using the correct libdivide algorithm
std::cout << "[SR] Testing magic number (libdivide algorithm):" << std::endl;
int test_values[] = {1, 7, 37, 100, 999, -1, -7, -37, -100};
for (int test_val : test_values) {
int64_t quotient;
// 实现正确的libdivide算法
int64_t product = (int64_t)test_val * magic;
int64_t high_bits = product >> 32;
if (need_add) {
// ADD_MARKER情况移位前加上被除数
// 这是libdivide的关键洞察
high_bits += test_val;
quotient = high_bits >> shift;
} else {
// 正常情况:只是移位
quotient = high_bits >> shift;
}
// 符号修正这是libdivide有符号除法的关键部分
// 如果被除数为负商需要加1来匹配C语言的截断除法语义
if (test_val < 0) {
quotient += 1;
}
int expected = test_val / divisor;
bool correct = (quotient == expected);
std::cout << "[SR] " << test_val << " / " << divisor << " = " << quotient
<< " (expected " << expected << ") " << (correct ? "" : "") << std::endl;
}
std::cout << "[SR] ===== End magic computation =====" << std::endl;
}
// 返回魔数、移位量并在移位中编码ADD_MARKER标志
// 我们将使用移位的第6位表示ADD_MARKER第7位表示负数如果需要
int encoded_shift = shift;
if (need_add) {
encoded_shift |= 0x40; // 设置第6位表示ADD_MARKER
if (DEBUG) {
std::cout << "[SR] Encoding ADD_MARKER in shift: " << encoded_shift << std::endl;
}
}
return {magic, encoded_shift};
}
bool LoopStrengthReduction::runOnFunction(Function* F, AnalysisManager& AM) {
if (F->getBasicBlocks().empty()) {
@@ -1018,7 +837,7 @@ Value* StrengthReductionContext::generateConstantDivisionReplacement(
IRBuilder* builder
) const {
// 使用mulh指令优化任意常数除法
auto [magic, shift] = computeMulhMagicNumbers(candidate->multiplier);
auto [magic, shift] = SysYIROptUtils::computeMulhMagicNumbers(candidate->multiplier);
// 检查是否无法优化magic == -1, shift == -1 表示失败)
if (magic == -1 && shift == -1) {

View File

@@ -18,6 +18,7 @@
#include "LICM.h"
#include "LoopStrengthReduction.h"
#include "InductionVariableElimination.h"
#include "GlobalStrengthReduction.h"
#include "Pass.h"
#include <iostream>
#include <queue>
@@ -77,6 +78,8 @@ void PassManager::runOptimizationPipeline(Module* moduleIR, IRBuilder* builderIR
registerOptimizationPass<LICM>(builderIR);
registerOptimizationPass<LoopStrengthReduction>(builderIR);
registerOptimizationPass<InductionVariableElimination>();
registerOptimizationPass<GlobalStrengthReduction>(builderIR);
registerOptimizationPass<Reg2Mem>(builderIR);
registerOptimizationPass<SCCP>(builderIR);
@@ -179,6 +182,16 @@ void PassManager::runOptimizationPipeline(Module* moduleIR, IRBuilder* builderIR
printPasses();
}
// 全局强度削弱优化,包括代数优化和魔数除法
this->clearPasses();
this->addPass(&GlobalStrengthReduction::ID);
this->run();
if(DEBUG) {
std::cout << "=== IR After Global Strength Reduction Optimizations ===\n";
printPasses();
}
// this->clearPasses();
// this->addPass(&Reg2Mem::ID);
// this->run();