diff --git a/src/DeadCodeElimination.cpp b/src/DeadCodeElimination.cpp index 9abca1c..ffe6022 100644 --- a/src/DeadCodeElimination.cpp +++ b/src/DeadCodeElimination.cpp @@ -1,8 +1,9 @@ #include "DeadCodeElimination.h" +#include +extern int DEBUG; namespace sysy { - void DeadCodeElimination::runDCEPipeline() { const auto& functions = pModule->getFunctions(); for (const auto& function : functions) { @@ -58,6 +59,10 @@ void DeadCodeElimination::eliminateDeadStores(Function* func, bool& changed) { if (changetag) { changed = true; + if(DEBUG){ + std::cout << "=== Dead Store Found ===\n"; + SysYPrinter::printInst(storeInst); + } usedelete(storeInst); iter = instrs.erase(iter); } else { @@ -76,6 +81,10 @@ void DeadCodeElimination::eliminateDeadLoads(Function* func, bool& changed) { if (inst->isBinary() || inst->isUnary() || inst->isLoad()) { if (inst->getUses().empty()) { changed = true; + if(DEBUG){ + std::cout << "=== Dead Load Binary Unary Found ===\n"; + SysYPrinter::printInst(inst); + } usedelete(inst); iter = instrs.erase(iter); continue; @@ -101,6 +110,10 @@ void DeadCodeElimination::eliminateDeadAllocas(Function* func, bool& changed) { func->getEntryBlock()->getArguments().end(), allocaInst) == func->getEntryBlock()->getArguments().end()) { changed = true; + if(DEBUG){ + std::cout << "=== Dead Alloca Found ===\n"; + SysYPrinter::printInst(inst); + } usedelete(inst); iter = instrs.erase(iter); continue; @@ -116,8 +129,12 @@ void DeadCodeElimination::eliminateDeadIndirectiveAllocas(Function* func, bool& FunctionAnalysisInfo* funcInfo = pCFA->getFunctionAnalysisInfo(func); for (auto it = funcInfo->getIndirectAllocas().begin(); it != funcInfo->getIndirectAllocas().end();) { auto &allocaInst = *it; - if (allocaInst->getUses().empty()) { + if (allocaInst->getUses().empty()) { changed = true; + if(DEBUG){ + std::cout << "=== Dead Indirect Alloca Found ===\n"; + SysYPrinter::printInst(allocaInst.get()); + } it = funcInfo->getIndirectAllocas().erase(it); } else { ++it; @@ -132,6 +149,10 @@ void DeadCodeElimination::eliminateDeadGlobals(bool& changed) { auto& global = *it; if (global->getUses().empty()) { changed = true; + if(DEBUG){ + std::cout << "=== Dead Global Found ===\n"; + SysYPrinter::printValue(global.get()); + } it = globals.erase(it); } else { ++it; @@ -207,6 +228,12 @@ void DeadCodeElimination::eliminateDeadRedundantLoadStore(Function* func, bool& // 可以优化直接把prevStorePointer的值存到nextStorePointer changed = true; nextStore->setOperand(0, prevStoreValue); + if(DEBUG){ + std::cout << "=== Dead Store Load Store Found(now only del Load) ===\n"; + SysYPrinter::printInst(prevStore); + SysYPrinter::printInst(loadInst); + SysYPrinter::printInst(nextStore); + } usedelete(loadInst); iter = instrs.erase(iter); // 删除 prevStore 这里是不是可以留给删除无用store处理? diff --git a/src/SysYIRGenerator.cpp b/src/SysYIRGenerator.cpp index 86e19e7..7520891 100644 --- a/src/SysYIRGenerator.cpp +++ b/src/SysYIRGenerator.cpp @@ -204,6 +204,7 @@ std::any SysYIRGenerator::visitFuncType(SysYParser::FuncTypeContext *ctx) { std::any SysYIRGenerator::visitFuncDef(SysYParser::FuncDefContext *ctx){ // 更新作用域 module->enterNewScope(); + HasReturnInst = false; auto name = ctx->Ident()->getText(); std::vector paramTypes; @@ -243,6 +244,18 @@ std::any SysYIRGenerator::visitFuncDef(SysYParser::FuncDefContext *ctx){ visitBlockItem(item); } + if(HasReturnInst == false) { + // 如果没有return语句,则默认返回0 + if (returnType != Type::getVoidType()) { + Value* returnValue = ConstantValue::get(0); + if (returnType == Type::getFloatType()) { + returnValue = ConstantValue::get(0.0f); + } + builder.createReturnInst(returnValue); + } else { + builder.createReturnInst(); + } + } module->leaveScope(); return std::any(); @@ -478,6 +491,7 @@ std::any SysYIRGenerator::visitReturnStmt(SysYParser::ReturnStmtContext *ctx) { } } builder.createReturnInst(returnValue); + HasReturnInst = true; return std::any(); } diff --git a/src/SysYIROptPre.cpp b/src/SysYIROptPre.cpp index 41af234..fb05cb7 100644 --- a/src/SysYIROptPre.cpp +++ b/src/SysYIROptPre.cpp @@ -4,6 +4,7 @@ #include #include #include +#include #include "IR.h" #include "IRBuilder.h" @@ -458,11 +459,13 @@ void SysYOptPre::SysYAddReturn() { // 如果基本块没有后继块,则添加一个返回指令 if (block->getNumInstructions() == 0) { pBuilder->setPosition(block.get(), block->end()); - pBuilder->createReturnInst({}); + pBuilder->createReturnInst(); } auto thelastinst = block->getInstructions().end(); --thelastinst; if (thelastinst->get()->getKind() != Instruction::kReturn) { + // std::cout << "Warning: Function " << func->getName() << " has no return instruction, adding default return." << std::endl; + pBuilder->setPosition(block.get(), block->end()); // TODO: 如果int float函数缺少返回值是否需要报错 if (func->getReturnType()->isInt()) { @@ -470,7 +473,7 @@ void SysYOptPre::SysYAddReturn() { } else if (func->getReturnType()->isFloat()) { pBuilder->createReturnInst(ConstantValue::get(0.0F)); } else { - pBuilder->createReturnInst({}); + pBuilder->createReturnInst(); } } } diff --git a/src/include/DeadCodeElimination.h b/src/include/DeadCodeElimination.h index 2d614bd..72b9935 100644 --- a/src/include/DeadCodeElimination.h +++ b/src/include/DeadCodeElimination.h @@ -2,6 +2,8 @@ #include "IR.h" #include "SysYIRAnalyser.h" +#include "SysYIRPrinter.h" + namespace sysy { class DeadCodeElimination { diff --git a/src/include/SysYIRGenerator.h b/src/include/SysYIRGenerator.h index 445a856..fe309e8 100644 --- a/src/include/SysYIRGenerator.h +++ b/src/include/SysYIRGenerator.h @@ -62,6 +62,8 @@ private: public: SysYIRGenerator() = default; + bool HasReturnInst; + public: Module *get() const { return module.get(); } IRBuilder *getBuilder(){ return &builder; } diff --git a/src/include/SysYIRPrinter.h b/src/include/SysYIRPrinter.h index 114fb05..bfd78bd 100644 --- a/src/include/SysYIRPrinter.h +++ b/src/include/SysYIRPrinter.h @@ -15,15 +15,16 @@ public: public: void printIR(); void printGlobalVariable(); - void printFunction(Function *function); - void printInst(Instruction *pInst); - void printType(Type *type); - void printValue(Value *value); + public: + static void printFunction(Function *function); + static void printInst(Instruction *pInst); + static void printType(Type *type); + static void printValue(Value *value); static std::string getOperandName(Value *operand); - std::string getTypeString(Type *type); - std::string getValueName(Value *value); + static std::string getTypeString(Type *type); + static std::string getValueName(Value *value); }; } // namespace sysy