241 lines
12 KiB
C++
241 lines
12 KiB
C++
#include "Pass/Optimize/ConstPropagation.h"
|
|
#include "IR.h"
|
|
#include "Pass.h"
|
|
#include <climits>
|
|
#include <cmath>
|
|
|
|
namespace sysy {
|
|
|
|
char ConstPropagation::ID = 0;
|
|
|
|
bool ConstPropagation::runOnFunction(Function *func, AnalysisManager &am) {
|
|
bool changed = false;
|
|
bool localChanged = true;
|
|
|
|
while (localChanged) {
|
|
localChanged = false;
|
|
|
|
for (auto &bb : func->getBasicBlocks()) {
|
|
for (auto instIter = bb->getInstructions().begin();
|
|
instIter != bb->getInstructions().end();) {
|
|
auto &inst = *instIter;
|
|
bool shouldAdvanceIter = true;
|
|
|
|
// 处理二元运算指令
|
|
if (auto *binaryInst = dynamic_cast<BinaryInst *>(inst.get())) {
|
|
auto *lhs = binaryInst->getLhs();
|
|
auto *rhs = binaryInst->getRhs();
|
|
|
|
auto *lhsConst = dynamic_cast<ConstantValue *>(lhs);
|
|
auto *rhsConst = dynamic_cast<ConstantValue *>(rhs);
|
|
|
|
if (lhsConst && rhsConst) {
|
|
ConstantValue *newConst = nullptr;
|
|
|
|
try {
|
|
if (lhs->isInt() && rhs->isInt()) {
|
|
int l = lhsConst->getInt();
|
|
int r = rhsConst->getInt();
|
|
int result;
|
|
bool validOperation = true;
|
|
|
|
switch (binaryInst->getKind()) {
|
|
case Instruction::kAdd:
|
|
// 检查加法溢出
|
|
if ((r > 0 && l > INT_MAX - r) || (r < 0 && l < INT_MIN - r)) {
|
|
validOperation = false;
|
|
} else {
|
|
result = l + r;
|
|
}
|
|
break;
|
|
case Instruction::kSub:
|
|
// 检查减法溢出
|
|
if ((r < 0 && l > INT_MAX + r) || (r > 0 && l < INT_MIN + r)) {
|
|
validOperation = false;
|
|
} else {
|
|
result = l - r;
|
|
}
|
|
break;
|
|
case Instruction::kMul:
|
|
// 检查乘法溢出
|
|
if (l != 0 && r != 0 &&
|
|
(std::abs(l) > INT_MAX / std::abs(r))) {
|
|
validOperation = false;
|
|
} else {
|
|
result = l * r;
|
|
}
|
|
break;
|
|
case Instruction::kDiv:
|
|
if (r == 0) {
|
|
validOperation = false;
|
|
} else {
|
|
result = l / r;
|
|
}
|
|
break;
|
|
case Instruction::kRem:
|
|
if (r == 0) {
|
|
validOperation = false;
|
|
} else {
|
|
result = l % r;
|
|
}
|
|
break;
|
|
case Instruction::kICmpEQ: result = (l == r) ? 1 : 0; break;
|
|
case Instruction::kICmpNE: result = (l != r) ? 1 : 0; break;
|
|
case Instruction::kICmpLT: result = (l < r) ? 1 : 0; break;
|
|
case Instruction::kICmpGT: result = (l > r) ? 1 : 0; break;
|
|
case Instruction::kICmpLE: result = (l <= r) ? 1 : 0; break;
|
|
case Instruction::kICmpGE: result = (l >= r) ? 1 : 0; break;
|
|
case Instruction::kAnd: result = (l && r) ? 1 : 0; break;
|
|
case Instruction::kOr: result = (l || r) ? 1 : 0; break;
|
|
default:
|
|
validOperation = false;
|
|
}
|
|
|
|
if (validOperation) {
|
|
if (binaryInst->isCmp() || binaryInst->getKind() == Instruction::kAnd ||
|
|
binaryInst->getKind() == Instruction::kOr) {
|
|
newConst = ConstantInteger::get(Type::getIntType(), result);
|
|
} else {
|
|
newConst = ConstantInteger::get(result);
|
|
}
|
|
}
|
|
} else if (lhs->isFloat() && rhs->isFloat()) {
|
|
float l = lhsConst->getFloat();
|
|
float r = rhsConst->getFloat();
|
|
bool validOperation = true;
|
|
|
|
switch (binaryInst->getKind()) {
|
|
case Instruction::kFAdd: {
|
|
float result = l + r;
|
|
if (std::isfinite(result)) {
|
|
newConst = ConstantFloating::get(result);
|
|
} else {
|
|
validOperation = false;
|
|
}
|
|
break;
|
|
}
|
|
case Instruction::kFSub: {
|
|
float result = l - r;
|
|
if (std::isfinite(result)) {
|
|
newConst = ConstantFloating::get(result);
|
|
} else {
|
|
validOperation = false;
|
|
}
|
|
break;
|
|
}
|
|
case Instruction::kFMul: {
|
|
float result = l * r;
|
|
if (std::isfinite(result)) {
|
|
newConst = ConstantFloating::get(result);
|
|
} else {
|
|
validOperation = false;
|
|
}
|
|
break;
|
|
}
|
|
case Instruction::kFDiv: {
|
|
if (std::abs(r) < std::numeric_limits<float>::epsilon()) {
|
|
validOperation = false;
|
|
} else {
|
|
float result = l / r;
|
|
if (std::isfinite(result)) {
|
|
newConst = ConstantFloating::get(result);
|
|
} else {
|
|
validOperation = false;
|
|
}
|
|
}
|
|
break;
|
|
}
|
|
case Instruction::kFCmpEQ:
|
|
newConst = ConstantInteger::get(Type::getIntType(), (l == r) ? 1 : 0);
|
|
break;
|
|
case Instruction::kFCmpNE:
|
|
newConst = ConstantInteger::get(Type::getIntType(), (l != r) ? 1 : 0);
|
|
break;
|
|
case Instruction::kFCmpLT:
|
|
newConst = ConstantInteger::get(Type::getIntType(), (l < r) ? 1 : 0);
|
|
break;
|
|
case Instruction::kFCmpGT:
|
|
newConst = ConstantInteger::get(Type::getIntType(), (l > r) ? 1 : 0);
|
|
break;
|
|
case Instruction::kFCmpLE:
|
|
newConst = ConstantInteger::get(Type::getIntType(), (l <= r) ? 1 : 0);
|
|
break;
|
|
case Instruction::kFCmpGE:
|
|
newConst = ConstantInteger::get(Type::getIntType(), (l >= r) ? 1 : 0);
|
|
break;
|
|
default:
|
|
validOperation = false;
|
|
}
|
|
}
|
|
} catch (...) {
|
|
// 捕获可能的异常,跳过优化
|
|
newConst = nullptr;
|
|
}
|
|
|
|
if (newConst) {
|
|
binaryInst->replaceAllUsesWith(newConst);
|
|
instIter = bb->getInstructions().erase(instIter);
|
|
shouldAdvanceIter = false;
|
|
localChanged = true;
|
|
}
|
|
}
|
|
}
|
|
// 处理一元运算指令
|
|
else if (auto *unaryInst = dynamic_cast<UnaryInst *>(inst.get())) {
|
|
auto *operand = unaryInst->getOperand();
|
|
auto *operandConst = dynamic_cast<ConstantValue *>(operand);
|
|
|
|
if (operandConst) {
|
|
ConstantValue *newConst = nullptr;
|
|
|
|
if (operand->isInt()) {
|
|
int val = operandConst->getInt();
|
|
|
|
switch (unaryInst->getKind()) {
|
|
case Instruction::kNeg:
|
|
if (val != INT_MIN) { // 避免溢出
|
|
newConst = ConstantInteger::get(-val);
|
|
}
|
|
break;
|
|
case Instruction::kNot:
|
|
newConst = ConstantInteger::get(Type::getIntType(), (!val) ? 1 : 0);
|
|
break;
|
|
default:
|
|
break;
|
|
}
|
|
} else if (operand->isFloat()) {
|
|
float val = operandConst->getFloat();
|
|
|
|
switch (unaryInst->getKind()) {
|
|
case Instruction::kFNeg:
|
|
newConst = ConstantFloating::get(-val);
|
|
break;
|
|
default:
|
|
break;
|
|
}
|
|
}
|
|
|
|
if (newConst) {
|
|
unaryInst->replaceAllUsesWith(newConst);
|
|
instIter = bb->getInstructions().erase(instIter);
|
|
shouldAdvanceIter = false;
|
|
localChanged = true;
|
|
}
|
|
}
|
|
}
|
|
|
|
if (shouldAdvanceIter) {
|
|
++instIter;
|
|
}
|
|
}
|
|
}
|
|
|
|
if (localChanged) {
|
|
changed = true;
|
|
}
|
|
}
|
|
|
|
return changed;
|
|
}
|
|
|
|
} // namespace sysy
|