[lab2] testfile01 finished

This commit is contained in:
Lixuanwang
2025-03-10 21:43:20 +08:00
parent b0b03ff55b
commit 3d60a94894
3 changed files with 93 additions and 64 deletions

View File

@@ -1,3 +1,4 @@
// SysYIRGenerator.cpp
#include "SysYIRGenerator.h" #include "SysYIRGenerator.h"
#include <iomanip> #include <iomanip>
@@ -14,7 +15,7 @@ std::string SysYIRGenerator::getLLVMType(const std::string& type) {
if (type == "int") return "i32"; if (type == "int") return "i32";
if (type == "float") return "float"; if (type == "float") return "float";
if (type.find("[]") != std::string::npos) if (type.find("[]") != std::string::npos)
return getLLVMType(type.substr(0, type.size() - 2)) + "*"; return getLLVMType(type.substr(0, type.size()-2)) + "*";
return "i32"; return "i32";
} }
@@ -29,26 +30,40 @@ std::any SysYIRGenerator::visitCompUnit(SysYParser::CompUnitContext* ctx) {
} }
std::any SysYIRGenerator::visitConstDecl(SysYParser::ConstDeclContext* ctx) { std::any SysYIRGenerator::visitConstDecl(SysYParser::ConstDeclContext* ctx) {
// 常量声明暂不处理LLVM IR 中常量通常内联)
return nullptr; return nullptr;
} }
std::any SysYIRGenerator::visitVarDecl(SysYParser::VarDeclContext* ctx) { std::any SysYIRGenerator::visitVarDecl(SysYParser::VarDeclContext* ctx) {
std::string type = ctx->bType()->getText();
for (auto varDef : ctx->varDef()) { for (auto varDef : ctx->varDef()) {
symbolTable[varDef->Ident()->getText()].second = type;
varDef->accept(this); varDef->accept(this);
} }
return nullptr; return nullptr;
} }
std::any SysYIRGenerator::visitVarDef(SysYParser::VarDefContext* ctx) {
std::string varName = ctx->Ident()->getText();
std::string type = symbolTable[varName].second;
std::string llvmType = getLLVMType(type);
std::string allocaName = getNextTemp();
symbolTable[varName] = {allocaName, llvmType};
irStream << " " << allocaName << " = alloca " << llvmType << ", align " << (type == "float" ? "4" : "4") << "\n";
if (ctx->ASSIGN()) {
std::string value = std::any_cast<std::string>(ctx->initVal()->accept(this));
irStream << " store " << llvmType << " " << value << ", " << llvmType << "* " << allocaName << ", align " << (type == "float" ? "4" : "4") << "\n";
}
return nullptr;
}
std::any SysYIRGenerator::visitFuncDef(SysYParser::FuncDefContext* ctx) { std::any SysYIRGenerator::visitFuncDef(SysYParser::FuncDefContext* ctx) {
currentFunction = ctx->Ident()->getText(); currentFunction = ctx->Ident()->getText();
currentReturnType = getLLVMType(ctx->funcType()->getText());
symbolTable.clear(); symbolTable.clear();
hasReturn = false;
// 函数头 irStream << "define " << currentReturnType << " @" << currentFunction << "(";
std::string returnType = getLLVMType(ctx->funcType()->getText());
irStream << "define " << returnType << " @" << currentFunction << "(";
// 参数
auto paramsCtx = ctx->funcFParams(); auto paramsCtx = ctx->funcFParams();
if (paramsCtx) { if (paramsCtx) {
auto params = paramsCtx->funcFParam(); auto params = paramsCtx->funcFParam();
@@ -58,26 +73,22 @@ std::any SysYIRGenerator::visitFuncDef(SysYParser::FuncDefContext* ctx) {
std::string paramName = "%" + std::to_string(i); std::string paramName = "%" + std::to_string(i);
std::string paramType = getLLVMType(param->bType()->getText()); std::string paramType = getLLVMType(param->bType()->getText());
irStream << paramType << " " << paramName; irStream << paramType << " " << paramName;
// 分配参数
std::string allocaName = getNextTemp(); std::string allocaName = getNextTemp();
symbolTable[param->Ident()->getText()] = allocaName; symbolTable[param->Ident()->getText()] = {allocaName, paramType};
irStream << "\n " << allocaName << " = alloca " << paramType; irStream << "\n " << allocaName << " = alloca " << paramType << ", align " << (paramType == "float" ? "4" : "4");
irStream << "\n store " << paramType << " %" << i << ", " << paramType << "* " << allocaName; irStream << "\n store " << paramType << " %" << i << ", " << paramType << "* " << allocaName << ", align " << (paramType == "float" ? "4" : "4");
} }
} }
irStream << ") {\nentry:\n"; irStream << ") {\nentry:\n";
// 函数体
ctx->blockStmt()->accept(this); ctx->blockStmt()->accept(this);
if (!hasReturn) {
// 默认返回值 if (currentReturnType == "void") {
if (returnType == "void") { irStream << " ret void\n";
irStream << " ret void\n"; } else {
} else { irStream << " ret " << currentReturnType << " 0\n";
irStream << " ret " << returnType << " 0\n"; }
} }
irStream << "}\n\n"; irStream << "}\n";
return nullptr; return nullptr;
} }
@@ -89,16 +100,17 @@ std::any SysYIRGenerator::visitBlockStmt(SysYParser::BlockStmtContext* ctx) {
} }
std::any SysYIRGenerator::visitStmt(SysYParser::StmtContext* ctx) { std::any SysYIRGenerator::visitStmt(SysYParser::StmtContext* ctx) {
if (ctx->lValue() && ctx->exp()) { if (ctx->lValue() && ctx->ASSIGN()) {
// 赋值语句 std::string lhsAlloca = std::any_cast<std::string>(ctx->lValue()->accept(this));
std::string lhs = std::any_cast<std::string>(ctx->lValue()->accept(this)); std::string varName = ctx->lValue()->Ident()->getText();
std::string lhsType = symbolTable[varName].second;
std::string rhs = std::any_cast<std::string>(ctx->exp()->accept(this)); std::string rhs = std::any_cast<std::string>(ctx->exp()->accept(this));
irStream << " store " << getLLVMType("") << " " << rhs << ", " << getLLVMType("") << "* " << lhs << "\n"; irStream << " store " << lhsType << " " << rhs << ", " << lhsType << "* " << lhsAlloca << ", align " << (lhsType == "float" ? "4" : "4") << "\n";
} else if (ctx->RETURN()) { } else if (ctx->RETURN()) {
// 返回语句 hasReturn = true;
if (ctx->exp()) { if (ctx->exp()) {
std::string value = std::any_cast<std::string>(ctx->exp()->accept(this)); std::string value = std::any_cast<std::string>(ctx->exp()->accept(this));
irStream << " ret " << getLLVMType("") << " " << value << "\n"; irStream << " ret " << currentReturnType << " " << value << "\n";
} else { } else {
irStream << " ret void\n"; irStream << " ret void\n";
} }
@@ -108,12 +120,22 @@ std::any SysYIRGenerator::visitStmt(SysYParser::StmtContext* ctx) {
std::any SysYIRGenerator::visitLValue(SysYParser::LValueContext* ctx) { std::any SysYIRGenerator::visitLValue(SysYParser::LValueContext* ctx) {
std::string varName = ctx->Ident()->getText(); std::string varName = ctx->Ident()->getText();
if (symbolTable.find(varName) == symbolTable.end()) { return symbolTable[varName].first;
std::string allocaName = getNextTemp(); }
symbolTable[varName] = allocaName;
irStream << " " << allocaName << " = alloca " << getLLVMType("") << "\n"; std::any SysYIRGenerator::visitPrimaryExp(SysYParser::PrimaryExpContext* ctx) {
if (ctx->lValue()) {
std::string allocaPtr = std::any_cast<std::string>(ctx->lValue()->accept(this));
std::string varName = ctx->lValue()->Ident()->getText();
std::string type = symbolTable[varName].second;
std::string temp = getNextTemp();
irStream << " " << temp << " = load " << type << ", " << type << "* " << allocaPtr << ", align " << (type == "float" ? "4" : "4") << "\n";
return temp;
} else if (ctx->exp()) {
return ctx->exp()->accept(this);
} else {
return ctx->number()->accept(this);
} }
return symbolTable[varName];
} }
std::any SysYIRGenerator::visitNumber(SysYParser::NumberContext* ctx) { std::any SysYIRGenerator::visitNumber(SysYParser::NumberContext* ctx) {
@@ -130,10 +152,11 @@ std::any SysYIRGenerator::visitUnaryExp(SysYParser::UnaryExpContext* ctx) {
std::string operand = std::any_cast<std::string>(ctx->unaryExp()->accept(this)); std::string operand = std::any_cast<std::string>(ctx->unaryExp()->accept(this));
std::string op = ctx->unaryOp()->getText(); std::string op = ctx->unaryOp()->getText();
std::string temp = getNextTemp(); std::string temp = getNextTemp();
std::string type = operand.substr(0, operand.find(' '));
if (op == "-") { if (op == "-") {
irStream << " " << temp << " = sub " << getLLVMType("") << " 0, " << operand << "\n"; irStream << " " << temp << " = sub " << type << " 0, " << operand << "\n";
} else if (op == "!") { } else if (op == "!") {
irStream << " " << temp << " = xor " << getLLVMType("") << " " << operand << ", 1\n"; irStream << " " << temp << " = xor " << type << " " << operand << ", 1\n";
} }
return temp; return temp;
} }
@@ -145,14 +168,15 @@ std::any SysYIRGenerator::visitMulExp(SysYParser::MulExpContext* ctx) {
std::string left = std::any_cast<std::string>(unaryExps[0]->accept(this)); std::string left = std::any_cast<std::string>(unaryExps[0]->accept(this));
for (size_t i = 1; i < unaryExps.size(); ++i) { for (size_t i = 1; i < unaryExps.size(); ++i) {
std::string right = std::any_cast<std::string>(unaryExps[i]->accept(this)); std::string right = std::any_cast<std::string>(unaryExps[i]->accept(this));
std::string op = ctx->children[2 * i - 1]->getText(); std::string op = ctx->children[2*i-1]->getText();
std::string temp = getNextTemp(); std::string temp = getNextTemp();
std::string type = left.substr(0, left.find(' '));
if (op == "*") { if (op == "*") {
irStream << " " << temp << " = mul " << getLLVMType("") << " " << left << ", " << right << "\n"; irStream << " " << temp << " = mul " << type << " " << left << ", " << right << "\n";
} else if (op == "/") { } else if (op == "/") {
irStream << " " << temp << " = sdiv " << getLLVMType("") << " " << left << ", " << right << "\n"; irStream << " " << temp << " = sdiv " << type << " " << left << ", " << right << "\n";
} else if (op == "%") { } else if (op == "%") {
irStream << " " << temp << " = srem " << getLLVMType("") << " " << left << ", " << right << "\n"; irStream << " " << temp << " = srem " << type << " " << left << ", " << right << "\n";
} }
left = temp; left = temp;
} }
@@ -164,12 +188,13 @@ std::any SysYIRGenerator::visitAddExp(SysYParser::AddExpContext* ctx) {
std::string left = std::any_cast<std::string>(mulExps[0]->accept(this)); std::string left = std::any_cast<std::string>(mulExps[0]->accept(this));
for (size_t i = 1; i < mulExps.size(); ++i) { for (size_t i = 1; i < mulExps.size(); ++i) {
std::string right = std::any_cast<std::string>(mulExps[i]->accept(this)); std::string right = std::any_cast<std::string>(mulExps[i]->accept(this));
std::string op = ctx->children[2 * i - 1]->getText(); std::string op = ctx->children[2*i-1]->getText();
std::string temp = getNextTemp(); std::string temp = getNextTemp();
std::string type = left.substr(0, left.find(' '));
if (op == "+") { if (op == "+") {
irStream << " " << temp << " = add " << getLLVMType("") << " " << left << ", " << right << "\n"; irStream << " " << temp << " = add " << type << " " << left << ", " << right << "\n";
} else if (op == "-") { } else if (op == "-") {
irStream << " " << temp << " = sub " << getLLVMType("") << " " << left << ", " << right << "\n"; irStream << " " << temp << " = sub " << type << " " << left << ", " << right << "\n";
} }
left = temp; left = temp;
} }
@@ -181,16 +206,17 @@ std::any SysYIRGenerator::visitRelExp(SysYParser::RelExpContext* ctx) {
std::string left = std::any_cast<std::string>(addExps[0]->accept(this)); std::string left = std::any_cast<std::string>(addExps[0]->accept(this));
for (size_t i = 1; i < addExps.size(); ++i) { for (size_t i = 1; i < addExps.size(); ++i) {
std::string right = std::any_cast<std::string>(addExps[i]->accept(this)); std::string right = std::any_cast<std::string>(addExps[i]->accept(this));
std::string op = ctx->children[2 * i - 1]->getText(); std::string op = ctx->children[2*i-1]->getText();
std::string temp = getNextTemp(); std::string temp = getNextTemp();
std::string type = left.substr(0, left.find(' '));
if (op == "<") { if (op == "<") {
irStream << " " << temp << " = icmp slt " << getLLVMType("") << " " << left << ", " << right << "\n"; irStream << " " << temp << " = icmp slt " << type << " " << left << ", " << right << "\n";
} else if (op == ">") { } else if (op == ">") {
irStream << " " << temp << " = icmp sgt " << getLLVMType("") << " " << left << ", " << right << "\n"; irStream << " " << temp << " = icmp sgt " << type << " " << left << ", " << right << "\n";
} else if (op == "<=") { } else if (op == "<=") {
irStream << " " << temp << " = icmp sle " << getLLVMType("") << " " << left << ", " << right << "\n"; irStream << " " << temp << " = icmp sle " << type << " " << left << ", " << right << "\n";
} else if (op == ">=") { } else if (op == ">=") {
irStream << " " << temp << " = icmp sge " << getLLVMType("") << " " << left << ", " << right << "\n"; irStream << " " << temp << " = icmp sge " << type << " " << left << ", " << right << "\n";
} }
left = temp; left = temp;
} }
@@ -202,12 +228,13 @@ std::any SysYIRGenerator::visitEqExp(SysYParser::EqExpContext* ctx) {
std::string left = std::any_cast<std::string>(relExps[0]->accept(this)); std::string left = std::any_cast<std::string>(relExps[0]->accept(this));
for (size_t i = 1; i < relExps.size(); ++i) { for (size_t i = 1; i < relExps.size(); ++i) {
std::string right = std::any_cast<std::string>(relExps[i]->accept(this)); std::string right = std::any_cast<std::string>(relExps[i]->accept(this));
std::string op = ctx->children[2 * i - 1]->getText(); std::string op = ctx->children[2*i-1]->getText();
std::string temp = getNextTemp(); std::string temp = getNextTemp();
std::string type = left.substr(0, left.find(' '));
if (op == "==") { if (op == "==") {
irStream << " " << temp << " = icmp eq " << getLLVMType("") << " " << left << ", " << right << "\n"; irStream << " " << temp << " = icmp eq " << type << " " << left << ", " << right << "\n";
} else if (op == "!=") { } else if (op == "!=") {
irStream << " " << temp << " = icmp ne " << getLLVMType("") << " " << left << ", " << right << "\n"; irStream << " " << temp << " = icmp ne " << type << " " << left << ", " << right << "\n";
} }
left = temp; left = temp;
} }
@@ -220,7 +247,7 @@ std::any SysYIRGenerator::visitLAndExp(SysYParser::LAndExpContext* ctx) {
for (size_t i = 1; i < eqExps.size(); ++i) { for (size_t i = 1; i < eqExps.size(); ++i) {
std::string right = std::any_cast<std::string>(eqExps[i]->accept(this)); std::string right = std::any_cast<std::string>(eqExps[i]->accept(this));
std::string temp = getNextTemp(); std::string temp = getNextTemp();
irStream << " " << temp << " = and " << getLLVMType("") << " " << left << ", " << right << "\n"; irStream << " " << temp << " = and i1 " << left << ", " << right << "\n";
left = temp; left = temp;
} }
return left; return left;
@@ -232,7 +259,7 @@ std::any SysYIRGenerator::visitLOrExp(SysYParser::LOrExpContext* ctx) {
for (size_t i = 1; i < lAndExps.size(); ++i) { for (size_t i = 1; i < lAndExps.size(); ++i) {
std::string right = std::any_cast<std::string>(lAndExps[i]->accept(this)); std::string right = std::any_cast<std::string>(lAndExps[i]->accept(this));
std::string temp = getNextTemp(); std::string temp = getNextTemp();
irStream << " " << temp << " = or " << getLLVMType("") << " " << left << ", " << right << "\n"; irStream << " " << temp << " = or i1 " << left << ", " << right << "\n";
left = temp; left = temp;
} }
return left; return left;

View File

@@ -1,3 +1,4 @@
// SysYIRGenerator.h
#pragma once #pragma once
#include "SysYBaseVisitor.h" #include "SysYBaseVisitor.h"
#include "SysYParser.h" #include "SysYParser.h"
@@ -7,29 +8,30 @@
class SysYIRGenerator : public SysYBaseVisitor { class SysYIRGenerator : public SysYBaseVisitor {
public: public:
std::string generateIR(SysYParser::CompUnitContext* unit); // 公共接口,用于生成 IR std::string generateIR(SysYParser::CompUnitContext* unit);
std::string getIR() const { return irStream.str(); } // 获取生成的 IR std::string getIR() const { return irStream.str(); }
private: private:
std::stringstream irStream; std::stringstream irStream;
int tempCounter = 0; int tempCounter = 0;
std::map<std::string, std::string> symbolTable; // 符号表 std::map<std::string, std::pair<std::string, std::string>> symbolTable; // {varName: {allocaName, type}}
std::vector<std::string> globalVars; // 全局变量 std::vector<std::string> globalVars;
std::string currentFunction; // 当前函数名 std::string currentFunction;
std::vector<std::string> breakStack; // break 目标标签栈 std::string currentReturnType;
std::vector<std::string> continueStack; // continue 目标标签栈 std::vector<std::string> breakStack;
std::vector<std::string> continueStack;
bool hasReturn = false;
std::string getNextTemp(); // 获取下一个临时变量名 std::string getNextTemp();
std::string getLLVMType(const std::string& type); // 获取 LLVM 类型 std::string getLLVMType(const std::string& type);
// 访问方法
std::any visitCompUnit(SysYParser::CompUnitContext* ctx) override; std::any visitCompUnit(SysYParser::CompUnitContext* ctx) override;
std::any visitConstDecl(SysYParser::ConstDeclContext* ctx) override; std::any visitConstDecl(SysYParser::ConstDeclContext* ctx) override;
std::any visitVarDecl(SysYParser::VarDeclContext* ctx) override; std::any visitVarDecl(SysYParser::VarDeclContext* ctx) override;
std::any visitVarDef(SysYParser::VarDefContext* ctx) override;
std::any visitFuncDef(SysYParser::FuncDefContext* ctx) override; std::any visitFuncDef(SysYParser::FuncDefContext* ctx) override;
std::any visitBlockStmt(SysYParser::BlockStmtContext* ctx) override; std::any visitBlockStmt(SysYParser::BlockStmtContext* ctx) override;
std::any visitStmt(SysYParser::StmtContext* ctx) override; std::any visitStmt(SysYParser::StmtContext* ctx) override;
std::any visitLValue(SysYParser::LValueContext* ctx) override; std::any visitLValue(SysYParser::LValueContext* ctx) override;
std::any visitPrimaryExp(SysYParser::PrimaryExpContext* ctx) override;
std::any visitNumber(SysYParser::NumberContext* ctx) override; std::any visitNumber(SysYParser::NumberContext* ctx) override;
std::any visitUnaryExp(SysYParser::UnaryExpContext* ctx) override; std::any visitUnaryExp(SysYParser::UnaryExpContext* ctx) override;
std::any visitMulExp(SysYParser::MulExpContext* ctx) override; std::any visitMulExp(SysYParser::MulExpContext* ctx) override;

View File

@@ -1,7 +1,7 @@
//test file for backend lab //test file for backend lab
int main() { int main() {
const int a = 1; int a;
const int b = 2; const int b = 2;
int c; int c;