diff --git a/include/ir/PassManager.h b/include/ir/PassManager.h index 062c197..a8e565d 100644 --- a/include/ir/PassManager.h +++ b/include/ir/PassManager.h @@ -36,6 +36,7 @@ class DominatorTree { bool RunMem2Reg(Function* func, Context& ctx); bool RunConstProp(Function* func, Context& ctx); bool RunConstFold(Function* func, Context& ctx); +bool RunAlgebraicSimplify(Function* func, Context& ctx); bool RunDCE(Function* func); bool RunCFGSimplify(Function* func); bool RunCSE(Function* func); diff --git a/src/ir/passes/AlgebraicSimplify.cpp b/src/ir/passes/AlgebraicSimplify.cpp new file mode 100644 index 0000000..2e3fa23 --- /dev/null +++ b/src/ir/passes/AlgebraicSimplify.cpp @@ -0,0 +1,112 @@ +#include "ir/PassManager.h" + +#include + +namespace ir { +namespace { + +bool IsConstInt(Value* value, int expected) { + auto* constant = dynamic_cast(value); + return constant && constant->GetValue() == expected; +} + +bool IsConstFloat(Value* value, float expected) { + auto* constant = dynamic_cast(value); + return constant && constant->GetValue() == expected; +} + +bool IsSameValue(Value* lhs, Value* rhs) { + return lhs == rhs; +} + +Value* SimplifyBinary(BinaryInst* bin, Context& ctx) { + auto* lhs = bin->GetLhs(); + auto* rhs = bin->GetRhs(); + + switch (bin->GetOpcode()) { + case Opcode::Add: + if (IsConstInt(rhs, 0)) return lhs; + if (IsConstInt(lhs, 0)) return rhs; + break; + case Opcode::Sub: + if (IsConstInt(rhs, 0)) return lhs; + if (IsSameValue(lhs, rhs)) return ctx.GetConstInt(0); + break; + case Opcode::Mul: + if (IsConstInt(rhs, 1)) return lhs; + if (IsConstInt(lhs, 1)) return rhs; + if (IsConstInt(rhs, 0) || IsConstInt(lhs, 0)) return ctx.GetConstInt(0); + break; + case Opcode::Div: + if (IsConstInt(rhs, 1)) return lhs; + if (IsConstInt(lhs, 0)) return ctx.GetConstInt(0); + break; + case Opcode::Mod: + if (IsConstInt(rhs, 1) || IsConstInt(lhs, 0)) return ctx.GetConstInt(0); + break; + case Opcode::FAdd: + if (IsConstFloat(rhs, 0.0f)) return lhs; + if (IsConstFloat(lhs, 0.0f)) return rhs; + break; + case Opcode::FSub: + if (IsConstFloat(rhs, 0.0f)) return lhs; + break; + case Opcode::FMul: + if (IsConstFloat(rhs, 1.0f)) return lhs; + if (IsConstFloat(lhs, 1.0f)) return rhs; + if (IsConstFloat(rhs, 0.0f) || IsConstFloat(lhs, 0.0f)) { + return ctx.GetConstFloat(0.0f); + } + break; + case Opcode::FDiv: + if (IsConstFloat(rhs, 1.0f)) return lhs; + break; + case Opcode::ICmpEQ: + if (IsSameValue(lhs, rhs)) return ctx.GetConstInt(1); + break; + case Opcode::ICmpNE: + if (IsSameValue(lhs, rhs)) return ctx.GetConstInt(0); + break; + case Opcode::ICmpLE: + case Opcode::ICmpGE: + if (IsSameValue(lhs, rhs)) return ctx.GetConstInt(1); + break; + case Opcode::ICmpLT: + case Opcode::ICmpGT: + if (IsSameValue(lhs, rhs)) return ctx.GetConstInt(0); + break; + default: + break; + } + + return nullptr; +} + +} // namespace + +bool RunAlgebraicSimplify(Function* func, Context& ctx) { + bool changed = false; + std::vector to_erase; + + for (const auto& bbPtr : func->GetBlocks()) { + for (const auto& instPtr : bbPtr->GetInstructions()) { + auto* bin = dynamic_cast(instPtr.get()); + if (!bin) { + continue; + } + if (auto* replacement = SimplifyBinary(bin, ctx)) { + bin->ReplaceAllUsesWith(replacement); + to_erase.push_back(bin); + changed = true; + } + } + } + + for (auto* inst : to_erase) { + inst->GetParent()->EraseInstruction(inst); + } + + return changed; +} + +} // namespace ir diff --git a/src/ir/passes/CMakeLists.txt b/src/ir/passes/CMakeLists.txt index d3ece9d..0a533d6 100644 --- a/src/ir/passes/CMakeLists.txt +++ b/src/ir/passes/CMakeLists.txt @@ -3,6 +3,7 @@ add_library(ir_passes STATIC Mem2Reg.cpp ConstFold.cpp ConstProp.cpp + AlgebraicSimplify.cpp CSE.cpp DCE.cpp CFGSimplify.cpp diff --git a/src/ir/passes/PassManager.cpp b/src/ir/passes/PassManager.cpp index d4d46fb..8f3009f 100644 --- a/src/ir/passes/PassManager.cpp +++ b/src/ir/passes/PassManager.cpp @@ -16,6 +16,7 @@ void RunFunctionOptimizationPasses(Function* func, Context& ctx) { changed |= RunConstProp(func, ctx); changed |= RunConstFold(func, ctx); + changed |= RunAlgebraicSimplify(func, ctx); changed |= RunCSE(func); changed |= RunLICM(func); changed |= RunDCE(func);