diff --git a/src/SysYIRGenerator.cpp b/src/SysYIRGenerator.cpp index f221205..967ee0d 100644 --- a/src/SysYIRGenerator.cpp +++ b/src/SysYIRGenerator.cpp @@ -1,7 +1,7 @@ // SysYIRGenerator.cpp #include "SysYIRGenerator.h" #include - +// #TODO浮点数精度还是有问题 std::string SysYIRGenerator::generateIR(SysYParser::CompUnitContext* unit) { visitCompUnit(unit); return irStream.str(); @@ -42,6 +42,17 @@ std::any SysYIRGenerator::visitConstDecl(SysYParser::ConstDeclContext* ctx) { if (constDef->constInitVal()) { std::string value = std::any_cast(constDef->constInitVal()->accept(this)); + if (llvmType == "float") { + try { + double floatValue = std::stod(value); + uint64_t hexValue = reinterpret_cast(floatValue); + std::stringstream ss; + ss << "0x" << std::hex << std::uppercase << hexValue; + value = ss.str(); + } catch (...) { + throw std::runtime_error("Invalid float literal: " + value); + } + } irStream << " store " << llvmType << " " << value << ", " << llvmType << "* " << allocaName << ", align 4\n"; } @@ -76,7 +87,7 @@ std::any SysYIRGenerator::visitVarDef(SysYParser::VarDefContext* ctx) { double floatValue = std::stod(value); uint64_t hexValue = reinterpret_cast(floatValue); std::stringstream ss; - ss << "0x" << std::hex << std::uppercase << hexValue; + ss << "0x" << std::hex << std::uppercase << (hexValue & (0xffffffffUL << 32)); value = ss.str(); } catch (...) { throw std::runtime_error("Invalid float literal: " + value); @@ -94,6 +105,8 @@ std::any SysYIRGenerator::visitFuncDef(SysYParser::FuncDefContext* ctx) { currentFunction = ctx->Ident()->getText(); currentReturnType = getLLVMType(ctx->funcType()->getText()); symbolTable.clear(); + tmpTable.clear(); + tempCounter = 0; hasReturn = false; irStream << "define dso_local " << currentReturnType << " @" << currentFunction << "("; @@ -110,7 +123,22 @@ std::any SysYIRGenerator::visitFuncDef(SysYParser::FuncDefContext* ctx) { } tempCounter++; irStream << ") #0 {\n"; - + + if (ctx->funcFParams()) { + auto params = ctx->funcFParams()->funcFParam(); + for (size_t i = 0; i < params.size(); ++i) { + std::string varName = params[i]->Ident()->getText(); + std::string type = params[i]->bType()->getText(); + std::string llvmType = getLLVMType(type); + std::string allocaName = getNextTemp(); + tmpTable[allocaName] = llvmType; + + irStream << " " << allocaName << " = alloca " << llvmType << ", align 4\n"; + irStream << " store " << llvmType << " " << symbolTable[varName].first << ", " << llvmType + << "* " << allocaName << ", align 4\n"; + + } + } ctx->blockStmt()->accept(this); if (!hasReturn) { @@ -141,13 +169,13 @@ std::any SysYIRGenerator::visitStmt(SysYParser::StmtContext* ctx) { double floatValue = std::stod(rhs); uint64_t hexValue = reinterpret_cast(floatValue); std::stringstream ss; - ss << "0x" << std::hex << std::uppercase << hexValue; + ss << "0x" << std::hex << std::uppercase << (hexValue & (0xffffffffUL << 32)); rhs = ss.str(); } catch (...) { throw std::runtime_error("Invalid float literal: " + rhs); } } - irStream << " store " << lhsType << " " << rhs << ", " << lhsType + irStream << " store1 " << lhsType << " " << rhs << ", " << lhsType << "* " << lhsAlloca << ", align 4\n"; } else if (ctx->RETURN()) { hasReturn = true; @@ -289,7 +317,7 @@ std::any SysYIRGenerator::visitRelExp(SysYParser::RelExpContext* ctx) { std::string right = std::any_cast(addExps[i]->accept(this)); std::string op = ctx->children[2*i-1]->getText(); std::string temp = getNextTemp(); - std::string type = left.substr(0, left.find(' ')); + std::string type = tmpTable[left]; if (op == "<") { irStream << " " << temp << " = icmp slt " << type << " " << left << ", " << right << "\n"; } else if (op == ">") { @@ -311,7 +339,7 @@ std::any SysYIRGenerator::visitEqExp(SysYParser::EqExpContext* ctx) { std::string right = std::any_cast(relExps[i]->accept(this)); std::string op = ctx->children[2*i-1]->getText(); std::string temp = getNextTemp(); - std::string type = left.substr(0, left.find(' ')); + std::string type = tmpTable[left]; if (op == "==") { irStream << " " << temp << " = icmp eq " << type << " " << left << ", " << right << "\n"; } else if (op == "!=") {