Files
mysysy/src/LLVMIRGenerator.cpp
2025-05-29 17:14:42 +08:00

566 lines
22 KiB
C++
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
// LLVMIRGenerator.cpp
// TODO类型转换及其检查
// TODOsysy库函数处理
// TODO数组处理
// TODO对while、continue、break的测试
#include "LLVMIRGenerator.h"
#include <iomanip>
std::string LLVMIRGenerator::generateIR(SysYParser::CompUnitContext* unit) {
visitCompUnit(unit);
return irStream.str();
}
std::string LLVMIRGenerator::getNextTemp() {
std::string ret = "%." + std::to_string(tempCounter++);
tmpTable[ret] = "void";
return ret;
}
std::string LLVMIRGenerator::getLLVMType(const std::string& type) {
if (type == "int") return "i32";
if (type == "float") return "float";
if (type.find("[]") != std::string::npos)
return getLLVMType(type.substr(0, type.size()-2)) + "*";
return "i32";
}
std::any LLVMIRGenerator::visitCompUnit(SysYParser::CompUnitContext* ctx) {
for (auto decl : ctx->decl()) {
decl->accept(this);
}
for (auto funcDef : ctx->funcDef()) {
inFunction = true; // 进入函数定义
funcDef->accept(this);
inFunction = false; // 离开函数定义
}
return nullptr;
}
std::any LLVMIRGenerator::visitVarDecl(SysYParser::VarDeclContext* ctx) {
// TODO数组初始化
std::string type = ctx->bType()->getText();
currentVarType = getLLVMType(type);
for (auto varDef : ctx->varDef()) {
if (!inFunction) {
// 全局变量声明
std::string varName = varDef->Ident()->getText();
std::string llvmType = getLLVMType(type);
std::string value = "0"; // 默认值为 0
if (varDef->ASSIGN()) {
value = std::any_cast<std::string>(varDef->initVal()->accept(this));
}
irStream << "@" << varName << " = dso_local global " << llvmType << " " << value << ", align 4\n";
globalVars.push_back(varName); // 记录全局变量
} else {
// 局部变量声明
varDef->accept(this);
}
}
return nullptr;
}
std::any LLVMIRGenerator::visitConstDecl(SysYParser::ConstDeclContext* ctx) {
// TODO数组初始化
std::string type = ctx->bType()->getText();
for (auto constDef : ctx->constDef()) {
if (!inFunction) {
// 全局常量声明
std::string varName = constDef->Ident()->getText();
std::string llvmType = getLLVMType(type);
std::string value = "0"; // 默认值为 0
try {
value = std::any_cast<std::string>(constDef->constInitVal()->accept(this));
} catch (...) {
throw std::runtime_error("Const value must be initialized upon definition.");
}
if (llvmType == "float") {
try {
double floatValue = std::stod(value);
uint64_t hexValue = reinterpret_cast<uint64_t&>(floatValue);
std::stringstream ss;
ss << "0x" << std::hex << std::uppercase << hexValue;
value = ss.str();
} catch (...) {
throw std::runtime_error("Invalid float literal: " + value);
}
}
irStream << "@" << varName << " = dso_local constant " << llvmType << " " << value << ", align 4\n";
globalVars.push_back(varName); // 记录全局变量
} else {
// 局部常量声明
std::string varName = constDef->Ident()->getText();
std::string llvmType = getLLVMType(type);
std::string allocaName = getNextTemp();
std::string value = "0"; // 默认值为 0
try {
value = std::any_cast<std::string>(constDef->constInitVal()->accept(this));
} catch (...) {
throw std::runtime_error("Const value must be initialized upon definition.");
}
irStream << " " << allocaName << " = alloca " << llvmType << ", align 4\n";
if (llvmType == "float") {
try {
double floatValue = std::stod(value);
uint64_t hexValue = reinterpret_cast<uint64_t&>(floatValue);
std::stringstream ss;
ss << "0x" << std::hex << std::uppercase << hexValue;
value = ss.str();
} catch (...) {
throw std::runtime_error("Invalid float literal: " + value);
}
}
irStream << " store " << llvmType << " " << value << ", " << llvmType
<< "* " << allocaName << ", align 4\n";
symbolTable[varName] = {allocaName, llvmType};
tmpTable[allocaName] = llvmType;
}
}
return nullptr;
}
std::any LLVMIRGenerator::visitVarDef(SysYParser::VarDefContext* ctx) {
// TODO数组初始化
std::string varName = ctx->Ident()->getText();
std::string type = currentVarType;
std::string llvmType = getLLVMType(type);
std::string allocaName = getNextTemp();
irStream << " " << allocaName << " = alloca " << llvmType << ", align 4\n";
if (ctx->ASSIGN()) {
std::string value = std::any_cast<std::string>(ctx->initVal()->accept(this));
if (llvmType == "float") {
try {
double floatValue = std::stod(value);
uint64_t hexValue = reinterpret_cast<uint64_t&>(floatValue);
std::stringstream ss;
ss << "0x" << std::hex << std::uppercase << (hexValue & (0xffffffffUL << 32));
value = ss.str();
} catch (...) {
throw std::runtime_error("Invalid float literal: " + value);
}
}
irStream << " store " << llvmType << " " << value << ", " << llvmType
<< "* " << allocaName << ", align 4\n";
}
symbolTable[varName] = {allocaName, llvmType};
tmpTable[allocaName] = llvmType;
return nullptr;
}
std::any LLVMIRGenerator::visitFuncDef(SysYParser::FuncDefContext* ctx) {
currentFunction = ctx->Ident()->getText();
currentReturnType = getLLVMType(ctx->funcType()->getText());
symbolTable.clear();
tmpTable.clear();
tempCounter = 0;
hasReturn = false;
irStream << "define dso_local " << currentReturnType << " @" << currentFunction << "(";
if (ctx->funcFParams()) {
auto params = ctx->funcFParams()->funcFParam();
tempCounter += params.size();
for (size_t i = 0; i < params.size(); ++i) {
if (i > 0) irStream << ", ";
std::string paramType = getLLVMType(params[i]->bType()->getText());
irStream << paramType << " noundef %" << i;
symbolTable[params[i]->Ident()->getText()] = {"%" + std::to_string(i), paramType};
tmpTable["%" + std::to_string(i)] = paramType;
}
}
tempCounter++;
irStream << ") #0 {\n";
if (ctx->funcFParams()) {
auto params = ctx->funcFParams()->funcFParam();
for (size_t i = 0; i < params.size(); ++i) {
std::string varName = params[i]->Ident()->getText();
std::string type = params[i]->bType()->getText();
std::string llvmType = getLLVMType(type);
std::string allocaName = getNextTemp();
tmpTable[allocaName] = llvmType;
irStream << " " << allocaName << " = alloca " << llvmType << ", align 4\n";
irStream << " store " << llvmType << " " << symbolTable[varName].first << ", " << llvmType
<< "* " << allocaName << ", align 4\n";
symbolTable[varName] = {allocaName, llvmType};
}
}
ctx->blockStmt()->accept(this);
if (!hasReturn) {
if (currentReturnType == "void") {
irStream << " ret void\n";
} else {
irStream << " ret " << currentReturnType << " 0\n";
}
}
irStream << "}\n";
return nullptr;
}
std::any LLVMIRGenerator::visitBlockStmt(SysYParser::BlockStmtContext* ctx) {
for (auto item : ctx->blockItem()) {
item->accept(this);
}
return nullptr;
}
std::any visitAssignStmt(SysYParser::AssignStmtContext *ctx)
{
std::string lhsAlloca = std::any_cast<std::string>(ctx->lValue()->accept(this));
std::string lhsType = symbolTable[ctx->lValue()->Ident()->getText()].second;
std::string rhs = std::any_cast<std::string>(ctx->exp()->accept(this));
if (lhsType == "float") {
try {
double floatValue = std::stod(rhs);
uint64_t hexValue = reinterpret_cast<uint64_t&>(floatValue);
std::stringstream ss;
ss << "0x" << std::hex << std::uppercase << (hexValue & (0xffffffffUL << 32));
rhs = ss.str();
} catch (...) {
throw std::runtime_error("Invalid float literal: " + rhs);
}
}
irStream << " store " << lhsType << " " << rhs << ", " << lhsType
<< "* " << lhsAlloca << ", align 4\n";
return nullptr;
}
std::any visitIfStmt(SysYParser::IfStmtContext *ctx)
{
std::string cond = std::any_cast<std::string>(ctx->cond()->accept(this));
std::string trueLabel = "if.then." + std::to_string(tempCounter);
std::string falseLabel = "if.else." + std::to_string(tempCounter);
std::string mergeLabel = "if.end." + std::to_string(tempCounter++);
irStream << " br i1 " << cond << ", label %" << trueLabel << ", label %" << falseLabel << "\n";
irStream << trueLabel << ":\n";
ctx->stmt(0)->accept(this);
irStream << " br label %" << mergeLabel << "\n";
irStream << falseLabel << ":\n";
if (ctx->ELSE()) {
ctx->stmt(1)->accept(this);
}
irStream << " br label %" << mergeLabel << "\n";
irStream << mergeLabel << ":\n";
return nullptr;
}
std::any visitWhileStmt(SysYParser::WhileStmtContext *ctx)
{
std::string loop_cond = "while.cond." + std::to_string(tempCounter);
std::string loop_body = "while.body." + std::to_string(tempCounter);
std::string loop_end = "while.end." + std::to_string(tempCounter++);
loopStack.push({loop_end, loop_cond});
irStream << " br label %" << loop_cond << "\n";
irStream << loop_cond << ":\n";
std::string cond = std::any_cast<std::string>(ctx->cond()->accept(this));
irStream << " br i1 " << cond << ", label %" << loop_body << ", label %" << loop_end << "\n";
irStream << loop_body << ":\n";
ctx->stmt(0)->accept(this);
irStream << " br label %" << loop_cond << "\n";
irStream << loop_end << ":\n";
loopStack.pop();
return nullptr;
}
std::any visitBreakStmt(SysYParser::BreakStmtContext *ctx)
{
if (loopStack.empty()) {
throw std::runtime_error("Break statement outside of a loop.");
}
irStream << " br label %" << loopStack.top().breakLabel << "\n";
return nullptr;
}
std::any visitContinueStmt(SysYParser::ContinueStmtContext *ctx)
{
if (loopStack.empty()) {
throw std::runtime_error("Continue statement outside of a loop.");
}
irStream << " br label %" << loopStack.top().continueLabel << "\n";
return nullptr;
}
// std::any LLVMIRGenerator::visitStmt(SysYParser::StmtContext* ctx) {
// if (ctx->lValue() && ctx->ASSIGN()) {
// std::string lhsAlloca = std::any_cast<std::string>(ctx->lValue()->accept(this));
// std::string lhsType = symbolTable[ctx->lValue()->Ident()->getText()].second;
// std::string rhs = std::any_cast<std::string>(ctx->exp()->accept(this));
// if (lhsType == "float") {
// try {
// double floatValue = std::stod(rhs);
// uint64_t hexValue = reinterpret_cast<uint64_t&>(floatValue);
// std::stringstream ss;
// ss << "0x" << std::hex << std::uppercase << (hexValue & (0xffffffffUL << 32));
// rhs = ss.str();
// } catch (...) {
// throw std::runtime_error("Invalid float literal: " + rhs);
// }
// }
// irStream << " store " << lhsType << " " << rhs << ", " << lhsType
// << "* " << lhsAlloca << ", align 4\n";
// } else if (ctx->RETURN()) {
// hasReturn = true;
// if (ctx->exp()) {
// std::string value = std::any_cast<std::string>(ctx->exp()->accept(this));
// irStream << " ret " << currentReturnType << " " << value << "\n";
// } else {
// irStream << " ret void\n";
// }
// } else if (ctx->IF()) {
// std::string cond = std::any_cast<std::string>(ctx->cond()->accept(this));
// std::string trueLabel = "if.then." + std::to_string(tempCounter);
// std::string falseLabel = "if.else." + std::to_string(tempCounter);
// std::string mergeLabel = "if.end." + std::to_string(tempCounter++);
// irStream << " br i1 " << cond << ", label %" << trueLabel << ", label %" << falseLabel << "\n";
// irStream << trueLabel << ":\n";
// ctx->stmt(0)->accept(this);
// irStream << " br label %" << mergeLabel << "\n";
// irStream << falseLabel << ":\n";
// if (ctx->ELSE()) {
// ctx->stmt(1)->accept(this);
// }
// irStream << " br label %" << mergeLabel << "\n";
// irStream << mergeLabel << ":\n";
// } else if (ctx->WHILE()) {
// std::string loop_cond = "while.cond." + std::to_string(tempCounter);
// std::string loop_body = "while.body." + std::to_string(tempCounter);
// std::string loop_end = "while.end." + std::to_string(tempCounter++);
// loopStack.push({loop_end, loop_cond});
// irStream << " br label %" << loop_cond << "\n";
// irStream << loop_cond << ":\n";
// std::string cond = std::any_cast<std::string>(ctx->cond()->accept(this));
// irStream << " br i1 " << cond << ", label %" << loop_body << ", label %" << loop_end << "\n";
// irStream << loop_body << ":\n";
// ctx->stmt(0)->accept(this);
// irStream << " br label %" << loop_cond << "\n";
// irStream << loop_end << ":\n";
// loopStack.pop();
// } else if (ctx->BREAK()) {
// if (loopStack.empty()) {
// throw std::runtime_error("Break statement outside of a loop.");
// }
// irStream << " br label %" << loopStack.top().breakLabel << "\n";
// } else if (ctx->CONTINUE()) {
// if (loopStack.empty()) {
// throw std::runtime_error("Continue statement outside of a loop.");
// }
// irStream << " br label %" << loopStack.top().continueLabel << "\n";
// } else if (ctx->blockStmt()) {
// ctx->blockStmt()->accept(this);
// } else if (ctx->exp()) {
// ctx->exp()->accept(this);
// }
// return nullptr;
// }
std::any LLVMIRGenerator::visitLValue(SysYParser::LValueContext* ctx) {
std::string varName = ctx->Ident()->getText();
return symbolTable[varName].first;
}
std::any LLVMIRGenerator::visitPrimaryExp(SysYParser::PrimaryExpContext* ctx) {
if (ctx->lValue()) {
std::string allocaPtr = std::any_cast<std::string>(ctx->lValue()->accept(this));
std::string varName = ctx->lValue()->Ident()->getText();
std::string type = symbolTable[varName].second;
std::string temp = getNextTemp();
irStream << " " << temp << " = load " << type << ", " << type << "* " << allocaPtr << ", align 4\n";
tmpTable[temp] = type;
return temp;
} else if (ctx->exp()) {
return ctx->exp()->accept(this);
} else {
return ctx->number()->accept(this);
}
}
std::any LLVMIRGenerator::visitNumber(SysYParser::NumberContext* ctx) {
if (ctx->ILITERAL()) {
return ctx->ILITERAL()->getText();
} else if (ctx->FLITERAL()) {
return ctx->FLITERAL()->getText();
}
return "";
}
std::any LLVMIRGenerator::visitUnaryExp(SysYParser::UnaryExpContext* ctx) {
if (ctx->unaryOp()) {
std::string operand = std::any_cast<std::string>(ctx->unaryExp()->accept(this));
std::string op = ctx->unaryOp()->getText();
std::string temp = getNextTemp();
std::string type = operand.substr(0, operand.find(' '));
tmpTable[temp] = type;
if (op == "-") {
irStream << " " << temp << " = sub " << type << " 0, " << operand << "\n";
} else if (op == "!") {
irStream << " " << temp << " = xor " << type << " " << operand << ", 1\n";
}
return temp;
} else if (ctx->Ident()) {
std::string funcName = ctx->Ident()->getText();
std::vector<std::string> args;
if (ctx->funcRParams()) {
for (auto argCtx : ctx->funcRParams()->exp()) {
args.push_back(std::any_cast<std::string>(argCtx->accept(this)));
}
}
std::string temp = getNextTemp();
std::string argList = "";
for (size_t i = 0; i < args.size(); ++i) {
if (i > 0) argList += ", ";
argList +=tmpTable[args[i]] + " noundef " + args[i];
}
irStream << " " << temp << " = call " << currentReturnType << " @" << funcName << "(" << argList << ")\n";
tmpTable[temp] = currentReturnType;
return temp;
}
return ctx->primaryExp()->accept(this);
}
std::any LLVMIRGenerator::visitMulExp(SysYParser::MulExpContext* ctx) {
auto unaryExps = ctx->unaryExp();
std::string left = std::any_cast<std::string>(unaryExps[0]->accept(this));
for (size_t i = 1; i < unaryExps.size(); ++i) {
std::string right = std::any_cast<std::string>(unaryExps[i]->accept(this));
std::string op = ctx->children[2*i-1]->getText();
std::string temp = getNextTemp();
std::string type = tmpTable[left];
if (op == "*") {
irStream << " " << temp << " = mul nsw " << type << " " << left << ", " << right << "\n";
} else if (op == "/") {
irStream << " " << temp << " = sdiv " << type << " " << left << ", " << right << "\n";
} else if (op == "%") {
irStream << " " << temp << " = srem " << type << " " << left << ", " << right << "\n";
}
left = temp;
tmpTable[temp] = type;
}
return left;
}
std::any LLVMIRGenerator::visitAddExp(SysYParser::AddExpContext* ctx) {
auto mulExps = ctx->mulExp();
std::string left = std::any_cast<std::string>(mulExps[0]->accept(this));
for (size_t i = 1; i < mulExps.size(); ++i) {
std::string right = std::any_cast<std::string>(mulExps[i]->accept(this));
std::string op = ctx->children[2*i-1]->getText();
std::string temp = getNextTemp();
std::string type = tmpTable[left];
if (op == "+") {
irStream << " " << temp << " = add nsw " << type << " " << left << ", " << right << "\n";
} else if (op == "-") {
irStream << " " << temp << " = sub nsw " << type << " " << left << ", " << right << "\n";
}
left = temp;
tmpTable[temp] = type;
}
return left;
}
std::any LLVMIRGenerator::visitRelExp(SysYParser::RelExpContext* ctx) {
auto addExps = ctx->addExp();
std::string left = std::any_cast<std::string>(addExps[0]->accept(this));
for (size_t i = 1; i < addExps.size(); ++i) {
std::string right = std::any_cast<std::string>(addExps[i]->accept(this));
std::string op = ctx->children[2*i-1]->getText();
std::string temp = getNextTemp();
std::string type = tmpTable[left];
if (op == "<") {
irStream << " " << temp << " = icmp slt " << type << " " << left << ", " << right << "\n";
} else if (op == ">") {
irStream << " " << temp << " = icmp sgt " << type << " " << left << ", " << right << "\n";
} else if (op == "<=") {
irStream << " " << temp << " = icmp sle " << type << " " << left << ", " << right << "\n";
} else if (op == ">=") {
irStream << " " << temp << " = icmp sge " << type << " " << left << ", " << right << "\n";
}
left = temp;
}
return left;
}
std::any LLVMIRGenerator::visitEqExp(SysYParser::EqExpContext* ctx) {
auto relExps = ctx->relExp();
std::string left = std::any_cast<std::string>(relExps[0]->accept(this));
for (size_t i = 1; i < relExps.size(); ++i) {
std::string right = std::any_cast<std::string>(relExps[i]->accept(this));
std::string op = ctx->children[2*i-1]->getText();
std::string temp = getNextTemp();
std::string type = tmpTable[left];
if (op == "==") {
irStream << " " << temp << " = icmp eq " << type << " " << left << ", " << right << "\n";
} else if (op == "!=") {
irStream << " " << temp << " = icmp ne " << type << " " << left << ", " << right << "\n";
}
left = temp;
}
return left;
}
std::any LLVMIRGenerator::visitLAndExp(SysYParser::LAndExpContext* ctx) {
auto eqExps = ctx->eqExp();
std::string left = std::any_cast<std::string>(eqExps[0]->accept(this));
for (size_t i = 1; i < eqExps.size(); ++i) {
std::string falseLabel = "land.false." + std::to_string(tempCounter);
std::string endLabel = "land.end." + std::to_string(tempCounter++);
std::string temp = getNextTemp();
irStream << " br label %" << falseLabel << "\n";
irStream << falseLabel << ":\n";
std::string right = std::any_cast<std::string>(eqExps[i]->accept(this));
irStream << " " << temp << " = and i1 " << left << ", " << right << "\n";
irStream << " br label %" << endLabel << "\n";
irStream << endLabel << ":\n";
left = temp;
}
return left;
}
std::any LLVMIRGenerator::visitLOrExp(SysYParser::LOrExpContext* ctx) {
auto lAndExps = ctx->lAndExp();
std::string left = std::any_cast<std::string>(lAndExps[0]->accept(this));
for (size_t i = 1; i < lAndExps.size(); ++i) {
std::string trueLabel = "lor.true." + std::to_string(tempCounter);
std::string endLabel = "lor.end." + std::to_string(tempCounter++);
std::string temp = getNextTemp();
irStream << " br label %" << trueLabel << "\n";
irStream << trueLabel << ":\n";
std::string right = std::any_cast<std::string>(lAndExps[i]->accept(this));
irStream << " " << temp << " = or i1 " << left << ", " << right << "\n";
irStream << " br label %" << endLabel << "\n";
irStream << endLabel << ":\n";
left = temp;
}
return left;
}