[optimze]添加基础的除法指令优化,目前只对除以2的幂数生效

This commit is contained in:
2025-08-03 13:46:42 +08:00
parent e8699d6d25
commit f312792fe9
10 changed files with 419 additions and 6 deletions

View File

@@ -0,0 +1,329 @@
#include "DivStrengthReduction.h"
namespace sysy {
char DivStrengthReduction::ID = 0;
bool DivStrengthReduction::runOnFunction(Function *F, AnalysisManager& AM) {
// This pass works on MachineFunction level, not IR level
return false;
}
void DivStrengthReduction::runOnMachineFunction(MachineFunction *mfunc) {
if (!mfunc)
return;
bool debug = false; // Set to true for debugging
if (debug)
std::cout << "Running DivStrengthReduction optimization..." << std::endl;
// 虚拟寄存器分配器
int next_temp_reg = 1000;
auto createTempReg = [&]() -> int {
return next_temp_reg++;
};
// Magic number 信息结构
struct MagicInfo {
int64_t magic;
int shift;
bool add_indicator; // 是否需要额外的加法修正
};
// 针对缺少MULH指令的简化magic number计算
auto computeMagicNumber = [](int64_t divisor, bool is_32bit) -> MagicInfo {
if (divisor == 0) return {0, 0, false};
if (divisor == 1) return {1, 0, false};
if (divisor == -1) return {-1, 0, false};
// 对于没有MULH的情况我们使用更简单但有效的算法
// 基于 2^n / divisor 的近似
bool neg = divisor < 0;
int64_t d = neg ? -divisor : divisor;
int word_size = is_32bit ? 32 : 64;
// 计算合适的移位量
int shift = word_size;
int64_t magic = ((1LL << shift) + d - 1) / d;
// 调整magic number以适应MUL指令
if (is_32bit) {
// 32位情况调整magic使其适合符号扩展后的乘法
shift = 32;
magic = ((1LL << shift) + d - 1) / d;
} else {
// 64位情况使用更保守的算法
shift = 32; // 使用32位作为基础移位
magic = ((1LL << shift) + d - 1) / d;
}
bool add_indicator = false;
// 检查是否需要加法修正
if (magic >= (1LL << (word_size - 1))) {
add_indicator = true;
magic -= (1LL << word_size);
}
if (neg) {
magic = -magic;
}
return {magic, shift, add_indicator};
};
// 检查是否为2的幂次
auto isPowerOfTwo = [](int64_t n) -> bool {
return n > 0 && (n & (n - 1)) == 0;
};
// 获取2的幂次的指数
auto getPowerOfTwoExponent = [](int64_t n) -> int {
if (n <= 0 || (n & (n - 1)) != 0) return -1;
int shift = 0;
while (n > 1) {
n >>= 1;
shift++;
}
return shift;
};
// 收集需要替换的指令
struct InstructionReplacement {
size_t index;
std::vector<std::unique_ptr<MachineInstr>> newInstrs;
};
for (auto &mbb_uptr : mfunc->getBlocks()) {
auto &mbb = *mbb_uptr;
auto &instrs = mbb.getInstructions();
std::vector<InstructionReplacement> replacements;
for (size_t i = 0; i < instrs.size(); ++i) {
auto *instr = instrs[i].get();
bool is_32bit = (instr->getOpcode() == RVOpcodes::DIVW);
// 只处理 DIV 和 DIVW 指令
if (instr->getOpcode() != RVOpcodes::DIV && !is_32bit) {
continue;
}
if (instr->getOperands().size() != 3) {
continue;
}
auto *dst_op = instr->getOperands()[0].get();
auto *src1_op = instr->getOperands()[1].get();
auto *src2_op = instr->getOperands()[2].get();
// 检查操作数类型
if (dst_op->getKind() != MachineOperand::KIND_REG ||
src1_op->getKind() != MachineOperand::KIND_REG ||
src2_op->getKind() != MachineOperand::KIND_IMM) {
continue;
}
auto *dst_reg = static_cast<RegOperand *>(dst_op);
auto *src1_reg = static_cast<RegOperand *>(src1_op);
auto *src2_imm = static_cast<ImmOperand *>(src2_op);
int64_t divisor = src2_imm->getValue();
// 跳过除数为0的情况
if (divisor == 0) continue;
std::vector<std::unique_ptr<MachineInstr>> newInstrs;
// 情况1: 除数为1
if (divisor == 1) {
// dst = src1 (直接复制)
auto moveInstr = std::make_unique<MachineInstr>(is_32bit ? RVOpcodes::ADDW : RVOpcodes::ADD);
moveInstr->addOperand(std::make_unique<RegOperand>(*dst_reg));
moveInstr->addOperand(std::make_unique<RegOperand>(*src1_reg));
moveInstr->addOperand(std::make_unique<RegOperand>(PhysicalReg::ZERO));
newInstrs.push_back(std::move(moveInstr));
}
// 情况2: 除数为-1
else if (divisor == -1) {
// dst = -src1
auto negInstr = std::make_unique<MachineInstr>(is_32bit ? RVOpcodes::SUBW : RVOpcodes::SUB);
negInstr->addOperand(std::make_unique<RegOperand>(*dst_reg));
negInstr->addOperand(std::make_unique<RegOperand>(PhysicalReg::ZERO));
negInstr->addOperand(std::make_unique<RegOperand>(*src1_reg));
newInstrs.push_back(std::move(negInstr));
}
// 情况3: 正的2的幂次除法
else if (isPowerOfTwo(divisor)) {
int shift = getPowerOfTwoExponent(divisor);
int temp_reg = createTempReg();
// 对于有符号除法,需要处理负数的舍入
// if (src1 < 0) src1 += (divisor - 1)
// 获取符号位temp = src1 >> (word_size - 1)
auto sraSignInstr = std::make_unique<MachineInstr>(is_32bit ? RVOpcodes::SRAIW : RVOpcodes::SRAI);
sraSignInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
sraSignInstr->addOperand(std::make_unique<RegOperand>(*src1_reg));
sraSignInstr->addOperand(std::make_unique<ImmOperand>(is_32bit ? 31 : 63));
newInstrs.push_back(std::move(sraSignInstr));
// 计算偏移temp = temp >> (word_size - shift)
if (shift < (is_32bit ? 32 : 64)) {
auto srlInstr = std::make_unique<MachineInstr>(is_32bit ? RVOpcodes::SRLIW : RVOpcodes::SRLI);
srlInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
srlInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
srlInstr->addOperand(std::make_unique<ImmOperand>((is_32bit ? 32 : 64) - shift));
newInstrs.push_back(std::move(srlInstr));
}
// 加上偏移temp = src1 + temp
auto addInstr = std::make_unique<MachineInstr>(is_32bit ? RVOpcodes::ADDW : RVOpcodes::ADD);
addInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
addInstr->addOperand(std::make_unique<RegOperand>(*src1_reg));
addInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
newInstrs.push_back(std::move(addInstr));
// 最终右移dst = temp >> shift
auto sraInstr = std::make_unique<MachineInstr>(is_32bit ? RVOpcodes::SRAIW : RVOpcodes::SRAI);
sraInstr->addOperand(std::make_unique<RegOperand>(*dst_reg));
sraInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
sraInstr->addOperand(std::make_unique<ImmOperand>(shift));
newInstrs.push_back(std::move(sraInstr));
}
// 情况4: 负的2的幂次除法
else if (divisor < 0 && isPowerOfTwo(-divisor)) {
int shift = getPowerOfTwoExponent(-divisor);
int temp_reg = createTempReg();
// 先按正数处理
auto sraSignInstr = std::make_unique<MachineInstr>(is_32bit ? RVOpcodes::SRAIW : RVOpcodes::SRAI);
sraSignInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
sraSignInstr->addOperand(std::make_unique<RegOperand>(*src1_reg));
sraSignInstr->addOperand(std::make_unique<ImmOperand>(is_32bit ? 31 : 63));
newInstrs.push_back(std::move(sraSignInstr));
if (shift < (is_32bit ? 32 : 64)) {
auto srlInstr = std::make_unique<MachineInstr>(is_32bit ? RVOpcodes::SRLIW : RVOpcodes::SRLI);
srlInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
srlInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
srlInstr->addOperand(std::make_unique<ImmOperand>((is_32bit ? 32 : 64) - shift));
newInstrs.push_back(std::move(srlInstr));
}
auto addInstr = std::make_unique<MachineInstr>(is_32bit ? RVOpcodes::ADDW : RVOpcodes::ADD);
addInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
addInstr->addOperand(std::make_unique<RegOperand>(*src1_reg));
addInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
newInstrs.push_back(std::move(addInstr));
auto sraInstr = std::make_unique<MachineInstr>(is_32bit ? RVOpcodes::SRAIW : RVOpcodes::SRAI);
sraInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
sraInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
sraInstr->addOperand(std::make_unique<ImmOperand>(shift));
newInstrs.push_back(std::move(sraInstr));
// 然后取反
auto negInstr = std::make_unique<MachineInstr>(is_32bit ? RVOpcodes::SUBW : RVOpcodes::SUB);
negInstr->addOperand(std::make_unique<RegOperand>(*dst_reg));
negInstr->addOperand(std::make_unique<RegOperand>(PhysicalReg::ZERO));
negInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
newInstrs.push_back(std::move(negInstr));
}
// 情况5: 通用magic number算法针对没有MULH的情况进行了简化
else {
// 对于一般除法在没有MULH的情况下我们采用更保守的策略
// 只处理一些简单的常数除法,复杂的情况保持原始除法指令
// 检查是否为小的常数(可以用简单乘法处理)
if (std::abs(divisor) <= 1024) { // 限制在较小的除数范围内
auto magic_info = computeMagicNumber(divisor, is_32bit);
if (magic_info.magic == 0) continue;
int magic_reg = createTempReg();
int temp_reg = createTempReg();
// 加载magic number到寄存器
auto loadInstr = std::make_unique<MachineInstr>(RVOpcodes::LI);
loadInstr->addOperand(std::make_unique<RegOperand>(magic_reg));
loadInstr->addOperand(std::make_unique<ImmOperand>(magic_info.magic));
newInstrs.push_back(std::move(loadInstr));
// 使用普通乘法模拟高位乘法
if (is_32bit) {
// 32位使用MULW
auto mulInstr = std::make_unique<MachineInstr>(RVOpcodes::MULW);
mulInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
mulInstr->addOperand(std::make_unique<RegOperand>(*src1_reg));
mulInstr->addOperand(std::make_unique<RegOperand>(magic_reg));
newInstrs.push_back(std::move(mulInstr));
// 右移得到近似结果
auto sraInstr = std::make_unique<MachineInstr>(RVOpcodes::SRAIW);
sraInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
sraInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
sraInstr->addOperand(std::make_unique<ImmOperand>(magic_info.shift));
newInstrs.push_back(std::move(sraInstr));
} else {
// 64位使用MUL
auto mulInstr = std::make_unique<MachineInstr>(RVOpcodes::MUL);
mulInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
mulInstr->addOperand(std::make_unique<RegOperand>(*src1_reg));
mulInstr->addOperand(std::make_unique<RegOperand>(magic_reg));
newInstrs.push_back(std::move(mulInstr));
// 右移得到近似结果
auto sraInstr = std::make_unique<MachineInstr>(RVOpcodes::SRAI);
sraInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
sraInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
sraInstr->addOperand(std::make_unique<ImmOperand>(magic_info.shift));
newInstrs.push_back(std::move(sraInstr));
}
// 符号修正:处理负数被除数
int sign_reg = createTempReg();
// 获取被除数的符号位
auto sraSignInstr = std::make_unique<MachineInstr>(is_32bit ? RVOpcodes::SRAIW : RVOpcodes::SRAI);
sraSignInstr->addOperand(std::make_unique<RegOperand>(sign_reg));
sraSignInstr->addOperand(std::make_unique<RegOperand>(*src1_reg));
sraSignInstr->addOperand(std::make_unique<ImmOperand>(is_32bit ? 31 : 63));
newInstrs.push_back(std::move(sraSignInstr));
// 最终结果dst = temp - sign对于正除数或 dst = temp + sign对于负除数
if (divisor > 0) {
auto finalSubInstr = std::make_unique<MachineInstr>(is_32bit ? RVOpcodes::SUBW : RVOpcodes::SUB);
finalSubInstr->addOperand(std::make_unique<RegOperand>(*dst_reg));
finalSubInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
finalSubInstr->addOperand(std::make_unique<RegOperand>(sign_reg));
newInstrs.push_back(std::move(finalSubInstr));
} else {
auto finalAddInstr = std::make_unique<MachineInstr>(is_32bit ? RVOpcodes::ADDW : RVOpcodes::ADD);
finalAddInstr->addOperand(std::make_unique<RegOperand>(*dst_reg));
finalAddInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
finalAddInstr->addOperand(std::make_unique<RegOperand>(sign_reg));
newInstrs.push_back(std::move(finalAddInstr));
}
}
// 对于大的除数或复杂情况,保持原始除法指令不变
}
if (!newInstrs.empty()) {
replacements.push_back({i, std::move(newInstrs)});
}
}
// 批量应用替换(从后往前处理避免索引问题)
for (auto it = replacements.rbegin(); it != replacements.rend(); ++it) {
instrs.erase(instrs.begin() + it->index);
instrs.insert(instrs.begin() + it->index,
std::make_move_iterator(it->newInstrs.begin()),
std::make_move_iterator(it->newInstrs.end()));
}
}
}
} // namespace sysy