#include "Pass/Optimize/ConstPropagation.h" #include "IR.h" #include "Pass.h" #include #include namespace sysy { char ConstPropagation::ID = 0; bool ConstPropagation::runOnFunction(Function *func, AnalysisManager &am) { bool changed = false; bool localChanged = true; while (localChanged) { localChanged = false; for (auto &bb : func->getBasicBlocks()) { for (auto instIter = bb->getInstructions().begin(); instIter != bb->getInstructions().end();) { auto &inst = *instIter; bool shouldAdvanceIter = true; // 处理二元运算指令 if (auto *binaryInst = dynamic_cast(inst.get())) { auto *lhs = binaryInst->getLhs(); auto *rhs = binaryInst->getRhs(); auto *lhsConst = dynamic_cast(lhs); auto *rhsConst = dynamic_cast(rhs); if (lhsConst && rhsConst) { ConstantValue *newConst = nullptr; try { if (lhs->isInt() && rhs->isInt()) { int l = lhsConst->getInt(); int r = rhsConst->getInt(); int result; bool validOperation = true; switch (binaryInst->getKind()) { case Instruction::kAdd: // 检查加法溢出 if ((r > 0 && l > INT_MAX - r) || (r < 0 && l < INT_MIN - r)) { validOperation = false; } else { result = l + r; } break; case Instruction::kSub: // 检查减法溢出 if ((r < 0 && l > INT_MAX + r) || (r > 0 && l < INT_MIN + r)) { validOperation = false; } else { result = l - r; } break; case Instruction::kMul: // 检查乘法溢出 if (l != 0 && r != 0 && (std::abs(l) > INT_MAX / std::abs(r))) { validOperation = false; } else { result = l * r; } break; case Instruction::kDiv: if (r == 0) { validOperation = false; } else { result = l / r; } break; case Instruction::kRem: if (r == 0) { validOperation = false; } else { result = l % r; } break; case Instruction::kICmpEQ: result = (l == r) ? 1 : 0; break; case Instruction::kICmpNE: result = (l != r) ? 1 : 0; break; case Instruction::kICmpLT: result = (l < r) ? 1 : 0; break; case Instruction::kICmpGT: result = (l > r) ? 1 : 0; break; case Instruction::kICmpLE: result = (l <= r) ? 1 : 0; break; case Instruction::kICmpGE: result = (l >= r) ? 1 : 0; break; case Instruction::kAnd: result = (l && r) ? 1 : 0; break; case Instruction::kOr: result = (l || r) ? 1 : 0; break; default: validOperation = false; } if (validOperation) { if (binaryInst->isCmp() || binaryInst->getKind() == Instruction::kAnd || binaryInst->getKind() == Instruction::kOr) { newConst = ConstantInteger::get(Type::getIntType(), result); } else { newConst = ConstantInteger::get(result); } } } else if (lhs->isFloat() && rhs->isFloat()) { float l = lhsConst->getFloat(); float r = rhsConst->getFloat(); bool validOperation = true; switch (binaryInst->getKind()) { case Instruction::kFAdd: { float result = l + r; if (std::isfinite(result)) { newConst = ConstantFloating::get(result); } else { validOperation = false; } break; } case Instruction::kFSub: { float result = l - r; if (std::isfinite(result)) { newConst = ConstantFloating::get(result); } else { validOperation = false; } break; } case Instruction::kFMul: { float result = l * r; if (std::isfinite(result)) { newConst = ConstantFloating::get(result); } else { validOperation = false; } break; } case Instruction::kFDiv: { if (std::abs(r) < std::numeric_limits::epsilon()) { validOperation = false; } else { float result = l / r; if (std::isfinite(result)) { newConst = ConstantFloating::get(result); } else { validOperation = false; } } break; } case Instruction::kFCmpEQ: newConst = ConstantInteger::get(Type::getIntType(), (l == r) ? 1 : 0); break; case Instruction::kFCmpNE: newConst = ConstantInteger::get(Type::getIntType(), (l != r) ? 1 : 0); break; case Instruction::kFCmpLT: newConst = ConstantInteger::get(Type::getIntType(), (l < r) ? 1 : 0); break; case Instruction::kFCmpGT: newConst = ConstantInteger::get(Type::getIntType(), (l > r) ? 1 : 0); break; case Instruction::kFCmpLE: newConst = ConstantInteger::get(Type::getIntType(), (l <= r) ? 1 : 0); break; case Instruction::kFCmpGE: newConst = ConstantInteger::get(Type::getIntType(), (l >= r) ? 1 : 0); break; default: validOperation = false; } } } catch (...) { // 捕获可能的异常,跳过优化 newConst = nullptr; } if (newConst) { binaryInst->replaceAllUsesWith(newConst); instIter = bb->getInstructions().erase(instIter); shouldAdvanceIter = false; localChanged = true; } } } // 处理一元运算指令 else if (auto *unaryInst = dynamic_cast(inst.get())) { auto *operand = unaryInst->getOperand(); auto *operandConst = dynamic_cast(operand); if (operandConst) { ConstantValue *newConst = nullptr; if (operand->isInt()) { int val = operandConst->getInt(); switch (unaryInst->getKind()) { case Instruction::kNeg: if (val != INT_MIN) { // 避免溢出 newConst = ConstantInteger::get(-val); } break; case Instruction::kNot: newConst = ConstantInteger::get(Type::getIntType(), (!val) ? 1 : 0); break; default: break; } } else if (operand->isFloat()) { float val = operandConst->getFloat(); switch (unaryInst->getKind()) { case Instruction::kFNeg: newConst = ConstantFloating::get(-val); break; default: break; } } if (newConst) { unaryInst->replaceAllUsesWith(newConst); instIter = bb->getInstructions().erase(instIter); shouldAdvanceIter = false; localChanged = true; } } } if (shouldAdvanceIter) { ++instIter; } } } if (localChanged) { changed = true; } } return changed; } } // namespace sysy