Files
mysysy/src/RISCv64RegAlloc.cpp
2025-07-26 17:36:23 +08:00

530 lines
22 KiB
C++
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#include "RISCv64RegAlloc.h"
#include "RISCv64ISel.h"
#include <algorithm>
#include <vector>
namespace sysy {
RISCv64RegAlloc::RISCv64RegAlloc(MachineFunction* mfunc) : MFunc(mfunc) {
allocable_int_regs = {
PhysicalReg::T0, PhysicalReg::T1, PhysicalReg::T2, PhysicalReg::T3,
PhysicalReg::T4, PhysicalReg::T5, PhysicalReg::T6,
PhysicalReg::A0, PhysicalReg::A1, PhysicalReg::A2, PhysicalReg::A3,
PhysicalReg::A4, PhysicalReg::A5, PhysicalReg::A6, PhysicalReg::A7,
PhysicalReg::S0, PhysicalReg::S1, PhysicalReg::S2, PhysicalReg::S3,
PhysicalReg::S4, PhysicalReg::S5, PhysicalReg::S6, PhysicalReg::S7,
PhysicalReg::S8, PhysicalReg::S9, PhysicalReg::S10, PhysicalReg::S11,
};
}
void RISCv64RegAlloc::run() {
handleCallingConvention();
eliminateFrameIndices();
analyzeLiveness();
buildInterferenceGraph();
colorGraph();
rewriteFunction();
}
/**
* @brief 处理调用约定,预先为函数参数分配物理寄存器。
*/
void RISCv64RegAlloc::handleCallingConvention() {
Function* F = MFunc->getFunc();
RISCv64ISel* isel = MFunc->getISel();
// 获取函数的Argument对象列表
if (F) {
auto& args = F->getArguments();
// RISC-V RV64G调用约定前8个整型/指针参数通过 a0-a7 传递
int arg_idx = 0;
// 遍历 AllocaInst* 列表
for (Argument* arg : args) {
if (arg_idx >= 8) {
break;
}
// 1. 获取该 Argument 对象对应的虚拟寄存器
unsigned vreg = isel->getVReg(arg);
// 2. 根据参数索引,确定对应的物理寄存器 (a0, a1, ...)
auto preg = static_cast<PhysicalReg>(static_cast<int>(PhysicalReg::A0) + arg_idx);
// 3. 在 color_map 中,将 vreg "预着色" 为对应的物理寄存器
color_map[vreg] = preg;
arg_idx++;
}
}
}
void RISCv64RegAlloc::eliminateFrameIndices() {
StackFrameInfo& frame_info = MFunc->getFrameInfo();
// 初始偏移量为保存ra和s0留出空间。
// 假设序言是 addi sp, sp, -stack_size; sd ra, stack_size-8(sp); sd s0, stack_size-16(sp);
int current_offset = 16;
Function* F = MFunc->getFunc();
RISCv64ISel* isel = MFunc->getISel();
// 在处理局部变量前,首先为栈参数计算偏移量。
if (F) {
int arg_idx = 0;
for (Argument* arg : F->getArguments()) {
// 我们只关心第8个索引及之后的参数即第9个参数开始
if (arg_idx >= 8) {
// 计算偏移量:第一个栈参数(idx=8)在0(s0),第二个(idx=9)在8(s0),以此类推。
int offset = (arg_idx - 8) * 8;
unsigned vreg = isel->getVReg(arg);
// 将这个vreg和它的栈偏移存入map。
// 我们可以复用alloca_offsets因为它们都代表“vreg到栈偏移”的映射。
frame_info.alloca_offsets[vreg] = offset;
}
arg_idx++;
}
}
// 处理局部变量
// 遍历AllocaInst来计算局部变量所需的总空间
for (auto& bb : F->getBasicBlocks()) {
for (auto& inst : bb->getInstructions()) {
if (auto alloca = dynamic_cast<AllocaInst*>(inst.get())) {
// 获取Alloca指令指向的类型 (例如 alloca i32* 中,获取 i32)
Type* allocated_type = alloca->getType()->as<PointerType>()->getBaseType();
int size = getTypeSizeInBytes(allocated_type);
// RISC-V要求栈地址8字节对齐
size = (size + 7) & ~7;
if (size == 0) size = 8; // 至少分配8字节
current_offset += size;
unsigned alloca_vreg = isel->getVReg(alloca);
// 局部变量使用相对于s0的负向偏移
frame_info.alloca_offsets[alloca_vreg] = -current_offset;
}
}
}
frame_info.locals_size = current_offset;
// 遍历所有机器指令,将伪指令展开为真实指令
for (auto& mbb : MFunc->getBlocks()) {
std::vector<std::unique_ptr<MachineInstr>> new_instructions;
for (auto& instr_ptr : mbb->getInstructions()) {
RVOpcodes opcode = instr_ptr->getOpcode();
// --- MODIFICATION START: 处理区分宽度的伪指令 ---
if (opcode == RVOpcodes::FRAME_LOAD_W || opcode == RVOpcodes::FRAME_LOAD_D) {
// 确定要生成的真实加载指令是 lw 还是 ld
RVOpcodes real_load_op = (opcode == RVOpcodes::FRAME_LOAD_W) ? RVOpcodes::LW : RVOpcodes::LD;
auto& operands = instr_ptr->getOperands();
unsigned dest_vreg = static_cast<RegOperand*>(operands[0].get())->getVRegNum();
unsigned alloca_vreg = static_cast<RegOperand*>(operands[1].get())->getVRegNum();
int offset = frame_info.alloca_offsets.at(alloca_vreg);
auto addr_vreg = isel->getNewVReg();
// 展开为: addi addr_vreg, s0, offset
auto addi = std::make_unique<MachineInstr>(RVOpcodes::ADDI);
addi->addOperand(std::make_unique<RegOperand>(addr_vreg));
addi->addOperand(std::make_unique<RegOperand>(PhysicalReg::S0));
addi->addOperand(std::make_unique<ImmOperand>(offset));
new_instructions.push_back(std::move(addi));
// 展开为: lw/ld dest_vreg, 0(addr_vreg)
auto load_instr = std::make_unique<MachineInstr>(real_load_op);
load_instr->addOperand(std::make_unique<RegOperand>(dest_vreg));
load_instr->addOperand(std::make_unique<MemOperand>(
std::make_unique<RegOperand>(addr_vreg),
std::make_unique<ImmOperand>(0)));
new_instructions.push_back(std::move(load_instr));
} else if (opcode == RVOpcodes::FRAME_STORE_W || opcode == RVOpcodes::FRAME_STORE_D) {
// 确定要生成的真实存储指令是 sw 还是 sd
RVOpcodes real_store_op = (opcode == RVOpcodes::FRAME_STORE_W) ? RVOpcodes::SW : RVOpcodes::SD;
auto& operands = instr_ptr->getOperands();
unsigned src_vreg = static_cast<RegOperand*>(operands[0].get())->getVRegNum();
unsigned alloca_vreg = static_cast<RegOperand*>(operands[1].get())->getVRegNum();
int offset = frame_info.alloca_offsets.at(alloca_vreg);
auto addr_vreg = isel->getNewVReg();
// 展开为: addi addr_vreg, s0, offset
auto addi = std::make_unique<MachineInstr>(RVOpcodes::ADDI);
addi->addOperand(std::make_unique<RegOperand>(addr_vreg));
addi->addOperand(std::make_unique<RegOperand>(PhysicalReg::S0));
addi->addOperand(std::make_unique<ImmOperand>(offset));
new_instructions.push_back(std::move(addi));
// 展开为: sw/sd src_vreg, 0(addr_vreg)
auto store_instr = std::make_unique<MachineInstr>(real_store_op);
store_instr->addOperand(std::make_unique<RegOperand>(src_vreg));
store_instr->addOperand(std::make_unique<MemOperand>(
std::make_unique<RegOperand>(addr_vreg),
std::make_unique<ImmOperand>(0)));
new_instructions.push_back(std::move(store_instr));
} else if (instr_ptr->getOpcode() == RVOpcodes::FRAME_ADDR) {
auto& operands = instr_ptr->getOperands();
unsigned dest_vreg = static_cast<RegOperand*>(operands[0].get())->getVRegNum();
unsigned alloca_vreg = static_cast<RegOperand*>(operands[1].get())->getVRegNum();
int offset = frame_info.alloca_offsets.at(alloca_vreg);
// 将 `frame_addr rd, rs` 展开为 `addi rd, s0, offset`
auto addi = std::make_unique<MachineInstr>(RVOpcodes::ADDI);
addi->addOperand(std::make_unique<RegOperand>(dest_vreg));
addi->addOperand(std::make_unique<RegOperand>(PhysicalReg::S0));
addi->addOperand(std::make_unique<ImmOperand>(offset));
new_instructions.push_back(std::move(addi));
} else {
new_instructions.push_back(std::move(instr_ptr));
}
// --- MODIFICATION END ---
}
mbb->getInstructions() = std::move(new_instructions);
}
}
void RISCv64RegAlloc::getInstrUseDef(MachineInstr* instr, LiveSet& use, LiveSet& def) {
bool is_def = true;
auto opcode = instr->getOpcode();
// --- MODIFICATION START: 细化对指令的 use/def 定义 ---
// 对于没有定义目标寄存器的指令,预先设置 is_def = false
if (opcode == RVOpcodes::SW || opcode == RVOpcodes::SD ||
opcode == RVOpcodes::BEQ || opcode == RVOpcodes::BNE ||
opcode == RVOpcodes::BLT || opcode == RVOpcodes::BGE ||
opcode == RVOpcodes::BLTU || opcode == RVOpcodes::BGEU ||
opcode == RVOpcodes::RET || opcode == RVOpcodes::J) {
is_def = false;
}
// 对 CALL 指令进行特殊处理
if (opcode == RVOpcodes::CALL) {
// CALL 指令的第一个操作数通常是目标函数标签,不是寄存器。
// 它可能会有一个可选的返回值def以及一系列参数use
// 这里的处理假定 CALL 的机器指令操作数布局是:
// [可选: dest_vreg (def)], [函数标签], [可选: arg1_vreg (use)], [可选: arg2_vreg (use)], ...
// 我们需要一种方法来识别哪些操作数是def哪些是use。
// 一个简单的约定如果第一个操作数是寄存器则它是def返回值
if (!instr->getOperands().empty() && instr->getOperands().front()->getKind() == MachineOperand::KIND_REG) {
auto reg_op = static_cast<RegOperand*>(instr->getOperands().front().get());
if (reg_op->isVirtual()) {
def.insert(reg_op->getVRegNum());
}
}
// 遍历所有操作数非第一个寄存器操作数均视为use
bool first_reg_skipped = false;
for (const auto& op : instr->getOperands()) {
if (op->getKind() == MachineOperand::KIND_REG) {
if (!first_reg_skipped) {
first_reg_skipped = true;
continue; // 跳过我们已经作为def处理的返回值
}
auto reg_op = static_cast<RegOperand*>(op.get());
if (reg_op->isVirtual()) {
use.insert(reg_op->getVRegNum());
}
}
}
// **重要**: CALL指令还隐式定义杀死了所有调用者保存的寄存器。
// 一个完整的实现会在这里将所有caller-saved寄存器标记为def
// 以确保任何跨调用存活的变量都不会被分配到这些寄存器中。
// 这个简化的实现暂不处理隐式def但这是未来优化的关键点。
return; // CALL 指令处理完毕,直接返回
}
// --- MODIFICATION END ---
// 对其他所有指令的通用处理逻辑
for (const auto& op : instr->getOperands()) {
if (op->getKind() == MachineOperand::KIND_REG) {
auto reg_op = static_cast<RegOperand*>(op.get());
if (reg_op->isVirtual()) {
if (is_def) {
def.insert(reg_op->getVRegNum());
is_def = false; // 一条指令通常只有一个目标寄存ator
} else {
use.insert(reg_op->getVRegNum());
}
}
} else if (op->getKind() == MachineOperand::KIND_MEM) {
// 内存操作数 `offset(base)` 中的 base 寄存器是 use
auto mem_op = static_cast<MemOperand*>(op.get());
if (mem_op->getBase()->isVirtual()) {
use.insert(mem_op->getBase()->getVRegNum());
}
}
}
}
/**
* @brief 计算一个类型在内存中占用的字节数。
* @param type 需要计算大小的IR类型。
* @return 该类型占用的字节数。
*/
unsigned RISCv64RegAlloc::getTypeSizeInBytes(Type* type) {
if (!type) {
assert(false && "Cannot get size of a null type.");
return 0;
}
switch (type->getKind()) {
// 对于SysY语言基本类型int和float都占用4字节
case Type::kInt:
case Type::kFloat:
return 4;
// 指针类型在RISC-V 64位架构下占用8字节
// 虽然SysY没有'int*'语法但数组变量在IR层面本身就是指针类型
case Type::kPointer:
return 8;
// 数组类型的总大小 = 元素数量 * 单个元素的大小
case Type::kArray: {
auto arrayType = type->as<ArrayType>();
// 递归调用以计算元素大小
return arrayType->getNumElements() * getTypeSizeInBytes(arrayType->getElementType());
}
// 其他类型如Void, Label等不占用栈空间或者不应该出现在这里
default:
// 如果遇到未处理的类型,触发断言,方便调试
assert(false && "Unsupported type for size calculation.");
return 0;
}
}
void RISCv64RegAlloc::analyzeLiveness() {
bool changed = true;
while (changed) {
changed = false;
for (auto it = MFunc->getBlocks().rbegin(); it != MFunc->getBlocks().rend(); ++it) {
auto& mbb = *it;
LiveSet live_out;
for (auto succ : mbb->successors) {
if (!succ->getInstructions().empty()) {
auto first_instr = succ->getInstructions().front().get();
if (live_in_map.count(first_instr)) {
live_out.insert(live_in_map.at(first_instr).begin(), live_in_map.at(first_instr).end());
}
}
}
for (auto instr_it = mbb->getInstructions().rbegin(); instr_it != mbb->getInstructions().rend(); ++instr_it) {
MachineInstr* instr = instr_it->get();
LiveSet old_live_in = live_in_map[instr];
live_out_map[instr] = live_out;
LiveSet use, def;
getInstrUseDef(instr, use, def);
LiveSet live_in = use;
LiveSet diff = live_out;
for (auto vreg : def) {
diff.erase(vreg);
}
live_in.insert(diff.begin(), diff.end());
live_in_map[instr] = live_in;
live_out = live_in;
if (live_in_map[instr] != old_live_in) {
changed = true;
}
}
}
}
}
void RISCv64RegAlloc::buildInterferenceGraph() {
std::set<unsigned> all_vregs;
for (auto& mbb : MFunc->getBlocks()) {
for(auto& instr : mbb->getInstructions()) {
LiveSet use, def;
getInstrUseDef(instr.get(), use, def);
for(auto u : use) all_vregs.insert(u);
for(auto d : def) all_vregs.insert(d);
}
}
for (auto vreg : all_vregs) { interference_graph[vreg] = {}; }
for (auto& mbb : MFunc->getBlocks()) {
for (auto& instr : mbb->getInstructions()) {
LiveSet def, use;
getInstrUseDef(instr.get(), use, def);
const LiveSet& live_out = live_out_map.at(instr.get());
for (unsigned d : def) {
for (unsigned l : live_out) {
if (d != l) {
interference_graph[d].insert(l);
interference_graph[l].insert(d);
}
}
}
}
}
}
void RISCv64RegAlloc::colorGraph() {
std::vector<unsigned> sorted_vregs;
for (auto const& [vreg, neighbors] : interference_graph) {
if (color_map.find(vreg) == color_map.end()) {
sorted_vregs.push_back(vreg);
}
}
// 排序
std::sort(sorted_vregs.begin(), sorted_vregs.end(), [&](unsigned a, unsigned b) {
return interference_graph[a].size() > interference_graph[b].size();
});
// 着色
for (unsigned vreg : sorted_vregs) {
std::set<PhysicalReg> used_colors;
for (unsigned neighbor : interference_graph.at(vreg)) {
if (color_map.count(neighbor)) {
used_colors.insert(color_map.at(neighbor));
}
}
bool colored = false;
for (PhysicalReg preg : allocable_int_regs) {
if (used_colors.find(preg) == used_colors.end()) {
color_map[vreg] = preg;
colored = true;
break;
}
}
if (!colored) {
spilled_vregs.insert(vreg);
}
}
}
void RISCv64RegAlloc::rewriteFunction() {
StackFrameInfo& frame_info = MFunc->getFrameInfo();
int current_offset = frame_info.locals_size;
// --- FIX 1: 动态计算溢出槽大小 ---
// 根据溢出虚拟寄存器的真实类型,为其在栈上分配正确大小的空间。
for (unsigned vreg : spilled_vregs) {
// 从反向映射中查找 vreg 对应的 IR Value
assert(vreg_to_value_map.count(vreg) && "Spilled vreg not found in map!");
Value* val = vreg_to_value_map.at(vreg);
// 使用辅助函数获取类型大小
int size = getTypeSizeInBytes(val->getType());
// 保持栈8字节对齐
current_offset += size;
current_offset = (current_offset + 7) & ~7;
frame_info.spill_offsets[vreg] = -current_offset;
}
frame_info.spill_size = current_offset - frame_info.locals_size;
for (auto& mbb : MFunc->getBlocks()) {
std::vector<std::unique_ptr<MachineInstr>> new_instructions;
for (auto& instr_ptr : mbb->getInstructions()) {
LiveSet use, def;
getInstrUseDef(instr_ptr.get(), use, def);
// --- FIX 2: 为溢出的 'use' 操作数插入正确的加载指令 ---
for (unsigned vreg : use) {
if (spilled_vregs.count(vreg)) {
// 同样地,根据 vreg 的类型决定使用 lw 还是 ld
assert(vreg_to_value_map.count(vreg));
Value* val = vreg_to_value_map.at(vreg);
RVOpcodes load_op = val->getType()->isPointer() ? RVOpcodes::LD : RVOpcodes::LW;
int offset = frame_info.spill_offsets.at(vreg);
auto load = std::make_unique<MachineInstr>(load_op);
load->addOperand(std::make_unique<RegOperand>(vreg));
load->addOperand(std::make_unique<MemOperand>(
std::make_unique<RegOperand>(PhysicalReg::S0),
std::make_unique<ImmOperand>(offset)
));
new_instructions.push_back(std::move(load));
}
}
new_instructions.push_back(std::move(instr_ptr));
// --- FIX 3: 为溢出的 'def' 操作数插入正确的存储指令 ---
for (unsigned vreg : def) {
if (spilled_vregs.count(vreg)) {
// 根据 vreg 的类型决定使用 sw 还是 sd
assert(vreg_to_value_map.count(vreg));
Value* val = vreg_to_value_map.at(vreg);
RVOpcodes store_op = val->getType()->isPointer() ? RVOpcodes::SD : RVOpcodes::SW;
int offset = frame_info.spill_offsets.at(vreg);
auto store = std::make_unique<MachineInstr>(store_op);
store->addOperand(std::make_unique<RegOperand>(vreg));
store->addOperand(std::make_unique<MemOperand>(
std::make_unique<RegOperand>(PhysicalReg::S0),
std::make_unique<ImmOperand>(offset)
));
new_instructions.push_back(std::move(store));
}
}
}
mbb->getInstructions() = std::move(new_instructions);
}
// 最后的虚拟寄存器到物理寄存器的替换过程保持不变
for (auto& mbb : MFunc->getBlocks()) {
for (auto& instr_ptr : mbb->getInstructions()) {
for (auto& op_ptr : instr_ptr->getOperands()) {
// 情况一:操作数本身就是一个寄存器 (例如 add rd, rs1, rs2 中的所有操作数)
if(op_ptr->getKind() == MachineOperand::KIND_REG) {
auto reg_op = static_cast<RegOperand*>(op_ptr.get());
if (reg_op->isVirtual()) {
unsigned vreg = reg_op->getVRegNum();
if (color_map.count(vreg)) {
PhysicalReg preg = color_map.at(vreg);
reg_op->setPReg(preg);
// 检查这个物理寄存器是否是 s0-s11
if (preg >= PhysicalReg::S0 && preg <= PhysicalReg::S11) {
used_callee_saved_regs.insert(preg);
}
} else if (spilled_vregs.count(vreg)) {
// 如果vreg被溢出替换为专用的溢出物理寄存器t6
reg_op->setPReg(PhysicalReg::T6);
}
}
}
// 情况二:操作数是一个内存地址 (例如 lw rd, offset(rs1) 中的 offset(rs1))
else if (op_ptr->getKind() == MachineOperand::KIND_MEM) {
auto mem_op = static_cast<MemOperand*>(op_ptr.get());
// 获取内存操作数内部的“基址寄存器”
auto base_reg_op = mem_op->getBase();
// 对这个基址寄存器,执行与情况一完全相同的替换逻辑
if(base_reg_op->isVirtual()){
unsigned vreg = base_reg_op->getVRegNum();
if(color_map.count(vreg)) {
// 如果基址vreg被成功着色替换
base_reg_op->setPReg(color_map.at(vreg));
} else if (spilled_vregs.count(vreg)) {
// 如果基址vreg被溢出替换为t6
base_reg_op->setPReg(PhysicalReg::T6);
}
}
}
}
}
}
}
} // namespace sysy