引入了常量池优化,修改constvalue类并对IR生成修复,能够编译通过

This commit is contained in:
rain2133
2025-06-19 00:18:58 +08:00
parent 1aa785efc3
commit 1de8c0e7d7
4 changed files with 194 additions and 64 deletions

View File

@@ -13,6 +13,8 @@
#include <string> #include <string>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include <variant>
#include <iomanip>
using namespace std; using namespace std;
namespace sysy { namespace sysy {
@@ -80,6 +82,15 @@ Type *Type::getFunctionType(Type *returnType,
return FunctionType::get(returnType, paramTypes); return FunctionType::get(returnType, paramTypes);
} }
Type *Type::getArrayType(Type *elementType, const vector<int> &dims) {
// forward to ArrayType
return ArrayType::get(elementType, dims);
}
ArrayType* Type::asArrayType() const {
return isArray() ? dynamic_cast<ArrayType*>(const_cast<Type*>(this)) : nullptr;
}
int Type::getSize() const { int Type::getSize() const {
switch (kind) { switch (kind) {
case kInt: case kInt:
@@ -177,33 +188,75 @@ bool Value::isConstant() const {
return false; return false;
} }
ConstantValue *ConstantValue::get(int value) {
static std::map<int, std::unique_ptr<ConstantValue>> intConstants; // 定义静态常量池
auto iter = intConstants.find(value); std::unordered_map<ConstantValueKey, ConstantValue*, ConstantValue::ConstantValueHash> ConstantValue::constantPool;
if (iter != intConstants.end())
return iter->second.get(); // 常量池实现
auto constant = new ConstantValue(value); ConstantValue* ConstantValue::get(Type* type, int32_t value) {
assert(constant); ConstantValueKey key = {type, ConstantValVariant(value)};
auto result = intConstants.emplace(value, constant);
return result.first->second.get(); if (auto it = constantPool.find(key); it != constantPool.end()) {
return it->second;
}
ConstantValue* constant = new ConstantInt(type, value);
constantPool[key] = constant;
return constant;
} }
ConstantValue *ConstantValue::get(float value) { ConstantValue* ConstantValue::get(Type* type, float value) {
static std::map<float, std::unique_ptr<ConstantValue>> floatConstants; ConstantValueKey key = {type, ConstantValVariant(value)};
auto iter = floatConstants.find(value);
if (iter != floatConstants.end()) if (auto it = constantPool.find(key); it != constantPool.end()) {
return iter->second.get(); return it->second;
auto constant = new ConstantValue(value); }
assert(constant);
auto result = floatConstants.emplace(value, constant); ConstantValue* constant = new ConstantFloat(type, value);
return result.first->second.get(); constantPool[key] = constant;
return constant;
} }
void ConstantValue::print(ostream &os) const { ConstantValue* ConstantValue::getInt32(int32_t value) {
if (isInt()) return get(Type::getIntType(), value);
os << getInt(); }
else
os << getFloat(); ConstantValue* ConstantValue::getFloat32(float value) {
return get(Type::getFloatType(), value);
}
ConstantValue* ConstantValue::getTrue() {
return get(Type::getIntType(), 1);
}
ConstantValue* ConstantValue::getFalse() {
return get(Type::getIntType(), 0);
}
void ConstantValue::print(std::ostream &os) const {
// 根据类型调用相应的打印实现
if (auto intConst = dynamic_cast<const ConstantInt*>(this)) {
intConst->print(os);
}
else if (auto floatConst = dynamic_cast<const ConstantFloat*>(this)) {
floatConst->print(os);
}
else {
os << "???"; // 未知常量类型
}
}
void ConstantInt::print(std::ostream &os) const {
os << value;
}
void ConstantFloat::print(std::ostream &os) const {
if (value == static_cast<int>(value)) {
os << value << ".0"; // 确保输出带小数点
} else {
os << std::fixed << std::setprecision(6) << value;
}
} }
Argument::Argument(Type *type, BasicBlock *block, int index, Argument::Argument(Type *type, BasicBlock *block, int index,

141
src/IR.h
View File

@@ -11,6 +11,9 @@
#include <string> #include <string>
#include <type_traits> #include <type_traits>
#include <vector> #include <vector>
#include <variant>
#include <unordered_map>
#include <cmath>
namespace sysy { namespace sysy {
@@ -33,6 +36,9 @@ namespace sysy {
* include `int`, `float`, `void`, and the label type representing branch * include `int`, `float`, `void`, and the label type representing branch
* targets * targets
*/ */
class ArrayType;
class Type { class Type {
public: public:
enum Kind { enum Kind {
@@ -58,9 +64,7 @@ public:
static Type *getPointerType(Type *baseType); static Type *getPointerType(Type *baseType);
static Type *getFunctionType(Type *returnType, static Type *getFunctionType(Type *returnType,
const std::vector<Type *> &paramTypes = {}); const std::vector<Type *> &paramTypes = {});
static Type *getArrayType(Type *elementType, const std::vector<int> &dims = {}) { static Type *getArrayType(Type *elementType, const std::vector<int> &dims = {});
return ArrayType::get(elementType, dims);
}
public: public:
Kind getKind() const { return kind; } Kind getKind() const { return kind; }
@@ -73,9 +77,9 @@ public:
bool isArray() const { return kind == kArray; } bool isArray() const { return kind == kArray; }
bool isIntOrFloat() const { return kind == kInt or kind == kFloat; } bool isIntOrFloat() const { return kind == kInt or kind == kFloat; }
int getSize() const; int getSize() const;
ArrayType *asArrayType() const {
return isArray() ? static_cast<ArrayType*>(const_cast<Type*>(this)) : nullptr; ArrayType* asArrayType() const;
}
template <typename T> template <typename T>
std::enable_if_t<std::is_base_of_v<Type, T>, T *> as() const { std::enable_if_t<std::is_base_of_v<Type, T>, T *> as() const {
return dynamic_cast<T *>(const_cast<Type *>(this)); return dynamic_cast<T *>(const_cast<Type *>(this));
@@ -335,41 +339,114 @@ public:
* `ConstantValue`s are not defined by instructions, and do not use any other * `ConstantValue`s are not defined by instructions, and do not use any other
* `Value`s. It's type is either `int` or `float`. * `Value`s. It's type is either `int` or `float`.
*/ */
class ConstantInt;
class ConstantFloat;
//常量池优化
using ConstantValVariant = std::variant<int32_t, float>;
using ConstantValueKey = std::pair<Type*, ConstantValVariant>;
class ConstantValue : public Value { class ConstantValue : public Value {
protected: protected:
union { ConstantValue(Type* type)
int iScalar; : Value(kConstant, type, "") {}
float fScalar;
};
protected:
ConstantValue(int value)
: Value(kConstant, Type::getIntType(), ""), iScalar(value) {}
ConstantValue(float value)
: Value(kConstant, Type::getFloatType(), ""), fScalar(value) {}
public: public:
static ConstantValue *get(int value); struct ConstantValueHash;
static ConstantValue *get(float value); struct ConstantValueEqual;
public: static std::unordered_map<ConstantValueKey, ConstantValue*, ConstantValueHash> constantPool;
static bool classof(const Value *value) {
virtual ~ConstantValue() = default;
static ConstantValue* get(Type* type, int32_t value);
static ConstantValue* get(Type* type, float value);
static bool classof(const Value* value) {
return value->getKind() == kConstant; return value->getKind() == kConstant;
} }
public: virtual int32_t getInt() const = 0;
int getInt() const { virtual float getFloat() const = 0;
assert(isInt()); virtual bool isZero() const = 0;
return iScalar; virtual bool isOne() const = 0;
}
float getFloat() const {
assert(isFloat()); static ConstantValue* getInt32(int32_t value);
return fScalar; static ConstantValue* getFloat32(float value);
} static ConstantValue* getTrue() ;
static ConstantValue* getFalse();
public:
void print(std::ostream &os) const override; void print(std::ostream &os) const override;
}; // class ConstantValue };
struct ConstantValue::ConstantValueHash {
std::size_t operator()(const ConstantValueKey& key) const {
std::size_t typeHash = std::hash<Type*>{}(key.first);
std::size_t valHash = 0;
if (key.first->isInt()) {
valHash = std::hash<int32_t>{}(std::get<int32_t>(key.second));
} else if (key.first->isFloat()) {
// 修复5: 确保float哈希正确
valHash = std::hash<float>{}(std::get<float>(key.second));
}
return typeHash ^ (valHash << 1);
}
};
struct ConstantValue::ConstantValueEqual {
bool operator()(const ConstantValueKey& lhs, const ConstantValueKey& rhs) const {
if (lhs.first != rhs.first) return false;
if (lhs.first->isInt()) {
return std::get<int32_t>(lhs.second) == std::get<int32_t>(rhs.second);
} else if (lhs.first->isFloat()) {
// 修复6: 使用浮点比较容差
const float eps = 1e-6;
return fabs(std::get<float>(lhs.second) - std::get<float>(rhs.second)) < eps;
}
return false;
}
};
class ConstantInt : public ConstantValue {
int32_t value;
friend class ConstantValue;
protected:
ConstantInt(Type* type, int32_t value)
: ConstantValue(type), value(value) {
assert(type->isInt() && "Invalid type for ConstantInt");
}
public:
static ConstantInt* get(Type* type, int32_t value);
int32_t getInt() const override { return value; }
float getFloat() const override { return static_cast<float>(value); }
bool isZero() const override { return value == 0; }
bool isOne() const override { return value == 1; }
void print(std::ostream& os) const override ;
};
class ConstantFloat : public ConstantValue {
float value;
friend class ConstantValue;
protected:
ConstantFloat(Type* type, float value)
: ConstantValue(type), value(value) {
assert(type->isFloat() && "Invalid type for ConstantFloat");
}
public:
static ConstantFloat* get(Type* type, float value);
int32_t getInt() const override { return static_cast<int32_t>(value); }
float getFloat() const override { return value; }
bool isZero() const override { return value == 0.0f; }
bool isOne() const override { return value == 1.0f; }
void print(std::ostream& os) const override;
};
class BasicBlock; class BasicBlock;
/*! /*!

View File

@@ -91,7 +91,7 @@ std::any LLVMIRGenerator::visitVarDecl(SysYParser::VarDeclContext* ctx) {
if (varDef->ASSIGN()) { if (varDef->ASSIGN()) {
value = std::any_cast<std::string>(varDef->initVal()->accept(this)); value = std::any_cast<std::string>(varDef->initVal()->accept(this));
if (irTmpTable.find(value) != irTmpTable.end() && isa<sysy::ConstantValue>(irTmpTable[value])) { if (irTmpTable.find(value) != irTmpTable.end() && sysy::isa<sysy::ConstantValue>(irTmpTable[value])) {
initValue = irTmpTable[value]; initValue = irTmpTable[value];
} }
} }
@@ -134,7 +134,7 @@ std::any LLVMIRGenerator::visitConstDecl(SysYParser::ConstDeclContext* ctx) {
try { try {
value = std::any_cast<std::string>(constDef->constInitVal()->accept(this)); value = std::any_cast<std::string>(constDef->constInitVal()->accept(this));
if (isa<sysy::ConstantValue>(irTmpTable[value])) { if (sysy::isa<sysy::ConstantValue>(irTmpTable[value])) {
initValue = irTmpTable[value]; initValue = irTmpTable[value];
} }
} catch (...) { } catch (...) {
@@ -310,7 +310,7 @@ std::any LLVMIRGenerator::visitFuncDef(SysYParser::FuncDefContext* ctx) {
} else { } else {
irStream << " ret " << currentReturnType << " 0\n"; irStream << " ret " << currentReturnType << " 0\n";
sysy::IRBuilder builder(currentIRBlock); sysy::IRBuilder builder(currentIRBlock);
builder.createReturnInst(sysy::ConstantValue::get(0)); builder.createReturnInst(sysy::ConstantValue::get(getIRType("int"),0));
} }
} }
irStream << "}\n"; irStream << "}\n";
@@ -524,10 +524,10 @@ std::any LLVMIRGenerator::visitNumber(SysYParser::NumberContext* ctx) {
sysy::Value* irValue = nullptr; sysy::Value* irValue = nullptr;
if (ctx->ILITERAL()) { if (ctx->ILITERAL()) {
value = ctx->ILITERAL()->getText(); value = ctx->ILITERAL()->getText();
irValue = sysy::ConstantValue::get(std::stoi(value)); irValue = sysy::ConstantValue::get(getIRType("int"), std::stoi(value));
} else if (ctx->FLITERAL()) { } else if (ctx->FLITERAL()) {
value = ctx->FLITERAL()->getText(); value = ctx->FLITERAL()->getText();
irValue = sysy::ConstantValue::get(std::stof(value)); irValue = sysy::ConstantValue::get(getIRType("float"), std::stof(value));
} else { } else {
value = ""; value = "";
} }

View File

@@ -552,10 +552,10 @@ std::any SysYIRGenerator::visitNumber(SysYParser::NumberContext *ctx) {
} else if (text.find("0") == 0) { } else if (text.find("0") == 0) {
base = 8; base = 8;
} }
res = ConstantValue::get((int)std::stol(text, 0, base)); res = ConstantValue::get(Type::getIntType() ,(int)std::stol(text, 0, base));
} else if (auto fLiteral = ctx->FLITERAL()) { } else if (auto fLiteral = ctx->FLITERAL()) {
const auto text = fLiteral->getText(); const auto text = fLiteral->getText();
res = ConstantValue::get((float)std::stof(text)); res = ConstantValue::get(Type::getFloatType(), (float)std::stof(text));
} }
cout << "number: "; cout << "number: ";
res->print(cout); res->print(cout);