diff --git a/src/IR.cpp b/src/IR.cpp index 540f974..5f4e0c5 100644 --- a/src/IR.cpp +++ b/src/IR.cpp @@ -102,30 +102,54 @@ void Value::replaceAllUsesWith(Value *value) { uses.clear(); } -ConstantValue* ConstantValue::get(int value) { - static std::map> intConstants; - auto iter = intConstants.find(value); - if (iter != intConstants.end()) { - return iter->second.get(); + +// Implementations for static members + +std::unordered_map ConstantValue::mConstantPool; +std::unordered_map UndefinedValue::UndefValues; + +ConstantValue* ConstantValue::get(Type* type, ConstantValVariant val) { + ConstantValueKey key = {type, val}; + auto it = mConstantPool.find(key); + if (it != mConstantPool.end()) { + return it->second; } - auto inst = new ConstantValue(value); - assert(inst); - auto result = intConstants.emplace(value, inst); - return result.first->second.get(); + + ConstantValue* newConstant = nullptr; + if (std::holds_alternative(val)) { + newConstant = new ConstantInteger(type, std::get(val)); + } else if (std::holds_alternative(val)) { + newConstant = new ConstantFloating(type, std::get(val)); + } else { + assert(false && "Unsupported ConstantValVariant type"); + } + + mConstantPool[key] = newConstant; + return newConstant; } -ConstantValue* ConstantValue::get(float value) { - static std::map> floatConstants; - auto iter = floatConstants.find(value); - if (iter != floatConstants.end()) { - return iter->second.get(); - } - auto inst = new ConstantValue(value); - assert(inst); - auto result = floatConstants.emplace(value, inst); - return result.first->second.get(); +ConstantInteger* ConstantInteger::get(Type* type, int val) { + return dynamic_cast(ConstantValue::get(type, val)); } +ConstantFloating* ConstantFloating::get(Type* type, float val) { + return dynamic_cast(ConstantValue::get(type, val)); +} + +UndefinedValue* UndefinedValue::get(Type* type) { + assert(!type->isVoid() && "Cannot get UndefinedValue of void type!"); + + auto it = UndefValues.find(type); + if (it != UndefValues.end()) { + return it->second; + } + + UndefinedValue* newUndef = new UndefinedValue(type); + UndefValues[type] = newUndef; + return newUndef; +} + + auto Function::getCalleesWithNoExternalAndSelf() -> std::set { std::set result; for (auto callee : callees) { @@ -545,6 +569,83 @@ void User::replaceOperand(unsigned index, Value *value) { value->addUse(use); } +/** + * phi相关函数 + */ + + Value* PhiInst::getvalfromBlk(BasicBlock* blk){ + refreshB2VMap(); + if( blk2val.find(blk) != blk2val.end()) { + return blk2val.at(blk); + } + return nullptr; +} + +BasicBlock* PhiInst::getBlkfromVal(Value* val){ + // 返回第一个值对应的基本块 + for(unsigned i = 0; i < vsize; i++) { + if(getValue(i) == val) { + return getBlock(i); + } + } + return nullptr; +} + +void PhiInst::delValue(Value* val){ + //根据value删除对应的基本块和值 + unsigned i = 0; + BasicBlock* blk = getBlkfromVal(val); + for(i = 0; i < vsize; i++) { + if(getValue(i) == val) { + break; + } + } + removeOperand(2 * i + 1); // 删除blk + removeOperand(2 * i); // 删除val + vsize--; + blk2val.erase(blk); // 删除blk2val映射 +} + +void PhiInst::delBlk(BasicBlock* blk){ + //根据Blk删除对应的基本块和值 + unsigned i = 0; + Value* val = getvalfromBlk(blk); + for(i = 0; i < vsize; i++) { + if(getBlock(i) == blk) { + break; + } + } + removeOperand(2 * i + 1); // 删除blk + removeOperand(2 * i); // 删除val + vsize--; + blk2val.erase(blk); // 删除blk2val映射 +} + +void PhiInst::replaceBlk(BasicBlock* newBlk, unsigned k){ + refreshB2VMap(); + Value* val = blk2val.at(getBlock(k)); + // 替换基本块 + setOperand(2 * k + 1, newBlk); + // 替换blk2val映射 + blk2val.erase(getBlock(k)); + blk2val.emplace(newBlk, val); +} + +void PhiInst::replaceold2new(BasicBlock* oldBlk, BasicBlock* newBlk){ + refreshB2VMap(); + Value* val = blk2val.at(oldBlk); + // 替换基本块 + delBlk(oldBlk); + addIncoming(val, newBlk); +} + +void PhiInst::refreshB2VMap(){ + blk2val.clear(); + for(unsigned i = 0; i < vsize; i++) { + blk2val.emplace(getBlock(i), getValue(i)); + } +} + CallInst::CallInst(Function *callee, const std::vector &args, BasicBlock *parent, const std::string &name) : Instruction(kCall, callee->getReturnType(), parent, name) { addOperand(callee); diff --git a/src/include/IR.h b/src/include/IR.h index 1b4c702..060bdc5 100644 --- a/src/include/IR.h +++ b/src/include/IR.h @@ -268,6 +268,51 @@ class ValueCounter { } ///< 清空ValueCounter }; + +// --- Refactored ConstantValue and related classes start here --- + +using ConstantValVariant = std::variant; + +// Helper for hashing std::variant +struct VariantHash { + template + std::size_t operator()(const T& val) const { + return std::hash{}(val); + } + std::size_t operator()(const ConstantValVariant& v) const { + return std::visit(*this, v); + } +}; + +struct ConstantValueKey { + Type* type; + ConstantValVariant val; + + bool operator==(const ConstantValueKey& other) const { + // Assuming Type objects are canonicalized, or add Type::isSame() + // If Type::isSame() is not available and Type objects are not canonicalized, + // this comparison might not be robust enough for structural equivalence of types. + return type == other.type && val == other.val; + } +}; + +struct ConstantValueHash { + std::size_t operator()(const ConstantValueKey& key) const { + std::size_t typeHash = std::hash{}(key.type); + std::size_t valHash = VariantHash{}(key.val); + // A simple way to combine hashes + return typeHash ^ (valHash << 1); + } +}; + +struct ConstantValueEqual { + bool operator()(const ConstantValueKey& lhs, const ConstantValueKey& rhs) const { + // Assuming Type objects are canonicalized (e.g., Type::getIntType() always returns same pointer) + // If not, and Type::isSame() is intended, it should be added to Type class. + return lhs.type == rhs.type && lhs.val == rhs.val; + } +}; + /*! * Static constants known at compile time. * @@ -276,45 +321,135 @@ class ValueCounter { * `ConstantValue`并不由指令定义, 也不使用任何Value。它的类型为int/float。 */ +template struct always_false : std::false_type {}; +template constexpr bool always_false_v = always_false::value; class ConstantValue : public Value { - protected: - /// 定义字面量类型的聚合类型 - union { - int iScalar; - float fScalar; - }; +protected: + static std::unordered_map mConstantPool; - protected: - explicit ConstantValue(int value, const std::string &name = "") : Value(Type::getIntType(), name), iScalar(value) {} - explicit ConstantValue(float value, const std::string &name = "") - : Value(Type::getFloatType(), name), fScalar(value) {} +public: + explicit ConstantValue(Type* type, const std::string& name = "") : Value(type, name) {} + virtual ~ConstantValue() = default; - public: - static ConstantValue* get(int value); ///< 获取一个int类型的ConstValue *,其值为value - static ConstantValue* get(float value); ///< 获取一个float类型的ConstValue *,其值为value + virtual size_t hash() const = 0; + virtual ConstantValVariant getVal() const = 0; - public: + // Static factory method to get a canonical ConstantValue from the pool + static ConstantValue* get(Type* type, ConstantValVariant val); + + // Helper methods to access constant values with appropriate casting int getInt() const { - assert(isInt()); - return iScalar; - } ///< 返回int类型的值 + assert(getType()->isInt() && "Calling getInt() on non-integer type"); + return std::get(getVal()); + } float getFloat() const { - assert(isFloat()); - return fScalar; - } ///< 返回float类型的值 - template - T getValue() const { - if (std::is_same::value && isInt()) { + assert(getType()->isFloat() && "Calling getFloat() on non-float type"); + return std::get(getVal()); + } + + template + T getVal() const { + if constexpr (std::is_same_v) { return getInt(); - } - if (std::is_same::value && isFloat()) { + } else if constexpr (std::is_same_v) { return getFloat(); + } else { + // This ensures a compilation error if an unsupported type is used + static_assert(always_false_v, "Unsupported type for ConstantValue::getValue()"); } - throw std::bad_cast(); // 或者其他适当的异常处理 - } ///< 返回值,getInt和getFloat统一化,整数返回整形,浮点返回浮点型 + } + + virtual bool isZero() const = 0; + virtual bool isOne() const = 0; }; +class ConstantInteger : public ConstantValue { + int constVal; +public: + explicit ConstantInteger(Type* type, int val, const std::string& name = "") + : ConstantValue(type, name), constVal(val) {} + + size_t hash() const override { + std::size_t typeHash = std::hash{}(getType()); + std::size_t valHash = std::hash{}(constVal); + return typeHash ^ (valHash << 1); + } + int getInt() const { return constVal; } + ConstantValVariant getVal() const override { return constVal; } + + static ConstantInteger* get(Type* type, int val); + static ConstantInteger* get(int val) { return get(Type::getIntType(), val); } + + ConstantInteger* getNeg() const { + assert(getType()->isInt() && "Cannot negate non-integer constant"); + return ConstantInteger::get(-constVal); + } + + bool isZero() const override { return constVal == 0; } + bool isOne() const override { return constVal == 1; } +}; + +class ConstantFloating : public ConstantValue { + float constFVal; +public: + explicit ConstantFloating(Type* type, float val, const std::string& name = "") + : ConstantValue(type, name), constFVal(val) {} + + size_t hash() const override { + std::size_t typeHash = std::hash{}(getType()); + std::size_t valHash = std::hash{}(constFVal); + return typeHash ^ (valHash << 1); + } + float getFloat() const { return constFVal; } + ConstantValVariant getVal() const override { return constFVal; } + + static ConstantFloating* get(Type* type, float val); + static ConstantFloating* get(float val) { return get(Type::getFloatType(), val); } + + ConstantFloating* getNeg() const { + assert(getType()->isFloat() && "Cannot negate non-float constant"); + return ConstantFloating::get(-constFVal); + } + + bool isZero() const override { return constFVal == 0.0f; } + bool isOne() const override { return constFVal == 1.0f; } +}; + +class UndefinedValue : public ConstantValue { +private: + static std::unordered_map UndefValues; + +protected: + explicit UndefinedValue(Type* type, const std::string& name = "") + : ConstantValue(type, name) { + assert(!type->isVoid() && "Cannot create UndefinedValue of void type!"); + } + +public: + static UndefinedValue* get(Type* type); + + size_t hash() const override { + return std::hash{}(getType()); + } + + ConstantValVariant getVal() const override { + if (getType()->isInt()) { + return 0; // Return 0 for undefined integer + } else if (getType()->isFloat()) { + return 0.0f; // Return 0.0f for undefined float + } + assert(false && "UndefinedValue has unexpected type for getValue()"); + return 0; // Should not be reached + } + + bool isZero() const override { return false; } + bool isOne() const override { return false; } +}; + +// --- End of refactored ConstantValue and related classes --- + + class Instruction; class Function; class BasicBlock; @@ -562,8 +697,8 @@ class Instruction : public User { kLa = 0x1UL << 36, kMemset = 0x1UL << 37, kGetSubArray = 0x1UL << 38, - // constant - kConstant = 0x1UL << 37, + // Constant Kind removed as Constants are now Values, not Instructions. + // kConstant = 0x1UL << 37, // Conflicts with kMemset if kept as is // phi kPhi = 0x1UL << 39, kBitItoF = 0x1UL << 40, @@ -755,24 +890,51 @@ class LaInst : public Instruction { class PhiInst : public Instruction { friend class IRBuilder; friend class Function; - friend class SysySSA; protected: - Value *map_val; // Phi的旧值 - PhiInst(Type *type, Value *lhs, const std::vector &rhs, Value *mval, BasicBlock *parent, + std::unordered_map blk2val; ///< 存储每个基本块对应的值 + unsigned vsize; ///< 存储值的数量 + + PhiInst(Type *type, + const std::vector &rhs = {}, + const std::vector &Blocks = {}, + BasicBlock *parent = nullptr, const std::string &name = "") - : Instruction(Kind::kPhi, type, parent, name) { - map_val = mval; - addOperand(lhs); - addOperands(rhs); + : Instruction(Kind::kPhi, type, parent, name), vsize(rhs.size()) { + assert(rhs.size() == Blocks.size() && "PhiInst: rhs and Blocks must have the same size"); + for(size_t i = 0; i < rhs.size(); ++i) { + addOperand(rhs[i]); + blk2val[Blocks[i]] = rhs[i]; + } } public: - Value* getMapVal() { return map_val; } - Value* getPointer() const { return getOperand(0); } + Value* getValue(unsigned k) const {return getOperand(2 * k);} ///< 获取位置为k的值 + BasicBlock* getBlock(unsigned k) const {return dynamic_cast(getOperand(2 * k + 1));} + + auto& getincomings() const {return blk2val;} ///< 获取所有的基本块和对应的值 + + Value* getvalfromBlk(BasicBlock* blk); + BasicBlock* getBlkfromVal(Value* val); + + unsigned getNumIncomingValues() const { return vsize; } ///< 获取传入值的数量 + void addIncoming(Value *value, BasicBlock *block) { + assert(value && block && "PhiInst: value and block must not be null"); + addOperand(value); + addOperand(block); + blk2val[block] = value; + vsize++; + } ///< 添加传入值和对应的基本块 + + void delValue(Value* val); + void delBlk(BasicBlock* blk); + + void replaceBlk(BasicBlock* newBlk, unsigned k); + void replaceold2new(BasicBlock* oldBlk, BasicBlock* newBlk); + void refreshB2VMap(); + auto getValues() { return make_range(std::next(operand_begin()), operand_end()); } - Value* getValue(unsigned index) const { return getOperand(index + 1); } }; @@ -884,7 +1046,7 @@ public: } } ///< 根据指令类型进行二元计算,eval template模板实现 static BinaryInst* create(Kind kind, Type *type, Value *lhs, Value *rhs, BasicBlock *parent, const std::string &name = "") { - // 后端处理数组访存操作时需要创建计算地址的指令,需要在外部构造 BinaryInst 对象,所以写了个public的方法。 + // 后端处理数组访存操作时需要创建计算地址的指令,需要在外部构造 BinaryInst 对象 return new BinaryInst(kind, type, lhs, rhs, parent, name); } }; // class BinaryInst @@ -1230,12 +1392,15 @@ protected: if (init.size() == 0) { unsigned num = 1; for (unsigned i = 0; i < numDims; i++) { - num *= dynamic_cast(dims[i])->getInt(); + // Assume dims elements are ConstantInteger and cast appropriately + auto dim_val = dynamic_cast(dims[i]); + assert(dim_val && "GlobalValue dims must be constant integers"); + num *= dim_val->getInt(); } if (dynamic_cast(type)->getBaseType() == Type::getFloatType()) { - init.push_back(ConstantValue::get(0.0F), num); + init.push_back(ConstantFloating::get(0.0F), num); // Use new constant factory } else { - init.push_back(ConstantValue::get(0), num); + init.push_back(ConstantInteger::get(0), num); // Use new constant factory } } initValues = init; @@ -1261,8 +1426,11 @@ public: Value* getByIndices(const std::vector &indices) const { int index = 0; for (size_t i = 0; i < indices.size(); i++) { - index = dynamic_cast(getDim(i))->getInt() * index + - dynamic_cast(indices[i])->getInt(); + // Ensure dims[i] and indices[i] are ConstantInteger and retrieve their values correctly + auto dim_val = dynamic_cast(getDim(i)); + auto idx_val = dynamic_cast(indices[i]); + assert(dim_val && idx_val && "Dims and indices must be constant integers"); + index = dim_val->getInt() * index + idx_val->getInt(); } return getByIndex(index); } ///< 通过多维索引indices获取初始值 @@ -1303,8 +1471,11 @@ class ConstantVariable : public User, public LVal { int index = 0; // 计算偏移量 for (size_t i = 0; i < indices.size(); i++) { - index = dynamic_cast(getDim(i))->getInt() * index + - dynamic_cast(indices[i])->getInt(); + // Ensure dims[i] and indices[i] are ConstantInteger and retrieve their values correctly + auto dim_val = dynamic_cast(getDim(i)); + auto idx_val = dynamic_cast(indices[i]); + assert(dim_val && idx_val && "Dims and indices must be constant integers"); + index = dim_val->getInt() * index + idx_val->getInt(); } return getByIndex(index); diff --git a/src/include/IRBuilder.h b/src/include/IRBuilder.h index aab9a1d..6df82e7 100644 --- a/src/include/IRBuilder.h +++ b/src/include/IRBuilder.h @@ -333,15 +333,11 @@ class IRBuilder { block->getInstructions().emplace(position, inst); return inst; } ///< 创建store指令 - PhiInst * createPhiInst(Type *type, Value *lhs, BasicBlock *parent, const std::string &name = "") { - auto predNum = parent->getNumPredecessors(); - std::vector rhs; - for (size_t i = 0; i < predNum; i++) { - rhs.push_back(lhs); - } - auto inst = new PhiInst(type, lhs, rhs, lhs, parent, name); + PhiInst * createPhiInst(Type *type, const std::vector &vals = {}, const std::vector &blks = {}, const std::string &name = "") { + auto predNum = block->getNumPredecessors(); + auto inst = new PhiInst(type, vals, blks, block, name); assert(inst); - parent->getInstructions().emplace(parent->begin(), inst); + block->getInstructions().emplace(block->begin(), inst); return inst; } ///< 创建Phi指令 };