diff --git a/src/include/midend/IRBuilder.h b/src/include/midend/IRBuilder.h index 760ef85..73e40e5 100644 --- a/src/include/midend/IRBuilder.h +++ b/src/include/midend/IRBuilder.h @@ -126,7 +126,7 @@ class IRBuilder { UnaryInst * createFNotInst(Value *operand, const std::string &name = "") { return createUnaryInst(Instruction::kFNot, Type::getIntType(), operand, name); } ///< 创建浮点取非指令 - UnaryInst * createIToFInst(Value *operand, const std::string &name = "") { + UnaryInst * createItoFInst(Value *operand, const std::string &name = "") { return createUnaryInst(Instruction::kItoF, Type::getFloatType(), operand, name); } ///< 创建整型转浮点指令 UnaryInst * createBitItoFInst(Value *operand, const std::string &name = "") { diff --git a/src/include/midend/SysYIRGenerator.h b/src/include/midend/SysYIRGenerator.h index aac6ec9..bd671ee 100644 --- a/src/include/midend/SysYIRGenerator.h +++ b/src/include/midend/SysYIRGenerator.h @@ -59,6 +59,35 @@ private: std::unique_ptr module; IRBuilder builder; + using ValueOrOperator = std::variant; + std::vector BinaryExpStack; ///< 用于存储二元表达式的中缀表达式 + std::vector BinaryExpLenStack; ///< 用于存储该层次的二元表达式的长度 + // 下面是用于后缀表达式的计算的数据结构 + std::vector BinaryRPNStack; ///< 用于存储二元表达式的后缀表达式 + std::vector BinaryOpStack; ///< 用于存储二元表达式中缀表达式转换到后缀表达式的操作符栈 + std::vector BinaryValueStack; ///< 用于存储后缀表达式计算的操作数栈 + + // 约定操作符: + // 1: 'ADD', 2: 'SUB', 3: 'MUL', 4: 'DIV', 5: '%', 6: 'PLUS', 7: 'NEG', 8: 'NOT', 9: 'LPAREN', 10: 'RPAREN' + // 这里的操作符是为了方便后缀表达式的计算而设计 + // 其中,'ADD', 'SUB', 'MUL', 'DIV', '%' + // 分别对应加法、减法、乘法、除法和取模 + // 'PLUS' 和 'NEG' 分别对应一元加法和一元减法 + // 'NOT' 对应逻辑非 + // 'LPAREN' 和 'RPAREN' 分别对应左括号和右括号 + enum BinaryOp { + ADD = 1, SUB = 2, MUL = 3, DIV = 4, MOD = 5, PLUS = 6, NEG = 7, NOT = 8, LPAREN = 9, RPAREN = 10, + }; + int getOperatorPrecedence(int op) { + switch (op) { + case MUL: case DIV: case MOD: return 2; + case ADD: case SUB: return 1; + case PLUS: case NEG: case NOT: return 3; + case LPAREN: case RPAREN: return 0; // Parentheses have lowest precedence for stack logic + default: return -1; // Unknown operator + } + } + public: SysYIRGenerator() = default; @@ -97,7 +126,7 @@ public: std::any visitBlockStmt(SysYParser::BlockStmtContext* ctx) override; // std::any visitStmt(SysYParser::StmtContext *ctx) override; std::any visitAssignStmt(SysYParser::AssignStmtContext *ctx) override; - // std::any visitExpStmt(SysYParser::ExpStmtContext *ctx) override; + std::any visitExpStmt(SysYParser::ExpStmtContext *ctx) override; // std::any visitBlkStmt(SysYParser::BlkStmtContext *ctx) override; std::any visitIfStmt(SysYParser::IfStmtContext *ctx) override; std::any visitWhileStmt(SysYParser::WhileStmtContext *ctx) override; @@ -131,8 +160,13 @@ public: std::any visitLAndExp(SysYParser::LAndExpContext *ctx) override; std::any visitLOrExp(SysYParser::LOrExpContext *ctx) override; - // std::any visitConstExp(SysYParser::ConstExpContext *ctx) override; + std::any visitConstExp(SysYParser::ConstExpContext *ctx) override; + bool isRightAssociative(int op); + Value* promoteType(Value* value, Type* targetType); + Value* computeExp(SysYParser::ExpContext *ctx, Type* targetType = nullptr); + Value* computeAddExp(SysYParser::AddExpContext *ctx, Type* targetType = nullptr); + void compute(); public: // 获取GEP指令的地址 Value* getGEPAddressInst(Value* basePointer, const std::vector& indices); @@ -141,6 +175,7 @@ public: unsigned countArrayDimensions(Type* type); + }; // class SysYIRGenerator } // namespace sysy \ No newline at end of file diff --git a/src/midend/SysYIRGenerator.cpp b/src/midend/SysYIRGenerator.cpp index 03f382b..d095a97 100644 --- a/src/midend/SysYIRGenerator.cpp +++ b/src/midend/SysYIRGenerator.cpp @@ -16,6 +16,438 @@ using namespace std; namespace sysy { +// std::vector BinaryValueStack; ///< 用于存储value的栈 +// std::vector BinaryOpStack; ///< 用于存储二元表达式的操作符栈 +// // 约定操作符: +// // 1: 'ADD', 2: 'SUB', 3: 'MUL', 4: 'DIV', 5: '%', 6: 'PLUS', 7: 'NEG', 8: 'NOT' +// enum BinaryOp { +// ADD = 1, +// SUB = 2, +// MUL = 3, +// DIV = 4, +// MOD = 5, +// PLUS = 6, +// NEG = 7, +// NOT = 8 +// }; + +Value *SysYIRGenerator::promoteType(Value *value, Type *targetType) { + //如果是常量则直接返回相应的值 + ConstantInteger* constInt = dynamic_cast(value); + ConstantFloating *constFloat = dynamic_cast(value); + if (constInt) { + if (targetType->isFloat()) { + return ConstantFloating::get(static_cast(constInt->getInt())); + } + return constInt; // 如果目标类型是int,直接返回原值 + } else if (constFloat) { + if (targetType->isInt()) { + return ConstantInteger::get(static_cast(constFloat->getFloat())); + } + return constFloat; // 如果目标类型是float,直接返回原值 + } + + if (value->getType()->isInt() && targetType->isFloat()) { + return builder.createItoFInst(value); + } else if (value->getType()->isFloat() && targetType->isInt()) { + return builder.createFtoIInst(value); + } + // 如果类型已经匹配,直接返回原值 + return value; +} + +bool SysYIRGenerator::isRightAssociative(int op) { + return (op == BinaryOp::PLUS || op == BinaryOp::NEG || op == BinaryOp::NOT); +} + +void SysYIRGenerator::compute() { + + // 先将中缀表达式转换为后缀表达式 + BinaryRPNStack.clear(); + BinaryOpStack.clear(); + + int begin = BinaryExpStack.size() - BinaryExpLenStack.back(), end = BinaryExpStack.size(); + + for (int i = begin; i < end; i++) { + auto item = BinaryExpStack[i]; + if (std::holds_alternative(item)) { + // 如果是操作数 (Value*),直接推入后缀表达式栈 + BinaryRPNStack.push_back(item); // 直接 push_back item (ValueOrOperator类型) + } else { + // 如果是操作符 + int currentOp = std::get(item); + + if (currentOp == LPAREN) { + // 左括号直接入栈 + BinaryOpStack.push_back(currentOp); + } else if (currentOp == RPAREN) { + // 右括号:将操作符栈中的操作符弹出并添加到后缀表达式栈,直到遇到左括号 + while (!BinaryOpStack.empty() && BinaryOpStack.back() != LPAREN) { + BinaryRPNStack.push_back(BinaryOpStack.back()); // 直接 push_back int + BinaryOpStack.pop_back(); + } + if (!BinaryOpStack.empty() && BinaryOpStack.back() == LPAREN) { + BinaryOpStack.pop_back(); // 弹出左括号,但不添加到后缀表达式栈 + } else { + // 错误:不匹配的右括号 + std::cerr << "Error: Mismatched parentheses in expression." << std::endl; + return; + } + } else { + // 普通操作符 + while (!BinaryOpStack.empty() && BinaryOpStack.back() != LPAREN) { + + int stackTopOp = BinaryOpStack.back(); + // 如果当前操作符优先级低于栈顶操作符优先级 + // 或者 (当前操作符优先级等于栈顶操作符优先级 并且 栈顶操作符是左结合) + if (getOperatorPrecedence(currentOp) < getOperatorPrecedence(stackTopOp) || + (getOperatorPrecedence(currentOp) == getOperatorPrecedence(stackTopOp) && + !isRightAssociative(stackTopOp))) { + + BinaryRPNStack.push_back(stackTopOp); + BinaryOpStack.pop_back(); + } else { + break; // 否则当前操作符入栈 + } + } + BinaryOpStack.push_back(currentOp); // 当前操作符入栈 + } + } + } + // 遍历结束后,将操作符栈中剩余的所有操作符弹出并添加到后缀表达式栈 + while (!BinaryOpStack.empty()) { + if (BinaryOpStack.back() == LPAREN) { + // 错误:不匹配的左括号 + std::cerr << "Error: Mismatched parentheses in expression (unclosed parenthesis)." << std::endl; + return; + } + BinaryRPNStack.push_back(BinaryOpStack.back()); // 直接 push_back int + BinaryOpStack.pop_back(); + } + + // 弹出BinaryExpStack的表达式 + while(begin < end) { + BinaryExpStack.pop_back(); + BinaryExpLenStack.back()--; + end--; + } + + // 计算后缀表达式 + // 每次计算前清空操作数栈 + BinaryValueStack.clear(); + + // 遍历后缀表达式栈 + Type *commonType = nullptr; + for(const auto &item : BinaryRPNStack) { + if (std::holds_alternative(item)) { + // 如果是操作数 (Value*) 检测他的类型 + Value *value = std::get(item); + if (commonType == nullptr) { + commonType = value->getType(); + } + else if (value->getType() != commonType && value->getType()->isFloat()) { + // 如果当前值的类型与commonType不同且是float类型,则提升为float + commonType = Type::getFloatType(); + break; + } + } else { + continue; + } + } + + for (const auto &item : BinaryRPNStack) { + if (std::holds_alternative(item)) { + // 如果是操作数 (Value*),直接推入操作数栈 + BinaryValueStack.push_back(std::get(item)); + } else { + // 如果是操作符 + int op = std::get(item); + Value *resultValue = nullptr; + Value *lhs = nullptr; + Value *rhs = nullptr; + Value *operand = nullptr; + + switch (op) { + case BinaryOp::ADD: + case BinaryOp::SUB: + case BinaryOp::MUL: + case BinaryOp::DIV: + case BinaryOp::MOD: { + // 二元操作符需要两个操作数 + if (BinaryValueStack.size() < 2) { + std::cerr << "Error: Not enough operands for binary operation: " << op << std::endl; + return; // 或者抛出异常 + } + rhs = BinaryValueStack.back(); + BinaryValueStack.pop_back(); + lhs = BinaryValueStack.back(); + BinaryValueStack.pop_back(); + // 类型转换 + lhs = promoteType(lhs, commonType); + rhs = promoteType(rhs, commonType); + + // 尝试常量折叠 + ConstantValue *lhsConst = dynamic_cast(lhs); + ConstantValue *rhsConst = dynamic_cast(rhs); + + if (lhsConst && rhsConst) { + // 如果都是常量,直接计算结果 + if (commonType == Type::getIntType()) { + int lhsVal = lhsConst->getInt(); + int rhsVal = rhsConst->getInt(); + switch (op) { + case BinaryOp::ADD: resultValue = ConstantInteger::get(lhsVal + rhsVal); break; + case BinaryOp::SUB: resultValue = ConstantInteger::get(lhsVal - rhsVal); break; + case BinaryOp::MUL: resultValue = ConstantInteger::get(lhsVal * rhsVal); break; + case BinaryOp::DIV: + if (rhsVal == 0) { + std::cerr << "Error: Division by zero." << std::endl; + return; + } + resultValue = sysy::ConstantInteger::get(lhsVal / rhsVal); break; + case BinaryOp::MOD: + if (rhsVal == 0) { + std::cerr << "Error: Modulo by zero." << std::endl; + return; + } + resultValue = sysy::ConstantInteger::get(lhsVal % rhsVal); break; + default: + std::cerr << "Error: Unknown binary operator for constants: " << op << std::endl; + return; + } + } else if (commonType == Type::getFloatType()) { + float lhsVal = lhsConst->getFloat(); + float rhsVal = rhsConst->getFloat(); + switch (op) { + case BinaryOp::ADD: resultValue = ConstantFloating::get(lhsVal + rhsVal); break; + case BinaryOp::SUB: resultValue = ConstantFloating::get(lhsVal - rhsVal); break; + case BinaryOp::MUL: resultValue = ConstantFloating::get(lhsVal * rhsVal); break; + case BinaryOp::DIV: + if (rhsVal == 0.0f) { + std::cerr << "Error: Division by zero." << std::endl; + return; + } + resultValue = sysy::ConstantFloating::get(lhsVal / rhsVal); break; + case BinaryOp::MOD: + std::cerr << "Error: Modulo operator not supported for float types." << std::endl; + return; + default: + std::cerr << "Error: Unknown binary operator for float constants: " << op << std::endl; + return; + } + } else { + std::cerr << "Error: Unsupported type for binary constant operation." << std::endl; + return; + } + } else { + // 否则,创建相应的IR指令 + if (commonType == Type::getIntType()) { + switch (op) { + case BinaryOp::ADD: resultValue = builder.createAddInst(lhs, rhs); break; + case BinaryOp::SUB: resultValue = builder.createSubInst(lhs, rhs); break; + case BinaryOp::MUL: resultValue = builder.createMulInst(lhs, rhs); break; + case BinaryOp::DIV: resultValue = builder.createDivInst(lhs, rhs); break; + case BinaryOp::MOD: resultValue = builder.createRemInst(lhs, rhs); break; + } + } else if (commonType == Type::getFloatType()) { + switch (op) { + case BinaryOp::ADD: resultValue = builder.createFAddInst(lhs, rhs); break; + case BinaryOp::SUB: resultValue = builder.createFSubInst(lhs, rhs); break; + case BinaryOp::MUL: resultValue = builder.createFMulInst(lhs, rhs); break; + case BinaryOp::DIV: resultValue = builder.createFDivInst(lhs, rhs); break; + case BinaryOp::MOD: + std::cerr << "Error: Modulo operator not supported for float types." << std::endl; + return; + } + } else { + std::cerr << "Error: Unsupported type for binary instruction." << std::endl; + return; + } + } + break; + } + case BinaryOp::PLUS: + case BinaryOp::NEG: + case BinaryOp::NOT: { + // 一元操作符需要一个操作数 + if (BinaryValueStack.empty()) { + std::cerr << "Error: Not enough operands for unary operation: " << op << std::endl; + return; + } + operand = BinaryValueStack.back(); + BinaryValueStack.pop_back(); + + operand = promoteType(operand, commonType); + + // 尝试常量折叠 + ConstantInteger *constInt = dynamic_cast(operand); + ConstantFloating *constFloat = dynamic_cast(operand); + + if (constInt || constFloat) { + // 如果是常量,直接计算结果 + switch (op) { + case BinaryOp::PLUS: resultValue = operand; break; + case BinaryOp::NEG: { + if (constInt) { + resultValue = constInt->getNeg(); + } else if (constFloat) { + resultValue = constFloat->getNeg(); + } else { + std::cerr << "Error: Negation not supported for constant operand type." << std::endl; + return; + } + break; + } + case BinaryOp::NOT: + if (constInt) { + resultValue = sysy::ConstantInteger::get(constInt->getInt() == 0 ? 1 : 0); + } else if (constFloat) { + resultValue = sysy::ConstantInteger::get(constFloat->getFloat() == 0.0f ? 1 : 0); + } else { + std::cerr << "Error: Logical NOT not supported for constant operand type." << std::endl; + return; + } + break; + default: + std::cerr << "Error: Unknown unary operator for constants: " << op << std::endl; + return; + } + } else { + // 否则,创建相应的IR指令 + switch (op) { + case BinaryOp::PLUS: + resultValue = operand; // 一元加指令通常直接返回操作数 + break; + case BinaryOp::NEG: { + if (commonType == sysy::Type::getIntType()) { + resultValue = builder.createNegInst(operand); + } else if (commonType == sysy::Type::getFloatType()) { + resultValue = builder.createFNegInst(operand); + } else { + std::cerr << "Error: Negation not supported for operand type." << std::endl; + return; + } + break; + } + case BinaryOp::NOT: + // 逻辑非 + if (commonType == sysy::Type::getIntType()) { + resultValue = builder.createNotInst(operand); + } else if (commonType == sysy::Type::getFloatType()) { + resultValue = builder.createFNotInst(operand); + } else { + std::cerr << "Error: Logical NOT not supported for operand type." << std::endl; + return; + } + break; + default: + std::cerr << "Error: Unknown unary operator for instructions: " << op << std::endl; + return; + } + } + break; + } + default: + std::cerr << "Error: Unknown operator " << op << " encountered in RPN stack." << std::endl; + return; + } + + // 将计算结果或指令结果推入操作数栈 + if (resultValue) { + BinaryValueStack.push_back(resultValue); + } else { + std::cerr << "Error: Result value is null after processing operator " << op << "!" << std::endl; + return; + } + + } + } + + // 后缀表达式处理完毕,操作数栈的栈顶就是最终结果 + if (BinaryValueStack.empty()) { + std::cerr << "Error: No values left in BinaryValueStack after processing RPN." << std::endl; + return; + } + if (BinaryValueStack.size() > 1) { + std::cerr + << "Warning: Multiple values left in BinaryValueStack after processing RPN. Expression might be malformed." + << std::endl; + } + BinaryRPNStack.clear(); // 清空后缀表达式栈 + BinaryOpStack.clear(); // 清空操作符栈 + return; +} + +Value* SysYIRGenerator::computeExp(SysYParser::ExpContext *ctx, Type* targetType){ + if (ctx->addExp() == nullptr) { + assert(false && "ExpContext should have an addExp child!"); + } + BinaryExpLenStack.push_back(0); // 进入新的层次时Push 0 + visitAddExp(ctx->addExp()); + + if(targetType == nullptr) { + targetType = Type::getIntType(); // 默认目标类型为int + } + + compute(); + // 最后一个Value应该是最终结果 + + Value* result = BinaryValueStack.back(); + BinaryValueStack.pop_back(); // 移除结果值 + + result = promoteType(result, targetType); // 确保结果类型符合目标类型 + // 检查当前层次的操作符数量 + int ExpLen = BinaryExpLenStack.back(); + BinaryExpLenStack.pop_back(); // 离开层次时将该层次 + if (ExpLen > 0) { + std::cerr << "Warning: There are still " << ExpLen << " binary val or op left unprocessed in this level!" << std::endl; + return nullptr; + } + return result; +} + +Value* SysYIRGenerator::computeAddExp(SysYParser::AddExpContext *ctx, Type* targetType){ + // 根据AddExpContext中的操作符和操作数计算加法表达式 + // 这里假设AddExpContext已经被正确填充 + if (ctx->mulExp().size() == 0) { + assert(false && "AddExpContext should have a mulExp child!"); + } + BinaryExpLenStack.push_back(0); // 进入新的层次时Push 0 + visitMulExp(ctx->mulExp(0)); + // BinaryValueStack.push_back(result); + + for (int i = 1; i < ctx->mulExp().size(); i++) { + auto opNode = dynamic_cast(ctx->children[2*i-1]); + int opType = opNode->getSymbol()->getType(); + switch(opType) { + case SysYParser::ADD: BinaryExpStack.push_back(BinaryOp::ADD); BinaryExpLenStack.back()++; break; + case SysYParser::SUB: BinaryExpStack.push_back(BinaryOp::SUB); BinaryExpLenStack.back()++; break; + default: assert(false && "Unexpected operator in AddExp."); + } + // BinaryExpStack.push_back(opType); + visitMulExp(ctx->mulExp(i)); + // BinaryValueStack.push_back(operand); + } + if(targetType == nullptr) { + targetType = Type::getIntType(); // 默认目标类型为int + } + // 根据后缀表达式的逻辑计算 + compute(); + // 最后一个Value应该是最终结果 + + Value* result = BinaryValueStack.back(); + BinaryValueStack.pop_back(); // 移除最后一个值,因为它已经被计算 + result = promoteType(result, targetType); // 确保结果类型符合目标类型 + + int ExpLen = BinaryExpLenStack.back(); + BinaryExpLenStack.pop_back(); // 离开层次时将该层次 + if (ExpLen > 0) { + std::cerr << "Warning: There are still " << ExpLen << " binary val or op left unprocessed in this level!" << std::endl; + return nullptr; + } + return result; +} + Type* SysYIRGenerator::buildArrayType(Type* baseType, const std::vector& dims){ Type* currentType = baseType; // 从最内层维度开始构建 ArrayType @@ -393,7 +825,8 @@ std::any SysYIRGenerator::visitBType(SysYParser::BTypeContext *ctx) { } std::any SysYIRGenerator::visitScalarInitValue(SysYParser::ScalarInitValueContext *ctx) { - Value* value = std::any_cast(visitExp(ctx->exp())); + // Value* value = std::any_cast(visitExp(ctx->exp())); + Value* value = computeExp(ctx->exp()); ArrayValueTree* result = new ArrayValueTree(); result->setValue(value); return result; @@ -408,13 +841,17 @@ std::any SysYIRGenerator::visitArrayInitValue(SysYParser::ArrayInitValueContext return result; } -std::any SysYIRGenerator::visitConstScalarInitValue(SysYParser::ConstScalarInitValueContext *ctx) { +std::any SysYIRGenerator::visitConstScalarInitValue(SysYParser::ConstScalarInitValueContext *ctx) { Value* value = std::any_cast(visitConstExp(ctx->constExp())); ArrayValueTree* result = new ArrayValueTree(); result->setValue(value); return result; } +std::any SysYIRGenerator::visitConstExp(SysYParser::ConstExpContext *ctx){ + return computeAddExp(ctx->addExp()); +} + std::any SysYIRGenerator::visitConstArrayInitValue(SysYParser::ConstArrayInitValueContext *ctx) { std::vector children; for (const auto &constInitVal : ctx->constInitVal()) @@ -570,8 +1007,8 @@ std::any SysYIRGenerator::visitAssignStmt(SysYParser::AssignStmtContext *ctx) { vector indices; if (lVal->exp().size() > 0) { // 如果有下标,访问表达式获取下标值 - for (const auto &exp : lVal->exp()) { - Value* indexValue = std::any_cast(visitExp(exp)); + for (auto &exp : lVal->exp()) { + Value* indexValue = std::any_cast(computeExp(exp)); indices.push_back(indexValue); } } @@ -610,15 +1047,18 @@ std::any SysYIRGenerator::visitAssignStmt(SysYParser::AssignStmtContext *ctx) { LValue = getGEPAddressInst(gepBasePointer, gepIndices); } - Value* RValue = std::any_cast(visitExp(ctx->exp())); // 右值 + // Value* RValue = std::any_cast(visitExp(ctx->exp())); // 右值 // 先推断 LValue 的类型 // 如果 LValue 是指向数组的指针,则需要根据 indices 获取正确的类型 // 如果 LValue 是标量,则直接使用其类型 // 注意:LValue 的类型可能是指向数组的指针 (e.g., int(*)[3]) 或者指向标量的指针 (e.g., int*) 也能推断 Type* LType = builder.getIndexedType(variable->getType(), indices); + + Value* RValue = computeExp(ctx->exp(), LType); // 右值计算 Type* RType = RValue->getType(); + // TODO:computeExp处理了类型转换,可以考虑删除判断逻辑 if (LType != RType) { ConstantValue *constValue = dynamic_cast(RValue); if (constValue != nullptr) { @@ -642,7 +1082,7 @@ std::any SysYIRGenerator::visitAssignStmt(SysYParser::AssignStmtContext *ctx) { } } else { if (LType == Type::getFloatType()) { - RValue = builder.createIToFInst(RValue); + RValue = builder.createItoFInst(RValue); } else { // 假设如果不是浮点型,就是整型 RValue = builder.createFtoIInst(RValue); } @@ -655,6 +1095,14 @@ std::any SysYIRGenerator::visitAssignStmt(SysYParser::AssignStmtContext *ctx) { } +std::any SysYIRGenerator::visitExpStmt(SysYParser::ExpStmtContext *ctx) { + // 访问表达式 + if (ctx->exp() != nullptr) { + computeExp(ctx->exp()); + } + return std::any(); +} + std::any SysYIRGenerator::visitIfStmt(SysYParser::IfStmtContext *ctx) { // labels string stream @@ -822,11 +1270,11 @@ std::any SysYIRGenerator::visitContinueStmt(SysYParser::ContinueStmtContext *ctx std::any SysYIRGenerator::visitReturnStmt(SysYParser::ReturnStmtContext *ctx) { Value* returnValue = nullptr; - if (ctx->exp() != nullptr) { - returnValue = std::any_cast(visitExp(ctx->exp())); - } - Type* funcType = builder.getBasicBlock()->getParent()->getReturnType(); + if (ctx->exp() != nullptr) { + returnValue = computeExp(ctx->exp(), funcType); + } + // TODOL 考虑删除类型转换判断逻辑 if (returnValue != nullptr && funcType!= returnValue->getType()) { ConstantValue * constValue = dynamic_cast(returnValue); if (constValue != nullptr) { @@ -849,7 +1297,7 @@ std::any SysYIRGenerator::visitReturnStmt(SysYParser::ReturnStmtContext *ctx) { } } else { if (funcType == Type::getFloatType()) { - returnValue = builder.createIToFInst(returnValue); + returnValue = builder.createItoFInst(returnValue); } else { returnValue = builder.createFtoIInst(returnValue); } @@ -891,7 +1339,8 @@ std::any SysYIRGenerator::visitLValue(SysYParser::LValueContext *ctx) { std::vector dims; for (const auto &exp : ctx->exp()) { - dims.push_back(std::any_cast(visitExp(exp))); + Value* expValue = std::any_cast(computeExp(exp)); + dims.push_back(expValue); } // 1. 获取变量的声明维度数量 @@ -995,16 +1444,23 @@ std::any SysYIRGenerator::visitLValue(SysYParser::LValueContext *ctx) { } std::any SysYIRGenerator::visitPrimaryExp(SysYParser::PrimaryExpContext *ctx) { - if (ctx->exp() != nullptr) - return visitExp(ctx->exp()); - if (ctx->lValue() != nullptr) - return visitLValue(ctx->lValue()); - if (ctx->number() != nullptr) - return visitNumber(ctx->number()); + if (ctx->exp() != nullptr) { + BinaryExpStack.push_back(BinaryOp::LPAREN);BinaryExpLenStack.back()++; + visitExp(ctx->exp()); + BinaryExpStack.push_back(BinaryOp::RPAREN);BinaryExpLenStack.back()++; + } + + if (ctx->lValue() != nullptr) { + // 如果是 lValue,将value压入栈中 + BinaryExpStack.push_back(std::any_cast(visitLValue(ctx->lValue())));BinaryExpLenStack.back()++; + } + if (ctx->number() != nullptr) { + BinaryExpStack.push_back(std::any_cast(visitNumber(ctx->number())));BinaryExpLenStack.back()++; + } if (ctx->string() != nullptr) { cout << "String literal not supported in SysYIRGenerator." << endl; } - return visitNumber(ctx->number()); + return std::any(); } std::any SysYIRGenerator::visitNumber(SysYParser::NumberContext *ctx) { @@ -1074,7 +1530,7 @@ std::any SysYIRGenerator::visitCall(SysYParser::CallContext *ctx) { if (formalParamExpectedValueType->isInt() && actualArgType->isFloat()) { args[i] = builder.createFtoIInst(args[i]); } else if (formalParamExpectedValueType->isFloat() && actualArgType->isInt()) { - args[i] = builder.createIToFInst(args[i]); + args[i] = builder.createItoFInst(args[i]); } // 2. 指针类型转换 (例如数组退化:`[N x T]*` 到 `T*`,或兼容指针类型之间) TODO:不清楚有没有这种样例 // 这种情况常见于数组参数,实参可能是一个更具体的数组指针类型, @@ -1099,235 +1555,78 @@ std::any SysYIRGenerator::visitCall(SysYParser::CallContext *ctx) { } std::any SysYIRGenerator::visitUnaryExp(SysYParser::UnaryExpContext *ctx) { - if (ctx->primaryExp() != nullptr) - return visitPrimaryExp(ctx->primaryExp()); - if (ctx->call() != nullptr) - return visitCall(ctx->call()); - - Value* value = std::any_cast(visitUnaryExp(ctx->unaryExp())); - Value* result = value; - if (ctx->unaryOp()->SUB() != nullptr) { - ConstantValue * constValue = dynamic_cast(value); - if (constValue != nullptr) { - if (constValue->isFloat()) { - result = ConstantFloating::get(-constValue->getFloat()); - } else { - result = ConstantInteger::get(-constValue->getInt()); + if (ctx->primaryExp() != nullptr) { + visitPrimaryExp(ctx->primaryExp()); + } else if (ctx->call() != nullptr) { + BinaryExpStack.push_back(std::any_cast(visitCall(ctx->call())));BinaryExpLenStack.back()++; + } else if (ctx->unaryOp() != nullptr) { + // 遇到一元操作符,将其压入 BinaryExpStack + auto opNode = dynamic_cast(ctx->unaryOp()->children[0]); + int opType = opNode->getSymbol()->getType(); + switch(opType) { + case SysYParser::ADD: BinaryExpStack.push_back(BinaryOp::PLUS); BinaryExpLenStack.back()++; break; + case SysYParser::SUB: BinaryExpStack.push_back(BinaryOp::NEG); BinaryExpLenStack.back()++; break; + case SysYParser::NOT: BinaryExpStack.push_back(BinaryOp::NOT); BinaryExpLenStack.back()++; break; + default: assert(false && "Unexpected operator in UnaryExp."); } - } else if (value != nullptr) { - if (value->getType() == Type::getIntType()) { - result = builder.createNegInst(value); - } else { - result = builder.createFNegInst(value); - } - } else { - std::cout << "UnExp: value is nullptr." << std::endl; - assert(false); - } - } else if (ctx->unaryOp()->NOT() != nullptr) { - auto constValue = dynamic_cast(value); - if (constValue != nullptr) { - if (constValue->isFloat()) { - result = - ConstantFloating::get(1 - (constValue->getFloat() != 0.0F ? 1 : 0)); - } else { - result = ConstantInteger::get(1 - (constValue->getInt() != 0 ? 1 : 0)); - } - } else if (value != nullptr) { - if (value->getType() == Type::getIntType()) { - result = builder.createNotInst(value); - } else { - result = builder.createFNotInst(value); - } - } else { - std::cout << "UnExp: value is nullptr." << std::endl; - assert(false); - } + visitUnaryExp(ctx->unaryExp()); } - return result; + return std::any(); } std::any SysYIRGenerator::visitFuncRParams(SysYParser::FuncRParamsContext *ctx) { std::vector params; - for (const auto &exp : ctx->exp()) - params.push_back(std::any_cast(visitExp(exp))); + for (const auto &exp : ctx->exp()) { + auto param = std::any_cast(computeExp(exp)); + params.push_back(param); + } + return params; } std::any SysYIRGenerator::visitMulExp(SysYParser::MulExpContext *ctx) { - Value * result = std::any_cast(visitUnaryExp(ctx->unaryExp(0))); - + visitUnaryExp(ctx->unaryExp(0)); + for (int i = 1; i < ctx->unaryExp().size(); i++) { auto opNode = dynamic_cast(ctx->children[2*i-1]); int opType = opNode->getSymbol()->getType(); - - Value* operand = std::any_cast(visitUnaryExp(ctx->unaryExp(i))); - - Type* resultType = result->getType(); - Type* operandType = operand->getType(); - Type* floatType = Type::getFloatType(); - - if (resultType == floatType || operandType == floatType) { - // 如果有一个操作数是浮点数,则将两个操作数都转换为浮点数 - if (operandType != floatType) { - ConstantValue * constValue = dynamic_cast(operand); - if (constValue != nullptr) { - if(dynamic_cast(constValue)) { - // 如果是整型常量,转换为浮点型 - operand = ConstantFloating::get(static_cast(constValue->getInt())); - } else if (dynamic_cast(constValue)) { - // 如果是浮点型常量,直接使用 - operand = ConstantFloating::get(static_cast(constValue->getFloat())); - } - } - else - operand = builder.createIToFInst(operand); - } else if (resultType != floatType) { - ConstantValue* constResult = dynamic_cast(result); - if (constResult != nullptr) { - if(dynamic_cast(constResult)) { - // 如果是整型常量,转换为浮点型 - result = ConstantFloating::get(static_cast(constResult->getInt())); - } else if (dynamic_cast(constResult)) { - // 如果是浮点型常量,直接使用 - result = ConstantFloating::get(static_cast(constResult->getFloat())); - } - } - else - result = builder.createIToFInst(result); - } - - ConstantFloating* constResult = dynamic_cast(result); - ConstantFloating* constOperand = dynamic_cast(operand); - if (opType == SysYParser::MUL) { - if ((constOperand != nullptr) && (constResult != nullptr)) { - result = ConstantFloating::get(constResult->getFloat() * - constOperand->getFloat()); - } else { - result = builder.createFMulInst(result, operand); - } - } else if (opType == SysYParser::DIV) { - if ((constOperand != nullptr) && (constResult != nullptr)) { - result = ConstantFloating::get(constResult->getFloat() / - constOperand->getFloat()); - } else { - result = builder.createFDivInst(result, operand); - } - } else { - // float类型的取模操作不允许 - std::cout << "MulExp: float type mod operation is not allowed." << std::endl; - assert(false); - } - } else { - ConstantInteger *constResult = dynamic_cast(result); - ConstantInteger *constOperand = dynamic_cast(operand); - if (opType == SysYParser::MUL) { - if ((constOperand != nullptr) && (constResult != nullptr)) - result = ConstantInteger::get(constResult->getInt() * constOperand->getInt()); - else - result = builder.createMulInst(result, operand); - } else if (opType == SysYParser::DIV) { - if ((constOperand != nullptr) && (constResult != nullptr)) - result = ConstantInteger::get(constResult->getInt() / constOperand->getInt()); - else - result = builder.createDivInst(result, operand); - } else { - if ((constOperand != nullptr) && (constResult != nullptr)) - result = ConstantInteger::get(constResult->getInt() % constOperand->getInt()); - else - result = builder.createRemInst(result, operand); - } + switch(opType) { + case SysYParser::MUL: BinaryExpStack.push_back(BinaryOp::MUL); BinaryExpLenStack.back()++; break; + case SysYParser::DIV: BinaryExpStack.push_back(BinaryOp::DIV); BinaryExpLenStack.back()++; break; + case SysYParser::MOD: BinaryExpStack.push_back(BinaryOp::MOD); BinaryExpLenStack.back()++; break; + default: assert(false && "Unexpected operator in MulExp."); } + visitUnaryExp(ctx->unaryExp(i)); } - - return result; + return std::any(); } std::any SysYIRGenerator::visitAddExp(SysYParser::AddExpContext *ctx) { - Value* result = std::any_cast(visitMulExp(ctx->mulExp(0))); + visitMulExp(ctx->mulExp(0)); for (int i = 1; i < ctx->mulExp().size(); i++) { auto opNode = dynamic_cast(ctx->children[2*i-1]); int opType = opNode->getSymbol()->getType(); - - Value* operand = std::any_cast(visitMulExp(ctx->mulExp(i))); - Type* resultType = result->getType(); - Type* operandType = operand->getType(); - Type* floatType = Type::getFloatType(); - - if (resultType == floatType || operandType == floatType) { - // 类型转换 - if (operandType != floatType) { - ConstantValue * constOperand = dynamic_cast(operand); - if (constOperand != nullptr) { - if(dynamic_cast(constOperand)) { - // 如果是整型常量,转换为浮点型 - operand = ConstantFloating::get(static_cast(constOperand->getInt())); - } else if (dynamic_cast(constOperand)) { - // 如果是浮点型常量,直接使用 - operand = ConstantFloating::get(static_cast(constOperand->getFloat())); - } - } - else - operand = builder.createIToFInst(operand); - } else if (resultType != floatType) { - ConstantValue * constResult = dynamic_cast(result); - if (constResult != nullptr) { - if(dynamic_cast(constResult)) { - // 如果是整型常量,转换为浮点型 - result = ConstantFloating::get(static_cast(constResult->getInt())); - } else if (dynamic_cast(constResult)) { - // 如果是浮点型常量,直接使用 - result = ConstantFloating::get(static_cast(constResult->getFloat())); - } - } - else - result = builder.createIToFInst(result); - } - - ConstantFloating *constResult = dynamic_cast(result); - ConstantFloating *constOperand = dynamic_cast(operand); - if (opType == SysYParser::ADD) { - if ((constResult != nullptr) && (constOperand != nullptr)) - result = ConstantFloating::get(constResult->getFloat() + constOperand->getFloat()); - else - result = builder.createFAddInst(result, operand); - } else { - if ((constResult != nullptr) && (constOperand != nullptr)) - result = ConstantFloating::get(constResult->getFloat() - constOperand->getFloat()); - else - result = builder.createFSubInst(result, operand); - } - } else { - ConstantInteger *constResult = dynamic_cast(result); - ConstantInteger *constOperand = dynamic_cast(operand); - if (opType == SysYParser::ADD) { - if ((constResult != nullptr) && (constOperand != nullptr)) - result = ConstantInteger::get(constResult->getInt() + constOperand->getInt()); - else - result = builder.createAddInst(result, operand); - } else { - if ((constResult != nullptr) && (constOperand != nullptr)) - result = ConstantInteger::get(constResult->getInt() - constOperand->getInt()); - else - result = builder.createSubInst(result, operand); - } + switch(opType) { + case SysYParser::ADD: BinaryExpStack.push_back(BinaryOp::ADD); BinaryExpLenStack.back()++; break; + case SysYParser::SUB: BinaryExpStack.push_back(BinaryOp::SUB); BinaryExpLenStack.back()++; break; + default: assert(false && "Unexpected operator in AddExp."); } + visitMulExp(ctx->mulExp(i)); } - - return result; + return std::any(); } std::any SysYIRGenerator::visitRelExp(SysYParser::RelExpContext *ctx) { - Value* result = std::any_cast(visitAddExp(ctx->addExp(0))); + Value* result = computeAddExp(ctx->addExp(0), Type::getIntType()); for (int i = 1; i < ctx->addExp().size(); i++) { auto opNode = dynamic_cast(ctx->children[2*i-1]); int opType = opNode->getSymbol()->getType(); - Value* operand = std::any_cast(visitAddExp(ctx->addExp(i))); + Value* operand = computeAddExp(ctx->addExp(i), Type::getIntType()); Type* resultType = result->getType(); Type* operandType = operand->getType(); @@ -1366,7 +1665,7 @@ std::any SysYIRGenerator::visitRelExp(SysYParser::RelExpContext *ctx) { } } else - result = builder.createIToFInst(result); + result = builder.createItoFInst(result); } if (operandType != floatType) { @@ -1380,7 +1679,7 @@ std::any SysYIRGenerator::visitRelExp(SysYParser::RelExpContext *ctx) { } } else - operand = builder.createIToFInst(operand); + operand = builder.createItoFInst(operand); } @@ -1407,6 +1706,7 @@ std::any SysYIRGenerator::visitRelExp(SysYParser::RelExpContext *ctx) { std::any SysYIRGenerator::visitEqExp(SysYParser::EqExpContext *ctx) { + // TODO:其实已经保证了result是一个int类型的值可以删除冗余判断逻辑 Value * result = std::any_cast(visitRelExp(ctx->relExp(0))); for (int i = 1; i < ctx->relExp().size(); i++) { @@ -1445,7 +1745,7 @@ std::any SysYIRGenerator::visitEqExp(SysYParser::EqExpContext *ctx) { } } else - result = builder.createIToFInst(result); + result = builder.createItoFInst(result); } if (operandType != floatType) { if (constOperand != nullptr) { @@ -1458,7 +1758,7 @@ std::any SysYIRGenerator::visitEqExp(SysYParser::EqExpContext *ctx) { } } else - operand = builder.createIToFInst(operand); + operand = builder.createItoFInst(operand); } if (opType == SysYParser::EQ) result = builder.createFCmpEQInst(result, operand); @@ -1567,7 +1867,7 @@ void Utils::tree2Array(Type *type, ArrayValueTree *root, assert(false && "Unknown constant type for float conversion."); } else - result.push_back(builder->createIToFInst(value)); + result.push_back(builder->createItoFInst(value)); } else { ConstantValue* constValue = dynamic_cast(value); diff --git a/src/midend/SysYIRPrinter.cpp b/src/midend/SysYIRPrinter.cpp index e2c5e17..0a024d2 100644 --- a/src/midend/SysYIRPrinter.cpp +++ b/src/midend/SysYIRPrinter.cpp @@ -1,7 +1,10 @@ #include "SysYIRPrinter.h" #include #include +#include #include +#include +#include #include #include "IR.h" // 确保IR.h包含了ArrayType、GetElementPtrInst等的定义 @@ -61,16 +64,21 @@ std::string SysYPrinter::getValueName(Value *value) { } else if (auto constInt = dynamic_cast(value)) { // 优先匹配具体的常量类型 return std::to_string(constInt->getInt()); } else if (auto constFloat = dynamic_cast(value)) { // 优先匹配具体的常量类型 - return std::to_string(constFloat->getFloat()); + std::ostringstream oss; + oss << std::scientific << std::setprecision(std::numeric_limits::max_digits10) << constFloat->getFloat(); + return oss.str(); } else if (auto constUndef = dynamic_cast(value)) { // 如果有Undef类型 return "undef"; } else if (auto constVal = dynamic_cast(value)) { // fallback for generic ConstantValue // 这里的逻辑可能需要根据你ConstantValue的实际设计调整 // 确保它能处理所有可能的ConstantValue - if (constVal->getType()->isFloat()) { - return std::to_string(constVal->getFloat()); + if (auto constInt = dynamic_cast(value)) { // 优先匹配具体的常量类型 + return std::to_string(constInt->getInt()); + } else if (auto constFloat = dynamic_cast(value)) { // 优先匹配具体的常量类型 + std::ostringstream oss; + oss << std::scientific << std::setprecision(std::numeric_limits::max_digits10) << constFloat->getFloat(); + return oss.str(); } - return std::to_string(constVal->getInt()); } else if (auto constVar = dynamic_cast(value)) { return constVar->getName(); // 假设ConstantVariable有自己的名字或通过getByIndices获取值 } else if (auto argVar = dynamic_cast(value)) {