[optimze]添加基础的除法指令优化,目前只对除以2的幂数生效
This commit is contained in:
329
src/backend/RISCv64/Optimize/DivStrengthReduction.cpp
Normal file
329
src/backend/RISCv64/Optimize/DivStrengthReduction.cpp
Normal file
@@ -0,0 +1,329 @@
|
||||
#include "DivStrengthReduction.h"
|
||||
|
||||
namespace sysy {
|
||||
|
||||
char DivStrengthReduction::ID = 0;
|
||||
|
||||
bool DivStrengthReduction::runOnFunction(Function *F, AnalysisManager& AM) {
|
||||
// This pass works on MachineFunction level, not IR level
|
||||
return false;
|
||||
}
|
||||
|
||||
void DivStrengthReduction::runOnMachineFunction(MachineFunction *mfunc) {
|
||||
if (!mfunc)
|
||||
return;
|
||||
|
||||
bool debug = false; // Set to true for debugging
|
||||
if (debug)
|
||||
std::cout << "Running DivStrengthReduction optimization..." << std::endl;
|
||||
|
||||
// 虚拟寄存器分配器
|
||||
int next_temp_reg = 1000;
|
||||
auto createTempReg = [&]() -> int {
|
||||
return next_temp_reg++;
|
||||
};
|
||||
|
||||
// Magic number 信息结构
|
||||
struct MagicInfo {
|
||||
int64_t magic;
|
||||
int shift;
|
||||
bool add_indicator; // 是否需要额外的加法修正
|
||||
};
|
||||
|
||||
// 针对缺少MULH指令的简化magic number计算
|
||||
auto computeMagicNumber = [](int64_t divisor, bool is_32bit) -> MagicInfo {
|
||||
if (divisor == 0) return {0, 0, false};
|
||||
if (divisor == 1) return {1, 0, false};
|
||||
if (divisor == -1) return {-1, 0, false};
|
||||
|
||||
// 对于没有MULH的情况,我们使用更简单但有效的算法
|
||||
// 基于 2^n / divisor 的近似
|
||||
|
||||
bool neg = divisor < 0;
|
||||
int64_t d = neg ? -divisor : divisor;
|
||||
|
||||
int word_size = is_32bit ? 32 : 64;
|
||||
|
||||
// 计算合适的移位量
|
||||
int shift = word_size;
|
||||
int64_t magic = ((1LL << shift) + d - 1) / d;
|
||||
|
||||
// 调整magic number以适应MUL指令
|
||||
if (is_32bit) {
|
||||
// 32位情况:调整magic使其适合符号扩展后的乘法
|
||||
shift = 32;
|
||||
magic = ((1LL << shift) + d - 1) / d;
|
||||
} else {
|
||||
// 64位情况:使用更保守的算法
|
||||
shift = 32; // 使用32位作为基础移位
|
||||
magic = ((1LL << shift) + d - 1) / d;
|
||||
}
|
||||
|
||||
bool add_indicator = false;
|
||||
|
||||
// 检查是否需要加法修正
|
||||
if (magic >= (1LL << (word_size - 1))) {
|
||||
add_indicator = true;
|
||||
magic -= (1LL << word_size);
|
||||
}
|
||||
|
||||
if (neg) {
|
||||
magic = -magic;
|
||||
}
|
||||
|
||||
return {magic, shift, add_indicator};
|
||||
};
|
||||
|
||||
// 检查是否为2的幂次
|
||||
auto isPowerOfTwo = [](int64_t n) -> bool {
|
||||
return n > 0 && (n & (n - 1)) == 0;
|
||||
};
|
||||
|
||||
// 获取2的幂次的指数
|
||||
auto getPowerOfTwoExponent = [](int64_t n) -> int {
|
||||
if (n <= 0 || (n & (n - 1)) != 0) return -1;
|
||||
int shift = 0;
|
||||
while (n > 1) {
|
||||
n >>= 1;
|
||||
shift++;
|
||||
}
|
||||
return shift;
|
||||
};
|
||||
|
||||
// 收集需要替换的指令
|
||||
struct InstructionReplacement {
|
||||
size_t index;
|
||||
std::vector<std::unique_ptr<MachineInstr>> newInstrs;
|
||||
};
|
||||
|
||||
for (auto &mbb_uptr : mfunc->getBlocks()) {
|
||||
auto &mbb = *mbb_uptr;
|
||||
auto &instrs = mbb.getInstructions();
|
||||
std::vector<InstructionReplacement> replacements;
|
||||
|
||||
for (size_t i = 0; i < instrs.size(); ++i) {
|
||||
auto *instr = instrs[i].get();
|
||||
|
||||
bool is_32bit = (instr->getOpcode() == RVOpcodes::DIVW);
|
||||
|
||||
// 只处理 DIV 和 DIVW 指令
|
||||
if (instr->getOpcode() != RVOpcodes::DIV && !is_32bit) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (instr->getOperands().size() != 3) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto *dst_op = instr->getOperands()[0].get();
|
||||
auto *src1_op = instr->getOperands()[1].get();
|
||||
auto *src2_op = instr->getOperands()[2].get();
|
||||
|
||||
// 检查操作数类型
|
||||
if (dst_op->getKind() != MachineOperand::KIND_REG ||
|
||||
src1_op->getKind() != MachineOperand::KIND_REG ||
|
||||
src2_op->getKind() != MachineOperand::KIND_IMM) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto *dst_reg = static_cast<RegOperand *>(dst_op);
|
||||
auto *src1_reg = static_cast<RegOperand *>(src1_op);
|
||||
auto *src2_imm = static_cast<ImmOperand *>(src2_op);
|
||||
|
||||
int64_t divisor = src2_imm->getValue();
|
||||
|
||||
// 跳过除数为0的情况
|
||||
if (divisor == 0) continue;
|
||||
|
||||
std::vector<std::unique_ptr<MachineInstr>> newInstrs;
|
||||
|
||||
// 情况1: 除数为1
|
||||
if (divisor == 1) {
|
||||
// dst = src1 (直接复制)
|
||||
auto moveInstr = std::make_unique<MachineInstr>(is_32bit ? RVOpcodes::ADDW : RVOpcodes::ADD);
|
||||
moveInstr->addOperand(std::make_unique<RegOperand>(*dst_reg));
|
||||
moveInstr->addOperand(std::make_unique<RegOperand>(*src1_reg));
|
||||
moveInstr->addOperand(std::make_unique<RegOperand>(PhysicalReg::ZERO));
|
||||
newInstrs.push_back(std::move(moveInstr));
|
||||
}
|
||||
// 情况2: 除数为-1
|
||||
else if (divisor == -1) {
|
||||
// dst = -src1
|
||||
auto negInstr = std::make_unique<MachineInstr>(is_32bit ? RVOpcodes::SUBW : RVOpcodes::SUB);
|
||||
negInstr->addOperand(std::make_unique<RegOperand>(*dst_reg));
|
||||
negInstr->addOperand(std::make_unique<RegOperand>(PhysicalReg::ZERO));
|
||||
negInstr->addOperand(std::make_unique<RegOperand>(*src1_reg));
|
||||
newInstrs.push_back(std::move(negInstr));
|
||||
}
|
||||
// 情况3: 正的2的幂次除法
|
||||
else if (isPowerOfTwo(divisor)) {
|
||||
int shift = getPowerOfTwoExponent(divisor);
|
||||
int temp_reg = createTempReg();
|
||||
|
||||
// 对于有符号除法,需要处理负数的舍入
|
||||
// if (src1 < 0) src1 += (divisor - 1)
|
||||
|
||||
// 获取符号位:temp = src1 >> (word_size - 1)
|
||||
auto sraSignInstr = std::make_unique<MachineInstr>(is_32bit ? RVOpcodes::SRAIW : RVOpcodes::SRAI);
|
||||
sraSignInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
|
||||
sraSignInstr->addOperand(std::make_unique<RegOperand>(*src1_reg));
|
||||
sraSignInstr->addOperand(std::make_unique<ImmOperand>(is_32bit ? 31 : 63));
|
||||
newInstrs.push_back(std::move(sraSignInstr));
|
||||
|
||||
// 计算偏移:temp = temp >> (word_size - shift)
|
||||
if (shift < (is_32bit ? 32 : 64)) {
|
||||
auto srlInstr = std::make_unique<MachineInstr>(is_32bit ? RVOpcodes::SRLIW : RVOpcodes::SRLI);
|
||||
srlInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
|
||||
srlInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
|
||||
srlInstr->addOperand(std::make_unique<ImmOperand>((is_32bit ? 32 : 64) - shift));
|
||||
newInstrs.push_back(std::move(srlInstr));
|
||||
}
|
||||
|
||||
// 加上偏移:temp = src1 + temp
|
||||
auto addInstr = std::make_unique<MachineInstr>(is_32bit ? RVOpcodes::ADDW : RVOpcodes::ADD);
|
||||
addInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
|
||||
addInstr->addOperand(std::make_unique<RegOperand>(*src1_reg));
|
||||
addInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
|
||||
newInstrs.push_back(std::move(addInstr));
|
||||
|
||||
// 最终右移:dst = temp >> shift
|
||||
auto sraInstr = std::make_unique<MachineInstr>(is_32bit ? RVOpcodes::SRAIW : RVOpcodes::SRAI);
|
||||
sraInstr->addOperand(std::make_unique<RegOperand>(*dst_reg));
|
||||
sraInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
|
||||
sraInstr->addOperand(std::make_unique<ImmOperand>(shift));
|
||||
newInstrs.push_back(std::move(sraInstr));
|
||||
}
|
||||
// 情况4: 负的2的幂次除法
|
||||
else if (divisor < 0 && isPowerOfTwo(-divisor)) {
|
||||
int shift = getPowerOfTwoExponent(-divisor);
|
||||
int temp_reg = createTempReg();
|
||||
|
||||
// 先按正数处理
|
||||
auto sraSignInstr = std::make_unique<MachineInstr>(is_32bit ? RVOpcodes::SRAIW : RVOpcodes::SRAI);
|
||||
sraSignInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
|
||||
sraSignInstr->addOperand(std::make_unique<RegOperand>(*src1_reg));
|
||||
sraSignInstr->addOperand(std::make_unique<ImmOperand>(is_32bit ? 31 : 63));
|
||||
newInstrs.push_back(std::move(sraSignInstr));
|
||||
|
||||
if (shift < (is_32bit ? 32 : 64)) {
|
||||
auto srlInstr = std::make_unique<MachineInstr>(is_32bit ? RVOpcodes::SRLIW : RVOpcodes::SRLI);
|
||||
srlInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
|
||||
srlInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
|
||||
srlInstr->addOperand(std::make_unique<ImmOperand>((is_32bit ? 32 : 64) - shift));
|
||||
newInstrs.push_back(std::move(srlInstr));
|
||||
}
|
||||
|
||||
auto addInstr = std::make_unique<MachineInstr>(is_32bit ? RVOpcodes::ADDW : RVOpcodes::ADD);
|
||||
addInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
|
||||
addInstr->addOperand(std::make_unique<RegOperand>(*src1_reg));
|
||||
addInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
|
||||
newInstrs.push_back(std::move(addInstr));
|
||||
|
||||
auto sraInstr = std::make_unique<MachineInstr>(is_32bit ? RVOpcodes::SRAIW : RVOpcodes::SRAI);
|
||||
sraInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
|
||||
sraInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
|
||||
sraInstr->addOperand(std::make_unique<ImmOperand>(shift));
|
||||
newInstrs.push_back(std::move(sraInstr));
|
||||
|
||||
// 然后取反
|
||||
auto negInstr = std::make_unique<MachineInstr>(is_32bit ? RVOpcodes::SUBW : RVOpcodes::SUB);
|
||||
negInstr->addOperand(std::make_unique<RegOperand>(*dst_reg));
|
||||
negInstr->addOperand(std::make_unique<RegOperand>(PhysicalReg::ZERO));
|
||||
negInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
|
||||
newInstrs.push_back(std::move(negInstr));
|
||||
}
|
||||
// 情况5: 通用magic number算法(针对没有MULH的情况进行了简化)
|
||||
else {
|
||||
// 对于一般除法,在没有MULH的情况下,我们采用更保守的策略
|
||||
// 只处理一些简单的常数除法,复杂的情况保持原始除法指令
|
||||
|
||||
// 检查是否为小的常数(可以用简单乘法处理)
|
||||
if (std::abs(divisor) <= 1024) { // 限制在较小的除数范围内
|
||||
auto magic_info = computeMagicNumber(divisor, is_32bit);
|
||||
|
||||
if (magic_info.magic == 0) continue;
|
||||
|
||||
int magic_reg = createTempReg();
|
||||
int temp_reg = createTempReg();
|
||||
|
||||
// 加载magic number到寄存器
|
||||
auto loadInstr = std::make_unique<MachineInstr>(RVOpcodes::LI);
|
||||
loadInstr->addOperand(std::make_unique<RegOperand>(magic_reg));
|
||||
loadInstr->addOperand(std::make_unique<ImmOperand>(magic_info.magic));
|
||||
newInstrs.push_back(std::move(loadInstr));
|
||||
|
||||
// 使用普通乘法模拟高位乘法
|
||||
if (is_32bit) {
|
||||
// 32位:使用MULW
|
||||
auto mulInstr = std::make_unique<MachineInstr>(RVOpcodes::MULW);
|
||||
mulInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
|
||||
mulInstr->addOperand(std::make_unique<RegOperand>(*src1_reg));
|
||||
mulInstr->addOperand(std::make_unique<RegOperand>(magic_reg));
|
||||
newInstrs.push_back(std::move(mulInstr));
|
||||
|
||||
// 右移得到近似结果
|
||||
auto sraInstr = std::make_unique<MachineInstr>(RVOpcodes::SRAIW);
|
||||
sraInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
|
||||
sraInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
|
||||
sraInstr->addOperand(std::make_unique<ImmOperand>(magic_info.shift));
|
||||
newInstrs.push_back(std::move(sraInstr));
|
||||
} else {
|
||||
// 64位:使用MUL
|
||||
auto mulInstr = std::make_unique<MachineInstr>(RVOpcodes::MUL);
|
||||
mulInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
|
||||
mulInstr->addOperand(std::make_unique<RegOperand>(*src1_reg));
|
||||
mulInstr->addOperand(std::make_unique<RegOperand>(magic_reg));
|
||||
newInstrs.push_back(std::move(mulInstr));
|
||||
|
||||
// 右移得到近似结果
|
||||
auto sraInstr = std::make_unique<MachineInstr>(RVOpcodes::SRAI);
|
||||
sraInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
|
||||
sraInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
|
||||
sraInstr->addOperand(std::make_unique<ImmOperand>(magic_info.shift));
|
||||
newInstrs.push_back(std::move(sraInstr));
|
||||
}
|
||||
|
||||
// 符号修正:处理负数被除数
|
||||
int sign_reg = createTempReg();
|
||||
|
||||
// 获取被除数的符号位
|
||||
auto sraSignInstr = std::make_unique<MachineInstr>(is_32bit ? RVOpcodes::SRAIW : RVOpcodes::SRAI);
|
||||
sraSignInstr->addOperand(std::make_unique<RegOperand>(sign_reg));
|
||||
sraSignInstr->addOperand(std::make_unique<RegOperand>(*src1_reg));
|
||||
sraSignInstr->addOperand(std::make_unique<ImmOperand>(is_32bit ? 31 : 63));
|
||||
newInstrs.push_back(std::move(sraSignInstr));
|
||||
|
||||
// 最终结果:dst = temp - sign(对于正除数)或 dst = temp + sign(对于负除数)
|
||||
if (divisor > 0) {
|
||||
auto finalSubInstr = std::make_unique<MachineInstr>(is_32bit ? RVOpcodes::SUBW : RVOpcodes::SUB);
|
||||
finalSubInstr->addOperand(std::make_unique<RegOperand>(*dst_reg));
|
||||
finalSubInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
|
||||
finalSubInstr->addOperand(std::make_unique<RegOperand>(sign_reg));
|
||||
newInstrs.push_back(std::move(finalSubInstr));
|
||||
} else {
|
||||
auto finalAddInstr = std::make_unique<MachineInstr>(is_32bit ? RVOpcodes::ADDW : RVOpcodes::ADD);
|
||||
finalAddInstr->addOperand(std::make_unique<RegOperand>(*dst_reg));
|
||||
finalAddInstr->addOperand(std::make_unique<RegOperand>(temp_reg));
|
||||
finalAddInstr->addOperand(std::make_unique<RegOperand>(sign_reg));
|
||||
newInstrs.push_back(std::move(finalAddInstr));
|
||||
}
|
||||
}
|
||||
// 对于大的除数或复杂情况,保持原始除法指令不变
|
||||
}
|
||||
|
||||
if (!newInstrs.empty()) {
|
||||
replacements.push_back({i, std::move(newInstrs)});
|
||||
}
|
||||
}
|
||||
|
||||
// 批量应用替换(从后往前处理避免索引问题)
|
||||
for (auto it = replacements.rbegin(); it != replacements.rend(); ++it) {
|
||||
instrs.erase(instrs.begin() + it->index);
|
||||
instrs.insert(instrs.begin() + it->index,
|
||||
std::make_move_iterator(it->newInstrs.begin()),
|
||||
std::make_move_iterator(it->newInstrs.end()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace sysy
|
||||
Reference in New Issue
Block a user