From ad74e435bad01cc633d5b05375959c54067c3dd7 Mon Sep 17 00:00:00 2001 From: rain2133 <1370973498@qq.com> Date: Mon, 18 Aug 2025 21:55:57 +0800 Subject: [PATCH] =?UTF-8?q?[midend-GSR]=E4=BF=AE=E5=A4=8D=E9=94=99?= =?UTF-8?q?=E8=AF=AF=E7=9A=84=E4=BB=A3=E6=95=B0=E7=AE=80=E5=8C=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../Pass/Optimize/GlobalStrengthReduction.cpp | 57 +++++++++++++------ 1 file changed, 41 insertions(+), 16 deletions(-) diff --git a/src/midend/Pass/Optimize/GlobalStrengthReduction.cpp b/src/midend/Pass/Optimize/GlobalStrengthReduction.cpp index e15a9c5..e8254a2 100644 --- a/src/midend/Pass/Optimize/GlobalStrengthReduction.cpp +++ b/src/midend/Pass/Optimize/GlobalStrengthReduction.cpp @@ -390,8 +390,8 @@ bool GlobalStrengthReductionContext::optimizeLogical(BinaryInst *inst) { return true; } - // x && 1 = x - if (isConstantInt(rhs, constVal) && constVal == 1) { + // x && -1 = x + if (isConstantInt(rhs, constVal) && constVal == -1) { if (DEBUG) { std::cout << " Algebraic: " << inst->getName() << " = x && 1 -> x" << std::endl; } @@ -416,15 +416,6 @@ bool GlobalStrengthReductionContext::optimizeLogical(BinaryInst *inst) { replaceWithOptimized(inst, lhs); return true; } - - // x || 1 = 1 - if (isConstantInt(rhs, constVal) && constVal == 1) { - if (DEBUG) { - std::cout << " Algebraic: " << inst->getName() << " = x || 1 -> 1" << std::endl; - } - replaceWithOptimized(inst, getConstantInt(1)); - return true; - } // x || x = x if (lhs == rhs) { @@ -630,16 +621,50 @@ bool GlobalStrengthReductionContext::reduceDivision(BinaryInst *inst) { // x / 2^n = x >> n (对于无符号除法或已知为正数的情况) if (isConstantInt(rhs, constVal) && constVal > 0 && isPowerOfTwo(constVal)) { + builder->setPosition(inst->getParent(), inst->getParent()->findInstIterator(inst)); int shiftAmount = log2OfPowerOfTwo(constVal); + // 有符号除法校正:(x + (x >> 31) & mask) >> k + int maskValue = constVal - 1; + + // x >> 31 (算术右移获取符号位) + Value* signShift = ConstantInteger::get(31); + Value* signBits = builder->createBinaryInst( + Instruction::Kind::kSra, // 算术右移 + lhs->getType(), + lhs, + signShift + ); + + // (x >> 31) & mask + Value* mask = ConstantInteger::get(maskValue); + Value* correction = builder->createBinaryInst( + Instruction::Kind::kAnd, + lhs->getType(), + signBits, + mask + ); + + // x + correction + Value* corrected = builder->createAddInst(lhs, correction); + + // (x + correction) >> k + Value* divShift = ConstantInteger::get(shiftAmount); + Value* shiftInst = builder->createBinaryInst( + Instruction::Kind::kSra, // 算术右移 + lhs->getType(), + corrected, + divShift + ); + if (DEBUG) { std::cout << " StrengthReduction: " << inst->getName() - << " = x / " << constVal << " -> x >> " << shiftAmount << std::endl; + << " = x / " << constVal << " -> (x + (x >> 31) & mask) >> " << shiftAmount << std::endl; } - builder->setPosition(inst->getParent(), inst->getParent()->findInstIterator(inst)); - Value* divisor_minus_1 = ConstantInteger::get(constVal - 1); - Value* adjusted = builder->createAddInst(lhs, divisor_minus_1); - Value* shiftInst = builder->createBinaryInst(Instruction::kSra, Type::getIntType(), adjusted, getConstantInt(shiftAmount)); + // builder->setPosition(inst->getParent(), inst->getParent()->findInstIterator(inst)); + // Value* divisor_minus_1 = ConstantInteger::get(constVal - 1); + // Value* adjusted = builder->createAddInst(lhs, divisor_minus_1); + // Value* shiftInst = builder->createBinaryInst(Instruction::kSra, Type::getIntType(), adjusted, getConstantInt(shiftAmount)); replaceWithOptimized(inst, shiftInst); strengthReductionCount++; return true;