From bb73ce3b5adbb7f6b51df2e0e0126908b47266ee Mon Sep 17 00:00:00 2001 From: lixuanwang Date: Wed, 28 May 2025 23:49:02 +0800 Subject: [PATCH] merging branch lab2-IRGen into master --- .gitignore | 12 +- src/LLVMIRGenerator.cpp | 482 ++++++++++++++++++++++++++++++++++++++++ src/LLVMIRGenerator.h | 55 +++++ src/sysyc.cpp | 7 +- 4 files changed, 554 insertions(+), 2 deletions(-) create mode 100644 src/LLVMIRGenerator.cpp create mode 100644 src/LLVMIRGenerator.h diff --git a/.gitignore b/.gitignore index 2e16250..76d16bf 100644 --- a/.gitignore +++ b/.gitignore @@ -39,4 +39,14 @@ doxygen !/testdata/performance/*.out build/ .antlr -.vscode/ \ No newline at end of file +.vscode/ + +tmp + +GPATH +GRTAGS +GTAGS + +__init__.py + +*.pyc \ No newline at end of file diff --git a/src/LLVMIRGenerator.cpp b/src/LLVMIRGenerator.cpp new file mode 100644 index 0000000..b349207 --- /dev/null +++ b/src/LLVMIRGenerator.cpp @@ -0,0 +1,482 @@ +// LLVMIRGenerator.cpp +// TODO:类型转换及其检查 +// TODO:sysy库函数处理 +// TODO:数组处理 +// TODO:对while、continue、break的测试 +#include "LLVMIRGenerator.h" +#include + +std::string LLVMIRGenerator::generateIR(SysYParser::CompUnitContext* unit) { + visitCompUnit(unit); + return irStream.str(); +} + +std::string LLVMIRGenerator::getNextTemp() { + std::string ret = "%." + std::to_string(tempCounter++); + tmpTable[ret] = "void"; + return ret; +} + +std::string LLVMIRGenerator::getLLVMType(const std::string& type) { + if (type == "int") return "i32"; + if (type == "float") return "float"; + if (type.find("[]") != std::string::npos) + return getLLVMType(type.substr(0, type.size()-2)) + "*"; + return "i32"; +} + +std::any LLVMIRGenerator::visitCompUnit(SysYParser::CompUnitContext* ctx) { + for (auto decl : ctx->decl()) { + decl->accept(this); + } + for (auto funcDef : ctx->funcDef()) { + inFunction = true; // 进入函数定义 + funcDef->accept(this); + inFunction = false; // 离开函数定义 + } + return nullptr; +} + +std::any LLVMIRGenerator::visitVarDecl(SysYParser::VarDeclContext* ctx) { + // TODO:数组初始化 + std::string type = ctx->bType()->getText(); + currentVarType = getLLVMType(type); + + for (auto varDef : ctx->varDef()) { + if (!inFunction) { + // 全局变量声明 + std::string varName = varDef->Ident()->getText(); + std::string llvmType = getLLVMType(type); + std::string value = "0"; // 默认值为 0 + + if (varDef->ASSIGN()) { + value = std::any_cast(varDef->initVal()->accept(this)); + } + + irStream << "@" << varName << " = dso_local global " << llvmType << " " << value << ", align 4\n"; + globalVars.push_back(varName); // 记录全局变量 + } else { + // 局部变量声明 + varDef->accept(this); + } + } + return nullptr; +} + +std::any LLVMIRGenerator::visitConstDecl(SysYParser::ConstDeclContext* ctx) { + // TODO:数组初始化 + std::string type = ctx->bType()->getText(); + for (auto constDef : ctx->constDef()) { + if (!inFunction) { + // 全局常量声明 + std::string varName = constDef->Ident()->getText(); + std::string llvmType = getLLVMType(type); + std::string value = "0"; // 默认值为 0 + + try { + value = std::any_cast(constDef->constInitVal()->accept(this)); + } catch (...) { + throw std::runtime_error("Const value must be initialized upon definition."); + } + + if (llvmType == "float") { + try { + double floatValue = std::stod(value); + uint64_t hexValue = reinterpret_cast(floatValue); + std::stringstream ss; + ss << "0x" << std::hex << std::uppercase << hexValue; + value = ss.str(); + } catch (...) { + throw std::runtime_error("Invalid float literal: " + value); + } + } + + irStream << "@" << varName << " = dso_local constant " << llvmType << " " << value << ", align 4\n"; + globalVars.push_back(varName); // 记录全局变量 + } else { + // 局部常量声明 + std::string varName = constDef->Ident()->getText(); + std::string llvmType = getLLVMType(type); + std::string allocaName = getNextTemp(); + std::string value = "0"; // 默认值为 0 + + try { + value = std::any_cast(constDef->constInitVal()->accept(this)); + } catch (...) { + throw std::runtime_error("Const value must be initialized upon definition."); + } + + irStream << " " << allocaName << " = alloca " << llvmType << ", align 4\n"; + if (llvmType == "float") { + try { + double floatValue = std::stod(value); + uint64_t hexValue = reinterpret_cast(floatValue); + std::stringstream ss; + ss << "0x" << std::hex << std::uppercase << hexValue; + value = ss.str(); + } catch (...) { + throw std::runtime_error("Invalid float literal: " + value); + } + } + irStream << " store " << llvmType << " " << value << ", " << llvmType + << "* " << allocaName << ", align 4\n"; + + symbolTable[varName] = {allocaName, llvmType}; + tmpTable[allocaName] = llvmType; + } + } + return nullptr; +} + +std::any LLVMIRGenerator::visitVarDef(SysYParser::VarDefContext* ctx) { + // TODO:数组初始化 + std::string varName = ctx->Ident()->getText(); + std::string type = currentVarType; + std::string llvmType = getLLVMType(type); + std::string allocaName = getNextTemp(); + + irStream << " " << allocaName << " = alloca " << llvmType << ", align 4\n"; + + if (ctx->ASSIGN()) { + std::string value = std::any_cast(ctx->initVal()->accept(this)); + + if (llvmType == "float") { + try { + double floatValue = std::stod(value); + uint64_t hexValue = reinterpret_cast(floatValue); + std::stringstream ss; + ss << "0x" << std::hex << std::uppercase << (hexValue & (0xffffffffUL << 32)); + value = ss.str(); + } catch (...) { + throw std::runtime_error("Invalid float literal: " + value); + } + } + irStream << " store " << llvmType << " " << value << ", " << llvmType + << "* " << allocaName << ", align 4\n"; + } + symbolTable[varName] = {allocaName, llvmType}; + tmpTable[allocaName] = llvmType; + return nullptr; +} + +std::any LLVMIRGenerator::visitFuncDef(SysYParser::FuncDefContext* ctx) { + currentFunction = ctx->Ident()->getText(); + currentReturnType = getLLVMType(ctx->funcType()->getText()); + symbolTable.clear(); + tmpTable.clear(); + tempCounter = 0; + hasReturn = false; + + irStream << "define dso_local " << currentReturnType << " @" << currentFunction << "("; + if (ctx->funcFParams()) { + auto params = ctx->funcFParams()->funcFParam(); + tempCounter += params.size(); + for (size_t i = 0; i < params.size(); ++i) { + if (i > 0) irStream << ", "; + std::string paramType = getLLVMType(params[i]->bType()->getText()); + irStream << paramType << " noundef %" << i; + symbolTable[params[i]->Ident()->getText()] = {"%" + std::to_string(i), paramType}; + tmpTable["%" + std::to_string(i)] = paramType; + } + } + tempCounter++; + irStream << ") #0 {\n"; + + if (ctx->funcFParams()) { + auto params = ctx->funcFParams()->funcFParam(); + for (size_t i = 0; i < params.size(); ++i) { + std::string varName = params[i]->Ident()->getText(); + std::string type = params[i]->bType()->getText(); + std::string llvmType = getLLVMType(type); + std::string allocaName = getNextTemp(); + tmpTable[allocaName] = llvmType; + + irStream << " " << allocaName << " = alloca " << llvmType << ", align 4\n"; + irStream << " store " << llvmType << " " << symbolTable[varName].first << ", " << llvmType + << "* " << allocaName << ", align 4\n"; + + symbolTable[varName] = {allocaName, llvmType}; + } + } + ctx->blockStmt()->accept(this); + + if (!hasReturn) { + if (currentReturnType == "void") { + irStream << " ret void\n"; + } else { + irStream << " ret " << currentReturnType << " 0\n"; + } + } + irStream << "}\n"; + return nullptr; +} + +std::any LLVMIRGenerator::visitBlockStmt(SysYParser::BlockStmtContext* ctx) { + for (auto item : ctx->blockItem()) { + item->accept(this); + } + return nullptr; +} + +std::any LLVMIRGenerator::visitStmt(SysYParser::StmtContext* ctx) { + if (ctx->lValue() && ctx->ASSIGN()) { + std::string lhsAlloca = std::any_cast(ctx->lValue()->accept(this)); + std::string lhsType = symbolTable[ctx->lValue()->Ident()->getText()].second; + std::string rhs = std::any_cast(ctx->exp()->accept(this)); + if (lhsType == "float") { + try { + double floatValue = std::stod(rhs); + uint64_t hexValue = reinterpret_cast(floatValue); + std::stringstream ss; + ss << "0x" << std::hex << std::uppercase << (hexValue & (0xffffffffUL << 32)); + rhs = ss.str(); + } catch (...) { + throw std::runtime_error("Invalid float literal: " + rhs); + } + } + irStream << " store " << lhsType << " " << rhs << ", " << lhsType + << "* " << lhsAlloca << ", align 4\n"; + } else if (ctx->RETURN()) { + hasReturn = true; + if (ctx->exp()) { + std::string value = std::any_cast(ctx->exp()->accept(this)); + irStream << " ret " << currentReturnType << " " << value << "\n"; + } else { + irStream << " ret void\n"; + } + } else if (ctx->IF()) { + std::string cond = std::any_cast(ctx->cond()->accept(this)); + std::string trueLabel = "if.then." + std::to_string(tempCounter); + std::string falseLabel = "if.else." + std::to_string(tempCounter); + std::string mergeLabel = "if.end." + std::to_string(tempCounter++); + + irStream << " br i1 " << cond << ", label %" << trueLabel << ", label %" << falseLabel << "\n"; + + irStream << trueLabel << ":\n"; + ctx->stmt(0)->accept(this); + irStream << " br label %" << mergeLabel << "\n"; + + irStream << falseLabel << ":\n"; + if (ctx->ELSE()) { + ctx->stmt(1)->accept(this); + } + irStream << " br label %" << mergeLabel << "\n"; + + irStream << mergeLabel << ":\n"; + } else if (ctx->WHILE()) { + std::string loop_cond = "while.cond." + std::to_string(tempCounter); + std::string loop_body = "while.body." + std::to_string(tempCounter); + std::string loop_end = "while.end." + std::to_string(tempCounter++); + + loopStack.push({loop_end, loop_cond}); + irStream << " br label %" << loop_cond << "\n"; + irStream << loop_cond << ":\n"; + + std::string cond = std::any_cast(ctx->cond()->accept(this)); + irStream << " br i1 " << cond << ", label %" << loop_body << ", label %" << loop_end << "\n"; + irStream << loop_body << ":\n"; + ctx->stmt(0)->accept(this); + irStream << " br label %" << loop_cond << "\n"; + irStream << loop_end << ":\n"; + + loopStack.pop(); + + } else if (ctx->BREAK()) { + if (loopStack.empty()) { + throw std::runtime_error("Break statement outside of a loop."); + } + irStream << " br label %" << loopStack.top().breakLabel << "\n"; + } else if (ctx->CONTINUE()) { + if (loopStack.empty()) { + throw std::runtime_error("Continue statement outside of a loop."); + } + irStream << " br label %" << loopStack.top().continueLabel << "\n"; + } else if (ctx->blockStmt()) { + ctx->blockStmt()->accept(this); + } else if (ctx->exp()) { + ctx->exp()->accept(this); + } + return nullptr; +} + +std::any LLVMIRGenerator::visitLValue(SysYParser::LValueContext* ctx) { + std::string varName = ctx->Ident()->getText(); + return symbolTable[varName].first; +} + +std::any LLVMIRGenerator::visitPrimaryExp(SysYParser::PrimaryExpContext* ctx) { + if (ctx->lValue()) { + std::string allocaPtr = std::any_cast(ctx->lValue()->accept(this)); + std::string varName = ctx->lValue()->Ident()->getText(); + std::string type = symbolTable[varName].second; + std::string temp = getNextTemp(); + irStream << " " << temp << " = load " << type << ", " << type << "* " << allocaPtr << ", align 4\n"; + tmpTable[temp] = type; + return temp; + } else if (ctx->exp()) { + return ctx->exp()->accept(this); + } else { + return ctx->number()->accept(this); + } +} + +std::any LLVMIRGenerator::visitNumber(SysYParser::NumberContext* ctx) { + if (ctx->ILITERAL()) { + return ctx->ILITERAL()->getText(); + } else if (ctx->FLITERAL()) { + return ctx->FLITERAL()->getText(); + } + return ""; +} + +std::any LLVMIRGenerator::visitUnaryExp(SysYParser::UnaryExpContext* ctx) { + if (ctx->unaryOp()) { + std::string operand = std::any_cast(ctx->unaryExp()->accept(this)); + std::string op = ctx->unaryOp()->getText(); + std::string temp = getNextTemp(); + std::string type = operand.substr(0, operand.find(' ')); + tmpTable[temp] = type; + if (op == "-") { + irStream << " " << temp << " = sub " << type << " 0, " << operand << "\n"; + } else if (op == "!") { + irStream << " " << temp << " = xor " << type << " " << operand << ", 1\n"; + } + return temp; + } else if (ctx->Ident()) { + std::string funcName = ctx->Ident()->getText(); + std::vector args; + if (ctx->funcRParams()) { + for (auto argCtx : ctx->funcRParams()->exp()) { + args.push_back(std::any_cast(argCtx->accept(this))); + } + } + std::string temp = getNextTemp(); + std::string argList = ""; + for (size_t i = 0; i < args.size(); ++i) { + if (i > 0) argList += ", "; + argList +=tmpTable[args[i]] + " noundef " + args[i]; + } + irStream << " " << temp << " = call " << currentReturnType << " @" << funcName << "(" << argList << ")\n"; + tmpTable[temp] = currentReturnType; + return temp; + } + return ctx->primaryExp()->accept(this); +} + +std::any LLVMIRGenerator::visitMulExp(SysYParser::MulExpContext* ctx) { + auto unaryExps = ctx->unaryExp(); + std::string left = std::any_cast(unaryExps[0]->accept(this)); + for (size_t i = 1; i < unaryExps.size(); ++i) { + std::string right = std::any_cast(unaryExps[i]->accept(this)); + std::string op = ctx->children[2*i-1]->getText(); + std::string temp = getNextTemp(); + std::string type = tmpTable[left]; + if (op == "*") { + irStream << " " << temp << " = mul nsw " << type << " " << left << ", " << right << "\n"; + } else if (op == "/") { + irStream << " " << temp << " = sdiv " << type << " " << left << ", " << right << "\n"; + } else if (op == "%") { + irStream << " " << temp << " = srem " << type << " " << left << ", " << right << "\n"; + } + left = temp; + tmpTable[temp] = type; + } + return left; +} + +std::any LLVMIRGenerator::visitAddExp(SysYParser::AddExpContext* ctx) { + auto mulExps = ctx->mulExp(); + std::string left = std::any_cast(mulExps[0]->accept(this)); + for (size_t i = 1; i < mulExps.size(); ++i) { + std::string right = std::any_cast(mulExps[i]->accept(this)); + std::string op = ctx->children[2*i-1]->getText(); + std::string temp = getNextTemp(); + std::string type = tmpTable[left]; + if (op == "+") { + irStream << " " << temp << " = add nsw " << type << " " << left << ", " << right << "\n"; + } else if (op == "-") { + irStream << " " << temp << " = sub nsw " << type << " " << left << ", " << right << "\n"; + } + left = temp; + tmpTable[temp] = type; + } + return left; +} + +std::any LLVMIRGenerator::visitRelExp(SysYParser::RelExpContext* ctx) { + auto addExps = ctx->addExp(); + std::string left = std::any_cast(addExps[0]->accept(this)); + for (size_t i = 1; i < addExps.size(); ++i) { + std::string right = std::any_cast(addExps[i]->accept(this)); + std::string op = ctx->children[2*i-1]->getText(); + std::string temp = getNextTemp(); + std::string type = tmpTable[left]; + if (op == "<") { + irStream << " " << temp << " = icmp slt " << type << " " << left << ", " << right << "\n"; + } else if (op == ">") { + irStream << " " << temp << " = icmp sgt " << type << " " << left << ", " << right << "\n"; + } else if (op == "<=") { + irStream << " " << temp << " = icmp sle " << type << " " << left << ", " << right << "\n"; + } else if (op == ">=") { + irStream << " " << temp << " = icmp sge " << type << " " << left << ", " << right << "\n"; + } + left = temp; + } + return left; +} + +std::any LLVMIRGenerator::visitEqExp(SysYParser::EqExpContext* ctx) { + auto relExps = ctx->relExp(); + std::string left = std::any_cast(relExps[0]->accept(this)); + for (size_t i = 1; i < relExps.size(); ++i) { + std::string right = std::any_cast(relExps[i]->accept(this)); + std::string op = ctx->children[2*i-1]->getText(); + std::string temp = getNextTemp(); + std::string type = tmpTable[left]; + if (op == "==") { + irStream << " " << temp << " = icmp eq " << type << " " << left << ", " << right << "\n"; + } else if (op == "!=") { + irStream << " " << temp << " = icmp ne " << type << " " << left << ", " << right << "\n"; + } + left = temp; + } + return left; +} + +std::any LLVMIRGenerator::visitLAndExp(SysYParser::LAndExpContext* ctx) { + auto eqExps = ctx->eqExp(); + std::string left = std::any_cast(eqExps[0]->accept(this)); + for (size_t i = 1; i < eqExps.size(); ++i) { + std::string falseLabel = "land.false." + std::to_string(tempCounter); + std::string endLabel = "land.end." + std::to_string(tempCounter++); + std::string temp = getNextTemp(); + + irStream << " br label %" << falseLabel << "\n"; + irStream << falseLabel << ":\n"; + std::string right = std::any_cast(eqExps[i]->accept(this)); + irStream << " " << temp << " = and i1 " << left << ", " << right << "\n"; + irStream << " br label %" << endLabel << "\n"; + irStream << endLabel << ":\n"; + left = temp; + } + return left; +} + +std::any LLVMIRGenerator::visitLOrExp(SysYParser::LOrExpContext* ctx) { + auto lAndExps = ctx->lAndExp(); + std::string left = std::any_cast(lAndExps[0]->accept(this)); + for (size_t i = 1; i < lAndExps.size(); ++i) { + std::string trueLabel = "lor.true." + std::to_string(tempCounter); + std::string endLabel = "lor.end." + std::to_string(tempCounter++); + std::string temp = getNextTemp(); + + irStream << " br label %" << trueLabel << "\n"; + irStream << trueLabel << ":\n"; + std::string right = std::any_cast(lAndExps[i]->accept(this)); + irStream << " " << temp << " = or i1 " << left << ", " << right << "\n"; + irStream << " br label %" << endLabel << "\n"; + irStream << endLabel << ":\n"; + left = temp; + } + return left; +} \ No newline at end of file diff --git a/src/LLVMIRGenerator.h b/src/LLVMIRGenerator.h new file mode 100644 index 0000000..20650ab --- /dev/null +++ b/src/LLVMIRGenerator.h @@ -0,0 +1,55 @@ +#pragma once +#include "SysYBaseVisitor.h" +#include "SysYParser.h" +#include +#include +#include +#include + +class LLVMIRGenerator : public SysYBaseVisitor { +public: + std::string generateIR(SysYParser::CompUnitContext* unit); + std::string getIR() const { return irStream.str(); } + +private: + std::stringstream irStream; + int tempCounter = 0; + std::string currentVarType; + std::map> symbolTable; + std::map tmpTable; + std::vector globalVars; + std::string currentFunction; + std::string currentReturnType; + std::vector breakStack; + std::vector continueStack; + bool hasReturn = false; + + struct LoopLabels { + std::string breakLabel; // break跳转的目标标签 + std::string continueLabel; // continue跳转的目标标签 + }; + std::stack loopStack; // 用于管理循环的break和continue标签 + std::string getNextTemp(); + std::string getLLVMType(const std::string& type); + + bool inFunction = false; // 标识当前是否处于函数内部 + + // 访问方法 + std::any visitCompUnit(SysYParser::CompUnitContext* ctx); + std::any visitConstDecl(SysYParser::ConstDeclContext* ctx); + std::any visitVarDecl(SysYParser::VarDeclContext* ctx); + std::any visitVarDef(SysYParser::VarDefContext* ctx); + std::any visitFuncDef(SysYParser::FuncDefContext* ctx); + std::any visitBlockStmt(SysYParser::BlockStmtContext* ctx); + std::any visitStmt(SysYParser::StmtContext* ctx); + std::any visitLValue(SysYParser::LValueContext* ctx); + std::any visitPrimaryExp(SysYParser::PrimaryExpContext* ctx); + std::any visitNumber(SysYParser::NumberContext* ctx); + std::any visitUnaryExp(SysYParser::UnaryExpContext* ctx); + std::any visitMulExp(SysYParser::MulExpContext* ctx); + std::any visitAddExp(SysYParser::AddExpContext* ctx); + std::any visitRelExp(SysYParser::RelExpContext* ctx); + std::any visitEqExp(SysYParser::EqExpContext* ctx); + std::any visitLAndExp(SysYParser::LAndExpContext* ctx); + std::any visitLOrExp(SysYParser::LOrExpContext* ctx); +}; \ No newline at end of file diff --git a/src/sysyc.cpp b/src/sysyc.cpp index d18df3c..6ec6eb7 100644 --- a/src/sysyc.cpp +++ b/src/sysyc.cpp @@ -9,6 +9,7 @@ using namespace antlr4; #include "ASTPrinter.h" #include "Backend.h" #include "SysYIRGenerator.h" +#include "LLVMIRGenerator.h" using namespace sysy; static string argStopAfter; @@ -20,7 +21,7 @@ void usage(int code = EXIT_FAILURE) { "Supported options:\n" " -h \tprint help message and exit\n"; " -f \tpretty-format the input file\n"; - " -s {ast,ir,asm}\tstop after generating AST/IR/Assembly\n"; + " -s {ast,ir,asm,llvmir}\tstop after generating AST/IR/Assembly\n"; cerr << msg; exit(code); } @@ -83,6 +84,10 @@ int main(int argc, char **argv) { if (argStopAfter == "ir") { moduleIR->print(cout); return EXIT_SUCCESS; + } else if (argStopAfter == "llvmir") { + LLVMIRGenerator llvmirGenerator; + cout << llvmirGenerator.getIR(); + return EXIT_SUCCESS; } // generate assembly