diff --git a/src/midend/Pass/Optimize/GlobalStrengthReduction.cpp b/src/midend/Pass/Optimize/GlobalStrengthReduction.cpp index e15a9c5..e8254a2 100644 --- a/src/midend/Pass/Optimize/GlobalStrengthReduction.cpp +++ b/src/midend/Pass/Optimize/GlobalStrengthReduction.cpp @@ -390,8 +390,8 @@ bool GlobalStrengthReductionContext::optimizeLogical(BinaryInst *inst) { return true; } - // x && 1 = x - if (isConstantInt(rhs, constVal) && constVal == 1) { + // x && -1 = x + if (isConstantInt(rhs, constVal) && constVal == -1) { if (DEBUG) { std::cout << " Algebraic: " << inst->getName() << " = x && 1 -> x" << std::endl; } @@ -416,15 +416,6 @@ bool GlobalStrengthReductionContext::optimizeLogical(BinaryInst *inst) { replaceWithOptimized(inst, lhs); return true; } - - // x || 1 = 1 - if (isConstantInt(rhs, constVal) && constVal == 1) { - if (DEBUG) { - std::cout << " Algebraic: " << inst->getName() << " = x || 1 -> 1" << std::endl; - } - replaceWithOptimized(inst, getConstantInt(1)); - return true; - } // x || x = x if (lhs == rhs) { @@ -630,16 +621,50 @@ bool GlobalStrengthReductionContext::reduceDivision(BinaryInst *inst) { // x / 2^n = x >> n (对于无符号除法或已知为正数的情况) if (isConstantInt(rhs, constVal) && constVal > 0 && isPowerOfTwo(constVal)) { + builder->setPosition(inst->getParent(), inst->getParent()->findInstIterator(inst)); int shiftAmount = log2OfPowerOfTwo(constVal); + // 有符号除法校正:(x + (x >> 31) & mask) >> k + int maskValue = constVal - 1; + + // x >> 31 (算术右移获取符号位) + Value* signShift = ConstantInteger::get(31); + Value* signBits = builder->createBinaryInst( + Instruction::Kind::kSra, // 算术右移 + lhs->getType(), + lhs, + signShift + ); + + // (x >> 31) & mask + Value* mask = ConstantInteger::get(maskValue); + Value* correction = builder->createBinaryInst( + Instruction::Kind::kAnd, + lhs->getType(), + signBits, + mask + ); + + // x + correction + Value* corrected = builder->createAddInst(lhs, correction); + + // (x + correction) >> k + Value* divShift = ConstantInteger::get(shiftAmount); + Value* shiftInst = builder->createBinaryInst( + Instruction::Kind::kSra, // 算术右移 + lhs->getType(), + corrected, + divShift + ); + if (DEBUG) { std::cout << " StrengthReduction: " << inst->getName() - << " = x / " << constVal << " -> x >> " << shiftAmount << std::endl; + << " = x / " << constVal << " -> (x + (x >> 31) & mask) >> " << shiftAmount << std::endl; } - builder->setPosition(inst->getParent(), inst->getParent()->findInstIterator(inst)); - Value* divisor_minus_1 = ConstantInteger::get(constVal - 1); - Value* adjusted = builder->createAddInst(lhs, divisor_minus_1); - Value* shiftInst = builder->createBinaryInst(Instruction::kSra, Type::getIntType(), adjusted, getConstantInt(shiftAmount)); + // builder->setPosition(inst->getParent(), inst->getParent()->findInstIterator(inst)); + // Value* divisor_minus_1 = ConstantInteger::get(constVal - 1); + // Value* adjusted = builder->createAddInst(lhs, divisor_minus_1); + // Value* shiftInst = builder->createBinaryInst(Instruction::kSra, Type::getIntType(), adjusted, getConstantInt(shiftAmount)); replaceWithOptimized(inst, shiftInst); strengthReductionCount++; return true;