diff --git a/src/LLVMIRGenerator.cpp b/src/LLVMIRGenerator.cpp index b349207..b6494d9 100644 --- a/src/LLVMIRGenerator.cpp +++ b/src/LLVMIRGenerator.cpp @@ -217,88 +217,172 @@ std::any LLVMIRGenerator::visitBlockStmt(SysYParser::BlockStmtContext* ctx) { } return nullptr; } - -std::any LLVMIRGenerator::visitStmt(SysYParser::StmtContext* ctx) { - if (ctx->lValue() && ctx->ASSIGN()) { - std::string lhsAlloca = std::any_cast(ctx->lValue()->accept(this)); - std::string lhsType = symbolTable[ctx->lValue()->Ident()->getText()].second; - std::string rhs = std::any_cast(ctx->exp()->accept(this)); - if (lhsType == "float") { - try { - double floatValue = std::stod(rhs); - uint64_t hexValue = reinterpret_cast(floatValue); - std::stringstream ss; - ss << "0x" << std::hex << std::uppercase << (hexValue & (0xffffffffUL << 32)); - rhs = ss.str(); - } catch (...) { - throw std::runtime_error("Invalid float literal: " + rhs); - } +std::any visitAssignStmt(SysYParser::AssignStmtContext *ctx) +{ + std::string lhsAlloca = std::any_cast(ctx->lValue()->accept(this)); + std::string lhsType = symbolTable[ctx->lValue()->Ident()->getText()].second; + std::string rhs = std::any_cast(ctx->exp()->accept(this)); + + if (lhsType == "float") { + try { + double floatValue = std::stod(rhs); + uint64_t hexValue = reinterpret_cast(floatValue); + std::stringstream ss; + ss << "0x" << std::hex << std::uppercase << (hexValue & (0xffffffffUL << 32)); + rhs = ss.str(); + } catch (...) { + throw std::runtime_error("Invalid float literal: " + rhs); } - irStream << " store " << lhsType << " " << rhs << ", " << lhsType - << "* " << lhsAlloca << ", align 4\n"; - } else if (ctx->RETURN()) { - hasReturn = true; - if (ctx->exp()) { - std::string value = std::any_cast(ctx->exp()->accept(this)); - irStream << " ret " << currentReturnType << " " << value << "\n"; - } else { - irStream << " ret void\n"; - } - } else if (ctx->IF()) { - std::string cond = std::any_cast(ctx->cond()->accept(this)); - std::string trueLabel = "if.then." + std::to_string(tempCounter); - std::string falseLabel = "if.else." + std::to_string(tempCounter); - std::string mergeLabel = "if.end." + std::to_string(tempCounter++); - - irStream << " br i1 " << cond << ", label %" << trueLabel << ", label %" << falseLabel << "\n"; - - irStream << trueLabel << ":\n"; - ctx->stmt(0)->accept(this); - irStream << " br label %" << mergeLabel << "\n"; - - irStream << falseLabel << ":\n"; - if (ctx->ELSE()) { - ctx->stmt(1)->accept(this); - } - irStream << " br label %" << mergeLabel << "\n"; - - irStream << mergeLabel << ":\n"; - } else if (ctx->WHILE()) { - std::string loop_cond = "while.cond." + std::to_string(tempCounter); - std::string loop_body = "while.body." + std::to_string(tempCounter); - std::string loop_end = "while.end." + std::to_string(tempCounter++); - - loopStack.push({loop_end, loop_cond}); - irStream << " br label %" << loop_cond << "\n"; - irStream << loop_cond << ":\n"; - - std::string cond = std::any_cast(ctx->cond()->accept(this)); - irStream << " br i1 " << cond << ", label %" << loop_body << ", label %" << loop_end << "\n"; - irStream << loop_body << ":\n"; - ctx->stmt(0)->accept(this); - irStream << " br label %" << loop_cond << "\n"; - irStream << loop_end << ":\n"; - - loopStack.pop(); - - } else if (ctx->BREAK()) { - if (loopStack.empty()) { - throw std::runtime_error("Break statement outside of a loop."); - } - irStream << " br label %" << loopStack.top().breakLabel << "\n"; - } else if (ctx->CONTINUE()) { - if (loopStack.empty()) { - throw std::runtime_error("Continue statement outside of a loop."); - } - irStream << " br label %" << loopStack.top().continueLabel << "\n"; - } else if (ctx->blockStmt()) { - ctx->blockStmt()->accept(this); - } else if (ctx->exp()) { - ctx->exp()->accept(this); } + + irStream << " store " << lhsType << " " << rhs << ", " << lhsType + << "* " << lhsAlloca << ", align 4\n"; return nullptr; } +std::any visitIfStmt(SysYParser::IfStmtContext *ctx) +{ + std::string cond = std::any_cast(ctx->cond()->accept(this)); + std::string trueLabel = "if.then." + std::to_string(tempCounter); + std::string falseLabel = "if.else." + std::to_string(tempCounter); + std::string mergeLabel = "if.end." + std::to_string(tempCounter++); + + irStream << " br i1 " << cond << ", label %" << trueLabel << ", label %" << falseLabel << "\n"; + + irStream << trueLabel << ":\n"; + ctx->stmt(0)->accept(this); + irStream << " br label %" << mergeLabel << "\n"; + + irStream << falseLabel << ":\n"; + if (ctx->ELSE()) { + ctx->stmt(1)->accept(this); + } + irStream << " br label %" << mergeLabel << "\n"; + + irStream << mergeLabel << ":\n"; + return nullptr; +} + +std::any visitWhileStmt(SysYParser::WhileStmtContext *ctx) +{ + std::string loop_cond = "while.cond." + std::to_string(tempCounter); + std::string loop_body = "while.body." + std::to_string(tempCounter); + std::string loop_end = "while.end." + std::to_string(tempCounter++); + + loopStack.push({loop_end, loop_cond}); + irStream << " br label %" << loop_cond << "\n"; + irStream << loop_cond << ":\n"; + + std::string cond = std::any_cast(ctx->cond()->accept(this)); + irStream << " br i1 " << cond << ", label %" << loop_body << ", label %" << loop_end << "\n"; + irStream << loop_body << ":\n"; + ctx->stmt(0)->accept(this); + irStream << " br label %" << loop_cond << "\n"; + irStream << loop_end << ":\n"; + + loopStack.pop(); + return nullptr; +} + +std::any visitBreakStmt(SysYParser::BreakStmtContext *ctx) +{ + if (loopStack.empty()) { + throw std::runtime_error("Break statement outside of a loop."); + } + irStream << " br label %" << loopStack.top().breakLabel << "\n"; + return nullptr; +} + +std::any visitContinueStmt(SysYParser::ContinueStmtContext *ctx) +{ + if (loopStack.empty()) { + throw std::runtime_error("Continue statement outside of a loop."); + } + irStream << " br label %" << loopStack.top().continueLabel << "\n"; + return nullptr; +} + +// std::any LLVMIRGenerator::visitStmt(SysYParser::StmtContext* ctx) { +// if (ctx->lValue() && ctx->ASSIGN()) { +// std::string lhsAlloca = std::any_cast(ctx->lValue()->accept(this)); +// std::string lhsType = symbolTable[ctx->lValue()->Ident()->getText()].second; +// std::string rhs = std::any_cast(ctx->exp()->accept(this)); +// if (lhsType == "float") { +// try { +// double floatValue = std::stod(rhs); +// uint64_t hexValue = reinterpret_cast(floatValue); +// std::stringstream ss; +// ss << "0x" << std::hex << std::uppercase << (hexValue & (0xffffffffUL << 32)); +// rhs = ss.str(); +// } catch (...) { +// throw std::runtime_error("Invalid float literal: " + rhs); +// } +// } +// irStream << " store " << lhsType << " " << rhs << ", " << lhsType +// << "* " << lhsAlloca << ", align 4\n"; +// } else if (ctx->RETURN()) { +// hasReturn = true; +// if (ctx->exp()) { +// std::string value = std::any_cast(ctx->exp()->accept(this)); +// irStream << " ret " << currentReturnType << " " << value << "\n"; +// } else { +// irStream << " ret void\n"; +// } +// } else if (ctx->IF()) { +// std::string cond = std::any_cast(ctx->cond()->accept(this)); +// std::string trueLabel = "if.then." + std::to_string(tempCounter); +// std::string falseLabel = "if.else." + std::to_string(tempCounter); +// std::string mergeLabel = "if.end." + std::to_string(tempCounter++); + +// irStream << " br i1 " << cond << ", label %" << trueLabel << ", label %" << falseLabel << "\n"; + +// irStream << trueLabel << ":\n"; +// ctx->stmt(0)->accept(this); +// irStream << " br label %" << mergeLabel << "\n"; + +// irStream << falseLabel << ":\n"; +// if (ctx->ELSE()) { +// ctx->stmt(1)->accept(this); +// } +// irStream << " br label %" << mergeLabel << "\n"; + +// irStream << mergeLabel << ":\n"; +// } else if (ctx->WHILE()) { +// std::string loop_cond = "while.cond." + std::to_string(tempCounter); +// std::string loop_body = "while.body." + std::to_string(tempCounter); +// std::string loop_end = "while.end." + std::to_string(tempCounter++); + +// loopStack.push({loop_end, loop_cond}); +// irStream << " br label %" << loop_cond << "\n"; +// irStream << loop_cond << ":\n"; + +// std::string cond = std::any_cast(ctx->cond()->accept(this)); +// irStream << " br i1 " << cond << ", label %" << loop_body << ", label %" << loop_end << "\n"; +// irStream << loop_body << ":\n"; +// ctx->stmt(0)->accept(this); +// irStream << " br label %" << loop_cond << "\n"; +// irStream << loop_end << ":\n"; + +// loopStack.pop(); + +// } else if (ctx->BREAK()) { +// if (loopStack.empty()) { +// throw std::runtime_error("Break statement outside of a loop."); +// } +// irStream << " br label %" << loopStack.top().breakLabel << "\n"; +// } else if (ctx->CONTINUE()) { +// if (loopStack.empty()) { +// throw std::runtime_error("Continue statement outside of a loop."); +// } +// irStream << " br label %" << loopStack.top().continueLabel << "\n"; +// } else if (ctx->blockStmt()) { +// ctx->blockStmt()->accept(this); +// } else if (ctx->exp()) { +// ctx->exp()->accept(this); +// } +// return nullptr; +// } + std::any LLVMIRGenerator::visitLValue(SysYParser::LValueContext* ctx) { std::string varName = ctx->Ident()->getText(); return symbolTable[varName].first; diff --git a/src/LLVMIRGenerator.h b/src/LLVMIRGenerator.h index 20650ab..f5b9430 100644 --- a/src/LLVMIRGenerator.h +++ b/src/LLVMIRGenerator.h @@ -52,4 +52,10 @@ private: std::any visitEqExp(SysYParser::EqExpContext* ctx); std::any visitLAndExp(SysYParser::LAndExpContext* ctx); std::any visitLOrExp(SysYParser::LOrExpContext* ctx); + std::any visitAssignStmt(SysYParser::AssignStmtContext *ctx) override; + std::any visitIfStmt(SysYParser::IfStmtContext *ctx) override; + std::any visitWhileStmt(SysYParser::WhileStmtContext *ctx) override; + std::any visitBreakStmt(SysYParser::BreakStmtContext *ctx) override; + std::any visitContinueStmt(SysYParser::ContinueStmtContext *ctx) override; + std::any visitReturnStmt(SysYParser::ReturnStmtContext *ctx) override; }; \ No newline at end of file