diff --git a/src/IR.cpp b/src/IR.cpp index 5d2765d..5f4e0c5 100644 --- a/src/IR.cpp +++ b/src/IR.cpp @@ -102,30 +102,54 @@ void Value::replaceAllUsesWith(Value *value) { uses.clear(); } -ConstantValue* ConstantValue::get(int value) { - static std::map> intConstants; - auto iter = intConstants.find(value); - if (iter != intConstants.end()) { - return iter->second.get(); + +// Implementations for static members + +std::unordered_map ConstantValue::mConstantPool; +std::unordered_map UndefinedValue::UndefValues; + +ConstantValue* ConstantValue::get(Type* type, ConstantValVariant val) { + ConstantValueKey key = {type, val}; + auto it = mConstantPool.find(key); + if (it != mConstantPool.end()) { + return it->second; } - auto inst = new ConstantValue(value); - assert(inst); - auto result = intConstants.emplace(value, inst); - return result.first->second.get(); + + ConstantValue* newConstant = nullptr; + if (std::holds_alternative(val)) { + newConstant = new ConstantInteger(type, std::get(val)); + } else if (std::holds_alternative(val)) { + newConstant = new ConstantFloating(type, std::get(val)); + } else { + assert(false && "Unsupported ConstantValVariant type"); + } + + mConstantPool[key] = newConstant; + return newConstant; } -ConstantValue* ConstantValue::get(float value) { - static std::map> floatConstants; - auto iter = floatConstants.find(value); - if (iter != floatConstants.end()) { - return iter->second.get(); - } - auto inst = new ConstantValue(value); - assert(inst); - auto result = floatConstants.emplace(value, inst); - return result.first->second.get(); +ConstantInteger* ConstantInteger::get(Type* type, int val) { + return dynamic_cast(ConstantValue::get(type, val)); } +ConstantFloating* ConstantFloating::get(Type* type, float val) { + return dynamic_cast(ConstantValue::get(type, val)); +} + +UndefinedValue* UndefinedValue::get(Type* type) { + assert(!type->isVoid() && "Cannot get UndefinedValue of void type!"); + + auto it = UndefValues.find(type); + if (it != UndefValues.end()) { + return it->second; + } + + UndefinedValue* newUndef = new UndefinedValue(type); + UndefValues[type] = newUndef; + return newUndef; +} + + auto Function::getCalleesWithNoExternalAndSelf() -> std::set { std::set result; for (auto callee : callees) { diff --git a/src/include/IR.h b/src/include/IR.h index b23689a..060bdc5 100644 --- a/src/include/IR.h +++ b/src/include/IR.h @@ -271,7 +271,7 @@ class ValueCounter { // --- Refactored ConstantValue and related classes start here --- -using ConstantValVariant = std::variant; +using ConstantValVariant = std::variant; // Helper for hashing std::variant struct VariantHash { @@ -320,6 +320,10 @@ struct ConstantValueEqual { * `Value`s. It's type is either `int` or `float`. * `ConstantValue`并不由指令定义, 也不使用任何Value。它的类型为int/float。 */ + +template struct always_false : std::false_type {}; +template constexpr bool always_false_v = always_false::value; + class ConstantValue : public Value { protected: static std::unordered_map mConstantPool; @@ -329,7 +333,7 @@ public: virtual ~ConstantValue() = default; virtual size_t hash() const = 0; - virtual ConstantValVariant getValue() const = 0; + virtual ConstantValVariant getVal() const = 0; // Static factory method to get a canonical ConstantValue from the pool static ConstantValue* get(Type* type, ConstantValVariant val); @@ -337,23 +341,23 @@ public: // Helper methods to access constant values with appropriate casting int getInt() const { assert(getType()->isInt() && "Calling getInt() on non-integer type"); - return std::get(getValue()); + return std::get(getVal()); } float getFloat() const { assert(getType()->isFloat() && "Calling getFloat() on non-float type"); - return std::get(getValue()); + return std::get(getVal()); } template - T getValue() const { - if constexpr (std::is_same_v) { - return getInt(); - } else if constexpr (std::is_same_v) { - return getFloat(); - } else { - // This ensures a compilation error if an unsupported type is used - static_assert(std::always_false_v, "Unsupported type for ConstantValue::getValue()"); - } + T getVal() const { + if constexpr (std::is_same_v) { + return getInt(); + } else if constexpr (std::is_same_v) { + return getFloat(); + } else { + // This ensures a compilation error if an unsupported type is used + static_assert(always_false_v, "Unsupported type for ConstantValue::getValue()"); + } } virtual bool isZero() const = 0; @@ -372,7 +376,7 @@ public: return typeHash ^ (valHash << 1); } int getInt() const { return constVal; } - ConstantValVariant getValue() const override { return constVal; } + ConstantValVariant getVal() const override { return constVal; } static ConstantInteger* get(Type* type, int val); static ConstantInteger* get(int val) { return get(Type::getIntType(), val); } @@ -398,7 +402,7 @@ public: return typeHash ^ (valHash << 1); } float getFloat() const { return constFVal; } - ConstantValVariant getValue() const override { return constFVal; } + ConstantValVariant getVal() const override { return constFVal; } static ConstantFloating* get(Type* type, float val); static ConstantFloating* get(float val) { return get(Type::getFloatType(), val); } @@ -429,7 +433,7 @@ public: return std::hash{}(getType()); } - ConstantValVariant getValue() const override { + ConstantValVariant getVal() const override { if (getType()->isInt()) { return 0; // Return 0 for undefined integer } else if (getType()->isFloat()) { @@ -443,52 +447,6 @@ public: bool isOne() const override { return false; } }; -// Implementations for static members (typically in .cpp, but for single-file, put here) - -std::unordered_map ConstantValue::mConstantPool; -std::unordered_map UndefinedValue::UndefValues; - -ConstantValue* ConstantValue::get(Type* type, ConstantValVariant val) { - ConstantValueKey key = {type, val}; - auto it = mConstantPool.find(key); - if (it != mConstantPool.end()) { - return it->second; - } - - ConstantValue* newConstant = nullptr; - if (std::holds_alternative(val)) { - newConstant = new ConstantInteger(type, std::get(val)); - } else if (std::holds_alternative(val)) { - newConstant = new ConstantFloating(type, std::get(val)); - } else { - assert(false && "Unsupported ConstantValVariant type"); - } - - mConstantPool[key] = newConstant; - return newConstant; -} - -ConstantInteger* ConstantInteger::get(Type* type, int val) { - return dynamic_cast(ConstantValue::get(type, val)); -} - -ConstantFloating* ConstantFloating::get(Type* type, float val) { - return dynamic_cast(ConstantValue::get(type, val)); -} - -UndefinedValue* UndefinedValue::get(Type* type) { - assert(!type->isVoid() && "Cannot get UndefinedValue of void type!"); - - auto it = UndefValues.find(type); - if (it != UndefValues.end()) { - return it->second; - } - - UndefinedValue* newUndef = new UndefinedValue(type); - UndefValues[type] = newUndef; - return newUndef; -} - // --- End of refactored ConstantValue and related classes --- @@ -941,7 +899,7 @@ class PhiInst : public Instruction { PhiInst(Type *type, const std::vector &rhs = {}, const std::vector &Blocks = {}, - BasicBlock *parent, + BasicBlock *parent = nullptr, const std::string &name = "") : Instruction(Kind::kPhi, type, parent, name), vsize(rhs.size()) { assert(rhs.size() == Blocks.size() && "PhiInst: rhs and Blocks must have the same size"); @@ -977,7 +935,6 @@ class PhiInst : public Instruction { void refreshB2VMap(); auto getValues() { return make_range(std::next(operand_begin()), operand_end()); } - Value* getValue(unsigned index) const { return getOperand(index + 1); } };