diff --git a/src/SysYIRGenerator.cpp b/src/SysYIRGenerator.cpp index 55e4396..f38db5c 100644 --- a/src/SysYIRGenerator.cpp +++ b/src/SysYIRGenerator.cpp @@ -1,4 +1,3 @@ -// SysYIRGenerator.cpp #include "SysYIRGenerator.h" #include @@ -30,6 +29,21 @@ std::any SysYIRGenerator::visitCompUnit(SysYParser::CompUnitContext* ctx) { } std::any SysYIRGenerator::visitConstDecl(SysYParser::ConstDeclContext* ctx) { + std::string type = ctx->bType()->getText(); + for (auto constDef : ctx->constDef()) { + std::string varName = constDef->Ident()->getText(); + symbolTable[varName].second = type; + std::string llvmType = getLLVMType(type); + std::string allocaName = getNextTemp(); + symbolTable[varName] = {allocaName, llvmType}; + irStream << " " << allocaName << " = alloca " << llvmType << ", align 4\n"; + + if (constDef->constInitVal()) { + std::string value = std::any_cast(constDef->constInitVal()->accept(this)); + irStream << " store " << llvmType << " " << value << ", " << llvmType + << "* " << allocaName << ", align 4\n"; + } + } return nullptr; } @@ -48,11 +62,12 @@ std::any SysYIRGenerator::visitVarDef(SysYParser::VarDefContext* ctx) { std::string llvmType = getLLVMType(type); std::string allocaName = getNextTemp(); symbolTable[varName] = {allocaName, llvmType}; - irStream << " " << allocaName << " = alloca " << llvmType << ", align " << (type == "float" ? "4" : "4") << "\n"; + irStream << " " << allocaName << " = alloca " << llvmType << ", align 4\n"; if (ctx->ASSIGN()) { std::string value = std::any_cast(ctx->initVal()->accept(this)); - irStream << " store " << llvmType << " " << value << ", " << llvmType << "* " << allocaName << ", align " << (type == "float" ? "4" : "4") << "\n"; + irStream << " store " << llvmType << " " << value << ", " << llvmType + << "* " << allocaName << ", align 4\n"; } return nullptr; } @@ -62,25 +77,30 @@ std::any SysYIRGenerator::visitFuncDef(SysYParser::FuncDefContext* ctx) { currentReturnType = getLLVMType(ctx->funcType()->getText()); symbolTable.clear(); hasReturn = false; - - irStream << "define " << currentReturnType << " @" << currentFunction << "("; - auto paramsCtx = ctx->funcFParams(); - if (paramsCtx) { - auto params = paramsCtx->funcFParam(); + tempCounter = 0; + + irStream << "define dso_local " << currentReturnType << " @" << currentFunction << "("; + if (ctx->funcFParams()) { + auto params = ctx->funcFParams()->funcFParam(); for (size_t i = 0; i < params.size(); ++i) { if (i > 0) irStream << ", "; - auto param = params[i]; - std::string paramName = "%" + std::to_string(i); - std::string paramType = getLLVMType(param->bType()->getText()); - irStream << paramType << " " << paramName; - std::string allocaName = getNextTemp(); - symbolTable[param->Ident()->getText()] = {allocaName, paramType}; - irStream << "\n " << allocaName << " = alloca " << paramType << ", align " << (paramType == "float" ? "4" : "4"); - irStream << "\n store " << paramType << " %" << i << ", " << paramType << "* " << allocaName << ", align " << (paramType == "float" ? "4" : "4"); + std::string paramType = getLLVMType(params[i]->bType()->getText()); + irStream << paramType << " noundef %" << i; } } - irStream << ") {\nentry:\n"; + irStream << ") #0 {\nentry:\n"; + + if (ctx->funcFParams()) { + auto params = ctx->funcFParams()->funcFParam(); + for (size_t i = 0; i < params.size(); ++i) { + std::string paramName = params[i]->Ident()->getText(); + std::string paramType = getLLVMType(params[i]->bType()->getText()); + symbolTable[paramName] = {"%" + std::to_string(i), paramType}; + } + } + ctx->blockStmt()->accept(this); + if (!hasReturn) { if (currentReturnType == "void") { irStream << " ret void\n"; @@ -88,7 +108,7 @@ std::any SysYIRGenerator::visitFuncDef(SysYParser::FuncDefContext* ctx) { irStream << " ret " << currentReturnType << " 0\n"; } } - irStream << "}\n"; + irStream << "}\n\n"; return nullptr; } @@ -105,7 +125,8 @@ std::any SysYIRGenerator::visitStmt(SysYParser::StmtContext* ctx) { std::string varName = ctx->lValue()->Ident()->getText(); std::string lhsType = symbolTable[varName].second; std::string rhs = std::any_cast(ctx->exp()->accept(this)); - irStream << " store " << lhsType << " " << rhs << ", " << lhsType << "* " << lhsAlloca << ", align " << (lhsType == "float" ? "4" : "4") << "\n"; + irStream << " store " << lhsType << " " << rhs << ", " << lhsType + << "* " << lhsAlloca << ", align 4\n"; } else if (ctx->RETURN()) { hasReturn = true; if (ctx->exp()) { @@ -114,6 +135,25 @@ std::any SysYIRGenerator::visitStmt(SysYParser::StmtContext* ctx) { } else { irStream << " ret void\n"; } + } else if (ctx->IF()) { + std::string cond = std::any_cast(ctx->cond()->accept(this)); + std::string trueLabel = "if.then." + std::to_string(tempCounter); + std::string falseLabel = "if.else." + std::to_string(tempCounter); + std::string mergeLabel = "if.end." + std::to_string(tempCounter++); + + irStream << " br i1 " << cond << ", label %" << trueLabel << ", label %" << falseLabel << "\n\n"; + + irStream << trueLabel << ":\n"; + ctx->stmt(0)->accept(this); + irStream << " br label %" << mergeLabel << "\n\n"; + + irStream << falseLabel << ":\n"; + if (ctx->ELSE()) { + ctx->stmt(1)->accept(this); + } + irStream << " br label %" << mergeLabel << "\n\n"; + + irStream << mergeLabel << ":\n"; } return nullptr; } @@ -129,7 +169,7 @@ std::any SysYIRGenerator::visitPrimaryExp(SysYParser::PrimaryExpContext* ctx) { std::string varName = ctx->lValue()->Ident()->getText(); std::string type = symbolTable[varName].second; std::string temp = getNextTemp(); - irStream << " " << temp << " = load " << type << ", " << type << "* " << allocaPtr << ", align " << (type == "float" ? "4" : "4") << "\n"; + irStream << " " << temp << " = load " << type << ", " << type << "* " << allocaPtr << ", align 4\n"; return temp; } else if (ctx->exp()) { return ctx->exp()->accept(this); @@ -159,6 +199,24 @@ std::any SysYIRGenerator::visitUnaryExp(SysYParser::UnaryExpContext* ctx) { irStream << " " << temp << " = xor " << type << " " << operand << ", 1\n"; } return temp; + } else if (ctx->Ident()) { + std::string funcName = ctx->Ident()->getText(); + std::vector args; + if (ctx->funcRParams()) { + for (auto argCtx : ctx->funcRParams()->exp()) { + args.push_back(std::any_cast(argCtx->accept(this))); + } + } + + std::string temp = getNextTemp(); + std::string argList = ""; + for (size_t i = 0; i < args.size(); ++i) { + if (i > 0) argList += ", "; + argList += args[i]; + } + + irStream << " " << temp << " = call " << currentReturnType << " @" << funcName << "(" << argList << ")\n"; + return temp; } return ctx->primaryExp()->accept(this); } @@ -172,7 +230,7 @@ std::any SysYIRGenerator::visitMulExp(SysYParser::MulExpContext* ctx) { std::string temp = getNextTemp(); std::string type = left.substr(0, left.find(' ')); if (op == "*") { - irStream << " " << temp << " = mul " << type << " " << left << ", " << right << "\n"; + irStream << " " << temp << " = mul nsw " << type << " " << left << ", " << right << "\n"; } else if (op == "/") { irStream << " " << temp << " = sdiv " << type << " " << left << ", " << right << "\n"; } else if (op == "%") { @@ -192,9 +250,9 @@ std::any SysYIRGenerator::visitAddExp(SysYParser::AddExpContext* ctx) { std::string temp = getNextTemp(); std::string type = left.substr(0, left.find(' ')); if (op == "+") { - irStream << " " << temp << " = add " << type << " " << left << ", " << right << "\n"; + irStream << " " << temp << " = add nsw " << type << " " << left << ", " << right << "\n"; } else if (op == "-") { - irStream << " " << temp << " = sub " << type << " " << left << ", " << right << "\n"; + irStream << " " << temp << " = sub nsw " << type << " " << left << ", " << right << "\n"; } left = temp; } @@ -245,9 +303,19 @@ std::any SysYIRGenerator::visitLAndExp(SysYParser::LAndExpContext* ctx) { auto eqExps = ctx->eqExp(); std::string left = std::any_cast(eqExps[0]->accept(this)); for (size_t i = 1; i < eqExps.size(); ++i) { - std::string right = std::any_cast(eqExps[i]->accept(this)); + std::string falseLabel = "land.false." + std::to_string(tempCounter); + std::string endLabel = "land.end." + std::to_string(tempCounter++); std::string temp = getNextTemp(); - irStream << " " << temp << " = and i1 " << left << ", " << right << "\n"; + + irStream << " " << temp << " = and i1 " << left << ", 1\n"; + irStream << " br i1 " << temp << ", label %" << falseLabel << ", label %" << endLabel << "\n\n"; + + irStream << falseLabel << ":\n"; + std::string right = std::any_cast(eqExps[i]->accept(this)); + irStream << " " << temp << " = and i1 " << temp << ", " << right << "\n"; + irStream << " br label %" << endLabel << "\n\n"; + + irStream << endLabel << ":\n"; left = temp; } return left; @@ -257,9 +325,19 @@ std::any SysYIRGenerator::visitLOrExp(SysYParser::LOrExpContext* ctx) { auto lAndExps = ctx->lAndExp(); std::string left = std::any_cast(lAndExps[0]->accept(this)); for (size_t i = 1; i < lAndExps.size(); ++i) { - std::string right = std::any_cast(lAndExps[i]->accept(this)); + std::string trueLabel = "lor.true." + std::to_string(tempCounter); + std::string endLabel = "lor.end." + std::to_string(tempCounter++); std::string temp = getNextTemp(); - irStream << " " << temp << " = or i1 " << left << ", " << right << "\n"; + + irStream << " " << temp << " = or i1 " << left << ", 0\n"; + irStream << " br i1 " << temp << ", label %" << trueLabel << ", label %" << endLabel << "\n\n"; + + irStream << trueLabel << ":\n"; + std::string right = std::any_cast(lAndExps[i]->accept(this)); + irStream << " " << temp << " = or i1 " << temp << ", " << right << "\n"; + irStream << " br label %" << endLabel << "\n\n"; + + irStream << endLabel << ":\n"; left = temp; } return left; diff --git a/src/SysYIRGenerator.h b/src/SysYIRGenerator.h index 0f1b362..c65e1ff 100644 --- a/src/SysYIRGenerator.h +++ b/src/SysYIRGenerator.h @@ -1,28 +1,33 @@ -// SysYIRGenerator.h #pragma once #include "SysYBaseVisitor.h" #include "SysYParser.h" #include #include #include +#include class SysYIRGenerator : public SysYBaseVisitor { public: std::string generateIR(SysYParser::CompUnitContext* unit); std::string getIR() const { return irStream.str(); } + private: std::stringstream irStream; int tempCounter = 0; - std::map> symbolTable; // {varName: {allocaName, type}} + std::map> symbolTable; std::vector globalVars; std::string currentFunction; std::string currentReturnType; std::vector breakStack; std::vector continueStack; bool hasReturn = false; + std::stack loopEndStack; + std::stack loopCondStack; std::string getNextTemp(); std::string getLLVMType(const std::string& type); + + // 访问方法 std::any visitCompUnit(SysYParser::CompUnitContext* ctx) override; std::any visitConstDecl(SysYParser::ConstDeclContext* ctx) override; std::any visitVarDecl(SysYParser::VarDeclContext* ctx) override;