From b0b03ff55b18548bd78748aeaf511f6b201c6923 Mon Sep 17 00:00:00 2001 From: Lixuanwang Date: Mon, 10 Mar 2025 16:50:18 +0800 Subject: [PATCH] [lab2] runnable --- .gitignore | 6 +- src/SysYIRGenerator.cpp | 261 ++++++++++++++++++++++++++++++++++++---- src/SysYIRGenerator.h | 54 +++++---- src/sysyc.cpp | 31 +++-- 4 files changed, 288 insertions(+), 64 deletions(-) diff --git a/.gitignore b/.gitignore index ae441e3..3649ac4 100644 --- a/.gitignore +++ b/.gitignore @@ -40,4 +40,8 @@ doxygen build .antlr -tmp \ No newline at end of file +tmp + +GPATH +GRTAGS +GTAGS \ No newline at end of file diff --git a/src/SysYIRGenerator.cpp b/src/SysYIRGenerator.cpp index 3eccc14..1acead1 100644 --- a/src/SysYIRGenerator.cpp +++ b/src/SysYIRGenerator.cpp @@ -1,30 +1,239 @@ -#include "IR.h" -#include -#include -using namespace std; #include "SysYIRGenerator.h" +#include -namespace sysy { - -any SysYIRGenerator::visitCompUnit(SysYParser::CompUnitContext *ctx) { - // create the IR module - auto pModule = new Module(); - assert(pModule); - module.reset(pModule); - // generates globals and functions - visitChildren(ctx); - // return the IR module - return pModule; -} -std::any -SysYIRGenerator::visitFuncRParams(SysYParser::FuncRParamsContext *ctx) { - return visitChildren(ctx); -} -std::any SysYIRGenerator::visitNumber(SysYParser::NumberContext *ctx) { - return visitChildren(ctx); -} -std::any SysYIRGenerator::visitString(SysYParser::StringContext *ctx) { - return visitChildren(ctx); +std::string SysYIRGenerator::generateIR(SysYParser::CompUnitContext* unit) { + visitCompUnit(unit); + return irStream.str(); } -} // namespace sysy \ No newline at end of file +std::string SysYIRGenerator::getNextTemp() { + return "%" + std::to_string(tempCounter++); +} + +std::string SysYIRGenerator::getLLVMType(const std::string& type) { + if (type == "int") return "i32"; + if (type == "float") return "float"; + if (type.find("[]") != std::string::npos) + return getLLVMType(type.substr(0, type.size() - 2)) + "*"; + return "i32"; +} + +std::any SysYIRGenerator::visitCompUnit(SysYParser::CompUnitContext* ctx) { + for (auto decl : ctx->decl()) { + decl->accept(this); + } + for (auto funcDef : ctx->funcDef()) { + funcDef->accept(this); + } + return nullptr; +} + +std::any SysYIRGenerator::visitConstDecl(SysYParser::ConstDeclContext* ctx) { + // 常量声明暂不处理(LLVM IR 中常量通常内联) + return nullptr; +} + +std::any SysYIRGenerator::visitVarDecl(SysYParser::VarDeclContext* ctx) { + for (auto varDef : ctx->varDef()) { + varDef->accept(this); + } + return nullptr; +} + +std::any SysYIRGenerator::visitFuncDef(SysYParser::FuncDefContext* ctx) { + currentFunction = ctx->Ident()->getText(); + symbolTable.clear(); + + // 函数头 + std::string returnType = getLLVMType(ctx->funcType()->getText()); + irStream << "define " << returnType << " @" << currentFunction << "("; + + // 参数 + auto paramsCtx = ctx->funcFParams(); + if (paramsCtx) { + auto params = paramsCtx->funcFParam(); + for (size_t i = 0; i < params.size(); ++i) { + if (i > 0) irStream << ", "; + auto param = params[i]; + std::string paramName = "%" + std::to_string(i); + std::string paramType = getLLVMType(param->bType()->getText()); + irStream << paramType << " " << paramName; + + // 分配参数 + std::string allocaName = getNextTemp(); + symbolTable[param->Ident()->getText()] = allocaName; + irStream << "\n " << allocaName << " = alloca " << paramType; + irStream << "\n store " << paramType << " %" << i << ", " << paramType << "* " << allocaName; + } + } + irStream << ") {\nentry:\n"; + + // 函数体 + ctx->blockStmt()->accept(this); + + // 默认返回值 + if (returnType == "void") { + irStream << " ret void\n"; + } else { + irStream << " ret " << returnType << " 0\n"; + } + irStream << "}\n\n"; + return nullptr; +} + +std::any SysYIRGenerator::visitBlockStmt(SysYParser::BlockStmtContext* ctx) { + for (auto item : ctx->blockItem()) { + item->accept(this); + } + return nullptr; +} + +std::any SysYIRGenerator::visitStmt(SysYParser::StmtContext* ctx) { + if (ctx->lValue() && ctx->exp()) { + // 赋值语句 + std::string lhs = std::any_cast(ctx->lValue()->accept(this)); + std::string rhs = std::any_cast(ctx->exp()->accept(this)); + irStream << " store " << getLLVMType("") << " " << rhs << ", " << getLLVMType("") << "* " << lhs << "\n"; + } else if (ctx->RETURN()) { + // 返回语句 + if (ctx->exp()) { + std::string value = std::any_cast(ctx->exp()->accept(this)); + irStream << " ret " << getLLVMType("") << " " << value << "\n"; + } else { + irStream << " ret void\n"; + } + } + return nullptr; +} + +std::any SysYIRGenerator::visitLValue(SysYParser::LValueContext* ctx) { + std::string varName = ctx->Ident()->getText(); + if (symbolTable.find(varName) == symbolTable.end()) { + std::string allocaName = getNextTemp(); + symbolTable[varName] = allocaName; + irStream << " " << allocaName << " = alloca " << getLLVMType("") << "\n"; + } + return symbolTable[varName]; +} + +std::any SysYIRGenerator::visitNumber(SysYParser::NumberContext* ctx) { + if (ctx->ILITERAL()) { + return "i32 " + ctx->ILITERAL()->getText(); + } else if (ctx->FLITERAL()) { + return "float " + ctx->FLITERAL()->getText(); + } + return ""; +} + +std::any SysYIRGenerator::visitUnaryExp(SysYParser::UnaryExpContext* ctx) { + if (ctx->unaryOp()) { + std::string operand = std::any_cast(ctx->unaryExp()->accept(this)); + std::string op = ctx->unaryOp()->getText(); + std::string temp = getNextTemp(); + if (op == "-") { + irStream << " " << temp << " = sub " << getLLVMType("") << " 0, " << operand << "\n"; + } else if (op == "!") { + irStream << " " << temp << " = xor " << getLLVMType("") << " " << operand << ", 1\n"; + } + return temp; + } + return ctx->primaryExp()->accept(this); +} + +std::any SysYIRGenerator::visitMulExp(SysYParser::MulExpContext* ctx) { + auto unaryExps = ctx->unaryExp(); + std::string left = std::any_cast(unaryExps[0]->accept(this)); + for (size_t i = 1; i < unaryExps.size(); ++i) { + std::string right = std::any_cast(unaryExps[i]->accept(this)); + std::string op = ctx->children[2 * i - 1]->getText(); + std::string temp = getNextTemp(); + if (op == "*") { + irStream << " " << temp << " = mul " << getLLVMType("") << " " << left << ", " << right << "\n"; + } else if (op == "/") { + irStream << " " << temp << " = sdiv " << getLLVMType("") << " " << left << ", " << right << "\n"; + } else if (op == "%") { + irStream << " " << temp << " = srem " << getLLVMType("") << " " << left << ", " << right << "\n"; + } + left = temp; + } + return left; +} + +std::any SysYIRGenerator::visitAddExp(SysYParser::AddExpContext* ctx) { + auto mulExps = ctx->mulExp(); + std::string left = std::any_cast(mulExps[0]->accept(this)); + for (size_t i = 1; i < mulExps.size(); ++i) { + std::string right = std::any_cast(mulExps[i]->accept(this)); + std::string op = ctx->children[2 * i - 1]->getText(); + std::string temp = getNextTemp(); + if (op == "+") { + irStream << " " << temp << " = add " << getLLVMType("") << " " << left << ", " << right << "\n"; + } else if (op == "-") { + irStream << " " << temp << " = sub " << getLLVMType("") << " " << left << ", " << right << "\n"; + } + left = temp; + } + return left; +} + +std::any SysYIRGenerator::visitRelExp(SysYParser::RelExpContext* ctx) { + auto addExps = ctx->addExp(); + std::string left = std::any_cast(addExps[0]->accept(this)); + for (size_t i = 1; i < addExps.size(); ++i) { + std::string right = std::any_cast(addExps[i]->accept(this)); + std::string op = ctx->children[2 * i - 1]->getText(); + std::string temp = getNextTemp(); + if (op == "<") { + irStream << " " << temp << " = icmp slt " << getLLVMType("") << " " << left << ", " << right << "\n"; + } else if (op == ">") { + irStream << " " << temp << " = icmp sgt " << getLLVMType("") << " " << left << ", " << right << "\n"; + } else if (op == "<=") { + irStream << " " << temp << " = icmp sle " << getLLVMType("") << " " << left << ", " << right << "\n"; + } else if (op == ">=") { + irStream << " " << temp << " = icmp sge " << getLLVMType("") << " " << left << ", " << right << "\n"; + } + left = temp; + } + return left; +} + +std::any SysYIRGenerator::visitEqExp(SysYParser::EqExpContext* ctx) { + auto relExps = ctx->relExp(); + std::string left = std::any_cast(relExps[0]->accept(this)); + for (size_t i = 1; i < relExps.size(); ++i) { + std::string right = std::any_cast(relExps[i]->accept(this)); + std::string op = ctx->children[2 * i - 1]->getText(); + std::string temp = getNextTemp(); + if (op == "==") { + irStream << " " << temp << " = icmp eq " << getLLVMType("") << " " << left << ", " << right << "\n"; + } else if (op == "!=") { + irStream << " " << temp << " = icmp ne " << getLLVMType("") << " " << left << ", " << right << "\n"; + } + left = temp; + } + return left; +} + +std::any SysYIRGenerator::visitLAndExp(SysYParser::LAndExpContext* ctx) { + auto eqExps = ctx->eqExp(); + std::string left = std::any_cast(eqExps[0]->accept(this)); + for (size_t i = 1; i < eqExps.size(); ++i) { + std::string right = std::any_cast(eqExps[i]->accept(this)); + std::string temp = getNextTemp(); + irStream << " " << temp << " = and " << getLLVMType("") << " " << left << ", " << right << "\n"; + left = temp; + } + return left; +} + +std::any SysYIRGenerator::visitLOrExp(SysYParser::LOrExpContext* ctx) { + auto lAndExps = ctx->lAndExp(); + std::string left = std::any_cast(lAndExps[0]->accept(this)); + for (size_t i = 1; i < lAndExps.size(); ++i) { + std::string right = std::any_cast(lAndExps[i]->accept(this)); + std::string temp = getNextTemp(); + irStream << " " << temp << " = or " << getLLVMType("") << " " << left << ", " << right << "\n"; + left = temp; + } + return left; +} \ No newline at end of file diff --git a/src/SysYIRGenerator.h b/src/SysYIRGenerator.h index bfc1865..c11d8ab 100644 --- a/src/SysYIRGenerator.h +++ b/src/SysYIRGenerator.h @@ -1,29 +1,41 @@ #pragma once - -#include "IR.h" -#include "IRBuilder.h" #include "SysYBaseVisitor.h" #include "SysYParser.h" -#include - -namespace sysy { +#include +#include +#include class SysYIRGenerator : public SysYBaseVisitor { +public: + std::string generateIR(SysYParser::CompUnitContext* unit); // 公共接口,用于生成 IR + std::string getIR() const { return irStream.str(); } // 获取生成的 IR + private: - std::unique_ptr module; - IRBuilder builder; + std::stringstream irStream; + int tempCounter = 0; + std::map symbolTable; // 符号表 + std::vector globalVars; // 全局变量 + std::string currentFunction; // 当前函数名 + std::vector breakStack; // break 目标标签栈 + std::vector continueStack; // continue 目标标签栈 -public: - SysYIRGenerator() = default; + std::string getNextTemp(); // 获取下一个临时变量名 + std::string getLLVMType(const std::string& type); // 获取 LLVM 类型 -public: - Module *get() const { return module.get(); } - -public: - std::any visitCompUnit(SysYParser::CompUnitContext *ctx) override; - std::any visitFuncRParams(SysYParser::FuncRParamsContext *ctx) override; - std::any visitNumber(SysYParser::NumberContext *ctx) override; - std::any visitString(SysYParser::StringContext *ctx) override; -}; // class SysYIRGenerator - -} // namespace sysy \ No newline at end of file + // 访问方法 + std::any visitCompUnit(SysYParser::CompUnitContext* ctx) override; + std::any visitConstDecl(SysYParser::ConstDeclContext* ctx) override; + std::any visitVarDecl(SysYParser::VarDeclContext* ctx) override; + std::any visitFuncDef(SysYParser::FuncDefContext* ctx) override; + std::any visitBlockStmt(SysYParser::BlockStmtContext* ctx) override; + std::any visitStmt(SysYParser::StmtContext* ctx) override; + std::any visitLValue(SysYParser::LValueContext* ctx) override; + std::any visitNumber(SysYParser::NumberContext* ctx) override; + std::any visitUnaryExp(SysYParser::UnaryExpContext* ctx) override; + std::any visitMulExp(SysYParser::MulExpContext* ctx) override; + std::any visitAddExp(SysYParser::AddExpContext* ctx) override; + std::any visitRelExp(SysYParser::RelExpContext* ctx) override; + std::any visitEqExp(SysYParser::EqExpContext* ctx) override; + std::any visitLAndExp(SysYParser::LAndExpContext* ctx) override; + std::any visitLOrExp(SysYParser::LOrExpContext* ctx) override; +}; \ No newline at end of file diff --git a/src/sysyc.cpp b/src/sysyc.cpp index d18df3c..b67c12d 100644 --- a/src/sysyc.cpp +++ b/src/sysyc.cpp @@ -18,9 +18,9 @@ static bool argFormat = false; void usage(int code = EXIT_FAILURE) { const char *msg = "Usage: sysyc [options] inputfile\n\n" "Supported options:\n" - " -h \tprint help message and exit\n"; - " -f \tpretty-format the input file\n"; - " -s {ast,ir,asm}\tstop after generating AST/IR/Assembly\n"; + " -h \tprint help message and exit\n" + " -f \tpretty-format the input file\n" + " -s {ast,ir,asm}\tstop after generating AST/IR/Assembly\n"; cerr << msg; exit(code); } @@ -51,14 +51,14 @@ void parseArgs(int argc, char **argv) { int main(int argc, char **argv) { parseArgs(argc, argv); - // open the input file + // 打开输入文件 ifstream fin(argInputFile); if (not fin) { cerr << "Failed to open file " << argv[1]; return EXIT_FAILURE; } - // parse sysy source to AST + // 解析 SysY 源码为 AST ANTLRInputStream input(fin); SysYLexer lexer(&input); CommonTokenStream tokens(&lexer); @@ -69,28 +69,27 @@ int main(int argc, char **argv) { return EXIT_SUCCESS; } - // pretty format the input file + // 格式化输入文件 if (argFormat) { ASTPrinter printer; printer.visitCompUnit(moduleAST); return EXIT_SUCCESS; } - // visit AST to generate IR + // 遍历 AST 生成 IR SysYIRGenerator generator; - generator.visitCompUnit(moduleAST); - auto moduleIR = generator.get(); + generator.generateIR(moduleAST); // 使用公共接口生成 IR if (argStopAfter == "ir") { - moduleIR->print(cout); + cout << generator.getIR(); // 输出生成的 IR return EXIT_SUCCESS; } - // generate assembly - CodeGen codegen(moduleIR); - string asmCode = codegen.code_gen(); - cout << asmCode << endl; - if (argStopAfter == "asm") - return EXIT_SUCCESS; + // // 生成汇编代码 + // CodeGen codegen(generator.getIR()); // 假设 CodeGen 接受字符串作为输入 + // string asmCode = codegen.code_gen(); + // cout << asmCode << endl; + // if (argStopAfter == "asm") + // return EXIT_SUCCESS; return EXIT_SUCCESS; } \ No newline at end of file