From c47d522e3a01fb3e13307aad9443011ad09d76ad Mon Sep 17 00:00:00 2001 From: Lixuanwang Date: Mon, 9 Jun 2025 19:29:59 +0800 Subject: [PATCH] [IR Gen] debugging expreimental IR generator --- src/LLVMIRGenerator_1.cpp | 855 ++++++++++++++++++++++++++++++++++++++ src/LLVMIRGenerator_1.h | 99 +++++ 2 files changed, 954 insertions(+) create mode 100644 src/LLVMIRGenerator_1.cpp create mode 100644 src/LLVMIRGenerator_1.h diff --git a/src/LLVMIRGenerator_1.cpp b/src/LLVMIRGenerator_1.cpp new file mode 100644 index 0000000..c7da24b --- /dev/null +++ b/src/LLVMIRGenerator_1.cpp @@ -0,0 +1,855 @@ +// LLVMIRGenerator.cpp +// TODO:类型转换及其检查 +// TODO:sysy库函数处理 +// TODO:数组处理 +// TODO:对while、continue、break的测试 +#include "LLVMIRGenerator_1.h" +#include +#include +#include + +namespace sysy { + +std::string LLVMIRGenerator::generateIR(SysYParser::CompUnitContext* unit) { + // 初始化 SysY IR 模块 + module = std::make_unique(); + // 清空符号表和临时变量表 + symbolTable.clear(); + tmpTable.clear(); + irSymbolTable.clear(); + irTmpTable.clear(); + tempCounter = 0; + globalVars.clear(); + hasReturn = false; + loopStack = std::stack(); + inFunction = false; + + // 访问编译单元 + visitCompUnit(unit); + return irStream.str(); +} + +std::string LLVMIRGenerator::getNextTemp() { + std::string ret = "%." + std::to_string(tempCounter++); + tmpTable[ret] = "void"; + return ret; +} + +std::string LLVMIRGenerator::getIRTempName() { + return "%" + std::to_string(tempCounter++); +} + +std::string LLVMIRGenerator::getLLVMType(const std::string& type) { + if (type == "int") return "i32"; + if (type == "float") return "float"; + if (type.find("[]") != std::string::npos) + return getLLVMType(type.substr(0, type.size() - 2)) + "*"; + return "i32"; +} + +sysy::Type* LLVMIRGenerator::getIRType(const std::string& type) { + if (type == "int") return sysy::Type::getIntType(); + if (type == "float") return sysy::Type::getFloatType(); + if (type == "void") return sysy::Type::getVoidType(); + if (type.find("[]") != std::string::npos) { + std::string baseType = type.substr(0, type.size() - 2); + return sysy::Type::getPointerType(getIRType(baseType)); + } + return sysy::Type::getIntType(); // 默认 int +} + +void LLVMIRGenerator::setIRPosition(sysy::BasicBlock* block) { + currentIRBlock = block; +} + +std::any LLVMIRGenerator::visitCompUnit(SysYParser::CompUnitContext* ctx) { + for (auto decl : ctx->decl()) { + decl->accept(this); + } + for (auto funcDef : ctx->funcDef()) { + inFunction = true; + funcDef->accept(this); + inFunction = false; + } + return nullptr; +} + + +std::any LLVMIRGenerator::visitVarDecl(SysYParser::VarDeclContext* ctx) { + // TODO:数组初始化 + std::string type = ctx->bType()->getText(); + currentVarType = getLLVMType(type); + sysy::Type* irType = sysy::Type::getPointerType(getIRType(type)); + + for (auto varDef : ctx->varDef()) { + if (!inFunction) { + // 全局变量(文本 IR) + std::string varName = varDef->Ident()->getText(); + std::string llvmType = getLLVMType(type); + std::string value = "0"; + sysy::Value* initValue = nullptr; + + if (varDef->ASSIGN()) { + value = std::any_cast(varDef->initVal()->accept(this)); + if (irTmpTable.find(value) != irTmpTable.end() && isa(irTmpTable[value])) { + initValue = irTmpTable[value]; + } + } + + if (llvmType == "float" && initValue) { + try { + double floatValue = std::stod(value); + uint64_t hexValue = reinterpret_cast(floatValue); + std::stringstream ss; + ss << "0x" << std::hex << std::uppercase << hexValue; + value = ss.str(); + } catch (...) { + throw std::runtime_error("[ERR-Release-02]Invalid float literal: " + value); + } + } + irStream << "@" << varName << " = dso_local global " << llvmType << " " << value << ", align 4\n"; + globalVars.push_back(varName); + + // 全局变量(SysY IR) + auto globalValue = module->createGlobalValue(varName, irType, {}, initValue); + irSymbolTable[varName] = globalValue; + } else { + varDef->accept(this); + } + } + return nullptr; +} + +std::any LLVMIRGenerator::visitConstDecl(SysYParser::ConstDeclContext* ctx) { + // TODO:数组初始化 + std::string type = ctx->bType()->getText(); + currentVarType = getLLVMType(type); + sysy::Type* irType = sysy::Type::getPointerType(getIRType(type)); // 全局变量为指针类型 + + for (auto constDef : ctx->constDef()) { + std::string varName = constDef->Ident()->getText(); + std::string llvmType = getLLVMType(type); + std::string value = "0"; + sysy::Value* initValue = nullptr; + + try { + value = std::any_cast(constDef->constInitVal()->accept(this)); + if (isa(irTmpTable[value])) { + initValue = irTmpTable[value]; + } + } catch (...) { + throw std::runtime_error("Const value must be initialized upon definition."); + } + + if (!inFunction) { + // 全局常量(文本 IR) + if (llvmType == "float") { + try { + double floatValue = std::stod(value); + uint64_t hexValue = reinterpret_cast(floatValue); + std::stringstream ss; + ss << "0x" << std::hex << std::uppercase << hexValue; + value = ss.str(); + } catch (...) { + throw std::runtime_error("[ERR-Release-03]Invalid float literal: " + value); + } + } + irStream << "@" << varName << " = dso_local constant " << llvmType << " " << value << ", align 4\n"; + globalVars.push_back(varName); + + // 全局常量(SysY IR) + auto globalValue = module->createGlobalValue(varName, irType, {}, initValue); + irSymbolTable[varName] = globalValue; + } else { + // 局部常量(文本 IR) + std::string allocaName = getNextTemp(); + if (llvmType == "float") { + try { + double floatValue = std::stod(value); + uint64_t hexValue = reinterpret_cast(floatValue); + std::stringstream ss; + ss << "0x" << std::hex << std::uppercase << hexValue; + value = ss.str(); + } catch (...) { + throw std::runtime_error("Invalid float literal: " + value); + } + } + irStream << " " << allocaName << " = alloca " << llvmType << ", align 4\n"; + irStream << " store " << llvmType << " " << value << ", " << llvmType + << "* " << allocaName << ", align 4\n"; + symbolTable[varName] = {allocaName, llvmType}; + tmpTable[allocaName] = llvmType; + + // 局部常量(SysY IR) + sysy::IRBuilder builder(currentIRBlock); + auto allocaInst = builder.createAllocaInst(irType, {}, varName); + builder.createStoreInst(initValue, allocaInst); + irSymbolTable[varName] = allocaInst; + irTmpTable[allocaName] = allocaInst; + } + } + return nullptr; +} + +std::any LLVMIRGenerator::visitVarDef(SysYParser::VarDefContext* ctx) { + // TODO:数组初始化 + std::string varName = ctx->Ident()->getText(); + std::string llvmType = currentVarType; + sysy::Type* irType = sysy::Type::getPointerType(getIRType(currentVarType == "i32" ? "int" : "float")); + std::string allocaName = getNextTemp(); + + // 局部变量(文本 IR) + irStream << " " << allocaName << " = alloca " << llvmType << ", align 4\n"; + + // 局部变量(SysY IR) + sysy::IRBuilder builder(currentIRBlock); + auto allocaInst = builder.createAllocaInst(irType, {}, varName); + sysy::Value* initValue = nullptr; + + if (ctx->ASSIGN()) { + std::string value = std::any_cast(ctx->initVal()->accept(this)); + if (llvmType == "float") { + try { + double floatValue = std::stod(value); + uint64_t hexValue = reinterpret_cast(floatValue); + std::stringstream ss; + ss << "0x" << std::hex << std::uppercase << (hexValue & (0xffffffffUL << 32)); + value = ss.str(); + } catch (...) { + throw std::runtime_error("Invalid float literal: " + value); + } + } + irStream << " store " << llvmType << " " << value << ", " << llvmType + << "* " << allocaName << ", align 4\n"; + + if (irTmpTable.find(value) != irTmpTable.end()) { + initValue = irTmpTable[value]; + } + builder.createStoreInst(initValue, allocaInst); + } + + symbolTable[varName] = {allocaName, llvmType}; + tmpTable[allocaName] = llvmType; + irSymbolTable[varName] = allocaInst; + irTmpTable[allocaName] = allocaInst; + builder.createStoreInst(initValue, allocaInst); + return nullptr; +} + +std::any LLVMIRGenerator::visitFuncDef(SysYParser::FuncDefContext* ctx) { + currentFunction = ctx->Ident()->getText(); + currentReturnType = getLLVMType(ctx->funcType()->getText()); + sysy::Type* irReturnType = getIRType(ctx->funcType()->getText()); + std::vector paramTypes; + + // 清空符号表 + symbolTable.clear(); + tmpTable.clear(); + irSymbolTable.clear(); + irTmpTable.clear(); + tempCounter = 0; + hasReturn = false; + + // 处理函数参数(文本 IR 和 SysY IR) + if (ctx->funcFParams()) { + auto params = ctx->funcFParams()->funcFParam(); + for (size_t i = 0; i < params.size(); ++i) { + std::string paramType = getLLVMType(params[i]->bType()->getText()); + if (i > 0) irStream << ", "; + irStream << paramType << " noundef %" << i; + symbolTable[params[i]->Ident()->getText()] = {"%" + std::to_string(i), paramType}; + tmpTable["%" + std::to_string(i)] = paramType; + paramTypes.push_back(getIRType(params[i]->bType()->getText())); + } + tempCounter += params.size(); + } + tempCounter++; + + // 文本 IR 函数定义 + irStream << "define dso_local " << currentReturnType << " @" << currentFunction << "("; + irStream << ") #0 {\n"; + + // SysY IR 函数定义 + sysy::Type* funcType = sysy::Type::getFunctionType(irReturnType, paramTypes); + currentIRFunction = module->createFunction(currentFunction, funcType); + setIRPosition(currentIRFunction->getEntryBlock()); + + // 处理函数参数分配 + if (ctx->funcFParams()) { + auto params = ctx->funcFParams()->funcFParam(); + for (size_t i = 0; i < params.size(); ++i) { + std::string varName = params[i]->Ident()->getText(); + std::string llvmType = getLLVMType(params[i]->bType()->getText()); + sysy::Type* irType = getIRType(params[i]->bType()->getText()); + std::string allocaName = getNextTemp(); + tmpTable[allocaName] = llvmType; + + // 文本 IR 分配 + irStream << " " << allocaName << " = alloca " << llvmType << ", align 4\n"; + irStream << " store " << llvmType << " %" << i << ", " << llvmType + << "* " << allocaName << ", align 4\n"; + + // SysY IR 分配 + sysy::IRBuilder builder(currentIRBlock); + auto arg = currentIRBlock->createArgument(irType, varName); + auto allocaInst = builder.createAllocaInst(sysy::Type::getPointerType(irType), {}, varName); + builder.createStoreInst(arg, allocaInst); + symbolTable[varName] = {allocaName, llvmType}; + irSymbolTable[varName] = allocaInst; + irTmpTable[allocaName] = allocaInst; + } + } + + ctx->blockStmt()->accept(this); + + if (!hasReturn) { + if (currentReturnType == "void") { + irStream << " ret void\n"; + sysy::IRBuilder builder(currentIRBlock); + builder.createReturnInst(); + } else { + irStream << " ret " << currentReturnType << " 0\n"; + sysy::IRBuilder builder(currentIRBlock); + builder.createReturnInst(sysy::ConstantValue::get(0)); + } + } + irStream << "}\n"; + currentIRFunction = nullptr; + currentIRBlock = nullptr; + return nullptr; +} + +std::any LLVMIRGenerator::visitBlockStmt(SysYParser::BlockStmtContext* ctx) { + for (auto item : ctx->blockItem()) { + item->accept(this); + } + return nullptr; +} + +std::any LLVMIRGenerator::visitAssignStmt(SysYParser::AssignStmtContext* ctx) { + std::string lhsAlloca = std::any_cast(ctx->lValue()->accept(this)); + std::string lhsType = symbolTable[ctx->lValue()->Ident()->getText()].second; + std::string rhs = std::any_cast(ctx->exp()->accept(this)); + sysy::Value* rhsValue = irTmpTable[rhs]; + + // 文本 IR + if (lhsType == "float") { + try { + double floatValue = std::stod(rhs); + uint64_t hexValue = reinterpret_cast(floatValue); + std::stringstream ss; + ss << "0x" << std::hex << std::uppercase << (hexValue & (0xffffffffUL << 32)); + rhs = ss.str(); + } catch (...) { + // 如果 rhs 不是字面量,假设已正确处理 + } + } + irStream << " store " << lhsType << " " << rhs << ", " << lhsType + << "* " << lhsAlloca << ", align 4\n"; + + // SysY IR + sysy::IRBuilder builder(currentIRBlock); + builder.createStoreInst(rhsValue, irSymbolTable[ctx->lValue()->Ident()->getText()]); + return nullptr; +} + +std::any LLVMIRGenerator::visitIfStmt(SysYParser::IfStmtContext* ctx) { + std::string cond = std::any_cast(ctx->cond()->accept(this)); + sysy::Value* condValue = irTmpTable[cond]; + std::string trueLabel = "if.then." + std::to_string(tempCounter); + std::string falseLabel = "if.else." + std::to_string(tempCounter); + std::string mergeLabel = "if.end." + std::to_string(tempCounter++); + + // SysY IR 基本块 + sysy::BasicBlock* thenBlock = currentIRFunction->addBasicBlock(trueLabel); + sysy::BasicBlock* elseBlock = ctx->ELSE() ? currentIRFunction->addBasicBlock(falseLabel) : nullptr; + sysy::BasicBlock* mergeBlock = currentIRFunction->addBasicBlock(mergeLabel); + + // 文本 IR + irStream << " br i1 " << cond << ", label %" << trueLabel << ", label %" + << (ctx->ELSE() ? falseLabel : mergeLabel) << "\n"; + + // SysY IR 条件分支 + sysy::IRBuilder builder(currentIRBlock); + builder.createCondBrInst(condValue, thenBlock, ctx->ELSE() ? elseBlock : mergeBlock, {}, {}); + + // 处理 then 分支 + setIRPosition(thenBlock); + irStream << trueLabel << ":\n"; + ctx->stmt(0)->accept(this); + irStream << " br label %" << mergeLabel << "\n"; + builder.setPosition(thenBlock, thenBlock->end()); + builder.createUncondBrInst(mergeBlock, {}); + + // 处理 else 分支 + if (ctx->ELSE()) { + setIRPosition(elseBlock); + irStream << falseLabel << ":\n"; + ctx->stmt(1)->accept(this); + irStream << " br label %" << mergeLabel << "\n"; + builder.setPosition(elseBlock, elseBlock->end()); + builder.createUncondBrInst(mergeBlock, {}); + } + + // 合并点 + setIRPosition(mergeBlock); + irStream << mergeLabel << ":\n"; + return nullptr; +} + +std::any LLVMIRGenerator::visitWhileStmt(SysYParser::WhileStmtContext* ctx) { + std::string loopCond = "while.cond." + std::to_string(tempCounter); + std::string loopBody = "while.body." + std::to_string(tempCounter); + std::string loopEnd = "while.end." + std::to_string(tempCounter++); + + // SysY IR 基本块 + sysy::BasicBlock* condBlock = currentIRFunction->addBasicBlock(loopCond); + sysy::BasicBlock* bodyBlock = currentIRFunction->addBasicBlock(loopBody); + sysy::BasicBlock* endBlock = currentIRFunction->addBasicBlock(loopEnd); + + loopStack.push({loopEnd, loopCond, endBlock, condBlock}); + + // 跳转到条件块 + sysy::IRBuilder builder(currentIRBlock); + builder.createUncondBrInst(condBlock, {}); + irStream << " br label %" << loopCond << "\n"; + + // 条件块 + setIRPosition(condBlock); + irStream << loopCond << ":\n"; + std::string cond = std::any_cast(ctx->cond()->accept(this)); + sysy::Value* condValue = irTmpTable[cond]; + irStream << " br i1 " << cond << ", label %" << loopBody << ", label %" << loopEnd << "\n"; + builder.setPosition(condBlock, condBlock->end()); + builder.createCondBrInst(condValue, bodyBlock, endBlock, {}, {}); + + // 循环体 + setIRPosition(bodyBlock); + irStream << loopBody << ":\n"; + ctx->stmt()->accept(this); + irStream << " br label %" << loopCond << "\n"; + builder.setPosition(bodyBlock, bodyBlock->end()); + builder.createUncondBrInst(condBlock, {}); + + // 结束块 + setIRPosition(endBlock); + irStream << loopEnd << ":\n"; + loopStack.pop(); + return nullptr; +} + +std::any LLVMIRGenerator::visitBreakStmt(SysYParser::BreakStmtContext* ctx) { + if (loopStack.empty()) { + throw std::runtime_error("Break statement outside of a loop."); + } + irStream << " br label %" << loopStack.top().breakLabel << "\n"; + sysy::IRBuilder builder(currentIRBlock); + builder.createUncondBrInst(loopStack.top().irBreakBlock, {}); + return nullptr; +} + +std::any LLVMIRGenerator::visitContinueStmt(SysYParser::ContinueStmtContext* ctx) { + if (loopStack.empty()) { + throw std::runtime_error("Continue statement outside of a loop."); + } + irStream << " br label %" << loopStack.top().continueLabel << "\n"; + sysy::IRBuilder builder(currentIRBlock); + builder.createUncondBrInst(loopStack.top().irContinueBlock, {}); + return nullptr; +} + +std::any LLVMIRGenerator::visitReturnStmt(SysYParser::ReturnStmtContext* ctx) { + hasReturn = true; + sysy::IRBuilder builder(currentIRBlock); + if (ctx->exp()) { + std::string value = std::any_cast(ctx->exp()->accept(this)); + sysy::Value* irValue = irTmpTable[value]; + irStream << " ret " << currentReturnType << " " << value << "\n"; + builder.createReturnInst(irValue); + } else { + irStream << " ret void\n"; + builder.createReturnInst(); + } + return nullptr; +} + +std::any LLVMIRGenerator::visitLValue(SysYParser::LValueContext* ctx) { + std::string varName = ctx->Ident()->getText(); + if (irSymbolTable.find(varName) == irSymbolTable.end()) { + throw std::runtime_error("Undefined variable: " + varName); + } + // 对于 LValue,返回分配的指针(文本 IR 和 SysY IR 一致) + return symbolTable[varName].first; +} + +std::any LLVMIRGenerator::visitPrimExp(SysYParser::PrimExpContext* ctx) { + SysYParser::PrimaryExpContext* pExpCtx = ctx->primaryExp(); + if (auto* lvalCtx = dynamic_cast(pExpCtx)) { + std + +::string allocaPtr = std::any_cast(lvalCtx->lValue()->accept(this)); + std::string varName = lvalCtx->lValue()->Ident()->getText(); + std::string type = symbolTable[varName].second; + std::string temp = getNextTemp(); + sysy::Type* irType = getIRType(type == "i32" ? "int" : "float"); + + // 文本 IR + irStream << " " << temp << " = load " << type << ", " << type << "* " << allocaPtr << ", align 4\n"; + tmpTable[temp] = type; + + // SysY IR + sysy::IRBuilder builder(currentIRBlock); + auto loadInst = builder.createLoadInst(irSymbolTable[varName], {}); + irTmpTable[temp] = loadInst; + return temp; + } else if (auto* expCtx = dynamic_cast(pExpCtx)) { + return expCtx->exp()->accept(this); + } else if (auto* strCtx = dynamic_cast(pExpCtx)) { + return strCtx->string()->accept(this); + } else if (auto* numCtx = dynamic_cast(pExpCtx)) { + return numCtx->number()->accept(this); + } else { + throw std::runtime_error("Unknown primary expression type."); + } +} + +std::any LLVMIRGenerator::visitParenExp(SysYParser::ParenExpContext* ctx) { + return ctx->exp()->accept(this); +} + +std::any LLVMIRGenerator::visitNumber(SysYParser::NumberContext* ctx) { + std::string value; + sysy::Value* irValue = nullptr; + if (ctx->ILITERAL()) { + value = ctx->ILITERAL()->getText(); + irValue = sysy::ConstantValue::get(std::stoi(value)); + } else if (ctx->FLITERAL()) { + value = ctx->FLITERAL()->getText(); + irValue = sysy::ConstantValue::get(std::stof(value)); + } else { + value = ""; + } + std::string temp = getNextTemp(); + tmpTable[temp] = ctx->ILITERAL() ? "i32" : "float"; + irTmpTable[temp] = irValue; + return value; +} + +std::any LLVMIRGenerator::visitString(SysYParser::StringContext* ctx) { + if (ctx->STRING()) { + std::string str = ctx->STRING()->getText(); + str = str.substr(1, str.size() - 2); + std::string escapedStr; + for (char c : str) { + if (c == '\\') { + escapedStr += "\\\\"; + } else if (c == '"') { + escapedStr += "\\\""; + } else { + escapedStr += c; + } + } + // TODO: SysY IR 暂不支持字符串常量,返回文本 IR 结果 + return "\"" + escapedStr + "\""; + } + return ctx->STRING()->getText(); +} + +std::any LLVMIRGenerator::visitCall(SysYParser::CallContext* ctx) { + std::string funcName = ctx->Ident()->getText(); + std::vector args; + std::vector irArgs; + if (ctx->funcRParams()) { + for (auto argCtx : ctx->funcRParams()->exp()) { + std::string arg = std::any_cast(argCtx->accept(this)); + args.push_back(arg); + irArgs.push_back(irTmpTable[arg]); + } + } + std::string temp = getNextTemp(); + std::string argList; + for (size_t i = 0; i < args.size(); ++i) { + if (i > 0) argList += ", "; + argList += tmpTable[args[i]] + " noundef " + args[i]; + } + + // 文本 IR + irStream << " " << temp << " = call " << currentReturnType << " @" << funcName << "(" << argList << ")\n"; + tmpTable[temp] = currentReturnType; + + // SysY IR + sysy::IRBuilder builder(currentIRBlock); + sysy::Function* callee = module->getFunction(funcName); + if (!callee) { + throw std::runtime_error("Undefined function: " + funcName); + } + auto callInst = builder.createCallInst(callee, irArgs, temp); + irTmpTable[temp] = callInst; + return temp; +} + +std::any LLVMIRGenerator::visitUnExp(SysYParser::UnExpContext* ctx) { + if (ctx->unaryOp()) { + std::string operand = std::any_cast(ctx->unaryExp()->accept(this)); + sysy::Value* irOperand = irTmpTable[operand]; + std::string op = ctx->unaryOp()->getText(); + std::string temp = getNextTemp(); + std::string type = tmpTable[operand]; + sysy::Type* irType = getIRType(type == "i32" ? "int" : "float"); + tmpTable[temp] = type; + + // 文本 IR + if (op == "-") { + irStream << " " << temp << " = sub " << type << " 0, " << operand << "\n"; + } else if (op == "!") { + irStream << " " << temp << " = xor " << type << " " << operand << ", 1\n"; + } + + // SysY IR + sysy::IRBuilder builder(currentIRBlock); + sysy::Instruction::Kind kind = (op == "-") ? (type == "i32" ? sysy::Instruction::kNeg : sysy::Instruction::kFNeg) + : sysy::Instruction::kNot; + auto unaryInst = builder.createUnaryInst(kind, irType, irOperand, temp); + irTmpTable[temp] = unaryInst; + return temp; + } + return ctx->unaryExp()->accept(this); +} + +std::any LLVMIRGenerator::visitMulExp(SysYParser::MulExpContext* ctx) { + auto unaryExps = ctx->unaryExp(); + std::string left = std::any_cast(unaryExps[0]->accept(this)); + sysy::Value* irLeft = irTmpTable[left]; + sysy::Type* irType = irLeft->getType(); + + for (size_t i = 1; i < unaryExps.size(); ++i) { + std::string right = std::any_cast(unaryExps[i]->accept(this)); + sysy::Value* irRight = irTmpTable[right]; + std::string op = ctx->children[2 * i - 1]->getText(); + std::string temp = getNextTemp(); + std::string type = tmpTable[left]; + tmpTable[temp] = type; + + // 文本 IR + if (op == "*") { + irStream << " " << temp << " = mul nsw " << type << " " << left << ", " << right << "\n"; + } else if (op == "/") { + irStream << " " << temp << " = sdiv " << type << " " << left << ", " << right << "\n"; + } else if (op == "%") { + irStream << " " << temp << " = srem " << type << " " << left << ", " << right << "\n"; + } + + // SysY IR + sysy::IRBuilder builder(currentIRBlock); + sysy::Instruction::Kind kind; + if (type == "i32") { + if (op == "*") kind = sysy::Instruction::kMul; + else if (op == "/") kind = sysy::Instruction::kDiv; + else kind = sysy::Instruction::kRem; + } else { + if (op == "*") kind = sysy::Instruction::kFMul; + else if (op == "/") kind = sysy::Instruction::kFDiv; + else kind = sysy::Instruction::kFRem; + } + auto binaryInst = builder.createBinaryInst(kind, irType, irLeft, irRight, temp); + irTmpTable[temp] = binaryInst; + left = temp; + irLeft = binaryInst; + } + return left; +} + +std::any LLVMIRGenerator::visitAddExp(SysYParser::AddExpContext* ctx) { + auto mulExps = ctx->mulExp(); + std::string left = std::any_cast(mulExps[0]->accept(this)); + sysy::Value* irLeft = irTmpTable[left]; + sysy::Type* irType = irLeft->getType(); + + for (size_t i = 1; i < mulExps.size(); ++i) { + std::string right = std::any_cast(mulExps[i]->accept(this)); + sysy::Value* irRight = irTmpTable[right]; + std::string op = ctx->children[2 * i - 1]->getText(); + std::string temp = getNextTemp(); + std::string type = tmpTable[left]; + tmpTable[temp] = type; + + // 文本 IR + if (op == "+") { + irStream << " " << temp << " = add nsw " << type << " " << left << ", " << right << "\n"; + } else if (op == "-") { + irStream << " " << temp << " = sub nsw " << type << " " << left << ", " << right << "\n"; + } + + // SysY IR + sysy::IRBuilder builder(currentIRBlock); + sysy::Instruction::Kind kind = (type == "i32") ? (op == "+" ? sysy::Instruction::kAdd : sysy::Instruction::kSub) + : (op == "+" ? sysy::Instruction::kFAdd : sysy::Instruction::kFSub); + auto binaryInst = builder.createBinaryInst(kind, irType, irLeft, irRight, temp); + irTmpTable[temp] = binaryInst; + left = temp; + irLeft = binaryInst; + } + return left; +} + +std::any LLVMIRGenerator::visitRelExp(SysYParser::RelExpContext* ctx) { + auto addExps = ctx->addExp(); + std::string left = std::any_cast(addExps[0]->accept(this)); + sysy::Value* irLeft = irTmpTable[left]; + sysy::Type* irType = sysy::Type::getIntType(); // 比较结果为 i1 + + for (size_t i = 1; i < addExps.size(); ++i) { + std::string right = std::any_cast(addExps[i]->accept(this)); + sysy::Value* irRight = irTmpTable[right]; + std::string op = ctx->children[2 * i - 1]->getText(); + std::string temp = getNextTemp(); + std::string type = tmpTable[left]; + tmpTable[temp] = "i1"; + + // 文本 IR + if (op == "<") { + irStream << " " << temp << " = icmp slt " << type << " " << left << ", " << right << "\n"; + } else if (op == ">") { + irStream << " " << temp << " = icmp sgt " << type << " " << left << ", " << right << "\n"; + } else if (op == "<=") { + irStream << " " << temp << " = icmp sle " << type << " " << left << ", " << right << "\n"; + } else if (op == ">=") { + irStream << " " << temp << " = icmp sge " << type << " " << left << ", " << right << "\n"; + } + + // SysY IR + sysy::IRBuilder builder(currentIRBlock); + sysy::Instruction::Kind kind; + if (type == "i32") { + if (op == "<") kind = sysy::Instruction::kICmpLT; + else if (op == ">") kind = sysy::Instruction::kICmpGT; + else if (op == "<=") kind = sysy::Instruction::kICmpLE; + else kind = sysy::Instruction::kICmpGE; + } else { + if (op == "<") kind = sysy::Instruction::kFCmpLT; + else if (op == ">") kind = sysy::Instruction::kFCmpGT; + else if (op == "<=") kind = sysy::Instruction::kFCmpLE; + else kind = sysy::Instruction::kFCmpGE; + } + auto cmpInst = builder.createBinaryInst(kind, irType, irLeft, irRight, temp); + irTmpTable[temp] = cmpInst; + left = temp; + irLeft = cmpInst; + } + return left; +} + +std::any LLVMIRGenerator::visitEqExp(SysYParser::EqExpContext* ctx) { + auto relExps = ctx->relExp(); + std::string left = std::any_cast(relExps[0]->accept(this)); + sysy::Value* irLeft = irTmpTable[left]; + sysy::Type* irType = sysy::Type::getIntType(); // 比较结果为 i1 + + for (size_t i = 1; i < relExps.size(); ++i) { + std::string right = std::any_cast(relExps[i]->accept(this)); + sysy::Value* irRight = irTmpTable[right]; + std::string op = ctx->children[2 * i - 1]->getText(); + std::string temp = getNextTemp(); + std::string type = tmpTable[left]; + tmpTable[temp] = "i1"; + + // 文本 IR + if (op == "==") { + irStream << " " << temp << " = icmp eq " << type << " " << left << ", " << right << "\n"; + } else if (op == "!=") { + irStream << " " << temp << " = icmp ne " << type << " " << left << ", " << right << "\n"; + } + + // SysY IR + sysy::IRBuilder builder(currentIRBlock); + sysy::Instruction::Kind kind = (type == "i32") ? (op == "==" ? sysy::Instruction::kICmpEQ : sysy::Instruction::kICmpNE) + : (op == "==" ? sysy::Instruction::kFCmpEQ : sysy::Instruction::kFCmpNE); + auto cmpInst = builder.createBinaryInst(kind, irType, irLeft, irRight, temp); + irTmpTable[temp] = cmpInst; + left = temp; + irLeft = cmpInst; + } + return left; +} + +std::any LLVMIRGenerator::visitLAndExp(SysYParser::LAndExpContext* ctx) { + auto eqExps = ctx->eqExp(); + std::string left = std::any_cast(eqExps[0]->accept(this)); + sysy::Value* irLeft = irTmpTable[left]; + + for (size_t i = 1; i < eqExps.size(); ++i) { + std::string falseLabel = "land.false." + std::to_string(tempCounter); + std::string endLabel = "land.end." + std::to_string(tempCounter++); + sysy::BasicBlock* falseBlock = currentIRFunction->addBasicBlock(falseLabel); + sysy::BasicBlock* endBlock = currentIRFunction->addBasicBlock(endLabel); + std::string temp = getNextTemp(); + tmpTable[temp] = "i1"; + + // 文本 IR + irStream << " br i1 " << left << ", label %" << falseLabel << ", label %" << endLabel << "\n"; + irStream << falseLabel << ":\n"; + + // SysY IR + sysy::IRBuilder builder(currentIRBlock); + builder.createCondBrInst(irLeft, falseBlock, endBlock, {}, {}); + setIRPosition(falseBlock); + + std::string right = std::any_cast(eqExps[i]->accept(this)); + sysy::Value* irRight = irTmpTable[right]; + irStream << " " << temp << " = and i1 " << left << ", " << right << "\n"; + irStream << " br label %" << endLabel << "\n"; + irStream << endLabel << ":\n"; + + // SysY IR 逻辑与(通过基本块实现短路求值) + builder.setPosition(falseBlock, falseBlock->end()); + auto andInst = builder.createBinaryInst(sysy::Instruction::kICmpEQ, sysy::Type::getIntType(), irLeft, irRight, temp); + builder.createUncondBrInst(endBlock, {}); + irTmpTable[temp] = andInst; + left = temp; + irLeft = andInst; + setIRPosition(endBlock); + } + return left; +} + +std::any LLVMIRGenerator::visitLOrExp(SysYParser::LOrExpContext* ctx) { + auto lAndExps = ctx->lAndExp(); + std::string left = std::any_cast(lAndExps[0]->accept(this)); + sysy::Value* irLeft = irTmpTable[left]; + + for (size_t i = 1; i < lAndExps.size(); ++i) { + std::string trueLabel = "lor.true." + std::to_string(tempCounter); + std::string endLabel = "lor.end." + std::to_string(tempCounter++); + sysy::BasicBlock* trueBlock = currentIRFunction->addBasicBlock(trueLabel); + sysy::BasicBlock* endBlock = currentIRFunction->addBasicBlock(endLabel); + std::string temp = getNextTemp(); + tmpTable[temp] = "i1"; + + // 文本 IR + irStream << " br i1 " << left << ", label %" << trueLabel << ", label %" << endLabel << "\n"; + irStream << trueLabel << ":\n"; + + // SysY IR + sysy::IRBuilder builder(currentIRBlock); + builder.createCondBrInst(irLeft, trueBlock, endBlock, {}, {}); + setIRPosition(trueBlock); + + std::string right = std::any_cast(lAndExps[i]->accept(this)); + sysy::Value* irRight = irTmpTable[right]; + irStream << " " << temp << " = or i1 " << left << ", " << right << "\n"; + irStream << " br label %" << endLabel << "\n"; + irStream << endLabel << ":\n"; + + // SysY IR 逻辑或(通过基本块实现短路求值) + builder.setPosition(trueBlock, trueBlock->end()); + auto orInst = builder.createBinaryInst(sysy::Instruction::kICmpEQ, sysy::Type::getIntType(), irLeft, irRight, temp); + builder.createUncondBrInst(endBlock, {}); + irTmpTable[temp] = orInst; + left = temp; + irLeft = orInst; + setIRPosition(endBlock); + } + return left; +} + +} // namespace sysy \ No newline at end of file diff --git a/src/LLVMIRGenerator_1.h b/src/LLVMIRGenerator_1.h new file mode 100644 index 0000000..5c851b6 --- /dev/null +++ b/src/LLVMIRGenerator_1.h @@ -0,0 +1,99 @@ +#pragma once +#include "SysYBaseVisitor.h" +#include "SysYParser.h" +#include "IR.h" // 引入 SysY IR 头文件 +#include "IRBuilder.h" +#include +#include +#include +#include +#include + +class LLVMIRGenerator : public SysYBaseVisitor { +public: + // 生成 IR(文本和数据结构) + std::string generateIR(SysYParser::CompUnitContext* unit); + + // 获取文本格式的 LLVM IR + std::string getIR() const { return irStream.str(); } + + // 获取 SysY IR 数据结构 + sysy::Module* getModule() const { return module.get(); } + +private: + // 文本输出相关 + std::stringstream irStream; + int tempCounter = 0; // 临时变量计数器 + std::string currentVarType; // 当前变量类型(文本 IR 用) + + // 符号表:映射变量名到 {分配地址/寄存器, 类型}(文本 IR) + std::map> symbolTable; + // 临时变量表:映射临时变量名到类型(文本 IR) + std::map tmpTable; + std::vector globalVars; // 全局变量列表(文本 IR) + + // SysY IR 数据结构 + std::unique_ptr module; // SysY IR 模块 + // 符号表:映射变量名到 SysY IR 的 Value 指针 + std::map irSymbolTable; + // 临时变量表:映射临时变量名到 SysY IR 的 Value 指针 + std::map irTmpTable; + + // 当前上下文 + std::string currentFunction; // 当前函数名(文本 IR) + std::string currentReturnType; // 当前函数返回类型(文本 IR) + sysy::Function* currentIRFunction = nullptr; // 当前 SysY IR 函数 + sysy::BasicBlock* currentIRBlock = nullptr; // 当前 SysY IR 基本块 + + // 循环控制 + std::vector breakStack; // break 标签栈(文本 IR) + std::vector continueStack; // continue 标签栈(文本 IR) + bool hasReturn = false; // 是否有返回语句(文本 IR) + + struct LoopLabels { + std::string breakLabel; // break 跳转目标标签(文本 IR) + std::string continueLabel; // continue 跳转目标标签(文本 IR) + sysy::BasicBlock* irBreakBlock = nullptr; // break 跳转目标块(SysY IR) + sysy::BasicBlock* irContinueBlock = nullptr; // continue 跳转目标块(SysY IR) + }; + std::stack loopStack; // 管理循环的 break 和 continue 标签 + + bool inFunction = false; // 标记是否在函数内部 + + // 辅助函数(文本 IR) + std::string getNextTemp(); // 获取下一个临时变量名 + std::string getLLVMType(const std::string& type); // 转换 SysY 类型到 LLVM 类型 + + // 辅助函数(SysY IR) + sysy::Type* getIRType(const std::string& type); // 转换 SysY 类型到 SysY IR 类型 + std::string getIRTempName(); // 获取 SysY IR 临时变量名 + void setIRPosition(sysy::BasicBlock* block); // 设置当前 IR 插入点 + + // 访问方法 + std::any visitCompUnit(SysYParser::CompUnitContext* ctx) override; + std::any visitConstDecl(SysYParser::ConstDeclContext* ctx) override; + std::any visitVarDecl(SysYParser::VarDeclContext* ctx) override; + std::any visitVarDef(SysYParser::VarDefContext* ctx) override; + std::any visitFuncDef(SysYParser::FuncDefContext* ctx) override; + std::any visitBlockStmt(SysYParser::BlockStmtContext* ctx) override; + std::any visitLValue(SysYParser::LValueContext* ctx) override; + // std::any visitPrimaryExp(SysYParser::PrimaryExpContext* ctx) override; + std::any visitPrimExp(SysYParser::PrimExpContext* ctx) override; + std::any visitParenExp(SysYParser::ParenExpContext* ctx) override; + std::any visitNumber(SysYParser::NumberContext* ctx) override; + std::any visitString(SysYParser::StringContext* ctx) override; + std::any visitCall(SysYParser::CallContext* ctx) override; + std::any visitUnExp(SysYParser::UnExpContext* ctx) override; + std::any visitMulExp(SysYParser::MulExpContext* ctx) override; + std::any visitAddExp(SysYParser::AddExpContext* ctx) override; + std::any visitRelExp(SysYParser::RelExpContext* ctx) override; + std::any visitEqExp(SysYParser::EqExpContext* ctx) override; + std::any visitLAndExp(SysYParser::LAndExpContext* ctx) override; + std::any visitLOrExp(SysYParser::LOrExpContext* ctx) override; + std::any visitAssignStmt(SysYParser::AssignStmtContext* ctx) override; + std::any visitIfStmt(SysYParser::IfStmtContext* ctx) override; + std::any visitWhileStmt(SysYParser::WhileStmtContext* ctx) override; + std::any visitBreakStmt(SysYParser::BreakStmtContext* ctx) override; + std::any visitContinueStmt(SysYParser::ContinueStmtContext* ctx) override; + std::any visitReturnStmt(SysYParser::ReturnStmtContext* ctx) override; +}; \ No newline at end of file