Compare commits
4 Commits
midend-Loo
...
midend-Fun
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
97d83d733e | ||
|
|
ad74e435ba | ||
|
|
5c34cbc7b8 | ||
|
|
c9a0c700e1 |
@@ -652,6 +652,13 @@ public:
|
|||||||
} ///< 移除指定位置的指令
|
} ///< 移除指定位置的指令
|
||||||
iterator moveInst(iterator sourcePos, iterator targetPos, BasicBlock *block);
|
iterator moveInst(iterator sourcePos, iterator targetPos, BasicBlock *block);
|
||||||
|
|
||||||
|
static void bbdfs(BasicBlock* bb, std::function<bool(BasicBlock*)> func) {
|
||||||
|
if (func(bb))
|
||||||
|
return;
|
||||||
|
for (auto succ : bb->getSuccessors())
|
||||||
|
bbdfs(succ, func);
|
||||||
|
}
|
||||||
|
|
||||||
/// 清理基本块中的所有使用关系
|
/// 清理基本块中的所有使用关系
|
||||||
void cleanup();
|
void cleanup();
|
||||||
|
|
||||||
@@ -767,7 +774,7 @@ protected:
|
|||||||
: User(type, name), kind(kind), parent(parent) {}
|
: User(type, name), kind(kind), parent(parent) {}
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
virtual Instruction* copy(std::function<Value*(Value*)> getValue) const = 0;
|
||||||
public:
|
public:
|
||||||
Kind getKind() const { return kind; }
|
Kind getKind() const { return kind; }
|
||||||
std::string getKindString() const{
|
std::string getKindString() const{
|
||||||
@@ -964,6 +971,18 @@ class PhiInst : public Instruction {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
PhiInst(Type *type,
|
||||||
|
const std::vector<std::pair<BasicBlock*, Value*>> IncomingValues,
|
||||||
|
BasicBlock *parent = nullptr,
|
||||||
|
const std::string &name = "")
|
||||||
|
: Instruction(Kind::kPhi, type, parent, name), vsize(IncomingValues.size()) {
|
||||||
|
for(size_t i = 0; i < vsize; ++i) {
|
||||||
|
addOperand(IncomingValues[i].first);
|
||||||
|
addOperand(IncomingValues[i].second);
|
||||||
|
}
|
||||||
|
refreshMap(); ///< 刷新块到值的映射关系
|
||||||
|
}
|
||||||
|
|
||||||
public:
|
public:
|
||||||
unsigned getNumIncomingValues() const { return vsize; } ///< 获取传入值的数量
|
unsigned getNumIncomingValues() const { return vsize; } ///< 获取传入值的数量
|
||||||
Value *getIncomingValue(unsigned Idx) const { return getOperand(Idx * 2); } ///< 获取指定位置的传入值
|
Value *getIncomingValue(unsigned Idx) const { return getOperand(Idx * 2); } ///< 获取指定位置的传入值
|
||||||
@@ -1013,6 +1032,7 @@ class PhiInst : public Instruction {
|
|||||||
} ///< 刷新块到值的映射关系
|
} ///< 刷新块到值的映射关系
|
||||||
auto getValues() { return make_range(std::next(operand_begin()), operand_end()); }
|
auto getValues() { return make_range(std::next(operand_begin()), operand_end()); }
|
||||||
void print(std::ostream& os) const override;
|
void print(std::ostream& os) const override;
|
||||||
|
Instruction* copy(std::function<Value*(Value*)> getValue) const override;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
@@ -1047,6 +1067,7 @@ protected:
|
|||||||
public:
|
public:
|
||||||
Value* getOperand() const { return User::getOperand(0); }
|
Value* getOperand() const { return User::getOperand(0); }
|
||||||
void print(std::ostream& os) const override;
|
void print(std::ostream& os) const override;
|
||||||
|
Instruction* copy(std::function<Value*(Value*)> getValue) const override;
|
||||||
}; // class UnaryInst
|
}; // class UnaryInst
|
||||||
|
|
||||||
//! Binary instruction, e.g., arithmatic, relation, logic, etc.
|
//! Binary instruction, e.g., arithmatic, relation, logic, etc.
|
||||||
@@ -1126,6 +1147,7 @@ public:
|
|||||||
return new BinaryInst(kind, type, lhs, rhs, parent, name);
|
return new BinaryInst(kind, type, lhs, rhs, parent, name);
|
||||||
}
|
}
|
||||||
void print(std::ostream& os) const override;
|
void print(std::ostream& os) const override;
|
||||||
|
Instruction* copy(std::function<Value*(Value*)> getValue) const override;
|
||||||
}; // class BinaryInst
|
}; // class BinaryInst
|
||||||
|
|
||||||
//! The return statement
|
//! The return statement
|
||||||
@@ -1217,6 +1239,7 @@ public:
|
|||||||
return succs;
|
return succs;
|
||||||
}
|
}
|
||||||
void print(std::ostream& os) const override;
|
void print(std::ostream& os) const override;
|
||||||
|
Instruction* copy(std::function<Value*(Value*)> getValue) const override;
|
||||||
}; // class CondBrInst
|
}; // class CondBrInst
|
||||||
|
|
||||||
class UnreachableInst : public Instruction {
|
class UnreachableInst : public Instruction {
|
||||||
@@ -1225,6 +1248,7 @@ public:
|
|||||||
explicit UnreachableInst(const std::string& name, BasicBlock *parent = nullptr)
|
explicit UnreachableInst(const std::string& name, BasicBlock *parent = nullptr)
|
||||||
: Instruction(kUnreachable, Type::getVoidType(), parent, "") {}
|
: Instruction(kUnreachable, Type::getVoidType(), parent, "") {}
|
||||||
void print(std::ostream& os) const { os << "unreachable"; }
|
void print(std::ostream& os) const { os << "unreachable"; }
|
||||||
|
Instruction* copy(std::function<Value*(Value*)> getValue) const override;
|
||||||
};
|
};
|
||||||
|
|
||||||
//! Allocate memory for stack variables, used for non-global variable declartion
|
//! Allocate memory for stack variables, used for non-global variable declartion
|
||||||
@@ -1243,6 +1267,7 @@ public:
|
|||||||
return getType()->as<PointerType>()->getBaseType();
|
return getType()->as<PointerType>()->getBaseType();
|
||||||
} ///< 获取分配的类型
|
} ///< 获取分配的类型
|
||||||
void print(std::ostream& os) const override;
|
void print(std::ostream& os) const override;
|
||||||
|
Instruction* copy(std::function<Value*(Value*)> getValue) const override;
|
||||||
}; // class AllocaInst
|
}; // class AllocaInst
|
||||||
|
|
||||||
|
|
||||||
@@ -1281,6 +1306,7 @@ public:
|
|||||||
return new GetElementPtrInst(resultType, basePointer, indices, parent, name);
|
return new GetElementPtrInst(resultType, basePointer, indices, parent, name);
|
||||||
}
|
}
|
||||||
void print(std::ostream& os) const override;
|
void print(std::ostream& os) const override;
|
||||||
|
Instruction* copy(std::function<Value*(Value*)> getValue) const override;
|
||||||
};
|
};
|
||||||
|
|
||||||
//! Load a value from memory address specified by a pointer value
|
//! Load a value from memory address specified by a pointer value
|
||||||
@@ -1299,6 +1325,7 @@ protected:
|
|||||||
public:
|
public:
|
||||||
Value* getPointer() const { return getOperand(0); }
|
Value* getPointer() const { return getOperand(0); }
|
||||||
void print(std::ostream& os) const override;
|
void print(std::ostream& os) const override;
|
||||||
|
Instruction* copy(std::function<Value*(Value*)> getValue) const override;
|
||||||
}; // class LoadInst
|
}; // class LoadInst
|
||||||
|
|
||||||
//! Store a value to memory address specified by a pointer value
|
//! Store a value to memory address specified by a pointer value
|
||||||
@@ -1318,6 +1345,7 @@ public:
|
|||||||
Value* getValue() const { return getOperand(0); }
|
Value* getValue() const { return getOperand(0); }
|
||||||
Value* getPointer() const { return getOperand(1); }
|
Value* getPointer() const { return getOperand(1); }
|
||||||
void print(std::ostream& os) const override;
|
void print(std::ostream& os) const override;
|
||||||
|
Instruction* copy(std::function<Value*(Value*)> getValue) const override;
|
||||||
}; // class StoreInst
|
}; // class StoreInst
|
||||||
|
|
||||||
//! Memset instruction
|
//! Memset instruction
|
||||||
@@ -1348,6 +1376,7 @@ public:
|
|||||||
Value* getSize() const { return getOperand(2); }
|
Value* getSize() const { return getOperand(2); }
|
||||||
Value* getValue() const { return getOperand(3); }
|
Value* getValue() const { return getOperand(3); }
|
||||||
void print(std::ostream& os) const override;
|
void print(std::ostream& os) const override;
|
||||||
|
Instruction* copy(std::function<Value*(Value*)> getValue) const override;
|
||||||
};
|
};
|
||||||
|
|
||||||
class GlobalValue;
|
class GlobalValue;
|
||||||
@@ -1385,19 +1414,11 @@ protected:
|
|||||||
public:
|
public:
|
||||||
using block_list = std::list<std::unique_ptr<BasicBlock>>;
|
using block_list = std::list<std::unique_ptr<BasicBlock>>;
|
||||||
using arg_list = std::vector<Argument *>;
|
using arg_list = std::vector<Argument *>;
|
||||||
enum FunctionAttribute : uint64_t {
|
|
||||||
PlaceHolder = 0x0UL,
|
|
||||||
Pure = 0x1UL << 0,
|
|
||||||
SelfRecursive = 0x1UL << 1,
|
|
||||||
SideEffect = 0x1UL << 2,
|
|
||||||
NoPureCauseMemRead = 0x1UL << 3
|
|
||||||
};
|
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
Module *parent; ///< 函数的父模块
|
Module *parent; ///< 函数的父模块
|
||||||
block_list blocks; ///< 函数包含的基本块列表
|
block_list blocks; ///< 函数包含的基本块列表
|
||||||
arg_list arguments; ///< 函数参数列表
|
arg_list arguments; ///< 函数参数列表
|
||||||
FunctionAttribute attribute = PlaceHolder; ///< 函数属性
|
|
||||||
std::set<Function *> callees; ///< 函数调用的函数集合
|
std::set<Function *> callees; ///< 函数调用的函数集合
|
||||||
public:
|
public:
|
||||||
static unsigned getcloneIndex() {
|
static unsigned getcloneIndex() {
|
||||||
@@ -1405,17 +1426,12 @@ protected:
|
|||||||
cloneIndex += 1;
|
cloneIndex += 1;
|
||||||
return cloneIndex - 1;
|
return cloneIndex - 1;
|
||||||
}
|
}
|
||||||
Function* clone(const std::string &suffix = "_" + std::to_string(getcloneIndex()) + "@") const;
|
Function* copy_func();
|
||||||
const std::set<Function *>& getCallees() { return callees; }
|
const std::set<Function *>& getCallees() { return callees; }
|
||||||
void addCallee(Function *callee) { callees.insert(callee); }
|
void addCallee(Function *callee) { callees.insert(callee); }
|
||||||
void removeCallee(Function *callee) { callees.erase(callee); }
|
void removeCallee(Function *callee) { callees.erase(callee); }
|
||||||
void clearCallees() { callees.clear(); }
|
void clearCallees() { callees.clear(); }
|
||||||
std::set<Function *> getCalleesWithNoExternalAndSelf();
|
std::set<Function *> getCalleesWithNoExternalAndSelf();
|
||||||
FunctionAttribute getAttribute() const { return attribute; }
|
|
||||||
void setAttribute(FunctionAttribute attr) {
|
|
||||||
attribute = static_cast<FunctionAttribute>(attribute | attr);
|
|
||||||
}
|
|
||||||
void clearAttribute() { attribute = PlaceHolder; }
|
|
||||||
Type* getReturnType() const { return getType()->as<FunctionType>()->getReturnType(); }
|
Type* getReturnType() const { return getType()->as<FunctionType>()->getReturnType(); }
|
||||||
auto getParamTypes() const { return getType()->as<FunctionType>()->getParamTypes(); }
|
auto getParamTypes() const { return getType()->as<FunctionType>()->getParamTypes(); }
|
||||||
auto getBasicBlocks() { return make_range(blocks); }
|
auto getBasicBlocks() { return make_range(blocks); }
|
||||||
|
|||||||
107
src/include/midend/Pass/Optimize/GlobalStrengthReduction.h
Normal file
107
src/include/midend/Pass/Optimize/GlobalStrengthReduction.h
Normal file
@@ -0,0 +1,107 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "Pass.h"
|
||||||
|
#include "IR.h"
|
||||||
|
#include "SideEffectAnalysis.h"
|
||||||
|
#include <unordered_map>
|
||||||
|
#include <unordered_set>
|
||||||
|
#include <vector>
|
||||||
|
#include <cstdint>
|
||||||
|
|
||||||
|
namespace sysy {
|
||||||
|
|
||||||
|
// 魔数乘法结构,用于除法优化
|
||||||
|
struct MagicNumber {
|
||||||
|
uint32_t multiplier;
|
||||||
|
int shift;
|
||||||
|
bool needAdd;
|
||||||
|
|
||||||
|
MagicNumber(uint32_t m, int s, bool add = false)
|
||||||
|
: multiplier(m), shift(s), needAdd(add) {}
|
||||||
|
};
|
||||||
|
|
||||||
|
// 全局强度削弱优化遍的核心逻辑封装类
|
||||||
|
class GlobalStrengthReductionContext {
|
||||||
|
public:
|
||||||
|
// 构造函数,接受IRBuilder参数
|
||||||
|
explicit GlobalStrengthReductionContext(IRBuilder* builder) : builder(builder) {}
|
||||||
|
|
||||||
|
// 运行优化的主要方法
|
||||||
|
void run(Function* func, AnalysisManager* AM, bool& changed);
|
||||||
|
|
||||||
|
private:
|
||||||
|
IRBuilder* builder; // IR构建器
|
||||||
|
|
||||||
|
// 分析结果
|
||||||
|
SideEffectAnalysisResult* sideEffectAnalysis = nullptr;
|
||||||
|
|
||||||
|
// 优化计数
|
||||||
|
int algebraicOptCount = 0;
|
||||||
|
int strengthReductionCount = 0;
|
||||||
|
int divisionOptCount = 0;
|
||||||
|
|
||||||
|
// 主要优化方法
|
||||||
|
bool processBasicBlock(BasicBlock* bb);
|
||||||
|
bool processInstruction(Instruction* inst);
|
||||||
|
|
||||||
|
// 代数优化方法
|
||||||
|
bool tryAlgebraicOptimization(Instruction* inst);
|
||||||
|
bool optimizeAddition(BinaryInst* inst);
|
||||||
|
bool optimizeSubtraction(BinaryInst* inst);
|
||||||
|
bool optimizeMultiplication(BinaryInst* inst);
|
||||||
|
bool optimizeDivision(BinaryInst* inst);
|
||||||
|
bool optimizeComparison(BinaryInst* inst);
|
||||||
|
bool optimizeLogical(BinaryInst* inst);
|
||||||
|
|
||||||
|
// 强度削弱方法
|
||||||
|
bool tryStrengthReduction(Instruction* inst);
|
||||||
|
bool reduceMultiplication(BinaryInst* inst);
|
||||||
|
bool reduceDivision(BinaryInst* inst);
|
||||||
|
bool reducePower(CallInst* inst);
|
||||||
|
|
||||||
|
// 复杂乘法强度削弱方法
|
||||||
|
bool tryComplexMultiplication(BinaryInst* inst, Value* variable, int constant);
|
||||||
|
bool findOptimalShiftDecomposition(int constant, std::vector<int>& shifts);
|
||||||
|
Value* createShiftDecomposition(BinaryInst* inst, Value* variable, const std::vector<int>& shifts);
|
||||||
|
|
||||||
|
// 魔数乘法相关方法
|
||||||
|
MagicNumber computeMagicNumber(uint32_t divisor);
|
||||||
|
std::pair<int, int> computeMulhMagicNumbers(int divisor);
|
||||||
|
Value* createMagicDivision(BinaryInst* divInst, uint32_t divisor, const MagicNumber& magic);
|
||||||
|
Value* createMagicDivisionLibdivide(BinaryInst* divInst, int divisor);
|
||||||
|
bool isPowerOfTwo(uint32_t n);
|
||||||
|
int log2OfPowerOfTwo(uint32_t n);
|
||||||
|
|
||||||
|
// 辅助方法
|
||||||
|
bool isConstantInt(Value* val, int& constVal);
|
||||||
|
bool isConstantInt(Value* val, uint32_t& constVal);
|
||||||
|
ConstantInteger* getConstantInt(int val);
|
||||||
|
bool hasOnlyLocalUses(Instruction* inst);
|
||||||
|
void replaceWithOptimized(Instruction* original, Value* replacement);
|
||||||
|
};
|
||||||
|
|
||||||
|
// 全局强度削弱优化遍类
|
||||||
|
class GlobalStrengthReduction : public OptimizationPass {
|
||||||
|
private:
|
||||||
|
IRBuilder* builder; // IR构建器,用于创建新指令
|
||||||
|
|
||||||
|
public:
|
||||||
|
// 静态成员,作为该遍的唯一ID
|
||||||
|
static void* ID;
|
||||||
|
|
||||||
|
// 构造函数,接受IRBuilder参数
|
||||||
|
explicit GlobalStrengthReduction(IRBuilder* builder)
|
||||||
|
: OptimizationPass("GlobalStrengthReduction", Granularity::Function), builder(builder) {}
|
||||||
|
|
||||||
|
// 在函数上运行优化
|
||||||
|
bool runOnFunction(Function* func, AnalysisManager& AM) override;
|
||||||
|
|
||||||
|
// 返回该遍的唯一ID
|
||||||
|
void* getPassID() const override { return ID; }
|
||||||
|
|
||||||
|
// 声明分析依赖
|
||||||
|
void getAnalysisUsage(std::set<void*>& analysisDependencies,
|
||||||
|
std::set<void*>& analysisInvalidations) const override;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace sysy
|
||||||
@@ -127,13 +127,6 @@ private:
|
|||||||
*/
|
*/
|
||||||
bool analyzeInductionVariableRange(const InductionVarInfo* ivInfo, Loop* loop) const;
|
bool analyzeInductionVariableRange(const InductionVarInfo* ivInfo, Loop* loop) const;
|
||||||
|
|
||||||
/**
|
|
||||||
* 计算用于除法优化的魔数和移位量
|
|
||||||
* @param divisor 除数
|
|
||||||
* @return {魔数, 移位量}
|
|
||||||
*/
|
|
||||||
std::pair<int, int> computeMulhMagicNumbers(int divisor) const;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 生成除法替换代码
|
* 生成除法替换代码
|
||||||
* @param candidate 优化候选项
|
* @param candidate 优化候选项
|
||||||
|
|||||||
@@ -107,6 +107,190 @@ public:
|
|||||||
// 所以当AllocaInst的basetype是PointerType时(一维数组)或者是指向ArrayType的PointerType(多位数组)时,返回true
|
// 所以当AllocaInst的basetype是PointerType时(一维数组)或者是指向ArrayType的PointerType(多位数组)时,返回true
|
||||||
return aval && (baseType->isPointer() || baseType->as<PointerType>()->getBaseType()->isArray());
|
return aval && (baseType->isPointer() || baseType->as<PointerType>()->getBaseType()->isArray());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
//该实现参考了libdivide的算法
|
||||||
|
static std::pair<int, int> computeMulhMagicNumbers(int divisor) {
|
||||||
|
|
||||||
|
if (DEBUG) {
|
||||||
|
std::cout << "\n[SR] ===== Computing magic numbers for divisor " << divisor << " (libdivide algorithm) =====" << std::endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (divisor == 0) {
|
||||||
|
if (DEBUG) std::cout << "[SR] Error: divisor must be != 0" << std::endl;
|
||||||
|
return {-1, -1};
|
||||||
|
}
|
||||||
|
|
||||||
|
// libdivide 常数
|
||||||
|
const uint8_t LIBDIVIDE_ADD_MARKER = 0x40;
|
||||||
|
const uint8_t LIBDIVIDE_NEGATIVE_DIVISOR = 0x80;
|
||||||
|
|
||||||
|
// 辅助函数:计算前导零个数
|
||||||
|
auto count_leading_zeros32 = [](uint32_t val) -> uint32_t {
|
||||||
|
if (val == 0) return 32;
|
||||||
|
return __builtin_clz(val);
|
||||||
|
};
|
||||||
|
|
||||||
|
// 辅助函数:64位除法返回32位商和余数
|
||||||
|
auto div_64_32 = [](uint32_t high, uint32_t low, uint32_t divisor, uint32_t* rem) -> uint32_t {
|
||||||
|
uint64_t dividend = ((uint64_t)high << 32) | low;
|
||||||
|
uint32_t quotient = dividend / divisor;
|
||||||
|
*rem = dividend % divisor;
|
||||||
|
return quotient;
|
||||||
|
};
|
||||||
|
|
||||||
|
if (DEBUG) {
|
||||||
|
std::cout << "[SR] Input divisor: " << divisor << std::endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
// libdivide_internal_s32_gen 算法实现
|
||||||
|
int32_t d = divisor;
|
||||||
|
uint32_t ud = (uint32_t)d;
|
||||||
|
uint32_t absD = (d < 0) ? -ud : ud;
|
||||||
|
|
||||||
|
if (DEBUG) {
|
||||||
|
std::cout << "[SR] absD = " << absD << std::endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint32_t floor_log_2_d = 31 - count_leading_zeros32(absD);
|
||||||
|
|
||||||
|
if (DEBUG) {
|
||||||
|
std::cout << "[SR] floor_log_2_d = " << floor_log_2_d << std::endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查 absD 是否为2的幂
|
||||||
|
if ((absD & (absD - 1)) == 0) {
|
||||||
|
if (DEBUG) {
|
||||||
|
std::cout << "[SR] " << absD << " 是2的幂,使用移位方法" << std::endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 对于2的幂,我们只使用移位,不需要魔数
|
||||||
|
int shift = floor_log_2_d;
|
||||||
|
if (d < 0) shift |= 0x80; // 标记负数
|
||||||
|
|
||||||
|
if (DEBUG) {
|
||||||
|
std::cout << "[SR] Power of 2 result: magic=0, shift=" << shift << std::endl;
|
||||||
|
std::cout << "[SR] ===== End magic computation =====" << std::endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 对于我们的目的,我们将在IR生成中以不同方式处理2的幂
|
||||||
|
// 返回特殊标记
|
||||||
|
return {0, shift};
|
||||||
|
}
|
||||||
|
|
||||||
|
if (DEBUG) {
|
||||||
|
std::cout << "[SR] " << absD << " is not a power of 2, computing magic number" << std::endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 非2的幂除数的魔数计算
|
||||||
|
uint8_t more;
|
||||||
|
uint32_t rem, proposed_m;
|
||||||
|
|
||||||
|
// 计算 proposed_m = floor(2^(floor_log_2_d + 31) / absD)
|
||||||
|
proposed_m = div_64_32((uint32_t)1 << (floor_log_2_d - 1), 0, absD, &rem);
|
||||||
|
const uint32_t e = absD - rem;
|
||||||
|
|
||||||
|
if (DEBUG) {
|
||||||
|
std::cout << "[SR] proposed_m = " << proposed_m << ", rem = " << rem << ", e = " << e << std::endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 确定是否需要"加法"版本
|
||||||
|
const bool branchfree = false; // 使用分支版本
|
||||||
|
|
||||||
|
if (!branchfree && e < ((uint32_t)1 << floor_log_2_d)) {
|
||||||
|
// 这个幂次有效
|
||||||
|
more = (uint8_t)(floor_log_2_d - 1);
|
||||||
|
if (DEBUG) {
|
||||||
|
std::cout << "[SR] Using basic algorithm, shift = " << (int)more << std::endl;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// 我们需要上升一个等级
|
||||||
|
proposed_m += proposed_m;
|
||||||
|
const uint32_t twice_rem = rem + rem;
|
||||||
|
if (twice_rem >= absD || twice_rem < rem) {
|
||||||
|
proposed_m += 1;
|
||||||
|
}
|
||||||
|
more = (uint8_t)(floor_log_2_d | LIBDIVIDE_ADD_MARKER);
|
||||||
|
if (DEBUG) {
|
||||||
|
std::cout << "[SR] Using add algorithm, proposed_m = " << proposed_m << ", more = " << (int)more << std::endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
proposed_m += 1;
|
||||||
|
int32_t magic = (int32_t)proposed_m;
|
||||||
|
|
||||||
|
// 处理负除数
|
||||||
|
if (d < 0) {
|
||||||
|
more |= LIBDIVIDE_NEGATIVE_DIVISOR;
|
||||||
|
if (!branchfree) {
|
||||||
|
magic = -magic;
|
||||||
|
}
|
||||||
|
if (DEBUG) {
|
||||||
|
std::cout << "[SR] Negative divisor, magic = " << magic << ", more = " << (int)more << std::endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 为我们的IR生成提取移位量和标志
|
||||||
|
int shift = more & 0x3F; // 移除标志,保留移位量(位0-5)
|
||||||
|
bool need_add = (more & LIBDIVIDE_ADD_MARKER) != 0;
|
||||||
|
bool is_negative = (more & LIBDIVIDE_NEGATIVE_DIVISOR) != 0;
|
||||||
|
|
||||||
|
if (DEBUG) {
|
||||||
|
std::cout << "[SR] Final result: magic = " << magic << ", more = " << (int)more
|
||||||
|
<< " (0x" << std::hex << (int)more << std::dec << ")" << std::endl;
|
||||||
|
std::cout << "[SR] Shift = " << shift << ", need_add = " << need_add
|
||||||
|
<< ", is_negative = " << is_negative << std::endl;
|
||||||
|
|
||||||
|
// Test the magic number using the correct libdivide algorithm
|
||||||
|
std::cout << "[SR] Testing magic number (libdivide algorithm):" << std::endl;
|
||||||
|
int test_values[] = {1, 7, 37, 100, 999, -1, -7, -37, -100};
|
||||||
|
|
||||||
|
for (int test_val : test_values) {
|
||||||
|
int64_t quotient;
|
||||||
|
|
||||||
|
// 实现正确的libdivide算法
|
||||||
|
int64_t product = (int64_t)test_val * magic;
|
||||||
|
int64_t high_bits = product >> 32;
|
||||||
|
|
||||||
|
if (need_add) {
|
||||||
|
// ADD_MARKER情况:移位前加上被除数
|
||||||
|
// 这是libdivide的关键洞察!
|
||||||
|
high_bits += test_val;
|
||||||
|
quotient = high_bits >> shift;
|
||||||
|
} else {
|
||||||
|
// 正常情况:只是移位
|
||||||
|
quotient = high_bits >> shift;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 符号修正:这是libdivide有符号除法的关键部分!
|
||||||
|
// 如果被除数为负,商需要加1来匹配C语言的截断除法语义
|
||||||
|
if (test_val < 0) {
|
||||||
|
quotient += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
int expected = test_val / divisor;
|
||||||
|
|
||||||
|
bool correct = (quotient == expected);
|
||||||
|
std::cout << "[SR] " << test_val << " / " << divisor << " = " << quotient
|
||||||
|
<< " (expected " << expected << ") " << (correct ? "✓" : "✗") << std::endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::cout << "[SR] ===== End magic computation =====" << std::endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 返回魔数、移位量,并在移位中编码ADD_MARKER标志
|
||||||
|
// 我们将使用移位的第6位表示ADD_MARKER,第7位表示负数(如果需要)
|
||||||
|
int encoded_shift = shift;
|
||||||
|
if (need_add) {
|
||||||
|
encoded_shift |= 0x40; // 设置第6位表示ADD_MARKER
|
||||||
|
if (DEBUG) {
|
||||||
|
std::cout << "[SR] Encoding ADD_MARKER in shift: " << encoded_shift << std::endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return {magic, encoded_shift};
|
||||||
|
}
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
}// namespace sysy
|
}// namespace sysy
|
||||||
@@ -22,6 +22,7 @@ add_library(midend_lib STATIC
|
|||||||
Pass/Optimize/LICM.cpp
|
Pass/Optimize/LICM.cpp
|
||||||
Pass/Optimize/LoopStrengthReduction.cpp
|
Pass/Optimize/LoopStrengthReduction.cpp
|
||||||
Pass/Optimize/InductionVariableElimination.cpp
|
Pass/Optimize/InductionVariableElimination.cpp
|
||||||
|
Pass/Optimize/GlobalStrengthReduction.cpp
|
||||||
Pass/Optimize/BuildCFG.cpp
|
Pass/Optimize/BuildCFG.cpp
|
||||||
Pass/Optimize/LargeArrayToGlobal.cpp
|
Pass/Optimize/LargeArrayToGlobal.cpp
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -652,6 +652,12 @@ void BasicBlock::print(std::ostream &os) const {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void Instruction::print(std::ostream &os) const {
|
||||||
|
this->getType()->print(os);
|
||||||
|
std::cerr << " %" << getName() << " ";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
void PhiInst::print(std::ostream &os) const {
|
void PhiInst::print(std::ostream &os) const {
|
||||||
printVarName(os, this);
|
printVarName(os, this);
|
||||||
os << " = " << getKindString() << " " << *getType() << " ";
|
os << " = " << getKindString() << " " << *getType() << " ";
|
||||||
@@ -1440,4 +1446,76 @@ void Argument::cleanup() {
|
|||||||
uses.clear();
|
uses.clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Function* Function::copy_func(){
|
||||||
|
std::unordered_map<Value*, Value*> valueMap;
|
||||||
|
// copy global
|
||||||
|
for (auto &gvalue : parent->getGlobals()) {
|
||||||
|
valueMap.emplace(gvalue.get(), gvalue.get());
|
||||||
|
}
|
||||||
|
// copy global const
|
||||||
|
for (auto &cvalue : parent->getConsts()) {
|
||||||
|
valueMap.emplace(cvalue.get(), cvalue.get());
|
||||||
|
}
|
||||||
|
|
||||||
|
// copy function
|
||||||
|
auto newFunc = new Function(parent, type, name + "_copy" + std::to_string(getcloneIndex()));
|
||||||
|
// copy arg
|
||||||
|
for(int i = 0; i < arguments.size(); i++) {
|
||||||
|
auto arg = arguments[i];
|
||||||
|
// Argument(paramActualTypes[i], function, i, paramNames[i]);
|
||||||
|
auto newarg = new Argument(arg->getType(), newFunc, i, arg->getName() + "_copy");
|
||||||
|
newFunc->insertArgument(newarg);
|
||||||
|
valueMap[arg] = newarg;
|
||||||
|
}
|
||||||
|
//
|
||||||
|
for (auto &bb : blocks) {
|
||||||
|
BasicBlock* copybb = newFunc->addBasicBlock();
|
||||||
|
valueMap.emplace(bb, copybb);
|
||||||
|
}
|
||||||
|
|
||||||
|
for(auto &bb : blocks) {
|
||||||
|
auto BB = bb.get();
|
||||||
|
auto copybb = dynamic_cast<BasicBlock*>(valueMap[BB]);
|
||||||
|
for(auto pred : BB->getPredecessors()) {
|
||||||
|
copybb->addPredecessor(dynamic_cast<BasicBlock*>(valueMap[pred]));
|
||||||
|
}
|
||||||
|
for(auto succ : BB->getSuccessors()) {
|
||||||
|
copybb->addSuccessor(dynamic_cast<BasicBlock*>(valueMap[succ]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// if cant find, return itself
|
||||||
|
auto getValue = [&](Value* val) -> Value* {
|
||||||
|
if (val == nullptr) {
|
||||||
|
std::cerr << "getValue(nullptr)" << std::endl;
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
if (dynamic_cast<ConstantValue*>(val)) return val;
|
||||||
|
if (auto iter = valueMap.find(val); iter != valueMap.end()) return iter->second;
|
||||||
|
return val;
|
||||||
|
};
|
||||||
|
|
||||||
|
std::set<BasicBlock*> visitedbb;
|
||||||
|
|
||||||
|
const auto copyBlock = [&](BasicBlock* bb) -> bool {
|
||||||
|
if (visitedbb.count(bb)) return true;
|
||||||
|
visitedbb.insert(bb);
|
||||||
|
auto bbCpy = dynamic_cast<BasicBlock*>(valueMap.at(bb));
|
||||||
|
for (auto &Inst : bb->getInstructions()) {
|
||||||
|
auto inst = Inst.get();
|
||||||
|
// inst->print(std::cerr);
|
||||||
|
// std::cerr << std::endl;
|
||||||
|
auto copyinst = inst->copy(getValue);
|
||||||
|
copyinst->setParent(bbCpy);
|
||||||
|
valueMap.emplace(inst, copyinst);
|
||||||
|
bbCpy->instructions.emplace_back(copyinst);
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
};
|
||||||
|
//dfs基本块将指令复制到新函数中
|
||||||
|
BasicBlock::bbdfs(getEntryBlock(), copyBlock);
|
||||||
|
|
||||||
|
return newFunc;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace sysy
|
} // namespace sysy
|
||||||
|
|||||||
897
src/midend/Pass/Optimize/GlobalStrengthReduction.cpp
Normal file
897
src/midend/Pass/Optimize/GlobalStrengthReduction.cpp
Normal file
@@ -0,0 +1,897 @@
|
|||||||
|
#include "GlobalStrengthReduction.h"
|
||||||
|
#include "SysYIROptUtils.h"
|
||||||
|
#include "IRBuilder.h"
|
||||||
|
#include <algorithm>
|
||||||
|
#include <cassert>
|
||||||
|
#include <iostream>
|
||||||
|
#include <cmath>
|
||||||
|
|
||||||
|
extern int DEBUG;
|
||||||
|
|
||||||
|
namespace sysy {
|
||||||
|
|
||||||
|
// 全局强度削弱优化遍的静态 ID
|
||||||
|
void *GlobalStrengthReduction::ID = (void *)&GlobalStrengthReduction::ID;
|
||||||
|
|
||||||
|
// ======================================================================
|
||||||
|
// GlobalStrengthReduction 类的实现
|
||||||
|
// ======================================================================
|
||||||
|
|
||||||
|
bool GlobalStrengthReduction::runOnFunction(Function *func, AnalysisManager &AM) {
|
||||||
|
if (func->getBasicBlocks().empty()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (DEBUG) {
|
||||||
|
std::cout << "\n=== Running GlobalStrengthReduction on function: " << func->getName() << " ===" << std::endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool changed = false;
|
||||||
|
GlobalStrengthReductionContext context(builder);
|
||||||
|
context.run(func, &AM, changed);
|
||||||
|
|
||||||
|
if (DEBUG) {
|
||||||
|
if (changed) {
|
||||||
|
std::cout << "GlobalStrengthReduction: Function " << func->getName() << " was modified" << std::endl;
|
||||||
|
} else {
|
||||||
|
std::cout << "GlobalStrengthReduction: Function " << func->getName() << " was not modified" << std::endl;
|
||||||
|
}
|
||||||
|
std::cout << "=== GlobalStrengthReduction completed for function: " << func->getName() << " ===" << std::endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
return changed;
|
||||||
|
}
|
||||||
|
|
||||||
|
void GlobalStrengthReduction::getAnalysisUsage(std::set<void *> &analysisDependencies, std::set<void *> &analysisInvalidations) const {
|
||||||
|
// 强度削弱依赖副作用分析来判断指令是否可以安全优化
|
||||||
|
analysisDependencies.insert(&SysYSideEffectAnalysisPass::ID);
|
||||||
|
|
||||||
|
// 强度削弱不会使分析失效,因为:
|
||||||
|
// - 只替换计算指令,不改变控制流
|
||||||
|
// - 不修改内存,不影响别名分析
|
||||||
|
// - 保持程序语义不变
|
||||||
|
// analysisInvalidations 保持为空
|
||||||
|
|
||||||
|
if (DEBUG) {
|
||||||
|
std::cout << "GlobalStrengthReduction: Declared analysis dependencies (SideEffectAnalysis)" << std::endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ======================================================================
|
||||||
|
// GlobalStrengthReductionContext 类的实现
|
||||||
|
// ======================================================================
|
||||||
|
|
||||||
|
void GlobalStrengthReductionContext::run(Function *func, AnalysisManager *AM, bool &changed) {
|
||||||
|
if (DEBUG) {
|
||||||
|
std::cout << " Starting GlobalStrengthReduction analysis for function: " << func->getName() << std::endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取分析结果
|
||||||
|
if (AM) {
|
||||||
|
sideEffectAnalysis = AM->getAnalysisResult<SideEffectAnalysisResult, SysYSideEffectAnalysisPass>();
|
||||||
|
|
||||||
|
if (DEBUG) {
|
||||||
|
if (sideEffectAnalysis) {
|
||||||
|
std::cout << " GlobalStrengthReduction: Using side effect analysis" << std::endl;
|
||||||
|
} else {
|
||||||
|
std::cout << " GlobalStrengthReduction: Warning - side effect analysis not available" << std::endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 重置计数器
|
||||||
|
algebraicOptCount = 0;
|
||||||
|
strengthReductionCount = 0;
|
||||||
|
divisionOptCount = 0;
|
||||||
|
|
||||||
|
// 遍历所有基本块进行优化
|
||||||
|
for (auto &bb_ptr : func->getBasicBlocks()) {
|
||||||
|
if (processBasicBlock(bb_ptr.get())) {
|
||||||
|
changed = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (DEBUG) {
|
||||||
|
std::cout << " GlobalStrengthReduction completed for function: " << func->getName() << std::endl;
|
||||||
|
std::cout << " Algebraic optimizations: " << algebraicOptCount << std::endl;
|
||||||
|
std::cout << " Strength reductions: " << strengthReductionCount << std::endl;
|
||||||
|
std::cout << " Division optimizations: " << divisionOptCount << std::endl;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool GlobalStrengthReductionContext::processBasicBlock(BasicBlock *bb) {
|
||||||
|
bool changed = false;
|
||||||
|
|
||||||
|
if (DEBUG) {
|
||||||
|
std::cout << " Processing block: " << bb->getName() << std::endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 收集需要处理的指令(避免迭代器失效)
|
||||||
|
std::vector<Instruction*> instructions;
|
||||||
|
for (auto &inst_ptr : bb->getInstructions()) {
|
||||||
|
instructions.push_back(inst_ptr.get());
|
||||||
|
}
|
||||||
|
|
||||||
|
// 处理每条指令
|
||||||
|
for (auto inst : instructions) {
|
||||||
|
if (processInstruction(inst)) {
|
||||||
|
changed = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return changed;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool GlobalStrengthReductionContext::processInstruction(Instruction *inst) {
|
||||||
|
|
||||||
|
if (DEBUG) {
|
||||||
|
std::cout << " Processing instruction: " << inst->getName() << std::endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 先尝试代数优化
|
||||||
|
if (tryAlgebraicOptimization(inst)) {
|
||||||
|
algebraicOptCount++;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 再尝试强度削弱
|
||||||
|
if (tryStrengthReduction(inst)) {
|
||||||
|
strengthReductionCount++;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// ======================================================================
|
||||||
|
// 代数优化方法
|
||||||
|
// ======================================================================
|
||||||
|
|
||||||
|
bool GlobalStrengthReductionContext::tryAlgebraicOptimization(Instruction *inst) {
|
||||||
|
auto binary = dynamic_cast<BinaryInst*>(inst);
|
||||||
|
if (!binary) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
switch (binary->getKind()) {
|
||||||
|
case Instruction::kAdd:
|
||||||
|
return optimizeAddition(binary);
|
||||||
|
case Instruction::kSub:
|
||||||
|
return optimizeSubtraction(binary);
|
||||||
|
case Instruction::kMul:
|
||||||
|
return optimizeMultiplication(binary);
|
||||||
|
case Instruction::kDiv:
|
||||||
|
return optimizeDivision(binary);
|
||||||
|
case Instruction::kICmpEQ:
|
||||||
|
case Instruction::kICmpNE:
|
||||||
|
case Instruction::kICmpLT:
|
||||||
|
case Instruction::kICmpGT:
|
||||||
|
case Instruction::kICmpLE:
|
||||||
|
case Instruction::kICmpGE:
|
||||||
|
return optimizeComparison(binary);
|
||||||
|
case Instruction::kAnd:
|
||||||
|
case Instruction::kOr:
|
||||||
|
return optimizeLogical(binary);
|
||||||
|
default:
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool GlobalStrengthReductionContext::optimizeAddition(BinaryInst *inst) {
|
||||||
|
Value *lhs = inst->getLhs();
|
||||||
|
Value *rhs = inst->getRhs();
|
||||||
|
int constVal;
|
||||||
|
|
||||||
|
// x + 0 = x
|
||||||
|
if (isConstantInt(rhs, constVal) && constVal == 0) {
|
||||||
|
if (DEBUG) {
|
||||||
|
std::cout << " Algebraic: " << inst->getName() << " = x + 0 -> x" << std::endl;
|
||||||
|
}
|
||||||
|
replaceWithOptimized(inst, lhs);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 0 + x = x
|
||||||
|
if (isConstantInt(lhs, constVal) && constVal == 0) {
|
||||||
|
if (DEBUG) {
|
||||||
|
std::cout << " Algebraic: " << inst->getName() << " = 0 + x -> x" << std::endl;
|
||||||
|
}
|
||||||
|
replaceWithOptimized(inst, rhs);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// x + (-y) = x - y
|
||||||
|
if (auto rhsInst = dynamic_cast<UnaryInst*>(rhs)) {
|
||||||
|
if (rhsInst->getKind() == Instruction::kNeg) {
|
||||||
|
if (DEBUG) {
|
||||||
|
std::cout << " Algebraic: " << inst->getName() << " = x + (-y) -> x - y" << std::endl;
|
||||||
|
}
|
||||||
|
// 创建减法指令
|
||||||
|
builder->setPosition(inst->getParent(), inst->getParent()->findInstIterator(inst));
|
||||||
|
auto subInst = builder->createSubInst(lhs, rhsInst->getOperand());
|
||||||
|
replaceWithOptimized(inst, subInst);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool GlobalStrengthReductionContext::optimizeSubtraction(BinaryInst *inst) {
|
||||||
|
Value *lhs = inst->getLhs();
|
||||||
|
Value *rhs = inst->getRhs();
|
||||||
|
int constVal;
|
||||||
|
|
||||||
|
// x - 0 = x
|
||||||
|
if (isConstantInt(rhs, constVal) && constVal == 0) {
|
||||||
|
if (DEBUG) {
|
||||||
|
std::cout << " Algebraic: " << inst->getName() << " = x - 0 -> x" << std::endl;
|
||||||
|
}
|
||||||
|
replaceWithOptimized(inst, lhs);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// x - x = 0 (如果x没有副作用)
|
||||||
|
if (lhs == rhs && hasOnlyLocalUses(dynamic_cast<Instruction*>(lhs))) {
|
||||||
|
if (DEBUG) {
|
||||||
|
std::cout << " Algebraic: " << inst->getName() << " = x - x -> 0" << std::endl;
|
||||||
|
}
|
||||||
|
replaceWithOptimized(inst, getConstantInt(0));
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// x - (-y) = x + y
|
||||||
|
if (auto rhsInst = dynamic_cast<UnaryInst*>(rhs)) {
|
||||||
|
if (rhsInst->getKind() == Instruction::kNeg) {
|
||||||
|
if (DEBUG) {
|
||||||
|
std::cout << " Algebraic: " << inst->getName() << " = x - (-y) -> x + y" << std::endl;
|
||||||
|
}
|
||||||
|
builder->setPosition(inst->getParent(), inst->getParent()->findInstIterator(inst));
|
||||||
|
auto addInst = builder->createAddInst(lhs, rhsInst->getOperand());
|
||||||
|
replaceWithOptimized(inst, addInst);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool GlobalStrengthReductionContext::optimizeMultiplication(BinaryInst *inst) {
|
||||||
|
Value *lhs = inst->getLhs();
|
||||||
|
Value *rhs = inst->getRhs();
|
||||||
|
int constVal;
|
||||||
|
|
||||||
|
// x * 0 = 0
|
||||||
|
if (isConstantInt(rhs, constVal) && constVal == 0) {
|
||||||
|
if (DEBUG) {
|
||||||
|
std::cout << " Algebraic: " << inst->getName() << " = x * 0 -> 0" << std::endl;
|
||||||
|
}
|
||||||
|
replaceWithOptimized(inst, getConstantInt(0));
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 0 * x = 0
|
||||||
|
if (isConstantInt(lhs, constVal) && constVal == 0) {
|
||||||
|
if (DEBUG) {
|
||||||
|
std::cout << " Algebraic: " << inst->getName() << " = 0 * x -> 0" << std::endl;
|
||||||
|
}
|
||||||
|
replaceWithOptimized(inst, getConstantInt(0));
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// x * 1 = x
|
||||||
|
if (isConstantInt(rhs, constVal) && constVal == 1) {
|
||||||
|
if (DEBUG) {
|
||||||
|
std::cout << " Algebraic: " << inst->getName() << " = x * 1 -> x" << std::endl;
|
||||||
|
}
|
||||||
|
replaceWithOptimized(inst, lhs);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 1 * x = x
|
||||||
|
if (isConstantInt(lhs, constVal) && constVal == 1) {
|
||||||
|
if (DEBUG) {
|
||||||
|
std::cout << " Algebraic: " << inst->getName() << " = 1 * x -> x" << std::endl;
|
||||||
|
}
|
||||||
|
replaceWithOptimized(inst, rhs);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// x * (-1) = -x
|
||||||
|
if (isConstantInt(rhs, constVal) && constVal == -1) {
|
||||||
|
if (DEBUG) {
|
||||||
|
std::cout << " Algebraic: " << inst->getName() << " = x * (-1) -> -x" << std::endl;
|
||||||
|
}
|
||||||
|
builder->setPosition(inst->getParent(), inst->getParent()->findInstIterator(inst));
|
||||||
|
auto negInst = builder->createNegInst(lhs);
|
||||||
|
replaceWithOptimized(inst, negInst);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool GlobalStrengthReductionContext::optimizeDivision(BinaryInst *inst) {
|
||||||
|
Value *lhs = inst->getLhs();
|
||||||
|
Value *rhs = inst->getRhs();
|
||||||
|
int constVal;
|
||||||
|
|
||||||
|
// x / 1 = x
|
||||||
|
if (isConstantInt(rhs, constVal) && constVal == 1) {
|
||||||
|
if (DEBUG) {
|
||||||
|
std::cout << " Algebraic: " << inst->getName() << " = x / 1 -> x" << std::endl;
|
||||||
|
}
|
||||||
|
replaceWithOptimized(inst, lhs);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// x / (-1) = -x
|
||||||
|
if (isConstantInt(rhs, constVal) && constVal == -1) {
|
||||||
|
if (DEBUG) {
|
||||||
|
std::cout << " Algebraic: " << inst->getName() << " = x / (-1) -> -x" << std::endl;
|
||||||
|
}
|
||||||
|
builder->setPosition(inst->getParent(), inst->getParent()->findInstIterator(inst));
|
||||||
|
auto negInst = builder->createNegInst(lhs);
|
||||||
|
replaceWithOptimized(inst, negInst);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// x / x = 1 (如果x != 0且没有副作用)
|
||||||
|
if (lhs == rhs && hasOnlyLocalUses(dynamic_cast<Instruction*>(lhs))) {
|
||||||
|
if (DEBUG) {
|
||||||
|
std::cout << " Algebraic: " << inst->getName() << " = x / x -> 1" << std::endl;
|
||||||
|
}
|
||||||
|
replaceWithOptimized(inst, getConstantInt(1));
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool GlobalStrengthReductionContext::optimizeComparison(BinaryInst *inst) {
|
||||||
|
Value *lhs = inst->getLhs();
|
||||||
|
Value *rhs = inst->getRhs();
|
||||||
|
|
||||||
|
// x == x = true (如果x没有副作用)
|
||||||
|
if (inst->getKind() == Instruction::kICmpEQ && lhs == rhs &&
|
||||||
|
hasOnlyLocalUses(dynamic_cast<Instruction*>(lhs))) {
|
||||||
|
if (DEBUG) {
|
||||||
|
std::cout << " Algebraic: " << inst->getName() << " = x == x -> true" << std::endl;
|
||||||
|
}
|
||||||
|
replaceWithOptimized(inst, getConstantInt(1));
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// x != x = false (如果x没有副作用)
|
||||||
|
if (inst->getKind() == Instruction::kICmpNE && lhs == rhs &&
|
||||||
|
hasOnlyLocalUses(dynamic_cast<Instruction*>(lhs))) {
|
||||||
|
if (DEBUG) {
|
||||||
|
std::cout << " Algebraic: " << inst->getName() << " = x != x -> false" << std::endl;
|
||||||
|
}
|
||||||
|
replaceWithOptimized(inst, getConstantInt(0));
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool GlobalStrengthReductionContext::optimizeLogical(BinaryInst *inst) {
|
||||||
|
Value *lhs = inst->getLhs();
|
||||||
|
Value *rhs = inst->getRhs();
|
||||||
|
int constVal;
|
||||||
|
|
||||||
|
if (inst->getKind() == Instruction::kAnd) {
|
||||||
|
// x && 0 = 0
|
||||||
|
if (isConstantInt(rhs, constVal) && constVal == 0) {
|
||||||
|
if (DEBUG) {
|
||||||
|
std::cout << " Algebraic: " << inst->getName() << " = x && 0 -> 0" << std::endl;
|
||||||
|
}
|
||||||
|
replaceWithOptimized(inst, getConstantInt(0));
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// x && -1 = x
|
||||||
|
if (isConstantInt(rhs, constVal) && constVal == -1) {
|
||||||
|
if (DEBUG) {
|
||||||
|
std::cout << " Algebraic: " << inst->getName() << " = x && 1 -> x" << std::endl;
|
||||||
|
}
|
||||||
|
replaceWithOptimized(inst, lhs);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// x && x = x
|
||||||
|
if (lhs == rhs) {
|
||||||
|
if (DEBUG) {
|
||||||
|
std::cout << " Algebraic: " << inst->getName() << " = x && x -> x" << std::endl;
|
||||||
|
}
|
||||||
|
replaceWithOptimized(inst, lhs);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
} else if (inst->getKind() == Instruction::kOr) {
|
||||||
|
// x || 0 = x
|
||||||
|
if (isConstantInt(rhs, constVal) && constVal == 0) {
|
||||||
|
if (DEBUG) {
|
||||||
|
std::cout << " Algebraic: " << inst->getName() << " = x || 0 -> x" << std::endl;
|
||||||
|
}
|
||||||
|
replaceWithOptimized(inst, lhs);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// x || x = x
|
||||||
|
if (lhs == rhs) {
|
||||||
|
if (DEBUG) {
|
||||||
|
std::cout << " Algebraic: " << inst->getName() << " = x || x -> x" << std::endl;
|
||||||
|
}
|
||||||
|
replaceWithOptimized(inst, lhs);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// ======================================================================
|
||||||
|
// 强度削弱方法
|
||||||
|
// ======================================================================
|
||||||
|
|
||||||
|
bool GlobalStrengthReductionContext::tryStrengthReduction(Instruction *inst) {
|
||||||
|
if (auto binary = dynamic_cast<BinaryInst*>(inst)) {
|
||||||
|
switch (binary->getKind()) {
|
||||||
|
case Instruction::kMul:
|
||||||
|
return reduceMultiplication(binary);
|
||||||
|
case Instruction::kDiv:
|
||||||
|
return reduceDivision(binary);
|
||||||
|
default:
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
} else if (auto call = dynamic_cast<CallInst*>(inst)) {
|
||||||
|
return reducePower(call);
|
||||||
|
}
|
||||||
|
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool GlobalStrengthReductionContext::reduceMultiplication(BinaryInst *inst) {
|
||||||
|
Value *lhs = inst->getLhs();
|
||||||
|
Value *rhs = inst->getRhs();
|
||||||
|
int constVal;
|
||||||
|
|
||||||
|
// 尝试右操作数为常数
|
||||||
|
Value* variable = lhs;
|
||||||
|
if (isConstantInt(rhs, constVal) && constVal > 0) {
|
||||||
|
return tryComplexMultiplication(inst, variable, constVal);
|
||||||
|
}
|
||||||
|
|
||||||
|
// 尝试左操作数为常数
|
||||||
|
if (isConstantInt(lhs, constVal) && constVal > 0) {
|
||||||
|
variable = rhs;
|
||||||
|
return tryComplexMultiplication(inst, variable, constVal);
|
||||||
|
}
|
||||||
|
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool GlobalStrengthReductionContext::tryComplexMultiplication(BinaryInst* inst, Value* variable, int constant) {
|
||||||
|
// 首先检查是否为2的幂,使用简单位移
|
||||||
|
if (isPowerOfTwo(constant)) {
|
||||||
|
int shiftAmount = log2OfPowerOfTwo(constant);
|
||||||
|
if (DEBUG) {
|
||||||
|
std::cout << " StrengthReduction: " << inst->getName()
|
||||||
|
<< " = x * " << constant << " -> x << " << shiftAmount << std::endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
builder->setPosition(inst->getParent(), inst->getParent()->findInstIterator(inst));
|
||||||
|
auto shiftInst = builder->createBinaryInst(Instruction::kSll, Type::getIntType(), variable, getConstantInt(shiftAmount));
|
||||||
|
replaceWithOptimized(inst, shiftInst);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 尝试分解为位移和加法的组合
|
||||||
|
std::vector<int> shifts;
|
||||||
|
if (findOptimalShiftDecomposition(constant, shifts)) {
|
||||||
|
if (DEBUG) {
|
||||||
|
std::cout << " StrengthReduction: " << inst->getName()
|
||||||
|
<< " = x * " << constant << " -> shift decomposition with " << shifts.size() << " terms" << std::endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
Value* result = createShiftDecomposition(inst, variable, shifts);
|
||||||
|
if (result) {
|
||||||
|
replaceWithOptimized(inst, result);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool GlobalStrengthReductionContext::findOptimalShiftDecomposition(int constant, std::vector<int>& shifts) {
|
||||||
|
shifts.clear();
|
||||||
|
|
||||||
|
// 常见的有效分解模式
|
||||||
|
switch (constant) {
|
||||||
|
case 3: // 3 = 2^1 + 2^0 -> (x << 1) + x
|
||||||
|
shifts = {1, 0};
|
||||||
|
return true;
|
||||||
|
case 5: // 5 = 2^2 + 2^0 -> (x << 2) + x
|
||||||
|
shifts = {2, 0};
|
||||||
|
return true;
|
||||||
|
case 6: // 6 = 2^2 + 2^1 -> (x << 2) + (x << 1)
|
||||||
|
shifts = {2, 1};
|
||||||
|
return true;
|
||||||
|
case 7: // 7 = 2^2 + 2^1 + 2^0 -> (x << 2) + (x << 1) + x
|
||||||
|
shifts = {2, 1, 0};
|
||||||
|
return true;
|
||||||
|
case 9: // 9 = 2^3 + 2^0 -> (x << 3) + x
|
||||||
|
shifts = {3, 0};
|
||||||
|
return true;
|
||||||
|
case 10: // 10 = 2^3 + 2^1 -> (x << 3) + (x << 1)
|
||||||
|
shifts = {3, 1};
|
||||||
|
return true;
|
||||||
|
case 11: // 11 = 2^3 + 2^1 + 2^0 -> (x << 3) + (x << 1) + x
|
||||||
|
shifts = {3, 1, 0};
|
||||||
|
return true;
|
||||||
|
case 12: // 12 = 2^3 + 2^2 -> (x << 3) + (x << 2)
|
||||||
|
shifts = {3, 2};
|
||||||
|
return true;
|
||||||
|
case 13: // 13 = 2^3 + 2^2 + 2^0 -> (x << 3) + (x << 2) + x
|
||||||
|
shifts = {3, 2, 0};
|
||||||
|
return true;
|
||||||
|
case 14: // 14 = 2^3 + 2^2 + 2^1 -> (x << 3) + (x << 2) + (x << 1)
|
||||||
|
shifts = {3, 2, 1};
|
||||||
|
return true;
|
||||||
|
case 15: // 15 = 2^3 + 2^2 + 2^1 + 2^0 -> (x << 3) + (x << 2) + (x << 1) + x
|
||||||
|
shifts = {3, 2, 1, 0};
|
||||||
|
return true;
|
||||||
|
case 17: // 17 = 2^4 + 2^0 -> (x << 4) + x
|
||||||
|
shifts = {4, 0};
|
||||||
|
return true;
|
||||||
|
case 18: // 18 = 2^4 + 2^1 -> (x << 4) + (x << 1)
|
||||||
|
shifts = {4, 1};
|
||||||
|
return true;
|
||||||
|
case 20: // 20 = 2^4 + 2^2 -> (x << 4) + (x << 2)
|
||||||
|
shifts = {4, 2};
|
||||||
|
return true;
|
||||||
|
case 24: // 24 = 2^4 + 2^3 -> (x << 4) + (x << 3)
|
||||||
|
shifts = {4, 3};
|
||||||
|
return true;
|
||||||
|
case 25: // 25 = 2^4 + 2^3 + 2^0 -> (x << 4) + (x << 3) + x
|
||||||
|
shifts = {4, 3, 0};
|
||||||
|
return true;
|
||||||
|
case 100: // 100 = 2^6 + 2^5 + 2^2 -> (x << 6) + (x << 5) + (x << 2)
|
||||||
|
shifts = {6, 5, 2};
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 通用二进制分解(最多4个项,避免过度复杂化)
|
||||||
|
if (constant > 0 && constant < 256) {
|
||||||
|
std::vector<int> binaryShifts;
|
||||||
|
int temp = constant;
|
||||||
|
int bit = 0;
|
||||||
|
|
||||||
|
while (temp > 0 && binaryShifts.size() < 4) {
|
||||||
|
if (temp & 1) {
|
||||||
|
binaryShifts.push_back(bit);
|
||||||
|
}
|
||||||
|
temp >>= 1;
|
||||||
|
bit++;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 只有当项数不超过3个时才使用二进制分解(比直接乘法更有效)
|
||||||
|
if (binaryShifts.size() <= 3 && binaryShifts.size() >= 2) {
|
||||||
|
shifts = binaryShifts;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
Value* GlobalStrengthReductionContext::createShiftDecomposition(BinaryInst* inst, Value* variable, const std::vector<int>& shifts) {
|
||||||
|
if (shifts.empty()) return nullptr;
|
||||||
|
|
||||||
|
builder->setPosition(inst->getParent(), inst->getParent()->findInstIterator(inst));
|
||||||
|
|
||||||
|
Value* result = nullptr;
|
||||||
|
|
||||||
|
for (int shift : shifts) {
|
||||||
|
Value* term;
|
||||||
|
if (shift == 0) {
|
||||||
|
// 0位移就是原变量
|
||||||
|
term = variable;
|
||||||
|
} else {
|
||||||
|
// 创建位移指令
|
||||||
|
term = builder->createBinaryInst(Instruction::kSll, Type::getIntType(), variable, getConstantInt(shift));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (result == nullptr) {
|
||||||
|
result = term;
|
||||||
|
} else {
|
||||||
|
// 累加到结果中
|
||||||
|
result = builder->createAddInst(result, term);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool GlobalStrengthReductionContext::reduceDivision(BinaryInst *inst) {
|
||||||
|
Value *lhs = inst->getLhs();
|
||||||
|
Value *rhs = inst->getRhs();
|
||||||
|
uint32_t constVal;
|
||||||
|
|
||||||
|
// x / 2^n = x >> n (对于无符号除法或已知为正数的情况)
|
||||||
|
if (isConstantInt(rhs, constVal) && constVal > 0 && isPowerOfTwo(constVal)) {
|
||||||
|
builder->setPosition(inst->getParent(), inst->getParent()->findInstIterator(inst));
|
||||||
|
int shiftAmount = log2OfPowerOfTwo(constVal);
|
||||||
|
// 有符号除法校正:(x + (x >> 31) & mask) >> k
|
||||||
|
int maskValue = constVal - 1;
|
||||||
|
|
||||||
|
// x >> 31 (算术右移获取符号位)
|
||||||
|
Value* signShift = ConstantInteger::get(31);
|
||||||
|
Value* signBits = builder->createBinaryInst(
|
||||||
|
Instruction::Kind::kSra, // 算术右移
|
||||||
|
lhs->getType(),
|
||||||
|
lhs,
|
||||||
|
signShift
|
||||||
|
);
|
||||||
|
|
||||||
|
// (x >> 31) & mask
|
||||||
|
Value* mask = ConstantInteger::get(maskValue);
|
||||||
|
Value* correction = builder->createBinaryInst(
|
||||||
|
Instruction::Kind::kAnd,
|
||||||
|
lhs->getType(),
|
||||||
|
signBits,
|
||||||
|
mask
|
||||||
|
);
|
||||||
|
|
||||||
|
// x + correction
|
||||||
|
Value* corrected = builder->createAddInst(lhs, correction);
|
||||||
|
|
||||||
|
// (x + correction) >> k
|
||||||
|
Value* divShift = ConstantInteger::get(shiftAmount);
|
||||||
|
Value* shiftInst = builder->createBinaryInst(
|
||||||
|
Instruction::Kind::kSra, // 算术右移
|
||||||
|
lhs->getType(),
|
||||||
|
corrected,
|
||||||
|
divShift
|
||||||
|
);
|
||||||
|
|
||||||
|
if (DEBUG) {
|
||||||
|
std::cout << " StrengthReduction: " << inst->getName()
|
||||||
|
<< " = x / " << constVal << " -> (x + (x >> 31) & mask) >> " << shiftAmount << std::endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
// builder->setPosition(inst->getParent(), inst->getParent()->findInstIterator(inst));
|
||||||
|
// Value* divisor_minus_1 = ConstantInteger::get(constVal - 1);
|
||||||
|
// Value* adjusted = builder->createAddInst(lhs, divisor_minus_1);
|
||||||
|
// Value* shiftInst = builder->createBinaryInst(Instruction::kSra, Type::getIntType(), adjusted, getConstantInt(shiftAmount));
|
||||||
|
replaceWithOptimized(inst, shiftInst);
|
||||||
|
strengthReductionCount++;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// x / c = x * magic_number (魔数乘法优化 - 使用libdivide算法)
|
||||||
|
if (isConstantInt(rhs, constVal) && constVal > 1 && constVal != (uint32_t)(-1)) {
|
||||||
|
// auto magicPair = computeMulhMagicNumbers(static_cast<int>(constVal));
|
||||||
|
Value* magicResult = createMagicDivisionLibdivide(inst, static_cast<int>(constVal));
|
||||||
|
replaceWithOptimized(inst, magicResult);
|
||||||
|
divisionOptCount++;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool GlobalStrengthReductionContext::reducePower(CallInst *inst) {
|
||||||
|
// 检查是否是pow函数调用
|
||||||
|
Function* callee = inst->getCallee();
|
||||||
|
if (!callee || callee->getName() != "pow") {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// pow(x, 2) = x * x
|
||||||
|
if (inst->getNumOperands() >= 2) {
|
||||||
|
int exponent;
|
||||||
|
if (isConstantInt(inst->getOperand(1), exponent)) {
|
||||||
|
if (exponent == 2) {
|
||||||
|
if (DEBUG) {
|
||||||
|
std::cout << " StrengthReduction: pow(x, 2) -> x * x" << std::endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
Value* base = inst->getOperand(0);
|
||||||
|
builder->setPosition(inst->getParent(), inst->getParent()->findInstIterator(inst));
|
||||||
|
auto mulInst = builder->createMulInst(base, base);
|
||||||
|
replaceWithOptimized(inst, mulInst);
|
||||||
|
strengthReductionCount++;
|
||||||
|
return true;
|
||||||
|
} else if (exponent >= 3 && exponent <= 8) {
|
||||||
|
// 对于小的指数,展开为连续乘法
|
||||||
|
if (DEBUG) {
|
||||||
|
std::cout << " StrengthReduction: pow(x, " << exponent << ") -> repeated multiplication" << std::endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
Value* base = inst->getOperand(0);
|
||||||
|
Value* result = base;
|
||||||
|
builder->setPosition(inst->getParent(), inst->getParent()->findInstIterator(inst));
|
||||||
|
|
||||||
|
for (int i = 1; i < exponent; i++) {
|
||||||
|
result = builder->createMulInst(result, base);
|
||||||
|
}
|
||||||
|
|
||||||
|
replaceWithOptimized(inst, result);
|
||||||
|
strengthReductionCount++;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
Value* GlobalStrengthReductionContext::createMagicDivisionLibdivide(BinaryInst* divInst, int divisor) {
|
||||||
|
builder->setPosition(divInst->getParent(), divInst->getParent()->findInstIterator(divInst));
|
||||||
|
// 使用mulh指令优化任意常数除法
|
||||||
|
auto [magic, shift] = SysYIROptUtils::computeMulhMagicNumbers(divisor);
|
||||||
|
|
||||||
|
// 检查是否无法优化(magic == -1, shift == -1 表示失败)
|
||||||
|
if (magic == -1 && shift == -1) {
|
||||||
|
if (DEBUG) {
|
||||||
|
std::cout << "[SR] Cannot optimize division by " << divisor
|
||||||
|
<< ", keeping original division" << std::endl;
|
||||||
|
}
|
||||||
|
// 返回 nullptr 表示无法优化,调用方应该保持原始除法
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2的幂次方除法可以用移位优化(但这不是魔数法的情况)这种情况应该不会被分类到这里但是还是做一个保护措施
|
||||||
|
if ((divisor & (divisor - 1)) == 0 && divisor > 0) {
|
||||||
|
// 是2的幂次方,可以用移位
|
||||||
|
int shift_amount = 0;
|
||||||
|
int temp = divisor;
|
||||||
|
while (temp > 1) {
|
||||||
|
temp >>= 1;
|
||||||
|
shift_amount++;
|
||||||
|
}
|
||||||
|
|
||||||
|
Value* shiftConstant = ConstantInteger::get(shift_amount);
|
||||||
|
// 对于有符号除法,需要先加上除数-1然后再移位(为了正确处理负数舍入)
|
||||||
|
Value* divisor_minus_1 = ConstantInteger::get(divisor - 1);
|
||||||
|
Value* adjusted = builder->createAddInst(divInst->getOperand(0), divisor_minus_1);
|
||||||
|
return builder->createBinaryInst(
|
||||||
|
Instruction::Kind::kSra, // 算术右移
|
||||||
|
divInst->getOperand(0)->getType(),
|
||||||
|
adjusted,
|
||||||
|
shiftConstant
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// 创建魔数常量
|
||||||
|
// 检查魔数是否能放入32位,如果不能,则不进行优化
|
||||||
|
if (magic > INT32_MAX || magic < INT32_MIN) {
|
||||||
|
if (DEBUG) {
|
||||||
|
std::cout << "[SR] Magic number " << magic << " exceeds 32-bit range, skipping optimization" << std::endl;
|
||||||
|
}
|
||||||
|
return nullptr; // 无法优化,保持原始除法
|
||||||
|
}
|
||||||
|
|
||||||
|
Value* magicConstant = ConstantInteger::get((int32_t)magic);
|
||||||
|
|
||||||
|
// 检查是否需要ADD_MARKER处理(加法调整)
|
||||||
|
bool needAdd = (shift & 0x40) != 0;
|
||||||
|
int actualShift = shift & 0x3F; // 提取真实的移位量
|
||||||
|
|
||||||
|
if (DEBUG) {
|
||||||
|
std::cout << "[SR] IR Generation: magic=" << magic << ", needAdd=" << needAdd
|
||||||
|
<< ", actualShift=" << actualShift << std::endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 执行高位乘法:mulh(x, magic)
|
||||||
|
Value* mulhResult = builder->createBinaryInst(
|
||||||
|
Instruction::Kind::kMulh, // 高位乘法
|
||||||
|
divInst->getOperand(0)->getType(),
|
||||||
|
divInst->getOperand(0),
|
||||||
|
magicConstant
|
||||||
|
);
|
||||||
|
|
||||||
|
if (needAdd) {
|
||||||
|
// ADD_MARKER 情况:需要在移位前加上被除数
|
||||||
|
// 这对应于 libdivide 的加法调整算法
|
||||||
|
if (DEBUG) {
|
||||||
|
std::cout << "[SR] Applying ADD_MARKER: adding dividend before shift" << std::endl;
|
||||||
|
}
|
||||||
|
mulhResult = builder->createAddInst(mulhResult, divInst->getOperand(0));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (actualShift > 0) {
|
||||||
|
// 如果需要额外移位
|
||||||
|
Value* shiftConstant = ConstantInteger::get(actualShift);
|
||||||
|
mulhResult = builder->createBinaryInst(
|
||||||
|
Instruction::Kind::kSra, // 算术右移
|
||||||
|
divInst->getOperand(0)->getType(),
|
||||||
|
mulhResult,
|
||||||
|
shiftConstant
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// 标准的有符号除法符号修正:如果被除数为负,商需要加1
|
||||||
|
// 这对所有有符号除法都需要,不管是否可能有负数
|
||||||
|
Value* isNegative = builder->createICmpLTInst(divInst->getOperand(0), ConstantInteger::get(0));
|
||||||
|
// 将i1转换为i32:负数时为1,非负数时为0 ICmpLTInst的结果会默认转化为32位
|
||||||
|
mulhResult = builder->createAddInst(mulhResult, isNegative);
|
||||||
|
|
||||||
|
return mulhResult;
|
||||||
|
}
|
||||||
|
|
||||||
|
// ======================================================================
|
||||||
|
// 辅助方法
|
||||||
|
// ======================================================================
|
||||||
|
|
||||||
|
bool GlobalStrengthReductionContext::isPowerOfTwo(uint32_t n) {
|
||||||
|
return n > 0 && (n & (n - 1)) == 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
int GlobalStrengthReductionContext::log2OfPowerOfTwo(uint32_t n) {
|
||||||
|
int result = 0;
|
||||||
|
while (n > 1) {
|
||||||
|
n >>= 1;
|
||||||
|
result++;
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool GlobalStrengthReductionContext::isConstantInt(Value* val, int& constVal) {
|
||||||
|
if (auto constInt = dynamic_cast<ConstantInteger*>(val)) {
|
||||||
|
constVal = std::get<int>(constInt->getVal());
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool GlobalStrengthReductionContext::isConstantInt(Value* val, uint32_t& constVal) {
|
||||||
|
if (auto constInt = dynamic_cast<ConstantInteger*>(val)) {
|
||||||
|
int signedVal = std::get<int>(constInt->getVal());
|
||||||
|
if (signedVal >= 0) {
|
||||||
|
constVal = static_cast<uint32_t>(signedVal);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
ConstantInteger* GlobalStrengthReductionContext::getConstantInt(int val) {
|
||||||
|
return ConstantInteger::get(val);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool GlobalStrengthReductionContext::hasOnlyLocalUses(Instruction* inst) {
|
||||||
|
if (!inst) return true;
|
||||||
|
|
||||||
|
// 简单检查:如果指令没有副作用,则认为是本地的
|
||||||
|
if (sideEffectAnalysis) {
|
||||||
|
auto sideEffect = sideEffectAnalysis->getInstructionSideEffect(inst);
|
||||||
|
return sideEffect.type == SideEffectType::NO_SIDE_EFFECT;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 没有副作用分析时,保守处理
|
||||||
|
return !inst->isCall() && !inst->isStore() && !inst->isLoad();
|
||||||
|
}
|
||||||
|
|
||||||
|
void GlobalStrengthReductionContext::replaceWithOptimized(Instruction* original, Value* replacement) {
|
||||||
|
if (DEBUG) {
|
||||||
|
std::cout << " Replacing " << original->getName()
|
||||||
|
<< " with " << replacement->getName() << std::endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
original->replaceAllUsesWith(replacement);
|
||||||
|
|
||||||
|
// 如果替换值是新创建的指令,确保它有合适的名字
|
||||||
|
// if (auto replInst = dynamic_cast<Instruction*>(replacement)) {
|
||||||
|
// if (replInst->getName().empty()) {
|
||||||
|
// replInst->setName(original->getName() + "_opt");
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
// 删除原指令,让调用者处理
|
||||||
|
SysYIROptUtils::usedelete(original);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace sysy
|
||||||
@@ -106,187 +106,6 @@ bool StrengthReductionContext::analyzeInductionVariableRange(
|
|||||||
return hasNegativePotential;
|
return hasNegativePotential;
|
||||||
}
|
}
|
||||||
|
|
||||||
//该实现参考了libdivide的算法
|
|
||||||
std::pair<int, int> StrengthReductionContext::computeMulhMagicNumbers(int divisor) const {
|
|
||||||
|
|
||||||
if (DEBUG) {
|
|
||||||
std::cout << "\n[SR] ===== Computing magic numbers for divisor " << divisor << " (libdivide algorithm) =====" << std::endl;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (divisor == 0) {
|
|
||||||
if (DEBUG) std::cout << "[SR] Error: divisor must be != 0" << std::endl;
|
|
||||||
return {-1, -1};
|
|
||||||
}
|
|
||||||
|
|
||||||
// libdivide 常数
|
|
||||||
const uint8_t LIBDIVIDE_ADD_MARKER = 0x40;
|
|
||||||
const uint8_t LIBDIVIDE_NEGATIVE_DIVISOR = 0x80;
|
|
||||||
|
|
||||||
// 辅助函数:计算前导零个数
|
|
||||||
auto count_leading_zeros32 = [](uint32_t val) -> uint32_t {
|
|
||||||
if (val == 0) return 32;
|
|
||||||
return __builtin_clz(val);
|
|
||||||
};
|
|
||||||
|
|
||||||
// 辅助函数:64位除法返回32位商和余数
|
|
||||||
auto div_64_32 = [](uint32_t high, uint32_t low, uint32_t divisor, uint32_t* rem) -> uint32_t {
|
|
||||||
uint64_t dividend = ((uint64_t)high << 32) | low;
|
|
||||||
uint32_t quotient = dividend / divisor;
|
|
||||||
*rem = dividend % divisor;
|
|
||||||
return quotient;
|
|
||||||
};
|
|
||||||
|
|
||||||
if (DEBUG) {
|
|
||||||
std::cout << "[SR] Input divisor: " << divisor << std::endl;
|
|
||||||
}
|
|
||||||
|
|
||||||
// libdivide_internal_s32_gen 算法实现
|
|
||||||
int32_t d = divisor;
|
|
||||||
uint32_t ud = (uint32_t)d;
|
|
||||||
uint32_t absD = (d < 0) ? -ud : ud;
|
|
||||||
|
|
||||||
if (DEBUG) {
|
|
||||||
std::cout << "[SR] absD = " << absD << std::endl;
|
|
||||||
}
|
|
||||||
|
|
||||||
uint32_t floor_log_2_d = 31 - count_leading_zeros32(absD);
|
|
||||||
|
|
||||||
if (DEBUG) {
|
|
||||||
std::cout << "[SR] floor_log_2_d = " << floor_log_2_d << std::endl;
|
|
||||||
}
|
|
||||||
|
|
||||||
// 检查 absD 是否为2的幂
|
|
||||||
if ((absD & (absD - 1)) == 0) {
|
|
||||||
if (DEBUG) {
|
|
||||||
std::cout << "[SR] " << absD << " 是2的幂,使用移位方法" << std::endl;
|
|
||||||
}
|
|
||||||
|
|
||||||
// 对于2的幂,我们只使用移位,不需要魔数
|
|
||||||
int shift = floor_log_2_d;
|
|
||||||
if (d < 0) shift |= 0x80; // 标记负数
|
|
||||||
|
|
||||||
if (DEBUG) {
|
|
||||||
std::cout << "[SR] Power of 2 result: magic=0, shift=" << shift << std::endl;
|
|
||||||
std::cout << "[SR] ===== End magic computation =====" << std::endl;
|
|
||||||
}
|
|
||||||
|
|
||||||
// 对于我们的目的,我们将在IR生成中以不同方式处理2的幂
|
|
||||||
// 返回特殊标记
|
|
||||||
return {0, shift};
|
|
||||||
}
|
|
||||||
|
|
||||||
if (DEBUG) {
|
|
||||||
std::cout << "[SR] " << absD << " is not a power of 2, computing magic number" << std::endl;
|
|
||||||
}
|
|
||||||
|
|
||||||
// 非2的幂除数的魔数计算
|
|
||||||
uint8_t more;
|
|
||||||
uint32_t rem, proposed_m;
|
|
||||||
|
|
||||||
// 计算 proposed_m = floor(2^(floor_log_2_d + 31) / absD)
|
|
||||||
proposed_m = div_64_32((uint32_t)1 << (floor_log_2_d - 1), 0, absD, &rem);
|
|
||||||
const uint32_t e = absD - rem;
|
|
||||||
|
|
||||||
if (DEBUG) {
|
|
||||||
std::cout << "[SR] proposed_m = " << proposed_m << ", rem = " << rem << ", e = " << e << std::endl;
|
|
||||||
}
|
|
||||||
|
|
||||||
// 确定是否需要"加法"版本
|
|
||||||
const bool branchfree = false; // 使用分支版本
|
|
||||||
|
|
||||||
if (!branchfree && e < ((uint32_t)1 << floor_log_2_d)) {
|
|
||||||
// 这个幂次有效
|
|
||||||
more = (uint8_t)(floor_log_2_d - 1);
|
|
||||||
if (DEBUG) {
|
|
||||||
std::cout << "[SR] Using basic algorithm, shift = " << (int)more << std::endl;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// 我们需要上升一个等级
|
|
||||||
proposed_m += proposed_m;
|
|
||||||
const uint32_t twice_rem = rem + rem;
|
|
||||||
if (twice_rem >= absD || twice_rem < rem) {
|
|
||||||
proposed_m += 1;
|
|
||||||
}
|
|
||||||
more = (uint8_t)(floor_log_2_d | LIBDIVIDE_ADD_MARKER);
|
|
||||||
if (DEBUG) {
|
|
||||||
std::cout << "[SR] Using add algorithm, proposed_m = " << proposed_m << ", more = " << (int)more << std::endl;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
proposed_m += 1;
|
|
||||||
int32_t magic = (int32_t)proposed_m;
|
|
||||||
|
|
||||||
// 处理负除数
|
|
||||||
if (d < 0) {
|
|
||||||
more |= LIBDIVIDE_NEGATIVE_DIVISOR;
|
|
||||||
if (!branchfree) {
|
|
||||||
magic = -magic;
|
|
||||||
}
|
|
||||||
if (DEBUG) {
|
|
||||||
std::cout << "[SR] Negative divisor, magic = " << magic << ", more = " << (int)more << std::endl;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 为我们的IR生成提取移位量和标志
|
|
||||||
int shift = more & 0x3F; // 移除标志,保留移位量(位0-5)
|
|
||||||
bool need_add = (more & LIBDIVIDE_ADD_MARKER) != 0;
|
|
||||||
bool is_negative = (more & LIBDIVIDE_NEGATIVE_DIVISOR) != 0;
|
|
||||||
|
|
||||||
if (DEBUG) {
|
|
||||||
std::cout << "[SR] Final result: magic = " << magic << ", more = " << (int)more
|
|
||||||
<< " (0x" << std::hex << (int)more << std::dec << ")" << std::endl;
|
|
||||||
std::cout << "[SR] Shift = " << shift << ", need_add = " << need_add
|
|
||||||
<< ", is_negative = " << is_negative << std::endl;
|
|
||||||
|
|
||||||
// Test the magic number using the correct libdivide algorithm
|
|
||||||
std::cout << "[SR] Testing magic number (libdivide algorithm):" << std::endl;
|
|
||||||
int test_values[] = {1, 7, 37, 100, 999, -1, -7, -37, -100};
|
|
||||||
|
|
||||||
for (int test_val : test_values) {
|
|
||||||
int64_t quotient;
|
|
||||||
|
|
||||||
// 实现正确的libdivide算法
|
|
||||||
int64_t product = (int64_t)test_val * magic;
|
|
||||||
int64_t high_bits = product >> 32;
|
|
||||||
|
|
||||||
if (need_add) {
|
|
||||||
// ADD_MARKER情况:移位前加上被除数
|
|
||||||
// 这是libdivide的关键洞察!
|
|
||||||
high_bits += test_val;
|
|
||||||
quotient = high_bits >> shift;
|
|
||||||
} else {
|
|
||||||
// 正常情况:只是移位
|
|
||||||
quotient = high_bits >> shift;
|
|
||||||
}
|
|
||||||
|
|
||||||
// 符号修正:这是libdivide有符号除法的关键部分!
|
|
||||||
// 如果被除数为负,商需要加1来匹配C语言的截断除法语义
|
|
||||||
if (test_val < 0) {
|
|
||||||
quotient += 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
int expected = test_val / divisor;
|
|
||||||
|
|
||||||
bool correct = (quotient == expected);
|
|
||||||
std::cout << "[SR] " << test_val << " / " << divisor << " = " << quotient
|
|
||||||
<< " (expected " << expected << ") " << (correct ? "✓" : "✗") << std::endl;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::cout << "[SR] ===== End magic computation =====" << std::endl;
|
|
||||||
}
|
|
||||||
|
|
||||||
// 返回魔数、移位量,并在移位中编码ADD_MARKER标志
|
|
||||||
// 我们将使用移位的第6位表示ADD_MARKER,第7位表示负数(如果需要)
|
|
||||||
int encoded_shift = shift;
|
|
||||||
if (need_add) {
|
|
||||||
encoded_shift |= 0x40; // 设置第6位表示ADD_MARKER
|
|
||||||
if (DEBUG) {
|
|
||||||
std::cout << "[SR] Encoding ADD_MARKER in shift: " << encoded_shift << std::endl;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return {magic, encoded_shift};
|
|
||||||
}
|
|
||||||
|
|
||||||
bool LoopStrengthReduction::runOnFunction(Function* F, AnalysisManager& AM) {
|
bool LoopStrengthReduction::runOnFunction(Function* F, AnalysisManager& AM) {
|
||||||
if (F->getBasicBlocks().empty()) {
|
if (F->getBasicBlocks().empty()) {
|
||||||
@@ -1018,7 +837,7 @@ Value* StrengthReductionContext::generateConstantDivisionReplacement(
|
|||||||
IRBuilder* builder
|
IRBuilder* builder
|
||||||
) const {
|
) const {
|
||||||
// 使用mulh指令优化任意常数除法
|
// 使用mulh指令优化任意常数除法
|
||||||
auto [magic, shift] = computeMulhMagicNumbers(candidate->multiplier);
|
auto [magic, shift] = SysYIROptUtils::computeMulhMagicNumbers(candidate->multiplier);
|
||||||
|
|
||||||
// 检查是否无法优化(magic == -1, shift == -1 表示失败)
|
// 检查是否无法优化(magic == -1, shift == -1 表示失败)
|
||||||
if (magic == -1 && shift == -1) {
|
if (magic == -1 && shift == -1) {
|
||||||
|
|||||||
@@ -18,6 +18,7 @@
|
|||||||
#include "LICM.h"
|
#include "LICM.h"
|
||||||
#include "LoopStrengthReduction.h"
|
#include "LoopStrengthReduction.h"
|
||||||
#include "InductionVariableElimination.h"
|
#include "InductionVariableElimination.h"
|
||||||
|
#include "GlobalStrengthReduction.h"
|
||||||
#include "Pass.h"
|
#include "Pass.h"
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <queue>
|
#include <queue>
|
||||||
@@ -77,6 +78,8 @@ void PassManager::runOptimizationPipeline(Module* moduleIR, IRBuilder* builderIR
|
|||||||
registerOptimizationPass<LICM>(builderIR);
|
registerOptimizationPass<LICM>(builderIR);
|
||||||
registerOptimizationPass<LoopStrengthReduction>(builderIR);
|
registerOptimizationPass<LoopStrengthReduction>(builderIR);
|
||||||
registerOptimizationPass<InductionVariableElimination>();
|
registerOptimizationPass<InductionVariableElimination>();
|
||||||
|
|
||||||
|
registerOptimizationPass<GlobalStrengthReduction>(builderIR);
|
||||||
registerOptimizationPass<Reg2Mem>(builderIR);
|
registerOptimizationPass<Reg2Mem>(builderIR);
|
||||||
|
|
||||||
registerOptimizationPass<SCCP>(builderIR);
|
registerOptimizationPass<SCCP>(builderIR);
|
||||||
@@ -179,6 +182,16 @@ void PassManager::runOptimizationPipeline(Module* moduleIR, IRBuilder* builderIR
|
|||||||
printPasses();
|
printPasses();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 全局强度削弱优化,包括代数优化和魔数除法
|
||||||
|
this->clearPasses();
|
||||||
|
this->addPass(&GlobalStrengthReduction::ID);
|
||||||
|
this->run();
|
||||||
|
|
||||||
|
if(DEBUG) {
|
||||||
|
std::cout << "=== IR After Global Strength Reduction Optimizations ===\n";
|
||||||
|
printPasses();
|
||||||
|
}
|
||||||
|
|
||||||
// this->clearPasses();
|
// this->clearPasses();
|
||||||
// this->addPass(&Reg2Mem::ID);
|
// this->addPass(&Reg2Mem::ID);
|
||||||
// this->run();
|
// this->run();
|
||||||
|
|||||||
Reference in New Issue
Block a user