From 0a04c816cf409db2abb154b9a7b950760387f788 Mon Sep 17 00:00:00 2001 From: rain2133 <1370973498@qq.com> Date: Sat, 21 Jun 2025 18:06:29 +0800 Subject: [PATCH] =?UTF-8?q?=E6=9B=B4=E6=96=B0IR=EF=BC=8C.g4=E4=BF=AE?= =?UTF-8?q?=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/SysY.g4 | 8 +- src/SysYIRGenerator.cpp | 457 ++++++++++++++++++++++++++++++++-- src/include/IR.h | 1 + src/include/SysYIRGenerator.h | 1 + 4 files changed, 446 insertions(+), 21 deletions(-) diff --git a/src/SysY.g4 b/src/SysY.g4 index a9e4208..d614ec4 100644 --- a/src/SysY.g4 +++ b/src/SysY.g4 @@ -153,10 +153,10 @@ cond: lOrExp; lValue: Ident (LBRACK exp RBRACK)*; // 为了方便测试 primaryExp 可以是一个string -primaryExp: LPAREN exp RPAREN #parenExp - | lValue #lVal - | number #num - | string #str; +primaryExp: LPAREN exp RPAREN + | lValue + | number + | string; number: ILITERAL | FLITERAL; unaryExp: primaryExp #primExp diff --git a/src/SysYIRGenerator.cpp b/src/SysYIRGenerator.cpp index d23bf28..6abdeff 100644 --- a/src/SysYIRGenerator.cpp +++ b/src/SysYIRGenerator.cpp @@ -382,7 +382,8 @@ std::any SysYIRGenerator::visitIfStmt(SysYParser::IfStmtContext *ctx) { } std::any SysYIRGenerator::visitWhileStmt(SysYParser::WhileStmtContext *ctx) { - + // while structure: + // curblock -> headBlock -> bodyBlock -> exitBlock BasicBlock* curBlock = builder.getBasicBlock(); Function* function = builder.getBasicBlock()->getParent(); @@ -390,18 +391,16 @@ std::any SysYIRGenerator::visitWhileStmt(SysYParser::WhileStmtContext *ctx) { labelstring << "head.L" << builder.getLabelIndex(); BasicBlock *headBlock = function->addBasicBlock(labelstring.str()); labelstring.str(""); - BasicBlock::conectBlocks(curBlock, headBlock); - builder.setPosition(headBlock, headBlock->end()) + builder.setPosition(headBlock, headBlock->end()); - function->addBasicBlock(condBlock); - builder.setPosition(condBlock, condBlock->end()); + BasicBlock* bodyBlock = new BasicBlock(function); + BasicBlock* exitBlock = new BasicBlock(function); builder.pushTrueBlock(bodyBlock); builder.pushFalseBlock(exitBlock); - + // 访问条件表达式 visitCond(ctx->cond()); - builder.popTrueBlock(); builder.popFalseBlock(); @@ -411,27 +410,451 @@ std::any SysYIRGenerator::visitWhileStmt(SysYParser::WhileStmtContext *ctx) { function->addBasicBlock(bodyBlock); builder.setPosition(bodyBlock, bodyBlock->end()); - module->enterNewScope(); - for (auto item : ctx->blockStmt()->blockItem()) { - visitBlockItem(item); - } - module->leaveScope(); - - builder.createUncondBrInst(condBlock, {}); + builder.pushBreakBlock(exitBlock); + builder.pushContinueBlock(headBlock); - BasicBlock::conectBlocks(builder.getBasicBlock(), condBlock); + auto block = dynamic_cast(ctx->stmt()); + + if( block != nullptr) { + visitBlockStmt(block); + } else { + module->enterNewScope(); + ctx->stmt()->accept(this); + module->leaveScope(); + } + + builder.createUncondBrInst(headBlock, {}); + BasicBlock::conectBlocks(builder.getBasicBlock(), exitBlock); + builder.popBreakBlock(); + builder.popContinueBlock(); labelstring << "exit.L" << builder.getLabelIndex(); exitBlock->setName(labelstring.str()); labelstring.str(""); - function->addBasicBlock(exitBlock); - builder.setPosition(exitBlock, exitBlock->end()); return std::any(); } +std::any SysYIRGenerator::visitBreakStmt(SysYParser::BreakStmtContext *ctx) { + BasicBlock* breakBlock = builder.getBreakBlock(); + builder.pushBreakBlock(breakBlock); + BasicBlock::conectBlocks(builder.getBasicBlock(), breakBlock); + return std::any(); +} + +std::any SysYIRGenerator::visitContinueStmt(SysYParser::ContinueStmtContext *ctx) { + BasicBlock* continueBlock = builder.getContinueBlock(); + builder.createUncondBrInst(continueBlock, {}); + return std::any(); +} + +std::any SysYIRGenerator::visitReturnStmt(SysYParser::ReturnStmtContext *ctx) { + Value* returnValue = nullptr; + if (ctx->exp() != nullptr) { + returnValue = std::any_cast(visitExp(ctx->exp())); + } + + Type* funcType = builder.getBasicBlock()->getParent()->getType(); + if (funcType!= returnValue->getType() && returnValue != nullptr) { + ConstantValue * constValue = dynamic_cast(returnValue); + if (constValue != nullptr) { + if (funcType == Type::getFloatType()) { + returnValue = ConstantValue::get(static_cast(constValue->getInt())); + } else { + returnValue = ConstantValue::get(static_cast(constValue->getFloat())); + } + } else { + if (funcType == Type::getFloatType()) { + returnValue = builder.createIToFInst(returnValue); + } else { + returnValue = builder.createFtoIInst(returnValue); + } + } + } + builder.createRetInst(returnValue); + return std::any(); +} + + +std::any SysYIRGenerator::visitLVal(SysYParser::LValContext *ctx) { + std::string name = ctx->Ident()->getText(); + User* variable = module->getVariable(name); + + Value* value = nullptr; + std::vector dims; + for (const auto &exp : ctx->exp()) { + dims.push_back(std::any_cast(visitExp(exp))); + } + + if (variable == nullptr) { + throw std::runtime_error("Variable " + name + " not found."); + } + + bool indicesConstant = true; + for (const auto &index : indices) { + if (dynamic_cast(index) == nullptr) { + indicesConstant = false; + break; + } + } + + ConstantVariable* constVar = dynamic_cast(variable); + GlobalValue* globalVar = dynamic_cast(variable); + AllocaInst* localVar = dynamic_cast(variable); + if (constVar != nullptr && indicesConstant) { + // 如果是常量变量,且索引是常量,则直接获取子数组 + value = constVar->getByIndices(indices); + } else if (module->isInGlobalArea() && (globalVar != nullptr)) { + assert(indicesConstant); + value = globalVar->getByIndices(indices); + } else { + if ((globalVar != nullptr && globalVar->getNumDims() > indices.size()) || + (localVar != nullptr && localVar->getNumDims() > indices.size()) || + (constVar != nullptr && constVar->getNumDims() > indices.size())) { + // value = builder.createLaInst(variable, indices); + // 如果变量是全局变量或局部变量,且索引数量小于维度数量,则创建createGetSubArray获取子数组 + auto getArrayInst = + builder.createGetSubArray(dynamic_cast(variable), indices); + value = getArrayInst->getChildArray(); + } else { + value = builder.createLoadInst(variable, indices); + } + } + + return value; +} + +std::any SysYIRGenerator::visitPrimaryExp(SysYParser::PrimaryExpContext *ctx) { + if (ctx->exp() != nullptr) + return visitExp(ctx->exp()); + if (ctx->lVal() != nullptr) + return visitLVal(ctx->lVal()); + if (ctx->number() != nullptr) + return visitNumber(ctx->number()); + // if (ctx->string() != nullptr) { + // std::string str = ctx->string()->getText(); + // str = str.substr(1, str.size() - 2); // 去掉双引号 + // return ConstantValue::get(str); + // } + return visitNumber(ctx->number()); +} + +std::any SysYIRGenerator::visitNumber(SysYParser::NumberContext *ctx) { + if (ctx->ILITERAL() != nullptr) { + int value = std::stol(ctx->ILITERAL()->getText(), nullptr, 0); + return static_cast(ConstantValue::get(Type::getIntType(), value)); + } else if (ctx->FLITERAL() != nullptr) { + float value = std::stof(ctx->FLITERAL()->getText()); + return static_cast(ConstantValue::get(Type::getFloatType(), value)); + } + throw std::runtime_error("Unknown number type."); + return std::any(); // 不会到达这里 +} + +std::any SysYIRGenerator::visitCall(SysYParser::CallContext *ctx) { + std::string funcName = ctx->Ident()->getText(); + Function *function = module->getFunction(funcName); + if (function == nullptr) { + function = module->getExternalFunction(name); + if (function == nullptr) { + std::cout << "The function " << name << " no defined." << std::endl; + assert(function); + } + } + + std::vector args = {}; + if (name == "starttime" || name == "stoptime") { + // 如果是starttime或stoptime函数 + // TODO: 这里需要处理starttime和stoptime函数的参数 + // args.emplace_back() + } else { + if (ctx->funcRParams() != nullptr) { + args = std::any_cast>(visitFuncRParams(ctx->funcRParams())); + } + + auto params = function->getEntryBlock()->getArguments(); + for (size_t i = 0; i < args.size(); i++) { + // 参数类型转换 + if (params[i]->getType() != args[i]->getType() && + (params[i]->getNumDims() != 0 || + params[i]->getType()->as()->getBaseType() != args[i]->getType())) { + ConstantValue * constValue = dynamic_cast(args[i]); + if (constValue != nullptr) { + if (params[i]->getType() == Type::getPointerType(Type::getFloatType())) { + args[i] = ConstantValue::get(static_cast(constValue->getInt())); + } else { + args[i] = ConstantValue::get(static_cast(constValue->getFloat())); + } + } else { + if (params[i]->getType() == Type::getPointerType(Type::getFloatType())) { + args[i] = builder.createIToFInst(args[i]); + } else { + args[i] = builder.createFtoIInst(args[i]); + } + } + } + } + } + + return static_cast(builder.createCallInst(function, args)); +} + +std::any SysYIRGenerator::visitUnExp(SysYParser::UnExpContext *ctx) { + Value* value = std::any_cast(visitUnaryExp(ctx->unaryExp())); + Value* result = value; + if (ctx->unaryOp()->SUB() != nullptr) { + ConstantValue * constValue = dynamic_cast(value); + if (constValue != nullptr) { + if (constValue->isFloat()) { + result = ConstantValue::get(-constValue->getFloat()); + } else { + result = ConstantValue::get(-constValue->getInt()); + } + } else if (value != nullptr) { + if (value->getType() == Type::getIntType()) { + result = builder.createNegInst(value); + } else { + result = builder.createFNegInst(value); + } + } else { + std::cout << "UnExp: value is nullptr." << std::endl; + assert(false); + } + } else if (ctx->unaryOp()->NOT() != nullptr) { + auto constValue = dynamic_cast(value); + if (constValue != nullptr) { + if (constValue->isFloat()) { + result = + ConstantValue::get(1 - (constValue->getFloat() != 0.0F ? 1 : 0)); + } else { + result = ConstantValue::get(1 - (constValue->getInt() != 0 ? 1 : 0)); + } + } else if (value != nullptr) { + if (value->getType() == Type::getIntType()) { + result = builder.createNotInst(value); + } else { + result = builder.createFNotInst(value); + } + } else { + std::cout << "UnExp: value is nullptr." << std::endl; + assert(false); + } + } + return result; +} + +std::any SysYIRGenerator::visitFuncRParams(SysYParser::FuncRParamsContext *ctx) { + std::vector params; + for (const auto &exp : ctx->exp()) + params.push_back(std::any_cast(visitExp(exp))); + return params; +} + + +std::any SysYIRGenerator::visitMulExp(SysYParser::MulExpContext *ctx) { + auto result = std::any_cast(visitUnaryExp(ctx->unaryExp(0))); + + for (size_t i = 1; i < ctx->unaryExp().size(); i++) { + auto op = ctx->mulOp(i - 1); + Value* operand = std::any_cast(visitUnaryExp(ctx->unaryExp(i))); + + Type* resultType = result->getType(); + Type* operandType = operand->getType(); + + if (resultType == Type::getFloatType() || operandType == Type::getFloatType()) { + // 如果有一个操作数是浮点数,则将两个操作数都转换为浮点数 + if (operandType != Type::getFloatType()) { + ConstantValue * constValue = dynamic_cast(operand); + if (constValue != nullptr) + operand = ConstantValue::get(static_cast(constValue->getInt())); + else + operand = builder.createIToFInst(operand); + } else if (resultType != Type::getFloatType()) { + ConstantValue* constResult = dynamic_cast(result); + if (constResult != nullptr) + result = ConstantValue::get(static_cast(constResult->getInt())); + else + result = builder.createIToFInst(result); + } + + ConstantValue* constResult = dynamic_cast(result); + ConstantValue* constOperand = dynamic_cast(operand); + if (op->MUL() != nullptr) { + if ((constOperand != nullptr) && (constResult != nullptr)) { + result = ConstantValue::get(constResult->getFloat() * + constOperand->getFloat()); + } else { + result = builder.createFMulInst(result, operand); + } + } else if (op->DIV() != nullptr) { + if ((constOperand != nullptr) && (constResult != nullptr)) { + result = ConstantValue::get(constResult->getFloat() / + constOperand->getFloat()); + } else { + result = builder.createFDivInst(result, operand); + } + } else { + // float类型的取模操作不允许 + std::cout << "MulExp: float type mod operation is not allowed." << std::endl; + assert(false); + } + } else { + ConstantValue * constResult = dynamic_cast(result); + ConstantValue * constOperand = dynamic_cast(operand); + if (op->MUL() != nullptr) { + if ((constOperand != nullptr) && (constResult != nullptr)) + result = ConstantValue::get(constResult->getInt() * constOperand->getInt()); + else + result = builder.createMulInst(result, operand); + } else if (op->DIV() != nullptr) { + if ((constOperand != nullptr) && (constResult != nullptr)) + result = ConstantValue::get(constResult->getInt() / constOperand->getInt()); + else + result = builder.createDivInst(result, operand); + } else { + if ((constOperand != nullptr) && (constResult != nullptr)) + result = ConstantValue::get(constResult->getInt() % constOperand->getInt()); + else + result = builder.createRemInst(result, operand); + } + } + } + + return result; +} + + +std::any SysYIRGenerator::visitAddExp(SysYParser::AddExpContext *ctx) { + Value* result = std::any_cast(visitMulExp(ctx->mulExp(0))); + + for (size_t i = 1; i < ctx->mulExp().size(); i++) { + auto op = ctx->addOp(i - 1); + + Value* operand = std::any_cast(visitMulExp(ctx->mulExp(i))); + Type* resultType = result->getType(); + Type* operandType = operand->getType(); + + if (resultType == Type::getFloatType() || operandType == Type::getFloatType()) { + // 类型转换 + if (operandType != Type::getFloatType()) { + Value* constOperand = dynamic_cast(operand); + if (constOperand != nullptr) + operand = ConstantValue::get(static_cast(constOperand->getInt())); + else + operand = builder.createIToFInst(operand); + } else if (resultType != Type::getFloatType()) { + Value* constResult = dynamic_cast(result); + if (constResult != nullptr) + result = ConstantValue::get(static_cast(constResult->getInt())); + else + result = builder.createIToFInst(result); + } + + Value* constResult = dynamic_cast(result); + Value* constOperand = dynamic_cast(operand); + if (op->ADD() != nullptr) { + if ((constResult != nullptr) && (constOperand != nullptr)) + result = ConstantValue::get(constResult->getFloat() + constOperand->getFloat()); + else + result = builder.createFAddInst(result, operand); + } else { + if ((constResult != nullptr) && (constOperand != nullptr)) + result = ConstantValue::get(constResult->getFloat() - constOperand->getFloat()); + else + result = builder.createFSubInst(result, operand); + } + } else { + Value* constResult = dynamic_cast(result); + Value* constOperand = dynamic_cast(operand); + if (op->ADD() != nullptr) { + if ((constResult != nullptr) && (constOperand != nullptr)) + result = ConstantValue::get(constResult->getInt() + constOperand->getInt()); + else + result = builder.createAddInst(result, operand); + } else { + if ((constResult != nullptr) && (constOperand != nullptr)) + result = ConstantValue::get(constResult->getInt() - constOperand->getInt()); + else + result = builder.createSubInst(result, operand); + } + } + } + + return result; +} + +std:any SysYIRGenerator::visitRelExp(SysYParser::RelExpContext *ctx) { + Value* result = std::any_cast(visitAddExp(ctx->addExp(0))); + + for (size_t i = 1; i < ctx->addExp().size(); i++) { + auto op = ctx->relOp(i - 1); + Value* operand = std::any_cast(visitAddExp(ctx->addExp(i))); + + Type* resultType = result->getType(); + Type* operandType = operand->getType(); + + ConstantValue* constResult = dynamic_cast(result); + ConstantValue* constOperand = dynamic_cast(operand); + + // 常量比较 + if ((constResult != nullptr) && (constOperand != nullptr)) { + auto operand1 = constResult->isFloat() ? constResult->getFloat() + : constResult->getInt(); + auto operand2 = constOperand->isFloat() ? constOperand->getFloat() + : constOperand->getInt(); + + if (op->LT() != nullptr) result = ConstantValue::get(operand1 < operand2 ? 1 : 0); + else if (op->GT() != nullptr) result = ConstantValue::get(operand1 > operand2 ? 1 : 0); + else if (op->LE() != nullptr) result = ConstantValue::get(operand1 <= operand2 ? 1 : 0); + else if (op->GE() != nullptr) result = ConstantValue::get(operand1 >= operand2 ? 1 : 0); + else assert(false); + + } else { + Type* resultType = result->getType(); + Type* operandType = operand->getType(); + Type* floatType = Type::getFloatType(); + + // 浮点数处理 + if (resultType == floatType || operandType == floatType) { + if (resultType != floatType) { + if (constResult != nullptr) + result = ConstantValue::get(static_cast(constResult->getInt())); + else + result = builder.createIToFInst(result); + + } + if (operandType != floatType) { + if (constOperand != nullptr) + operand = ConstantValue::get(static_cast(constOperand->getInt())); + else + operand = builder.createIToFInst(operand); + + } + + if (op->LT() != nullptr) result = builder.createFCmpLTInst(result, operand); + else if (op->GT() != nullptr) result = builder.createFCmpGTInst(result, operand); + else if (op->LE() != nullptr) result = builder.createFCmpLEInst(result, operand); + else if (op->GE() != nullptr) result = builder.createFCmpGEInst(result, operand); + else assert(false); + + } else { + // 整数处理 + if (op->LT() != nullptr) result = builder.createICmpLTInst(result, operand); + else if (op->GT() != nullptr) result = builder.createICmpGTInst(result, operand); + else if (op->LE() != nullptr) result = builder.createICmpLEInst(result, operand); + else if (op->GE() != nullptr) result = builder.createICmpGEInst(result, operand); + else assert(false); + + } + } + } + + return result; +} + + void Utils::tree2Array(Type *type, ArrayValueTree *root, const std::vector &dims, unsigned numDims, ValueCounter &result, IRBuilder *builder) { diff --git a/src/include/IR.h b/src/include/IR.h index 1de2e23..3182a9a 100644 --- a/src/include/IR.h +++ b/src/include/IR.h @@ -1575,6 +1575,7 @@ class ConstantVariable : public User, public LVal { Value* getByIndex(unsigned index) const { return initValues.getValue(index); } ///< 通过一维位置index获取值 Value* getByIndices(const std::vector &indices) const { int index = 0; + // 计算偏移量 for (size_t i = 0; i < indices.size(); i++) { index = dynamic_cast(getDim(i))->getInt() * index + dynamic_cast(indices[i])->getInt(); diff --git a/src/include/SysYIRGenerator.h b/src/include/SysYIRGenerator.h index a5f5a91..e203016 100644 --- a/src/include/SysYIRGenerator.h +++ b/src/include/SysYIRGenerator.h @@ -120,6 +120,7 @@ public: // std::any visitUnaryOp(SysYParser::UnaryOpContext *ctx) override; std::any visitUnExp(SysYParser::UnExpContext *ctx) override; + std::any visitFuncRParams(SysYParser::FuncRParamsContext *ctx) override; std::any visitMulExp(SysYParser::MulExpContext *ctx) override; std::any visitAddExp(SysYParser::AddExpContext *ctx) override;