[backend-llir]修复了许多重构的bug

This commit is contained in:
Lixuanwang
2025-07-19 17:50:14 +08:00
parent d4a6996d74
commit 9528335a04
11 changed files with 513 additions and 497 deletions

View File

@@ -1,44 +1,71 @@
#include "RISCv64AsmPrinter.h" #include "RISCv64AsmPrinter.h"
#include "RISCv64ISel.h"
#include <stdexcept> #include <stdexcept>
namespace sysy { namespace sysy {
void RISCv64AsmPrinter::runOnMachineFunction(MachineFunction* mfunc, std::ostream& os) { // 检查是否为内存加载/存储指令,以处理特殊的打印格式
bool isMemoryOp(RVOpcodes opcode) {
switch (opcode) {
case RVOpcodes::LB: case RVOpcodes::LH: case RVOpcodes::LW: case RVOpcodes::LD:
case RVOpcodes::LBU: case RVOpcodes::LHU: case RVOpcodes::LWU:
case RVOpcodes::SB: case RVOpcodes::SH: case RVOpcodes::SW: case RVOpcodes::SD:
return true;
default:
return false;
}
}
RISCv64AsmPrinter::RISCv64AsmPrinter(MachineFunction* mfunc) : MFunc(mfunc) {}
void RISCv64AsmPrinter::run(std::ostream& os) {
OS = &os; OS = &os;
// 打印函数声明和全局符号 *OS << ".globl " << MFunc->getName() << "\n";
*OS << ".text\n"; *OS << MFunc->getName() << ":\n";
*OS << ".globl " << mfunc->getName() << "\n";
*OS << mfunc->getName() << ":\n";
// 打印函数序言 printPrologue();
printPrologue(mfunc);
// 遍历并打印所有基本块 for (auto& mbb : MFunc->getBlocks()) {
for (auto& mbb : mfunc->getBlocks()) {
printBasicBlock(mbb.get()); printBasicBlock(mbb.get());
} }
} }
void RISCv64AsmPrinter::printPrologue(MachineFunction* mfunc) { void RISCv64AsmPrinter::printPrologue() {
int stack_size = mfunc->getFrameInfo().frame_size; StackFrameInfo& frame_info = MFunc->getFrameInfo();
// 序言需要为保存ra和s0预留16字节
// 确保栈大小是16字节对齐 int total_stack_size = frame_info.locals_size + frame_info.spill_size + 16;
int aligned_stack_size = (stack_size + 15) & ~15; int aligned_stack_size = (total_stack_size + 15) & ~15;
frame_info.total_size = aligned_stack_size;
if (aligned_stack_size > 0) { if (aligned_stack_size > 0) {
*OS << " addi sp, sp, -" << aligned_stack_size << "\n"; *OS << " addi sp, sp, -" << aligned_stack_size << "\n";
// RV64中ra和s0都是8字节
*OS << " sd ra, " << (aligned_stack_size - 8) << "(sp)\n"; *OS << " sd ra, " << (aligned_stack_size - 8) << "(sp)\n";
*OS << " sd s0, " << (aligned_stack_size - 16) << "(sp)\n"; *OS << " sd s0, " << (aligned_stack_size - 16) << "(sp)\n";
*OS << " mv s0, sp\n"; *OS << " mv s0, sp\n";
} }
// 忠实还原保存函数入口参数的逻辑
Function* F = MFunc->getFunc();
if (F && F->getEntryBlock()) {
int arg_idx = 0;
RISCv64ISel* isel = MFunc->getISel();
for (AllocaInst* alloca_for_param : F->getEntryBlock()->getArguments()) {
if (arg_idx >= 8) break;
unsigned vreg = isel->getVReg(alloca_for_param);
if (frame_info.alloca_offsets.count(vreg)) {
int offset = frame_info.alloca_offsets.at(vreg);
auto arg_reg = static_cast<PhysicalReg>(static_cast<int>(PhysicalReg::A0) + arg_idx);
*OS << " sw " << regToString(arg_reg) << ", " << offset << "(s0)\n";
}
arg_idx++;
}
}
} }
void RISCv64AsmPrinter::printEpilogue(MachineFunction* mfunc) { void RISCv64AsmPrinter::printEpilogue() {
int stack_size = mfunc->getFrameInfo().frame_size; int aligned_stack_size = MFunc->getFrameInfo().total_size;
int aligned_stack_size = (stack_size + 15) & ~15;
if (aligned_stack_size > 0) { if (aligned_stack_size > 0) {
*OS << " ld ra, " << (aligned_stack_size - 8) << "(sp)\n"; *OS << " ld ra, " << (aligned_stack_size - 8) << "(sp)\n";
*OS << " ld s0, " << (aligned_stack_size - 16) << "(sp)\n"; *OS << " ld s0, " << (aligned_stack_size - 16) << "(sp)\n";
@@ -46,129 +73,85 @@ void RISCv64AsmPrinter::printEpilogue(MachineFunction* mfunc) {
} }
} }
void RISCv64AsmPrinter::printBasicBlock(MachineBasicBlock* mbb) { void RISCv64AsmPrinter::printBasicBlock(MachineBasicBlock* mbb) {
// 打印基本块标签
if (!mbb->getName().empty()) { if (!mbb->getName().empty()) {
*OS << mbb->getName() << ":\n"; *OS << mbb->getName() << ":\n";
} }
// 打印指令
for (auto& instr : mbb->getInstructions()) { for (auto& instr : mbb->getInstructions()) {
printInstruction(instr.get(), mbb); printInstruction(instr.get());
} }
} }
void RISCv64AsmPrinter::printInstruction(MachineInstr* instr, MachineBasicBlock* parent_bb) { void RISCv64AsmPrinter::printInstruction(MachineInstr* instr) {
*OS << " "; // 指令缩进
auto opcode = instr->getOpcode(); auto opcode = instr->getOpcode();
// RET指令需要特殊处理在打印ret之前先打印函数尾声
if (opcode == RVOpcodes::RET) { if (opcode == RVOpcodes::RET) {
printEpilogue(parent_bb->getParent()); printEpilogue();
} }
if (opcode != RVOpcodes::LABEL) {
// 使用switch将Opcode转换为汇编助记符 *OS << " ";
}
switch (opcode) { switch (opcode) {
// Arithmatic case RVOpcodes::ADD: *OS << "add "; break; case RVOpcodes::ADDI: *OS << "addi "; break;
case RVOpcodes::ADD: *OS << "add "; break; case RVOpcodes::ADDW: *OS << "addw "; break; case RVOpcodes::ADDIW: *OS << "addiw "; break;
case RVOpcodes::ADDI: *OS << "addi "; break; case RVOpcodes::SUB: *OS << "sub "; break; case RVOpcodes::SUBW: *OS << "subw "; break;
case RVOpcodes::ADDW: *OS << "addw "; break; case RVOpcodes::MUL: *OS << "mul "; break; case RVOpcodes::MULW: *OS << "mulw "; break;
case RVOpcodes::ADDIW: *OS << "addiw "; break; case RVOpcodes::DIV: *OS << "div "; break; case RVOpcodes::DIVW: *OS << "divw "; break;
case RVOpcodes::SUB: *OS << "sub "; break; case RVOpcodes::REM: *OS << "rem "; break; case RVOpcodes::REMW: *OS << "remw "; break;
case RVOpcodes::SUBW: *OS << "subw "; break; case RVOpcodes::XOR: *OS << "xor "; break; case RVOpcodes::XORI: *OS << "xori "; break;
case RVOpcodes::MUL: *OS << "mul "; break; case RVOpcodes::OR: *OS << "or "; break; case RVOpcodes::ORI: *OS << "ori "; break;
case RVOpcodes::MULW: *OS << "mulw "; break; case RVOpcodes::AND: *OS << "and "; break; case RVOpcodes::ANDI: *OS << "andi "; break;
case RVOpcodes::DIV: *OS << "div "; break; case RVOpcodes::SLL: *OS << "sll "; break; case RVOpcodes::SLLI: *OS << "slli "; break;
case RVOpcodes::DIVW: *OS << "divw "; break; case RVOpcodes::SLLW: *OS << "sllw "; break; case RVOpcodes::SLLIW: *OS << "slliw "; break;
case RVOpcodes::REM: *OS << "rem "; break; case RVOpcodes::SRL: *OS << "srl "; break; case RVOpcodes::SRLI: *OS << "srli "; break;
case RVOpcodes::REMW: *OS << "remw "; break; case RVOpcodes::SRLW: *OS << "srlw "; break; case RVOpcodes::SRLIW: *OS << "srliw "; break;
// Logical case RVOpcodes::SRA: *OS << "sra "; break; case RVOpcodes::SRAI: *OS << "srai "; break;
case RVOpcodes::XOR: *OS << "xor "; break; case RVOpcodes::SRAW: *OS << "sraw "; break; case RVOpcodes::SRAIW: *OS << "sraiw "; break;
case RVOpcodes::XORI: *OS << "xori "; break; case RVOpcodes::SLT: *OS << "slt "; break; case RVOpcodes::SLTI: *OS << "slti "; break;
case RVOpcodes::OR: *OS << "or "; break; case RVOpcodes::SLTU: *OS << "sltu "; break; case RVOpcodes::SLTIU: *OS << "sltiu "; break;
case RVOpcodes::ORI: *OS << "ori "; break; case RVOpcodes::LW: *OS << "lw "; break; case RVOpcodes::LH: *OS << "lh "; break;
case RVOpcodes::AND: *OS << "and "; break; case RVOpcodes::LB: *OS << "lb "; break; case RVOpcodes::LWU: *OS << "lwu "; break;
case RVOpcodes::ANDI: *OS << "andi "; break; case RVOpcodes::LHU: *OS << "lhu "; break; case RVOpcodes::LBU: *OS << "lbu "; break;
// Shift case RVOpcodes::SW: *OS << "sw "; break; case RVOpcodes::SH: *OS << "sh "; break;
case RVOpcodes::SLL: *OS << "sll "; break; case RVOpcodes::SB: *OS << "sb "; break; case RVOpcodes::LD: *OS << "ld "; break;
case RVOpcodes::SLLI: *OS << "slli "; break;
case RVOpcodes::SLLW: *OS << "sllw "; break;
case RVOpcodes::SLLIW: *OS << "slliw "; break;
case RVOpcodes::SRL: *OS << "srl "; break;
case RVOpcodes::SRLI: *OS << "srli "; break;
case RVOpcodes::SRLW: *OS << "srlw "; break;
case RVOpcodes::SRLIW: *OS << "srliw "; break;
case RVOpcodes::SRA: *OS << "sra "; break;
case RVOpcodes::SRAI: *OS << "srai "; break;
case RVOpcodes::SRAW: *OS << "sraw "; break;
case RVOpcodes::SRAIW: *OS << "sraiw "; break;
// Compare
case RVOpcodes::SLT: *OS << "slt "; break;
case RVOpcodes::SLTI: *OS << "slti "; break;
case RVOpcodes::SLTU: *OS << "sltu "; break;
case RVOpcodes::SLTIU: *OS << "sltiu "; break;
// Memory
case RVOpcodes::LW: *OS << "lw "; break;
case RVOpcodes::LH: *OS << "lh "; break;
case RVOpcodes::LB: *OS << "lb "; break;
case RVOpcodes::LWU: *OS << "lwu "; break;
case RVOpcodes::LHU: *OS << "lhu "; break;
case RVOpcodes::LBU: *OS << "lbu "; break;
case RVOpcodes::SW: *OS << "sw "; break;
case RVOpcodes::SH: *OS << "sh "; break;
case RVOpcodes::SB: *OS << "sb "; break;
case RVOpcodes::LD: *OS << "ld "; break;
case RVOpcodes::SD: *OS << "sd "; break; case RVOpcodes::SD: *OS << "sd "; break;
// Control Flow case RVOpcodes::J: *OS << "j "; break; case RVOpcodes::JAL: *OS << "jal "; break;
case RVOpcodes::J: *OS << "j "; break; case RVOpcodes::JALR: *OS << "jalr "; break; case RVOpcodes::RET: *OS << "ret"; break;
case RVOpcodes::JAL: *OS << "jal "; break; case RVOpcodes::BEQ: *OS << "beq "; break; case RVOpcodes::BNE: *OS << "bne "; break;
case RVOpcodes::JALR: *OS << "jalr "; break; case RVOpcodes::BLT: *OS << "blt "; break; case RVOpcodes::BGE: *OS << "bge "; break;
case RVOpcodes::RET: *OS << "ret"; break; case RVOpcodes::BLTU: *OS << "bltu "; break; case RVOpcodes::BGEU: *OS << "bgeu "; break;
case RVOpcodes::BEQ: *OS << "beq "; break; case RVOpcodes::LI: *OS << "li "; break; case RVOpcodes::LA: *OS << "la "; break;
case RVOpcodes::BNE: *OS << "bne "; break; case RVOpcodes::MV: *OS << "mv "; break; case RVOpcodes::NEG: *OS << "neg "; break;
case RVOpcodes::BLT: *OS << "blt "; break; case RVOpcodes::NEGW: *OS << "negw "; break; case RVOpcodes::SEQZ: *OS << "seqz "; break;
case RVOpcodes::BGE: *OS << "bge "; break;
case RVOpcodes::BLTU: *OS << "bltu "; break;
case RVOpcodes::BGEU: *OS << "bgeu "; break;
// Pseudo-Instructions
case RVOpcodes::LI: *OS << "li "; break;
case RVOpcodes::LA: *OS << "la "; break;
case RVOpcodes::MV: *OS << "mv "; break;
case RVOpcodes::NEG: *OS << "neg "; break;
case RVOpcodes::NEGW: *OS << "negw "; break;
case RVOpcodes::SEQZ: *OS << "seqz "; break;
case RVOpcodes::SNEZ: *OS << "snez "; break; case RVOpcodes::SNEZ: *OS << "snez "; break;
// Call
case RVOpcodes::CALL: *OS << "call "; break; case RVOpcodes::CALL: *OS << "call "; break;
// Special
case RVOpcodes::LABEL: case RVOpcodes::LABEL:
*OS << "\b\b\b\b";
printOperand(instr->getOperands()[0].get()); printOperand(instr->getOperands()[0].get());
*OS << ":"; *OS << ":";
break; break;
case RVOpcodes::FRAME_LOAD:
case RVOpcodes::FRAME_STORE:
// These should have been eliminated by RegAlloc
throw std::runtime_error("FRAME pseudo-instruction not eliminated before AsmPrinter");
default: default:
throw std::runtime_error("Unknown opcode in AsmPrinter"); throw std::runtime_error("Unknown opcode in AsmPrinter");
} }
// 打印操作数
const auto& operands = instr->getOperands(); const auto& operands = instr->getOperands();
for (size_t i = 0; i < operands.size(); ++i) { if (!operands.empty()) {
// 对于LW/SW, 操作数格式是 rd, offset(rs1) if (isMemoryOp(opcode)) {
if (opcode == RVOpcodes::LW || opcode == RVOpcodes::SW || opcode == RVOpcodes::LD || opcode == RVOpcodes::SD) {
printOperand(operands[0].get()); printOperand(operands[0].get());
*OS << ", "; *OS << ", ";
printOperand(operands[1].get()); printOperand(operands[1].get());
break; // LW/SW只有两个操作数部分 } else {
} for (size_t i = 0; i < operands.size(); ++i) {
printOperand(operands[i].get());
printOperand(operands[i].get()); if (i < operands.size() - 1) {
if (i < operands.size() - 1) { *OS << ", ";
*OS << ", "; }
}
} }
} }
*OS << "\n"; *OS << "\n";
} }
@@ -178,21 +161,18 @@ void RISCv64AsmPrinter::printOperand(MachineOperand* op) {
case MachineOperand::KIND_REG: { case MachineOperand::KIND_REG: {
auto reg_op = static_cast<RegOperand*>(op); auto reg_op = static_cast<RegOperand*>(op);
if (reg_op->isVirtual()) { if (reg_op->isVirtual()) {
// 在这个阶段不应该再有虚拟寄存器了
*OS << "%vreg" << reg_op->getVRegNum(); *OS << "%vreg" << reg_op->getVRegNum();
} else { } else {
*OS << regToString(reg_op->getPReg()); *OS << regToString(reg_op->getPReg());
} }
break; break;
} }
case MachineOperand::KIND_IMM: { case MachineOperand::KIND_IMM:
*OS << static_cast<ImmOperand*>(op)->getValue(); *OS << static_cast<ImmOperand*>(op)->getValue();
break; break;
} case MachineOperand::KIND_LABEL:
case MachineOperand::KIND_LABEL: {
*OS << static_cast<LabelOperand*>(op)->getName(); *OS << static_cast<LabelOperand*>(op)->getName();
break; break;
}
case MachineOperand::KIND_MEM: { case MachineOperand::KIND_MEM: {
auto mem_op = static_cast<MemOperand*>(op); auto mem_op = static_cast<MemOperand*>(op);
printOperand(mem_op->getOffset()); printOperand(mem_op->getOffset());
@@ -204,41 +184,40 @@ void RISCv64AsmPrinter::printOperand(MachineOperand* op) {
} }
} }
// 物理寄存器到字符串的转换 (从原RISCv64Backend.cpp迁移)
std::string RISCv64AsmPrinter::regToString(PhysicalReg reg) { std::string RISCv64AsmPrinter::regToString(PhysicalReg reg) {
switch (reg) { switch (reg) {
case PhysicalReg::ZERO: return "x0"; case PhysicalReg::ZERO: return "x0"; case PhysicalReg::RA: return "ra";
case PhysicalReg::RA: return "ra"; case PhysicalReg::SP: return "sp"; case PhysicalReg::GP: return "gp";
case PhysicalReg::SP: return "sp"; case PhysicalReg::TP: return "tp"; case PhysicalReg::T0: return "t0";
case PhysicalReg::GP: return "gp"; case PhysicalReg::T1: return "t1"; case PhysicalReg::T2: return "t2";
case PhysicalReg::TP: return "tp"; case PhysicalReg::S0: return "s0"; case PhysicalReg::S1: return "s1";
case PhysicalReg::T0: return "t0"; case PhysicalReg::A0: return "a0"; case PhysicalReg::A1: return "a1";
case PhysicalReg::T1: return "t1"; case PhysicalReg::A2: return "a2"; case PhysicalReg::A3: return "a3";
case PhysicalReg::T2: return "t2"; case PhysicalReg::A4: return "a4"; case PhysicalReg::A5: return "a5";
case PhysicalReg::S0: return "s0"; case PhysicalReg::A6: return "a6"; case PhysicalReg::A7: return "a7";
case PhysicalReg::S1: return "s1"; case PhysicalReg::S2: return "s2"; case PhysicalReg::S3: return "s3";
case PhysicalReg::A0: return "a0"; case PhysicalReg::S4: return "s4"; case PhysicalReg::S5: return "s5";
case PhysicalReg::A1: return "a1"; case PhysicalReg::S6: return "s6"; case PhysicalReg::S7: return "s7";
case PhysicalReg::A2: return "a2"; case PhysicalReg::S8: return "s8"; case PhysicalReg::S9: return "s9";
case PhysicalReg::A3: return "a3"; case PhysicalReg::S10: return "s10"; case PhysicalReg::S11: return "s11";
case PhysicalReg::A4: return "a4"; case PhysicalReg::T3: return "t3"; case PhysicalReg::T4: return "t4";
case PhysicalReg::A5: return "a5"; case PhysicalReg::T5: return "t5"; case PhysicalReg::T6: return "t6";
case PhysicalReg::A6: return "a6"; case PhysicalReg::F0: return "f0"; case PhysicalReg::F1: return "f1";
case PhysicalReg::A7: return "a7"; case PhysicalReg::F2: return "f2"; case PhysicalReg::F3: return "f3";
case PhysicalReg::S2: return "s2"; case PhysicalReg::F4: return "f4"; case PhysicalReg::F5: return "f5";
case PhysicalReg::S3: return "s3"; case PhysicalReg::F6: return "f6"; case PhysicalReg::F7: return "f7";
case PhysicalReg::S4: return "s4"; case PhysicalReg::F8: return "f8"; case PhysicalReg::F9: return "f9";
case PhysicalReg::S5: return "s5"; case PhysicalReg::F10: return "f10"; case PhysicalReg::F11: return "f11";
case PhysicalReg::S6: return "s6"; case PhysicalReg::F12: return "f12"; case PhysicalReg::F13: return "f13";
case PhysicalReg::S7: return "s7"; case PhysicalReg::F14: return "f14"; case PhysicalReg::F15: return "f15";
case PhysicalReg::S8: return "s8"; case PhysicalReg::F16: return "f16"; case PhysicalReg::F17: return "f17";
case PhysicalReg::S9: return "s9"; case PhysicalReg::F18: return "f18"; case PhysicalReg::F19: return "f19";
case PhysicalReg::S10: return "s10"; case PhysicalReg::F20: return "f20"; case PhysicalReg::F21: return "f21";
case PhysicalReg::S11: return "s11"; case PhysicalReg::F22: return "f22"; case PhysicalReg::F23: return "f23";
case PhysicalReg::T3: return "t3"; case PhysicalReg::F24: return "f24"; case PhysicalReg::F25: return "f25";
case PhysicalReg::T4: return "t4"; case PhysicalReg::F26: return "f26"; case PhysicalReg::F27: return "f27";
case PhysicalReg::T5: return "t5"; case PhysicalReg::F28: return "f28"; case PhysicalReg::F29: return "f29";
case PhysicalReg::T6: return "t6"; case PhysicalReg::F30: return "f30"; case PhysicalReg::F31: return "f31";
default: return "UNKNOWN_REG"; default: return "UNKNOWN_REG";
} }
} }

View File

@@ -3,7 +3,6 @@
#include "RISCv64RegAlloc.h" #include "RISCv64RegAlloc.h"
#include "RISCv64AsmPrinter.h" #include "RISCv64AsmPrinter.h"
#include <sstream> #include <sstream>
#include <stdexcept>
namespace sysy { namespace sysy {
@@ -12,13 +11,12 @@ std::string RISCv64CodeGen::code_gen() {
return module_gen(); return module_gen();
} }
// module_gen 的逻辑基本不变,它负责处理.data段和驱动每个函数生成 // 模块级代码生成 (移植自原文件,处理.data段和驱动函数生成)
std::string RISCv64CodeGen::module_gen() { std::string RISCv64CodeGen::module_gen() {
std::stringstream ss; std::stringstream ss;
// 1. 处理全局变量 (.data段) // 1. 处理全局变量 (.data段)
bool has_globals = !module->getGlobals().empty(); if (!module->getGlobals().empty()) {
if (has_globals) {
ss << ".data\n"; ss << ".data\n";
for (const auto& global : module->getGlobals()) { for (const auto& global : module->getGlobals()) {
ss << ".globl " << global->getName() << "\n"; ss << ".globl " << global->getName() << "\n";
@@ -45,9 +43,9 @@ std::string RISCv64CodeGen::module_gen() {
// 2. 处理函数 (.text段) // 2. 处理函数 (.text段)
if (!module->getFunctions().empty()) { if (!module->getFunctions().empty()) {
ss << ".text\n"; ss << ".text\n";
for (const auto& func : module->getFunctions()) { for (const auto& func_pair : module->getFunctions()) {
if (func.second.get()) { if (func_pair.second.get()) {
ss << function_gen(func.second.get()); ss << function_gen(func_pair.second.get());
} }
} }
} }
@@ -56,31 +54,18 @@ std::string RISCv64CodeGen::module_gen() {
// function_gen 现在是新的、模块化的处理流水线 // function_gen 现在是新的、模块化的处理流水线
std::string RISCv64CodeGen::function_gen(Function* func) { std::string RISCv64CodeGen::function_gen(Function* func) {
// === 新的、完整的流水线 ===
// 阶段 1: 指令选择 (sysy::IR -> LLIR with virtual registers) // 阶段 1: 指令选择 (sysy::IR -> LLIR with virtual registers)
RISCv64ISel isel; RISCv64ISel isel;
std::unique_ptr<MachineFunction> mfunc = isel.runOnFunction(func); std::unique_ptr<MachineFunction> mfunc = isel.runOnFunction(func);
// 阶段 2: 寄存器分配前优化 (未来扩展点) // 阶段 2: 寄存器分配 (包含栈帧布局, 活跃性分析, 图着色, spill/rewrite)
// 例如:
// auto pre_ra_scheduler = std::make_unique<PreRAScheduler>();
// pre_ra_scheduler->runOnMachineFunction(mfunc.get());
// 阶段 3: 物理寄存器分配 (virtual regs -> physical regs + spill code)
RISCv64RegAlloc reg_alloc(mfunc.get()); RISCv64RegAlloc reg_alloc(mfunc.get());
reg_alloc.run(); reg_alloc.run();
// 阶段 4: 寄存器分配后优化 (未来扩展点) // 阶段 3: 代码发射 (LLIR with physical regs -> Assembly Text)
// 例如:
// auto post_ra_peephole = std::make_unique<PeepholeOptimizer>();
// post_ra_peephole->runOnMachineFunction(mfunc.get());
// 阶段 5: 代码发射 (LLIR with physical regs -> Assembly Text)
std::stringstream ss; std::stringstream ss;
RISCv64AsmPrinter printer; RISCv64AsmPrinter printer(mfunc.get());
printer.runOnMachineFunction(mfunc.get(), ss); printer.run(ss);
return ss.str(); return ss.str();
} }

View File

@@ -1,22 +1,32 @@
#include "RISCv64ISel.h" #include "RISCv64ISel.h"
#include <stdexcept> #include <stdexcept>
#include <iostream>
#include <functional>
#include <set> #include <set>
#include <functional>
#include <cmath> // For std::fabs
#include <limits> // For std::numeric_limits
namespace sysy { namespace sysy {
// DAG节点定义 (内部实现)
struct RISCv64ISel::DAGNode {
enum NodeKind { CONSTANT, LOAD, STORE, BINARY, CALL, RETURN, BRANCH, ALLOCA_ADDR, UNARY, MEMSET };
NodeKind kind;
Value* value = nullptr;
std::vector<DAGNode*> operands;
std::vector<DAGNode*> users;
DAGNode(NodeKind k) : kind(k) {}
};
RISCv64ISel::RISCv64ISel() : vreg_counter(0), local_label_counter(0) {} RISCv64ISel::RISCv64ISel() : vreg_counter(0), local_label_counter(0) {}
// 为一个IR Value获取或分配一个新的虚拟寄存器 // 为一个IR Value获取或分配一个新的虚拟寄存器
unsigned RISCv64ISel::getVReg(Value* val) { unsigned RISCv64ISel::getVReg(Value* val) {
if (!val) { // 安全检查 if (!val) {
throw std::runtime_error("Cannot get vreg for a null Value."); throw std::runtime_error("Cannot get vreg for a null Value.");
} }
if (vreg_map.find(val) == vreg_map.end()) { if (vreg_map.find(val) == vreg_map.end()) {
if (vreg_counter == 0) { if (vreg_counter == 0) {
// vreg 0 通常保留给物理寄存器x0(zero)我们从1开始分配 vreg_counter = 1; // vreg 0 保留
vreg_counter = 1;
} }
vreg_map[val] = vreg_counter++; vreg_map[val] = vreg_counter++;
} }
@@ -27,7 +37,7 @@ unsigned RISCv64ISel::getVReg(Value* val) {
std::unique_ptr<MachineFunction> RISCv64ISel::runOnFunction(Function* func) { std::unique_ptr<MachineFunction> RISCv64ISel::runOnFunction(Function* func) {
F = func; F = func;
if (!F) return nullptr; if (!F) return nullptr;
MFunc = std::make_unique<MachineFunction>(F->getName()); MFunc = std::make_unique<MachineFunction>(F, this);
vreg_map.clear(); vreg_map.clear();
bb_map.clear(); bb_map.clear();
vreg_counter = 0; vreg_counter = 0;
@@ -40,37 +50,28 @@ std::unique_ptr<MachineFunction> RISCv64ISel::runOnFunction(Function* func) {
// 指令选择主流程 // 指令选择主流程
void RISCv64ISel::select() { void RISCv64ISel::select() {
// 1. 为所有基本块创建对应的MachineBasicBlock
for (const auto& bb_ptr : F->getBasicBlocks()) { for (const auto& bb_ptr : F->getBasicBlocks()) {
BasicBlock* bb = bb_ptr.get(); auto mbb = std::make_unique<MachineBasicBlock>(bb_ptr->getName(), MFunc.get());
auto mbb = std::make_unique<MachineBasicBlock>(bb->getName(), MFunc.get()); bb_map[bb_ptr.get()] = mbb.get();
bb_map[bb] = mbb.get();
MFunc->addBlock(std::move(mbb)); MFunc->addBlock(std::move(mbb));
} }
// 2. 为函数参数创建虚拟寄存器
// ====================== 已修正 ======================
// 根据 IR.h, 参数列表存储在入口基本块中
if (F->getEntryBlock()) { if (F->getEntryBlock()) {
for (auto* arg_alloca : F->getEntryBlock()->getArguments()) { for (auto* arg_alloca : F->getEntryBlock()->getArguments()) {
getVReg(arg_alloca); getVReg(arg_alloca);
} }
} }
// =====================================================
// 3. 遍历每个基本块,生成指令
for (const auto& bb_ptr : F->getBasicBlocks()) { for (const auto& bb_ptr : F->getBasicBlocks()) {
selectBasicBlock(bb_ptr.get()); selectBasicBlock(bb_ptr.get());
} }
// 4. 设置基本块的前驱后继关系
for (const auto& bb_ptr : F->getBasicBlocks()) { for (const auto& bb_ptr : F->getBasicBlocks()) {
BasicBlock* bb = bb_ptr.get(); CurMBB = bb_map.at(bb_ptr.get());
CurMBB = bb_map.at(bb); for (auto succ : bb_ptr->getSuccessors()) {
for (auto succ : bb->getSuccessors()) {
CurMBB->successors.push_back(bb_map.at(succ)); CurMBB->successors.push_back(bb_map.at(succ));
} }
for (auto pred : bb->getPredecessors()) { for (auto pred : bb_ptr->getPredecessors()) {
CurMBB->predecessors.push_back(bb_map.at(pred)); CurMBB->predecessors.push_back(bb_map.at(pred));
} }
} }
@@ -87,29 +88,23 @@ void RISCv64ISel::selectBasicBlock(BasicBlock* bb) {
value_to_node[node->value] = node.get(); value_to_node[node->value] = node.get();
} }
} }
std::set<DAGNode*> selected_nodes; std::set<DAGNode*> selected_nodes;
std::function<void(DAGNode*)> select_recursive = std::function<void(DAGNode*)> select_recursive =
[&](DAGNode* node) { [&](DAGNode* node) {
if (!node || selected_nodes.count(node)) return; if (!node || selected_nodes.count(node)) return;
for (auto operand : node->operands) { for (auto operand : node->operands) {
select_recursive(operand); select_recursive(operand);
} }
// 只有当所有操作数都选择完毕后,才选择当前节点
selectNode(node); selectNode(node);
selected_nodes.insert(node); selected_nodes.insert(node);
}; };
// 按照IR指令的原始顺序来驱动指令选择
for (const auto& inst_ptr : bb->getInstructions()) { for (const auto& inst_ptr : bb->getInstructions()) {
DAGNode* node_to_select = nullptr; DAGNode* node_to_select = nullptr;
// 查找当前IR指令对应的DAG节点
if (value_to_node.count(inst_ptr.get())) { if (value_to_node.count(inst_ptr.get())) {
node_to_select = value_to_node.at(inst_ptr.get()); node_to_select = value_to_node.at(inst_ptr.get());
} else { } else {
// 对于没有返回值的指令或某些特殊情况
for(const auto& node : dag) { for(const auto& node : dag) {
if(node->value == inst_ptr.get()) { if(node->value == inst_ptr.get()) {
node_to_select = node.get(); node_to_select = node.get();
@@ -123,88 +118,105 @@ void RISCv64ISel::selectBasicBlock(BasicBlock* bb) {
} }
} }
// 核心函数为DAG节点选择并生成MachineInstr (忠实移植版)
void RISCv64ISel::selectNode(DAGNode* node) { void RISCv64ISel::selectNode(DAGNode* node) {
// 注意不再生成字符串而是创建MachineInstr对象并加入到CurMBB
switch (node->kind) { switch (node->kind) {
case DAGNode::CONSTANT: case DAGNode::CONSTANT:
case DAGNode::ALLOCA_ADDR: case DAGNode::ALLOCA_ADDR:
// 这些节点本身不生成指令。使用它们的指令会按需处理。
// 为Alloca地址分配一个vreg是必要的代表地址。
if (node->value) getVReg(node->value); if (node->value) getVReg(node->value);
break; break;
case DAGNode::LOAD: { case DAGNode::LOAD: {
// lw rd, offset(base)
auto dest_vreg = getVReg(node->value); auto dest_vreg = getVReg(node->value);
auto ptr_vreg = getVReg(node->operands[0]->value); Value* ptr_val = node->operands[0]->value;
auto instr = std::make_unique<MachineInstr>(RVOpcodes::LW); if (auto alloca = dynamic_cast<AllocaInst*>(ptr_val)) {
instr->addOperand(std::make_unique<RegOperand>(dest_vreg)); auto instr = std::make_unique<MachineInstr>(RVOpcodes::FRAME_LOAD);
// 暂时生成0(ptr)后续pass会将其优化为 offset(s0) instr->addOperand(std::make_unique<RegOperand>(dest_vreg));
instr->addOperand(std::make_unique<MemOperand>( instr->addOperand(std::make_unique<RegOperand>(getVReg(alloca)));
std::make_unique<RegOperand>(ptr_vreg), CurMBB->addInstruction(std::move(instr));
std::make_unique<ImmOperand>(0) } else if (auto global = dynamic_cast<GlobalValue*>(ptr_val)) {
)); auto addr_vreg = getNewVReg();
CurMBB->addInstruction(std::move(instr)); auto la = std::make_unique<MachineInstr>(RVOpcodes::LA);
la->addOperand(std::make_unique<RegOperand>(addr_vreg));
la->addOperand(std::make_unique<LabelOperand>(global->getName()));
CurMBB->addInstruction(std::move(la));
auto lw = std::make_unique<MachineInstr>(RVOpcodes::LW);
lw->addOperand(std::make_unique<RegOperand>(dest_vreg));
lw->addOperand(std::make_unique<MemOperand>(
std::make_unique<RegOperand>(addr_vreg),
std::make_unique<ImmOperand>(0)
));
CurMBB->addInstruction(std::move(lw));
} else {
auto ptr_vreg = getVReg(ptr_val);
auto lw = std::make_unique<MachineInstr>(RVOpcodes::LW);
lw->addOperand(std::make_unique<RegOperand>(dest_vreg));
lw->addOperand(std::make_unique<MemOperand>(
std::make_unique<RegOperand>(ptr_vreg),
std::make_unique<ImmOperand>(0)
));
CurMBB->addInstruction(std::move(lw));
}
break; break;
} }
case DAGNode::STORE: { case DAGNode::STORE: {
// sw rs2, offset(rs1) Value* val_to_store = node->operands[0]->value;
// 先加载常量 Value* ptr_val = node->operands[1]->value;
if (auto val_const = dynamic_cast<ConstantValue*>(node->operands[0]->value)) {
if (auto val_const = dynamic_cast<ConstantValue*>(val_to_store)) {
auto li = std::make_unique<MachineInstr>(RVOpcodes::LI); auto li = std::make_unique<MachineInstr>(RVOpcodes::LI);
li->addOperand(std::make_unique<RegOperand>(getVReg(val_const))); li->addOperand(std::make_unique<RegOperand>(getVReg(val_const)));
li->addOperand(std::make_unique<ImmOperand>(val_const->getInt())); li->addOperand(std::make_unique<ImmOperand>(val_const->getInt()));
CurMBB->addInstruction(std::move(li)); CurMBB->addInstruction(std::move(li));
} }
auto val_vreg = getVReg(val_to_store);
auto val_vreg = getVReg(node->operands[0]->value); if (auto alloca = dynamic_cast<AllocaInst*>(ptr_val)) {
auto ptr_vreg = getVReg(node->operands[1]->value); auto instr = std::make_unique<MachineInstr>(RVOpcodes::FRAME_STORE);
instr->addOperand(std::make_unique<RegOperand>(val_vreg));
instr->addOperand(std::make_unique<RegOperand>(getVReg(alloca)));
CurMBB->addInstruction(std::move(instr));
} else if (auto global = dynamic_cast<GlobalValue*>(ptr_val)) {
auto addr_vreg = getNewVReg();
auto la = std::make_unique<MachineInstr>(RVOpcodes::LA);
la->addOperand(std::make_unique<RegOperand>(addr_vreg));
la->addOperand(std::make_unique<LabelOperand>(global->getName()));
CurMBB->addInstruction(std::move(la));
auto instr = std::make_unique<MachineInstr>(RVOpcodes::SW); auto sw = std::make_unique<MachineInstr>(RVOpcodes::SW);
instr->addOperand(std::make_unique<RegOperand>(val_vreg)); // value to store sw->addOperand(std::make_unique<RegOperand>(val_vreg));
instr->addOperand(std::make_unique<MemOperand>( sw->addOperand(std::make_unique<MemOperand>(
std::make_unique<RegOperand>(ptr_vreg), // base address std::make_unique<RegOperand>(addr_vreg),
std::make_unique<ImmOperand>(0) // offset std::make_unique<ImmOperand>(0)
)); ));
CurMBB->addInstruction(std::move(instr)); CurMBB->addInstruction(std::move(sw));
} else {
auto ptr_vreg = getVReg(ptr_val);
auto sw = std::make_unique<MachineInstr>(RVOpcodes::SW);
sw->addOperand(std::make_unique<RegOperand>(val_vreg));
sw->addOperand(std::make_unique<MemOperand>(
std::make_unique<RegOperand>(ptr_vreg),
std::make_unique<ImmOperand>(0)
));
CurMBB->addInstruction(std::move(sw));
}
break; break;
} }
case DAGNode::BINARY: { case DAGNode::BINARY: {
auto bin = dynamic_cast<BinaryInst*>(node->value); auto bin = dynamic_cast<BinaryInst*>(node->value);
if (!bin) break;
Value* lhs = bin->getLhs(); Value* lhs = bin->getLhs();
Value* rhs = bin->getRhs(); Value* rhs = bin->getRhs();
// 检查是否为 addi 优化
if (bin->getKind() == BinaryInst::kAdd) {
if (auto rhs_const = dynamic_cast<ConstantValue*>(rhs)) {
if (rhs_const->getInt() >= -2048 && rhs_const->getInt() < 2048) {
auto instr = std::make_unique<MachineInstr>(RVOpcodes::ADDIW);
instr->addOperand(std::make_unique<RegOperand>(getVReg(bin)));
instr->addOperand(std::make_unique<RegOperand>(getVReg(lhs)));
instr->addOperand(std::make_unique<ImmOperand>(rhs_const->getInt()));
CurMBB->addInstruction(std::move(instr));
return; // 指令已生成,提前返回
}
}
}
// 为操作数加载立即数或地址
auto load_val_if_const = [&](Value* val) { auto load_val_if_const = [&](Value* val) {
if (auto c = dynamic_cast<ConstantValue*>(val)) { if (auto c = dynamic_cast<ConstantValue*>(val)) {
auto li = std::make_unique<MachineInstr>(RVOpcodes::LI); auto li = std::make_unique<MachineInstr>(RVOpcodes::LI);
li->addOperand(std::make_unique<RegOperand>(getVReg(c))); li->addOperand(std::make_unique<RegOperand>(getVReg(c)));
li->addOperand(std::make_unique<ImmOperand>(c->getInt())); li->addOperand(std::make_unique<ImmOperand>(c->getInt()));
CurMBB->addInstruction(std::move(li)); CurMBB->addInstruction(std::move(li));
} else if (auto g = dynamic_cast<GlobalValue*>(val)) {
auto la = std::make_unique<MachineInstr>(RVOpcodes::LA);
la->addOperand(std::make_unique<RegOperand>(getVReg(g)));
la->addOperand(std::make_unique<LabelOperand>(g->getName()));
CurMBB->addInstruction(std::move(la));
} }
}; };
load_val_if_const(lhs); load_val_if_const(lhs);
@@ -214,7 +226,19 @@ void RISCv64ISel::selectNode(DAGNode* node) {
auto lhs_vreg = getVReg(lhs); auto lhs_vreg = getVReg(lhs);
auto rhs_vreg = getVReg(rhs); auto rhs_vreg = getVReg(rhs);
// 生成二元运算指令 if (bin->getKind() == BinaryInst::kAdd) {
if (auto rhs_const = dynamic_cast<ConstantValue*>(rhs)) {
if (rhs_const->getInt() >= -2048 && rhs_const->getInt() < 2048) {
auto instr = std::make_unique<MachineInstr>(RVOpcodes::ADDIW);
instr->addOperand(std::make_unique<RegOperand>(dest_vreg));
instr->addOperand(std::make_unique<RegOperand>(lhs_vreg));
instr->addOperand(std::make_unique<ImmOperand>(rhs_const->getInt()));
CurMBB->addInstruction(std::move(instr));
return;
}
}
}
switch (bin->getKind()) { switch (bin->getKind()) {
case BinaryInst::kAdd: { case BinaryInst::kAdd: {
RVOpcodes opcode = (lhs->getType()->isPointer() || rhs->getType()->isPointer()) ? RVOpcodes::ADD : RVOpcodes::ADDW; RVOpcodes opcode = (lhs->getType()->isPointer() || rhs->getType()->isPointer()) ? RVOpcodes::ADD : RVOpcodes::ADDW;
@@ -294,16 +318,16 @@ void RISCv64ISel::selectNode(DAGNode* node) {
case BinaryInst::kICmpGT: { case BinaryInst::kICmpGT: {
auto instr = std::make_unique<MachineInstr>(RVOpcodes::SLT); auto instr = std::make_unique<MachineInstr>(RVOpcodes::SLT);
instr->addOperand(std::make_unique<RegOperand>(dest_vreg)); instr->addOperand(std::make_unique<RegOperand>(dest_vreg));
instr->addOperand(std::make_unique<RegOperand>(rhs_vreg)); // Swapped instr->addOperand(std::make_unique<RegOperand>(rhs_vreg));
instr->addOperand(std::make_unique<RegOperand>(lhs_vreg)); // Swapped instr->addOperand(std::make_unique<RegOperand>(lhs_vreg));
CurMBB->addInstruction(std::move(instr)); CurMBB->addInstruction(std::move(instr));
break; break;
} }
case BinaryInst::kICmpLE: { case BinaryInst::kICmpLE: {
auto slt = std::make_unique<MachineInstr>(RVOpcodes::SLT); auto slt = std::make_unique<MachineInstr>(RVOpcodes::SLT);
slt->addOperand(std::make_unique<RegOperand>(dest_vreg)); slt->addOperand(std::make_unique<RegOperand>(dest_vreg));
slt->addOperand(std::make_unique<RegOperand>(rhs_vreg)); // Swapped slt->addOperand(std::make_unique<RegOperand>(rhs_vreg));
slt->addOperand(std::make_unique<RegOperand>(lhs_vreg)); // Swapped slt->addOperand(std::make_unique<RegOperand>(lhs_vreg));
CurMBB->addInstruction(std::move(slt)); CurMBB->addInstruction(std::move(slt));
auto xori = std::make_unique<MachineInstr>(RVOpcodes::XORI); auto xori = std::make_unique<MachineInstr>(RVOpcodes::XORI);
@@ -335,8 +359,6 @@ void RISCv64ISel::selectNode(DAGNode* node) {
case DAGNode::UNARY: { case DAGNode::UNARY: {
auto unary = dynamic_cast<UnaryInst*>(node->value); auto unary = dynamic_cast<UnaryInst*>(node->value);
if (!unary) break;
auto dest_vreg = getVReg(unary); auto dest_vreg = getVReg(unary);
auto src_vreg = getVReg(unary->getOperand()); auto src_vreg = getVReg(unary->getOperand());
@@ -344,7 +366,7 @@ void RISCv64ISel::selectNode(DAGNode* node) {
case UnaryInst::kNeg: { case UnaryInst::kNeg: {
auto instr = std::make_unique<MachineInstr>(RVOpcodes::SUBW); auto instr = std::make_unique<MachineInstr>(RVOpcodes::SUBW);
instr->addOperand(std::make_unique<RegOperand>(dest_vreg)); instr->addOperand(std::make_unique<RegOperand>(dest_vreg));
instr->addOperand(std::make_unique<RegOperand>(PhysicalReg::ZERO)); // x0 instr->addOperand(std::make_unique<RegOperand>(PhysicalReg::ZERO));
instr->addOperand(std::make_unique<RegOperand>(src_vreg)); instr->addOperand(std::make_unique<RegOperand>(src_vreg));
CurMBB->addInstruction(std::move(instr)); CurMBB->addInstruction(std::move(instr));
break; break;
@@ -364,51 +386,67 @@ void RISCv64ISel::selectNode(DAGNode* node) {
case DAGNode::CALL: { case DAGNode::CALL: {
auto call = dynamic_cast<CallInst*>(node->value); auto call = dynamic_cast<CallInst*>(node->value);
if (!call) break; for (size_t i = 0; i < node->operands.size() && i < 8; ++i) {
DAGNode* arg_node = node->operands[i];
auto arg_preg = static_cast<PhysicalReg>(static_cast<int>(PhysicalReg::A0) + i);
if (arg_node->kind == DAGNode::CONSTANT) {
if (auto const_val = dynamic_cast<ConstantValue*>(arg_node->value)) {
auto li = std::make_unique<MachineInstr>(RVOpcodes::LI);
li->addOperand(std::make_unique<RegOperand>(arg_preg));
li->addOperand(std::make_unique<ImmOperand>(const_val->getInt()));
CurMBB->addInstruction(std::move(li));
}
} else {
auto src_vreg = getVReg(arg_node->value);
auto mv = std::make_unique<MachineInstr>(RVOpcodes::MV);
mv->addOperand(std::make_unique<RegOperand>(arg_preg));
mv->addOperand(std::make_unique<RegOperand>(src_vreg));
CurMBB->addInstruction(std::move(mv));
}
}
// 在此阶段,我们只处理函数调用本身和返回值的移动
// 参数的传递将在一个专门的 Calling Convention Pass 中处理
auto call_instr = std::make_unique<MachineInstr>(RVOpcodes::CALL); auto call_instr = std::make_unique<MachineInstr>(RVOpcodes::CALL);
call_instr->addOperand(std::make_unique<LabelOperand>(call->getCallee()->getName())); call_instr->addOperand(std::make_unique<LabelOperand>(call->getCallee()->getName()));
CurMBB->addInstruction(std::move(call_instr)); CurMBB->addInstruction(std::move(call_instr));
if (!call->getType()->isVoid()) { if (!call->getType()->isVoid()) {
auto mv_instr = std::make_unique<MachineInstr>(RVOpcodes::MV); auto mv_instr = std::make_unique<MachineInstr>(RVOpcodes::MV);
mv_instr->addOperand(std::make_unique<RegOperand>(getVReg(call))); // dest mv_instr->addOperand(std::make_unique<RegOperand>(getVReg(call)));
mv_instr->addOperand(std::make_unique<RegOperand>(PhysicalReg::A0)); // src mv_instr->addOperand(std::make_unique<RegOperand>(PhysicalReg::A0));
CurMBB->addInstruction(std::move(mv_instr)); CurMBB->addInstruction(std::move(mv_instr));
} }
break; break;
} }
case DAGNode::RETURN: { case DAGNode::RETURN: {
auto ret_inst = dynamic_cast<ReturnInst*>(node->value); auto ret_inst_ir = dynamic_cast<ReturnInst*>(node->value);
if (ret_inst && ret_inst->hasReturnValue()) { if (ret_inst_ir && ret_inst_ir->hasReturnValue()) {
// 如果有返回值生成一条mv指令将其放入a0 Value* ret_val = ret_inst_ir->getReturnValue();
auto mv_instr = std::make_unique<MachineInstr>(RVOpcodes::MV); if (auto const_val = dynamic_cast<ConstantValue*>(ret_val)) {
mv_instr->addOperand(std::make_unique<RegOperand>(PhysicalReg::A0)); auto li_instr = std::make_unique<MachineInstr>(RVOpcodes::LI);
mv_instr->addOperand(std::make_unique<RegOperand>(getVReg(ret_inst->getReturnValue()))); li_instr->addOperand(std::make_unique<RegOperand>(PhysicalReg::A0));
CurMBB->addInstruction(std::move(mv_instr)); li_instr->addOperand(std::make_unique<ImmOperand>(const_val->getInt()));
CurMBB->addInstruction(std::move(li_instr));
} else {
auto mv_instr = std::make_unique<MachineInstr>(RVOpcodes::MV);
mv_instr->addOperand(std::make_unique<RegOperand>(PhysicalReg::A0));
mv_instr->addOperand(std::make_unique<RegOperand>(getVReg(ret_val)));
CurMBB->addInstruction(std::move(mv_instr));
}
} }
// 生成ret伪指令 auto ret_mi = std::make_unique<MachineInstr>(RVOpcodes::RET);
auto instr = std::make_unique<MachineInstr>(RVOpcodes::RET); CurMBB->addInstruction(std::move(ret_mi));
CurMBB->addInstruction(std::move(instr));
break; break;
} }
case DAGNode::BRANCH: { case DAGNode::BRANCH: {
if (auto cond_br = dynamic_cast<CondBrInst*>(node->value)) { if (auto cond_br = dynamic_cast<CondBrInst*>(node->value)) {
// bne cond, x0, then_block
auto br_instr = std::make_unique<MachineInstr>(RVOpcodes::BNE); auto br_instr = std::make_unique<MachineInstr>(RVOpcodes::BNE);
br_instr->addOperand(std::make_unique<RegOperand>(getVReg(cond_br->getCondition()))); br_instr->addOperand(std::make_unique<RegOperand>(getVReg(cond_br->getCondition())));
br_instr->addOperand(std::make_unique<RegOperand>(PhysicalReg::ZERO)); br_instr->addOperand(std::make_unique<RegOperand>(PhysicalReg::ZERO));
br_instr->addOperand(std::make_unique<LabelOperand>(cond_br->getThenBlock()->getName())); br_instr->addOperand(std::make_unique<LabelOperand>(cond_br->getThenBlock()->getName()));
CurMBB->addInstruction(std::move(br_instr)); CurMBB->addInstruction(std::move(br_instr));
// j else_block
// 注意这里会产生一个fallthrough问题后续的分支优化pass会解决它
// 一个更健壮的生成方式是 bne -> j else; then: ...; else: ...
} else if (auto uncond_br = dynamic_cast<UncondBrInst*>(node->value)) { } else if (auto uncond_br = dynamic_cast<UncondBrInst*>(node->value)) {
auto j_instr = std::make_unique<MachineInstr>(RVOpcodes::J); auto j_instr = std::make_unique<MachineInstr>(RVOpcodes::J);
j_instr->addOperand(std::make_unique<LabelOperand>(uncond_br->getBlock()->getName())); j_instr->addOperand(std::make_unique<LabelOperand>(uncond_br->getBlock()->getName()));
@@ -416,20 +454,16 @@ void RISCv64ISel::selectNode(DAGNode* node) {
} }
break; break;
} }
case DAGNode::MEMSET: {
// 这是对原memset逻辑的完整LLIR翻译
auto memset = dynamic_cast<MemsetInst*>(node->value);
if (!memset) break;
case DAGNode::MEMSET: {
auto memset = dynamic_cast<MemsetInst*>(node->value);
auto r_dest_addr = getVReg(memset->getPointer()); auto r_dest_addr = getVReg(memset->getPointer());
auto r_num_bytes = getVReg(memset->getSize()); auto r_num_bytes = getVReg(memset->getSize());
auto r_value_byte = getVReg(memset->getValue()); auto r_value_byte = getVReg(memset->getValue());
auto r_counter = getNewVReg();
// 为临时值创建虚拟寄存器 auto r_end_addr = getNewVReg();
auto r_counter = vreg_counter++; auto r_current_addr = getNewVReg();
auto r_end_addr = vreg_counter++; auto r_temp_val = getNewVReg();
auto r_current_addr = vreg_counter++;
auto r_temp_val = vreg_counter++;
auto add_instr = [&](RVOpcodes op, unsigned rd, unsigned rs1, unsigned rs2) { auto add_instr = [&](RVOpcodes op, unsigned rd, unsigned rs1, unsigned rs2) {
auto i = std::make_unique<MachineInstr>(op); auto i = std::make_unique<MachineInstr>(op);
@@ -470,12 +504,11 @@ void RISCv64ISel::selectNode(DAGNode* node) {
}; };
int unique_id = this->local_label_counter++; int unique_id = this->local_label_counter++;
std::string loop_start_label = "memset_loop_start_" + std::to_string(unique_id); std::string loop_start_label = MFunc->getName() + "_memset_loop_start_" + std::to_string(unique_id);
std::string loop_end_label = "memset_loop_end_" + std::to_string(unique_id); std::string loop_end_label = MFunc->getName() + "_memset_loop_end_" + std::to_string(unique_id);
std::string remainder_label = "memset_remainder_" + std::to_string(unique_id); std::string remainder_label = MFunc->getName() + "_memset_remainder_" + std::to_string(unique_id);
std::string done_label = "memset_done_" + std::to_string(unique_id); std::string done_label = MFunc->getName() + "_memset_done_" + std::to_string(unique_id);
// 构造64位的填充值
addi_instr(RVOpcodes::ANDI, r_temp_val, r_value_byte, 255); addi_instr(RVOpcodes::ANDI, r_temp_val, r_value_byte, 255);
addi_instr(RVOpcodes::SLLI, r_value_byte, r_temp_val, 8); addi_instr(RVOpcodes::SLLI, r_value_byte, r_temp_val, 8);
add_instr(RVOpcodes::OR, r_temp_val, r_temp_val, r_value_byte); add_instr(RVOpcodes::OR, r_temp_val, r_temp_val, r_value_byte);
@@ -483,8 +516,6 @@ void RISCv64ISel::selectNode(DAGNode* node) {
add_instr(RVOpcodes::OR, r_temp_val, r_temp_val, r_value_byte); add_instr(RVOpcodes::OR, r_temp_val, r_temp_val, r_value_byte);
addi_instr(RVOpcodes::SLLI, r_value_byte, r_temp_val, 32); addi_instr(RVOpcodes::SLLI, r_value_byte, r_temp_val, 32);
add_instr(RVOpcodes::OR, r_temp_val, r_temp_val, r_value_byte); add_instr(RVOpcodes::OR, r_temp_val, r_temp_val, r_value_byte);
// 设置循环变量
add_instr(RVOpcodes::ADD, r_end_addr, r_dest_addr, r_num_bytes); add_instr(RVOpcodes::ADD, r_end_addr, r_dest_addr, r_num_bytes);
auto mv = std::make_unique<MachineInstr>(RVOpcodes::MV); auto mv = std::make_unique<MachineInstr>(RVOpcodes::MV);
mv->addOperand(std::make_unique<RegOperand>(r_current_addr)); mv->addOperand(std::make_unique<RegOperand>(r_current_addr));
@@ -492,16 +523,12 @@ void RISCv64ISel::selectNode(DAGNode* node) {
CurMBB->addInstruction(std::move(mv)); CurMBB->addInstruction(std::move(mv));
addi_instr(RVOpcodes::ANDI, r_counter, r_num_bytes, -8); addi_instr(RVOpcodes::ANDI, r_counter, r_num_bytes, -8);
add_instr(RVOpcodes::ADD, r_counter, r_dest_addr, r_counter); add_instr(RVOpcodes::ADD, r_counter, r_dest_addr, r_counter);
// 64位写入循环
label_instr(loop_start_label); label_instr(loop_start_label);
branch_instr(RVOpcodes::BGEU, r_current_addr, r_counter, loop_end_label); branch_instr(RVOpcodes::BGEU, r_current_addr, r_counter, loop_end_label);
store_instr(RVOpcodes::SD, r_temp_val, r_current_addr, 0); store_instr(RVOpcodes::SD, r_temp_val, r_current_addr, 0);
addi_instr(RVOpcodes::ADDI, r_current_addr, r_current_addr, 8); addi_instr(RVOpcodes::ADDI, r_current_addr, r_current_addr, 8);
jump_instr(loop_start_label); jump_instr(loop_start_label);
label_instr(loop_end_label); label_instr(loop_end_label);
// 剩余字节写入循环
label_instr(remainder_label); label_instr(remainder_label);
branch_instr(RVOpcodes::BGEU, r_current_addr, r_end_addr, done_label); branch_instr(RVOpcodes::BGEU, r_current_addr, r_end_addr, done_label);
store_instr(RVOpcodes::SB, r_temp_val, r_current_addr, 0); store_instr(RVOpcodes::SB, r_temp_val, r_current_addr, 0);
@@ -512,13 +539,13 @@ void RISCv64ISel::selectNode(DAGNode* node) {
} }
default: default:
throw std::runtime_error("Unsupported DAGNode kind in ISel: " + std::to_string(node->kind)); throw std::runtime_error("Unsupported DAGNode kind in ISel");
} }
} }
// 以下是忠实移植的DAG构建函数
// --- DAG构建函数 (从原RISCv64Backend.cpp几乎原样迁移, 保持不变) --- RISCv64ISel::DAGNode* RISCv64ISel::create_node(int kind_int, Value* val, std::map<Value*, DAGNode*>& value_to_node, std::vector<std::unique_ptr<DAGNode>>& nodes_storage) {
RISCv64ISel::DAGNode* RISCv64ISel::create_node(DAGNode::NodeKind kind, Value* val, std::map<Value*, DAGNode*>& value_to_node, std::vector<std::unique_ptr<DAGNode>>& nodes_storage) { auto kind = static_cast<DAGNode::NodeKind>(kind_int);
if (val && value_to_node.count(val) && kind != DAGNode::STORE && kind != DAGNode::RETURN && kind != DAGNode::BRANCH && kind != DAGNode::MEMSET) { if (val && value_to_node.count(val) && kind != DAGNode::STORE && kind != DAGNode::RETURN && kind != DAGNode::BRANCH && kind != DAGNode::MEMSET) {
return value_to_node[val]; return value_to_node[val];
} }
@@ -526,10 +553,7 @@ RISCv64ISel::DAGNode* RISCv64ISel::create_node(DAGNode::NodeKind kind, Value* va
node->value = val; node->value = val;
DAGNode* raw_node_ptr = node.get(); DAGNode* raw_node_ptr = node.get();
nodes_storage.push_back(std::move(node)); nodes_storage.push_back(std::move(node));
// 只有产生值的节点才应该被记录,以备复用 if (val && !val->getType()->isVoid() && (dynamic_cast<Instruction*>(val) || dynamic_cast<GlobalValue*>(val))) {
if (val && !val->getType()->isVoid() && dynamic_cast<Instruction*>(val)) {
value_to_node[val] = raw_node_ptr;
} else if (val && dynamic_cast<GlobalValue*>(val)) {
value_to_node[val] = raw_node_ptr; value_to_node[val] = raw_node_ptr;
} }
return raw_node_ptr; return raw_node_ptr;
@@ -545,7 +569,6 @@ RISCv64ISel::DAGNode* RISCv64ISel::get_operand_node(Value* val_ir, std::map<Valu
} else if (dynamic_cast<AllocaInst*>(val_ir)) { } else if (dynamic_cast<AllocaInst*>(val_ir)) {
return create_node(DAGNode::ALLOCA_ADDR, val_ir, value_to_node, nodes_storage); return create_node(DAGNode::ALLOCA_ADDR, val_ir, value_to_node, nodes_storage);
} }
// Fallback: Assume it needs to be loaded if not found (might be a parameter or a value from another block)
return create_node(DAGNode::LOAD, val_ir, value_to_node, nodes_storage); return create_node(DAGNode::LOAD, val_ir, value_to_node, nodes_storage);
} }
@@ -567,12 +590,20 @@ std::vector<std::unique_ptr<RISCv64ISel::DAGNode>> RISCv64ISel::build_dag(BasicB
memset_node->operands.push_back(get_operand_node(memset->getBegin(), value_to_node, nodes_storage)); memset_node->operands.push_back(get_operand_node(memset->getBegin(), value_to_node, nodes_storage));
memset_node->operands.push_back(get_operand_node(memset->getSize(), value_to_node, nodes_storage)); memset_node->operands.push_back(get_operand_node(memset->getSize(), value_to_node, nodes_storage));
memset_node->operands.push_back(get_operand_node(memset->getValue(), value_to_node, nodes_storage)); memset_node->operands.push_back(get_operand_node(memset->getValue(), value_to_node, nodes_storage));
} } else if (auto load = dynamic_cast<LoadInst*>(inst)) {
else if (auto load = dynamic_cast<LoadInst*>(inst)) {
auto load_node = create_node(DAGNode::LOAD, load, value_to_node, nodes_storage); auto load_node = create_node(DAGNode::LOAD, load, value_to_node, nodes_storage);
load_node->operands.push_back(get_operand_node(load->getPointer(), value_to_node, nodes_storage)); load_node->operands.push_back(get_operand_node(load->getPointer(), value_to_node, nodes_storage));
} else if (auto bin = dynamic_cast<BinaryInst*>(inst)) { } else if (auto bin = dynamic_cast<BinaryInst*>(inst)) {
if(value_to_node.count(bin)) continue; if(value_to_node.count(bin)) continue;
if (bin->getKind() == BinaryInst::kSub) {
if (auto const_lhs = dynamic_cast<ConstantValue*>(bin->getLhs())) {
if (const_lhs->getInt() == 0) {
auto unary_node = create_node(DAGNode::UNARY, bin, value_to_node, nodes_storage);
unary_node->operands.push_back(get_operand_node(bin->getRhs(), value_to_node, nodes_storage));
continue;
}
}
}
auto bin_node = create_node(DAGNode::BINARY, bin, value_to_node, nodes_storage); auto bin_node = create_node(DAGNode::BINARY, bin, value_to_node, nodes_storage);
bin_node->operands.push_back(get_operand_node(bin->getLhs(), value_to_node, nodes_storage)); bin_node->operands.push_back(get_operand_node(bin->getLhs(), value_to_node, nodes_storage));
bin_node->operands.push_back(get_operand_node(bin->getRhs(), value_to_node, nodes_storage)); bin_node->operands.push_back(get_operand_node(bin->getRhs(), value_to_node, nodes_storage));
@@ -580,8 +611,7 @@ std::vector<std::unique_ptr<RISCv64ISel::DAGNode>> RISCv64ISel::build_dag(BasicB
if(value_to_node.count(un)) continue; if(value_to_node.count(un)) continue;
auto unary_node = create_node(DAGNode::UNARY, un, value_to_node, nodes_storage); auto unary_node = create_node(DAGNode::UNARY, un, value_to_node, nodes_storage);
unary_node->operands.push_back(get_operand_node(un->getOperand(), value_to_node, nodes_storage)); unary_node->operands.push_back(get_operand_node(un->getOperand(), value_to_node, nodes_storage));
} } else if (auto call = dynamic_cast<CallInst*>(inst)) {
else if (auto call = dynamic_cast<CallInst*>(inst)) {
if(value_to_node.count(call)) continue; if(value_to_node.count(call)) continue;
auto call_node = create_node(DAGNode::CALL, call, value_to_node, nodes_storage); auto call_node = create_node(DAGNode::CALL, call, value_to_node, nodes_storage);
for (auto arg : call->getArguments()) { for (auto arg : call->getArguments()) {

8
src/RISCv64Passes.cpp Normal file
View File

@@ -0,0 +1,8 @@
// RISCv64Passes.cpp
#include "RISCv64Passes.h"
namespace sysy {
// 此处为未来优化Pass的实现
} // namespace sysy

View File

@@ -1,11 +1,11 @@
#include "RISCv64RegAlloc.h" #include "RISCv64RegAlloc.h"
#include "RISCv64ISel.h"
#include <algorithm> #include <algorithm>
#include <vector> #include <vector>
namespace sysy { namespace sysy {
RISCv64RegAlloc::RISCv64RegAlloc(MachineFunction* mfunc) : MFunc(mfunc) { RISCv64RegAlloc::RISCv64RegAlloc(MachineFunction* mfunc) : MFunc(mfunc) {
// 初始化可分配的整数寄存器池 (排除特殊用途的)
allocable_int_regs = { allocable_int_regs = {
PhysicalReg::T0, PhysicalReg::T1, PhysicalReg::T2, PhysicalReg::T3, PhysicalReg::T0, PhysicalReg::T1, PhysicalReg::T2, PhysicalReg::T3,
PhysicalReg::T4, PhysicalReg::T5, PhysicalReg::T6, PhysicalReg::T4, PhysicalReg::T5, PhysicalReg::T6,
@@ -18,23 +18,113 @@ RISCv64RegAlloc::RISCv64RegAlloc(MachineFunction* mfunc) : MFunc(mfunc) {
} }
void RISCv64RegAlloc::run() { void RISCv64RegAlloc::run() {
eliminateFrameIndices();
analyzeLiveness(); analyzeLiveness();
buildInterferenceGraph(); buildInterferenceGraph();
colorGraph(); colorGraph();
rewriteFunction(); rewriteFunction();
} }
void RISCv64RegAlloc::eliminateFrameIndices() {
StackFrameInfo& frame_info = MFunc->getFrameInfo();
int current_offset = 0;
Function* F = MFunc->getFunc();
RISCv64ISel* isel = MFunc->getISel();
for (auto& bb : F->getBasicBlocks()) {
for (auto& inst : bb->getInstructions()) {
if (auto alloca = dynamic_cast<AllocaInst*>(inst.get())) {
int size = 4;
if (!alloca->getDims().empty()) {
int num_elements = 1;
for (const auto& dim_use : alloca->getDims()) {
if (auto const_dim = dynamic_cast<ConstantValue*>(dim_use->getValue())) {
num_elements *= const_dim->getInt();
}
}
size *= num_elements;
}
current_offset += size;
unsigned alloca_vreg = isel->getVReg(alloca);
frame_info.alloca_offsets[alloca_vreg] = -current_offset;
}
}
}
frame_info.locals_size = current_offset;
for (auto& mbb : MFunc->getBlocks()) {
std::vector<std::unique_ptr<MachineInstr>> new_instructions;
for (auto& instr_ptr : mbb->getInstructions()) {
if (instr_ptr->getOpcode() == RVOpcodes::FRAME_LOAD) {
auto& operands = instr_ptr->getOperands();
unsigned dest_vreg = static_cast<RegOperand*>(operands[0].get())->getVRegNum();
unsigned alloca_vreg = static_cast<RegOperand*>(operands[1].get())->getVRegNum();
int offset = frame_info.alloca_offsets.at(alloca_vreg);
auto addr_vreg = isel->getNewVReg();
auto addi = std::make_unique<MachineInstr>(RVOpcodes::ADDI);
addi->addOperand(std::make_unique<RegOperand>(addr_vreg));
addi->addOperand(std::make_unique<RegOperand>(PhysicalReg::S0));
addi->addOperand(std::make_unique<ImmOperand>(offset));
new_instructions.push_back(std::move(addi));
auto lw = std::make_unique<MachineInstr>(RVOpcodes::LW);
lw->addOperand(std::make_unique<RegOperand>(dest_vreg));
lw->addOperand(std::make_unique<MemOperand>(
std::make_unique<RegOperand>(addr_vreg),
std::make_unique<ImmOperand>(0)));
new_instructions.push_back(std::move(lw));
} else if (instr_ptr->getOpcode() == RVOpcodes::FRAME_STORE) {
auto& operands = instr_ptr->getOperands();
unsigned src_vreg = static_cast<RegOperand*>(operands[0].get())->getVRegNum();
unsigned alloca_vreg = static_cast<RegOperand*>(operands[1].get())->getVRegNum();
int offset = frame_info.alloca_offsets.at(alloca_vreg);
auto addr_vreg = isel->getNewVReg();
auto addi = std::make_unique<MachineInstr>(RVOpcodes::ADDI);
addi->addOperand(std::make_unique<RegOperand>(addr_vreg));
addi->addOperand(std::make_unique<RegOperand>(PhysicalReg::S0));
addi->addOperand(std::make_unique<ImmOperand>(offset));
new_instructions.push_back(std::move(addi));
auto sw = std::make_unique<MachineInstr>(RVOpcodes::SW);
sw->addOperand(std::make_unique<RegOperand>(src_vreg));
sw->addOperand(std::make_unique<MemOperand>(
std::make_unique<RegOperand>(addr_vreg),
std::make_unique<ImmOperand>(0)));
new_instructions.push_back(std::move(sw));
} else {
new_instructions.push_back(std::move(instr_ptr));
}
}
mbb->getInstructions() = std::move(new_instructions);
}
}
void RISCv64RegAlloc::getInstrUseDef(MachineInstr* instr, LiveSet& use, LiveSet& def) { void RISCv64RegAlloc::getInstrUseDef(MachineInstr* instr, LiveSet& use, LiveSet& def) {
// 这是一个简化的版本实际需要根据RVOpcodes精确定义
// 通常第一个RegOperand是def其余是use
bool is_def = true; bool is_def = true;
auto opcode = instr->getOpcode();
// 预定义def和use规则
if (opcode == RVOpcodes::SW || opcode == RVOpcodes::SD ||
opcode == RVOpcodes::BEQ || opcode == RVOpcodes::BNE ||
opcode == RVOpcodes::BLT || opcode == RVOpcodes::BGE ||
opcode == RVOpcodes::RET || opcode == RVOpcodes::J) {
is_def = false;
}
if (opcode == RVOpcodes::CALL) {
// CALL会杀死所有调用者保存寄存器这是一个简化处理
// 同时也使用了传入a0-a7的参数
}
for (const auto& op : instr->getOperands()) { for (const auto& op : instr->getOperands()) {
if (op->getKind() == MachineOperand::KIND_REG) { if (op->getKind() == MachineOperand::KIND_REG) {
auto reg_op = static_cast<RegOperand*>(op.get()); auto reg_op = static_cast<RegOperand*>(op.get());
if (reg_op->isVirtual()) { if (reg_op->isVirtual()) {
if (is_def) { if (is_def) {
def.insert(reg_op->getVRegNum()); def.insert(reg_op->getVRegNum());
is_def = false; // 假设每条指令最多一个def is_def = false;
} else { } else {
use.insert(reg_op->getVRegNum()); use.insert(reg_op->getVRegNum());
} }
@@ -46,35 +136,16 @@ void RISCv64RegAlloc::getInstrUseDef(MachineInstr* instr, LiveSet& use, LiveSet&
} }
} }
} }
// 特殊处理store和branch指令它们没有显式的def
auto opcode = instr->getOpcode();
if (opcode == RVOpcodes::SW || opcode == RVOpcodes::SD || opcode == RVOpcodes::BNE || opcode == RVOpcodes::BEQ) {
def.clear(); // 清空错误的def
use.clear();
for (const auto& op : instr->getOperands()) {
if (op->getKind() == MachineOperand::KIND_REG) {
auto reg_op = static_cast<RegOperand*>(op.get());
if(reg_op->isVirtual()) use.insert(reg_op->getVRegNum());
} else if (op->getKind() == MachineOperand::KIND_MEM) {
auto mem_op = static_cast<MemOperand*>(op.get());
if(mem_op->getBase()->isVirtual()) use.insert(mem_op->getBase()->getVRegNum());
}
}
}
} }
void RISCv64RegAlloc::analyzeLiveness() { void RISCv64RegAlloc::analyzeLiveness() {
bool changed = true; bool changed = true;
while (changed) { while (changed) {
changed = false; changed = false;
// 逆序遍历基本块
for (auto it = MFunc->getBlocks().rbegin(); it != MFunc->getBlocks().rend(); ++it) { for (auto it = MFunc->getBlocks().rbegin(); it != MFunc->getBlocks().rend(); ++it) {
auto& mbb = *it; auto& mbb = *it;
LiveSet live_out; LiveSet live_out;
for (auto succ : mbb->successors) { for (auto succ : mbb->successors) {
// live_out[B] = Union(live_in[S]) for all S in succ(B)
if (!succ->getInstructions().empty()) { if (!succ->getInstructions().empty()) {
auto first_instr = succ->getInstructions().front().get(); auto first_instr = succ->getInstructions().front().get();
if (live_in_map.count(first_instr)) { if (live_in_map.count(first_instr)) {
@@ -83,19 +154,14 @@ void RISCv64RegAlloc::analyzeLiveness() {
} }
} }
// 逆序遍历指令
for (auto instr_it = mbb->getInstructions().rbegin(); instr_it != mbb->getInstructions().rend(); ++instr_it) { for (auto instr_it = mbb->getInstructions().rbegin(); instr_it != mbb->getInstructions().rend(); ++instr_it) {
MachineInstr* instr = instr_it->get(); MachineInstr* instr = instr_it->get();
LiveSet old_live_in = live_in_map[instr]; LiveSet old_live_in = live_in_map[instr];
LiveSet old_live_out = live_out_map[instr];
// 更新 live_out
live_out_map[instr] = live_out; live_out_map[instr] = live_out;
LiveSet use, def; LiveSet use, def;
getInstrUseDef(instr, use, def); getInstrUseDef(instr, use, def);
// live_in[i] = use[i] U (live_out[i] - def[i])
LiveSet live_in = use; LiveSet live_in = use;
LiveSet diff = live_out; LiveSet diff = live_out;
for (auto vreg : def) { for (auto vreg : def) {
@@ -104,10 +170,9 @@ void RISCv64RegAlloc::analyzeLiveness() {
live_in.insert(diff.begin(), diff.end()); live_in.insert(diff.begin(), diff.end());
live_in_map[instr] = live_in; live_in_map[instr] = live_in;
// 为下一次迭代准备live_out
live_out = live_in; live_out = live_in;
if (live_in_map[instr] != old_live_in || live_out_map[instr] != old_live_out) { if (live_in_map[instr] != old_live_in) {
changed = true; changed = true;
} }
} }
@@ -117,21 +182,21 @@ void RISCv64RegAlloc::analyzeLiveness() {
void RISCv64RegAlloc::buildInterferenceGraph() { void RISCv64RegAlloc::buildInterferenceGraph() {
std::set<unsigned> all_vregs; std::set<unsigned> all_vregs;
// 收集所有虚拟寄存器 for (auto& mbb : MFunc->getBlocks()) {
for (auto const& [instr, live_set] : live_out_map) { for(auto& instr : mbb->getInstructions()) {
all_vregs.insert(live_set.begin(), live_set.end()); LiveSet use, def;
getInstrUseDef(instr.get(), use, def);
for(auto u : use) all_vregs.insert(u);
for(auto d : def) all_vregs.insert(d);
}
} }
// 初始化图 for (auto vreg : all_vregs) { interference_graph[vreg] = {}; }
for (auto vreg : all_vregs) {
interference_graph[vreg] = {};
}
for (auto& mbb : MFunc->getBlocks()) { for (auto& mbb : MFunc->getBlocks()) {
for (auto& instr : mbb->getInstructions()) { for (auto& instr : mbb->getInstructions()) {
LiveSet def, use; LiveSet def, use;
getInstrUseDef(instr.get(), use, def); getInstrUseDef(instr.get(), use, def);
const LiveSet& live_out = live_out_map.at(instr.get()); const LiveSet& live_out = live_out_map.at(instr.get());
for (unsigned d : def) { for (unsigned d : def) {
@@ -152,21 +217,18 @@ void RISCv64RegAlloc::colorGraph() {
sorted_vregs.push_back(vreg); sorted_vregs.push_back(vreg);
} }
// 按度数降序排序 (简单贪心策略)
std::sort(sorted_vregs.begin(), sorted_vregs.end(), [&](unsigned a, unsigned b) { std::sort(sorted_vregs.begin(), sorted_vregs.end(), [&](unsigned a, unsigned b) {
return interference_graph[a].size() > interference_graph[b].size(); return interference_graph[a].size() > interference_graph[b].size();
}); });
for (unsigned vreg : sorted_vregs) { for (unsigned vreg : sorted_vregs) {
std::set<PhysicalReg> used_colors; std::set<PhysicalReg> used_colors;
// 查找邻居已用的颜色
for (unsigned neighbor : interference_graph.at(vreg)) { for (unsigned neighbor : interference_graph.at(vreg)) {
if (color_map.count(neighbor)) { if (color_map.count(neighbor)) {
used_colors.insert(color_map.at(neighbor)); used_colors.insert(color_map.at(neighbor));
} }
} }
// 寻找一个可用的颜色
bool colored = false; bool colored = false;
for (PhysicalReg preg : allocable_int_regs) { for (PhysicalReg preg : allocable_int_regs) {
if (used_colors.find(preg) == used_colors.end()) { if (used_colors.find(preg) == used_colors.end()) {
@@ -175,54 +237,47 @@ void RISCv64RegAlloc::colorGraph() {
break; break;
} }
} }
if (!colored) { if (!colored) {
// 无法分配,需要溢出
spilled_vregs.insert(vreg); spilled_vregs.insert(vreg);
} }
} }
} }
void RISCv64RegAlloc::rewriteFunction() { void RISCv64RegAlloc::rewriteFunction() {
// 1. 为所有溢出的vreg分配栈槽
StackFrameInfo& frame_info = MFunc->getFrameInfo(); StackFrameInfo& frame_info = MFunc->getFrameInfo();
int current_offset = frame_info.frame_size; // 假设从现有栈大小后开始分配 int current_offset = frame_info.locals_size;
for (unsigned vreg : spilled_vregs) { for (unsigned vreg : spilled_vregs) {
current_offset += 4; // 假设所有溢出变量都占4字节 current_offset += 4;
frame_info.spill_slots[vreg] = -current_offset; // 栈向下增长,所以是负偏移 frame_info.spill_offsets[vreg] = -current_offset;
} }
frame_info.frame_size = current_offset; frame_info.spill_size = current_offset - frame_info.locals_size;
// 2. 遍历所有指令替换vreg并插入spill代码
for (auto& mbb : MFunc->getBlocks()) { for (auto& mbb : MFunc->getBlocks()) {
std::vector<std::unique_ptr<MachineInstr>> new_instructions; std::vector<std::unique_ptr<MachineInstr>> new_instructions;
for (auto& instr_ptr : mbb->getInstructions()) { for (auto& instr_ptr : mbb->getInstructions()) {
LiveSet use, def; LiveSet use, def;
getInstrUseDef(instr_ptr.get(), use, def); getInstrUseDef(instr_ptr.get(), use, def);
// 为use的溢出变量插入LOAD
for (unsigned vreg : use) { for (unsigned vreg : use) {
if (spilled_vregs.count(vreg)) { if (spilled_vregs.count(vreg)) {
int offset = frame_info.spill_slots.at(vreg); int offset = frame_info.spill_offsets.at(vreg);
auto load = std::make_unique<MachineInstr>(RVOpcodes::LW); auto load = std::make_unique<MachineInstr>(RVOpcodes::LW);
load->addOperand(std::make_unique<RegOperand>(vreg)); // 临时用vreg号代表稍后替换 load->addOperand(std::make_unique<RegOperand>(vreg));
load->addOperand(std::make_unique<MemOperand>( load->addOperand(std::make_unique<MemOperand>(
std::make_unique<RegOperand>(PhysicalReg::S0), // 基址用帧指针s0 std::make_unique<RegOperand>(PhysicalReg::S0),
std::make_unique<ImmOperand>(offset) std::make_unique<ImmOperand>(offset)
)); ));
new_instructions.push_back(std::move(load)); new_instructions.push_back(std::move(load));
} }
} }
// 添加原始指令
new_instructions.push_back(std::move(instr_ptr)); new_instructions.push_back(std::move(instr_ptr));
// 为def的溢出变量插入STORE
for (unsigned vreg : def) { for (unsigned vreg : def) {
if (spilled_vregs.count(vreg)) { if (spilled_vregs.count(vreg)) {
int offset = frame_info.spill_slots.at(vreg); int offset = frame_info.spill_offsets.at(vreg);
auto store = std::make_unique<MachineInstr>(RVOpcodes::SW); auto store = std::make_unique<MachineInstr>(RVOpcodes::SW);
store->addOperand(std::make_unique<RegOperand>(vreg)); // 临时用vreg号代表 store->addOperand(std::make_unique<RegOperand>(vreg));
store->addOperand(std::make_unique<MemOperand>( store->addOperand(std::make_unique<MemOperand>(
std::make_unique<RegOperand>(PhysicalReg::S0), std::make_unique<RegOperand>(PhysicalReg::S0),
std::make_unique<ImmOperand>(offset) std::make_unique<ImmOperand>(offset)
@@ -234,7 +289,6 @@ void RISCv64RegAlloc::rewriteFunction() {
mbb->getInstructions() = std::move(new_instructions); mbb->getInstructions() = std::move(new_instructions);
} }
// 3. 最后一遍扫描将所有RegOperand从vreg替换为preg
for (auto& mbb : MFunc->getBlocks()) { for (auto& mbb : MFunc->getBlocks()) {
for (auto& instr_ptr : mbb->getInstructions()) { for (auto& instr_ptr : mbb->getInstructions()) {
for (auto& op_ptr : instr_ptr->getOperands()) { for (auto& op_ptr : instr_ptr->getOperands()) {
@@ -245,8 +299,7 @@ void RISCv64RegAlloc::rewriteFunction() {
if (color_map.count(vreg)) { if (color_map.count(vreg)) {
reg_op->setPReg(color_map.at(vreg)); reg_op->setPReg(color_map.at(vreg));
} else if (spilled_vregs.count(vreg)) { } else if (spilled_vregs.count(vreg)) {
// 对于spill的vreg, 使用一个固定的临时寄存器, 比如t6 reg_op->setPReg(PhysicalReg::T6); // 溢出统一用t6
reg_op->setPReg(PhysicalReg::T6);
} }
} }
} else if (op_ptr->getKind() == MachineOperand::KIND_MEM) { } else if (op_ptr->getKind() == MachineOperand::KIND_MEM) {
@@ -254,7 +307,11 @@ void RISCv64RegAlloc::rewriteFunction() {
auto base_reg_op = mem_op->getBase(); auto base_reg_op = mem_op->getBase();
if(base_reg_op->isVirtual()){ if(base_reg_op->isVirtual()){
unsigned vreg = base_reg_op->getVRegNum(); unsigned vreg = base_reg_op->getVRegNum();
if(color_map.count(vreg)) base_reg_op->setPReg(color_map.at(vreg)); if(color_map.count(vreg)) {
base_reg_op->setPReg(color_map.at(vreg));
} else if (spilled_vregs.count(vreg)) {
base_reg_op->setPReg(PhysicalReg::T6);
}
} }
} }
} }

View File

@@ -8,29 +8,23 @@ namespace sysy {
class RISCv64AsmPrinter { class RISCv64AsmPrinter {
public: public:
// 主入口将整个MachineFunction打印到指定的输出流 RISCv64AsmPrinter(MachineFunction* mfunc);
void runOnMachineFunction(MachineFunction* mfunc, std::ostream& os); // 主入口
void run(std::ostream& os);
private: private:
// 打印单个基本块 // 打印各个部分
void printPrologue();
void printEpilogue();
void printBasicBlock(MachineBasicBlock* mbb); void printBasicBlock(MachineBasicBlock* mbb);
void printInstruction(MachineInstr* instr);
// 打印单条指令
void printInstruction(MachineInstr* instr, MachineBasicBlock* parent_bb);
// 打印函数序言
void printPrologue(MachineFunction* mfunc);
// 打印函数尾声 // 辅助函数
void printEpilogue(MachineFunction* mfunc);
// 将物理寄存器枚举转换为字符串 (从原RISCv64Backend迁移)
std::string regToString(PhysicalReg reg); std::string regToString(PhysicalReg reg);
// 打印单个操作数
void printOperand(MachineOperand* op); void printOperand(MachineOperand* op);
std::ostream* OS; // 指向当前输出流 MachineFunction* MFunc;
std::ostream* OS;
}; };
} // namespace sysy } // namespace sysy

View File

@@ -1,10 +1,8 @@
#ifndef RISCV64_BACKEND_H #ifndef RISCV64_BACKEND_H
#define RISCV64_BACKEND_H #define RISCV64_BACKEND_H
#include "IR.h" // 只需包含高层IR定义 #include "IR.h"
#include <string> #include <string>
#include <vector>
#include <memory>
namespace sysy { namespace sysy {
@@ -12,14 +10,12 @@ namespace sysy {
class RISCv64CodeGen { class RISCv64CodeGen {
public: public:
RISCv64CodeGen(Module* mod) : module(mod) {} RISCv64CodeGen(Module* mod) : module(mod) {}
// 唯一的公共入口点 // 唯一的公共入口点
std::string code_gen(); std::string code_gen();
private: private:
// 模块级代码生成 (处理全局变量和驱动函数生成) // 模块级代码生成
std::string module_gen(); std::string module_gen();
// 函数级代码生成 (实现新的流水线) // 函数级代码生成 (实现新的流水线)
std::string function_gen(Function* func); std::string function_gen(Function* func);

View File

@@ -1,10 +1,7 @@
#ifndef RISCV64_ISEL_H #ifndef RISCV64_ISEL_H
#define RISCV64_ISEL_H #define RISCV64_ISEL_H
#include "IR.h"
#include "RISCv64LLIR.h" #include "RISCv64LLIR.h"
#include <memory>
#include <map>
namespace sysy { namespace sysy {
@@ -14,43 +11,34 @@ public:
// 模块主入口将一个高层IR函数转换为底层LLIR函数 // 模块主入口将一个高层IR函数转换为底层LLIR函数
std::unique_ptr<MachineFunction> runOnFunction(Function* func); std::unique_ptr<MachineFunction> runOnFunction(Function* func);
// 公开接口以便后续模块如RegAlloc可以查询或创建vreg
unsigned getVReg(Value* val);
unsigned getNewVReg() { return vreg_counter++; }
private: private:
// DAG节点定义作为ISel的内部实现细节 // DAG节点定义作为ISel的内部实现细节
struct DAGNode { struct DAGNode;
enum NodeKind { CONSTANT, LOAD, STORE, BINARY, CALL, RETURN, BRANCH, ALLOCA_ADDR, UNARY, MEMSET };
NodeKind kind; // 指令选择主流程
Value* value = nullptr;
std::vector<DAGNode*> operands;
DAGNode(NodeKind k) : kind(k) {}
};
// 为当前函数生成LLIR
void select(); void select();
// 为单个基本块生成指令 // 为单个基本块生成指令
void selectBasicBlock(BasicBlock* bb); void selectBasicBlock(BasicBlock* bb);
// 核心函数为DAG节点选择并生成MachineInstr // 核心函数为DAG节点选择并生成MachineInstr
void selectNode(DAGNode* node); void selectNode(DAGNode* node);
// --- DAG 构建相关函数 (从原RISCv64Backend迁移) --- // DAG 构建相关函数 (从原RISCv64Backend迁移)
std::vector<std::unique_ptr<DAGNode>> build_dag(BasicBlock* bb); std::vector<std::unique_ptr<DAGNode>> build_dag(BasicBlock* bb);
DAGNode* get_operand_node(Value* val_ir, std::map<Value*, DAGNode*>& value_to_node, std::vector<std::unique_ptr<DAGNode>>& nodes_storage); DAGNode* get_operand_node(Value* val_ir, std::map<Value*, DAGNode*>&, std::vector<std::unique_ptr<DAGNode>>&);
DAGNode* create_node(DAGNode::NodeKind kind, Value* val, std::map<Value*, DAGNode*>& value_to_node, std::vector<std::unique_ptr<DAGNode>>& nodes_storage); DAGNode* create_node(int kind, Value* val, std::map<Value*, DAGNode*>&, std::vector<std::unique_ptr<DAGNode>>&);
// --- 辅助函数 ---
// 为一个IR Value获取/分配一个虚拟寄存器号
unsigned getVReg(Value* val);
// 状态
Function* F; // 当前处理的高层IR函数 Function* F; // 当前处理的高层IR函数
std::unique_ptr<MachineFunction> MFunc; // 正在构建的底层LLIR函数 std::unique_ptr<MachineFunction> MFunc; // 正在构建的底层LLIR函数
MachineBasicBlock* CurMBB; // 当前正在处理的机器基本块 MachineBasicBlock* CurMBB; // 当前正在处理的机器基本块
// 映射关系 // 映射关系
std::map<Value*, unsigned> vreg_map; std::map<Value*, unsigned> vreg_map;
std::map<const BasicBlock*, MachineBasicBlock*> bb_map; std::map<const BasicBlock*, MachineBasicBlock*> bb_map;
std::map<Value*, DAGNode*> value_to_node_map; // 用于selectNode中查找
unsigned vreg_counter; unsigned vreg_counter;
int local_label_counter; int local_label_counter;

View File

@@ -1,66 +1,51 @@
#ifndef RISCV64_LLIR_H #ifndef RISCV64_LLIR_H
#define RISCV64_LLIR_H #define RISCV64_LLIR_H
#include "IR.h" // 确保包含了您自己的IR头文件
#include <string> #include <string>
#include <vector> #include <vector>
#include <memory> #include <memory>
#include <cstdint> #include <cstdint>
#include <map> #include <map>
// 前向声明,避免循环引用
namespace sysy {
class Function;
class RISCv64ISel;
}
namespace sysy { namespace sysy {
// 物理寄存器定义 (从 RISCv64Backend.h 移至此) // 物理寄存器定义
enum class PhysicalReg { enum class PhysicalReg {
ZERO, RA, SP, GP, TP, T0, T1, T2, S0, S1, A0, A1, A2, A3, A4, A5, A6, A7, S2, S3, S4, S5, S6, S7, S8, S9, S10, S11, T3, T4, T5, T6, ZERO, RA, SP, GP, TP, T0, T1, T2, S0, S1, A0, A1, A2, A3, A4, A5, A6, A7, S2, S3, S4, S5, S6, S7, S8, S9, S10, S11, T3, T4, T5, T6,
F0, F1, F2, F3, F4, F5, F6, F7, F8, F9, F10, F11, F12, F13, F14, F15,F16, F17, F18, F19, F20, F21, F22, F23, F24, F25, F26, F27, F28, F29, F30, F31 F0, F1, F2, F3, F4, F5, F6, F7, F8, F9, F10, F11, F12, F13, F14, F15,F16, F17, F18, F19, F20, F21, F22, F23, F24, F25, F26, F27, F28, F29, F30, F31
}; };
// RISC-V 指令操作码枚举 // RISC-V 指令操作码枚举
enum class RVOpcodes { enum class RVOpcodes {
// 算术指令 // 算术指令
ADD, ADDI, ADDW, ADDIW, ADD, ADDI, ADDW, ADDIW, SUB, SUBW, MUL, MULW, DIV, DIVW, REM, REMW,
SUB, SUBW,
MUL, MULW,
DIV, DIVW,
REM, REMW,
// 逻辑指令 // 逻辑指令
XOR, XORI, XOR, XORI, OR, ORI, AND, ANDI,
OR, ORI,
AND, ANDI,
// 移位指令 // 移位指令
SLL, SLLI, SLLW, SLLIW, SLL, SLLI, SLLW, SLLIW, SRL, SRLI, SRLW, SRLIW, SRA, SRAI, SRAW, SRAIW,
SRL, SRLI, SRLW, SRLIW,
SRA, SRAI, SRAW, SRAIW,
// 比较指令 // 比较指令
SLT, SLTI, SLTU, SLTIU, SLT, SLTI, SLTU, SLTIU,
// 内存访问指令 // 内存访问指令
LW, LH, LB, LWU, LHU, LBU, LW, LH, LB, LWU, LHU, LBU, SW, SH, SB, LD, SD,
SW, SH, SB,
LD, SD, // 64位
// 控制流指令 // 控制流指令
J, JAL, JALR, RET, // RET 是 JALR x0, 0(ra) 的伪指令 J, JAL, JALR, RET,
BEQ, BNE, BLT, BGE, BLTU, BGEU, BEQ, BNE, BLT, BGE, BLTU, BGEU,
// 伪指令
// 伪指令 (方便指令选择) LI, LA, MV, NEG, NEGW, SEQZ, SNEZ,
LI, // Load Immediate
LA, // Load Address
MV, // Move register
NEG, // Negate
NEGW, // Negate Word
SEQZ, // Set if Equal to Zero
SNEZ, // Set if Not Equal to Zero
// 函数调用 // 函数调用
CALL, CALL,
// 特殊标记,非指令 // 特殊标记,非指令
LABEL, // 用于表示一个标签位置 LABEL,
// 新增伪指令,用于解耦栈帧处理
FRAME_LOAD, // 从栈帧加载 (AllocaInst)
FRAME_STORE, // 保存到栈帧 (AllocaInst)
}; };
class MachineOperand; class MachineOperand;
@@ -72,22 +57,13 @@ class MachineInstr;
class MachineBasicBlock; class MachineBasicBlock;
class MachineFunction; class MachineFunction;
// --- 操作数定义 ---
// 操作数基类 // 操作数基类
class MachineOperand { class MachineOperand {
public: public:
enum OperandKind { enum OperandKind { KIND_REG, KIND_IMM, KIND_LABEL, KIND_MEM };
KIND_REG,
KIND_IMM,
KIND_LABEL,
KIND_MEM
};
MachineOperand(OperandKind kind) : kind(kind) {} MachineOperand(OperandKind kind) : kind(kind) {}
virtual ~MachineOperand() = default; virtual ~MachineOperand() = default;
OperandKind getKind() const { return kind; } OperandKind getKind() const { return kind; }
private: private:
OperandKind kind; OperandKind kind;
}; };
@@ -111,7 +87,6 @@ public:
preg = new_preg; preg = new_preg;
is_virtual = false; is_virtual = false;
} }
private: private:
unsigned vreg_num = 0; unsigned vreg_num = 0;
PhysicalReg preg = PhysicalReg::ZERO; PhysicalReg preg = PhysicalReg::ZERO;
@@ -121,9 +96,7 @@ private:
// 立即数操作数 // 立即数操作数
class ImmOperand : public MachineOperand { class ImmOperand : public MachineOperand {
public: public:
ImmOperand(int64_t value) ImmOperand(int64_t value) : MachineOperand(KIND_IMM), value(value) {}
: MachineOperand(KIND_IMM), value(value) {}
int64_t getValue() const { return value; } int64_t getValue() const { return value; }
private: private:
int64_t value; int64_t value;
@@ -132,9 +105,7 @@ private:
// 标签操作数 // 标签操作数
class LabelOperand : public MachineOperand { class LabelOperand : public MachineOperand {
public: public:
LabelOperand(const std::string& name) LabelOperand(const std::string& name) : MachineOperand(KIND_LABEL), name(name) {}
: MachineOperand(KIND_LABEL), name(name) {}
const std::string& getName() const { return name; } const std::string& getName() const { return name; }
private: private:
std::string name; std::string name;
@@ -145,33 +116,25 @@ class MemOperand : public MachineOperand {
public: public:
MemOperand(std::unique_ptr<RegOperand> base, std::unique_ptr<ImmOperand> offset) MemOperand(std::unique_ptr<RegOperand> base, std::unique_ptr<ImmOperand> offset)
: MachineOperand(KIND_MEM), base(std::move(base)), offset(std::move(offset)) {} : MachineOperand(KIND_MEM), base(std::move(base)), offset(std::move(offset)) {}
RegOperand* getBase() const { return base.get(); } RegOperand* getBase() const { return base.get(); }
ImmOperand* getOffset() const { return offset.get(); } ImmOperand* getOffset() const { return offset.get(); }
private: private:
std::unique_ptr<RegOperand> base; std::unique_ptr<RegOperand> base;
std::unique_ptr<ImmOperand> offset; std::unique_ptr<ImmOperand> offset;
}; };
// --- 组织结构定义 ---
// 机器指令 // 机器指令
class MachineInstr { class MachineInstr {
public: public:
MachineInstr(RVOpcodes opcode) : opcode(opcode) {} MachineInstr(RVOpcodes opcode) : opcode(opcode) {}
RVOpcodes getOpcode() const { return opcode; } RVOpcodes getOpcode() const { return opcode; }
// 注意返回const引用因为通常不直接修改指令的操作数列表
const std::vector<std::unique_ptr<MachineOperand>>& getOperands() const { return operands; } const std::vector<std::unique_ptr<MachineOperand>>& getOperands() const { return operands; }
// 提供一个非const版本用于内部修改
std::vector<std::unique_ptr<MachineOperand>>& getOperands() { return operands; } std::vector<std::unique_ptr<MachineOperand>>& getOperands() { return operands; }
void addOperand(std::unique_ptr<MachineOperand> operand) { void addOperand(std::unique_ptr<MachineOperand> operand) {
operands.push_back(std::move(operand)); operands.push_back(std::move(operand));
} }
private: private:
RVOpcodes opcode; RVOpcodes opcode;
std::vector<std::unique_ptr<MachineOperand>> operands; std::vector<std::unique_ptr<MachineOperand>> operands;
@@ -185,8 +148,6 @@ public:
const std::string& getName() const { return name; } const std::string& getName() const { return name; }
MachineFunction* getParent() const { return parent; } MachineFunction* getParent() const { return parent; }
// 同时提供 const 和 non-const 版本
const std::vector<std::unique_ptr<MachineInstr>>& getInstructions() const { return instructions; } const std::vector<std::unique_ptr<MachineInstr>>& getInstructions() const { return instructions; }
std::vector<std::unique_ptr<MachineInstr>>& getInstructions() { return instructions; } std::vector<std::unique_ptr<MachineInstr>>& getInstructions() { return instructions; }
@@ -196,43 +157,44 @@ public:
std::vector<MachineBasicBlock*> successors; std::vector<MachineBasicBlock*> successors;
std::vector<MachineBasicBlock*> predecessors; std::vector<MachineBasicBlock*> predecessors;
private: private:
std::string name; std::string name;
std::vector<std::unique_ptr<MachineInstr>> instructions; std::vector<std::unique_ptr<MachineInstr>> instructions;
MachineFunction* parent; // 指向所属函数 MachineFunction* parent;
}; };
// 栈帧信息 // 栈帧信息
struct StackFrameInfo { struct StackFrameInfo {
int frame_size = 0; int locals_size = 0; // 仅为AllocaInst分配的大小
std::map<int, int> spill_slots; // <虚拟寄存器号, 栈偏移> int spill_size = 0; // 仅为溢出分配的大小
// ... 未来可以添加更多信息 int total_size = 0; // 总大小
std::map<unsigned, int> alloca_offsets; // <AllocaInst的vreg, 栈偏移>
std::map<unsigned, int> spill_offsets; // <溢出vreg, 栈偏移>
}; };
// 机器函数 // 机器函数
class MachineFunction { class MachineFunction {
public: public:
MachineFunction(const std::string& name) : name(name) {} MachineFunction(Function* func, RISCv64ISel* isel) : F(func), name(func->getName()), isel(isel) {}
Function* getFunc() const { return F; }
RISCv64ISel* getISel() const { return isel; }
const std::string& getName() const { return name; } const std::string& getName() const { return name; }
StackFrameInfo& getFrameInfo() { return frame_info; } StackFrameInfo& getFrameInfo() { return frame_info; }
// 同时提供 const 和 non-const 版本
const std::vector<std::unique_ptr<MachineBasicBlock>>& getBlocks() const { return blocks; } const std::vector<std::unique_ptr<MachineBasicBlock>>& getBlocks() const { return blocks; }
std::vector<std::unique_ptr<MachineBasicBlock>>& getBlocks() { return blocks; } std::vector<std::unique_ptr<MachineBasicBlock>>& getBlocks() { return blocks; }
void addBlock(std::unique_ptr<MachineBasicBlock> block) { void addBlock(std::unique_ptr<MachineBasicBlock> block) {
blocks.push_back(std::move(block)); blocks.push_back(std::move(block));
} }
private: private:
Function* F;
RISCv64ISel* isel; // 指向创建它的ISel用于获取vreg映射等信息
std::string name; std::string name;
std::vector<std::unique_ptr<MachineBasicBlock>> blocks; std::vector<std::unique_ptr<MachineBasicBlock>> blocks;
StackFrameInfo frame_info; StackFrameInfo frame_info;
}; };
} // namespace sysy } // namespace sysy
#endif // RISCV64_LLIR_H #endif // RISCV64_LLIR_H

View File

@@ -0,0 +1,18 @@
// RISCv64Passes.h
#ifndef RISCV64_PASSES_H
#define RISCV64_PASSES_H
#include "RISCv64LLIR.h"
namespace sysy {
// 此处为未来优化Pass的基类或独立类定义
// 例如:
// class PeepholeOptimizer {
// public:
// void runOnMachineFunction(MachineFunction* mfunc);
// };
} // namespace sysy
#endif // RISCV64_PASSES_H

View File

@@ -2,9 +2,6 @@
#define RISCV64_REGALLOC_H #define RISCV64_REGALLOC_H
#include "RISCv64LLIR.h" #include "RISCv64LLIR.h"
#include <map>
#include <set>
#include <vector>
namespace sysy { namespace sysy {
@@ -19,6 +16,9 @@ private:
using LiveSet = std::set<unsigned>; // 活跃虚拟寄存器集合 using LiveSet = std::set<unsigned>; // 活跃虚拟寄存器集合
using InterferenceGraph = std::map<unsigned, std::set<unsigned>>; using InterferenceGraph = std::map<unsigned, std::set<unsigned>>;
// 栈帧管理
void eliminateFrameIndices();
// 活跃性分析 // 活跃性分析
void analyzeLiveness(); void analyzeLiveness();
@@ -28,7 +28,7 @@ private:
// 图着色分配寄存器 // 图着色分配寄存器
void colorGraph(); void colorGraph();
// 重写函数,将虚拟寄存器替换为物理寄存器,并插入溢出代码 // 重写函数,替换vreg并插入溢出代码
void rewriteFunction(); void rewriteFunction();
// 辅助函数获取指令的Use/Def集合 // 辅助函数获取指令的Use/Def集合
@@ -37,8 +37,8 @@ private:
MachineFunction* MFunc; MachineFunction* MFunc;
// 活跃性分析结果 // 活跃性分析结果
std::map<MachineInstr*, LiveSet> live_in_map; std::map<const MachineInstr*, LiveSet> live_in_map;
std::map<MachineInstr*, LiveSet> live_out_map; std::map<const MachineInstr*, LiveSet> live_out_map;
// 干扰图 // 干扰图
InterferenceGraph interference_graph; InterferenceGraph interference_graph;
@@ -49,7 +49,6 @@ private:
// 可用的物理寄存器池 // 可用的物理寄存器池
std::vector<PhysicalReg> allocable_int_regs; std::vector<PhysicalReg> allocable_int_regs;
std::vector<PhysicalReg> allocable_float_regs; // (为未来浮点支持预留)
}; };
} // namespace sysy } // namespace sysy