From 3d60a94894a3541b84e9a49fadd54fd83ba5b0ab Mon Sep 17 00:00:00 2001 From: Lixuanwang Date: Mon, 10 Mar 2025 21:43:20 +0800 Subject: [PATCH] [lab2] testfile01 finished --- src/SysYIRGenerator.cpp | 129 ++++++++++++++++++++++++---------------- src/SysYIRGenerator.h | 26 ++++---- test/10_test.sy | 2 +- 3 files changed, 93 insertions(+), 64 deletions(-) diff --git a/src/SysYIRGenerator.cpp b/src/SysYIRGenerator.cpp index 1acead1..55e4396 100644 --- a/src/SysYIRGenerator.cpp +++ b/src/SysYIRGenerator.cpp @@ -1,3 +1,4 @@ +// SysYIRGenerator.cpp #include "SysYIRGenerator.h" #include @@ -14,7 +15,7 @@ std::string SysYIRGenerator::getLLVMType(const std::string& type) { if (type == "int") return "i32"; if (type == "float") return "float"; if (type.find("[]") != std::string::npos) - return getLLVMType(type.substr(0, type.size() - 2)) + "*"; + return getLLVMType(type.substr(0, type.size()-2)) + "*"; return "i32"; } @@ -29,26 +30,40 @@ std::any SysYIRGenerator::visitCompUnit(SysYParser::CompUnitContext* ctx) { } std::any SysYIRGenerator::visitConstDecl(SysYParser::ConstDeclContext* ctx) { - // 常量声明暂不处理(LLVM IR 中常量通常内联) return nullptr; } std::any SysYIRGenerator::visitVarDecl(SysYParser::VarDeclContext* ctx) { + std::string type = ctx->bType()->getText(); for (auto varDef : ctx->varDef()) { + symbolTable[varDef->Ident()->getText()].second = type; varDef->accept(this); } return nullptr; } +std::any SysYIRGenerator::visitVarDef(SysYParser::VarDefContext* ctx) { + std::string varName = ctx->Ident()->getText(); + std::string type = symbolTable[varName].second; + std::string llvmType = getLLVMType(type); + std::string allocaName = getNextTemp(); + symbolTable[varName] = {allocaName, llvmType}; + irStream << " " << allocaName << " = alloca " << llvmType << ", align " << (type == "float" ? "4" : "4") << "\n"; + + if (ctx->ASSIGN()) { + std::string value = std::any_cast(ctx->initVal()->accept(this)); + irStream << " store " << llvmType << " " << value << ", " << llvmType << "* " << allocaName << ", align " << (type == "float" ? "4" : "4") << "\n"; + } + return nullptr; +} + std::any SysYIRGenerator::visitFuncDef(SysYParser::FuncDefContext* ctx) { currentFunction = ctx->Ident()->getText(); + currentReturnType = getLLVMType(ctx->funcType()->getText()); symbolTable.clear(); + hasReturn = false; - // 函数头 - std::string returnType = getLLVMType(ctx->funcType()->getText()); - irStream << "define " << returnType << " @" << currentFunction << "("; - - // 参数 + irStream << "define " << currentReturnType << " @" << currentFunction << "("; auto paramsCtx = ctx->funcFParams(); if (paramsCtx) { auto params = paramsCtx->funcFParam(); @@ -58,26 +73,22 @@ std::any SysYIRGenerator::visitFuncDef(SysYParser::FuncDefContext* ctx) { std::string paramName = "%" + std::to_string(i); std::string paramType = getLLVMType(param->bType()->getText()); irStream << paramType << " " << paramName; - - // 分配参数 std::string allocaName = getNextTemp(); - symbolTable[param->Ident()->getText()] = allocaName; - irStream << "\n " << allocaName << " = alloca " << paramType; - irStream << "\n store " << paramType << " %" << i << ", " << paramType << "* " << allocaName; + symbolTable[param->Ident()->getText()] = {allocaName, paramType}; + irStream << "\n " << allocaName << " = alloca " << paramType << ", align " << (paramType == "float" ? "4" : "4"); + irStream << "\n store " << paramType << " %" << i << ", " << paramType << "* " << allocaName << ", align " << (paramType == "float" ? "4" : "4"); } } irStream << ") {\nentry:\n"; - - // 函数体 ctx->blockStmt()->accept(this); - - // 默认返回值 - if (returnType == "void") { - irStream << " ret void\n"; - } else { - irStream << " ret " << returnType << " 0\n"; + if (!hasReturn) { + if (currentReturnType == "void") { + irStream << " ret void\n"; + } else { + irStream << " ret " << currentReturnType << " 0\n"; + } } - irStream << "}\n\n"; + irStream << "}\n"; return nullptr; } @@ -89,16 +100,17 @@ std::any SysYIRGenerator::visitBlockStmt(SysYParser::BlockStmtContext* ctx) { } std::any SysYIRGenerator::visitStmt(SysYParser::StmtContext* ctx) { - if (ctx->lValue() && ctx->exp()) { - // 赋值语句 - std::string lhs = std::any_cast(ctx->lValue()->accept(this)); + if (ctx->lValue() && ctx->ASSIGN()) { + std::string lhsAlloca = std::any_cast(ctx->lValue()->accept(this)); + std::string varName = ctx->lValue()->Ident()->getText(); + std::string lhsType = symbolTable[varName].second; std::string rhs = std::any_cast(ctx->exp()->accept(this)); - irStream << " store " << getLLVMType("") << " " << rhs << ", " << getLLVMType("") << "* " << lhs << "\n"; + irStream << " store " << lhsType << " " << rhs << ", " << lhsType << "* " << lhsAlloca << ", align " << (lhsType == "float" ? "4" : "4") << "\n"; } else if (ctx->RETURN()) { - // 返回语句 + hasReturn = true; if (ctx->exp()) { std::string value = std::any_cast(ctx->exp()->accept(this)); - irStream << " ret " << getLLVMType("") << " " << value << "\n"; + irStream << " ret " << currentReturnType << " " << value << "\n"; } else { irStream << " ret void\n"; } @@ -108,12 +120,22 @@ std::any SysYIRGenerator::visitStmt(SysYParser::StmtContext* ctx) { std::any SysYIRGenerator::visitLValue(SysYParser::LValueContext* ctx) { std::string varName = ctx->Ident()->getText(); - if (symbolTable.find(varName) == symbolTable.end()) { - std::string allocaName = getNextTemp(); - symbolTable[varName] = allocaName; - irStream << " " << allocaName << " = alloca " << getLLVMType("") << "\n"; + return symbolTable[varName].first; +} + +std::any SysYIRGenerator::visitPrimaryExp(SysYParser::PrimaryExpContext* ctx) { + if (ctx->lValue()) { + std::string allocaPtr = std::any_cast(ctx->lValue()->accept(this)); + std::string varName = ctx->lValue()->Ident()->getText(); + std::string type = symbolTable[varName].second; + std::string temp = getNextTemp(); + irStream << " " << temp << " = load " << type << ", " << type << "* " << allocaPtr << ", align " << (type == "float" ? "4" : "4") << "\n"; + return temp; + } else if (ctx->exp()) { + return ctx->exp()->accept(this); + } else { + return ctx->number()->accept(this); } - return symbolTable[varName]; } std::any SysYIRGenerator::visitNumber(SysYParser::NumberContext* ctx) { @@ -130,10 +152,11 @@ std::any SysYIRGenerator::visitUnaryExp(SysYParser::UnaryExpContext* ctx) { std::string operand = std::any_cast(ctx->unaryExp()->accept(this)); std::string op = ctx->unaryOp()->getText(); std::string temp = getNextTemp(); + std::string type = operand.substr(0, operand.find(' ')); if (op == "-") { - irStream << " " << temp << " = sub " << getLLVMType("") << " 0, " << operand << "\n"; + irStream << " " << temp << " = sub " << type << " 0, " << operand << "\n"; } else if (op == "!") { - irStream << " " << temp << " = xor " << getLLVMType("") << " " << operand << ", 1\n"; + irStream << " " << temp << " = xor " << type << " " << operand << ", 1\n"; } return temp; } @@ -145,14 +168,15 @@ std::any SysYIRGenerator::visitMulExp(SysYParser::MulExpContext* ctx) { std::string left = std::any_cast(unaryExps[0]->accept(this)); for (size_t i = 1; i < unaryExps.size(); ++i) { std::string right = std::any_cast(unaryExps[i]->accept(this)); - std::string op = ctx->children[2 * i - 1]->getText(); + std::string op = ctx->children[2*i-1]->getText(); std::string temp = getNextTemp(); + std::string type = left.substr(0, left.find(' ')); if (op == "*") { - irStream << " " << temp << " = mul " << getLLVMType("") << " " << left << ", " << right << "\n"; + irStream << " " << temp << " = mul " << type << " " << left << ", " << right << "\n"; } else if (op == "/") { - irStream << " " << temp << " = sdiv " << getLLVMType("") << " " << left << ", " << right << "\n"; + irStream << " " << temp << " = sdiv " << type << " " << left << ", " << right << "\n"; } else if (op == "%") { - irStream << " " << temp << " = srem " << getLLVMType("") << " " << left << ", " << right << "\n"; + irStream << " " << temp << " = srem " << type << " " << left << ", " << right << "\n"; } left = temp; } @@ -164,12 +188,13 @@ std::any SysYIRGenerator::visitAddExp(SysYParser::AddExpContext* ctx) { std::string left = std::any_cast(mulExps[0]->accept(this)); for (size_t i = 1; i < mulExps.size(); ++i) { std::string right = std::any_cast(mulExps[i]->accept(this)); - std::string op = ctx->children[2 * i - 1]->getText(); + std::string op = ctx->children[2*i-1]->getText(); std::string temp = getNextTemp(); + std::string type = left.substr(0, left.find(' ')); if (op == "+") { - irStream << " " << temp << " = add " << getLLVMType("") << " " << left << ", " << right << "\n"; + irStream << " " << temp << " = add " << type << " " << left << ", " << right << "\n"; } else if (op == "-") { - irStream << " " << temp << " = sub " << getLLVMType("") << " " << left << ", " << right << "\n"; + irStream << " " << temp << " = sub " << type << " " << left << ", " << right << "\n"; } left = temp; } @@ -181,16 +206,17 @@ std::any SysYIRGenerator::visitRelExp(SysYParser::RelExpContext* ctx) { std::string left = std::any_cast(addExps[0]->accept(this)); for (size_t i = 1; i < addExps.size(); ++i) { std::string right = std::any_cast(addExps[i]->accept(this)); - std::string op = ctx->children[2 * i - 1]->getText(); + std::string op = ctx->children[2*i-1]->getText(); std::string temp = getNextTemp(); + std::string type = left.substr(0, left.find(' ')); if (op == "<") { - irStream << " " << temp << " = icmp slt " << getLLVMType("") << " " << left << ", " << right << "\n"; + irStream << " " << temp << " = icmp slt " << type << " " << left << ", " << right << "\n"; } else if (op == ">") { - irStream << " " << temp << " = icmp sgt " << getLLVMType("") << " " << left << ", " << right << "\n"; + irStream << " " << temp << " = icmp sgt " << type << " " << left << ", " << right << "\n"; } else if (op == "<=") { - irStream << " " << temp << " = icmp sle " << getLLVMType("") << " " << left << ", " << right << "\n"; + irStream << " " << temp << " = icmp sle " << type << " " << left << ", " << right << "\n"; } else if (op == ">=") { - irStream << " " << temp << " = icmp sge " << getLLVMType("") << " " << left << ", " << right << "\n"; + irStream << " " << temp << " = icmp sge " << type << " " << left << ", " << right << "\n"; } left = temp; } @@ -202,12 +228,13 @@ std::any SysYIRGenerator::visitEqExp(SysYParser::EqExpContext* ctx) { std::string left = std::any_cast(relExps[0]->accept(this)); for (size_t i = 1; i < relExps.size(); ++i) { std::string right = std::any_cast(relExps[i]->accept(this)); - std::string op = ctx->children[2 * i - 1]->getText(); + std::string op = ctx->children[2*i-1]->getText(); std::string temp = getNextTemp(); + std::string type = left.substr(0, left.find(' ')); if (op == "==") { - irStream << " " << temp << " = icmp eq " << getLLVMType("") << " " << left << ", " << right << "\n"; + irStream << " " << temp << " = icmp eq " << type << " " << left << ", " << right << "\n"; } else if (op == "!=") { - irStream << " " << temp << " = icmp ne " << getLLVMType("") << " " << left << ", " << right << "\n"; + irStream << " " << temp << " = icmp ne " << type << " " << left << ", " << right << "\n"; } left = temp; } @@ -220,7 +247,7 @@ std::any SysYIRGenerator::visitLAndExp(SysYParser::LAndExpContext* ctx) { for (size_t i = 1; i < eqExps.size(); ++i) { std::string right = std::any_cast(eqExps[i]->accept(this)); std::string temp = getNextTemp(); - irStream << " " << temp << " = and " << getLLVMType("") << " " << left << ", " << right << "\n"; + irStream << " " << temp << " = and i1 " << left << ", " << right << "\n"; left = temp; } return left; @@ -232,7 +259,7 @@ std::any SysYIRGenerator::visitLOrExp(SysYParser::LOrExpContext* ctx) { for (size_t i = 1; i < lAndExps.size(); ++i) { std::string right = std::any_cast(lAndExps[i]->accept(this)); std::string temp = getNextTemp(); - irStream << " " << temp << " = or " << getLLVMType("") << " " << left << ", " << right << "\n"; + irStream << " " << temp << " = or i1 " << left << ", " << right << "\n"; left = temp; } return left; diff --git a/src/SysYIRGenerator.h b/src/SysYIRGenerator.h index c11d8ab..0f1b362 100644 --- a/src/SysYIRGenerator.h +++ b/src/SysYIRGenerator.h @@ -1,3 +1,4 @@ +// SysYIRGenerator.h #pragma once #include "SysYBaseVisitor.h" #include "SysYParser.h" @@ -7,29 +8,30 @@ class SysYIRGenerator : public SysYBaseVisitor { public: - std::string generateIR(SysYParser::CompUnitContext* unit); // 公共接口,用于生成 IR - std::string getIR() const { return irStream.str(); } // 获取生成的 IR - + std::string generateIR(SysYParser::CompUnitContext* unit); + std::string getIR() const { return irStream.str(); } private: std::stringstream irStream; int tempCounter = 0; - std::map symbolTable; // 符号表 - std::vector globalVars; // 全局变量 - std::string currentFunction; // 当前函数名 - std::vector breakStack; // break 目标标签栈 - std::vector continueStack; // continue 目标标签栈 + std::map> symbolTable; // {varName: {allocaName, type}} + std::vector globalVars; + std::string currentFunction; + std::string currentReturnType; + std::vector breakStack; + std::vector continueStack; + bool hasReturn = false; - std::string getNextTemp(); // 获取下一个临时变量名 - std::string getLLVMType(const std::string& type); // 获取 LLVM 类型 - - // 访问方法 + std::string getNextTemp(); + std::string getLLVMType(const std::string& type); std::any visitCompUnit(SysYParser::CompUnitContext* ctx) override; std::any visitConstDecl(SysYParser::ConstDeclContext* ctx) override; std::any visitVarDecl(SysYParser::VarDeclContext* ctx) override; + std::any visitVarDef(SysYParser::VarDefContext* ctx) override; std::any visitFuncDef(SysYParser::FuncDefContext* ctx) override; std::any visitBlockStmt(SysYParser::BlockStmtContext* ctx) override; std::any visitStmt(SysYParser::StmtContext* ctx) override; std::any visitLValue(SysYParser::LValueContext* ctx) override; + std::any visitPrimaryExp(SysYParser::PrimaryExpContext* ctx) override; std::any visitNumber(SysYParser::NumberContext* ctx) override; std::any visitUnaryExp(SysYParser::UnaryExpContext* ctx) override; std::any visitMulExp(SysYParser::MulExpContext* ctx) override; diff --git a/test/10_test.sy b/test/10_test.sy index 05dab29..bf84af2 100644 --- a/test/10_test.sy +++ b/test/10_test.sy @@ -1,7 +1,7 @@ //test file for backend lab int main() { - const int a = 1; + int a; const int b = 2; int c;