// SysYIRGenerator.cpp // TODO:类型转换及其检查 // TODO:sysy库函数处理 // TODO:数组处理 // TODO:对while、continue、break的测试 #include "IR.h" #include #include #include #include #include #include using namespace std; #include "SysYIRGenerator.h" namespace sysy { /* * @brief: visit compUnit * @details: * compUnit: (globalDecl | funcDef)+; */ std::any SysYIRGenerator::visitCompUnit(SysYParser::CompUnitContext *ctx) { // create the IR module auto pModule = new Module(); assert(pModule); module.reset(pModule); // SymbolTable::ModuleScope scope(symbols_table); Utils::initExternalFunction(pModule, &builder); pModule->enterNewScope(); visitChildren(ctx); pModule->leaveScope(); return pModule; } std::any SysYIRGenerator::visitGlobalConstDecl(SysYParser::GlobalConstDeclContext *ctx){ auto constDecl = ctx->constDecl(); Type* type = std::any_cast(visitBType(constDecl->bType())); for (const auto &constDef : constDecl->constDef()) { std::vector dims = {}; std::string name = constDef->Ident()->getText(); auto constExps = constDef->constExp(); if (!constExps.empty()) { for (const auto &constExp : constExps) { dims.push_back(std::any_cast(visitConstExp(constExp))); } } ArrayValueTree* root = std::any_cast(constDef->constInitVal()->accept(this)); ValueCounter values; Utils::tree2Array(type, root, dims, dims.size(), values, &builder); delete root; // 创建全局常量变量,并更新符号表 module->createConstVar(name, Type::getPointerType(type), values, dims); } return std::any(); } std::any SysYIRGenerator::visitGlobalVarDecl(SysYParser::GlobalVarDeclContext *ctx) { auto varDecl = ctx->varDecl(); Type* type = std::any_cast(visitBType(varDecl->bType())); for (const auto &varDef : varDecl->varDef()) { std::vector dims = {}; std::string name = varDef->Ident()->getText(); auto constExps = varDef->constExp(); if (!constExps.empty()) { for (const auto &constExp : constExps) { dims.push_back(std::any_cast(visitConstExp(constExp))); } } ValueCounter values = {}; if (varDef->initVal() != nullptr) { ArrayValueTree* root = std::any_cast(varDef->initVal()->accept(this)); Utils::tree2Array(type, root, dims, dims.size(), values, &builder); delete root; } // 创建全局变量,并更新符号表 module->createGlobalValue(name, Type::getPointerType(type), dims, values); } return std::any(); } std::any SysYIRGenerator::visitConstDecl(SysYParser::ConstDeclContext *ctx){ Type* type = std::any_cast(visitBType(ctx->bType())); for (const auto constDef : ctx->constDef()) { std::vector dims = {}; std::string name = constDef->Ident()->getText(); auto constExps = constDef->constExp(); if (!constExps.empty()) { for (const auto constExp : constExps) { dims.push_back(std::any_cast(visitConstExp(constExp))); } } ArrayValueTree* root = std::any_cast(constDef->constInitVal()->accept(this)); ValueCounter values; Utils::tree2Array(type, root, dims, dims.size(), values, &builder); delete root; module->createConstVar(name, Type::getPointerType(type), values, dims); } return 0; } std::any SysYIRGenerator::visitVarDecl(SysYParser::VarDeclContext *ctx) { Type* type = std::any_cast(visitBType(ctx->bType())); for (const auto varDef : ctx->varDef()) { std::vector dims = {}; std::string name = varDef->Ident()->getText(); auto constExps = varDef->constExp(); if (!constExps.empty()) { for (const auto &constExp : constExps) { dims.push_back(std::any_cast(visitConstExp(constExp))); } } AllocaInst* alloca = builder.createAllocaInst(Type::getPointerType(type), dims, name); if (varDef->initVal() != nullptr) { ValueCounter values; // 这里的varDef->initVal()可能是ScalarInitValue或ArrayInitValue ArrayValueTree* root = std::any_cast(varDef->initVal()->accept(this)); Utils::tree2Array(type, root, dims, dims.size(), values, &builder); delete root; if (dims.empty()) { builder.createStoreInst(values.getValue(0), alloca); } else { // 对于多维数组,使用memset初始化 // 计算每个维度的大小 // 这里的values.getNumbers()返回的是每个维度的大小 // 这里的values.getValues()返回的是每个维度对应的值 // 例如:对于一个二维数组,values.getNumbers()可能是[3, 4],表示3行4列 // values.getValues()可能是[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12] // 对于每个维度,使用memset将对应的值填充到数组中 // 这里的alloca是一个指向数组的指针 const std::vector & counterNumbers = values.getNumbers(); const std::vector & counterValues = values.getValues(); unsigned begin = 0; for (size_t i = 0; i < counterNumbers.size(); i++) { builder.createMemsetInst( alloca, ConstantValue::get(static_cast(begin)), ConstantValue::get(static_cast(counterNumbers[i])), counterValues[i]); begin += counterNumbers[i]; } } } module->addVariable(name, alloca); } return std::any(); } std::any SysYIRGenerator::visitBType(SysYParser::BTypeContext *ctx) { return ctx->INT() != nullptr ? Type::getIntType() : Type::getFloatType(); } std::any SysYIRGenerator::visitScalarInitValue(SysYParser::ScalarInitValueContext *ctx) { Value* value = std::any_cast(visitExp(ctx->exp())); ArrayValueTree* result = new ArrayValueTree(); result->setValue(value); return result; } std::any SysYIRGenerator::visitArrayInitValue(SysYParser::ArrayInitValueContext *ctx) { std::vector children; for (const auto &initVal : ctx->initVal()) children.push_back(std::any_cast(initVal->accept(this))); ArrayValueTree* result = new ArrayValueTree(); result->addChildren(children); return result; } std::any SysYIRGenerator::visitConstScalarInitValue(SysYParser::ConstScalarInitValueContext *ctx) { Value* value = std::any_cast(visitConstExp(ctx->constExp())); ArrayValueTree* result = new ArrayValueTree(); result->setValue(value); return result; } std::any SysYIRGenerator::visitConstArrayInitValue(SysYParser::ConstArrayInitValueContext *ctx) { std::vector children; for (const auto &constInitVal : ctx->constInitVal()) children.push_back(std::any_cast(constInitVal->accept(this))); ArrayValueTree* result = new ArrayValueTree(); result->addChildren(children); return result; } std::any SysYIRGenerator::visitFuncType(SysYParser::FuncTypeContext *ctx) { if (ctx->INT() != nullptr) return Type::getIntType(); if (ctx->FLOAT() != nullptr) return Type::getFloatType(); return Type::getVoidType(); } std::any SysYIRGenerator::visitFuncDef(SysYParser::FuncDefContext *ctx){ // 更新作用域 module->enterNewScope(); HasReturnInst = false; auto name = ctx->Ident()->getText(); std::vector paramTypes; std::vector paramNames; std::vector> paramDims; if (ctx->funcFParams() != nullptr) { auto params = ctx->funcFParams()->funcFParam(); for (const auto ¶m : params) { paramTypes.push_back(std::any_cast(visitBType(param->bType()))); paramNames.push_back(param->Ident()->getText()); std::vector dims = {}; if (!param->LBRACK().empty()) { dims.push_back(ConstantValue::get(-1)); // 第一个维度不确定 for (const auto &exp : param->exp()) { dims.push_back(std::any_cast(visitExp(exp))); } } paramDims.emplace_back(dims); } } Type* returnType = std::any_cast(visitFuncType(ctx->funcType())); Type* funcType = Type::getFunctionType(returnType, paramTypes); Function* function = module->createFunction(name, funcType); BasicBlock* entry = function->getEntryBlock(); builder.setPosition(entry, entry->end()); for (size_t i = 0; i < paramTypes.size(); ++i) { AllocaInst* alloca = builder.createAllocaInst(Type::getPointerType(paramTypes[i]), paramDims[i], paramNames[i]); entry->insertArgument(alloca); module->addVariable(paramNames[i], alloca); } for (auto item : ctx->blockStmt()->blockItem()) { visitBlockItem(item); } if(HasReturnInst == false) { // 如果没有return语句,则默认返回0 if (returnType != Type::getVoidType()) { Value* returnValue = ConstantValue::get(0); if (returnType == Type::getFloatType()) { returnValue = ConstantValue::get(0.0f); } builder.createReturnInst(returnValue); } else { builder.createReturnInst(); } } module->leaveScope(); return std::any(); } std::any SysYIRGenerator::visitBlockStmt(SysYParser::BlockStmtContext *ctx) { module->enterNewScope(); for (auto item : ctx->blockItem()) visitBlockItem(item); module->leaveScope(); return 0; } std::any SysYIRGenerator::visitAssignStmt(SysYParser::AssignStmtContext *ctx) { auto lVal = ctx->lValue(); std::string name = lVal->Ident()->getText(); std::vector dims; for (const auto &exp : lVal->exp()) { dims.push_back(std::any_cast(visitExp(exp))); } auto variable = module->getVariable(name); Value* value = std::any_cast(visitExp(ctx->exp())); Type* variableType = dynamic_cast(variable->getType())->getBaseType(); // 左值右值类型不同处理 if (variableType != value->getType()) { ConstantValue * constValue = dynamic_cast(value); if (constValue != nullptr) { if (variableType == Type::getFloatType()) { value = ConstantValue::get(static_cast(constValue->getInt())); } else { value = ConstantValue::get(static_cast(constValue->getFloat())); } } else { if (variableType == Type::getFloatType()) { value = builder.createIToFInst(value); } else { value = builder.createFtoIInst(value); } } } builder.createStoreInst(value, variable, dims, variable->getName()); return std::any(); } std::any SysYIRGenerator::visitIfStmt(SysYParser::IfStmtContext *ctx) { // labels string stream std::stringstream labelstring; Function * function = builder.getBasicBlock()->getParent(); BasicBlock* thenBlock = new BasicBlock(function); BasicBlock* exitBlock = new BasicBlock(function); if (ctx->stmt().size() > 1) { BasicBlock* elseBlock = new BasicBlock(function); builder.pushTrueBlock(thenBlock); builder.pushFalseBlock(elseBlock); // 访问条件表达式 visitCond(ctx->cond()); builder.popTrueBlock(); builder.popFalseBlock(); labelstring << "if_then.L" << builder.getLabelIndex(); thenBlock->setName(labelstring.str()); labelstring.str(""); function->addBasicBlock(thenBlock); builder.setPosition(thenBlock, thenBlock->end()); auto block = dynamic_cast(ctx->stmt(0)); // 如果是块语句,直接访问 // 否则访问语句 if (block != nullptr) { visitBlockStmt(block); } else { module->enterNewScope(); ctx->stmt(0)->accept(this); module->leaveScope(); } builder.createUncondBrInst(exitBlock, {}); BasicBlock::conectBlocks(builder.getBasicBlock(), exitBlock); labelstring << "if_else.L" << builder.getLabelIndex(); elseBlock->setName(labelstring.str()); labelstring.str(""); function->addBasicBlock(elseBlock); builder.setPosition(elseBlock, elseBlock->end()); block = dynamic_cast(ctx->stmt(1)); if (block != nullptr) { visitBlockStmt(block); } else { module->enterNewScope(); ctx->stmt(1)->accept(this); module->leaveScope(); } builder.createUncondBrInst(exitBlock, {}); BasicBlock::conectBlocks(builder.getBasicBlock(), exitBlock); labelstring << "if_exit.L" << builder.getLabelIndex(); exitBlock->setName(labelstring.str()); labelstring.str(""); function->addBasicBlock(exitBlock); builder.setPosition(exitBlock, exitBlock->end()); } else { builder.pushTrueBlock(thenBlock); builder.pushFalseBlock(exitBlock); visitCond(ctx->cond()); builder.popTrueBlock(); builder.popFalseBlock(); labelstring << "if_then.L" << builder.getLabelIndex(); thenBlock->setName(labelstring.str()); labelstring.str(""); function->addBasicBlock(thenBlock); builder.setPosition(thenBlock, thenBlock->end()); auto block = dynamic_cast(ctx->stmt(0)); if (block != nullptr) { visitBlockStmt(block); } else { module->enterNewScope(); ctx->stmt(0)->accept(this); module->leaveScope(); } BasicBlock::conectBlocks(builder.getBasicBlock(), exitBlock); labelstring << "if_exit.L" << builder.getLabelIndex(); exitBlock->setName(labelstring.str()); labelstring.str(""); function->addBasicBlock(exitBlock); builder.setPosition(exitBlock, exitBlock->end()); } return std::any(); } std::any SysYIRGenerator::visitWhileStmt(SysYParser::WhileStmtContext *ctx) { // while structure: // curblock -> headBlock -> bodyBlock -> exitBlock BasicBlock* curBlock = builder.getBasicBlock(); Function* function = builder.getBasicBlock()->getParent(); std::stringstream labelstring; labelstring << "while_head.L" << builder.getLabelIndex(); BasicBlock *headBlock = function->addBasicBlock(labelstring.str()); labelstring.str(""); BasicBlock::conectBlocks(curBlock, headBlock); builder.setPosition(headBlock, headBlock->end()); BasicBlock* bodyBlock = new BasicBlock(function); BasicBlock* exitBlock = new BasicBlock(function); builder.pushTrueBlock(bodyBlock); builder.pushFalseBlock(exitBlock); // 访问条件表达式 visitCond(ctx->cond()); builder.popTrueBlock(); builder.popFalseBlock(); labelstring << "while_body.L" << builder.getLabelIndex(); bodyBlock->setName(labelstring.str()); labelstring.str(""); function->addBasicBlock(bodyBlock); builder.setPosition(bodyBlock, bodyBlock->end()); builder.pushBreakBlock(exitBlock); builder.pushContinueBlock(headBlock); auto block = dynamic_cast(ctx->stmt()); if( block != nullptr) { visitBlockStmt(block); } else { module->enterNewScope(); ctx->stmt()->accept(this); module->leaveScope(); } builder.createUncondBrInst(headBlock, {}); BasicBlock::conectBlocks(builder.getBasicBlock(), exitBlock); builder.popBreakBlock(); builder.popContinueBlock(); labelstring << "while_exit.L" << builder.getLabelIndex(); exitBlock->setName(labelstring.str()); labelstring.str(""); function->addBasicBlock(exitBlock); builder.setPosition(exitBlock, exitBlock->end()); return std::any(); } std::any SysYIRGenerator::visitBreakStmt(SysYParser::BreakStmtContext *ctx) { BasicBlock* breakBlock = builder.getBreakBlock(); builder.createUncondBrInst(breakBlock, {}); BasicBlock::conectBlocks(builder.getBasicBlock(), breakBlock); return std::any(); } std::any SysYIRGenerator::visitContinueStmt(SysYParser::ContinueStmtContext *ctx) { BasicBlock* continueBlock = builder.getContinueBlock(); builder.createUncondBrInst(continueBlock, {}); BasicBlock::conectBlocks(builder.getBasicBlock(), continueBlock); return std::any(); } std::any SysYIRGenerator::visitReturnStmt(SysYParser::ReturnStmtContext *ctx) { Value* returnValue = nullptr; if (ctx->exp() != nullptr) { returnValue = std::any_cast(visitExp(ctx->exp())); } Type* funcType = builder.getBasicBlock()->getParent()->getReturnType(); if (funcType!= returnValue->getType() && returnValue != nullptr) { ConstantValue * constValue = dynamic_cast(returnValue); if (constValue != nullptr) { if (funcType == Type::getFloatType()) { returnValue = ConstantValue::get(static_cast(constValue->getInt())); } else { returnValue = ConstantValue::get(static_cast(constValue->getFloat())); } } else { if (funcType == Type::getFloatType()) { returnValue = builder.createIToFInst(returnValue); } else { returnValue = builder.createFtoIInst(returnValue); } } } builder.createReturnInst(returnValue); HasReturnInst = true; return std::any(); } std::any SysYIRGenerator::visitLValue(SysYParser::LValueContext *ctx) { std::string name = ctx->Ident()->getText(); User* variable = module->getVariable(name); Value* value = nullptr; std::vector dims; for (const auto &exp : ctx->exp()) { dims.push_back(std::any_cast(visitExp(exp))); } if (variable == nullptr) { throw std::runtime_error("Variable " + name + " not found."); } bool indicesConstant = true; for (const auto &dim : dims) { if (dynamic_cast(dim) == nullptr) { indicesConstant = false; break; } } ConstantVariable* constVar = dynamic_cast(variable); GlobalValue* globalVar = dynamic_cast(variable); AllocaInst* localVar = dynamic_cast(variable); if (constVar != nullptr && indicesConstant) { // 如果是常量变量,且索引是常量,则直接获取子数组 value = constVar->getByIndices(dims); } else if (module->isInGlobalArea() && (globalVar != nullptr)) { assert(indicesConstant); value = globalVar->getByIndices(dims); } else { if ((globalVar != nullptr && globalVar->getNumDims() > dims.size()) || (localVar != nullptr && localVar->getNumDims() > dims.size()) || (constVar != nullptr && constVar->getNumDims() > dims.size())) { // value = builder.createLaInst(variable, indices); // 如果变量是全局变量或局部变量,且索引数量小于维度数量,则创建createGetSubArray获取子数组 auto getArrayInst = builder.createGetSubArray(dynamic_cast(variable), dims); value = getArrayInst->getChildArray(); } else { value = builder.createLoadInst(variable, dims); } } return value; } std::any SysYIRGenerator::visitPrimaryExp(SysYParser::PrimaryExpContext *ctx) { if (ctx->exp() != nullptr) return visitExp(ctx->exp()); if (ctx->lValue() != nullptr) return visitLValue(ctx->lValue()); if (ctx->number() != nullptr) return visitNumber(ctx->number()); if (ctx->string() != nullptr) { cout << "String literal not supported in SysYIRGenerator." << endl; } return visitNumber(ctx->number()); } std::any SysYIRGenerator::visitNumber(SysYParser::NumberContext *ctx) { if (ctx->ILITERAL() != nullptr) { int value = std::stol(ctx->ILITERAL()->getText(), nullptr, 0); return static_cast(ConstantValue::get(value)); } else if (ctx->FLITERAL() != nullptr) { float value = std::stof(ctx->FLITERAL()->getText()); return static_cast(ConstantValue::get(value)); } throw std::runtime_error("Unknown number type."); return std::any(); // 不会到达这里 } std::any SysYIRGenerator::visitCall(SysYParser::CallContext *ctx) { std::string funcName = ctx->Ident()->getText(); Function *function = module->getFunction(funcName); if (function == nullptr) { function = module->getExternalFunction(funcName); if (function == nullptr) { std::cout << "The function " << funcName << " no defined." << std::endl; assert(function); } } std::vector args = {}; if (funcName == "starttime" || funcName == "stoptime") { // 如果是starttime或stoptime函数 // TODO: 这里需要处理starttime和stoptime函数的参数 // args.emplace_back() } else { if (ctx->funcRParams() != nullptr) { args = std::any_cast>(visitFuncRParams(ctx->funcRParams())); } auto params = function->getEntryBlock()->getArguments(); for (size_t i = 0; i < args.size(); i++) { // 参数类型转换 if (params[i]->getType() != args[i]->getType() && (params[i]->getNumDims() != 0 || params[i]->getType()->as()->getBaseType() != args[i]->getType())) { ConstantValue * constValue = dynamic_cast(args[i]); if (constValue != nullptr) { if (params[i]->getType() == Type::getPointerType(Type::getFloatType())) { args[i] = ConstantValue::get(static_cast(constValue->getInt())); } else { args[i] = ConstantValue::get(static_cast(constValue->getFloat())); } } else { if (params[i]->getType() == Type::getPointerType(Type::getFloatType())) { args[i] = builder.createIToFInst(args[i]); } else { args[i] = builder.createFtoIInst(args[i]); } } } } } return static_cast(builder.createCallInst(function, args)); } std::any SysYIRGenerator::visitUnaryExp(SysYParser::UnaryExpContext *ctx) { if (ctx->primaryExp() != nullptr) return visitPrimaryExp(ctx->primaryExp()); if (ctx->call() != nullptr) return visitCall(ctx->call()); Value* value = std::any_cast(visitUnaryExp(ctx->unaryExp())); Value* result = value; if (ctx->unaryOp()->SUB() != nullptr) { ConstantValue * constValue = dynamic_cast(value); if (constValue != nullptr) { if (constValue->isFloat()) { result = ConstantValue::get(-constValue->getFloat()); } else { result = ConstantValue::get(-constValue->getInt()); } } else if (value != nullptr) { if (value->getType() == Type::getIntType()) { result = builder.createNegInst(value); } else { result = builder.createFNegInst(value); } } else { std::cout << "UnExp: value is nullptr." << std::endl; assert(false); } } else if (ctx->unaryOp()->NOT() != nullptr) { auto constValue = dynamic_cast(value); if (constValue != nullptr) { if (constValue->isFloat()) { result = ConstantValue::get(1 - (constValue->getFloat() != 0.0F ? 1 : 0)); } else { result = ConstantValue::get(1 - (constValue->getInt() != 0 ? 1 : 0)); } } else if (value != nullptr) { if (value->getType() == Type::getIntType()) { result = builder.createNotInst(value); } else { result = builder.createFNotInst(value); } } else { std::cout << "UnExp: value is nullptr." << std::endl; assert(false); } } return result; } std::any SysYIRGenerator::visitFuncRParams(SysYParser::FuncRParamsContext *ctx) { std::vector params; for (const auto &exp : ctx->exp()) params.push_back(std::any_cast(visitExp(exp))); return params; } std::any SysYIRGenerator::visitMulExp(SysYParser::MulExpContext *ctx) { Value * result = std::any_cast(visitUnaryExp(ctx->unaryExp(0))); for (size_t i = 1; i < ctx->unaryExp().size(); i++) { auto opNode = dynamic_cast(ctx->children[2*i-1]); int opType = opNode->getSymbol()->getType(); Value* operand = std::any_cast(visitUnaryExp(ctx->unaryExp(i))); Type* resultType = result->getType(); Type* operandType = operand->getType(); Type* floatType = Type::getFloatType(); if (resultType == floatType || operandType == floatType) { // 如果有一个操作数是浮点数,则将两个操作数都转换为浮点数 if (operandType != floatType) { ConstantValue * constValue = dynamic_cast(operand); if (constValue != nullptr) operand = ConstantValue::get(static_cast(constValue->getInt())); else operand = builder.createIToFInst(operand); } else if (resultType != floatType) { ConstantValue* constResult = dynamic_cast(result); if (constResult != nullptr) result = ConstantValue::get(static_cast(constResult->getInt())); else result = builder.createIToFInst(result); } ConstantValue* constResult = dynamic_cast(result); ConstantValue* constOperand = dynamic_cast(operand); if (opType == SysYParser::MUL) { if ((constOperand != nullptr) && (constResult != nullptr)) { result = ConstantValue::get(constResult->getFloat() * constOperand->getFloat()); } else { result = builder.createFMulInst(result, operand); } } else if (opType == SysYParser::DIV) { if ((constOperand != nullptr) && (constResult != nullptr)) { result = ConstantValue::get(constResult->getFloat() / constOperand->getFloat()); } else { result = builder.createFDivInst(result, operand); } } else { // float类型的取模操作不允许 std::cout << "MulExp: float type mod operation is not allowed." << std::endl; assert(false); } } else { ConstantValue * constResult = dynamic_cast(result); ConstantValue * constOperand = dynamic_cast(operand); if (opType == SysYParser::MUL) { if ((constOperand != nullptr) && (constResult != nullptr)) result = ConstantValue::get(constResult->getInt() * constOperand->getInt()); else result = builder.createMulInst(result, operand); } else if (opType == SysYParser::DIV) { if ((constOperand != nullptr) && (constResult != nullptr)) result = ConstantValue::get(constResult->getInt() / constOperand->getInt()); else result = builder.createDivInst(result, operand); } else { if ((constOperand != nullptr) && (constResult != nullptr)) result = ConstantValue::get(constResult->getInt() % constOperand->getInt()); else result = builder.createRemInst(result, operand); } } } return result; } std::any SysYIRGenerator::visitAddExp(SysYParser::AddExpContext *ctx) { Value* result = std::any_cast(visitMulExp(ctx->mulExp(0))); for (size_t i = 1; i < ctx->mulExp().size(); i++) { auto opNode = dynamic_cast(ctx->children[2*i-1]); int opType = opNode->getSymbol()->getType(); Value* operand = std::any_cast(visitMulExp(ctx->mulExp(i))); Type* resultType = result->getType(); Type* operandType = operand->getType(); Type* floatType = Type::getFloatType(); if (resultType == floatType || operandType == floatType) { // 类型转换 if (operandType != floatType) { ConstantValue * constOperand = dynamic_cast(operand); if (constOperand != nullptr) operand = ConstantValue::get(static_cast(constOperand->getInt())); else operand = builder.createIToFInst(operand); } else if (resultType != floatType) { ConstantValue * constResult = dynamic_cast(result); if (constResult != nullptr) result = ConstantValue::get(static_cast(constResult->getInt())); else result = builder.createIToFInst(result); } ConstantValue * constResult = dynamic_cast(result); ConstantValue * constOperand = dynamic_cast(operand); if (opType == SysYParser::ADD) { if ((constResult != nullptr) && (constOperand != nullptr)) result = ConstantValue::get(constResult->getFloat() + constOperand->getFloat()); else result = builder.createFAddInst(result, operand); } else { if ((constResult != nullptr) && (constOperand != nullptr)) result = ConstantValue::get(constResult->getFloat() - constOperand->getFloat()); else result = builder.createFSubInst(result, operand); } } else { ConstantValue * constResult = dynamic_cast(result); ConstantValue * constOperand = dynamic_cast(operand); if (opType == SysYParser::ADD) { if ((constResult != nullptr) && (constOperand != nullptr)) result = ConstantValue::get(constResult->getInt() + constOperand->getInt()); else result = builder.createAddInst(result, operand); } else { if ((constResult != nullptr) && (constOperand != nullptr)) result = ConstantValue::get(constResult->getInt() - constOperand->getInt()); else result = builder.createSubInst(result, operand); } } } return result; } std::any SysYIRGenerator::visitRelExp(SysYParser::RelExpContext *ctx) { Value* result = std::any_cast(visitAddExp(ctx->addExp(0))); for (size_t i = 1; i < ctx->addExp().size(); i++) { auto opNode = dynamic_cast(ctx->children[2*i-1]); int opType = opNode->getSymbol()->getType(); Value* operand = std::any_cast(visitAddExp(ctx->addExp(i))); Type* resultType = result->getType(); Type* operandType = operand->getType(); ConstantValue* constResult = dynamic_cast(result); ConstantValue* constOperand = dynamic_cast(operand); // 常量比较 if ((constResult != nullptr) && (constOperand != nullptr)) { auto operand1 = constResult->isFloat() ? constResult->getFloat() : constResult->getInt(); auto operand2 = constOperand->isFloat() ? constOperand->getFloat() : constOperand->getInt(); if (opType == SysYParser::LT) result = ConstantValue::get(operand1 < operand2 ? 1 : 0); else if (opType == SysYParser::GT) result = ConstantValue::get(operand1 > operand2 ? 1 : 0); else if (opType == SysYParser::LE) result = ConstantValue::get(operand1 <= operand2 ? 1 : 0); else if (opType == SysYParser::GE) result = ConstantValue::get(operand1 >= operand2 ? 1 : 0); else assert(false); } else { Type* resultType = result->getType(); Type* operandType = operand->getType(); Type* floatType = Type::getFloatType(); // 浮点数处理 if (resultType == floatType || operandType == floatType) { if (resultType != floatType) { if (constResult != nullptr) result = ConstantValue::get(static_cast(constResult->getInt())); else result = builder.createIToFInst(result); } if (operandType != floatType) { if (constOperand != nullptr) operand = ConstantValue::get(static_cast(constOperand->getInt())); else operand = builder.createIToFInst(operand); } if (opType == SysYParser::LT) result = builder.createFCmpLTInst(result, operand); else if (opType == SysYParser::GT) result = builder.createFCmpGTInst(result, operand); else if (opType == SysYParser::LE) result = builder.createFCmpLEInst(result, operand); else if (opType == SysYParser::GE) result = builder.createFCmpGEInst(result, operand); else assert(false); } else { // 整数处理 if (opType == SysYParser::LT) result = builder.createICmpLTInst(result, operand); else if (opType == SysYParser::GT) result = builder.createICmpGTInst(result, operand); else if (opType == SysYParser::LE) result = builder.createICmpLEInst(result, operand); else if (opType == SysYParser::GE) result = builder.createICmpGEInst(result, operand); else assert(false); } } } return result; } std::any SysYIRGenerator::visitEqExp(SysYParser::EqExpContext *ctx) { Value * result = std::any_cast(visitRelExp(ctx->relExp(0))); for (size_t i = 1; i < ctx->relExp().size(); i++) { auto opNode = dynamic_cast(ctx->children[2*i-1]); int opType = opNode->getSymbol()->getType(); Value * operand = std::any_cast(visitRelExp(ctx->relExp(i))); ConstantValue* constResult = dynamic_cast(result); ConstantValue* constOperand = dynamic_cast(operand); if ((constResult != nullptr) && (constOperand != nullptr)) { auto operand1 = constResult->isFloat() ? constResult->getFloat() : constResult->getInt(); auto operand2 = constOperand->isFloat() ? constOperand->getFloat() : constOperand->getInt(); if (opType == SysYParser::EQ) result = ConstantValue::get(operand1 == operand2 ? 1 : 0); else if (opType == SysYParser::NE) result = ConstantValue::get(operand1 != operand2 ? 1 : 0); else assert(false); } else { Type* resultType = result->getType(); Type* operandType = operand->getType(); Type* floatType = Type::getFloatType(); if (resultType == floatType || operandType == floatType) { if (resultType != floatType) { if (constResult != nullptr) result = ConstantValue::get(static_cast(constResult->getInt())); else result = builder.createIToFInst(result); } if (operandType != floatType) { if (constOperand != nullptr) operand = ConstantValue::get(static_cast(constOperand->getInt())); else operand = builder.createIToFInst(operand); } if (opType == SysYParser::EQ) result = builder.createFCmpEQInst(result, operand); else if (opType == SysYParser::NE) result = builder.createFCmpNEInst(result, operand); else assert(false); } else { if (opType == SysYParser::EQ) result = builder.createICmpEQInst(result, operand); else if (opType == SysYParser::NE) result = builder.createICmpNEInst(result, operand); else assert(false); } } } if (ctx->relExp().size() == 1) { ConstantValue * constResult = dynamic_cast(result); // 如果只有一个关系表达式,则将结果转换为0或1 if (constResult != nullptr) { if (constResult->isFloat()) result = ConstantValue::get(constResult->getFloat() != 0.0F ? 1 : 0); else result = ConstantValue::get(constResult->getInt() != 0 ? 1 : 0); } } return result; } std::any SysYIRGenerator::visitLAndExp(SysYParser::LAndExpContext *ctx){ std::stringstream labelstring; BasicBlock *curBlock = builder.getBasicBlock(); Function *function = builder.getBasicBlock()->getParent(); BasicBlock *trueBlock = builder.getTrueBlock(); BasicBlock *falseBlock = builder.getFalseBlock(); auto conds = ctx->eqExp(); for (size_t i = 0; i < conds.size() - 1; i++) { labelstring << "AND.L" << builder.getLabelIndex(); BasicBlock *newtrueBlock = function->addBasicBlock(labelstring.str()); labelstring.str(""); auto cond = std::any_cast(visitEqExp(ctx->eqExp(i))); builder.createCondBrInst(cond, newtrueBlock, falseBlock, {}, {}); BasicBlock::conectBlocks(curBlock, newtrueBlock); BasicBlock::conectBlocks(curBlock, falseBlock); curBlock = newtrueBlock; builder.setPosition(curBlock, curBlock->end()); } auto cond = std::any_cast(visitEqExp(conds.back())); builder.createCondBrInst(cond, trueBlock, falseBlock, {}, {}); BasicBlock::conectBlocks(curBlock, trueBlock); BasicBlock::conectBlocks(curBlock, falseBlock); return std::any(); } auto SysYIRGenerator::visitLOrExp(SysYParser::LOrExpContext *ctx) -> std::any { std::stringstream labelstring; BasicBlock *curBlock = builder.getBasicBlock(); Function *function = curBlock->getParent(); auto conds = ctx->lAndExp(); for (size_t i = 0; i < conds.size() - 1; i++) { labelstring << "OR.L" << builder.getLabelIndex(); BasicBlock *newFalseBlock = function->addBasicBlock(labelstring.str()); labelstring.str(""); builder.pushFalseBlock(newFalseBlock); visitLAndExp(ctx->lAndExp(i)); builder.popFalseBlock(); builder.setPosition(newFalseBlock, newFalseBlock->end()); } visitLAndExp(conds.back()); return std::any(); } void Utils::tree2Array(Type *type, ArrayValueTree *root, const std::vector &dims, unsigned numDims, ValueCounter &result, IRBuilder *builder) { Value* value = root->getValue(); auto &children = root->getChildren(); if (value != nullptr) { if (type == value->getType()) { result.push_back(value); } else { if (type == Type::getFloatType()) { ConstantValue* constValue = dynamic_cast(value); if (constValue != nullptr) result.push_back(ConstantValue::get(static_cast(constValue->getInt()))); else result.push_back(builder->createIToFInst(value)); } else { ConstantValue* constValue = dynamic_cast(value); if (constValue != nullptr) result.push_back(ConstantValue::get(static_cast(constValue->getFloat()))); else result.push_back(builder->createFtoIInst(value)); } } return; } auto beforeSize = result.size(); for (const auto &child : children) { int begin = result.size(); int newNumDims = 0; for (unsigned i = 0; i < numDims - 1; i++) { auto dim = dynamic_cast(*(dims.rbegin() + i))->getInt(); if (begin % dim == 0) { newNumDims += 1; begin /= dim; } else { break; } } tree2Array(type, child.get(), dims, newNumDims, result, builder); } auto afterSize = result.size(); int blockSize = 1; for (unsigned i = 0; i < numDims; i++) { blockSize *= dynamic_cast(*(dims.rbegin() + i))->getInt(); } int num = blockSize - afterSize + beforeSize; if (num > 0) { if (type == Type::getFloatType()) result.push_back(ConstantValue::get(0.0F), num); else result.push_back(ConstantValue::get(0), num); } } void Utils::createExternalFunction( const std::vector ¶mTypes, const std::vector ¶mNames, const std::vector> ¶mDims, Type *returnType, const std::string &funcName, Module *pModule, IRBuilder *pBuilder) { auto funcType = Type::getFunctionType(returnType, paramTypes); auto function = pModule->createExternalFunction(funcName, funcType); auto entry = function->getEntryBlock(); pBuilder->setPosition(entry, entry->end()); for (size_t i = 0; i < paramTypes.size(); ++i) { auto alloca = pBuilder->createAllocaInst( Type::getPointerType(paramTypes[i]), paramDims[i], paramNames[i]); entry->insertArgument(alloca); // pModule->addVariable(paramNames[i], alloca); } } void Utils::initExternalFunction(Module *pModule, IRBuilder *pBuilder) { std::vector paramTypes; std::vector paramNames; std::vector> paramDims; Type *returnType; std::string funcName; returnType = Type::getIntType(); funcName = "getint"; Utils::createExternalFunction(paramTypes, paramNames, paramDims, returnType, funcName, pModule, pBuilder); funcName = "getch"; Utils::createExternalFunction(paramTypes, paramNames, paramDims, returnType, funcName, pModule, pBuilder); paramTypes.push_back(Type::getIntType()); paramNames.emplace_back("x"); paramDims.push_back(std::vector{ConstantValue::get(-1)}); funcName = "getarray"; Utils::createExternalFunction(paramTypes, paramNames, paramDims, returnType, funcName, pModule, pBuilder); returnType = Type::getFloatType(); paramTypes.clear(); paramNames.clear(); paramDims.clear(); funcName = "getfloat"; Utils::createExternalFunction(paramTypes, paramNames, paramDims, returnType, funcName, pModule, pBuilder); returnType = Type::getIntType(); paramTypes.push_back(Type::getFloatType()); paramNames.emplace_back("x"); paramDims.push_back(std::vector{ConstantValue::get(-1)}); funcName = "getfarray"; Utils::createExternalFunction(paramTypes, paramNames, paramDims, returnType, funcName, pModule, pBuilder); returnType = Type::getVoidType(); paramTypes.clear(); paramTypes.push_back(Type::getIntType()); paramDims.clear(); paramDims.emplace_back(); funcName = "putint"; Utils::createExternalFunction(paramTypes, paramNames, paramDims, returnType, funcName, pModule, pBuilder); funcName = "putch"; Utils::createExternalFunction(paramTypes, paramNames, paramDims, returnType, funcName, pModule, pBuilder); paramTypes.clear(); paramTypes.push_back(Type::getIntType()); paramTypes.push_back(Type::getIntType()); paramDims.clear(); paramDims.emplace_back(); paramDims.push_back(std::vector{ConstantValue::get(-1)}); paramNames.clear(); paramNames.emplace_back("n"); paramNames.emplace_back("a"); funcName = "putarray"; Utils::createExternalFunction(paramTypes, paramNames, paramDims, returnType, funcName, pModule, pBuilder); paramTypes.clear(); paramTypes.push_back(Type::getFloatType()); paramDims.clear(); paramDims.emplace_back(); paramNames.clear(); paramNames.emplace_back("a"); funcName = "putfloat"; Utils::createExternalFunction(paramTypes, paramNames, paramDims, returnType, funcName, pModule, pBuilder); paramTypes.clear(); paramTypes.push_back(Type::getIntType()); paramTypes.push_back(Type::getFloatType()); paramDims.clear(); paramDims.emplace_back(); paramDims.push_back(std::vector{ConstantValue::get(-1)}); paramNames.clear(); paramNames.emplace_back("n"); paramNames.emplace_back("a"); funcName = "putfarray"; Utils::createExternalFunction(paramTypes, paramNames, paramDims, returnType, funcName, pModule, pBuilder); paramTypes.clear(); paramTypes.push_back(Type::getIntType()); paramDims.clear(); paramDims.emplace_back(); paramNames.clear(); paramNames.emplace_back("__LINE__"); funcName = "starttime"; Utils::createExternalFunction(paramTypes, paramNames, paramDims, returnType, funcName, pModule, pBuilder); paramTypes.clear(); paramTypes.push_back(Type::getIntType()); paramDims.clear(); paramDims.emplace_back(); paramNames.clear(); paramNames.emplace_back("__LINE__"); funcName = "stoptime"; Utils::createExternalFunction(paramTypes, paramNames, paramDims, returnType, funcName, pModule, pBuilder); } } // namespace sysy