diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index d68d41a..8851182 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -16,6 +16,7 @@ add_executable(sysyc IR.cpp SysYIRGenerator.cpp # Backend.cpp + SysYIRPrinter.cpp RISCv32Backend.cpp ) target_include_directories(sysyc PRIVATE ${CMAKE_CURRENT_BINARY_DIR} ${CMAKE_CURRENT_SOURCE_DIR}/include) diff --git a/src/SysYIRPrinter.cpp b/src/SysYIRPrinter.cpp new file mode 100644 index 0000000..5be7777 --- /dev/null +++ b/src/SysYIRPrinter.cpp @@ -0,0 +1,481 @@ +#include "SysYIRPrinter.h" +#include +#include +#include +#include +#include "IR.h" + +namespace sysy { + +void SysYPrinter::printIR() { + + const auto &functions = pModule->getFunctions(); + + // Print target datalayout and triple (minimal required by LLVM) + std::cout << "target datalayout = \"e-m:e-p270:32:32-p271:32:32-p272:64:64-f64:32:64-f80:32-n8:16:32-S128\"\n"; + std::cout << "target triple = \"i386-pc-linux-gnu\"\n\n"; + + printGlobalVariable(); + + for (const auto &iter : functions) { + if (iter.second->getName() == "main") { + printFunction(iter.second.get()); + break; + } + } + + for (const auto &iter : functions) { + if (iter.second->getName() != "main") { + printFunction(iter.second.get()); + } + } +} + +std::string SysYPrinter::getTypeString(Type *type) { + if (type->isVoid()) { + return "void"; + } else if (type->isInt()) { + return "i32"; + } else if (type->isFloat()) { + return "float"; + + } else if (auto ptrType = dynamic_cast(type)) { + return getTypeString(ptrType->getBaseType()) + "*"; + } else if (auto ptrType = dynamic_cast(type)) { + return getTypeString(ptrType->getReturnType()); + } + assert(false && "Unsupported type"); + return ""; +} + +std::string SysYPrinter::getValueName(Value *value) { + if (auto global = dynamic_cast(value)) { + return "@" + global->getName(); + } else if (auto inst = dynamic_cast(value)) { + return "%" + inst->getName(); + } else if (auto constVal = dynamic_cast(value)) { + if (constVal->isFloat()) { + return std::to_string(constVal->getFloat()); + } + return std::to_string(constVal->getInt()); + } else if (auto constVar = dynamic_cast(value)) { + return constVar->getName(); + } + assert(false && "Unknown value type"); + return ""; +} + +void SysYPrinter::printType(Type *type) { + std::cout << getTypeString(type); +} + +void SysYPrinter::printValue(Value *value) { + std::cout << getValueName(value); +} + +void SysYPrinter::printGlobalVariable() { + auto &globals = pModule->getGlobals(); + + for (const auto &global : globals) { + std::cout << "@" << global->getName() << " = global "; + + auto baseType = dynamic_cast(global->getType())->getBaseType(); + printType(baseType); + + if (global->getNumDims() > 0) { + // Array type + std::cout << " ["; + for (unsigned i = 0; i < global->getNumDims(); i++) { + if (i > 0) std::cout << " x "; + std::cout << getValueName(global->getDim(i)); + } + std::cout << "]"; + } + + std::cout << " "; + + if (global->getNumDims() > 0) { + // Array initializer + std::cout << "["; + auto values = global->getInitValues(); + auto counterValues = values.getValues(); + auto counterNumbers = values.getNumbers(); + + for (size_t i = 0; i < counterNumbers.size(); i++) { + if (i > 0) std::cout << ", "; + if (baseType->isFloat()) { + std::cout << "float " << dynamic_cast(counterValues[i])->getFloat(); + } else { + std::cout << "i32 " << dynamic_cast(counterValues[i])->getInt(); + } + } + std::cout << "]"; + } else { + // Scalar initializer + if (baseType->isFloat()) { + std::cout << "float " << dynamic_cast(global->getByIndex(0))->getFloat(); + } else { + std::cout << "i32 " << dynamic_cast(global->getByIndex(0))->getInt(); + } + } + + std::cout << ", align 4" << std::endl; + } +} + +void SysYPrinter::printFunction(Function *function) { + // Function signature + std::cout << "define "; + printType(function->getReturnType()); + std::cout << " @" << function->getName() << "("; + + auto entryBlock = function->getEntryBlock(); + auto &args = entryBlock->getArguments(); + + for (size_t i = 0; i < args.size(); i++) { + if (i > 0) std::cout << ", "; + printType(args[i]->getType()); + std::cout << " %" << args[i]->getName(); + } + + std::cout << ") {" << std::endl; + + // Function body + for (const auto &blockIter : function->getBasicBlocks()) { + // Basic block label + BasicBlock* blockPtr = blockIter.get(); + if (blockPtr == function->getEntryBlock()) { + std::cout << "entry:" << std::endl; + } else if (!blockPtr->getName().empty()) { + std::cout << blockPtr->getName() << ":" << std::endl; + } + + // Instructions + for (const auto &instIter : blockIter->getInstructions()) { + auto inst = instIter.get(); + std::cout << " "; + printInst(inst); + } + } + + std::cout << "}" << std::endl << std::endl; +} + +void SysYPrinter::printInst(Instruction *pInst) { + using Kind = Instruction::Kind; + + switch (pInst->getKind()) { + case Kind::kAdd: + case Kind::kSub: + case Kind::kMul: + case Kind::kDiv: + case Kind::kRem: + case Kind::kFAdd: + case Kind::kFSub: + case Kind::kFMul: + case Kind::kFDiv: + case Kind::kICmpEQ: + case Kind::kICmpNE: + case Kind::kICmpLT: + case Kind::kICmpGT: + case Kind::kICmpLE: + case Kind::kICmpGE: + case Kind::kFCmpEQ: + case Kind::kFCmpNE: + case Kind::kFCmpLT: + case Kind::kFCmpGT: + case Kind::kFCmpLE: + case Kind::kFCmpGE: + case Kind::kAnd: + case Kind::kOr: { + auto binInst = dynamic_cast(pInst); + + // Print result variable if exists + if (!binInst->getName().empty()) { + std::cout << "%" << binInst->getName() << " = "; + } + + // Operation name + switch (pInst->getKind()) { + case Kind::kAdd: std::cout << "add"; break; + case Kind::kSub: std::cout << "sub"; break; + case Kind::kMul: std::cout << "mul"; break; + case Kind::kDiv: std::cout << "sdiv"; break; + case Kind::kRem: std::cout << "srem"; break; + case Kind::kFAdd: std::cout << "fadd"; break; + case Kind::kFSub: std::cout << "fsub"; break; + case Kind::kFMul: std::cout << "fmul"; break; + case Kind::kFDiv: std::cout << "fdiv"; break; + case Kind::kICmpEQ: std::cout << "icmp eq"; break; + case Kind::kICmpNE: std::cout << "icmp ne"; break; + case Kind::kICmpLT: std::cout << "icmp slt"; break; + case Kind::kICmpGT: std::cout << "icmp sgt"; break; + case Kind::kICmpLE: std::cout << "icmp sle"; break; + case Kind::kICmpGE: std::cout << "icmp sge"; break; + case Kind::kFCmpEQ: std::cout << "fcmp oeq"; break; + case Kind::kFCmpNE: std::cout << "fcmp one"; break; + case Kind::kFCmpLT: std::cout << "fcmp olt"; break; + case Kind::kFCmpGT: std::cout << "fcmp ogt"; break; + case Kind::kFCmpLE: std::cout << "fcmp ole"; break; + case Kind::kFCmpGE: std::cout << "fcmp oge"; break; + case Kind::kAnd: std::cout << "and"; break; + case Kind::kOr: std::cout << "or"; break; + default: break; + } + + // Types and operands + std::cout << " "; + printType(binInst->getType()); + std::cout << " "; + printValue(binInst->getLhs()); + std::cout << ", "; + printValue(binInst->getRhs()); + + std::cout << std::endl; + } break; + + case Kind::kNeg: + case Kind::kNot: + case Kind::kFNeg: + case Kind::kFNot: + case Kind::kFtoI: + case Kind::kBitFtoI: + case Kind::kItoF: + case Kind::kBitItoF: { + auto unyInst = dynamic_cast(pInst); + + if (!unyInst->getName().empty()) { + std::cout << "%" << unyInst->getName() << " = "; + } + + switch (pInst->getKind()) { + case Kind::kNeg: std::cout << "sub "; break; + case Kind::kNot: std::cout << "xor "; break; + case Kind::kFNeg: std::cout << "fneg "; break; + case Kind::kFNot: std::cout << "fneg "; break; // FNot not standard, map to fneg + case Kind::kFtoI: std::cout << "fptosi "; break; + case Kind::kBitFtoI: std::cout << "bitcast "; break; + case Kind::kItoF: std::cout << "sitofp "; break; + case Kind::kBitItoF: std::cout << "bitcast "; break; + default: break; + } + + printType(unyInst->getType()); + std::cout << " "; + + // Special handling for negation + if (pInst->getKind() == Kind::kNeg || pInst->getKind() == Kind::kNot) { + std::cout << "i32 0, "; + } + + printValue(pInst->getOperand(0)); + + // For bitcast, need to specify destination type + if (pInst->getKind() == Kind::kBitFtoI || pInst->getKind() == Kind::kBitItoF) { + std::cout << " to "; + printType(unyInst->getType()); + } + + std::cout << std::endl; + } break; + + case Kind::kCall: { + auto callInst = dynamic_cast(pInst); + auto function = callInst->getCallee(); + + if (!callInst->getName().empty()) { + std::cout << "%" << callInst->getName() << " = "; + } + + std::cout << "call "; + printType(callInst->getType()); + std::cout << " @" << function->getName() << "("; + + auto params = callInst->getArguments(); + bool first = true; + for (auto ¶m : params) { + if (!first) std::cout << ", "; + first = false; + printType(param->getValue()->getType()); + std::cout << " "; + printValue(param->getValue()); + } + + std::cout << ")" << std::endl; + } break; + + case Kind::kCondBr: { + auto condBrInst = dynamic_cast(pInst); + std::cout << "br i1 "; + printValue(condBrInst->getCondition()); + std::cout << ", label %" << condBrInst->getThenBlock()->getName(); + std::cout << ", label %" << condBrInst->getElseBlock()->getName(); + std::cout << std::endl; + } break; + + case Kind::kBr: { + auto brInst = dynamic_cast(pInst); + std::cout << "br label %" << brInst->getBlock()->getName(); + std::cout << std::endl; + } break; + + case Kind::kReturn: { + auto retInst = dynamic_cast(pInst); + std::cout << "ret "; + if (retInst->getNumOperands() != 0) { + printType(retInst->getOperand(0)->getType()); + std::cout << " "; + printValue(retInst->getOperand(0)); + } else { + std::cout << "void"; + } + std::cout << std::endl; + } break; + + case Kind::kAlloca: { + auto allocaInst = dynamic_cast(pInst); + std::cout << "%" << allocaInst->getName() << " = alloca "; + + auto baseType = dynamic_cast(allocaInst->getType())->getBaseType(); + printType(baseType); + + if (allocaInst->getNumDims() > 0) { + std::cout << ", "; + for (size_t i = 0; i < allocaInst->getNumDims(); i++) { + if (i > 0) std::cout << ", "; + printType(Type::getIntType()); + std::cout << " "; + printValue(allocaInst->getDim(i)); + } + } + + std::cout << ", align 4" << std::endl; + } break; + + case Kind::kLoad: { + auto loadInst = dynamic_cast(pInst); + std::cout << "%" << loadInst->getName() << " = load "; + printType(loadInst->getType()); + std::cout << ", "; + printType(loadInst->getPointer()->getType()); + std::cout << " "; + printValue(loadInst->getPointer()); + + if (loadInst->getNumIndices() > 0) { + std::cout << ", "; + for (size_t i = 0; i < loadInst->getNumIndices(); i++) { + if (i > 0) std::cout << ", "; + printType(Type::getIntType()); + std::cout << " "; + printValue(loadInst->getIndex(i)); + } + } + + std::cout << ", align 4" << std::endl; + } break; + + case Kind::kLa: { + auto laInst = dynamic_cast(pInst); + std::cout << "%" << laInst->getName() << " = getelementptr inbounds "; + + auto ptrType = dynamic_cast(laInst->getPointer()->getType()); + printType(ptrType->getBaseType()); + std::cout << ", "; + printType(laInst->getPointer()->getType()); + std::cout << " "; + printValue(laInst->getPointer()); + std::cout << ", "; + + for (size_t i = 0; i < laInst->getNumIndices(); i++) { + if (i > 0) std::cout << ", "; + printType(Type::getIntType()); + std::cout << " "; + printValue(laInst->getIndex(i)); + } + + std::cout << std::endl; + } break; + + case Kind::kStore: { + auto storeInst = dynamic_cast(pInst); + std::cout << "store "; + printType(storeInst->getValue()->getType()); + std::cout << " "; + printValue(storeInst->getValue()); + std::cout << ", "; + printType(storeInst->getPointer()->getType()); + std::cout << " "; + printValue(storeInst->getPointer()); + + if (storeInst->getNumIndices() > 0) { + std::cout << ", "; + for (size_t i = 0; i < storeInst->getNumIndices(); i++) { + if (i > 0) std::cout << ", "; + printType(Type::getIntType()); + std::cout << " "; + printValue(storeInst->getIndex(i)); + } + } + + std::cout << ", align 4" << std::endl; + } break; + + case Kind::kMemset: { + auto memsetInst = dynamic_cast(pInst); + std::cout << "call void @llvm.memset.p0."; + printType(memsetInst->getPointer()->getType()); + std::cout << "("; + printType(memsetInst->getPointer()->getType()); + std::cout << " "; + printValue(memsetInst->getPointer()); + std::cout << ", i8 "; + printValue(memsetInst->getValue()); + std::cout << ", i32 "; + printValue(memsetInst->getSize()); + std::cout << ", i1 false)" << std::endl; + } break; + + case Kind::kPhi: { + auto phiInst = dynamic_cast(pInst); + std::cout << "%" << phiInst->getName() << " = phi "; + printType(phiInst->getType()); + + for (unsigned i = 0; i < phiInst->getNumOperands(); i += 2) { + if (i > 0) std::cout << ", "; + std::cout << "[ "; + printValue(phiInst->getOperand(i)); + std::cout << ", %" << dynamic_cast(phiInst->getOperand(i+1))->getName() << " ]"; + } + std::cout << std::endl; + } break; + + case Kind::kGetSubArray: { + auto getSubArrayInst = dynamic_cast(pInst); + std::cout << "%" << getSubArrayInst->getName() << " = getelementptr inbounds "; + + auto ptrType = dynamic_cast(getSubArrayInst->getFatherArray()->getType()); + printType(ptrType->getBaseType()); + std::cout << ", "; + printType(getSubArrayInst->getFatherArray()->getType()); + std::cout << " "; + printValue(getSubArrayInst->getFatherArray()); + std::cout << ", "; + bool firstIndex = true; + for (auto &index : getSubArrayInst->getIndices()) { + if (!firstIndex) std::cout << ", "; + firstIndex = false; + printType(Type::getIntType()); + std::cout << " "; + printValue(index->getValue()); + } + + std::cout << std::endl; + } break; + + default: + assert(false && "Unsupported instruction kind"); + break; + } +} + +} // namespace sysy diff --git a/src/include/IR.h b/src/include/IR.h index 3182a9a..b8cd1ba 100644 --- a/src/include/IR.h +++ b/src/include/IR.h @@ -1106,7 +1106,7 @@ public: return make_range(std::next(operand_begin()), operand_end()); } Value* getIndex(int index) const { return getOperand(index + 1); } - std::list getAncestorIndices() const { + std::list getAncestorIndices() const { std::list indices; for (const auto &index : getIndices()) { indices.emplace_back(index->getValue()); diff --git a/src/include/SysYIRPrinter.h b/src/include/SysYIRPrinter.h new file mode 100644 index 0000000..114fb05 --- /dev/null +++ b/src/include/SysYIRPrinter.h @@ -0,0 +1,29 @@ +#pragma once + +#include +#include "IR.h" + +namespace sysy { + +class SysYPrinter { +private: + Module *pModule; + +public: + explicit SysYPrinter(Module *pModule) : pModule(pModule) {} + +public: + void printIR(); + void printGlobalVariable(); + void printFunction(Function *function); + void printInst(Instruction *pInst); + void printType(Type *type); + void printValue(Value *value); + +public: + static std::string getOperandName(Value *operand); + std::string getTypeString(Type *type); + std::string getValueName(Value *value); +}; + +} // namespace sysy diff --git a/src/sysyc.cpp b/src/sysyc.cpp index ac13a68..dc1015b 100644 --- a/src/sysyc.cpp +++ b/src/sysyc.cpp @@ -8,7 +8,8 @@ using namespace std; using namespace antlr4; // #include "Backend.h" #include "SysYIRGenerator.h" -#include "RISCv32Backend.h" +#include "SysYIRPrinter.h" +// #include "LLVMIRGenerator.h" using namespace sysy; static string argStopAfter; @@ -71,22 +72,14 @@ int main(int argc, char **argv) { // visit AST to generate IR - SysYIRGenerator generator; - generator.visitCompUnit(moduleAST); - if (argStopAfter == "ir") { - // auto module = generator.get(); - // module->print(cout); - return EXIT_SUCCESS; - } - // generate assembly - auto module = generator.get(); - sysy::RISCv32CodeGen codegen(module); - string asmCode = codegen.code_gen(); - if (argStopAfter == "asm") { - cout << asmCode << endl; + if (argStopAfter == "ir") { + SysYIRGenerator generator; + generator.visitCompUnit(moduleAST); + auto moduleIR = generator.get(); + SysYPrinter printer(moduleIR); + printer.printIR(); return EXIT_SUCCESS; } - return EXIT_SUCCESS; } \ No newline at end of file