diff --git a/src/include/midend/IR.h b/src/include/midend/IR.h index 241f28f..73a93cc 100644 --- a/src/include/midend/IR.h +++ b/src/include/midend/IR.h @@ -98,7 +98,6 @@ class PointerType : public Type { public: Type* getBaseType() const { return baseType; } ///< 获取指向的类型 - void print(std::ostream& os) const override; }; class FunctionType : public Type { @@ -118,7 +117,6 @@ class FunctionType : public Type { Type* getReturnType() const { return returnType; } ///< 获取返回值类信息 auto getParamTypes() const { return make_range(paramTypes); } ///< 获取形参类型列表 unsigned getNumParams() const { return paramTypes.size(); } ///< 获取形参数量 - void print(std::ostream& os) const override; }; class ArrayType : public Type { @@ -135,7 +133,6 @@ class ArrayType : public Type { : Type(Kind::kArray), elementType(elementType), numElements(numElements) {} Type *elementType; unsigned numElements; // 当前维度的大小 - void print(std::ostream& os) const override; }; /*! @@ -980,16 +977,10 @@ class CallInst : public Instruction { friend class IRBuilder; protected: - CallInst(Function *callee, const std::vector &args, BasicBlock *parent = nullptr, const std::string &name = "") - : Instruction(kCall, callee->getReturnType(), parent, name) { - addOperand(callee); - for (auto arg : args) { - addOperand(arg); - } - } + CallInst(Function *callee, const std::vector &args, BasicBlock *parent = nullptr, const std::string &name = ""); public: - Function *getCallee() const { return dynamic_cast(getOperand(0)); } + Function *getCallee() const; auto getArguments() const { return make_range(std::next(operand_begin()), operand_end()); } diff --git a/src/midend/IR.cpp b/src/midend/IR.cpp index 37e1a9a..c0901cf 100644 --- a/src/midend/IR.cpp +++ b/src/midend/IR.cpp @@ -209,6 +209,62 @@ ArrayType *ArrayType::get(Type *elementType, unsigned numElements) { return result.first->get(); } +void Argument::print(std::ostream& os) const { + os << *getType() << " %" << getName(); +} + +void GlobalValue::print(std::ostream& os) const { + // 输出全局变量的LLVM IR格式 + os << "@" << getName() << " = global "; + + // 输出初始化值 + if (initValues.size() > 0) { + os << *getType()->as()->getBaseType() << " "; + if (initValues.size() == 1) { + // 单个初始值 + initValues.getValue(0)->print(os); + } else { + // 数组初始值 + os << "["; + for (unsigned i = 0; i < initValues.size(); ++i) { + if (i > 0) os << ", "; + auto value = initValues.getValue(i); + os << *value->getType() << " "; + value->print(os); + } + os << "]"; + } + } else { + os << *getType()->as()->getBaseType() << " zeroinitializer"; + } +} + +void ConstantVariable::print(std::ostream& os) const { + // 输出常量的LLVM IR格式 + os << "@" << getName() << " = constant "; + + // 输出初始化值 + if (initValues.size() > 0) { + os << *getType()->as()->getBaseType() << " "; + if (initValues.size() == 1) { + // 单个初始值 + initValues.getValue(0)->print(os); + } else { + // 数组初始值 + os << "["; + for (unsigned i = 0; i < initValues.size(); ++i) { + if (i > 0) os << ", "; + auto value = initValues.getValue(i); + os << *value->getType() << " "; + value->print(os); + } + os << "]"; + } + } else { + os << *getType()->as()->getBaseType() << " zeroinitializer"; + } +} + // void Value::replaceAllUsesWith(Value *value) { // for (auto &use : uses) { // auto user = use->getUser(); @@ -349,6 +405,18 @@ void PhiInst::print(std::ostream &os) const { } } +CallInst::CallInst(Function *callee, const std::vector &args, BasicBlock *parent, const std::string &name) + : Instruction(kCall, callee->getReturnType(), parent, name) { + addOperand(callee); + for (auto arg : args) { + addOperand(arg); + } +} + +Function *CallInst::getCallee() const { + return dynamic_cast(getOperand(0)); +} + void CallInst::print(std::ostream &os) const { if(!getType()->isVoid()) { printVarName(os, this) << " = "; @@ -488,9 +556,9 @@ void MemsetInst::print(std::ostream &os) const { os << "call void @llvm.memset.p0i8.i32(i8* "; printOperand(os, getPointer()); os << ", i8 "; - printOperand(os, getOperand(3)); // value + printOperand(os, getValue()); // value os << ", i32 "; - printOperand(os, getOperand(2)); // size + printOperand(os, getSize()); // size os << ", i1 false)"; } @@ -779,4 +847,146 @@ auto BasicBlock::moveInst(iterator sourcePos, iterator targetPos, BasicBlock *bl return instructions.erase(sourcePos); } +/** + * 为Value重命名以符合LLVM IR格式 + */ +void renameValues(Function* function) { + std::unordered_map valueNames; + unsigned tempCounter = 0; + unsigned labelCounter = 0; + + // 检查名字是否需要重命名(只有纯数字或空名字才需要重命名) + auto needsRename = [](const std::string& name) { + if (name.empty()) return true; + + // 检查是否为纯数字 + for (char c : name) { + if (!std::isdigit(c)) { + return false; // 包含非数字字符,不需要重命名 + } + } + return true; // 纯数字或空字符串,需要重命名 + }; + + // 重命名函数参数 + for (auto arg : function->getArguments()) { + if (needsRename(arg->getName())) { + valueNames[arg] = "%" + std::to_string(tempCounter++); + arg->setName(valueNames[arg].substr(1)); // 去掉%前缀,因为printVarName会加上 + } + } + + // 重命名基本块 + for (auto& block : function->getBasicBlocks()) { + if (needsRename(block->getName())) { + valueNames[block.get()] = "label" + std::to_string(labelCounter++); + block->setName(valueNames[block.get()]); + } + } + + // 重命名指令 + for (auto& block : function->getBasicBlocks()) { + for (auto& inst : block->getInstructions()) { + // 只有产生值的指令需要重命名 + if (!inst->getType()->isVoid() && needsRename(inst->getName())) { + valueNames[inst.get()] = "%" + std::to_string(tempCounter++); + inst->setName(valueNames[inst.get()].substr(1)); // 去掉%前缀 + } + } + } +} + +void Function::print(std::ostream& os) const { + // 重命名所有值 + auto* mutableThis = const_cast(this); + renameValues(mutableThis); + + // 打印函数签名 + os << "define " << *getReturnType() << " "; + printFunctionName(os, this); + os << "("; + + // 打印参数列表 + const auto& args = const_cast(this)->getArguments(); + for (size_t i = 0; i < args.size(); ++i) { + if (i > 0) os << ", "; + os << *args[i]->getType() << " "; + printVarName(os, args[i]); + } + os << ") {\n"; + + // 打印基本块 + for (auto& block : const_cast(this)->getBasicBlocks()) { + block->print(os); + } + + os << "}\n"; +} + +void Module::print(std::ostream& os) const { + // 打印全局变量声明 + for (auto& globalVar : const_cast(this)->getGlobals()) { + printVarName(os, globalVar.get()); + os << " = global " << *globalVar->getType()->as()->getBaseType(); + + // 打印初始值 + const auto& initValues = globalVar->getInitValues(); + if (initValues.size() > 0) { + os << " "; + // 简化处理:只打印第一个初始值 + initValues.getValue(0)->print(os); + } else { + // 默认初始化 + if (globalVar->getType()->as()->getBaseType()->isInt()) { + os << " 0"; + } else if (globalVar->getType()->as()->getBaseType()->isFloat()) { + os << " 0.0"; + } else { + os << " zeroinitializer"; + } + } + os << "\n"; + } + + // 打印常量声明 + for (auto& constVar : getConsts()) { + printVarName(os, constVar.get()); + os << " = constant " << *constVar->getType()->as()->getBaseType(); + os << " "; + const auto& initValues = constVar->getInitValues(); + if (initValues.size() > 0) { + initValues.getValue(0)->print(os); + } else { + os << "0"; + } + os << "\n"; + } + + // 打印外部函数声明 + for (auto& extFunc : getExternalFunctions()) { + os << "declare " << *extFunc.second->getReturnType() << " "; + printFunctionName(os, extFunc.second.get()); + os << "("; + + const auto& paramTypes = extFunc.second->getParamTypes(); + bool first = true; + for (auto paramType : paramTypes) { + if (!first) os << ", "; + os << *paramType; + first = false; + } + os << ")\n"; + } + + if (!getExternalFunctions().empty()) { + os << "\n"; // 外部函数和普通函数之间加空行 + } + + // 打印函数定义 + for (auto& func : const_cast(this)->getFunctions()) { + func.second->print(os); + os << "\n"; // 函数之间加空行 + } +} + } // namespace sysy