[midend-IVE]参考libdivide库,实现了魔数的正确求解,如果后续出错直接用API或者不要除法强度削弱了

This commit is contained in:
rain2133
2025-08-14 05:12:54 +08:00
parent 06a368db39
commit 7547d34598
5 changed files with 421 additions and 61 deletions

View File

@@ -864,6 +864,8 @@ public:
return "shl";
case kSra:
return "ashr";
case kMulh:
return "mulh";
default:
return "Unknown";
}

View File

@@ -132,7 +132,7 @@ private:
* @param divisor 除数
* @return {魔数, 移位量}
*/
std::pair<int64_t, int> computeMulhMagicNumbers(int divisor) const;
std::pair<int, int> computeMulhMagicNumbers(int divisor) const;
/**
* 生成除法替换代码

View File

@@ -779,7 +779,29 @@ void BinaryInst::print(std::ostream &os) const {
printOperand(os, getRhs());
os << "\n ";
printVarName(os, this) << " = zext i1 %" << tmpName << " to i32";
} else {
} else if(kind == kMulh){
// 模拟高位乘法先扩展为i64乘法右移32位截断为i32
static int mulhCount = 0;
mulhCount++;
std::string lhsName = getLhs()->getName();
std::string rhsName = getRhs()->getName();
std::string tmpLhs = "tmp_mulh_lhs_" + std::to_string(mulhCount) + "_" + lhsName;
std::string tmpRhs = "tmp_mulh_rhs_" + std::to_string(mulhCount) + rhsName;
std::string tmpMul = "tmp_mulh_mul_" + std::to_string(mulhCount) + getName();
std::string tmpHigh = "tmp_mulh_high_" + std::to_string(mulhCount) + getName();
// printVarName(os, this) << " = "; // 输出最终变量名
// os << "; mulh emulation\n ";
os << "%" << tmpLhs << " = sext i32 ";
printOperand(os, getLhs());
os << " to i64\n ";
os << "%" << tmpRhs << " = sext i32 ";
printOperand(os, getRhs());
os << " to i64\n ";
os << "%" << tmpMul << " = mul i64 %" << tmpLhs << ", %" << tmpRhs << "\n ";
os << "%" << tmpHigh << " = ashr i64 %" << tmpMul << ", 32\n ";
printVarName(os, this) << " = trunc i64 %" << tmpHigh << " to i32";
}else {
// 算术和逻辑指令
printVarName(os, this) << " = ";
os << getKindString() << " " << *getType() << " ";

View File

@@ -7,6 +7,8 @@
#include <iostream>
#include <algorithm>
#include <cmath>
#include <unordered_map>
#include <climits>
// 使用全局调试开关
extern int DEBUG;
@@ -104,65 +106,188 @@ bool StrengthReductionContext::analyzeInductionVariableRange(
return hasNegativePotential;
}
std::pair<int64_t, int> StrengthReductionContext::computeMulhMagicNumbers(int divisor) const {
// 计算用于除法的魔数 (magic number) 和移位量
// 基于 "Division by Invariant Integers using Multiplication" 算法
//该实现参考了libdivide的算法
std::pair<int, int> StrengthReductionContext::computeMulhMagicNumbers(int divisor) const {
int64_t magic = 0;
int shift = 0;
bool isPowerOfTwo = (divisor & (divisor - 1)) == 0;
if (isPowerOfTwo) {
// 对于2的幂不需要魔数直接使用移位
magic = 1;
shift = __builtin_ctz(divisor); // 计算尾随零的个数
return {magic, shift};
if (DEBUG) {
std::cout << "\n[SR] ===== Computing magic numbers for divisor " << divisor << " (libdivide algorithm) =====" << std::endl;
}
// 对于非2的幂的正数除数计算魔数
// 使用32位有符号整数范围
const int bitWidth = 32;
const int64_t maxMagic = (1LL << (bitWidth - 1)) - 1;
if (divisor == 0) {
if (DEBUG) std::cout << "[SR] Error: divisor must be != 0" << std::endl;
return {-1, -1};
}
// libdivide 常数
const uint8_t LIBDIVIDE_ADD_MARKER = 0x40;
const uint8_t LIBDIVIDE_NEGATIVE_DIVISOR = 0x80;
int64_t d = divisor;
int64_t nc = (1LL << (bitWidth - 1)) - (1LL << (bitWidth - 1)) % d;
int64_t delta = d - (1LL << (bitWidth - 1)) % d;
// 辅助函数:计算前导零个数
auto count_leading_zeros32 = [](uint32_t val) -> uint32_t {
if (val == 0) return 32;
return __builtin_clz(val);
};
shift = bitWidth - 1;
// 辅助函数64位除法返回32位商和余数
auto div_64_32 = [](uint32_t high, uint32_t low, uint32_t divisor, uint32_t* rem) -> uint32_t {
uint64_t dividend = ((uint64_t)high << 32) | low;
uint32_t quotient = dividend / divisor;
*rem = dividend % divisor;
return quotient;
};
if (DEBUG) {
std::cout << "[SR] Input divisor: " << divisor << std::endl;
}
// libdivide_internal_s32_gen 算法实现
int32_t d = divisor;
uint32_t ud = (uint32_t)d;
uint32_t absD = (d < 0) ? -ud : ud;
// 找到合适的魔数和移位量
while (shift < bitWidth + 30) { // 避免无限循环
int64_t q1 = (1LL << shift) / nc;
int64_t r1 = (1LL << shift) - q1 * nc;
int64_t q2 = (1LL << shift) / delta;
int64_t r2 = (1LL << shift) - q2 * delta;
if (q1 < q2 || (q1 == q2 && r1 < r2)) {
magic = q2 + 1;
if (magic <= maxMagic) {
break;
}
if (DEBUG) {
std::cout << "[SR] absD = " << absD << std::endl;
}
uint32_t floor_log_2_d = 31 - count_leading_zeros32(absD);
if (DEBUG) {
std::cout << "[SR] floor_log_2_d = " << floor_log_2_d << std::endl;
}
// 检查 absD 是否为2的幂
if ((absD & (absD - 1)) == 0) {
if (DEBUG) {
std::cout << "[SR] " << absD << " 是2的幂使用移位方法" << std::endl;
}
shift++;
nc = 2 * nc;
delta = 2 * delta;
// 对于2的幂我们只使用移位不需要魔数
int shift = floor_log_2_d;
if (d < 0) shift |= 0x80; // 标记负数
if (DEBUG) {
std::cout << "[SR] Power of 2 result: magic=0, shift=" << shift << std::endl;
std::cout << "[SR] ===== End magic computation =====" << std::endl;
}
// 对于我们的目的我们将在IR生成中以不同方式处理2的幂
// 返回特殊标记
return {0, shift};
}
if (magic > maxMagic) {
// 回退到简单的魔数
magic = (1LL << bitWidth) / d + 1;
shift = bitWidth;
if (DEBUG) {
std::cout << "[SR] " << absD << " is not a power of 2, computing magic number" << std::endl;
}
// 调整移位量以移除多余的2的幂因子
shift = shift - bitWidth;
if (shift < 0) shift = 0;
// 非2的幂除数的魔数计算
uint8_t more;
uint32_t rem, proposed_m;
return {magic, shift};
// 计算 proposed_m = floor(2^(floor_log_2_d + 31) / absD)
proposed_m = div_64_32((uint32_t)1 << (floor_log_2_d - 1), 0, absD, &rem);
const uint32_t e = absD - rem;
if (DEBUG) {
std::cout << "[SR] proposed_m = " << proposed_m << ", rem = " << rem << ", e = " << e << std::endl;
}
// 确定是否需要"加法"版本
const bool branchfree = false; // 使用分支版本
if (!branchfree && e < ((uint32_t)1 << floor_log_2_d)) {
// 这个幂次有效
more = (uint8_t)(floor_log_2_d - 1);
if (DEBUG) {
std::cout << "[SR] Using basic algorithm, shift = " << (int)more << std::endl;
}
} else {
// 我们需要上升一个等级
proposed_m += proposed_m;
const uint32_t twice_rem = rem + rem;
if (twice_rem >= absD || twice_rem < rem) {
proposed_m += 1;
}
more = (uint8_t)(floor_log_2_d | LIBDIVIDE_ADD_MARKER);
if (DEBUG) {
std::cout << "[SR] Using add algorithm, proposed_m = " << proposed_m << ", more = " << (int)more << std::endl;
}
}
proposed_m += 1;
int32_t magic = (int32_t)proposed_m;
// 处理负除数
if (d < 0) {
more |= LIBDIVIDE_NEGATIVE_DIVISOR;
if (!branchfree) {
magic = -magic;
}
if (DEBUG) {
std::cout << "[SR] Negative divisor, magic = " << magic << ", more = " << (int)more << std::endl;
}
}
// 为我们的IR生成提取移位量和标志
int shift = more & 0x3F; // 移除标志保留移位量位0-5
bool need_add = (more & LIBDIVIDE_ADD_MARKER) != 0;
bool is_negative = (more & LIBDIVIDE_NEGATIVE_DIVISOR) != 0;
if (DEBUG) {
std::cout << "[SR] Final result: magic = " << magic << ", more = " << (int)more
<< " (0x" << std::hex << (int)more << std::dec << ")" << std::endl;
std::cout << "[SR] Shift = " << shift << ", need_add = " << need_add
<< ", is_negative = " << is_negative << std::endl;
// Test the magic number using the correct libdivide algorithm
std::cout << "[SR] Testing magic number (libdivide algorithm):" << std::endl;
int test_values[] = {1, 7, 37, 100, 999, -1, -7, -37, -100};
for (int test_val : test_values) {
int64_t quotient;
// 实现正确的libdivide算法
int64_t product = (int64_t)test_val * magic;
int64_t high_bits = product >> 32;
if (need_add) {
// ADD_MARKER情况移位前加上被除数
// 这是libdivide的关键洞察
high_bits += test_val;
quotient = high_bits >> shift;
} else {
// 正常情况:只是移位
quotient = high_bits >> shift;
}
// 符号修正这是libdivide有符号除法的关键部分
// 如果被除数为负商需要加1来匹配C语言的截断除法语义
if (test_val < 0) {
quotient += 1;
}
int expected = test_val / divisor;
bool correct = (quotient == expected);
std::cout << "[SR] " << test_val << " / " << divisor << " = " << quotient
<< " (expected " << expected << ") " << (correct ? "" : "") << std::endl;
}
std::cout << "[SR] ===== End magic computation =====" << std::endl;
}
// 返回魔数、移位量并在移位中编码ADD_MARKER标志
// 我们将使用移位的第6位表示ADD_MARKER第7位表示负数如果需要
int encoded_shift = shift;
if (need_add) {
encoded_shift |= 0x40; // 设置第6位表示ADD_MARKER
if (DEBUG) {
std::cout << "[SR] Encoding ADD_MARKER in shift: " << encoded_shift << std::endl;
}
}
return {magic, encoded_shift};
}
bool LoopStrengthReduction::runOnFunction(Function* F, AnalysisManager& AM) {
if (F->getBasicBlocks().empty()) {
return false; // 空函数
@@ -651,7 +776,7 @@ bool StrengthReductionContext::createNewInductionVariable(StrengthReductionCandi
// 2. 在循环头创建新的 phi 指令
builder->setPosition(header, header->begin());
candidate->newPhi = builder->createPhiInst(originalPhi->getType());
candidate->newPhi->setName(originalPhi->getName() + "_sr");
candidate->newPhi->setName("sr_" + originalPhi->getName());
// 3. 计算新归纳变量的初始值和步长
// 新IV的初始值 = 原IV初始值 * multiplier
@@ -895,14 +1020,35 @@ Value* StrengthReductionContext::generateConstantDivisionReplacement(
// 使用mulh指令优化任意常数除法
auto [magic, shift] = computeMulhMagicNumbers(candidate->multiplier);
if (magic == 1 && shift > 0) {
// 特殊情况:可以直接使用移位
Value* shiftConstant = ConstantInteger::get(shift);
// 检查是否无法优化magic == -1, shift == -1 表示失败)
if (magic == -1 && shift == -1) {
if (DEBUG) {
std::cout << "[SR] Cannot optimize division by " << candidate->multiplier
<< ", keeping original division" << std::endl;
}
// 返回 nullptr 表示无法优化,调用方应该保持原始除法
return nullptr;
}
// 2的幂次方除法可以用移位优化但这不是魔数法的情况这种情况应该不会被分类到这里但是还是做一个保护措施
if ((candidate->multiplier & (candidate->multiplier - 1)) == 0 && candidate->multiplier > 0) {
// 是2的幂次方可以用移位
int shift_amount = 0;
int temp = candidate->multiplier;
while (temp > 1) {
temp >>= 1;
shift_amount++;
}
Value* shiftConstant = ConstantInteger::get(shift_amount);
if (candidate->hasNegativeValues) {
// 对于有符号除法,需要先加上除数-1然后再移位为了正确处理负数舍入
Value* divisor_minus_1 = ConstantInteger::get(candidate->multiplier - 1);
Value* adjusted = builder->createAddInst(candidate->inductionVar, divisor_minus_1);
return builder->createBinaryInst(
Instruction::Kind::kSra, // 算术右移
candidate->inductionVar->getType(),
candidate->inductionVar,
adjusted,
shiftConstant
);
} else {
@@ -916,8 +1062,25 @@ Value* StrengthReductionContext::generateConstantDivisionReplacement(
}
// 创建魔数常量
// 检查魔数是否能放入32位如果不能则不进行优化
if (magic > INT32_MAX || magic < INT32_MIN) {
if (DEBUG) {
std::cout << "[SR] Magic number " << magic << " exceeds 32-bit range, skipping optimization" << std::endl;
}
return nullptr; // 无法优化,保持原始除法
}
Value* magicConstant = ConstantInteger::get((int32_t)magic);
// 检查是否需要ADD_MARKER处理加法调整
bool needAdd = (shift & 0x40) != 0;
int actualShift = shift & 0x3F; // 提取真实的移位量
if (DEBUG) {
std::cout << "[SR] IR Generation: magic=" << magic << ", needAdd=" << needAdd
<< ", actualShift=" << actualShift << std::endl;
}
// 执行高位乘法mulh(x, magic)
Value* mulhResult = builder->createBinaryInst(
Instruction::Kind::kMulh, // 高位乘法
@@ -926,9 +1089,18 @@ Value* StrengthReductionContext::generateConstantDivisionReplacement(
magicConstant
);
if (shift > 0) {
if (needAdd) {
// ADD_MARKER 情况:需要在移位前加上被除数
// 这对应于 libdivide 的加法调整算法
if (DEBUG) {
std::cout << "[SR] Applying ADD_MARKER: adding dividend before shift" << std::endl;
}
mulhResult = builder->createAddInst(mulhResult, candidate->inductionVar);
}
if (actualShift > 0) {
// 如果需要额外移位
Value* shiftConstant = ConstantInteger::get(shift);
Value* shiftConstant = ConstantInteger::get(actualShift);
mulhResult = builder->createBinaryInst(
Instruction::Kind::kSra, // 算术右移
candidate->inductionVar->getType(),
@@ -937,14 +1109,11 @@ Value* StrengthReductionContext::generateConstantDivisionReplacement(
);
}
// 处理负数校正 - 简化版本
if (candidate->hasNegativeValues) {
// 简化处理:添加一个常数偏移来处理负数情况
// 这是一个简化的实现,实际的负数校正会更复杂
Value* zero = ConstantInteger::get(0);
Value* isNegative = builder->createICmpLTInst(candidate->inductionVar, zero);
// 这里应该有条件逻辑但为了简化实现暂时直接返回mulhResult
}
// 标准的有符号除法符号修正如果被除数为负商需要加1
// 这对所有有符号除法都需要,不管是否可能有负数
Value* isNegative = builder->createICmpLTInst(candidate->inductionVar, ConstantInteger::get(0));
// 将i1转换为i32负数时为1非负数时为0 ICmpLTInst的结果会默认转化为32位
mulhResult = builder->createAddInst(mulhResult, isNegative);
return mulhResult;
}