diff --git a/src/IR.cpp b/src/IR.cpp index 529e903..f961d53 100644 --- a/src/IR.cpp +++ b/src/IR.cpp @@ -13,6 +13,8 @@ #include #include #include +#include +#include using namespace std; namespace sysy { @@ -80,6 +82,15 @@ Type *Type::getFunctionType(Type *returnType, return FunctionType::get(returnType, paramTypes); } +Type *Type::getArrayType(Type *elementType, const vector &dims) { + // forward to ArrayType + return ArrayType::get(elementType, dims); +} + +ArrayType* Type::asArrayType() const { + return isArray() ? dynamic_cast(const_cast(this)) : nullptr; +} + int Type::getSize() const { switch (kind) { case kInt: @@ -177,33 +188,75 @@ bool Value::isConstant() const { return false; } -ConstantValue *ConstantValue::get(int value) { - static std::map> intConstants; - auto iter = intConstants.find(value); - if (iter != intConstants.end()) - return iter->second.get(); - auto constant = new ConstantValue(value); - assert(constant); - auto result = intConstants.emplace(value, constant); - return result.first->second.get(); + +// 定义静态常量池 +std::unordered_map ConstantValue::constantPool; + +// 常量池实现 +ConstantValue* ConstantValue::get(Type* type, int32_t value) { + ConstantValueKey key = {type, ConstantValVariant(value)}; + + if (auto it = constantPool.find(key); it != constantPool.end()) { + return it->second; + } + + ConstantValue* constant = new ConstantInt(type, value); + constantPool[key] = constant; + return constant; } -ConstantValue *ConstantValue::get(float value) { - static std::map> floatConstants; - auto iter = floatConstants.find(value); - if (iter != floatConstants.end()) - return iter->second.get(); - auto constant = new ConstantValue(value); - assert(constant); - auto result = floatConstants.emplace(value, constant); - return result.first->second.get(); +ConstantValue* ConstantValue::get(Type* type, float value) { + ConstantValueKey key = {type, ConstantValVariant(value)}; + + if (auto it = constantPool.find(key); it != constantPool.end()) { + return it->second; + } + + ConstantValue* constant = new ConstantFloat(type, value); + constantPool[key] = constant; + return constant; } -void ConstantValue::print(ostream &os) const { - if (isInt()) - os << getInt(); - else - os << getFloat(); +ConstantValue* ConstantValue::getInt32(int32_t value) { + return get(Type::getIntType(), value); +} + +ConstantValue* ConstantValue::getFloat32(float value) { + return get(Type::getFloatType(), value); +} + +ConstantValue* ConstantValue::getTrue() { + return get(Type::getIntType(), 1); +} + +ConstantValue* ConstantValue::getFalse() { + return get(Type::getIntType(), 0); +} + + + +void ConstantValue::print(std::ostream &os) const { + // 根据类型调用相应的打印实现 + if (auto intConst = dynamic_cast(this)) { + intConst->print(os); + } + else if (auto floatConst = dynamic_cast(this)) { + floatConst->print(os); + } + else { + os << "???"; // 未知常量类型 + } +} + +void ConstantInt::print(std::ostream &os) const { + os << value; +} +void ConstantFloat::print(std::ostream &os) const { + if (value == static_cast(value)) { + os << value << ".0"; // 确保输出带小数点 + } else { + os << std::fixed << std::setprecision(6) << value; + } } Argument::Argument(Type *type, BasicBlock *block, int index, diff --git a/src/IR.h b/src/IR.h index 1483395..7d228ad 100644 --- a/src/IR.h +++ b/src/IR.h @@ -11,6 +11,9 @@ #include #include #include +#include +#include +#include namespace sysy { @@ -33,6 +36,9 @@ namespace sysy { * include `int`, `float`, `void`, and the label type representing branch * targets */ + +class ArrayType; + class Type { public: enum Kind { @@ -58,9 +64,7 @@ public: static Type *getPointerType(Type *baseType); static Type *getFunctionType(Type *returnType, const std::vector ¶mTypes = {}); - static Type *getArrayType(Type *elementType, const std::vector &dims = {}) { - return ArrayType::get(elementType, dims); - } + static Type *getArrayType(Type *elementType, const std::vector &dims = {}); public: Kind getKind() const { return kind; } @@ -73,9 +77,9 @@ public: bool isArray() const { return kind == kArray; } bool isIntOrFloat() const { return kind == kInt or kind == kFloat; } int getSize() const; - ArrayType *asArrayType() const { - return isArray() ? static_cast(const_cast(this)) : nullptr; - } + + ArrayType* asArrayType() const; + template std::enable_if_t, T *> as() const { return dynamic_cast(const_cast(this)); @@ -335,41 +339,114 @@ public: * `ConstantValue`s are not defined by instructions, and do not use any other * `Value`s. It's type is either `int` or `float`. */ + +class ConstantInt; +class ConstantFloat; +//常量池优化 + +using ConstantValVariant = std::variant; +using ConstantValueKey = std::pair; + class ConstantValue : public Value { protected: - union { - int iScalar; - float fScalar; - }; - -protected: - ConstantValue(int value) - : Value(kConstant, Type::getIntType(), ""), iScalar(value) {} - ConstantValue(float value) - : Value(kConstant, Type::getFloatType(), ""), fScalar(value) {} - + ConstantValue(Type* type) + : Value(kConstant, type, "") {} public: - static ConstantValue *get(int value); - static ConstantValue *get(float value); - -public: - static bool classof(const Value *value) { + struct ConstantValueHash; + struct ConstantValueEqual; + + static std::unordered_map constantPool; + + virtual ~ConstantValue() = default; + + static ConstantValue* get(Type* type, int32_t value); + static ConstantValue* get(Type* type, float value); + + static bool classof(const Value* value) { return value->getKind() == kConstant; } + + virtual int32_t getInt() const = 0; + virtual float getFloat() const = 0; + virtual bool isZero() const = 0; + virtual bool isOne() const = 0; + + + static ConstantValue* getInt32(int32_t value); + static ConstantValue* getFloat32(float value); + static ConstantValue* getTrue() ; + static ConstantValue* getFalse(); -public: - int getInt() const { - assert(isInt()); - return iScalar; - } - float getFloat() const { - assert(isFloat()); - return fScalar; - } - -public: void print(std::ostream &os) const override; -}; // class ConstantValue +}; + +struct ConstantValue::ConstantValueHash { + std::size_t operator()(const ConstantValueKey& key) const { + std::size_t typeHash = std::hash{}(key.first); + std::size_t valHash = 0; + if (key.first->isInt()) { + valHash = std::hash{}(std::get(key.second)); + } else if (key.first->isFloat()) { + // 修复5: 确保float哈希正确 + valHash = std::hash{}(std::get(key.second)); + } + return typeHash ^ (valHash << 1); + } +}; + +struct ConstantValue::ConstantValueEqual { + bool operator()(const ConstantValueKey& lhs, const ConstantValueKey& rhs) const { + if (lhs.first != rhs.first) return false; + if (lhs.first->isInt()) { + return std::get(lhs.second) == std::get(rhs.second); + } else if (lhs.first->isFloat()) { + // 修复6: 使用浮点比较容差 + const float eps = 1e-6; + return fabs(std::get(lhs.second) - std::get(rhs.second)) < eps; + } + return false; + } +}; + +class ConstantInt : public ConstantValue { + int32_t value; + friend class ConstantValue; + +protected: + ConstantInt(Type* type, int32_t value) + : ConstantValue(type), value(value) { + assert(type->isInt() && "Invalid type for ConstantInt"); + } +public: + static ConstantInt* get(Type* type, int32_t value); + + int32_t getInt() const override { return value; } + float getFloat() const override { return static_cast(value); } + bool isZero() const override { return value == 0; } + bool isOne() const override { return value == 1; } + + void print(std::ostream& os) const override ; +}; + +class ConstantFloat : public ConstantValue { + float value; + friend class ConstantValue; + +protected: + ConstantFloat(Type* type, float value) + : ConstantValue(type), value(value) { + assert(type->isFloat() && "Invalid type for ConstantFloat"); + } +public: + static ConstantFloat* get(Type* type, float value); + + int32_t getInt() const override { return static_cast(value); } + float getFloat() const override { return value; } + bool isZero() const override { return value == 0.0f; } + bool isOne() const override { return value == 1.0f; } + + void print(std::ostream& os) const override; +}; class BasicBlock; /*! diff --git a/src/LLVMIRGenerator_1.cpp b/src/LLVMIRGenerator_1.cpp index 515b5a2..965eb82 100644 --- a/src/LLVMIRGenerator_1.cpp +++ b/src/LLVMIRGenerator_1.cpp @@ -91,7 +91,7 @@ std::any LLVMIRGenerator::visitVarDecl(SysYParser::VarDeclContext* ctx) { if (varDef->ASSIGN()) { value = std::any_cast(varDef->initVal()->accept(this)); - if (irTmpTable.find(value) != irTmpTable.end() && isa(irTmpTable[value])) { + if (irTmpTable.find(value) != irTmpTable.end() && sysy::isa(irTmpTable[value])) { initValue = irTmpTable[value]; } } @@ -134,7 +134,7 @@ std::any LLVMIRGenerator::visitConstDecl(SysYParser::ConstDeclContext* ctx) { try { value = std::any_cast(constDef->constInitVal()->accept(this)); - if (isa(irTmpTable[value])) { + if (sysy::isa(irTmpTable[value])) { initValue = irTmpTable[value]; } } catch (...) { @@ -310,7 +310,7 @@ std::any LLVMIRGenerator::visitFuncDef(SysYParser::FuncDefContext* ctx) { } else { irStream << " ret " << currentReturnType << " 0\n"; sysy::IRBuilder builder(currentIRBlock); - builder.createReturnInst(sysy::ConstantValue::get(0)); + builder.createReturnInst(sysy::ConstantValue::get(getIRType("int"),0)); } } irStream << "}\n"; @@ -524,10 +524,10 @@ std::any LLVMIRGenerator::visitNumber(SysYParser::NumberContext* ctx) { sysy::Value* irValue = nullptr; if (ctx->ILITERAL()) { value = ctx->ILITERAL()->getText(); - irValue = sysy::ConstantValue::get(std::stoi(value)); + irValue = sysy::ConstantValue::get(getIRType("int"), std::stoi(value)); } else if (ctx->FLITERAL()) { value = ctx->FLITERAL()->getText(); - irValue = sysy::ConstantValue::get(std::stof(value)); + irValue = sysy::ConstantValue::get(getIRType("float"), std::stof(value)); } else { value = ""; } diff --git a/src/SysYIRGenerator.cpp b/src/SysYIRGenerator.cpp index a2962c7..1844a53 100644 --- a/src/SysYIRGenerator.cpp +++ b/src/SysYIRGenerator.cpp @@ -552,10 +552,10 @@ std::any SysYIRGenerator::visitNumber(SysYParser::NumberContext *ctx) { } else if (text.find("0") == 0) { base = 8; } - res = ConstantValue::get((int)std::stol(text, 0, base)); + res = ConstantValue::get(Type::getIntType() ,(int)std::stol(text, 0, base)); } else if (auto fLiteral = ctx->FLITERAL()) { const auto text = fLiteral->getText(); - res = ConstantValue::get((float)std::stof(text)); + res = ConstantValue::get(Type::getFloatType(), (float)std::stof(text)); } cout << "number: "; res->print(cout);