From ede6465e8caf1497fb5975646d400b2d663f9158 Mon Sep 17 00:00:00 2001 From: rain2133 <1370973498@qq.com> Date: Tue, 15 Jul 2025 12:53:03 +0800 Subject: [PATCH 1/2] =?UTF-8?q?[IR]:=E5=A2=9E=E5=8A=A0=E9=BB=98=E8=AE=A4?= =?UTF-8?q?=E6=B7=BB=E5=8A=A0ret=E6=8C=87=E4=BB=A4=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/LLVMIRGenerator.cpp | 674 -------------------------- src/LLVMIRGenerator_1.cpp | 859 ---------------------------------- src/SysYIRGenerator.cpp | 14 + src/SysYIROptPre.cpp | 7 +- src/include/LLVMIRGenerator.h | 78 --- src/include/SysYIRGenerator.h | 2 + 6 files changed, 21 insertions(+), 1613 deletions(-) delete mode 100644 src/LLVMIRGenerator.cpp delete mode 100644 src/LLVMIRGenerator_1.cpp delete mode 100644 src/include/LLVMIRGenerator.h diff --git a/src/LLVMIRGenerator.cpp b/src/LLVMIRGenerator.cpp deleted file mode 100644 index 0d42ce5..0000000 --- a/src/LLVMIRGenerator.cpp +++ /dev/null @@ -1,674 +0,0 @@ -// LLVMIRGenerator.cpp -// TODO:类型转换及其检查 -// TODO:sysy库函数处理 -// TODO:数组处理 -// TODO:对while、continue、break的测试 -#include "LLVMIRGenerator.h" -#include -using namespace std; -namespace sysy { -std::string LLVMIRGenerator::generateIR(SysYParser::CompUnitContext* unit) { - // 初始化自定义IR数据结构 - irModule = std::make_unique(); - irBuilder = sysy::IRBuilder(); // 初始化IR构建器 - tempCounter = 0; - symbolTable.clear(); - tmpTable.clear(); - globalVars.clear(); - inFunction = false; - - visitCompUnit(unit); - return irStream.str(); -} - -std::string LLVMIRGenerator::getNextTemp() { - std::string ret = "%." + std::to_string(tempCounter++); - tmpTable[ret] = "void"; - return ret; -} - -std::string LLVMIRGenerator::getLLVMType(const std::string& type) { - if (type == "int") return "i32"; - if (type == "float") return "float"; - if (type.find("[]") != std::string::npos) - return getLLVMType(type.substr(0, type.size()-2)) + "*"; - return "i32"; -} - -sysy::Type* LLVMIRGenerator::getSysYType(const std::string& typeStr) { - if (typeStr == "int") return sysy::Type::getIntType(); - if (typeStr == "float") return sysy::Type::getFloatType(); - if (typeStr == "void") return sysy::Type::getVoidType(); - // 处理指针类型等 - return sysy::Type::getIntType(); -} - -std::any LLVMIRGenerator::visitCompUnit(SysYParser::CompUnitContext* ctx) { - auto type_i32 = Type::getIntType(); - auto type_f32 = Type::getFloatType(); - auto type_void = Type::getVoidType(); - auto type_i32p = Type::getPointerType(type_i32); - auto type_f32p = Type::getPointerType(type_f32); - - // 创建运行时库函数 - irModule->createFunction("getint", sysy::FunctionType::get(type_i32, {})); - irModule->createFunction("getch", sysy::FunctionType::get(type_i32, {})); - irModule->createFunction("getfloat", sysy::FunctionType::get(type_f32, {})); - //TODO: 添加更多运行时库函数 - irStream << "declare i32 @getint()\n"; - irStream << "declare i32 @getch()\n"; - irStream << "declare float @getfloat()\n"; - //TODO: 添加更多运行时库函数的文本IR - - for (auto decl : ctx->decl()) { - decl->accept(this); - } - for (auto funcDef : ctx->funcDef()) { - inFunction = true; // 进入函数定义 - funcDef->accept(this); - inFunction = false; // 离开函数定义 - } - return nullptr; -} - -std::any LLVMIRGenerator::visitVarDecl(SysYParser::VarDeclContext* ctx) { - // TODO:数组初始化 - std::string type = ctx->bType()->getText(); - currentVarType = getLLVMType(type); - - for (auto varDef : ctx->varDef()) { - if (!inFunction) { - // 全局变量声明 - std::string varName = varDef->Ident()->getText(); - std::string llvmType = getLLVMType(type); - std::string value = "0"; // 默认值为 0 - - if (varDef->ASSIGN()) { - value = std::any_cast(varDef->initVal()->accept(this)); - } else { - std::cout << "[WR-Release-01]Warning: Global variable '" << varName - << "' is declared without initialization, defaulting to 0.\n"; - } - irStream << "@" << varName << " = dso_local global " << llvmType << " " << value << ", align 4\n"; - globalVars.push_back(varName); // 记录全局变量 - } else { - // 局部变量声明 - varDef->accept(this); - } - } - return nullptr; -} - -std::any LLVMIRGenerator::visitConstDecl(SysYParser::ConstDeclContext* ctx) { - // TODO:数组初始化 - std::string type = ctx->bType()->getText(); - for (auto constDef : ctx->constDef()) { - if (!inFunction) { - // 全局常量声明 - std::string varName = constDef->Ident()->getText(); - std::string llvmType = getLLVMType(type); - std::string value = "0"; // 默认值为 0 - - try { - value = std::any_cast(constDef->constInitVal()->accept(this)); - } catch (...) { - throw std::runtime_error("[ERR-Release-01]Const value must be initialized upon definition."); - } - // 如果是 float 类型,转换为十六进制表示 - if (llvmType == "float") { - try { - double floatValue = std::stod(value); - uint64_t hexValue = reinterpret_cast(floatValue); - std::stringstream ss; - ss << "0x" << std::hex << std::uppercase << hexValue; - value = ss.str(); - } catch (...) { - throw std::runtime_error("[ERR-Release-02]Invalid float literal: " + value); - } - } - - irStream << "@" << varName << " = dso_local constant " << llvmType << " " << value << ", align 4\n"; - globalVars.push_back(varName); // 记录全局变量 - } else { - // 局部常量声明 - std::string varName = constDef->Ident()->getText(); - std::string llvmType = getLLVMType(type); - std::string allocaName = getNextTemp(); - std::string value = "0"; // 默认值为 0 - - try { - value = std::any_cast(constDef->constInitVal()->accept(this)); - } catch (...) { - throw std::runtime_error("Const value must be initialized upon definition."); - } - - irStream << " " << allocaName << " = alloca " << llvmType << ", align 4\n"; - if (llvmType == "float") { - try { - double floatValue = std::stod(value); - uint64_t hexValue = reinterpret_cast(floatValue); - std::stringstream ss; - ss << "0x" << std::hex << std::uppercase << hexValue; - value = ss.str(); - } catch (...) { - throw std::runtime_error("Invalid float literal: " + value); - } - } - irStream << " store " << llvmType << " " << value << ", " << llvmType - << "* " << allocaName << ", align 4\n"; - - symbolTable[varName] = {allocaName, llvmType}; - tmpTable[allocaName] = llvmType; - } - } - return nullptr; -} - -std::any LLVMIRGenerator::visitVarDef(SysYParser::VarDefContext* ctx) { - // TODO:数组初始化 - std::string varName = ctx->Ident()->getText(); - std::string type = currentVarType; - std::string llvmType = getLLVMType(type); - std::string allocaName = getNextTemp(); - - - irStream << " " << allocaName << " = alloca " << llvmType << ", align 4\n"; - - if (ctx->ASSIGN()) { - std::string value = std::any_cast(ctx->initVal()->accept(this)); - - if (llvmType == "float") { - try { - double floatValue = std::stod(value); - uint64_t hexValue = reinterpret_cast(floatValue); - std::stringstream ss; - ss << "0x" << std::hex << std::uppercase << (hexValue & (0xffffffffUL << 32)); - value = ss.str(); - } catch (...) { - throw std::runtime_error("Invalid float literal: " + value); - } - } - irStream << " store " << llvmType << " " << value << ", " << llvmType - << "* " << allocaName << ", align 4\n"; - } - symbolTable[varName] = {allocaName, llvmType}; - tmpTable[allocaName] = llvmType; - return nullptr; -} - -std::any LLVMIRGenerator::visitFuncDef(SysYParser::FuncDefContext* ctx) { - currentFunction = ctx->Ident()->getText(); - currentReturnType = getLLVMType(ctx->funcType()->getText()); - symbolTable.clear(); - tmpTable.clear(); - tempCounter = 0; - hasReturn = false; - - irStream << "define dso_local " << currentReturnType << " @" << currentFunction << "("; - if (ctx->funcFParams()) { - auto params = ctx->funcFParams()->funcFParam(); - tempCounter += params.size(); - for (size_t i = 0; i < params.size(); ++i) { - if (i > 0) irStream << ", "; - std::string paramType = getLLVMType(params[i]->bType()->getText()); - irStream << paramType << " noundef %" << i; - symbolTable[params[i]->Ident()->getText()] = {"%" + std::to_string(i), paramType}; - tmpTable["%" + std::to_string(i)] = paramType; - } - } - tempCounter++; - irStream << ") #0 {\n"; - - if (ctx->funcFParams()) { - auto params = ctx->funcFParams()->funcFParam(); - for (size_t i = 0; i < params.size(); ++i) { - std::string varName = params[i]->Ident()->getText(); - std::string type = params[i]->bType()->getText(); - std::string llvmType = getLLVMType(type); - std::string allocaName = getNextTemp(); - tmpTable[allocaName] = llvmType; - - irStream << " " << allocaName << " = alloca " << llvmType << ", align 4\n"; - irStream << " store " << llvmType << " " << symbolTable[varName].first << ", " << llvmType - << "* " << allocaName << ", align 4\n"; - - symbolTable[varName] = {allocaName, llvmType}; - } - } - ctx->blockStmt()->accept(this); - - if (!hasReturn) { - if (currentReturnType == "void") { - irStream << " ret void\n"; - } else { - irStream << " ret " << currentReturnType << " 0\n"; - } - } - irStream << "}\n"; - return nullptr; -} - -std::any LLVMIRGenerator::visitBlockStmt(SysYParser::BlockStmtContext* ctx) { - for (auto item : ctx->blockItem()) { - item->accept(this); - } - return nullptr; -} -std::any LLVMIRGenerator::visitAssignStmt(SysYParser::AssignStmtContext *ctx) -{ - std::string lhsAlloca = std::any_cast(ctx->lValue()->accept(this)); - std::string lhsType = symbolTable[ctx->lValue()->Ident()->getText()].second; - std::string rhs = std::any_cast(ctx->exp()->accept(this)); - - if (lhsType == "float") { - try { - double floatValue = std::stod(rhs); - uint64_t hexValue = reinterpret_cast(floatValue); - std::stringstream ss; - ss << "0x" << std::hex << std::uppercase << (hexValue & (0xffffffffUL << 32)); - rhs = ss.str(); - } catch (...) { - throw std::runtime_error("Invalid float literal: " + rhs); - } - } - - irStream << " store " << lhsType << " " << rhs << ", " << lhsType - << "* " << lhsAlloca << ", align 4\n"; - return nullptr; -} - -std::any LLVMIRGenerator::visitIfStmt(SysYParser::IfStmtContext *ctx) -{ - std::string cond = std::any_cast(ctx->cond()->accept(this)); - std::string trueLabel = "if.then." + std::to_string(tempCounter); - std::string falseLabel = "if.else." + std::to_string(tempCounter); - std::string mergeLabel = "if.end." + std::to_string(tempCounter++); - - irStream << " br i1 " << cond << ", label %" << trueLabel << ", label %" << falseLabel << "\n"; - - irStream << trueLabel << ":\n"; - ctx->stmt(0)->accept(this); - irStream << " br label %" << mergeLabel << "\n"; - - irStream << falseLabel << ":\n"; - if (ctx->ELSE()) { - ctx->stmt(1)->accept(this); - } - irStream << " br label %" << mergeLabel << "\n"; - - irStream << mergeLabel << ":\n"; - return nullptr; -} - -std::any LLVMIRGenerator::visitWhileStmt(SysYParser::WhileStmtContext *ctx) -{ - std::string loop_cond = "while.cond." + std::to_string(tempCounter); - std::string loop_body = "while.body." + std::to_string(tempCounter); - std::string loop_end = "while.end." + std::to_string(tempCounter++); - - loopStack.push({loop_end, loop_cond}); - irStream << " br label %" << loop_cond << "\n"; - irStream << loop_cond << ":\n"; - - std::string cond = std::any_cast(ctx->cond()->accept(this)); - irStream << " br i1 " << cond << ", label %" << loop_body << ", label %" << loop_end << "\n"; - irStream << loop_body << ":\n"; - ctx->stmt()->accept(this); - irStream << " br label %" << loop_cond << "\n"; - irStream << loop_end << ":\n"; - - loopStack.pop(); - return nullptr; -} - -std::any LLVMIRGenerator::visitBreakStmt(SysYParser::BreakStmtContext *ctx) -{ - if (loopStack.empty()) { - throw std::runtime_error("Break statement outside of a loop."); - } - irStream << " br label %" << loopStack.top().breakLabel << "\n"; - return nullptr; -} - -std::any LLVMIRGenerator::visitContinueStmt(SysYParser::ContinueStmtContext *ctx) -{ - if (loopStack.empty()) { - throw std::runtime_error("Continue statement outside of a loop."); - } - irStream << " br label %" << loopStack.top().continueLabel << "\n"; - return nullptr; -} - -std::any LLVMIRGenerator::visitReturnStmt(SysYParser::ReturnStmtContext *ctx) -{ - hasReturn = true; - if (ctx->exp()) { - std::string value = std::any_cast(ctx->exp()->accept(this)); - irStream << " ret " << currentReturnType << " " << value << "\n"; - } else { - irStream << " ret void\n"; - } - return nullptr; -} - -// std::any LLVMIRGenerator::visitStmt(SysYParser::StmtContext* ctx) { -// if (ctx->lValue() && ctx->ASSIGN()) { -// std::string lhsAlloca = std::any_cast(ctx->lValue()->accept(this)); -// std::string lhsType = symbolTable[ctx->lValue()->Ident()->getText()].second; -// std::string rhs = std::any_cast(ctx->exp()->accept(this)); -// if (lhsType == "float") { -// try { -// double floatValue = std::stod(rhs); -// uint64_t hexValue = reinterpret_cast(floatValue); -// std::stringstream ss; -// ss << "0x" << std::hex << std::uppercase << (hexValue & (0xffffffffUL << 32)); -// rhs = ss.str(); -// } catch (...) { -// throw std::runtime_error("Invalid float literal: " + rhs); -// } -// } -// irStream << " store " << lhsType << " " << rhs << ", " << lhsType -// << "* " << lhsAlloca << ", align 4\n"; -// } else if (ctx->RETURN()) { -// hasReturn = true; -// if (ctx->exp()) { -// std::string value = std::any_cast(ctx->exp()->accept(this)); -// irStream << " ret " << currentReturnType << " " << value << "\n"; -// } else { -// irStream << " ret void\n"; -// } -// } else if (ctx->IF()) { -// std::string cond = std::any_cast(ctx->cond()->accept(this)); -// std::string trueLabel = "if.then." + std::to_string(tempCounter); -// std::string falseLabel = "if.else." + std::to_string(tempCounter); -// std::string mergeLabel = "if.end." + std::to_string(tempCounter++); - -// irStream << " br i1 " << cond << ", label %" << trueLabel << ", label %" << falseLabel << "\n"; - -// irStream << trueLabel << ":\n"; -// ctx->stmt(0)->accept(this); -// irStream << " br label %" << mergeLabel << "\n"; - -// irStream << falseLabel << ":\n"; -// if (ctx->ELSE()) { -// ctx->stmt(1)->accept(this); -// } -// irStream << " br label %" << mergeLabel << "\n"; - -// irStream << mergeLabel << ":\n"; -// } else if (ctx->WHILE()) { -// std::string loop_cond = "while.cond." + std::to_string(tempCounter); -// std::string loop_body = "while.body." + std::to_string(tempCounter); -// std::string loop_end = "while.end." + std::to_string(tempCounter++); - -// loopStack.push({loop_end, loop_cond}); -// irStream << " br label %" << loop_cond << "\n"; -// irStream << loop_cond << ":\n"; - -// std::string cond = std::any_cast(ctx->cond()->accept(this)); -// irStream << " br i1 " << cond << ", label %" << loop_body << ", label %" << loop_end << "\n"; -// irStream << loop_body << ":\n"; -// ctx->stmt(0)->accept(this); -// irStream << " br label %" << loop_cond << "\n"; -// irStream << loop_end << ":\n"; - -// loopStack.pop(); - -// } else if (ctx->BREAK()) { -// if (loopStack.empty()) { -// throw std::runtime_error("Break statement outside of a loop."); -// } -// irStream << " br label %" << loopStack.top().breakLabel << "\n"; -// } else if (ctx->CONTINUE()) { -// if (loopStack.empty()) { -// throw std::runtime_error("Continue statement outside of a loop."); -// } -// irStream << " br label %" << loopStack.top().continueLabel << "\n"; -// } else if (ctx->blockStmt()) { -// ctx->blockStmt()->accept(this); -// } else if (ctx->exp()) { -// ctx->exp()->accept(this); -// } -// return nullptr; -// } - -std::any LLVMIRGenerator::visitLValue(SysYParser::LValueContext* ctx) { - std::string varName = ctx->Ident()->getText(); - return symbolTable[varName].first; -} - -// std::any LLVMIRGenerator::visitPrimaryExp(SysYParser::PrimaryExpContext* ctx) { -// if (ctx->lValue()) { -// std::string allocaPtr = std::any_cast(ctx->lValue()->accept(this)); -// std::string varName = ctx->lValue()->Ident()->getText(); -// std::string type = symbolTable[varName].second; -// std::string temp = getNextTemp(); -// irStream << " " << temp << " = load " << type << ", " << type << "* " << allocaPtr << ", align 4\n"; -// tmpTable[temp] = type; -// return temp; -// } else if (ctx->exp()) { -// return ctx->exp()->accept(this); -// } else { -// return ctx->number()->accept(this); -// } -// } - - -std::any LLVMIRGenerator::visitPrimExp(SysYParser::PrimExpContext *ctx){ - // irStream << "visitPrimExp\n"; - // std::cout << "Type name: " << typeid(*(ctx->primaryExp())).name() << std::endl; - SysYParser::PrimaryExpContext* pExpCtx = ctx->primaryExp(); - if (auto* lvalCtx = dynamic_cast(pExpCtx)) { - std::string allocaPtr = std::any_cast(lvalCtx->lValue()->accept(this)); - std::string varName = lvalCtx->lValue()->Ident()->getText(); - std::string type = symbolTable[varName].second; - std::string temp = getNextTemp(); - irStream << " " << temp << " = load " << type << ", " << type << "* " << allocaPtr << ", align 4\n"; - tmpTable[temp] = type; - return temp; - } else if (auto* expCtx = dynamic_cast(pExpCtx)) { - return expCtx->exp()->accept(this); - } else if (auto* strCtx = dynamic_cast(pExpCtx)) { - return strCtx->string()->accept(this); - } else if (auto* numCtx = dynamic_cast(pExpCtx)) { - return numCtx->number()->accept(this); - } else { - // 没有成功转换,说明 ctx->primaryExp() 不是 NumContext 或其他已知类型 - // 可能是其他类型的表达式,或者是一个空的 PrimaryExpContext - std::cout << "Unknown primary expression type." << std::endl; - throw std::runtime_error("Unknown primary expression type."); - } - // return visitChildren(ctx); -} - -std::any LLVMIRGenerator::visitParenExp(SysYParser::ParenExpContext* ctx) { - return ctx->exp()->accept(this); -} - -std::any LLVMIRGenerator::visitNumber(SysYParser::NumberContext* ctx) { - if (ctx->ILITERAL()) { - return ctx->ILITERAL()->getText(); - } else if (ctx->FLITERAL()) { - return ctx->FLITERAL()->getText(); - } - return ""; -} - -std::any LLVMIRGenerator::visitString(SysYParser::StringContext *ctx) -{ - if (ctx->STRING()) { - // 处理字符串常量 - std::string str = ctx->STRING()->getText(); - // 去掉引号 - str = str.substr(1, str.size() - 2); - // 转义处理 - std::string escapedStr; - for (char c : str) { - if (c == '\\') { - escapedStr += "\\\\"; - } else if (c == '"') { - escapedStr += "\\\""; - } else { - escapedStr += c; - } - } - return "\"" + escapedStr + "\""; - } - return ctx->STRING()->getText(); -} - -std::any LLVMIRGenerator::visitUnExp(SysYParser::UnExpContext* ctx) { - if (ctx->unaryOp()) { - std::string operand = std::any_cast(ctx->unaryExp()->accept(this)); - std::string op = ctx->unaryOp()->getText(); - std::string temp = getNextTemp(); - std::string type = operand.substr(0, operand.find(' ')); - tmpTable[temp] = type; - if (op == "-") { - irStream << " " << temp << " = sub " << type << " 0, " << operand << "\n"; - } else if (op == "!") { - irStream << " " << temp << " = xor " << type << " " << operand << ", 1\n"; - } - return temp; - } - return ctx->unaryExp()->accept(this); -} - -std::any LLVMIRGenerator::visitCall(SysYParser::CallContext *ctx) -{ - std::string funcName = ctx->Ident()->getText(); - std::vector args; - if (ctx->funcRParams()) { - for (auto argCtx : ctx->funcRParams()->exp()) { - args.push_back(std::any_cast(argCtx->accept(this))); - } - } - std::string temp = getNextTemp(); - std::string argList = ""; - for (size_t i = 0; i < args.size(); ++i) { - if (i > 0) argList += ", "; - argList +=tmpTable[args[i]] + " noundef " + args[i]; - } - irStream << " " << temp << " = call " << currentReturnType << " @" << funcName << "(" << argList << ")\n"; - tmpTable[temp] = currentReturnType; - return temp; -} - -std::any LLVMIRGenerator::visitMulExp(SysYParser::MulExpContext* ctx) { - auto unaryExps = ctx->unaryExp(); - std::string left = std::any_cast(unaryExps[0]->accept(this)); - for (size_t i = 1; i < unaryExps.size(); ++i) { - std::string right = std::any_cast(unaryExps[i]->accept(this)); - std::string op = ctx->children[2*i-1]->getText(); - std::string temp = getNextTemp(); - std::string type = tmpTable[left]; - if (op == "*") { - irStream << " " << temp << " = mul nsw " << type << " " << left << ", " << right << "\n"; - } else if (op == "/") { - irStream << " " << temp << " = sdiv " << type << " " << left << ", " << right << "\n"; - } else if (op == "%") { - irStream << " " << temp << " = srem " << type << " " << left << ", " << right << "\n"; - } - left = temp; - tmpTable[temp] = type; - } - return left; -} - -std::any LLVMIRGenerator::visitAddExp(SysYParser::AddExpContext* ctx) { - auto mulExps = ctx->mulExp(); - std::string left = std::any_cast(mulExps[0]->accept(this)); - for (size_t i = 1; i < mulExps.size(); ++i) { - std::string right = std::any_cast(mulExps[i]->accept(this)); - std::string op = ctx->children[2*i-1]->getText(); - std::string temp = getNextTemp(); - std::string type = tmpTable[left]; - if (op == "+") { - irStream << " " << temp << " = add nsw " << type << " " << left << ", " << right << "\n"; - } else if (op == "-") { - irStream << " " << temp << " = sub nsw " << type << " " << left << ", " << right << "\n"; - } - left = temp; - tmpTable[temp] = type; - } - return left; -} - -std::any LLVMIRGenerator::visitRelExp(SysYParser::RelExpContext* ctx) { - auto addExps = ctx->addExp(); - std::string left = std::any_cast(addExps[0]->accept(this)); - for (size_t i = 1; i < addExps.size(); ++i) { - std::string right = std::any_cast(addExps[i]->accept(this)); - std::string op = ctx->children[2*i-1]->getText(); - std::string temp = getNextTemp(); - std::string type = tmpTable[left]; - if (op == "<") { - irStream << " " << temp << " = icmp slt " << type << " " << left << ", " << right << "\n"; - } else if (op == ">") { - irStream << " " << temp << " = icmp sgt " << type << " " << left << ", " << right << "\n"; - } else if (op == "<=") { - irStream << " " << temp << " = icmp sle " << type << " " << left << ", " << right << "\n"; - } else if (op == ">=") { - irStream << " " << temp << " = icmp sge " << type << " " << left << ", " << right << "\n"; - } - left = temp; - } - return left; -} - -std::any LLVMIRGenerator::visitEqExp(SysYParser::EqExpContext* ctx) { - auto relExps = ctx->relExp(); - std::string left = std::any_cast(relExps[0]->accept(this)); - for (size_t i = 1; i < relExps.size(); ++i) { - std::string right = std::any_cast(relExps[i]->accept(this)); - std::string op = ctx->children[2*i-1]->getText(); - std::string temp = getNextTemp(); - std::string type = tmpTable[left]; - if (op == "==") { - irStream << " " << temp << " = icmp eq " << type << " " << left << ", " << right << "\n"; - } else if (op == "!=") { - irStream << " " << temp << " = icmp ne " << type << " " << left << ", " << right << "\n"; - } - left = temp; - } - return left; -} - -std::any LLVMIRGenerator::visitLAndExp(SysYParser::LAndExpContext* ctx) { - auto eqExps = ctx->eqExp(); - std::string left = std::any_cast(eqExps[0]->accept(this)); - for (size_t i = 1; i < eqExps.size(); ++i) { - std::string falseLabel = "land.false." + std::to_string(tempCounter); - std::string endLabel = "land.end." + std::to_string(tempCounter++); - std::string temp = getNextTemp(); - - irStream << " br label %" << falseLabel << "\n"; - irStream << falseLabel << ":\n"; - std::string right = std::any_cast(eqExps[i]->accept(this)); - irStream << " " << temp << " = and i1 " << left << ", " << right << "\n"; - irStream << " br label %" << endLabel << "\n"; - irStream << endLabel << ":\n"; - left = temp; - } - return left; -} - -std::any LLVMIRGenerator::visitLOrExp(SysYParser::LOrExpContext* ctx) { - auto lAndExps = ctx->lAndExp(); - std::string left = std::any_cast(lAndExps[0]->accept(this)); - for (size_t i = 1; i < lAndExps.size(); ++i) { - std::string trueLabel = "lor.true." + std::to_string(tempCounter); - std::string endLabel = "lor.end." + std::to_string(tempCounter++); - std::string temp = getNextTemp(); - - irStream << " br label %" << trueLabel << "\n"; - irStream << trueLabel << ":\n"; - std::string right = std::any_cast(lAndExps[i]->accept(this)); - irStream << " " << temp << " = or i1 " << left << ", " << right << "\n"; - irStream << " br label %" << endLabel << "\n"; - irStream << endLabel << ":\n"; - left = temp; - } - return left; -} -} \ No newline at end of file diff --git a/src/LLVMIRGenerator_1.cpp b/src/LLVMIRGenerator_1.cpp deleted file mode 100644 index 515b5a2..0000000 --- a/src/LLVMIRGenerator_1.cpp +++ /dev/null @@ -1,859 +0,0 @@ -// LLVMIRGenerator.cpp -// TODO:类型转换及其检查 -// TODO:sysy库函数处理 -// TODO:数组处理 -// TODO:对while、continue、break的测试 -#include "LLVMIRGenerator_1.h" -#include -#include -#include - -// namespace sysy { - -std::string LLVMIRGenerator::generateIR(SysYParser::CompUnitContext* unit) { - // 初始化 SysY IR 模块 - module = std::make_unique(); - // 清空符号表和临时变量表 - symbolTable.clear(); - tmpTable.clear(); - irSymbolTable.clear(); - irTmpTable.clear(); - tempCounter = 0; - globalVars.clear(); - hasReturn = false; - loopStack = std::stack(); - inFunction = false; - - // 访问编译单元 - visitCompUnit(unit); - return irStream.str(); -} - -std::string LLVMIRGenerator::getNextTemp() { - std::string ret = "%." + std::to_string(tempCounter++); - tmpTable[ret] = "void"; - return ret; -} - -std::string LLVMIRGenerator::getIRTempName() { - return "%" + std::to_string(tempCounter++); -} - -std::string LLVMIRGenerator::getLLVMType(const std::string& type) { - if (type == "int") return "i32"; - if (type == "float") return "float"; - if (type.find("[]") != std::string::npos) - return getLLVMType(type.substr(0, type.size() - 2)) + "*"; - return "i32"; -} - -sysy::Type* LLVMIRGenerator::getIRType(const std::string& type) { - if (type == "int") return sysy::Type::getIntType(); - if (type == "float") return sysy::Type::getFloatType(); - if (type == "void") return sysy::Type::getVoidType(); - if (type.find("[]") != std::string::npos) { - std::string baseType = type.substr(0, type.size() - 2); - return sysy::Type::getPointerType(getIRType(baseType)); - } - return sysy::Type::getIntType(); // 默认 int -} - -void LLVMIRGenerator::setIRPosition(sysy::BasicBlock* block) { - currentIRBlock = block; -} - -std::any LLVMIRGenerator::visitCompUnit(SysYParser::CompUnitContext* ctx) { - for (auto decl : ctx->decl()) { - decl->accept(this); - } - for (auto funcDef : ctx->funcDef()) { - inFunction = true; - funcDef->accept(this); - inFunction = false; - } - return nullptr; -} - - -std::any LLVMIRGenerator::visitVarDecl(SysYParser::VarDeclContext* ctx) { - // TODO:数组初始化 - std::string type = ctx->bType()->getText(); - currentVarType = getLLVMType(type); - sysy::Type* irType = sysy::Type::getPointerType(getIRType(type)); - - for (auto varDef : ctx->varDef()) { - if (!inFunction) { - // 全局变量(文本 IR) - std::string varName = varDef->Ident()->getText(); - std::string llvmType = getLLVMType(type); - std::string value = "0"; - sysy::Value* initValue = nullptr; - - if (varDef->ASSIGN()) { - value = std::any_cast(varDef->initVal()->accept(this)); - if (irTmpTable.find(value) != irTmpTable.end() && isa(irTmpTable[value])) { - initValue = irTmpTable[value]; - } - } - - if (llvmType == "float" && initValue) { - try { - double floatValue = std::stod(value); - uint64_t hexValue = reinterpret_cast(floatValue); - std::stringstream ss; - ss << "0x" << std::hex << std::uppercase << hexValue; - value = ss.str(); - } catch (...) { - throw std::runtime_error("[ERR-Release-02]Invalid float literal: " + value); - } - } - irStream << "@" << varName << " = dso_local global " << llvmType << " " << value << ", align 4\n"; - globalVars.push_back(varName); - - // 全局变量(SysY IR) - auto globalValue = module->createGlobalValue(varName, irType, {}, initValue); - irSymbolTable[varName] = globalValue; - } else { - varDef->accept(this); - } - } - return nullptr; -} - -std::any LLVMIRGenerator::visitConstDecl(SysYParser::ConstDeclContext* ctx) { - // TODO:数组初始化 - std::string type = ctx->bType()->getText(); - currentVarType = getLLVMType(type); - sysy::Type* irType = sysy::Type::getPointerType(getIRType(type)); // 全局变量为指针类型 - - for (auto constDef : ctx->constDef()) { - std::string varName = constDef->Ident()->getText(); - std::string llvmType = getLLVMType(type); - std::string value = "0"; - sysy::Value* initValue = nullptr; - - try { - value = std::any_cast(constDef->constInitVal()->accept(this)); - if (isa(irTmpTable[value])) { - initValue = irTmpTable[value]; - } - } catch (...) { - throw std::runtime_error("Const value must be initialized upon definition."); - } - - if (!inFunction) { - // 全局常量(文本 IR) - if (llvmType == "float") { - try { - double floatValue = std::stod(value); - uint64_t hexValue = reinterpret_cast(floatValue); - std::stringstream ss; - ss << "0x" << std::hex << std::uppercase << hexValue; - value = ss.str(); - } catch (...) { - throw std::runtime_error("[ERR-Release-03]Invalid float literal: " + value); - } - } - irStream << "@" << varName << " = dso_local constant " << llvmType << " " << value << ", align 4\n"; - globalVars.push_back(varName); - - // 全局常量(SysY IR) - auto globalValue = module->createGlobalValue(varName, irType, {}, initValue); - irSymbolTable[varName] = globalValue; - } else { - // 局部常量(文本 IR) - std::string allocaName = getNextTemp(); - if (llvmType == "float") { - try { - double floatValue = std::stod(value); - uint64_t hexValue = reinterpret_cast(floatValue); - std::stringstream ss; - ss << "0x" << std::hex << std::uppercase << hexValue; - value = ss.str(); - } catch (...) { - throw std::runtime_error("Invalid float literal: " + value); - } - } - irStream << " " << allocaName << " = alloca " << llvmType << ", align 4\n"; - irStream << " store " << llvmType << " " << value << ", " << llvmType - << "* " << allocaName << ", align 4\n"; - symbolTable[varName] = {allocaName, llvmType}; - tmpTable[allocaName] = llvmType; - - // 局部常量(SysY IR)TODO:这里可能有bug,AI在犯蠢 - sysy::IRBuilder builder(currentIRBlock); - auto allocaInst = builder.createAllocaInst(irType, {}, varName); - builder.createStoreInst(initValue, allocaInst); - irSymbolTable[varName] = allocaInst; - irTmpTable[allocaName] = allocaInst; - } - } - return nullptr; -} - -std::any LLVMIRGenerator::visitVarDef(SysYParser::VarDefContext* ctx) { - // TODO:数组初始化 - std::string varName = ctx->Ident()->getText(); - std::string llvmType = currentVarType; - sysy::Type* irType = sysy::Type::getPointerType(getIRType(currentVarType == "i32" ? "int" : "float")); - std::string allocaName = getNextTemp(); - - // 局部变量(文本 IR) - irStream << " " << allocaName << " = alloca " << llvmType << ", align 4\n"; - - // 局部变量(SysY IR) - sysy::IRBuilder builder(currentIRBlock); - auto allocaInst = builder.createAllocaInst(irType, {}, varName); - sysy::Value* initValue = nullptr; - - if (ctx->ASSIGN()) { - std::string value = std::any_cast(ctx->initVal()->accept(this)); - if (llvmType == "float") { - try { - double floatValue = std::stod(value); - uint64_t hexValue = reinterpret_cast(floatValue); - std::stringstream ss; - ss << "0x" << std::hex << std::uppercase << (hexValue & (0xffffffffUL << 32)); - value = ss.str(); - } catch (...) { - throw std::runtime_error("Invalid float literal: " + value); - } - } - irStream << " store " << llvmType << " " << value << ", " << llvmType - << "* " << allocaName << ", align 4\n"; - - if (irTmpTable.find(value) != irTmpTable.end()) { - initValue = irTmpTable[value]; - } - builder.createStoreInst(initValue, allocaInst); - } - - symbolTable[varName] = {allocaName, llvmType}; - tmpTable[allocaName] = llvmType; - irSymbolTable[varName] = allocaInst;//TODO:这里没看懂在干嘛 - irTmpTable[allocaName] = allocaInst;//TODO:这里没看懂在干嘛 - builder.createStoreInst(initValue, allocaInst);//TODO:这里没看懂在干嘛 - return nullptr; -} - -std::any LLVMIRGenerator::visitFuncDef(SysYParser::FuncDefContext* ctx) { - currentFunction = ctx->Ident()->getText(); - currentReturnType = getLLVMType(ctx->funcType()->getText()); - sysy::Type* irReturnType = getIRType(ctx->funcType()->getText()); - std::vector paramTypes; - - // 清空符号表 - symbolTable.clear(); - tmpTable.clear(); - irSymbolTable.clear(); - irTmpTable.clear(); - tempCounter = 0; - hasReturn = false; - - // 处理函数参数(文本 IR 和 SysY IR) - if (ctx->funcFParams()) { - auto params = ctx->funcFParams()->funcFParam(); - for (size_t i = 0; i < params.size(); ++i) { - std::string paramType = getLLVMType(params[i]->bType()->getText()); - if (i > 0) irStream << ", "; - irStream << paramType << " noundef %" << i; - symbolTable[params[i]->Ident()->getText()] = {"%" + std::to_string(i), paramType}; - tmpTable["%" + std::to_string(i)] = paramType; - paramTypes.push_back(getIRType(params[i]->bType()->getText())); - } - tempCounter += params.size(); - } - tempCounter++; - - // 文本 IR 函数定义 - irStream << "define dso_local " << currentReturnType << " @" << currentFunction << "("; - irStream << ") #0 {\n"; - - // SysY IR 函数定义 - sysy::Type* funcType = sysy::Type::getFunctionType(irReturnType, paramTypes); - currentIRFunction = module->createFunction(currentFunction, funcType); - setIRPosition(currentIRFunction->getEntryBlock()); - - // 处理函数参数分配 - if (ctx->funcFParams()) { - auto params = ctx->funcFParams()->funcFParam(); - for (size_t i = 0; i < params.size(); ++i) { - std::string varName = params[i]->Ident()->getText(); - std::string llvmType = getLLVMType(params[i]->bType()->getText()); - sysy::Type* irType = getIRType(params[i]->bType()->getText()); - std::string allocaName = getNextTemp(); - tmpTable[allocaName] = llvmType; - - // 文本 IR 分配 - irStream << " " << allocaName << " = alloca " << llvmType << ", align 4\n"; - irStream << " store " << llvmType << " %" << i << ", " << llvmType - << "* " << allocaName << ", align 4\n"; - - // SysY IR 分配 - sysy::IRBuilder builder(currentIRBlock); - auto arg = currentIRBlock->createArgument(irType, varName); - auto allocaInst = builder.createAllocaInst(sysy::Type::getPointerType(irType), {}, varName); - builder.createStoreInst(arg, allocaInst); - symbolTable[varName] = {allocaName, llvmType}; - irSymbolTable[varName] = allocaInst; - irTmpTable[allocaName] = allocaInst; - } - } - - ctx->blockStmt()->accept(this); - - if (!hasReturn) { - if (currentReturnType == "void") { - irStream << " ret void\n"; - sysy::IRBuilder builder(currentIRBlock); - builder.createReturnInst(); - } else { - irStream << " ret " << currentReturnType << " 0\n"; - sysy::IRBuilder builder(currentIRBlock); - builder.createReturnInst(sysy::ConstantValue::get(0)); - } - } - irStream << "}\n"; - currentIRFunction = nullptr; - currentIRBlock = nullptr; - return nullptr; -} - -std::any LLVMIRGenerator::visitBlockStmt(SysYParser::BlockStmtContext* ctx) { - for (auto item : ctx->blockItem()) { - item->accept(this); - } - return nullptr; -} - -std::any LLVMIRGenerator::visitAssignStmt(SysYParser::AssignStmtContext* ctx) { - std::string lhsAlloca = std::any_cast(ctx->lValue()->accept(this)); - std::string lhsType = symbolTable[ctx->lValue()->Ident()->getText()].second; - std::string rhs = std::any_cast(ctx->exp()->accept(this)); - sysy::Value* rhsValue = irTmpTable[rhs]; - - // 文本 IR - if (lhsType == "float") { - try { - double floatValue = std::stod(rhs); - uint64_t hexValue = reinterpret_cast(floatValue); - std::stringstream ss; - ss << "0x" << std::hex << std::uppercase << (hexValue & (0xffffffffUL << 32)); - rhs = ss.str(); - } catch (...) { - // 如果 rhs 不是字面量,假设已正确处理 - throw std::runtime_error("Invalid float literal: " + rhs); - } - } - irStream << " store " << lhsType << " " << rhs << ", " << lhsType - << "* " << lhsAlloca << ", align 4\n"; - - // SysY IR - sysy::IRBuilder builder(currentIRBlock); - builder.createStoreInst(rhsValue, irSymbolTable[ctx->lValue()->Ident()->getText()]); - return nullptr; -} - -std::any LLVMIRGenerator::visitIfStmt(SysYParser::IfStmtContext* ctx) { - std::string cond = std::any_cast(ctx->cond()->accept(this)); - sysy::Value* condValue = irTmpTable[cond]; - std::string trueLabel = "if.then." + std::to_string(tempCounter); - std::string falseLabel = "if.else." + std::to_string(tempCounter); - std::string mergeLabel = "if.end." + std::to_string(tempCounter++); - - // SysY IR 基本块 - sysy::BasicBlock* thenBlock = currentIRFunction->addBasicBlock(trueLabel); - sysy::BasicBlock* elseBlock = ctx->ELSE() ? currentIRFunction->addBasicBlock(falseLabel) : nullptr; - sysy::BasicBlock* mergeBlock = currentIRFunction->addBasicBlock(mergeLabel); - - // 文本 IR - irStream << " br i1 " << cond << ", label %" << trueLabel << ", label %" - << (ctx->ELSE() ? falseLabel : mergeLabel) << "\n"; - - // SysY IR 条件分支 - sysy::IRBuilder builder(currentIRBlock); - builder.createCondBrInst(condValue, thenBlock, ctx->ELSE() ? elseBlock : mergeBlock, {}, {}); - - // 处理 then 分支 - setIRPosition(thenBlock); - irStream << trueLabel << ":\n"; - ctx->stmt(0)->accept(this); - irStream << " br label %" << mergeLabel << "\n"; - builder.setPosition(thenBlock, thenBlock->end()); - builder.createUncondBrInst(mergeBlock, {}); - - // 处理 else 分支 - if (ctx->ELSE()) { - setIRPosition(elseBlock); - irStream << falseLabel << ":\n"; - ctx->stmt(1)->accept(this); - irStream << " br label %" << mergeLabel << "\n"; - builder.setPosition(elseBlock, elseBlock->end()); - builder.createUncondBrInst(mergeBlock, {}); - } - - // 合并点 - setIRPosition(mergeBlock); - irStream << mergeLabel << ":\n"; - return nullptr; -} - -std::any LLVMIRGenerator::visitWhileStmt(SysYParser::WhileStmtContext* ctx) { - std::string loopCond = "while.cond." + std::to_string(tempCounter); - std::string loopBody = "while.body." + std::to_string(tempCounter); - std::string loopEnd = "while.end." + std::to_string(tempCounter++); - - // SysY IR 基本块 - sysy::BasicBlock* condBlock = currentIRFunction->addBasicBlock(loopCond); - sysy::BasicBlock* bodyBlock = currentIRFunction->addBasicBlock(loopBody); - sysy::BasicBlock* endBlock = currentIRFunction->addBasicBlock(loopEnd); - - loopStack.push({loopEnd, loopCond, endBlock, condBlock}); - - // 跳转到条件块 - sysy::IRBuilder builder(currentIRBlock); - builder.createUncondBrInst(condBlock, {}); - irStream << " br label %" << loopCond << "\n"; - - // 条件块 - setIRPosition(condBlock); - irStream << loopCond << ":\n"; - std::string cond = std::any_cast(ctx->cond()->accept(this)); - sysy::Value* condValue = irTmpTable[cond]; - irStream << " br i1 " << cond << ", label %" << loopBody << ", label %" << loopEnd << "\n"; - builder.setPosition(condBlock, condBlock->end()); - builder.createCondBrInst(condValue, bodyBlock, endBlock, {}, {}); - - // 循环体 - setIRPosition(bodyBlock); - irStream << loopBody << ":\n"; - ctx->stmt()->accept(this); - irStream << " br label %" << loopCond << "\n"; - builder.setPosition(bodyBlock, bodyBlock->end()); - builder.createUncondBrInst(condBlock, {}); - - // 结束块 - setIRPosition(endBlock); - irStream << loopEnd << ":\n"; - loopStack.pop(); - return nullptr; -} - -std::any LLVMIRGenerator::visitBreakStmt(SysYParser::BreakStmtContext* ctx) { - if (loopStack.empty()) { - throw std::runtime_error("Break statement outside of a loop."); - } - irStream << " br label %" << loopStack.top().breakLabel << "\n"; - sysy::IRBuilder builder(currentIRBlock); - builder.createUncondBrInst(loopStack.top().irBreakBlock, {}); - return nullptr; -} - -std::any LLVMIRGenerator::visitContinueStmt(SysYParser::ContinueStmtContext* ctx) { - if (loopStack.empty()) { - throw std::runtime_error("Continue statement outside of a loop."); - } - irStream << " br label %" << loopStack.top().continueLabel << "\n"; - sysy::IRBuilder builder(currentIRBlock); - builder.createUncondBrInst(loopStack.top().irContinueBlock, {}); - return nullptr; -} - -std::any LLVMIRGenerator::visitReturnStmt(SysYParser::ReturnStmtContext* ctx) { - hasReturn = true; - sysy::IRBuilder builder(currentIRBlock); - if (ctx->exp()) { - std::string value = std::any_cast(ctx->exp()->accept(this)); - sysy::Value* irValue = irTmpTable[value]; - irStream << " ret " << currentReturnType << " " << value << "\n"; - builder.createReturnInst(irValue); - } else { - irStream << " ret void\n"; - builder.createReturnInst(); - } - return nullptr; -} - -std::any LLVMIRGenerator::visitLValue(SysYParser::LValueContext* ctx) { - std::string varName = ctx->Ident()->getText(); - if (irSymbolTable.find(varName) == irSymbolTable.end()) { - throw std::runtime_error("Undefined variable: " + varName); - } - // 对于 LValue,返回分配的指针(文本 IR 和 SysY IR 一致) - return symbolTable[varName].first; -} - -std::any LLVMIRGenerator::visitPrimExp(SysYParser::PrimExpContext* ctx) { - SysYParser::PrimaryExpContext* pExpCtx = ctx->primaryExp(); - if (auto* lvalCtx = dynamic_cast(pExpCtx)) { - std::string allocaPtr = std::any_cast(lvalCtx->lValue()->accept(this)); - std::string varName = lvalCtx->lValue()->Ident()->getText(); - std::string type = symbolTable[varName].second; - std::string temp = getNextTemp(); - sysy::Type* irType = getIRType(type == "i32" ? "int" : "float"); - - // 文本 IR - irStream << " " << temp << " = load " << type << ", " << type << "* " << allocaPtr << ", align 4\n"; - tmpTable[temp] = type; - - // SysY IR - sysy::IRBuilder builder(currentIRBlock); - auto loadInst = builder.createLoadInst(irSymbolTable[varName], {}); - irTmpTable[temp] = loadInst; - return temp; - } else if (auto* expCtx = dynamic_cast(pExpCtx)) { - return expCtx->exp()->accept(this); - } else if (auto* strCtx = dynamic_cast(pExpCtx)) { - return strCtx->string()->accept(this); - } else if (auto* numCtx = dynamic_cast(pExpCtx)) { - return numCtx->number()->accept(this); - } else { - // 没有成功转换,说明 ctx->primaryExp() 不是 NumContext 或其他已知类型 - // 可能是其他类型的表达式,或者是一个空的 PrimaryExpContext - std::cout << "Unknown primary expression type." << std::endl; - throw std::runtime_error("Unknown primary expression type."); - } -} - -std::any LLVMIRGenerator::visitParenExp(SysYParser::ParenExpContext* ctx) { - return ctx->exp()->accept(this); -} - -std::any LLVMIRGenerator::visitNumber(SysYParser::NumberContext* ctx) { - std::string value; - sysy::Value* irValue = nullptr; - if (ctx->ILITERAL()) { - value = ctx->ILITERAL()->getText(); - irValue = sysy::ConstantValue::get(std::stoi(value)); - } else if (ctx->FLITERAL()) { - value = ctx->FLITERAL()->getText(); - irValue = sysy::ConstantValue::get(std::stof(value)); - } else { - value = ""; - } - std::string temp = getNextTemp(); - tmpTable[temp] = ctx->ILITERAL() ? "i32" : "float"; - irTmpTable[temp] = irValue; - return value; -} - -std::any LLVMIRGenerator::visitString(SysYParser::StringContext* ctx) { - if (ctx->STRING()) { - std::string str = ctx->STRING()->getText(); - str = str.substr(1, str.size() - 2); - std::string escapedStr; - for (char c : str) { - if (c == '\\') { - escapedStr += "\\\\"; - } else if (c == '"') { - escapedStr += "\\\""; - } else { - escapedStr += c; - } - } - // TODO: SysY IR 暂不支持字符串常量,返回文本 IR 结果 - return "\"" + escapedStr + "\""; - } - return ctx->STRING()->getText(); -} - - - -std::any LLVMIRGenerator::visitUnExp(SysYParser::UnExpContext* ctx) { - if (ctx->unaryOp()) { - std::string operand = std::any_cast(ctx->unaryExp()->accept(this)); - sysy::Value* irOperand = irTmpTable[operand]; - std::string op = ctx->unaryOp()->getText(); - std::string temp = getNextTemp(); - std::string type = tmpTable[operand]; - sysy::Type* irType = getIRType(type == "i32" ? "int" : "float"); - tmpTable[temp] = type; - - // 文本 IR - if (op == "-") { - irStream << " " << temp << " = sub " << type << " 0, " << operand << "\n"; - } else if (op == "!") { - irStream << " " << temp << " = xor " << type << " " << operand << ", 1\n"; - } - - // SysY IR - sysy::IRBuilder builder(currentIRBlock); - sysy::Instruction::Kind kind = (op == "-") ? (type == "i32" ? sysy::Instruction::kNeg : sysy::Instruction::kFNeg) - : sysy::Instruction::kNot; - auto unaryInst = builder.createUnaryInst(kind, irType, irOperand, temp); - irTmpTable[temp] = unaryInst; - return temp; - } - return ctx->unaryExp()->accept(this); -} - -std::any LLVMIRGenerator::visitCall(SysYParser::CallContext* ctx) { - std::string funcName = ctx->Ident()->getText(); - std::vector args; - std::vector irArgs; - if (ctx->funcRParams()) { - for (auto argCtx : ctx->funcRParams()->exp()) { - std::string arg = std::any_cast(argCtx->accept(this)); - args.push_back(arg); - irArgs.push_back(irTmpTable[arg]); - } - } - std::string temp = getNextTemp(); - std::string argList; - for (size_t i = 0; i < args.size(); ++i) { - if (i > 0) argList += ", "; - argList += tmpTable[args[i]] + " noundef " + args[i]; - } - - // 文本 IR - irStream << " " << temp << " = call " << currentReturnType << " @" << funcName << "(" << argList << ")\n"; - tmpTable[temp] = currentReturnType; - - // SysY IR - sysy::IRBuilder builder(currentIRBlock); - sysy::Function* callee = module->getFunction(funcName); - if (!callee) { - throw std::runtime_error("Undefined function: " + funcName); - } - auto callInst = builder.createCallInst(callee, irArgs, temp); - irTmpTable[temp] = callInst; - return temp; -} - -std::any LLVMIRGenerator::visitMulExp(SysYParser::MulExpContext* ctx) { - auto unaryExps = ctx->unaryExp(); - std::string left = std::any_cast(unaryExps[0]->accept(this)); - sysy::Value* irLeft = irTmpTable[left]; - sysy::Type* irType = irLeft->getType(); - - for (size_t i = 1; i < unaryExps.size(); ++i) { - std::string right = std::any_cast(unaryExps[i]->accept(this)); - sysy::Value* irRight = irTmpTable[right]; - std::string op = ctx->children[2 * i - 1]->getText(); - std::string temp = getNextTemp(); - std::string type = tmpTable[left]; - tmpTable[temp] = type; - - // 文本 IR - if (op == "*") { - irStream << " " << temp << " = mul nsw " << type << " " << left << ", " << right << "\n"; - } else if (op == "/") { - irStream << " " << temp << " = sdiv " << type << " " << left << ", " << right << "\n"; - } else if (op == "%") { - irStream << " " << temp << " = srem " << type << " " << left << ", " << right << "\n"; - } - - // SysY IR - sysy::IRBuilder builder(currentIRBlock); - sysy::Instruction::Kind kind; - if (type == "i32") { - if (op == "*") kind = sysy::Instruction::kMul; - else if (op == "/") kind = sysy::Instruction::kDiv; - else kind = sysy::Instruction::kRem; - } else { - if (op == "*") kind = sysy::Instruction::kFMul; - else if (op == "/") kind = sysy::Instruction::kFDiv; - else kind = sysy::Instruction::kFRem; - } - auto binaryInst = builder.createBinaryInst(kind, irType, irLeft, irRight, temp); - irTmpTable[temp] = binaryInst; - left = temp; - irLeft = binaryInst; - } - return left; -} - -std::any LLVMIRGenerator::visitAddExp(SysYParser::AddExpContext* ctx) { - auto mulExps = ctx->mulExp(); - std::string left = std::any_cast(mulExps[0]->accept(this)); - sysy::Value* irLeft = irTmpTable[left]; - sysy::Type* irType = irLeft->getType(); - - for (size_t i = 1; i < mulExps.size(); ++i) { - std::string right = std::any_cast(mulExps[i]->accept(this)); - sysy::Value* irRight = irTmpTable[right]; - std::string op = ctx->children[2 * i - 1]->getText(); - std::string temp = getNextTemp(); - std::string type = tmpTable[left]; - tmpTable[temp] = type; - - // 文本 IR - if (op == "+") { - irStream << " " << temp << " = add nsw " << type << " " << left << ", " << right << "\n"; - } else if (op == "-") { - irStream << " " << temp << " = sub nsw " << type << " " << left << ", " << right << "\n"; - } - - // SysY IR - sysy::IRBuilder builder(currentIRBlock); - sysy::Instruction::Kind kind = (type == "i32") ? (op == "+" ? sysy::Instruction::kAdd : sysy::Instruction::kSub) - : (op == "+" ? sysy::Instruction::kFAdd : sysy::Instruction::kFSub); - auto binaryInst = builder.createBinaryInst(kind, irType, irLeft, irRight, temp); - irTmpTable[temp] = binaryInst; - left = temp; - irLeft = binaryInst; - } - return left; -} - -std::any LLVMIRGenerator::visitRelExp(SysYParser::RelExpContext* ctx) { - auto addExps = ctx->addExp(); - std::string left = std::any_cast(addExps[0]->accept(this)); - sysy::Value* irLeft = irTmpTable[left]; - sysy::Type* irType = sysy::Type::getIntType(); // 比较结果为 i1 - - for (size_t i = 1; i < addExps.size(); ++i) { - std::string right = std::any_cast(addExps[i]->accept(this)); - sysy::Value* irRight = irTmpTable[right]; - std::string op = ctx->children[2 * i - 1]->getText(); - std::string temp = getNextTemp(); - std::string type = tmpTable[left]; - tmpTable[temp] = "i1"; - - // 文本 IR - if (op == "<") { - irStream << " " << temp << " = icmp slt " << type << " " << left << ", " << right << "\n"; - } else if (op == ">") { - irStream << " " << temp << " = icmp sgt " << type << " " << left << ", " << right << "\n"; - } else if (op == "<=") { - irStream << " " << temp << " = icmp sle " << type << " " << left << ", " << right << "\n"; - } else if (op == ">=") { - irStream << " " << temp << " = icmp sge " << type << " " << left << ", " << right << "\n"; - } - - // SysY IR - sysy::IRBuilder builder(currentIRBlock); - sysy::Instruction::Kind kind; - if (type == "i32") { - if (op == "<") kind = sysy::Instruction::kICmpLT; - else if (op == ">") kind = sysy::Instruction::kICmpGT; - else if (op == "<=") kind = sysy::Instruction::kICmpLE; - else kind = sysy::Instruction::kICmpGE; - } else { - if (op == "<") kind = sysy::Instruction::kFCmpLT; - else if (op == ">") kind = sysy::Instruction::kFCmpGT; - else if (op == "<=") kind = sysy::Instruction::kFCmpLE; - else kind = sysy::Instruction::kFCmpGE; - } - auto cmpInst = builder.createBinaryInst(kind, irType, irLeft, irRight, temp); - irTmpTable[temp] = cmpInst; - left = temp; - irLeft = cmpInst; - } - return left; -} - -std::any LLVMIRGenerator::visitEqExp(SysYParser::EqExpContext* ctx) { - auto relExps = ctx->relExp(); - std::string left = std::any_cast(relExps[0]->accept(this)); - sysy::Value* irLeft = irTmpTable[left]; - sysy::Type* irType = sysy::Type::getIntType(); // 比较结果为 i1 - - for (size_t i = 1; i < relExps.size(); ++i) { - std::string right = std::any_cast(relExps[i]->accept(this)); - sysy::Value* irRight = irTmpTable[right]; - std::string op = ctx->children[2 * i - 1]->getText(); - std::string temp = getNextTemp(); - std::string type = tmpTable[left]; - tmpTable[temp] = "i1"; - - // 文本 IR - if (op == "==") { - irStream << " " << temp << " = icmp eq " << type << " " << left << ", " << right << "\n"; - } else if (op == "!=") { - irStream << " " << temp << " = icmp ne " << type << " " << left << ", " << right << "\n"; - } - - // SysY IR - sysy::IRBuilder builder(currentIRBlock); - sysy::Instruction::Kind kind = (type == "i32") ? (op == "==" ? sysy::Instruction::kICmpEQ : sysy::Instruction::kICmpNE) - : (op == "==" ? sysy::Instruction::kFCmpEQ : sysy::Instruction::kFCmpNE); - auto cmpInst = builder.createBinaryInst(kind, irType, irLeft, irRight, temp); - irTmpTable[temp] = cmpInst; - left = temp; - irLeft = cmpInst; - } - return left; -} - -std::any LLVMIRGenerator::visitLAndExp(SysYParser::LAndExpContext* ctx) { - auto eqExps = ctx->eqExp(); - std::string left = std::any_cast(eqExps[0]->accept(this)); - sysy::Value* irLeft = irTmpTable[left]; - - for (size_t i = 1; i < eqExps.size(); ++i) { - std::string falseLabel = "land.false." + std::to_string(tempCounter); - std::string endLabel = "land.end." + std::to_string(tempCounter++); - sysy::BasicBlock* falseBlock = currentIRFunction->addBasicBlock(falseLabel); - sysy::BasicBlock* endBlock = currentIRFunction->addBasicBlock(endLabel); - std::string temp = getNextTemp(); - tmpTable[temp] = "i1"; - - // 文本 IR - irStream << " br i1 " << left << ", label %" << falseLabel << ", label %" << endLabel << "\n"; - irStream << falseLabel << ":\n"; - - // SysY IR - sysy::IRBuilder builder(currentIRBlock); - builder.createCondBrInst(irLeft, falseBlock, endBlock, {}, {}); - setIRPosition(falseBlock); - - std::string right = std::any_cast(eqExps[i]->accept(this)); - sysy::Value* irRight = irTmpTable[right]; - irStream << " " << temp << " = and i1 " << left << ", " << right << "\n"; - irStream << " br label %" << endLabel << "\n"; - irStream << endLabel << ":\n"; - - // SysY IR 逻辑与(通过基本块实现短路求值) - builder.setPosition(falseBlock, falseBlock->end()); - auto andInst = builder.createBinaryInst(sysy::Instruction::kICmpEQ, sysy::Type::getIntType(), irLeft, irRight, temp); - builder.createUncondBrInst(endBlock, {}); - irTmpTable[temp] = andInst; - left = temp; - irLeft = andInst; - setIRPosition(endBlock); - } - return left; -} - -std::any LLVMIRGenerator::visitLOrExp(SysYParser::LOrExpContext* ctx) { - auto lAndExps = ctx->lAndExp(); - std::string left = std::any_cast(lAndExps[0]->accept(this)); - sysy::Value* irLeft = irTmpTable[left]; - - for (size_t i = 1; i < lAndExps.size(); ++i) { - std::string trueLabel = "lor.true." + std::to_string(tempCounter); - std::string endLabel = "lor.end." + std::to_string(tempCounter++); - sysy::BasicBlock* trueBlock = currentIRFunction->addBasicBlock(trueLabel); - sysy::BasicBlock* endBlock = currentIRFunction->addBasicBlock(endLabel); - std::string temp = getNextTemp(); - tmpTable[temp] = "i1"; - - // 文本 IR - irStream << " br i1 " << left << ", label %" << trueLabel << ", label %" << endLabel << "\n"; - irStream << trueLabel << ":\n"; - - // SysY IR - sysy::IRBuilder builder(currentIRBlock); - builder.createCondBrInst(irLeft, trueBlock, endBlock, {}, {}); - setIRPosition(trueBlock); - - std::string right = std::any_cast(lAndExps[i]->accept(this)); - sysy::Value* irRight = irTmpTable[right]; - irStream << " " << temp << " = or i1 " << left << ", " << right << "\n"; - irStream << " br label %" << endLabel << "\n"; - irStream << endLabel << ":\n"; - - // SysY IR 逻辑或(通过基本块实现短路求值) - builder.setPosition(trueBlock, trueBlock->end()); - auto orInst = builder.createBinaryInst(sysy::Instruction::kICmpEQ, sysy::Type::getIntType(), irLeft, irRight, temp); - builder.createUncondBrInst(endBlock, {}); - irTmpTable[temp] = orInst; - left = temp; - irLeft = orInst; - setIRPosition(endBlock); - } - return left; -} - -// } // namespace sysy \ No newline at end of file diff --git a/src/SysYIRGenerator.cpp b/src/SysYIRGenerator.cpp index 86e19e7..7520891 100644 --- a/src/SysYIRGenerator.cpp +++ b/src/SysYIRGenerator.cpp @@ -204,6 +204,7 @@ std::any SysYIRGenerator::visitFuncType(SysYParser::FuncTypeContext *ctx) { std::any SysYIRGenerator::visitFuncDef(SysYParser::FuncDefContext *ctx){ // 更新作用域 module->enterNewScope(); + HasReturnInst = false; auto name = ctx->Ident()->getText(); std::vector paramTypes; @@ -243,6 +244,18 @@ std::any SysYIRGenerator::visitFuncDef(SysYParser::FuncDefContext *ctx){ visitBlockItem(item); } + if(HasReturnInst == false) { + // 如果没有return语句,则默认返回0 + if (returnType != Type::getVoidType()) { + Value* returnValue = ConstantValue::get(0); + if (returnType == Type::getFloatType()) { + returnValue = ConstantValue::get(0.0f); + } + builder.createReturnInst(returnValue); + } else { + builder.createReturnInst(); + } + } module->leaveScope(); return std::any(); @@ -478,6 +491,7 @@ std::any SysYIRGenerator::visitReturnStmt(SysYParser::ReturnStmtContext *ctx) { } } builder.createReturnInst(returnValue); + HasReturnInst = true; return std::any(); } diff --git a/src/SysYIROptPre.cpp b/src/SysYIROptPre.cpp index 41af234..fb05cb7 100644 --- a/src/SysYIROptPre.cpp +++ b/src/SysYIROptPre.cpp @@ -4,6 +4,7 @@ #include #include #include +#include #include "IR.h" #include "IRBuilder.h" @@ -458,11 +459,13 @@ void SysYOptPre::SysYAddReturn() { // 如果基本块没有后继块,则添加一个返回指令 if (block->getNumInstructions() == 0) { pBuilder->setPosition(block.get(), block->end()); - pBuilder->createReturnInst({}); + pBuilder->createReturnInst(); } auto thelastinst = block->getInstructions().end(); --thelastinst; if (thelastinst->get()->getKind() != Instruction::kReturn) { + // std::cout << "Warning: Function " << func->getName() << " has no return instruction, adding default return." << std::endl; + pBuilder->setPosition(block.get(), block->end()); // TODO: 如果int float函数缺少返回值是否需要报错 if (func->getReturnType()->isInt()) { @@ -470,7 +473,7 @@ void SysYOptPre::SysYAddReturn() { } else if (func->getReturnType()->isFloat()) { pBuilder->createReturnInst(ConstantValue::get(0.0F)); } else { - pBuilder->createReturnInst({}); + pBuilder->createReturnInst(); } } } diff --git a/src/include/LLVMIRGenerator.h b/src/include/LLVMIRGenerator.h deleted file mode 100644 index e330a4f..0000000 --- a/src/include/LLVMIRGenerator.h +++ /dev/null @@ -1,78 +0,0 @@ -#pragma once -#include "SysYBaseVisitor.h" -#include "SysYParser.h" -#include "IR.h" -#include "IRBuilder.h" -#include -#include -#include -#include - -class LLVMIRGenerator : public SysYBaseVisitor { -public: - sysy::Module* getIRModule() const { return irModule.get(); } - - std::string generateIR(SysYParser::CompUnitContext* unit); - std::string getIR() const { return irStream.str(); } - -private: - std::unique_ptr irModule; // IR数据结构 - std::stringstream irStream; // 文本输出流 - sysy::IRBuilder irBuilder; // IR构建器 - int tempCounter = 0; - std::string currentVarType; - // std::map symbolTable; - std::map> symbolTable; - std::map tmpTable; - std::vector globalVars; - std::string currentFunction; - std::string currentReturnType; - std::vector breakStack; - std::vector continueStack; - bool hasReturn = false; - - struct LoopLabels { - std::string breakLabel; // break跳转的目标标签 - std::string continueLabel; // continue跳转的目标标签 - }; - std::stack loopStack; // 用于管理循环的break和continue标签 - std::string getNextTemp(); - std::string getLLVMType(const std::string&); - sysy::Type* getSysYType(const std::string&); - - bool inFunction = false; // 标识当前是否处于函数内部 - - // 访问方法 - std::any visitCompUnit(SysYParser::CompUnitContext* ctx); - std::any visitConstDecl(SysYParser::ConstDeclContext* ctx); - std::any visitVarDecl(SysYParser::VarDeclContext* ctx); - std::any visitVarDef(SysYParser::VarDefContext* ctx); - std::any visitFuncDef(SysYParser::FuncDefContext* ctx); - std::any visitBlockStmt(SysYParser::BlockStmtContext* ctx); - // std::any visitStmt(SysYParser::StmtContext* ctx); - std::any visitLValue(SysYParser::LValueContext* ctx); - std::any visitPrimaryExp(SysYParser::PrimaryExpContext* ctx); - std::any visitPrimExp(SysYParser::PrimExpContext* ctx); - std::any visitParenExp(SysYParser::ParenExpContext* ctx); - std::any visitNumber(SysYParser::NumberContext* ctx); - std::any visitString(SysYParser::StringContext* ctx); - std::any visitCall(SysYParser::CallContext *ctx); - std::any visitUnExp(SysYParser::UnExpContext* ctx); - std::any visitMulExp(SysYParser::MulExpContext* ctx); - std::any visitAddExp(SysYParser::AddExpContext* ctx); - std::any visitRelExp(SysYParser::RelExpContext* ctx); - std::any visitEqExp(SysYParser::EqExpContext* ctx); - std::any visitLAndExp(SysYParser::LAndExpContext* ctx); - std::any visitLOrExp(SysYParser::LOrExpContext* ctx); - std::any visitAssignStmt(SysYParser::AssignStmtContext *ctx) override; - std::any visitIfStmt(SysYParser::IfStmtContext *ctx) override; - std::any visitWhileStmt(SysYParser::WhileStmtContext *ctx) override; - std::any visitBreakStmt(SysYParser::BreakStmtContext *ctx) override; - std::any visitContinueStmt(SysYParser::ContinueStmtContext *ctx) override; - std::any visitReturnStmt(SysYParser::ReturnStmtContext *ctx) override; - - // 统一创建二元操作(同时生成数据结构和文本) - sysy::Value* createBinaryOp(SysYParser::ExpContext* lhs, - SysYParser::ExpContext* rhs, - sysy::Instruction::Kind opKind); -}; \ No newline at end of file diff --git a/src/include/SysYIRGenerator.h b/src/include/SysYIRGenerator.h index 445a856..fe309e8 100644 --- a/src/include/SysYIRGenerator.h +++ b/src/include/SysYIRGenerator.h @@ -62,6 +62,8 @@ private: public: SysYIRGenerator() = default; + bool HasReturnInst; + public: Module *get() const { return module.get(); } IRBuilder *getBuilder(){ return &builder; } From 50fd9cffe9726302699a0115f986637346ffdbcd Mon Sep 17 00:00:00 2001 From: rain2133 <1370973498@qq.com> Date: Wed, 16 Jul 2025 13:04:05 +0800 Subject: [PATCH 2/2] =?UTF-8?q?[IRPrinter&DCE]=E4=BF=AE=E6=94=B9=E5=AE=9A?= =?UTF-8?q?=E4=B9=89=E6=96=B9=E4=BE=BF=E8=B0=83=E8=AF=95=E6=89=93=E5=8D=B0?= =?UTF-8?q?=EF=BC=8C=E5=9C=A8DEC=E4=B8=AD=E5=A2=9E=E5=8A=A0=E8=B0=83?= =?UTF-8?q?=E8=AF=95=E4=BF=A1=E6=81=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/DeadCodeElimination.cpp | 31 +++++++++++++++++++++++++++++-- src/include/DeadCodeElimination.h | 2 ++ src/include/SysYIRPrinter.h | 13 +++++++------ 3 files changed, 38 insertions(+), 8 deletions(-) diff --git a/src/DeadCodeElimination.cpp b/src/DeadCodeElimination.cpp index 9abca1c..ffe6022 100644 --- a/src/DeadCodeElimination.cpp +++ b/src/DeadCodeElimination.cpp @@ -1,8 +1,9 @@ #include "DeadCodeElimination.h" +#include +extern int DEBUG; namespace sysy { - void DeadCodeElimination::runDCEPipeline() { const auto& functions = pModule->getFunctions(); for (const auto& function : functions) { @@ -58,6 +59,10 @@ void DeadCodeElimination::eliminateDeadStores(Function* func, bool& changed) { if (changetag) { changed = true; + if(DEBUG){ + std::cout << "=== Dead Store Found ===\n"; + SysYPrinter::printInst(storeInst); + } usedelete(storeInst); iter = instrs.erase(iter); } else { @@ -76,6 +81,10 @@ void DeadCodeElimination::eliminateDeadLoads(Function* func, bool& changed) { if (inst->isBinary() || inst->isUnary() || inst->isLoad()) { if (inst->getUses().empty()) { changed = true; + if(DEBUG){ + std::cout << "=== Dead Load Binary Unary Found ===\n"; + SysYPrinter::printInst(inst); + } usedelete(inst); iter = instrs.erase(iter); continue; @@ -101,6 +110,10 @@ void DeadCodeElimination::eliminateDeadAllocas(Function* func, bool& changed) { func->getEntryBlock()->getArguments().end(), allocaInst) == func->getEntryBlock()->getArguments().end()) { changed = true; + if(DEBUG){ + std::cout << "=== Dead Alloca Found ===\n"; + SysYPrinter::printInst(inst); + } usedelete(inst); iter = instrs.erase(iter); continue; @@ -116,8 +129,12 @@ void DeadCodeElimination::eliminateDeadIndirectiveAllocas(Function* func, bool& FunctionAnalysisInfo* funcInfo = pCFA->getFunctionAnalysisInfo(func); for (auto it = funcInfo->getIndirectAllocas().begin(); it != funcInfo->getIndirectAllocas().end();) { auto &allocaInst = *it; - if (allocaInst->getUses().empty()) { + if (allocaInst->getUses().empty()) { changed = true; + if(DEBUG){ + std::cout << "=== Dead Indirect Alloca Found ===\n"; + SysYPrinter::printInst(allocaInst.get()); + } it = funcInfo->getIndirectAllocas().erase(it); } else { ++it; @@ -132,6 +149,10 @@ void DeadCodeElimination::eliminateDeadGlobals(bool& changed) { auto& global = *it; if (global->getUses().empty()) { changed = true; + if(DEBUG){ + std::cout << "=== Dead Global Found ===\n"; + SysYPrinter::printValue(global.get()); + } it = globals.erase(it); } else { ++it; @@ -207,6 +228,12 @@ void DeadCodeElimination::eliminateDeadRedundantLoadStore(Function* func, bool& // 可以优化直接把prevStorePointer的值存到nextStorePointer changed = true; nextStore->setOperand(0, prevStoreValue); + if(DEBUG){ + std::cout << "=== Dead Store Load Store Found(now only del Load) ===\n"; + SysYPrinter::printInst(prevStore); + SysYPrinter::printInst(loadInst); + SysYPrinter::printInst(nextStore); + } usedelete(loadInst); iter = instrs.erase(iter); // 删除 prevStore 这里是不是可以留给删除无用store处理? diff --git a/src/include/DeadCodeElimination.h b/src/include/DeadCodeElimination.h index 2d614bd..72b9935 100644 --- a/src/include/DeadCodeElimination.h +++ b/src/include/DeadCodeElimination.h @@ -2,6 +2,8 @@ #include "IR.h" #include "SysYIRAnalyser.h" +#include "SysYIRPrinter.h" + namespace sysy { class DeadCodeElimination { diff --git a/src/include/SysYIRPrinter.h b/src/include/SysYIRPrinter.h index 114fb05..bfd78bd 100644 --- a/src/include/SysYIRPrinter.h +++ b/src/include/SysYIRPrinter.h @@ -15,15 +15,16 @@ public: public: void printIR(); void printGlobalVariable(); - void printFunction(Function *function); - void printInst(Instruction *pInst); - void printType(Type *type); - void printValue(Value *value); + public: + static void printFunction(Function *function); + static void printInst(Instruction *pInst); + static void printType(Type *type); + static void printValue(Value *value); static std::string getOperandName(Value *operand); - std::string getTypeString(Type *type); - std::string getValueName(Value *value); + static std::string getTypeString(Type *type); + static std::string getValueName(Value *value); }; } // namespace sysy