add algebraic IR simplification

This commit is contained in:
2026-06-30 00:31:17 +08:00
parent 108f3d9e4b
commit 6f943b395f
4 changed files with 115 additions and 0 deletions

View File

@@ -36,6 +36,7 @@ class DominatorTree {
bool RunMem2Reg(Function* func, Context& ctx); bool RunMem2Reg(Function* func, Context& ctx);
bool RunConstProp(Function* func, Context& ctx); bool RunConstProp(Function* func, Context& ctx);
bool RunConstFold(Function* func, Context& ctx); bool RunConstFold(Function* func, Context& ctx);
bool RunAlgebraicSimplify(Function* func, Context& ctx);
bool RunDCE(Function* func); bool RunDCE(Function* func);
bool RunCFGSimplify(Function* func); bool RunCFGSimplify(Function* func);
bool RunCSE(Function* func); bool RunCSE(Function* func);

View File

@@ -0,0 +1,112 @@
#include "ir/PassManager.h"
#include <vector>
namespace ir {
namespace {
bool IsConstInt(Value* value, int expected) {
auto* constant = dynamic_cast<ConstantInt*>(value);
return constant && constant->GetValue() == expected;
}
bool IsConstFloat(Value* value, float expected) {
auto* constant = dynamic_cast<ConstantFloat*>(value);
return constant && constant->GetValue() == expected;
}
bool IsSameValue(Value* lhs, Value* rhs) {
return lhs == rhs;
}
Value* SimplifyBinary(BinaryInst* bin, Context& ctx) {
auto* lhs = bin->GetLhs();
auto* rhs = bin->GetRhs();
switch (bin->GetOpcode()) {
case Opcode::Add:
if (IsConstInt(rhs, 0)) return lhs;
if (IsConstInt(lhs, 0)) return rhs;
break;
case Opcode::Sub:
if (IsConstInt(rhs, 0)) return lhs;
if (IsSameValue(lhs, rhs)) return ctx.GetConstInt(0);
break;
case Opcode::Mul:
if (IsConstInt(rhs, 1)) return lhs;
if (IsConstInt(lhs, 1)) return rhs;
if (IsConstInt(rhs, 0) || IsConstInt(lhs, 0)) return ctx.GetConstInt(0);
break;
case Opcode::Div:
if (IsConstInt(rhs, 1)) return lhs;
if (IsConstInt(lhs, 0)) return ctx.GetConstInt(0);
break;
case Opcode::Mod:
if (IsConstInt(rhs, 1) || IsConstInt(lhs, 0)) return ctx.GetConstInt(0);
break;
case Opcode::FAdd:
if (IsConstFloat(rhs, 0.0f)) return lhs;
if (IsConstFloat(lhs, 0.0f)) return rhs;
break;
case Opcode::FSub:
if (IsConstFloat(rhs, 0.0f)) return lhs;
break;
case Opcode::FMul:
if (IsConstFloat(rhs, 1.0f)) return lhs;
if (IsConstFloat(lhs, 1.0f)) return rhs;
if (IsConstFloat(rhs, 0.0f) || IsConstFloat(lhs, 0.0f)) {
return ctx.GetConstFloat(0.0f);
}
break;
case Opcode::FDiv:
if (IsConstFloat(rhs, 1.0f)) return lhs;
break;
case Opcode::ICmpEQ:
if (IsSameValue(lhs, rhs)) return ctx.GetConstInt(1);
break;
case Opcode::ICmpNE:
if (IsSameValue(lhs, rhs)) return ctx.GetConstInt(0);
break;
case Opcode::ICmpLE:
case Opcode::ICmpGE:
if (IsSameValue(lhs, rhs)) return ctx.GetConstInt(1);
break;
case Opcode::ICmpLT:
case Opcode::ICmpGT:
if (IsSameValue(lhs, rhs)) return ctx.GetConstInt(0);
break;
default:
break;
}
return nullptr;
}
} // namespace
bool RunAlgebraicSimplify(Function* func, Context& ctx) {
bool changed = false;
std::vector<Instruction*> to_erase;
for (const auto& bbPtr : func->GetBlocks()) {
for (const auto& instPtr : bbPtr->GetInstructions()) {
auto* bin = dynamic_cast<BinaryInst*>(instPtr.get());
if (!bin) {
continue;
}
if (auto* replacement = SimplifyBinary(bin, ctx)) {
bin->ReplaceAllUsesWith(replacement);
to_erase.push_back(bin);
changed = true;
}
}
}
for (auto* inst : to_erase) {
inst->GetParent()->EraseInstruction(inst);
}
return changed;
}
} // namespace ir

View File

@@ -3,6 +3,7 @@ add_library(ir_passes STATIC
Mem2Reg.cpp Mem2Reg.cpp
ConstFold.cpp ConstFold.cpp
ConstProp.cpp ConstProp.cpp
AlgebraicSimplify.cpp
CSE.cpp CSE.cpp
DCE.cpp DCE.cpp
CFGSimplify.cpp CFGSimplify.cpp

View File

@@ -16,6 +16,7 @@ void RunFunctionOptimizationPasses(Function* func, Context& ctx) {
changed |= RunConstProp(func, ctx); changed |= RunConstProp(func, ctx);
changed |= RunConstFold(func, ctx); changed |= RunConstFold(func, ctx);
changed |= RunAlgebraicSimplify(func, ctx);
changed |= RunCSE(func); changed |= RunCSE(func);
changed |= RunLICM(func); changed |= RunLICM(func);
changed |= RunDCE(func); changed |= RunDCE(func);