From fdc946c1b528ac4659be665bb6bf63a51fdc13d1 Mon Sep 17 00:00:00 2001 From: rain2133 <1370973498@qq.com> Date: Fri, 18 Jul 2025 16:40:16 +0800 Subject: [PATCH] =?UTF-8?q?[IR]=E9=87=8D=E6=9E=84=E5=B8=B8=E9=87=8F?= =?UTF-8?q?=E5=AE=9A=E4=B9=89=EF=BC=8C=E5=BC=95=E5=85=A5undefvalue?= =?UTF-8?q?=E5=AE=9A=E4=B9=89=EF=BC=8C=E4=BF=AE=E6=94=B9=E5=B8=B8=E9=87=8F?= =?UTF-8?q?=E6=96=B9=E6=B3=95=E4=BD=BF=E7=94=A8=E5=B0=BD=E9=87=8F=E9=80=82?= =?UTF-8?q?=E9=85=8D=E6=97=A7=E7=89=88?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/SysYIRGenerator.cpp | 101 +++++++-------- src/include/IR.h | 266 ++++++++++++++++++++++++++++++++++------ 2 files changed, 277 insertions(+), 90 deletions(-) diff --git a/src/SysYIRGenerator.cpp b/src/SysYIRGenerator.cpp index 7520891..afaf24b 100644 --- a/src/SysYIRGenerator.cpp +++ b/src/SysYIRGenerator.cpp @@ -145,8 +145,8 @@ std::any SysYIRGenerator::visitVarDecl(SysYParser::VarDeclContext *ctx) { for (size_t i = 0; i < counterNumbers.size(); i++) { builder.createMemsetInst( - alloca, ConstantValue::get(static_cast(begin)), - ConstantValue::get(static_cast(counterNumbers[i])), + alloca, ConstantInteger::get(begin), + ConstantInteger::get(static_cast(counterNumbers[i])), counterValues[i]); begin += counterNumbers[i]; } @@ -218,7 +218,7 @@ std::any SysYIRGenerator::visitFuncDef(SysYParser::FuncDefContext *ctx){ paramNames.push_back(param->Ident()->getText()); std::vector dims = {}; if (!param->LBRACK().empty()) { - dims.push_back(ConstantValue::get(-1)); // 第一个维度不确定 + dims.push_back(ConstantInteger::get(-1)); // 第一个维度不确定 for (const auto &exp : param->exp()) { dims.push_back(std::any_cast(visitExp(exp))); } @@ -247,9 +247,9 @@ std::any SysYIRGenerator::visitFuncDef(SysYParser::FuncDefContext *ctx){ if(HasReturnInst == false) { // 如果没有return语句,则默认返回0 if (returnType != Type::getVoidType()) { - Value* returnValue = ConstantValue::get(0); + Value* returnValue = ConstantInteger::get(0); if (returnType == Type::getFloatType()) { - returnValue = ConstantValue::get(0.0f); + returnValue = ConstantFloating::get(0.0f); } builder.createReturnInst(returnValue); } else { @@ -286,9 +286,9 @@ std::any SysYIRGenerator::visitAssignStmt(SysYParser::AssignStmtContext *ctx) { ConstantValue * constValue = dynamic_cast(value); if (constValue != nullptr) { if (variableType == Type::getFloatType()) { - value = ConstantValue::get(static_cast(constValue->getInt())); + value = ConstantInteger::get(static_cast(constValue->getInt())); } else { - value = ConstantValue::get(static_cast(constValue->getFloat())); + value = ConstantFloating::get(static_cast(constValue->getFloat())); } } else { if (variableType == Type::getFloatType()) { @@ -478,9 +478,9 @@ std::any SysYIRGenerator::visitReturnStmt(SysYParser::ReturnStmtContext *ctx) { ConstantValue * constValue = dynamic_cast(returnValue); if (constValue != nullptr) { if (funcType == Type::getFloatType()) { - returnValue = ConstantValue::get(static_cast(constValue->getInt())); + returnValue = ConstantInteger::get(static_cast(constValue->getInt())); } else { - returnValue = ConstantValue::get(static_cast(constValue->getFloat())); + returnValue = ConstantFloating::get(static_cast(constValue->getFloat())); } } else { if (funcType == Type::getFloatType()) { @@ -560,10 +560,10 @@ std::any SysYIRGenerator::visitPrimaryExp(SysYParser::PrimaryExpContext *ctx) { std::any SysYIRGenerator::visitNumber(SysYParser::NumberContext *ctx) { if (ctx->ILITERAL() != nullptr) { int value = std::stol(ctx->ILITERAL()->getText(), nullptr, 0); - return static_cast(ConstantValue::get(value)); + return static_cast(ConstantInteger::get(value)); } else if (ctx->FLITERAL() != nullptr) { float value = std::stof(ctx->FLITERAL()->getText()); - return static_cast(ConstantValue::get(value)); + return static_cast(ConstantFloating::get(value)); } throw std::runtime_error("Unknown number type."); return std::any(); // 不会到达这里 @@ -599,9 +599,9 @@ std::any SysYIRGenerator::visitCall(SysYParser::CallContext *ctx) { ConstantValue * constValue = dynamic_cast(args[i]); if (constValue != nullptr) { if (params[i]->getType() == Type::getPointerType(Type::getFloatType())) { - args[i] = ConstantValue::get(static_cast(constValue->getInt())); + args[i] = ConstantInteger::get(static_cast(constValue->getInt())); } else { - args[i] = ConstantValue::get(static_cast(constValue->getFloat())); + args[i] = ConstantFloating::get(static_cast(constValue->getFloat())); } } else { if (params[i]->getType() == Type::getPointerType(Type::getFloatType())) { @@ -629,9 +629,9 @@ std::any SysYIRGenerator::visitUnaryExp(SysYParser::UnaryExpContext *ctx) { ConstantValue * constValue = dynamic_cast(value); if (constValue != nullptr) { if (constValue->isFloat()) { - result = ConstantValue::get(-constValue->getFloat()); + result = ConstantFloating::get(-constValue->getFloat()); } else { - result = ConstantValue::get(-constValue->getInt()); + result = ConstantInteger::get(-constValue->getInt()); } } else if (value != nullptr) { if (value->getType() == Type::getIntType()) { @@ -648,9 +648,9 @@ std::any SysYIRGenerator::visitUnaryExp(SysYParser::UnaryExpContext *ctx) { if (constValue != nullptr) { if (constValue->isFloat()) { result = - ConstantValue::get(1 - (constValue->getFloat() != 0.0F ? 1 : 0)); + ConstantFloating::get(1 - (constValue->getFloat() != 0.0F ? 1 : 0)); } else { - result = ConstantValue::get(1 - (constValue->getInt() != 0 ? 1 : 0)); + result = ConstantInteger::get(1 - (constValue->getInt() != 0 ? 1 : 0)); } } else if (value != nullptr) { if (value->getType() == Type::getIntType()) { @@ -692,13 +692,13 @@ std::any SysYIRGenerator::visitMulExp(SysYParser::MulExpContext *ctx) { if (operandType != floatType) { ConstantValue * constValue = dynamic_cast(operand); if (constValue != nullptr) - operand = ConstantValue::get(static_cast(constValue->getInt())); + operand = ConstantFloating::get(static_cast(constValue->getInt())); else operand = builder.createIToFInst(operand); } else if (resultType != floatType) { ConstantValue* constResult = dynamic_cast(result); if (constResult != nullptr) - result = ConstantValue::get(static_cast(constResult->getInt())); + result = ConstantFloating::get(static_cast(constResult->getInt())); else result = builder.createIToFInst(result); } @@ -707,14 +707,14 @@ std::any SysYIRGenerator::visitMulExp(SysYParser::MulExpContext *ctx) { ConstantValue* constOperand = dynamic_cast(operand); if (opType == SysYParser::MUL) { if ((constOperand != nullptr) && (constResult != nullptr)) { - result = ConstantValue::get(constResult->getFloat() * + result = ConstantFloating::get(constResult->getFloat() * constOperand->getFloat()); } else { result = builder.createFMulInst(result, operand); } } else if (opType == SysYParser::DIV) { if ((constOperand != nullptr) && (constResult != nullptr)) { - result = ConstantValue::get(constResult->getFloat() / + result = ConstantFloating::get(constResult->getFloat() / constOperand->getFloat()); } else { result = builder.createFDivInst(result, operand); @@ -729,17 +729,17 @@ std::any SysYIRGenerator::visitMulExp(SysYParser::MulExpContext *ctx) { ConstantValue * constOperand = dynamic_cast(operand); if (opType == SysYParser::MUL) { if ((constOperand != nullptr) && (constResult != nullptr)) - result = ConstantValue::get(constResult->getInt() * constOperand->getInt()); + result = ConstantInteger::get(constResult->getInt() * constOperand->getInt()); else result = builder.createMulInst(result, operand); } else if (opType == SysYParser::DIV) { if ((constOperand != nullptr) && (constResult != nullptr)) - result = ConstantValue::get(constResult->getInt() / constOperand->getInt()); + result = ConstantInteger::get(constResult->getInt() / constOperand->getInt()); else result = builder.createDivInst(result, operand); } else { if ((constOperand != nullptr) && (constResult != nullptr)) - result = ConstantValue::get(constResult->getInt() % constOperand->getInt()); + result = ConstantInteger::get(constResult->getInt() % constOperand->getInt()); else result = builder.createRemInst(result, operand); } @@ -767,13 +767,13 @@ std::any SysYIRGenerator::visitAddExp(SysYParser::AddExpContext *ctx) { if (operandType != floatType) { ConstantValue * constOperand = dynamic_cast(operand); if (constOperand != nullptr) - operand = ConstantValue::get(static_cast(constOperand->getInt())); + operand = ConstantFloating::get(static_cast(constOperand->getInt())); else operand = builder.createIToFInst(operand); } else if (resultType != floatType) { ConstantValue * constResult = dynamic_cast(result); if (constResult != nullptr) - result = ConstantValue::get(static_cast(constResult->getInt())); + result = ConstantFloating::get(static_cast(constResult->getInt())); else result = builder.createIToFInst(result); } @@ -782,12 +782,12 @@ std::any SysYIRGenerator::visitAddExp(SysYParser::AddExpContext *ctx) { ConstantValue * constOperand = dynamic_cast(operand); if (opType == SysYParser::ADD) { if ((constResult != nullptr) && (constOperand != nullptr)) - result = ConstantValue::get(constResult->getFloat() + constOperand->getFloat()); + result = ConstantFloating::get(constResult->getFloat() + constOperand->getFloat()); else result = builder.createFAddInst(result, operand); } else { if ((constResult != nullptr) && (constOperand != nullptr)) - result = ConstantValue::get(constResult->getFloat() - constOperand->getFloat()); + result = ConstantFloating::get(constResult->getFloat() - constOperand->getFloat()); else result = builder.createFSubInst(result, operand); } @@ -796,12 +796,12 @@ std::any SysYIRGenerator::visitAddExp(SysYParser::AddExpContext *ctx) { ConstantValue * constOperand = dynamic_cast(operand); if (opType == SysYParser::ADD) { if ((constResult != nullptr) && (constOperand != nullptr)) - result = ConstantValue::get(constResult->getInt() + constOperand->getInt()); + result = ConstantInteger::get(constResult->getInt() + constOperand->getInt()); else result = builder.createAddInst(result, operand); } else { if ((constResult != nullptr) && (constOperand != nullptr)) - result = ConstantValue::get(constResult->getInt() - constOperand->getInt()); + result = ConstantInteger::get(constResult->getInt() - constOperand->getInt()); else result = builder.createSubInst(result, operand); } @@ -833,10 +833,10 @@ std::any SysYIRGenerator::visitRelExp(SysYParser::RelExpContext *ctx) { auto operand2 = constOperand->isFloat() ? constOperand->getFloat() : constOperand->getInt(); - if (opType == SysYParser::LT) result = ConstantValue::get(operand1 < operand2 ? 1 : 0); - else if (opType == SysYParser::GT) result = ConstantValue::get(operand1 > operand2 ? 1 : 0); - else if (opType == SysYParser::LE) result = ConstantValue::get(operand1 <= operand2 ? 1 : 0); - else if (opType == SysYParser::GE) result = ConstantValue::get(operand1 >= operand2 ? 1 : 0); + if (opType == SysYParser::LT) result = ConstantInteger::get(operand1 < operand2 ? 1 : 0); + else if (opType == SysYParser::GT) result = ConstantInteger::get(operand1 > operand2 ? 1 : 0); + else if (opType == SysYParser::LE) result = ConstantInteger::get(operand1 <= operand2 ? 1 : 0); + else if (opType == SysYParser::GE) result = ConstantInteger::get(operand1 >= operand2 ? 1 : 0); else assert(false); } else { @@ -848,14 +848,14 @@ std::any SysYIRGenerator::visitRelExp(SysYParser::RelExpContext *ctx) { if (resultType == floatType || operandType == floatType) { if (resultType != floatType) { if (constResult != nullptr) - result = ConstantValue::get(static_cast(constResult->getInt())); + result = ConstantFloating::get(static_cast(constResult->getInt())); else result = builder.createIToFInst(result); } if (operandType != floatType) { if (constOperand != nullptr) - operand = ConstantValue::get(static_cast(constOperand->getInt())); + operand = ConstantFloating::get(static_cast(constOperand->getInt())); else operand = builder.createIToFInst(operand); @@ -901,8 +901,8 @@ std::any SysYIRGenerator::visitEqExp(SysYParser::EqExpContext *ctx) { auto operand2 = constOperand->isFloat() ? constOperand->getFloat() : constOperand->getInt(); - if (opType == SysYParser::EQ) result = ConstantValue::get(operand1 == operand2 ? 1 : 0); - else if (opType == SysYParser::NE) result = ConstantValue::get(operand1 != operand2 ? 1 : 0); + if (opType == SysYParser::EQ) result = ConstantInteger::get(operand1 == operand2 ? 1 : 0); + else if (opType == SysYParser::NE) result = ConstantInteger::get(operand1 != operand2 ? 1 : 0); else assert(false); } else { @@ -913,13 +913,13 @@ std::any SysYIRGenerator::visitEqExp(SysYParser::EqExpContext *ctx) { if (resultType == floatType || operandType == floatType) { if (resultType != floatType) { if (constResult != nullptr) - result = ConstantValue::get(static_cast(constResult->getInt())); + result = ConstantFloating::get(static_cast(constResult->getInt())); else result = builder.createIToFInst(result); } if (operandType != floatType) { if (constOperand != nullptr) - operand = ConstantValue::get(static_cast(constOperand->getInt())); + operand = ConstantFloating::get(static_cast(constOperand->getInt())); else operand = builder.createIToFInst(operand); } @@ -943,9 +943,9 @@ std::any SysYIRGenerator::visitEqExp(SysYParser::EqExpContext *ctx) { // 如果只有一个关系表达式,则将结果转换为0或1 if (constResult != nullptr) { if (constResult->isFloat()) - result = ConstantValue::get(constResult->getFloat() != 0.0F ? 1 : 0); + result = ConstantInteger::get(constResult->getFloat() != 0.0F ? 1 : 0); else - result = ConstantValue::get(constResult->getInt() != 0 ? 1 : 0); + result = ConstantInteger::get(constResult->getInt() != 0 ? 1 : 0); } } @@ -1013,6 +1013,7 @@ void Utils::tree2Array(Type *type, ArrayValueTree *root, ValueCounter &result, IRBuilder *builder) { Value* value = root->getValue(); auto &children = root->getChildren(); + // 类型转换 if (value != nullptr) { if (type == value->getType()) { result.push_back(value); @@ -1020,14 +1021,14 @@ void Utils::tree2Array(Type *type, ArrayValueTree *root, if (type == Type::getFloatType()) { ConstantValue* constValue = dynamic_cast(value); if (constValue != nullptr) - result.push_back(ConstantValue::get(static_cast(constValue->getInt()))); + result.push_back(ConstantFloating::get(static_cast(constValue->getInt()))); else result.push_back(builder->createIToFInst(value)); } else { ConstantValue* constValue = dynamic_cast(value); if (constValue != nullptr) - result.push_back(ConstantValue::get(static_cast(constValue->getFloat()))); + result.push_back(ConstantInteger::get(static_cast(constValue->getFloat()))); else result.push_back(builder->createFtoIInst(value)); @@ -1061,9 +1062,9 @@ void Utils::tree2Array(Type *type, ArrayValueTree *root, int num = blockSize - afterSize + beforeSize; if (num > 0) { if (type == Type::getFloatType()) - result.push_back(ConstantValue::get(0.0F), num); + result.push_back(ConstantFloating::get(0.0F), num); else - result.push_back(ConstantValue::get(0), num); + result.push_back(ConstantInteger::get(0), num); } } @@ -1101,7 +1102,7 @@ void Utils::initExternalFunction(Module *pModule, IRBuilder *pBuilder) { funcName, pModule, pBuilder); paramTypes.push_back(Type::getIntType()); paramNames.emplace_back("x"); - paramDims.push_back(std::vector{ConstantValue::get(-1)}); + paramDims.push_back(std::vector{ConstantInteger::get(-1)}); funcName = "getarray"; Utils::createExternalFunction(paramTypes, paramNames, paramDims, returnType, funcName, pModule, pBuilder); @@ -1117,7 +1118,7 @@ void Utils::initExternalFunction(Module *pModule, IRBuilder *pBuilder) { returnType = Type::getIntType(); paramTypes.push_back(Type::getFloatType()); paramNames.emplace_back("x"); - paramDims.push_back(std::vector{ConstantValue::get(-1)}); + paramDims.push_back(std::vector{ConstantInteger::get(-1)}); funcName = "getfarray"; Utils::createExternalFunction(paramTypes, paramNames, paramDims, returnType, funcName, pModule, pBuilder); @@ -1141,7 +1142,7 @@ void Utils::initExternalFunction(Module *pModule, IRBuilder *pBuilder) { paramTypes.push_back(Type::getIntType()); paramDims.clear(); paramDims.emplace_back(); - paramDims.push_back(std::vector{ConstantValue::get(-1)}); + paramDims.push_back(std::vector{ConstantInteger::get(-1)}); paramNames.clear(); paramNames.emplace_back("n"); paramNames.emplace_back("a"); @@ -1164,7 +1165,7 @@ void Utils::initExternalFunction(Module *pModule, IRBuilder *pBuilder) { paramTypes.push_back(Type::getFloatType()); paramDims.clear(); paramDims.emplace_back(); - paramDims.push_back(std::vector{ConstantValue::get(-1)}); + paramDims.push_back(std::vector{ConstantInteger::get(-1)}); paramNames.clear(); paramNames.emplace_back("n"); paramNames.emplace_back("a"); diff --git a/src/include/IR.h b/src/include/IR.h index 1154d4e..b23689a 100644 --- a/src/include/IR.h +++ b/src/include/IR.h @@ -268,6 +268,51 @@ class ValueCounter { } ///< 清空ValueCounter }; + +// --- Refactored ConstantValue and related classes start here --- + +using ConstantValVariant = std::variant; + +// Helper for hashing std::variant +struct VariantHash { + template + std::size_t operator()(const T& val) const { + return std::hash{}(val); + } + std::size_t operator()(const ConstantValVariant& v) const { + return std::visit(*this, v); + } +}; + +struct ConstantValueKey { + Type* type; + ConstantValVariant val; + + bool operator==(const ConstantValueKey& other) const { + // Assuming Type objects are canonicalized, or add Type::isSame() + // If Type::isSame() is not available and Type objects are not canonicalized, + // this comparison might not be robust enough for structural equivalence of types. + return type == other.type && val == other.val; + } +}; + +struct ConstantValueHash { + std::size_t operator()(const ConstantValueKey& key) const { + std::size_t typeHash = std::hash{}(key.type); + std::size_t valHash = VariantHash{}(key.val); + // A simple way to combine hashes + return typeHash ^ (valHash << 1); + } +}; + +struct ConstantValueEqual { + bool operator()(const ConstantValueKey& lhs, const ConstantValueKey& rhs) const { + // Assuming Type objects are canonicalized (e.g., Type::getIntType() always returns same pointer) + // If not, and Type::isSame() is intended, it should be added to Type class. + return lhs.type == rhs.type && lhs.val == rhs.val; + } +}; + /*! * Static constants known at compile time. * @@ -275,46 +320,178 @@ class ValueCounter { * `Value`s. It's type is either `int` or `float`. * `ConstantValue`并不由指令定义, 也不使用任何Value。它的类型为int/float。 */ - - class ConstantValue : public Value { - protected: - /// 定义字面量类型的聚合类型 - union { - int iScalar; - float fScalar; - }; +protected: + static std::unordered_map mConstantPool; - protected: - explicit ConstantValue(int value, const std::string &name = "") : Value(Type::getIntType(), name), iScalar(value) {} - explicit ConstantValue(float value, const std::string &name = "") - : Value(Type::getFloatType(), name), fScalar(value) {} +public: + explicit ConstantValue(Type* type, const std::string& name = "") : Value(type, name) {} + virtual ~ConstantValue() = default; - public: - static ConstantValue* get(int value); ///< 获取一个int类型的ConstValue *,其值为value - static ConstantValue* get(float value); ///< 获取一个float类型的ConstValue *,其值为value + virtual size_t hash() const = 0; + virtual ConstantValVariant getValue() const = 0; - public: + // Static factory method to get a canonical ConstantValue from the pool + static ConstantValue* get(Type* type, ConstantValVariant val); + + // Helper methods to access constant values with appropriate casting int getInt() const { - assert(isInt()); - return iScalar; - } ///< 返回int类型的值 + assert(getType()->isInt() && "Calling getInt() on non-integer type"); + return std::get(getValue()); + } float getFloat() const { - assert(isFloat()); - return fScalar; - } ///< 返回float类型的值 - template + assert(getType()->isFloat() && "Calling getFloat() on non-float type"); + return std::get(getValue()); + } + + template T getValue() const { - if (std::is_same::value && isInt()) { - return getInt(); - } - if (std::is_same::value && isFloat()) { - return getFloat(); - } - throw std::bad_cast(); // 或者其他适当的异常处理 - } ///< 返回值,getInt和getFloat统一化,整数返回整形,浮点返回浮点型 + if constexpr (std::is_same_v) { + return getInt(); + } else if constexpr (std::is_same_v) { + return getFloat(); + } else { + // This ensures a compilation error if an unsupported type is used + static_assert(std::always_false_v, "Unsupported type for ConstantValue::getValue()"); + } + } + + virtual bool isZero() const = 0; + virtual bool isOne() const = 0; }; +class ConstantInteger : public ConstantValue { + int constVal; +public: + explicit ConstantInteger(Type* type, int val, const std::string& name = "") + : ConstantValue(type, name), constVal(val) {} + + size_t hash() const override { + std::size_t typeHash = std::hash{}(getType()); + std::size_t valHash = std::hash{}(constVal); + return typeHash ^ (valHash << 1); + } + int getInt() const { return constVal; } + ConstantValVariant getValue() const override { return constVal; } + + static ConstantInteger* get(Type* type, int val); + static ConstantInteger* get(int val) { return get(Type::getIntType(), val); } + + ConstantInteger* getNeg() const { + assert(getType()->isInt() && "Cannot negate non-integer constant"); + return ConstantInteger::get(-constVal); + } + + bool isZero() const override { return constVal == 0; } + bool isOne() const override { return constVal == 1; } +}; + +class ConstantFloating : public ConstantValue { + float constFVal; +public: + explicit ConstantFloating(Type* type, float val, const std::string& name = "") + : ConstantValue(type, name), constFVal(val) {} + + size_t hash() const override { + std::size_t typeHash = std::hash{}(getType()); + std::size_t valHash = std::hash{}(constFVal); + return typeHash ^ (valHash << 1); + } + float getFloat() const { return constFVal; } + ConstantValVariant getValue() const override { return constFVal; } + + static ConstantFloating* get(Type* type, float val); + static ConstantFloating* get(float val) { return get(Type::getFloatType(), val); } + + ConstantFloating* getNeg() const { + assert(getType()->isFloat() && "Cannot negate non-float constant"); + return ConstantFloating::get(-constFVal); + } + + bool isZero() const override { return constFVal == 0.0f; } + bool isOne() const override { return constFVal == 1.0f; } +}; + +class UndefinedValue : public ConstantValue { +private: + static std::unordered_map UndefValues; + +protected: + explicit UndefinedValue(Type* type, const std::string& name = "") + : ConstantValue(type, name) { + assert(!type->isVoid() && "Cannot create UndefinedValue of void type!"); + } + +public: + static UndefinedValue* get(Type* type); + + size_t hash() const override { + return std::hash{}(getType()); + } + + ConstantValVariant getValue() const override { + if (getType()->isInt()) { + return 0; // Return 0 for undefined integer + } else if (getType()->isFloat()) { + return 0.0f; // Return 0.0f for undefined float + } + assert(false && "UndefinedValue has unexpected type for getValue()"); + return 0; // Should not be reached + } + + bool isZero() const override { return false; } + bool isOne() const override { return false; } +}; + +// Implementations for static members (typically in .cpp, but for single-file, put here) + +std::unordered_map ConstantValue::mConstantPool; +std::unordered_map UndefinedValue::UndefValues; + +ConstantValue* ConstantValue::get(Type* type, ConstantValVariant val) { + ConstantValueKey key = {type, val}; + auto it = mConstantPool.find(key); + if (it != mConstantPool.end()) { + return it->second; + } + + ConstantValue* newConstant = nullptr; + if (std::holds_alternative(val)) { + newConstant = new ConstantInteger(type, std::get(val)); + } else if (std::holds_alternative(val)) { + newConstant = new ConstantFloating(type, std::get(val)); + } else { + assert(false && "Unsupported ConstantValVariant type"); + } + + mConstantPool[key] = newConstant; + return newConstant; +} + +ConstantInteger* ConstantInteger::get(Type* type, int val) { + return dynamic_cast(ConstantValue::get(type, val)); +} + +ConstantFloating* ConstantFloating::get(Type* type, float val) { + return dynamic_cast(ConstantValue::get(type, val)); +} + +UndefinedValue* UndefinedValue::get(Type* type) { + assert(!type->isVoid() && "Cannot get UndefinedValue of void type!"); + + auto it = UndefValues.find(type); + if (it != UndefValues.end()) { + return it->second; + } + + UndefinedValue* newUndef = new UndefinedValue(type); + UndefValues[type] = newUndef; + return newUndef; +} + +// --- End of refactored ConstantValue and related classes --- + + class Instruction; class Function; class BasicBlock; @@ -562,8 +739,8 @@ class Instruction : public User { kLa = 0x1UL << 36, kMemset = 0x1UL << 37, kGetSubArray = 0x1UL << 38, - // constant - kConstant = 0x1UL << 37, + // Constant Kind removed as Constants are now Values, not Instructions. + // kConstant = 0x1UL << 37, // Conflicts with kMemset if kept as is // phi kPhi = 0x1UL << 39, kBitItoF = 0x1UL << 40, @@ -1258,12 +1435,15 @@ protected: if (init.size() == 0) { unsigned num = 1; for (unsigned i = 0; i < numDims; i++) { - num *= dynamic_cast(dims[i])->getInt(); + // Assume dims elements are ConstantInteger and cast appropriately + auto dim_val = dynamic_cast(dims[i]); + assert(dim_val && "GlobalValue dims must be constant integers"); + num *= dim_val->getInt(); } if (dynamic_cast(type)->getBaseType() == Type::getFloatType()) { - init.push_back(ConstantValue::get(0.0F), num); + init.push_back(ConstantFloating::get(0.0F), num); // Use new constant factory } else { - init.push_back(ConstantValue::get(0), num); + init.push_back(ConstantInteger::get(0), num); // Use new constant factory } } initValues = init; @@ -1289,8 +1469,11 @@ public: Value* getByIndices(const std::vector &indices) const { int index = 0; for (size_t i = 0; i < indices.size(); i++) { - index = dynamic_cast(getDim(i))->getInt() * index + - dynamic_cast(indices[i])->getInt(); + // Ensure dims[i] and indices[i] are ConstantInteger and retrieve their values correctly + auto dim_val = dynamic_cast(getDim(i)); + auto idx_val = dynamic_cast(indices[i]); + assert(dim_val && idx_val && "Dims and indices must be constant integers"); + index = dim_val->getInt() * index + idx_val->getInt(); } return getByIndex(index); } ///< 通过多维索引indices获取初始值 @@ -1331,8 +1514,11 @@ class ConstantVariable : public User, public LVal { int index = 0; // 计算偏移量 for (size_t i = 0; i < indices.size(); i++) { - index = dynamic_cast(getDim(i))->getInt() * index + - dynamic_cast(indices[i])->getInt(); + // Ensure dims[i] and indices[i] are ConstantInteger and retrieve their values correctly + auto dim_val = dynamic_cast(getDim(i)); + auto idx_val = dynamic_cast(indices[i]); + assert(dim_val && idx_val && "Dims and indices must be constant integers"); + index = dim_val->getInt() * index + idx_val->getInt(); } return getByIndex(index);