Files
mysysy/src/midend/Pass/Optimize/ConstPropagation.cpp

241 lines
12 KiB
C++

#include "Pass/Optimize/ConstPropagation.h"
#include "IR.h"
#include "Pass.h"
#include <climits>
#include <cmath>
namespace sysy {
char ConstPropagation::ID = 0;
bool ConstPropagation::runOnFunction(Function *func, AnalysisManager &am) {
bool changed = false;
bool localChanged = true;
while (localChanged) {
localChanged = false;
for (auto &bb : func->getBasicBlocks()) {
for (auto instIter = bb->getInstructions().begin();
instIter != bb->getInstructions().end();) {
auto &inst = *instIter;
bool shouldAdvanceIter = true;
// 处理二元运算指令
if (auto *binaryInst = dynamic_cast<BinaryInst *>(inst.get())) {
auto *lhs = binaryInst->getLhs();
auto *rhs = binaryInst->getRhs();
auto *lhsConst = dynamic_cast<ConstantValue *>(lhs);
auto *rhsConst = dynamic_cast<ConstantValue *>(rhs);
if (lhsConst && rhsConst) {
ConstantValue *newConst = nullptr;
try {
if (lhs->isInt() && rhs->isInt()) {
int l = lhsConst->getInt();
int r = rhsConst->getInt();
int result;
bool validOperation = true;
switch (binaryInst->getKind()) {
case Instruction::kAdd:
// 检查加法溢出
if ((r > 0 && l > INT_MAX - r) || (r < 0 && l < INT_MIN - r)) {
validOperation = false;
} else {
result = l + r;
}
break;
case Instruction::kSub:
// 检查减法溢出
if ((r < 0 && l > INT_MAX + r) || (r > 0 && l < INT_MIN + r)) {
validOperation = false;
} else {
result = l - r;
}
break;
case Instruction::kMul:
// 检查乘法溢出
if (l != 0 && r != 0 &&
(std::abs(l) > INT_MAX / std::abs(r))) {
validOperation = false;
} else {
result = l * r;
}
break;
case Instruction::kDiv:
if (r == 0) {
validOperation = false;
} else {
result = l / r;
}
break;
case Instruction::kRem:
if (r == 0) {
validOperation = false;
} else {
result = l % r;
}
break;
case Instruction::kICmpEQ: result = (l == r) ? 1 : 0; break;
case Instruction::kICmpNE: result = (l != r) ? 1 : 0; break;
case Instruction::kICmpLT: result = (l < r) ? 1 : 0; break;
case Instruction::kICmpGT: result = (l > r) ? 1 : 0; break;
case Instruction::kICmpLE: result = (l <= r) ? 1 : 0; break;
case Instruction::kICmpGE: result = (l >= r) ? 1 : 0; break;
case Instruction::kAnd: result = (l && r) ? 1 : 0; break;
case Instruction::kOr: result = (l || r) ? 1 : 0; break;
default:
validOperation = false;
}
if (validOperation) {
if (binaryInst->isCmp() || binaryInst->getKind() == Instruction::kAnd ||
binaryInst->getKind() == Instruction::kOr) {
newConst = ConstantInteger::get(Type::getIntType(), result);
} else {
newConst = ConstantInteger::get(result);
}
}
} else if (lhs->isFloat() && rhs->isFloat()) {
float l = lhsConst->getFloat();
float r = rhsConst->getFloat();
bool validOperation = true;
switch (binaryInst->getKind()) {
case Instruction::kFAdd: {
float result = l + r;
if (std::isfinite(result)) {
newConst = ConstantFloating::get(result);
} else {
validOperation = false;
}
break;
}
case Instruction::kFSub: {
float result = l - r;
if (std::isfinite(result)) {
newConst = ConstantFloating::get(result);
} else {
validOperation = false;
}
break;
}
case Instruction::kFMul: {
float result = l * r;
if (std::isfinite(result)) {
newConst = ConstantFloating::get(result);
} else {
validOperation = false;
}
break;
}
case Instruction::kFDiv: {
if (std::abs(r) < std::numeric_limits<float>::epsilon()) {
validOperation = false;
} else {
float result = l / r;
if (std::isfinite(result)) {
newConst = ConstantFloating::get(result);
} else {
validOperation = false;
}
}
break;
}
case Instruction::kFCmpEQ:
newConst = ConstantInteger::get(Type::getIntType(), (l == r) ? 1 : 0);
break;
case Instruction::kFCmpNE:
newConst = ConstantInteger::get(Type::getIntType(), (l != r) ? 1 : 0);
break;
case Instruction::kFCmpLT:
newConst = ConstantInteger::get(Type::getIntType(), (l < r) ? 1 : 0);
break;
case Instruction::kFCmpGT:
newConst = ConstantInteger::get(Type::getIntType(), (l > r) ? 1 : 0);
break;
case Instruction::kFCmpLE:
newConst = ConstantInteger::get(Type::getIntType(), (l <= r) ? 1 : 0);
break;
case Instruction::kFCmpGE:
newConst = ConstantInteger::get(Type::getIntType(), (l >= r) ? 1 : 0);
break;
default:
validOperation = false;
}
}
} catch (...) {
// 捕获可能的异常,跳过优化
newConst = nullptr;
}
if (newConst) {
binaryInst->replaceAllUsesWith(newConst);
instIter = bb->getInstructions().erase(instIter);
shouldAdvanceIter = false;
localChanged = true;
}
}
}
// 处理一元运算指令
else if (auto *unaryInst = dynamic_cast<UnaryInst *>(inst.get())) {
auto *operand = unaryInst->getOperand();
auto *operandConst = dynamic_cast<ConstantValue *>(operand);
if (operandConst) {
ConstantValue *newConst = nullptr;
if (operand->isInt()) {
int val = operandConst->getInt();
switch (unaryInst->getKind()) {
case Instruction::kNeg:
if (val != INT_MIN) { // 避免溢出
newConst = ConstantInteger::get(-val);
}
break;
case Instruction::kNot:
newConst = ConstantInteger::get(Type::getIntType(), (!val) ? 1 : 0);
break;
default:
break;
}
} else if (operand->isFloat()) {
float val = operandConst->getFloat();
switch (unaryInst->getKind()) {
case Instruction::kFNeg:
newConst = ConstantFloating::get(-val);
break;
default:
break;
}
}
if (newConst) {
unaryInst->replaceAllUsesWith(newConst);
instIter = bb->getInstructions().erase(instIter);
shouldAdvanceIter = false;
localChanged = true;
}
}
}
if (shouldAdvanceIter) {
++instIter;
}
}
}
if (localChanged) {
changed = true;
}
}
return changed;
}
} // namespace sysy