diff --git a/src/include/midend/SysYIRGenerator.h b/src/include/midend/SysYIRGenerator.h index aac6ec9..bd671ee 100644 --- a/src/include/midend/SysYIRGenerator.h +++ b/src/include/midend/SysYIRGenerator.h @@ -59,6 +59,35 @@ private: std::unique_ptr module; IRBuilder builder; + using ValueOrOperator = std::variant; + std::vector BinaryExpStack; ///< 用于存储二元表达式的中缀表达式 + std::vector BinaryExpLenStack; ///< 用于存储该层次的二元表达式的长度 + // 下面是用于后缀表达式的计算的数据结构 + std::vector BinaryRPNStack; ///< 用于存储二元表达式的后缀表达式 + std::vector BinaryOpStack; ///< 用于存储二元表达式中缀表达式转换到后缀表达式的操作符栈 + std::vector BinaryValueStack; ///< 用于存储后缀表达式计算的操作数栈 + + // 约定操作符: + // 1: 'ADD', 2: 'SUB', 3: 'MUL', 4: 'DIV', 5: '%', 6: 'PLUS', 7: 'NEG', 8: 'NOT', 9: 'LPAREN', 10: 'RPAREN' + // 这里的操作符是为了方便后缀表达式的计算而设计 + // 其中,'ADD', 'SUB', 'MUL', 'DIV', '%' + // 分别对应加法、减法、乘法、除法和取模 + // 'PLUS' 和 'NEG' 分别对应一元加法和一元减法 + // 'NOT' 对应逻辑非 + // 'LPAREN' 和 'RPAREN' 分别对应左括号和右括号 + enum BinaryOp { + ADD = 1, SUB = 2, MUL = 3, DIV = 4, MOD = 5, PLUS = 6, NEG = 7, NOT = 8, LPAREN = 9, RPAREN = 10, + }; + int getOperatorPrecedence(int op) { + switch (op) { + case MUL: case DIV: case MOD: return 2; + case ADD: case SUB: return 1; + case PLUS: case NEG: case NOT: return 3; + case LPAREN: case RPAREN: return 0; // Parentheses have lowest precedence for stack logic + default: return -1; // Unknown operator + } + } + public: SysYIRGenerator() = default; @@ -97,7 +126,7 @@ public: std::any visitBlockStmt(SysYParser::BlockStmtContext* ctx) override; // std::any visitStmt(SysYParser::StmtContext *ctx) override; std::any visitAssignStmt(SysYParser::AssignStmtContext *ctx) override; - // std::any visitExpStmt(SysYParser::ExpStmtContext *ctx) override; + std::any visitExpStmt(SysYParser::ExpStmtContext *ctx) override; // std::any visitBlkStmt(SysYParser::BlkStmtContext *ctx) override; std::any visitIfStmt(SysYParser::IfStmtContext *ctx) override; std::any visitWhileStmt(SysYParser::WhileStmtContext *ctx) override; @@ -131,8 +160,13 @@ public: std::any visitLAndExp(SysYParser::LAndExpContext *ctx) override; std::any visitLOrExp(SysYParser::LOrExpContext *ctx) override; - // std::any visitConstExp(SysYParser::ConstExpContext *ctx) override; + std::any visitConstExp(SysYParser::ConstExpContext *ctx) override; + bool isRightAssociative(int op); + Value* promoteType(Value* value, Type* targetType); + Value* computeExp(SysYParser::ExpContext *ctx, Type* targetType = nullptr); + Value* computeAddExp(SysYParser::AddExpContext *ctx, Type* targetType = nullptr); + void compute(); public: // 获取GEP指令的地址 Value* getGEPAddressInst(Value* basePointer, const std::vector& indices); @@ -141,6 +175,7 @@ public: unsigned countArrayDimensions(Type* type); + }; // class SysYIRGenerator } // namespace sysy \ No newline at end of file diff --git a/src/midend/SysYIRGenerator.cpp b/src/midend/SysYIRGenerator.cpp index 38fa52a..d095a97 100644 --- a/src/midend/SysYIRGenerator.cpp +++ b/src/midend/SysYIRGenerator.cpp @@ -16,6 +16,438 @@ using namespace std; namespace sysy { +// std::vector BinaryValueStack; ///< 用于存储value的栈 +// std::vector BinaryOpStack; ///< 用于存储二元表达式的操作符栈 +// // 约定操作符: +// // 1: 'ADD', 2: 'SUB', 3: 'MUL', 4: 'DIV', 5: '%', 6: 'PLUS', 7: 'NEG', 8: 'NOT' +// enum BinaryOp { +// ADD = 1, +// SUB = 2, +// MUL = 3, +// DIV = 4, +// MOD = 5, +// PLUS = 6, +// NEG = 7, +// NOT = 8 +// }; + +Value *SysYIRGenerator::promoteType(Value *value, Type *targetType) { + //如果是常量则直接返回相应的值 + ConstantInteger* constInt = dynamic_cast(value); + ConstantFloating *constFloat = dynamic_cast(value); + if (constInt) { + if (targetType->isFloat()) { + return ConstantFloating::get(static_cast(constInt->getInt())); + } + return constInt; // 如果目标类型是int,直接返回原值 + } else if (constFloat) { + if (targetType->isInt()) { + return ConstantInteger::get(static_cast(constFloat->getFloat())); + } + return constFloat; // 如果目标类型是float,直接返回原值 + } + + if (value->getType()->isInt() && targetType->isFloat()) { + return builder.createItoFInst(value); + } else if (value->getType()->isFloat() && targetType->isInt()) { + return builder.createFtoIInst(value); + } + // 如果类型已经匹配,直接返回原值 + return value; +} + +bool SysYIRGenerator::isRightAssociative(int op) { + return (op == BinaryOp::PLUS || op == BinaryOp::NEG || op == BinaryOp::NOT); +} + +void SysYIRGenerator::compute() { + + // 先将中缀表达式转换为后缀表达式 + BinaryRPNStack.clear(); + BinaryOpStack.clear(); + + int begin = BinaryExpStack.size() - BinaryExpLenStack.back(), end = BinaryExpStack.size(); + + for (int i = begin; i < end; i++) { + auto item = BinaryExpStack[i]; + if (std::holds_alternative(item)) { + // 如果是操作数 (Value*),直接推入后缀表达式栈 + BinaryRPNStack.push_back(item); // 直接 push_back item (ValueOrOperator类型) + } else { + // 如果是操作符 + int currentOp = std::get(item); + + if (currentOp == LPAREN) { + // 左括号直接入栈 + BinaryOpStack.push_back(currentOp); + } else if (currentOp == RPAREN) { + // 右括号:将操作符栈中的操作符弹出并添加到后缀表达式栈,直到遇到左括号 + while (!BinaryOpStack.empty() && BinaryOpStack.back() != LPAREN) { + BinaryRPNStack.push_back(BinaryOpStack.back()); // 直接 push_back int + BinaryOpStack.pop_back(); + } + if (!BinaryOpStack.empty() && BinaryOpStack.back() == LPAREN) { + BinaryOpStack.pop_back(); // 弹出左括号,但不添加到后缀表达式栈 + } else { + // 错误:不匹配的右括号 + std::cerr << "Error: Mismatched parentheses in expression." << std::endl; + return; + } + } else { + // 普通操作符 + while (!BinaryOpStack.empty() && BinaryOpStack.back() != LPAREN) { + + int stackTopOp = BinaryOpStack.back(); + // 如果当前操作符优先级低于栈顶操作符优先级 + // 或者 (当前操作符优先级等于栈顶操作符优先级 并且 栈顶操作符是左结合) + if (getOperatorPrecedence(currentOp) < getOperatorPrecedence(stackTopOp) || + (getOperatorPrecedence(currentOp) == getOperatorPrecedence(stackTopOp) && + !isRightAssociative(stackTopOp))) { + + BinaryRPNStack.push_back(stackTopOp); + BinaryOpStack.pop_back(); + } else { + break; // 否则当前操作符入栈 + } + } + BinaryOpStack.push_back(currentOp); // 当前操作符入栈 + } + } + } + // 遍历结束后,将操作符栈中剩余的所有操作符弹出并添加到后缀表达式栈 + while (!BinaryOpStack.empty()) { + if (BinaryOpStack.back() == LPAREN) { + // 错误:不匹配的左括号 + std::cerr << "Error: Mismatched parentheses in expression (unclosed parenthesis)." << std::endl; + return; + } + BinaryRPNStack.push_back(BinaryOpStack.back()); // 直接 push_back int + BinaryOpStack.pop_back(); + } + + // 弹出BinaryExpStack的表达式 + while(begin < end) { + BinaryExpStack.pop_back(); + BinaryExpLenStack.back()--; + end--; + } + + // 计算后缀表达式 + // 每次计算前清空操作数栈 + BinaryValueStack.clear(); + + // 遍历后缀表达式栈 + Type *commonType = nullptr; + for(const auto &item : BinaryRPNStack) { + if (std::holds_alternative(item)) { + // 如果是操作数 (Value*) 检测他的类型 + Value *value = std::get(item); + if (commonType == nullptr) { + commonType = value->getType(); + } + else if (value->getType() != commonType && value->getType()->isFloat()) { + // 如果当前值的类型与commonType不同且是float类型,则提升为float + commonType = Type::getFloatType(); + break; + } + } else { + continue; + } + } + + for (const auto &item : BinaryRPNStack) { + if (std::holds_alternative(item)) { + // 如果是操作数 (Value*),直接推入操作数栈 + BinaryValueStack.push_back(std::get(item)); + } else { + // 如果是操作符 + int op = std::get(item); + Value *resultValue = nullptr; + Value *lhs = nullptr; + Value *rhs = nullptr; + Value *operand = nullptr; + + switch (op) { + case BinaryOp::ADD: + case BinaryOp::SUB: + case BinaryOp::MUL: + case BinaryOp::DIV: + case BinaryOp::MOD: { + // 二元操作符需要两个操作数 + if (BinaryValueStack.size() < 2) { + std::cerr << "Error: Not enough operands for binary operation: " << op << std::endl; + return; // 或者抛出异常 + } + rhs = BinaryValueStack.back(); + BinaryValueStack.pop_back(); + lhs = BinaryValueStack.back(); + BinaryValueStack.pop_back(); + // 类型转换 + lhs = promoteType(lhs, commonType); + rhs = promoteType(rhs, commonType); + + // 尝试常量折叠 + ConstantValue *lhsConst = dynamic_cast(lhs); + ConstantValue *rhsConst = dynamic_cast(rhs); + + if (lhsConst && rhsConst) { + // 如果都是常量,直接计算结果 + if (commonType == Type::getIntType()) { + int lhsVal = lhsConst->getInt(); + int rhsVal = rhsConst->getInt(); + switch (op) { + case BinaryOp::ADD: resultValue = ConstantInteger::get(lhsVal + rhsVal); break; + case BinaryOp::SUB: resultValue = ConstantInteger::get(lhsVal - rhsVal); break; + case BinaryOp::MUL: resultValue = ConstantInteger::get(lhsVal * rhsVal); break; + case BinaryOp::DIV: + if (rhsVal == 0) { + std::cerr << "Error: Division by zero." << std::endl; + return; + } + resultValue = sysy::ConstantInteger::get(lhsVal / rhsVal); break; + case BinaryOp::MOD: + if (rhsVal == 0) { + std::cerr << "Error: Modulo by zero." << std::endl; + return; + } + resultValue = sysy::ConstantInteger::get(lhsVal % rhsVal); break; + default: + std::cerr << "Error: Unknown binary operator for constants: " << op << std::endl; + return; + } + } else if (commonType == Type::getFloatType()) { + float lhsVal = lhsConst->getFloat(); + float rhsVal = rhsConst->getFloat(); + switch (op) { + case BinaryOp::ADD: resultValue = ConstantFloating::get(lhsVal + rhsVal); break; + case BinaryOp::SUB: resultValue = ConstantFloating::get(lhsVal - rhsVal); break; + case BinaryOp::MUL: resultValue = ConstantFloating::get(lhsVal * rhsVal); break; + case BinaryOp::DIV: + if (rhsVal == 0.0f) { + std::cerr << "Error: Division by zero." << std::endl; + return; + } + resultValue = sysy::ConstantFloating::get(lhsVal / rhsVal); break; + case BinaryOp::MOD: + std::cerr << "Error: Modulo operator not supported for float types." << std::endl; + return; + default: + std::cerr << "Error: Unknown binary operator for float constants: " << op << std::endl; + return; + } + } else { + std::cerr << "Error: Unsupported type for binary constant operation." << std::endl; + return; + } + } else { + // 否则,创建相应的IR指令 + if (commonType == Type::getIntType()) { + switch (op) { + case BinaryOp::ADD: resultValue = builder.createAddInst(lhs, rhs); break; + case BinaryOp::SUB: resultValue = builder.createSubInst(lhs, rhs); break; + case BinaryOp::MUL: resultValue = builder.createMulInst(lhs, rhs); break; + case BinaryOp::DIV: resultValue = builder.createDivInst(lhs, rhs); break; + case BinaryOp::MOD: resultValue = builder.createRemInst(lhs, rhs); break; + } + } else if (commonType == Type::getFloatType()) { + switch (op) { + case BinaryOp::ADD: resultValue = builder.createFAddInst(lhs, rhs); break; + case BinaryOp::SUB: resultValue = builder.createFSubInst(lhs, rhs); break; + case BinaryOp::MUL: resultValue = builder.createFMulInst(lhs, rhs); break; + case BinaryOp::DIV: resultValue = builder.createFDivInst(lhs, rhs); break; + case BinaryOp::MOD: + std::cerr << "Error: Modulo operator not supported for float types." << std::endl; + return; + } + } else { + std::cerr << "Error: Unsupported type for binary instruction." << std::endl; + return; + } + } + break; + } + case BinaryOp::PLUS: + case BinaryOp::NEG: + case BinaryOp::NOT: { + // 一元操作符需要一个操作数 + if (BinaryValueStack.empty()) { + std::cerr << "Error: Not enough operands for unary operation: " << op << std::endl; + return; + } + operand = BinaryValueStack.back(); + BinaryValueStack.pop_back(); + + operand = promoteType(operand, commonType); + + // 尝试常量折叠 + ConstantInteger *constInt = dynamic_cast(operand); + ConstantFloating *constFloat = dynamic_cast(operand); + + if (constInt || constFloat) { + // 如果是常量,直接计算结果 + switch (op) { + case BinaryOp::PLUS: resultValue = operand; break; + case BinaryOp::NEG: { + if (constInt) { + resultValue = constInt->getNeg(); + } else if (constFloat) { + resultValue = constFloat->getNeg(); + } else { + std::cerr << "Error: Negation not supported for constant operand type." << std::endl; + return; + } + break; + } + case BinaryOp::NOT: + if (constInt) { + resultValue = sysy::ConstantInteger::get(constInt->getInt() == 0 ? 1 : 0); + } else if (constFloat) { + resultValue = sysy::ConstantInteger::get(constFloat->getFloat() == 0.0f ? 1 : 0); + } else { + std::cerr << "Error: Logical NOT not supported for constant operand type." << std::endl; + return; + } + break; + default: + std::cerr << "Error: Unknown unary operator for constants: " << op << std::endl; + return; + } + } else { + // 否则,创建相应的IR指令 + switch (op) { + case BinaryOp::PLUS: + resultValue = operand; // 一元加指令通常直接返回操作数 + break; + case BinaryOp::NEG: { + if (commonType == sysy::Type::getIntType()) { + resultValue = builder.createNegInst(operand); + } else if (commonType == sysy::Type::getFloatType()) { + resultValue = builder.createFNegInst(operand); + } else { + std::cerr << "Error: Negation not supported for operand type." << std::endl; + return; + } + break; + } + case BinaryOp::NOT: + // 逻辑非 + if (commonType == sysy::Type::getIntType()) { + resultValue = builder.createNotInst(operand); + } else if (commonType == sysy::Type::getFloatType()) { + resultValue = builder.createFNotInst(operand); + } else { + std::cerr << "Error: Logical NOT not supported for operand type." << std::endl; + return; + } + break; + default: + std::cerr << "Error: Unknown unary operator for instructions: " << op << std::endl; + return; + } + } + break; + } + default: + std::cerr << "Error: Unknown operator " << op << " encountered in RPN stack." << std::endl; + return; + } + + // 将计算结果或指令结果推入操作数栈 + if (resultValue) { + BinaryValueStack.push_back(resultValue); + } else { + std::cerr << "Error: Result value is null after processing operator " << op << "!" << std::endl; + return; + } + + } + } + + // 后缀表达式处理完毕,操作数栈的栈顶就是最终结果 + if (BinaryValueStack.empty()) { + std::cerr << "Error: No values left in BinaryValueStack after processing RPN." << std::endl; + return; + } + if (BinaryValueStack.size() > 1) { + std::cerr + << "Warning: Multiple values left in BinaryValueStack after processing RPN. Expression might be malformed." + << std::endl; + } + BinaryRPNStack.clear(); // 清空后缀表达式栈 + BinaryOpStack.clear(); // 清空操作符栈 + return; +} + +Value* SysYIRGenerator::computeExp(SysYParser::ExpContext *ctx, Type* targetType){ + if (ctx->addExp() == nullptr) { + assert(false && "ExpContext should have an addExp child!"); + } + BinaryExpLenStack.push_back(0); // 进入新的层次时Push 0 + visitAddExp(ctx->addExp()); + + if(targetType == nullptr) { + targetType = Type::getIntType(); // 默认目标类型为int + } + + compute(); + // 最后一个Value应该是最终结果 + + Value* result = BinaryValueStack.back(); + BinaryValueStack.pop_back(); // 移除结果值 + + result = promoteType(result, targetType); // 确保结果类型符合目标类型 + // 检查当前层次的操作符数量 + int ExpLen = BinaryExpLenStack.back(); + BinaryExpLenStack.pop_back(); // 离开层次时将该层次 + if (ExpLen > 0) { + std::cerr << "Warning: There are still " << ExpLen << " binary val or op left unprocessed in this level!" << std::endl; + return nullptr; + } + return result; +} + +Value* SysYIRGenerator::computeAddExp(SysYParser::AddExpContext *ctx, Type* targetType){ + // 根据AddExpContext中的操作符和操作数计算加法表达式 + // 这里假设AddExpContext已经被正确填充 + if (ctx->mulExp().size() == 0) { + assert(false && "AddExpContext should have a mulExp child!"); + } + BinaryExpLenStack.push_back(0); // 进入新的层次时Push 0 + visitMulExp(ctx->mulExp(0)); + // BinaryValueStack.push_back(result); + + for (int i = 1; i < ctx->mulExp().size(); i++) { + auto opNode = dynamic_cast(ctx->children[2*i-1]); + int opType = opNode->getSymbol()->getType(); + switch(opType) { + case SysYParser::ADD: BinaryExpStack.push_back(BinaryOp::ADD); BinaryExpLenStack.back()++; break; + case SysYParser::SUB: BinaryExpStack.push_back(BinaryOp::SUB); BinaryExpLenStack.back()++; break; + default: assert(false && "Unexpected operator in AddExp."); + } + // BinaryExpStack.push_back(opType); + visitMulExp(ctx->mulExp(i)); + // BinaryValueStack.push_back(operand); + } + if(targetType == nullptr) { + targetType = Type::getIntType(); // 默认目标类型为int + } + // 根据后缀表达式的逻辑计算 + compute(); + // 最后一个Value应该是最终结果 + + Value* result = BinaryValueStack.back(); + BinaryValueStack.pop_back(); // 移除最后一个值,因为它已经被计算 + result = promoteType(result, targetType); // 确保结果类型符合目标类型 + + int ExpLen = BinaryExpLenStack.back(); + BinaryExpLenStack.pop_back(); // 离开层次时将该层次 + if (ExpLen > 0) { + std::cerr << "Warning: There are still " << ExpLen << " binary val or op left unprocessed in this level!" << std::endl; + return nullptr; + } + return result; +} + Type* SysYIRGenerator::buildArrayType(Type* baseType, const std::vector& dims){ Type* currentType = baseType; // 从最内层维度开始构建 ArrayType @@ -132,8 +564,8 @@ std::any SysYIRGenerator::visitGlobalVarDecl(SysYParser::GlobalVarDeclContext *c return std::any(); } -std::any SysYIRGenerator::visitConstDecl(SysYParser::ConstDeclContext *ctx){ - Type* type = std::any_cast(visitBType(ctx->bType())); +std::any SysYIRGenerator::visitConstDecl(SysYParser::ConstDeclContext *ctx) { + Type *type = std::any_cast(visitBType(ctx->bType())); for (const auto constDef : ctx->constDef()) { std::vector dims = {}; std::string name = constDef->Ident()->getText(); @@ -144,19 +576,112 @@ std::any SysYIRGenerator::visitConstDecl(SysYParser::ConstDeclContext *ctx){ } } - ArrayValueTree* root = std::any_cast(constDef->constInitVal()->accept(this)); + Type *variableType = type; + if (!dims.empty()) { + variableType = buildArrayType(type, dims); // 构建完整的 ArrayType + } + + // 显式地为局部常量在栈上分配空间 + // alloca 的类型将是指针指向常量类型,例如 `int*` 或 `int[2][3]*` + AllocaInst *alloca = builder.createAllocaInst(Type::getPointerType(variableType), {}, name); + + ArrayValueTree *root = std::any_cast(constDef->constInitVal()->accept(this)); ValueCounter values; Utils::tree2Array(type, root, dims, dims.size(), values, &builder); delete root; - // 创建局部常量,并更新符号表 - Type* variableType = type; - if (!dims.empty()) { - variableType = buildArrayType(type, dims); // 构建完整的 ArrayType + // 根据维度信息进行 store 初始化 + if (dims.empty()) { // 标量常量初始化 + // 局部常量必须有初始值,且通常是单个值 + if (!values.getValues().empty()) { + builder.createStoreInst(values.getValue(0), alloca); + } else { + // 错误处理:局部标量常量缺少初始化值 + // 或者可以考虑默认初始化为0,但这通常不符合常量的语义 + assert(false && "Local scalar constant must have an initialization value!"); + return std::any(); // 直接返回,避免继续执行 + } + } else { // 数组常量初始化 + const std::vector &counterValues = values.getValues(); + const std::vector &counterNumbers = values.getNumbers(); + int numElements = 1; + std::vector dimSizes; + for (Value *dimVal : dims) { + if (ConstantInteger *constInt = dynamic_cast(dimVal)) { + int dimSize = constInt->getInt(); + numElements *= dimSize; + dimSizes.push_back(dimSize); + } + // TODO else 错误处理:数组维度必须是常量(对于静态分配) + else { + assert(false && "Array dimension must be a constant integer!"); + return std::any(); // 直接返回,避免继续执行 + } + } + unsigned int elementSizeInBytes = type->getSize(); + unsigned int totalSizeInBytes = numElements * elementSizeInBytes; + + // 检查是否所有初始化值都是零 + bool allValuesAreZero = false; + if (counterValues.empty()) { // 如果没有提供初始化值,通常视为全零初始化 + allValuesAreZero = true; + } else { + allValuesAreZero = true; + for (Value *val : counterValues) { + if (ConstantInteger *constInt = dynamic_cast(val)) { + if (constInt->getInt() != 0) { + allValuesAreZero = false; + break; + } + } else { // 如果不是常量整数,则不能确定是零 + allValuesAreZero = false; + break; + } + } + } + + if (allValuesAreZero) { + builder.createMemsetInst(alloca, ConstantInteger::get(0), ConstantInteger::get(totalSizeInBytes), + ConstantInteger::get(0)); + } else { + int linearIndexOffset = 0; // 用于追踪当前处理的线性索引的偏移量 + for (int k = 0; k < counterValues.size(); ++k) { + // 当前 Value 的值和重复次数 + Value *currentValue = counterValues[k]; + unsigned currentRepeatNum = counterNumbers[k]; + + for (unsigned i = 0; i < currentRepeatNum; ++i) { + std::vector currentIndices; + int tempLinearIndex = linearIndexOffset + i; // 使用偏移量和当前重复次数内的索引 + + // 将线性索引转换为多维索引 + for (int dimIdx = dimSizes.size() - 1; dimIdx >= 0; --dimIdx) { + currentIndices.insert(currentIndices.begin(), + ConstantInteger::get(static_cast(tempLinearIndex % dimSizes[dimIdx]))); + tempLinearIndex /= dimSizes[dimIdx]; + } + + // 对于局部数组,alloca 本身就是 GEP 的基指针。 + // GEP 的第一个索引必须是 0,用于“步过”整个数组。 + std::vector gepIndicesForInit; + gepIndicesForInit.push_back(ConstantInteger::get(0)); + gepIndicesForInit.insert(gepIndicesForInit.end(), currentIndices.begin(), currentIndices.end()); + + // 计算元素的地址 + Value *elementAddress = getGEPAddressInst(alloca, gepIndicesForInit); + // 生成 store 指令 + builder.createStoreInst(currentValue, elementAddress); + } + // 更新线性索引偏移量,以便下一次迭代从正确的位置开始 + linearIndexOffset += currentRepeatNum; + } + } } - module->createConstVar(name, Type::getPointerType(variableType), values, dims); + + // 更新符号表,将常量名称与 AllocaInst 关联起来 + module->addVariable(name, alloca); } - return 0; + return std::any(); } std::any SysYIRGenerator::visitVarDecl(SysYParser::VarDeclContext *ctx) { @@ -300,7 +825,8 @@ std::any SysYIRGenerator::visitBType(SysYParser::BTypeContext *ctx) { } std::any SysYIRGenerator::visitScalarInitValue(SysYParser::ScalarInitValueContext *ctx) { - Value* value = std::any_cast(visitExp(ctx->exp())); + // Value* value = std::any_cast(visitExp(ctx->exp())); + Value* value = computeExp(ctx->exp()); ArrayValueTree* result = new ArrayValueTree(); result->setValue(value); return result; @@ -315,13 +841,17 @@ std::any SysYIRGenerator::visitArrayInitValue(SysYParser::ArrayInitValueContext return result; } -std::any SysYIRGenerator::visitConstScalarInitValue(SysYParser::ConstScalarInitValueContext *ctx) { +std::any SysYIRGenerator::visitConstScalarInitValue(SysYParser::ConstScalarInitValueContext *ctx) { Value* value = std::any_cast(visitConstExp(ctx->constExp())); ArrayValueTree* result = new ArrayValueTree(); result->setValue(value); return result; } +std::any SysYIRGenerator::visitConstExp(SysYParser::ConstExpContext *ctx){ + return computeAddExp(ctx->addExp()); +} + std::any SysYIRGenerator::visitConstArrayInitValue(SysYParser::ConstArrayInitValueContext *ctx) { std::vector children; for (const auto &constInitVal : ctx->constInitVal()) @@ -477,8 +1007,8 @@ std::any SysYIRGenerator::visitAssignStmt(SysYParser::AssignStmtContext *ctx) { vector indices; if (lVal->exp().size() > 0) { // 如果有下标,访问表达式获取下标值 - for (const auto &exp : lVal->exp()) { - Value* indexValue = std::any_cast(visitExp(exp)); + for (auto &exp : lVal->exp()) { + Value* indexValue = std::any_cast(computeExp(exp)); indices.push_back(indexValue); } } @@ -517,15 +1047,18 @@ std::any SysYIRGenerator::visitAssignStmt(SysYParser::AssignStmtContext *ctx) { LValue = getGEPAddressInst(gepBasePointer, gepIndices); } - Value* RValue = std::any_cast(visitExp(ctx->exp())); // 右值 + // Value* RValue = std::any_cast(visitExp(ctx->exp())); // 右值 // 先推断 LValue 的类型 // 如果 LValue 是指向数组的指针,则需要根据 indices 获取正确的类型 // 如果 LValue 是标量,则直接使用其类型 // 注意:LValue 的类型可能是指向数组的指针 (e.g., int(*)[3]) 或者指向标量的指针 (e.g., int*) 也能推断 Type* LType = builder.getIndexedType(variable->getType(), indices); + + Value* RValue = computeExp(ctx->exp(), LType); // 右值计算 Type* RType = RValue->getType(); + // TODO:computeExp处理了类型转换,可以考虑删除判断逻辑 if (LType != RType) { ConstantValue *constValue = dynamic_cast(RValue); if (constValue != nullptr) { @@ -549,7 +1082,7 @@ std::any SysYIRGenerator::visitAssignStmt(SysYParser::AssignStmtContext *ctx) { } } else { if (LType == Type::getFloatType()) { - RValue = builder.createIToFInst(RValue); + RValue = builder.createItoFInst(RValue); } else { // 假设如果不是浮点型,就是整型 RValue = builder.createFtoIInst(RValue); } @@ -562,6 +1095,14 @@ std::any SysYIRGenerator::visitAssignStmt(SysYParser::AssignStmtContext *ctx) { } +std::any SysYIRGenerator::visitExpStmt(SysYParser::ExpStmtContext *ctx) { + // 访问表达式 + if (ctx->exp() != nullptr) { + computeExp(ctx->exp()); + } + return std::any(); +} + std::any SysYIRGenerator::visitIfStmt(SysYParser::IfStmtContext *ctx) { // labels string stream @@ -729,11 +1270,11 @@ std::any SysYIRGenerator::visitContinueStmt(SysYParser::ContinueStmtContext *ctx std::any SysYIRGenerator::visitReturnStmt(SysYParser::ReturnStmtContext *ctx) { Value* returnValue = nullptr; - if (ctx->exp() != nullptr) { - returnValue = std::any_cast(visitExp(ctx->exp())); - } - Type* funcType = builder.getBasicBlock()->getParent()->getReturnType(); + if (ctx->exp() != nullptr) { + returnValue = computeExp(ctx->exp(), funcType); + } + // TODOL 考虑删除类型转换判断逻辑 if (returnValue != nullptr && funcType!= returnValue->getType()) { ConstantValue * constValue = dynamic_cast(returnValue); if (constValue != nullptr) { @@ -756,7 +1297,7 @@ std::any SysYIRGenerator::visitReturnStmt(SysYParser::ReturnStmtContext *ctx) { } } else { if (funcType == Type::getFloatType()) { - returnValue = builder.createIToFInst(returnValue); + returnValue = builder.createItoFInst(returnValue); } else { returnValue = builder.createFtoIInst(returnValue); } @@ -798,7 +1339,8 @@ std::any SysYIRGenerator::visitLValue(SysYParser::LValueContext *ctx) { std::vector dims; for (const auto &exp : ctx->exp()) { - dims.push_back(std::any_cast(visitExp(exp))); + Value* expValue = std::any_cast(computeExp(exp)); + dims.push_back(expValue); } // 1. 获取变量的声明维度数量 @@ -902,16 +1444,23 @@ std::any SysYIRGenerator::visitLValue(SysYParser::LValueContext *ctx) { } std::any SysYIRGenerator::visitPrimaryExp(SysYParser::PrimaryExpContext *ctx) { - if (ctx->exp() != nullptr) - return visitExp(ctx->exp()); - if (ctx->lValue() != nullptr) - return visitLValue(ctx->lValue()); - if (ctx->number() != nullptr) - return visitNumber(ctx->number()); + if (ctx->exp() != nullptr) { + BinaryExpStack.push_back(BinaryOp::LPAREN);BinaryExpLenStack.back()++; + visitExp(ctx->exp()); + BinaryExpStack.push_back(BinaryOp::RPAREN);BinaryExpLenStack.back()++; + } + + if (ctx->lValue() != nullptr) { + // 如果是 lValue,将value压入栈中 + BinaryExpStack.push_back(std::any_cast(visitLValue(ctx->lValue())));BinaryExpLenStack.back()++; + } + if (ctx->number() != nullptr) { + BinaryExpStack.push_back(std::any_cast(visitNumber(ctx->number())));BinaryExpLenStack.back()++; + } if (ctx->string() != nullptr) { cout << "String literal not supported in SysYIRGenerator." << endl; } - return visitNumber(ctx->number()); + return std::any(); } std::any SysYIRGenerator::visitNumber(SysYParser::NumberContext *ctx) { @@ -981,7 +1530,7 @@ std::any SysYIRGenerator::visitCall(SysYParser::CallContext *ctx) { if (formalParamExpectedValueType->isInt() && actualArgType->isFloat()) { args[i] = builder.createFtoIInst(args[i]); } else if (formalParamExpectedValueType->isFloat() && actualArgType->isInt()) { - args[i] = builder.createIToFInst(args[i]); + args[i] = builder.createItoFInst(args[i]); } // 2. 指针类型转换 (例如数组退化:`[N x T]*` 到 `T*`,或兼容指针类型之间) TODO:不清楚有没有这种样例 // 这种情况常见于数组参数,实参可能是一个更具体的数组指针类型, @@ -1006,235 +1555,78 @@ std::any SysYIRGenerator::visitCall(SysYParser::CallContext *ctx) { } std::any SysYIRGenerator::visitUnaryExp(SysYParser::UnaryExpContext *ctx) { - if (ctx->primaryExp() != nullptr) - return visitPrimaryExp(ctx->primaryExp()); - if (ctx->call() != nullptr) - return visitCall(ctx->call()); - - Value* value = std::any_cast(visitUnaryExp(ctx->unaryExp())); - Value* result = value; - if (ctx->unaryOp()->SUB() != nullptr) { - ConstantValue * constValue = dynamic_cast(value); - if (constValue != nullptr) { - if (constValue->isFloat()) { - result = ConstantFloating::get(-constValue->getFloat()); - } else { - result = ConstantInteger::get(-constValue->getInt()); + if (ctx->primaryExp() != nullptr) { + visitPrimaryExp(ctx->primaryExp()); + } else if (ctx->call() != nullptr) { + BinaryExpStack.push_back(std::any_cast(visitCall(ctx->call())));BinaryExpLenStack.back()++; + } else if (ctx->unaryOp() != nullptr) { + // 遇到一元操作符,将其压入 BinaryExpStack + auto opNode = dynamic_cast(ctx->unaryOp()->children[0]); + int opType = opNode->getSymbol()->getType(); + switch(opType) { + case SysYParser::ADD: BinaryExpStack.push_back(BinaryOp::PLUS); BinaryExpLenStack.back()++; break; + case SysYParser::SUB: BinaryExpStack.push_back(BinaryOp::NEG); BinaryExpLenStack.back()++; break; + case SysYParser::NOT: BinaryExpStack.push_back(BinaryOp::NOT); BinaryExpLenStack.back()++; break; + default: assert(false && "Unexpected operator in UnaryExp."); } - } else if (value != nullptr) { - if (value->getType() == Type::getIntType()) { - result = builder.createNegInst(value); - } else { - result = builder.createFNegInst(value); - } - } else { - std::cout << "UnExp: value is nullptr." << std::endl; - assert(false); - } - } else if (ctx->unaryOp()->NOT() != nullptr) { - auto constValue = dynamic_cast(value); - if (constValue != nullptr) { - if (constValue->isFloat()) { - result = - ConstantFloating::get(1 - (constValue->getFloat() != 0.0F ? 1 : 0)); - } else { - result = ConstantInteger::get(1 - (constValue->getInt() != 0 ? 1 : 0)); - } - } else if (value != nullptr) { - if (value->getType() == Type::getIntType()) { - result = builder.createNotInst(value); - } else { - result = builder.createFNotInst(value); - } - } else { - std::cout << "UnExp: value is nullptr." << std::endl; - assert(false); - } + visitUnaryExp(ctx->unaryExp()); } - return result; + return std::any(); } std::any SysYIRGenerator::visitFuncRParams(SysYParser::FuncRParamsContext *ctx) { std::vector params; - for (const auto &exp : ctx->exp()) - params.push_back(std::any_cast(visitExp(exp))); + for (const auto &exp : ctx->exp()) { + auto param = std::any_cast(computeExp(exp)); + params.push_back(param); + } + return params; } std::any SysYIRGenerator::visitMulExp(SysYParser::MulExpContext *ctx) { - Value * result = std::any_cast(visitUnaryExp(ctx->unaryExp(0))); - + visitUnaryExp(ctx->unaryExp(0)); + for (int i = 1; i < ctx->unaryExp().size(); i++) { auto opNode = dynamic_cast(ctx->children[2*i-1]); int opType = opNode->getSymbol()->getType(); - - Value* operand = std::any_cast(visitUnaryExp(ctx->unaryExp(i))); - - Type* resultType = result->getType(); - Type* operandType = operand->getType(); - Type* floatType = Type::getFloatType(); - - if (resultType == floatType || operandType == floatType) { - // 如果有一个操作数是浮点数,则将两个操作数都转换为浮点数 - if (operandType != floatType) { - ConstantValue * constValue = dynamic_cast(operand); - if (constValue != nullptr) { - if(dynamic_cast(constValue)) { - // 如果是整型常量,转换为浮点型 - operand = ConstantFloating::get(static_cast(constValue->getInt())); - } else if (dynamic_cast(constValue)) { - // 如果是浮点型常量,直接使用 - operand = ConstantFloating::get(static_cast(constValue->getFloat())); - } - } - else - operand = builder.createIToFInst(operand); - } else if (resultType != floatType) { - ConstantValue* constResult = dynamic_cast(result); - if (constResult != nullptr) { - if(dynamic_cast(constResult)) { - // 如果是整型常量,转换为浮点型 - result = ConstantFloating::get(static_cast(constResult->getInt())); - } else if (dynamic_cast(constResult)) { - // 如果是浮点型常量,直接使用 - result = ConstantFloating::get(static_cast(constResult->getFloat())); - } - } - else - result = builder.createIToFInst(result); - } - - ConstantFloating* constResult = dynamic_cast(result); - ConstantFloating* constOperand = dynamic_cast(operand); - if (opType == SysYParser::MUL) { - if ((constOperand != nullptr) && (constResult != nullptr)) { - result = ConstantFloating::get(constResult->getFloat() * - constOperand->getFloat()); - } else { - result = builder.createFMulInst(result, operand); - } - } else if (opType == SysYParser::DIV) { - if ((constOperand != nullptr) && (constResult != nullptr)) { - result = ConstantFloating::get(constResult->getFloat() / - constOperand->getFloat()); - } else { - result = builder.createFDivInst(result, operand); - } - } else { - // float类型的取模操作不允许 - std::cout << "MulExp: float type mod operation is not allowed." << std::endl; - assert(false); - } - } else { - ConstantInteger *constResult = dynamic_cast(result); - ConstantInteger *constOperand = dynamic_cast(operand); - if (opType == SysYParser::MUL) { - if ((constOperand != nullptr) && (constResult != nullptr)) - result = ConstantInteger::get(constResult->getInt() * constOperand->getInt()); - else - result = builder.createMulInst(result, operand); - } else if (opType == SysYParser::DIV) { - if ((constOperand != nullptr) && (constResult != nullptr)) - result = ConstantInteger::get(constResult->getInt() / constOperand->getInt()); - else - result = builder.createDivInst(result, operand); - } else { - if ((constOperand != nullptr) && (constResult != nullptr)) - result = ConstantInteger::get(constResult->getInt() % constOperand->getInt()); - else - result = builder.createRemInst(result, operand); - } + switch(opType) { + case SysYParser::MUL: BinaryExpStack.push_back(BinaryOp::MUL); BinaryExpLenStack.back()++; break; + case SysYParser::DIV: BinaryExpStack.push_back(BinaryOp::DIV); BinaryExpLenStack.back()++; break; + case SysYParser::MOD: BinaryExpStack.push_back(BinaryOp::MOD); BinaryExpLenStack.back()++; break; + default: assert(false && "Unexpected operator in MulExp."); } + visitUnaryExp(ctx->unaryExp(i)); } - - return result; + return std::any(); } std::any SysYIRGenerator::visitAddExp(SysYParser::AddExpContext *ctx) { - Value* result = std::any_cast(visitMulExp(ctx->mulExp(0))); + visitMulExp(ctx->mulExp(0)); for (int i = 1; i < ctx->mulExp().size(); i++) { auto opNode = dynamic_cast(ctx->children[2*i-1]); int opType = opNode->getSymbol()->getType(); - - Value* operand = std::any_cast(visitMulExp(ctx->mulExp(i))); - Type* resultType = result->getType(); - Type* operandType = operand->getType(); - Type* floatType = Type::getFloatType(); - - if (resultType == floatType || operandType == floatType) { - // 类型转换 - if (operandType != floatType) { - ConstantValue * constOperand = dynamic_cast(operand); - if (constOperand != nullptr) { - if(dynamic_cast(constOperand)) { - // 如果是整型常量,转换为浮点型 - operand = ConstantFloating::get(static_cast(constOperand->getInt())); - } else if (dynamic_cast(constOperand)) { - // 如果是浮点型常量,直接使用 - operand = ConstantFloating::get(static_cast(constOperand->getFloat())); - } - } - else - operand = builder.createIToFInst(operand); - } else if (resultType != floatType) { - ConstantValue * constResult = dynamic_cast(result); - if (constResult != nullptr) { - if(dynamic_cast(constResult)) { - // 如果是整型常量,转换为浮点型 - result = ConstantFloating::get(static_cast(constResult->getInt())); - } else if (dynamic_cast(constResult)) { - // 如果是浮点型常量,直接使用 - result = ConstantFloating::get(static_cast(constResult->getFloat())); - } - } - else - result = builder.createIToFInst(result); - } - - ConstantFloating *constResult = dynamic_cast(result); - ConstantFloating *constOperand = dynamic_cast(operand); - if (opType == SysYParser::ADD) { - if ((constResult != nullptr) && (constOperand != nullptr)) - result = ConstantFloating::get(constResult->getFloat() + constOperand->getFloat()); - else - result = builder.createFAddInst(result, operand); - } else { - if ((constResult != nullptr) && (constOperand != nullptr)) - result = ConstantFloating::get(constResult->getFloat() - constOperand->getFloat()); - else - result = builder.createFSubInst(result, operand); - } - } else { - ConstantInteger *constResult = dynamic_cast(result); - ConstantInteger *constOperand = dynamic_cast(operand); - if (opType == SysYParser::ADD) { - if ((constResult != nullptr) && (constOperand != nullptr)) - result = ConstantInteger::get(constResult->getInt() + constOperand->getInt()); - else - result = builder.createAddInst(result, operand); - } else { - if ((constResult != nullptr) && (constOperand != nullptr)) - result = ConstantInteger::get(constResult->getInt() - constOperand->getInt()); - else - result = builder.createSubInst(result, operand); - } + switch(opType) { + case SysYParser::ADD: BinaryExpStack.push_back(BinaryOp::ADD); BinaryExpLenStack.back()++; break; + case SysYParser::SUB: BinaryExpStack.push_back(BinaryOp::SUB); BinaryExpLenStack.back()++; break; + default: assert(false && "Unexpected operator in AddExp."); } + visitMulExp(ctx->mulExp(i)); } - - return result; + return std::any(); } std::any SysYIRGenerator::visitRelExp(SysYParser::RelExpContext *ctx) { - Value* result = std::any_cast(visitAddExp(ctx->addExp(0))); + Value* result = computeAddExp(ctx->addExp(0), Type::getIntType()); for (int i = 1; i < ctx->addExp().size(); i++) { auto opNode = dynamic_cast(ctx->children[2*i-1]); int opType = opNode->getSymbol()->getType(); - Value* operand = std::any_cast(visitAddExp(ctx->addExp(i))); + Value* operand = computeAddExp(ctx->addExp(i), Type::getIntType()); Type* resultType = result->getType(); Type* operandType = operand->getType(); @@ -1273,7 +1665,7 @@ std::any SysYIRGenerator::visitRelExp(SysYParser::RelExpContext *ctx) { } } else - result = builder.createIToFInst(result); + result = builder.createItoFInst(result); } if (operandType != floatType) { @@ -1287,7 +1679,7 @@ std::any SysYIRGenerator::visitRelExp(SysYParser::RelExpContext *ctx) { } } else - operand = builder.createIToFInst(operand); + operand = builder.createItoFInst(operand); } @@ -1314,6 +1706,7 @@ std::any SysYIRGenerator::visitRelExp(SysYParser::RelExpContext *ctx) { std::any SysYIRGenerator::visitEqExp(SysYParser::EqExpContext *ctx) { + // TODO:其实已经保证了result是一个int类型的值可以删除冗余判断逻辑 Value * result = std::any_cast(visitRelExp(ctx->relExp(0))); for (int i = 1; i < ctx->relExp().size(); i++) { @@ -1352,7 +1745,7 @@ std::any SysYIRGenerator::visitEqExp(SysYParser::EqExpContext *ctx) { } } else - result = builder.createIToFInst(result); + result = builder.createItoFInst(result); } if (operandType != floatType) { if (constOperand != nullptr) { @@ -1365,7 +1758,7 @@ std::any SysYIRGenerator::visitEqExp(SysYParser::EqExpContext *ctx) { } } else - operand = builder.createIToFInst(operand); + operand = builder.createItoFInst(operand); } if (opType == SysYParser::EQ) result = builder.createFCmpEQInst(result, operand); @@ -1474,7 +1867,7 @@ void Utils::tree2Array(Type *type, ArrayValueTree *root, assert(false && "Unknown constant type for float conversion."); } else - result.push_back(builder->createIToFInst(value)); + result.push_back(builder->createItoFInst(value)); } else { ConstantValue* constValue = dynamic_cast(value);