diff --git a/src/IR.cpp b/src/IR.cpp index da4292b..faa9aed 100644 --- a/src/IR.cpp +++ b/src/IR.cpp @@ -647,7 +647,7 @@ Function * CallInst::getCallee() const { return dynamic_cast(getOper /** * 获取变量指针 */ -auto SymbolTable::getVariable(const std::string &name) const -> User * { +auto SymbolTable::getVariable(const std::string &name) const -> Value * { auto node = curNode; while (node != nullptr) { auto iter = node->varList.find(name); @@ -662,8 +662,8 @@ auto SymbolTable::getVariable(const std::string &name) const -> User * { /** * 添加变量到符号表 */ -auto SymbolTable::addVariable(const std::string &name, User *variable) -> User * { - User *result = nullptr; +auto SymbolTable::addVariable(const std::string &name, Value *variable) -> Value * { + Value *result = nullptr; if (curNode != nullptr) { std::stringstream ss; auto iter = variableIndex.find(name); diff --git a/src/SysYIRGenerator.cpp b/src/SysYIRGenerator.cpp index 9f0f012..08332ec 100644 --- a/src/SysYIRGenerator.cpp +++ b/src/SysYIRGenerator.cpp @@ -450,44 +450,79 @@ std::any SysYIRGenerator::visitBlockStmt(SysYParser::BlockStmtContext *ctx) { std::any SysYIRGenerator::visitAssignStmt(SysYParser::AssignStmtContext *ctx) { auto lVal = ctx->lValue(); std::string name = lVal->Ident()->getText(); - std::vector dims; - for (const auto &exp : lVal->exp()) { - dims.push_back(std::any_cast(visitExp(exp))); + Value* LValue = nullptr; + Value* variable = module->getVariable(name); // 左值 + + vector indices; + if (lVal->exp().size() > 0) { + // 如果有下标,访问表达式获取下标值 + for (const auto &exp : lVal->exp()) { + Value* indexValue = std::any_cast(visitExp(exp)); + indices.push_back(indexValue); + } + } + if (indices.empty()) { + // variable 本身就是指向标量的指针 (e.g., int* %a) + if (dynamic_cast(variable) || dynamic_cast(variable)) { + LValue = variable; + } + } + else { + // 对于数组或多维数组的左值处理 + // 需要获取 GEP 地址 + Value* gepBasePointer = nullptr; + std::vector gepIndices; + if (AllocaInst *alloc = dynamic_cast(variable)) { + Type* allocatedType = alloc->getType()->as()->getBaseType(); + if (allocatedType->isPointer()) { + gepBasePointer = builder.createLoadInst(alloc); + gepIndices = indices; + } else { + gepBasePointer = alloc; + gepIndices.push_back(ConstantInteger::get(0)); + gepIndices.insert(gepIndices.end(), indices.begin(), indices.end()); + } + } else if (GlobalValue *glob = dynamic_cast(variable)) { + // 情况 B: 全局变量 (GlobalValue) + gepBasePointer = glob; + gepIndices.push_back(ConstantInteger::get(0)); + gepIndices.insert(gepIndices.end(), indices.begin(), indices.end()); + } else if (ConstantVariable *constV = dynamic_cast(variable)) { + gepBasePointer = constV; + gepIndices.push_back(ConstantInteger::get(0)); + gepIndices.insert(gepIndices.end(), indices.begin(), indices.end()); + } + // 左值为地址 + LValue = getGEPAddressInst(gepBasePointer, gepIndices); } - - auto variable = module->getVariable(name); // 获取 AllocaInst 或 GlobalValue - Value* value = std::any_cast(visitExp(ctx->exp())); // 右值 - Type* targetElementType = variable->getType(); // 从基指针指向的类型开始 + Value* RValue = std::any_cast(visitExp(ctx->exp())); // 右值 - //根据 dims 确定最终元素的类型 - targetElementType = builder.getIndexedType(targetElementType, dims); + // 先推断 LValue 的类型 + // 如果 LValue 是指向数组的指针,则需要根据 indices 获取正确的类型 + // 如果 LValue 是标量,则直接使用其类型 + // 注意:LValue 的类型可能是指向数组的指针 (e.g., int(*)[3]) 或者指向标量的指针 (e.g., int*) 也能推断 + Type* LType = builder.getIndexedType(variable->getType(), indices); + Type* RType = RValue->getType(); - // 左值右值类型不同处理:根据最终元素类型进行转换 - if (targetElementType != value->getType()) { - ConstantValue * constValue = dynamic_cast(value); + if (LType != RType) { + ConstantValue * constValue = dynamic_cast(RValue); if (constValue != nullptr) { - if (targetElementType == Type::getFloatType()) { - value = ConstantFloating::get(static_cast(constValue->getFloat())); + if (LType == Type::getFloatType()) { + RValue = ConstantFloating::get(static_cast(constValue->getFloat())); } else { // 假设如果不是浮点型,就是整型 - value = ConstantInteger::get(static_cast(constValue->getInt())); + RValue = ConstantInteger::get(static_cast(constValue->getInt())); } } else { - if (targetElementType == Type::getFloatType()) { - value = builder.createIToFInst(value); + if (LType == Type::getFloatType()) { + RValue = builder.createIToFInst(RValue); } else { // 假设如果不是浮点型,就是整型 - value = builder.createFtoIInst(value); + RValue = builder.createFtoIInst(RValue); } } } - // 计算目标地址:如果 dims 为空,就是变量本身地址;否则通过 GEP 计算 - Value* targetAddress = variable; - if (!dims.empty()) { - targetAddress = getGEPAddressInst(variable, dims); - } - - builder.createStoreInst(value, targetAddress); + builder.createStoreInst(RValue, LValue); return std::any(); } @@ -711,7 +746,7 @@ unsigned SysYIRGenerator::countArrayDimensions(Type* type) { std::any SysYIRGenerator::visitLValue(SysYParser::LValueContext *ctx) { std::string name = ctx->Ident()->getText(); - User* variable = module->getVariable(name); + Value* variable = module->getVariable(name); Value* value = nullptr; diff --git a/src/include/IR.h b/src/include/IR.h index 99bf003..8f0103d 100644 --- a/src/include/IR.h +++ b/src/include/IR.h @@ -521,6 +521,7 @@ public: Function* getParent() const { return parent; } void setParent(Function *func) { parent = func; } inst_list& getInstructions() { return instructions; } + auto getInstructions_Range() const { return make_range(instructions); } arg_list& getArguments() { return arguments; } block_list& getPredecessors() { return predecessors; } void clearPredecessors() { predecessors.clear(); } @@ -1404,7 +1405,7 @@ class ConstantVariable : public User { using SymbolTableNode = struct SymbolTableNode { SymbolTableNode *pNode; ///< 父节点 std::vector children; ///< 子节点列表 - std::map varList; ///< 变量列表 + std::map varList; ///< 变量列表 }; @@ -1419,8 +1420,8 @@ class SymbolTable { public: SymbolTable() = default; - User* getVariable(const std::string &name) const; ///< 根据名字name以及当前作用域获取变量 - User* addVariable(const std::string &name, User *variable); ///< 添加变量 + Value* getVariable(const std::string &name) const; ///< 根据名字name以及当前作用域获取变量 + Value* addVariable(const std::string &name, Value *variable); ///< 添加变量 std::vector>& getGlobals(); ///< 获取全局变量列表 const std::vector>& getConsts() const; ///< 获取常量列表 void enterNewScope(); ///< 进入新的作用域 @@ -1482,7 +1483,7 @@ class Module { void addVariable(const std::string &name, AllocaInst *variable) { variableTable.addVariable(name, variable); } ///< 添加变量 - User* getVariable(const std::string &name) { + Value* getVariable(const std::string &name) { return variableTable.getVariable(name); } ///< 根据名字name和当前作用域获取变量 Function* getFunction(const std::string &name) const {