#include "DivStrengthReduction.h" #include #include namespace sysy { char DivStrengthReduction::ID = 0; bool DivStrengthReduction::runOnFunction(Function *F, AnalysisManager& AM) { // This pass works on MachineFunction level, not IR level return false; } void DivStrengthReduction::runOnMachineFunction(MachineFunction *mfunc) { if (!mfunc) return; bool debug = false; // Set to true for debugging if (debug) std::cout << "Running DivStrengthReduction optimization..." << std::endl; int next_temp_reg = 1000; auto createTempReg = [&]() -> int { return next_temp_reg++; }; struct MagicInfo { int64_t magic; int shift; }; auto computeMagic = [](int64_t d, bool is_32bit) -> MagicInfo { int word_size = is_32bit ? 32 : 64; uint64_t ad = std::abs(d); if (ad == 0) return {0, 0}; int l = std::floor(std::log2(ad)); if ((ad & (ad - 1)) == 0) { // power of 2 l = 0; // special case for power of 2, shift will be calculated differently } __int128_t one = 1; __int128_t num; int total_shift; if (is_32bit) { total_shift = 31 + l; num = one << total_shift; } else { total_shift = 63 + l; num = one << total_shift; } __int128_t den = ad; int64_t magic = (num / den) + 1; return {magic, total_shift}; }; auto isPowerOfTwo = [](int64_t n) -> bool { return n > 0 && (n & (n - 1)) == 0; }; auto getPowerOfTwoExponent = [](int64_t n) -> int { if (n <= 0 || (n & (n - 1)) != 0) return -1; int shift = 0; while (n > 1) { n >>= 1; shift++; } return shift; }; struct InstructionReplacement { size_t index; size_t count_to_erase; std::vector> newInstrs; }; for (auto &mbb_uptr : mfunc->getBlocks()) { auto &mbb = *mbb_uptr; auto &instrs = mbb.getInstructions(); std::vector replacements; for (size_t i = 0; i < instrs.size(); ++i) { auto *instr = instrs[i].get(); bool is_32bit = (instr->getOpcode() == RVOpcodes::DIVW); if (instr->getOpcode() != RVOpcodes::DIV && !is_32bit) { continue; } if (instr->getOperands().size() != 3) { continue; } auto *dst_op = instr->getOperands()[0].get(); auto *src1_op = instr->getOperands()[1].get(); auto *src2_op = instr->getOperands()[2].get(); int64_t divisor = 0; bool const_divisor_found = false; size_t instructions_to_replace = 1; if (src2_op->getKind() == MachineOperand::KIND_IMM) { divisor = static_cast(src2_op)->getValue(); const_divisor_found = true; } else if (src2_op->getKind() == MachineOperand::KIND_REG) { if (i > 0) { auto *prev_instr = instrs[i - 1].get(); if (prev_instr->getOpcode() == RVOpcodes::LI && prev_instr->getOperands().size() == 2) { auto *li_dst_op = prev_instr->getOperands()[0].get(); auto *li_imm_op = prev_instr->getOperands()[1].get(); if (li_dst_op->getKind() == MachineOperand::KIND_REG && li_imm_op->getKind() == MachineOperand::KIND_IMM) { auto *div_reg_op = static_cast(src2_op); auto *li_dst_reg_op = static_cast(li_dst_op); if (div_reg_op->isVirtual() && li_dst_reg_op->isVirtual() && div_reg_op->getVRegNum() == li_dst_reg_op->getVRegNum()) { divisor = static_cast(li_imm_op)->getValue(); const_divisor_found = true; instructions_to_replace = 2; } } } } } if (!const_divisor_found) { continue; } auto *dst_reg = static_cast(dst_op); auto *src1_reg = static_cast(src1_op); if (divisor == 0) continue; std::vector> newInstrs; if (divisor == 1) { auto moveInstr = std::make_unique(is_32bit ? RVOpcodes::ADDW : RVOpcodes::ADD); moveInstr->addOperand(std::make_unique(*dst_reg)); moveInstr->addOperand(std::make_unique(*src1_reg)); moveInstr->addOperand(std::make_unique(PhysicalReg::ZERO)); newInstrs.push_back(std::move(moveInstr)); } else if (divisor == -1) { auto negInstr = std::make_unique(is_32bit ? RVOpcodes::SUBW : RVOpcodes::SUB); negInstr->addOperand(std::make_unique(*dst_reg)); negInstr->addOperand(std::make_unique(PhysicalReg::ZERO)); negInstr->addOperand(std::make_unique(*src1_reg)); newInstrs.push_back(std::move(negInstr)); } else if (isPowerOfTwo(std::abs(divisor))) { int shift = getPowerOfTwoExponent(std::abs(divisor)); int temp_reg = createTempReg(); auto sraSignInstr = std::make_unique(is_32bit ? RVOpcodes::SRAIW : RVOpcodes::SRAI); sraSignInstr->addOperand(std::make_unique(temp_reg)); sraSignInstr->addOperand(std::make_unique(*src1_reg)); sraSignInstr->addOperand(std::make_unique(is_32bit ? 31 : 63)); newInstrs.push_back(std::move(sraSignInstr)); auto srlInstr = std::make_unique(is_32bit ? RVOpcodes::SRLIW : RVOpcodes::SRLI); srlInstr->addOperand(std::make_unique(temp_reg)); srlInstr->addOperand(std::make_unique(temp_reg)); srlInstr->addOperand(std::make_unique((is_32bit ? 32 : 64) - shift)); newInstrs.push_back(std::move(srlInstr)); auto addInstr = std::make_unique(is_32bit ? RVOpcodes::ADDW : RVOpcodes::ADD); addInstr->addOperand(std::make_unique(temp_reg)); addInstr->addOperand(std::make_unique(*src1_reg)); addInstr->addOperand(std::make_unique(temp_reg)); newInstrs.push_back(std::move(addInstr)); auto sraInstr = std::make_unique(is_32bit ? RVOpcodes::SRAIW : RVOpcodes::SRAI); sraInstr->addOperand(std::make_unique(temp_reg)); sraInstr->addOperand(std::make_unique(temp_reg)); sraInstr->addOperand(std::make_unique(shift)); newInstrs.push_back(std::move(sraInstr)); if (divisor < 0) { auto negInstr = std::make_unique(is_32bit ? RVOpcodes::SUBW : RVOpcodes::SUB); negInstr->addOperand(std::make_unique(*dst_reg)); negInstr->addOperand(std::make_unique(PhysicalReg::ZERO)); negInstr->addOperand(std::make_unique(temp_reg)); newInstrs.push_back(std::move(negInstr)); } else { auto moveInstr = std::make_unique(is_32bit ? RVOpcodes::ADDW : RVOpcodes::ADD); moveInstr->addOperand(std::make_unique(*dst_reg)); moveInstr->addOperand(std::make_unique(temp_reg)); moveInstr->addOperand(std::make_unique(PhysicalReg::ZERO)); newInstrs.push_back(std::move(moveInstr)); } } else { auto magic_info = computeMagic(divisor, is_32bit); int magic_reg = createTempReg(); int temp_reg = createTempReg(); auto loadInstr = std::make_unique(RVOpcodes::LI); loadInstr->addOperand(std::make_unique(magic_reg)); loadInstr->addOperand(std::make_unique(magic_info.magic)); newInstrs.push_back(std::move(loadInstr)); if (is_32bit) { auto mulInstr = std::make_unique(RVOpcodes::MUL); mulInstr->addOperand(std::make_unique(temp_reg)); mulInstr->addOperand(std::make_unique(*src1_reg)); mulInstr->addOperand(std::make_unique(magic_reg)); newInstrs.push_back(std::move(mulInstr)); auto sraInstr = std::make_unique(RVOpcodes::SRAI); sraInstr->addOperand(std::make_unique(temp_reg)); sraInstr->addOperand(std::make_unique(temp_reg)); sraInstr->addOperand(std::make_unique(magic_info.shift)); newInstrs.push_back(std::move(sraInstr)); } else { auto mulhInstr = std::make_unique(RVOpcodes::MULH); mulhInstr->addOperand(std::make_unique(temp_reg)); mulhInstr->addOperand(std::make_unique(*src1_reg)); mulhInstr->addOperand(std::make_unique(magic_reg)); newInstrs.push_back(std::move(mulhInstr)); int post_shift = magic_info.shift - 63; if (post_shift > 0) { auto sraInstr = std::make_unique(RVOpcodes::SRAI); sraInstr->addOperand(std::make_unique(temp_reg)); sraInstr->addOperand(std::make_unique(temp_reg)); sraInstr->addOperand(std::make_unique(post_shift)); newInstrs.push_back(std::move(sraInstr)); } } int sign_reg = createTempReg(); auto sraSignInstr = std::make_unique(is_32bit ? RVOpcodes::SRAIW : RVOpcodes::SRAI); sraSignInstr->addOperand(std::make_unique(sign_reg)); sraSignInstr->addOperand(std::make_unique(*src1_reg)); sraSignInstr->addOperand(std::make_unique(is_32bit ? 31 : 63)); newInstrs.push_back(std::move(sraSignInstr)); auto subInstr = std::make_unique(is_32bit ? RVOpcodes::SUBW : RVOpcodes::SUB); subInstr->addOperand(std::make_unique(temp_reg)); subInstr->addOperand(std::make_unique(temp_reg)); subInstr->addOperand(std::make_unique(sign_reg)); newInstrs.push_back(std::move(subInstr)); if (divisor < 0) { auto negInstr = std::make_unique(is_32bit ? RVOpcodes::SUBW : RVOpcodes::SUB); negInstr->addOperand(std::make_unique(*dst_reg)); negInstr->addOperand(std::make_unique(PhysicalReg::ZERO)); negInstr->addOperand(std::make_unique(temp_reg)); newInstrs.push_back(std::move(negInstr)); } else { auto moveInstr = std::make_unique(is_32bit ? RVOpcodes::ADDW : RVOpcodes::ADD); moveInstr->addOperand(std::make_unique(*dst_reg)); moveInstr->addOperand(std::make_unique(temp_reg)); moveInstr->addOperand(std::make_unique(PhysicalReg::ZERO)); newInstrs.push_back(std::move(moveInstr)); } } if (!newInstrs.empty()) { size_t start_index = i; if (instructions_to_replace == 2) { start_index = i - 1; } replacements.push_back({start_index, instructions_to_replace, std::move(newInstrs)}); } } for (auto it = replacements.rbegin(); it != replacements.rend(); ++it) { instrs.erase(instrs.begin() + it->index, instrs.begin() + it->index + it->count_to_erase); instrs.insert(instrs.begin() + it->index, std::make_move_iterator(it->newInstrs.begin()), std::make_move_iterator(it->newInstrs.end())); } } } } // namespace sysy