[backend]继续增强寄存器分配器健壮性

This commit is contained in:
Lixuanwang
2025-08-05 13:50:55 +08:00
parent f2477c4af3
commit 32684d8255
8 changed files with 707 additions and 214 deletions

View File

@@ -2,6 +2,7 @@
#include "RISCv64LLIR.h"
#include "RISCv64ISel.h"
#include "RISCv64Info.h"
#include "RISCv64AsmPrinter.h"
#include <iostream>
#include <algorithm>
#include <set>
@@ -16,7 +17,9 @@ extern int DEEPERDEBUG;
namespace sysy {
// --- 调试辅助函数 ---
std::string pregToString(PhysicalReg preg) {
// These helpers are self-contained and only used for logging.
static std::string pregToString(PhysicalReg preg) {
// This map is a copy from AsmPrinter to avoid dependency issues.
static const std::map<PhysicalReg, std::string> preg_names = {
{PhysicalReg::ZERO, "zero"}, {PhysicalReg::RA, "ra"}, {PhysicalReg::SP, "sp"}, {PhysicalReg::GP, "gp"}, {PhysicalReg::TP, "tp"},
{PhysicalReg::T0, "t0"}, {PhysicalReg::T1, "t1"}, {PhysicalReg::T2, "t2"}, {PhysicalReg::T3, "t3"}, {PhysicalReg::T4, "t4"}, {PhysicalReg::T5, "t5"}, {PhysicalReg::T6, "t6"},
@@ -33,36 +36,34 @@ std::string pregToString(PhysicalReg preg) {
}
template<typename T>
std::string setToString(const std::set<T>& s, std::function<std::string(T)> formatter) {
static std::string setToString(const std::set<T>& s, std::function<std::string(T)> formatter) {
std::stringstream ss;
ss << "{ ";
for (auto it = s.begin(); it != s.end(); ++it) {
ss << formatter(*it) << (std::next(it) == s.end() ? "" : ", ");
bool first = true;
for (const auto& item : s) {
if (!first) ss << ", ";
ss << formatter(item);
first = false;
}
ss << " }";
return ss.str();
}
std::string vregSetToString(const std::set<unsigned>& s) {
static std::string vregSetToString(const std::set<unsigned>& s) {
return setToString<unsigned>(s, [](unsigned v){ return "%v" + std::to_string(v); });
}
std::string pregSetToString(const std::set<PhysicalReg>& s) {
static std::string pregSetToString(const std::set<PhysicalReg>& s) {
return setToString<PhysicalReg>(s, pregToString);
}
std::string opcodeToString(RVOpcodes opcode) {
static const std::map<RVOpcodes, std::string> opcode_names = {
{RVOpcodes::ADD, "add"}, {RVOpcodes::ADDI, "addi"}, {RVOpcodes::ADDW, "addw"}, {RVOpcodes::ADDIW, "addiw"},
{RVOpcodes::SUB, "sub"}, {RVOpcodes::SUBW, "subw"}, {RVOpcodes::MUL, "mul"}, {RVOpcodes::MULW, "mulw"},
{RVOpcodes::DIV, "div"}, {RVOpcodes::DIVW, "divw"}, {RVOpcodes::REM, "rem"}, {RVOpcodes::REMW, "remw"},
{RVOpcodes::SLT, "slt"}, {RVOpcodes::LW, "lw"}, {RVOpcodes::LD, "ld"}, {RVOpcodes::SW, "sw"}, {RVOpcodes::SD, "sd"},
{RVOpcodes::FLW, "flw"}, {RVOpcodes::FSW, "fsw"}, {RVOpcodes::CALL, "call"}, {RVOpcodes::RET, "ret"},
{RVOpcodes::MV, "mv"}, {RVOpcodes::LI, "li"}, {RVOpcodes::LA, "la"}, {RVOpcodes::BNE, "bne"},
{RVOpcodes::J, "j"}, {RVOpcodes::LABEL, "label"}
};
if (opcode_names.count(opcode)) return opcode_names.at(opcode);
return "Op(" + std::to_string(static_cast<int>(opcode)) + ")";
// Helper function to check if a register is callee-saved.
// Defined locally to avoid scope issues.
static bool isCalleeSaved(PhysicalReg preg) {
if (preg >= PhysicalReg::S0 && preg <= PhysicalReg::S11) return true;
if (preg >= PhysicalReg::F8 && preg <= PhysicalReg::F9) return true;
if (preg >= PhysicalReg::F18 && preg <= PhysicalReg::F27) return true;
return false;
}
@@ -72,8 +73,7 @@ RISCv64LinearScan::RISCv64LinearScan(MachineFunction* mfunc)
vreg_type_map(ISel->getVRegTypeMap()) {
allocable_int_regs = {
PhysicalReg::T0, PhysicalReg::T1, PhysicalReg::T2, PhysicalReg::T3, PhysicalReg::T4, PhysicalReg::T6,
// PhysicalReg::A0, PhysicalReg::A1, PhysicalReg::A2, PhysicalReg::A3, PhysicalReg::A4, PhysicalReg::A5, PhysicalReg::A6, PhysicalReg::A7,
PhysicalReg::T0, PhysicalReg::T1, PhysicalReg::T2, PhysicalReg::T3, PhysicalReg::T6,
PhysicalReg::S1, PhysicalReg::S2, PhysicalReg::S3, PhysicalReg::S4, PhysicalReg::S5, PhysicalReg::S6, PhysicalReg::S7,
PhysicalReg::S8, PhysicalReg::S9, PhysicalReg::S10, PhysicalReg::S11,
};
@@ -104,24 +104,42 @@ RISCv64LinearScan::RISCv64LinearScan(MachineFunction* mfunc)
}
}
void RISCv64LinearScan::run() {
bool RISCv64LinearScan::run() {
if (DEBUG) std::cerr << "===== [LSRA] Running for function: " << MFunc->getName() << " =====\n";
bool changed = true;
int iteration = 1;
while(changed) {
const int MAX_ITERATIONS = 3;
for (int iteration = 1; ; ++iteration) {
if (DEBUG && iteration > 1) {
std::cerr << "\n----- [LSRA] Re-running iteration " << iteration << " -----\n";
}
linearizeBlocks();
computeLiveIntervals();
changed = linearScan();
bool needs_spill = linearScan();
if (changed) {
if (DEBUG) std::cerr << "[LSRA] Spilling detected, will rewrite program.\n";
rewriteProgram();
// 如果当前这轮线性扫描不需要溢出,说明分配成功,直接跳出循环。
if (!needs_spill) {
break;
}
iteration++;
// --- 检查是否需要启动或已经失败于保底策略 ---
if (iteration > MAX_ITERATIONS) {
// 如果我们已经在保底模式下运行过,但这一轮 linearScan 仍然返回 true
// 这说明发生了无法解决的错误,此时才真正失败。
if (conservative_spill_mode) {
std::cerr << "\n!!!!!! [LSRA-FATAL] Allocation failed to converge even in Conservative Spill Mode. Triggering final fallback. !!!!!!\n\n";
return false; // 返回失败而不是exit
}
// 这是第一次达到最大迭代次数,触发保底策略。
std::cerr << "\n!!!!!! [LSRA-WARN] Convergence failed after " << MAX_ITERATIONS
<< " iterations. Entering Conservative Spill Mode for the next attempt. !!!!!!\n\n";
conservative_spill_mode = true; // 开启保守溢出模式,将在下一次循环生效
}
// 只要需要溢出,就重写程序
if (DEBUG) std::cerr << "[LSRA] Spilling detected, will rewrite program.\n";
rewriteProgram();
}
if (DEBUG) std::cerr << "[LSRA] Applying final allocation.\n";
@@ -130,6 +148,7 @@ void RISCv64LinearScan::run() {
collectUsedCalleeSavedRegs();
if (DEBUG) std::cerr << "===== [LSRA] Finished for function: " << MFunc->getName() << " =====\n\n";
return true; // 分配成功
}
void RISCv64LinearScan::linearizeBlocks() {
@@ -139,27 +158,22 @@ void RISCv64LinearScan::linearizeBlocks() {
}
}
// 步骤 1.2: 计算活跃区间 (最终修复版)
void RISCv64LinearScan::computeLiveIntervals() {
if (DEBUG) std::cerr << "[LSRA-Live] Starting live interval computation.\n";
instr_numbering.clear();
live_intervals.clear();
unhandled.clear();
// a. 对所有指令进行线性编号
int num = 0;
std::set<int> call_locations;
for (auto* mbb : linear_order_blocks) {
for (auto& instr : mbb->getInstructions()) {
instr_numbering[instr.get()] = num;
if (instr->getOpcode() == RVOpcodes::CALL) {
call_locations.insert(num);
}
if (instr->getOpcode() == RVOpcodes::CALL) call_locations.insert(num);
num += 2;
}
}
// b. 活跃变量分析(数据流分析部分,这部分是正确的,保持不变)
if (DEEPDEBUG) std::cerr << " [Live] Starting live variable dataflow analysis...\n";
std::map<const MachineBasicBlock*, std::set<unsigned>> live_in, live_out;
bool changed = true;
@@ -169,16 +183,12 @@ void RISCv64LinearScan::computeLiveIntervals() {
df_iter++;
std::vector<MachineBasicBlock*> reversed_blocks = linear_order_blocks;
std::reverse(reversed_blocks.begin(), reversed_blocks.end());
for(auto* mbb : reversed_blocks) {
std::set<unsigned> old_live_in = live_in[mbb];
std::set<unsigned> current_live_out;
for (auto* succ : mbb->successors) {
current_live_out.insert(live_in[succ].begin(), live_in[succ].end());
}
for (auto* succ : mbb->successors) current_live_out.insert(live_in[succ].begin(), live_in[succ].end());
std::set<unsigned> use, def;
std::set<unsigned> temp_live = current_live_out;
auto& instrs = mbb->getInstructions();
for (auto it = instrs.rbegin(); it != instrs.rend(); ++it) {
use.clear(); def.clear();
@@ -186,7 +196,6 @@ void RISCv64LinearScan::computeLiveIntervals() {
for (unsigned vreg : def) temp_live.erase(vreg);
for (unsigned vreg : use) temp_live.insert(vreg);
}
if (live_in[mbb] != temp_live || live_out[mbb] != current_live_out) {
changed = true;
live_in[mbb] = temp_live;
@@ -195,36 +204,35 @@ void RISCv64LinearScan::computeLiveIntervals() {
}
}
if (DEEPDEBUG) std::cerr << " [Live] Dataflow analysis converged after " << df_iter << " iterations.\n";
if (DEEPERDEBUG) {
std::cerr << " [Live-Debug] Live-in sets:\n";
for (auto* mbb : linear_order_blocks) std::cerr << " " << mbb->getName() << ": " << vregSetToString(live_in[mbb]) << "\n";
std::cerr << " [Live-Debug] Live-out sets:\n";
for (auto* mbb : linear_order_blocks) std::cerr << " " << mbb->getName() << ": " << vregSetToString(live_out[mbb]) << "\n";
}
// c. 根据指令遍历和活跃出口信息,精确构建区间
if (DEEPDEBUG) std::cerr << " [Live] Building precise intervals...\n";
std::map<unsigned, int> first_def, last_use;
// 步骤 c.1: 遍历所有指令找到每个vreg的首次定义和最后一次使用
for (auto* mbb : linear_order_blocks) {
for (auto& instr_ptr : mbb->getInstructions()) {
int instr_num = instr_numbering.at(instr_ptr.get());
std::set<unsigned> use, def;
getInstrUseDef(instr_ptr.get(), use, def);
for (unsigned vreg : def) {
if (first_def.find(vreg) == first_def.end()) {
first_def[vreg] = instr_num;
}
}
for (unsigned vreg : use) {
last_use[vreg] = instr_num;
}
for (unsigned vreg : def) if (first_def.find(vreg) == first_def.end()) first_def[vreg] = instr_num;
for (unsigned vreg : use) last_use[vreg] = instr_num;
}
}
// 步骤 c.2: 创建 LiveInterval 对象,并使用 live_out 信息修正区间的结束点
if (DEEPERDEBUG) {
std::cerr << " [Live-Debug] First def points:\n";
for (auto const& [vreg, pos] : first_def) std::cerr << " %v" << vreg << ": " << pos << "\n";
std::cerr << " [Live-Debug] Last use points:\n";
for (auto const& [vreg, pos] : last_use) std::cerr << " %v" << vreg << ": " << pos << "\n";
}
for (auto const& [vreg, start] : first_def) {
live_intervals.emplace(vreg, LiveInterval(vreg));
auto& interval = live_intervals.at(vreg);
interval.start = start;
// 初始结束点是最后一次使用。如果从未被使用,则结束点就是定义点。
interval.end = last_use.count(vreg) ? last_use.at(vreg) : start;
}
@@ -233,27 +241,22 @@ void RISCv64LinearScan::computeLiveIntervals() {
int block_end_num = instr_numbering.at(mbb->getInstructions().back().get());
for (unsigned vreg : live_set) {
if (live_intervals.count(vreg)) {
if (DEEPERDEBUG && live_intervals.at(vreg).end < block_end_num) {
std::cerr << " [Live-Debug] Extending interval for %v" << vreg << " from " << live_intervals.at(vreg).end << " to " << block_end_num << " due to live_out of " << mbb->getName() << "\n";
}
live_intervals.at(vreg).end = std::max(live_intervals.at(vreg).end, block_end_num);
}
}
}
// 步骤 c.3: 检查是否跨越函数调用
for (auto& pair : live_intervals) {
auto& interval = pair.second;
auto it = call_locations.lower_bound(interval.start);
if (it != call_locations.end() && *it < interval.end) {
interval.crosses_call = true;
}
if (it != call_locations.end() && *it < interval.end) interval.crosses_call = true;
}
// d. 排序并准备分配
for (auto& pair : live_intervals) {
unhandled.push_back(&pair.second);
}
std::sort(unhandled.begin(), unhandled.end(), [](const LiveInterval* a, const LiveInterval* b){
return a->start < b->start;
});
for (auto& pair : live_intervals) unhandled.push_back(&pair.second);
std::sort(unhandled.begin(), unhandled.end(), [](const LiveInterval* a, const LiveInterval* b){ return a->start < b->start; });
if (DEBUG) {
std::cerr << "[LSRA-Live] Finished. Total intervals: " << unhandled.size() << "\n";
@@ -265,17 +268,73 @@ void RISCv64LinearScan::computeLiveIntervals() {
}
}
}
}
bool isCalleeSaved(PhysicalReg preg) {
if (preg >= PhysicalReg::S1 && preg <= PhysicalReg::S11) return true;
if (preg == PhysicalReg::S0) return true;
if (preg >= PhysicalReg::F8 && preg <= PhysicalReg::F9) return true;
if (preg >= PhysicalReg::F18 && preg <= PhysicalReg::F27) return true;
return false;
// ================== 新增的调试代码 ==================
// 检查活性分析找到的vreg与指令扫描找到的vreg是否一致
if (DEEPERDEBUG) {
// 修正:将 std.set 修改为 std::set
std::set<unsigned> vregs_from_liveness;
for (const auto& pair : live_intervals) {
vregs_from_liveness.insert(pair.first);
}
std::set<unsigned> vregs_from_instr_scan;
for (auto* mbb : linear_order_blocks) {
for (auto& instr_ptr : mbb->getInstructions()) {
std::set<unsigned> use, def;
getInstrUseDef(instr_ptr.get(), use, def);
vregs_from_instr_scan.insert(use.begin(), use.end());
vregs_from_instr_scan.insert(def.begin(), def.end());
}
}
std::cerr << " [Live-Debug] VReg Consistency Check:\n";
std::cerr << " VRegs found by Liveness Analysis: " << vregs_from_liveness.size() << "\n";
std::cerr << " VRegs found by getInstrUseDef Scan: " << vregs_from_instr_scan.size() << "\n";
// 修正:将 std.set 修改为 std::set
std::set<unsigned> diff;
std::set_difference(vregs_from_liveness.begin(), vregs_from_liveness.end(),
vregs_from_instr_scan.begin(), vregs_from_instr_scan.end(),
std::inserter(diff, diff.begin()));
if (!diff.empty()) {
std::cerr << " !!!!!! [Live-Debug] DISCREPANCY DETECTED !!!!!!\n";
std::cerr << " The following vregs were found by liveness but NOT by getInstrUseDef scan:\n";
std::cerr << " " << vregSetToString(diff) << "\n";
} else {
std::cerr << " [Live-Debug] VReg sets are consistent.\n";
}
}
// ======================================================
}
bool RISCv64LinearScan::linearScan() {
// ================== 终极保底策略 (新逻辑) ==================
// 当此标志位为true时我们进入最暴力的溢出模式。
if (conservative_spill_mode) {
if (DEBUG) std::cerr << "[LSRA-Scan-Panic] In Conservative Mode. Spilling all unhandled vregs.\n";
// 1. 清空溢出列表,准备重新计算
spilled_vregs.clear();
// 2. 遍历所有计算出的活性区间
for (auto& pair : live_intervals) {
// 3. 如果一个vreg不是ABI规定的寄存器就必须溢出
if (abi_vreg_map.find(pair.first) == abi_vreg_map.end()) {
spilled_vregs.insert(pair.first);
}
}
// 4. 只要有任何vreg被标记为溢出就返回true以触发最终的rewriteProgram。
// 下一轮迭代时由于所有vreg都已被重写将不再有新的溢出保证收敛。
return !spilled_vregs.empty();
}
// ==========================================================
// ================== 常规线性扫描逻辑 (您已有的代码) ==================
// 只有在非保守模式下才会执行以下代码
if (DEBUG) std::cerr << "[LSRA-Scan] Starting main linear scan algorithm.\n";
active.clear();
spilled_vregs.clear();
@@ -324,9 +383,9 @@ bool RISCv64LinearScan::linearScan() {
PhysicalReg preg = vreg_to_preg_map.at(active_interval->vreg);
if (DEEPDEBUG) std::cerr << " [Scan] Expiring interval %v" << active_interval->vreg << ", freeing " << pregToString(preg) << "\n";
if (isFPVReg(active_interval->vreg)) {
if(isCalleeSaved(preg)) free_callee_fp_regs.insert(preg); else free_caller_fp_regs.insert(preg);
if(isCalleeSaved(preg)) free_callee_fp_regs.insert(preg); else free_caller_fp_regs.insert(preg);
} else {
if(isCalleeSaved(preg)) free_callee_int_regs.insert(preg); else free_caller_int_regs.insert(preg);
if(isCalleeSaved(preg)) free_callee_int_regs.insert(preg); else free_caller_int_regs.insert(preg);
}
} else {
new_active.push_back(active_interval);
@@ -368,152 +427,181 @@ bool RISCv64LinearScan::linearScan() {
}
void RISCv64LinearScan::spillAtInterval(LiveInterval* current) {
// 保持您的原始逻辑
LiveInterval* spill_candidate = nullptr;
if (!active.empty()) {
spill_candidate = active.back();
}
if (DEEPERDEBUG) {
std::cerr << " [Spill-Debug] Spill decision for current=%v" << current->vreg << "[" << current->start << "," << current->end << "]\n";
std::cerr << " [Spill-Debug] Active intervals (sorted by end point):\n";
for (const auto* i : active) {
std::cerr << " %v" << i->vreg << "[" << i->start << "," << i->end << "] in " << pregToString(vreg_to_preg_map[i->vreg]) << "\n";
}
if(spill_candidate) {
std::cerr << " [Spill-Debug] Candidate is %v" << spill_candidate->vreg << ". Its end is " << spill_candidate->end << ", current's end is " << current->end << "\n";
} else {
std::cerr << " [Spill-Debug] No active candidate.\n";
}
}
if (spill_candidate && spill_candidate->end > current->end) {
if (DEEPDEBUG) std::cerr << " [Spill] Current interval ends at " << current->end << ", active interval %v" << spill_candidate->vreg << " ends at " << spill_candidate->end << ". Spilling active.\n";
if (DEEPDEBUG) std::cerr << " [Spill] Decision: Spilling active %v" << spill_candidate->vreg << ".\n";
PhysicalReg preg = vreg_to_preg_map.at(spill_candidate->vreg);
vreg_to_preg_map.erase(spill_candidate->vreg); // 确保移除旧映射
vreg_to_preg_map[current->vreg] = preg;
active.pop_back();
active.push_back(current);
std::sort(active.begin(), active.end(), [](const LiveInterval* a, const LiveInterval* b){ return a->end < b->end; });
spilled_vregs.insert(spill_candidate->vreg);
} else {
if (DEEPDEBUG) std::cerr << " [Spill] Current interval %v" << current->vreg << " ends at " << current->end << " which is later than all active intervals. Spilling current.\n";
if (DEEPDEBUG) std::cerr << " [Spill] Decision: Spilling current %v" << current->vreg << ".\n";
spilled_vregs.insert(current->vreg);
}
}
// 步骤 3: 重写程序,插入溢出代码
void RISCv64LinearScan::rewriteProgram() {
if (DEBUG) {
std::cerr << "[LSRA-Rewrite] Starting program rewrite. Spilled vregs: " << vregSetToString(spilled_vregs) << "\n";
}
StackFrameInfo& frame_info = MFunc->getFrameInfo();
// 遵循后端流水线设计,在本地变量区域之下为溢出变量分配空间。
// locals_end_offset 是由 EliminateFrameIndices Pass 计算好的局部变量区域的下边界。
// 我们在此基础上,减去已有的溢出区大小,作为新溢出槽的起始点。
int spill_current_offset = frame_info.locals_end_offset - frame_info.spill_size;
for (unsigned vreg : spilled_vregs) {
// 只为本轮新发现的溢出变量分配空间
// 保持您的原始逻辑
if (frame_info.spill_offsets.count(vreg)) continue;
Type* type = vreg_type_map.count(vreg) ? vreg_type_map.at(vreg) : Type::getIntType();
int size = isFPVReg(vreg) ? 4 : (type->isPointer() ? 8 : 4);
// 在当前偏移基础上继续向下(地址变得更负)分配空间
spill_current_offset -= size;
// 对齐新的、更小的地址RISC-V 要求8字节对齐
spill_current_offset = (spill_current_offset & ~7);
// 将计算出的、不会冲突的正确偏移量存入 spill_offsets
frame_info.spill_offsets[vreg] = spill_current_offset;
if (DEEPDEBUG) std::cerr << " [Rewrite] Assigned stack offset " << frame_info.spill_offsets.at(vreg) << " to spilled %v" << vreg << "\n";
if (DEEPDEBUG) std::cerr << " [Rewrite] Assigned new stack offset " << frame_info.spill_offsets.at(vreg) << " to spilled %v" << vreg << "\n";
}
// 更新总的溢出区域大小。
// spill_size = -(结束偏移 - 开始偏移)
// 注意locals_end_offset 和 spill_current_offset 都是负数。
frame_info.spill_size = -(spill_current_offset - frame_info.locals_end_offset);
for (auto& mbb : MFunc->getBlocks()) {
auto& instrs = mbb->getInstructions();
std::vector<std::unique_ptr<MachineInstr>> new_instrs;
if (DEEPERDEBUG) std::cerr << " [Rewrite] Processing block " << mbb->getName() << "\n";
for (auto it = instrs.begin(); it != instrs.end(); ++it) {
auto& instr = *it;
std::set<unsigned> use_vregs, def_vregs;
getInstrUseDef(instr.get(), use_vregs, def_vregs);
std::map<unsigned, unsigned> use_remap;
std::map<unsigned, unsigned> def_remap;
// 为每个溢出的 use 操作数创建新vreg并插入 load
for (unsigned old_vreg : use_vregs) {
if (spilled_vregs.count(old_vreg) && use_remap.find(old_vreg) == use_remap.end()) {
Type* type = vreg_type_map.at(old_vreg);
unsigned new_temp_vreg = ISel->getNewVReg(type);
use_remap[old_vreg] = new_temp_vreg;
RVOpcodes load_op = isFPVReg(old_vreg) ? RVOpcodes::FLW : (type->isPointer() ? RVOpcodes::LD : RVOpcodes::LW);
auto load = std::make_unique<MachineInstr>(load_op);
load->addOperand(std::make_unique<RegOperand>(new_temp_vreg));
load->addOperand(std::make_unique<MemOperand>(
std::make_unique<RegOperand>(PhysicalReg::S0),
std::make_unique<ImmOperand>(frame_info.spill_offsets.at(old_vreg))
));
if (DEEPDEBUG) std::cerr << " [Rewrite] Inserting LOAD for use of %v" << old_vreg << " into new %v" << new_temp_vreg << "\n";
new_instrs.push_back(std::move(load));
}
}
// 为每个溢出的 def 操作数创建新vreg
for (unsigned old_vreg : def_vregs) {
if (spilled_vregs.count(old_vreg) && def_remap.find(old_vreg) == def_remap.end()) {
Type* type = vreg_type_map.at(old_vreg);
unsigned new_temp_vreg = ISel->getNewVReg(type);
def_remap[old_vreg] = new_temp_vreg;
}
}
auto opcode = instr->getOpcode();
auto& operands = instr->getOperands();
auto replace_reg_op = [](RegOperand* reg_op, const std::map<unsigned, unsigned>& remap) {
if (reg_op->isVirtual() && remap.count(reg_op->getVRegNum())) {
reg_op->setVRegNum(remap.at(reg_op->getVRegNum()));
}
};
if (op_info.count(opcode)) {
const auto& info = op_info.at(opcode);
for (int idx : info.first) {
if (idx < operands.size() && operands[idx]->getKind() == MachineOperand::KIND_REG) {
replace_reg_op(static_cast<RegOperand*>(operands[idx].get()), def_remap);
}
}
for (int idx : info.second) {
if (idx < operands.size()) {
if (operands[idx]->getKind() == MachineOperand::KIND_REG) {
replace_reg_op(static_cast<RegOperand*>(operands[idx].get()), use_remap);
} else if (operands[idx]->getKind() == MachineOperand::KIND_MEM) {
replace_reg_op(static_cast<MemOperand*>(operands[idx].get())->getBase(), use_remap);
if (conservative_spill_mode) {
// ================== 紧急模式重写逻辑 ==================
// 直接使用物理寄存器 t4 (SPILL_TEMP_REG) 进行加载/存储
// 为调试日志准备一个指令打印机
auto printer = DEEPERDEBUG ? std::make_unique<RISCv64AsmPrinter>(MFunc) : nullptr;
auto original_instr_str_for_log = DEEPERDEBUG ? printer->formatInstr(instr.get()) : "";
bool modified = false;
for (unsigned old_vreg : use_vregs) {
if (spilled_vregs.count(old_vreg)) {
modified = true;
Type* type = vreg_type_map.at(old_vreg);
RVOpcodes load_op = isFPVReg(old_vreg) ? RVOpcodes::FLW : (type->isPointer() ? RVOpcodes::LD : RVOpcodes::LW);
auto load = std::make_unique<MachineInstr>(load_op);
// 直接加载到保留的物理寄存器
load->addOperand(std::make_unique<RegOperand>(SPILL_TEMP_REG));
load->addOperand(std::make_unique<MemOperand>(
std::make_unique<RegOperand>(PhysicalReg::S0),
std::make_unique<ImmOperand>(frame_info.spill_offsets.at(old_vreg))));
if (DEEPERDEBUG) {
std::cerr << " [Rewrite-Panic] Inserting LOAD for use of %v" << old_vreg
<< " into " << pregToString(SPILL_TEMP_REG)
<< " before: " << original_instr_str_for_log << "\n";
}
new_instrs.push_back(std::move(load));
// 替换指令中的操作数
instr->replaceVRegWithPReg(old_vreg, SPILL_TEMP_REG);
}
}
} else if (opcode == RVOpcodes::CALL) {
if (!operands.empty() && operands[0]->getKind() == MachineOperand::KIND_REG) {
replace_reg_op(static_cast<RegOperand*>(operands[0].get()), def_remap);
}
for (size_t i = 1; i < operands.size(); ++i) {
if (operands[i]->getKind() == MachineOperand::KIND_REG) {
replace_reg_op(static_cast<RegOperand*>(operands[i].get()), use_remap);
// 在处理 def 之前,先替换定义自身的 vreg
for (unsigned old_vreg : def_vregs) {
if (spilled_vregs.count(old_vreg)) {
modified = true;
instr->replaceVRegWithPReg(old_vreg, SPILL_TEMP_REG);
}
}
// 将原始指令(可能已被修改)放入新列表
new_instrs.push_back(std::move(instr));
if (DEEPERDEBUG && modified) {
std::cerr << " [Rewrite-Panic] Original: " << original_instr_str_for_log
<< " -> Rewritten: " << printer->formatInstr(new_instrs.back().get()) << "\n";
}
for (unsigned old_vreg : def_vregs) {
if (spilled_vregs.count(old_vreg)) {
// 指令本身已经被修改为定义到 SPILL_TEMP_REG现在从它存回内存
Type* type = vreg_type_map.at(old_vreg);
RVOpcodes store_op = isFPVReg(old_vreg) ? RVOpcodes::FSW : (type->isPointer() ? RVOpcodes::SD : RVOpcodes::SW);
auto store = std::make_unique<MachineInstr>(store_op);
store->addOperand(std::make_unique<RegOperand>(SPILL_TEMP_REG));
store->addOperand(std::make_unique<MemOperand>(
std::make_unique<RegOperand>(PhysicalReg::S0),
std::make_unique<ImmOperand>(frame_info.spill_offsets.at(old_vreg))));
if (DEEPERDEBUG) {
std::cerr << " [Rewrite-Panic] Inserting STORE for def of %v" << old_vreg
<< " from " << pregToString(SPILL_TEMP_REG) << " after original instr.\n";
}
new_instrs.push_back(std::move(store));
}
}
}
new_instrs.push_back(std::move(instr));
// 为每个溢出的 def 操作数,在原指令后插入 store 指令
for(const auto& pair : def_remap) {
unsigned old_vreg = pair.first;
unsigned new_temp_vreg = pair.second;
Type* type = vreg_type_map.at(old_vreg);
RVOpcodes store_op = isFPVReg(old_vreg) ? RVOpcodes::FSW : (type->isPointer() ? RVOpcodes::SD : RVOpcodes::SW);
auto store = std::make_unique<MachineInstr>(store_op);
store->addOperand(std::make_unique<RegOperand>(new_temp_vreg));
store->addOperand(std::make_unique<MemOperand>(
std::make_unique<RegOperand>(PhysicalReg::S0),
std::make_unique<ImmOperand>(frame_info.spill_offsets.at(old_vreg))
));
if (DEEPDEBUG) std::cerr << " [Rewrite] Inserting STORE for def of %v" << old_vreg << " from new %v" << new_temp_vreg << "\n";
new_instrs.push_back(std::move(store));
} else {
// ================== 常规模式重写逻辑 (您的原始代码) ==================
std::map<unsigned, unsigned> use_remap, def_remap;
for (unsigned old_vreg : use_vregs) {
if (spilled_vregs.count(old_vreg) && use_remap.find(old_vreg) == use_remap.end()) {
Type* type = vreg_type_map.at(old_vreg);
unsigned new_temp_vreg = ISel->getNewVReg(type);
use_remap[old_vreg] = new_temp_vreg;
RVOpcodes load_op = isFPVReg(old_vreg) ? RVOpcodes::FLW : (type->isPointer() ? RVOpcodes::LD : RVOpcodes::LW);
auto load = std::make_unique<MachineInstr>(load_op);
load->addOperand(std::make_unique<RegOperand>(new_temp_vreg));
load->addOperand(std::make_unique<MemOperand>(
std::make_unique<RegOperand>(PhysicalReg::S0),
std::make_unique<ImmOperand>(frame_info.spill_offsets.at(old_vreg))));
if (DEEPERDEBUG) {
RISCv64AsmPrinter printer(MFunc);
std::cerr << " [Rewrite] Inserting LOAD for use of %v" << old_vreg << " into new %v" << new_temp_vreg << " before: " << printer.formatInstr(instr.get()) << "\n";
}
new_instrs.push_back(std::move(load));
}
}
for (unsigned old_vreg : def_vregs) {
if (spilled_vregs.count(old_vreg) && def_remap.find(old_vreg) == def_remap.end()) {
Type* type = vreg_type_map.at(old_vreg);
unsigned new_temp_vreg = ISel->getNewVReg(type);
def_remap[old_vreg] = new_temp_vreg;
}
}
auto original_instr_str_for_log = DEEPERDEBUG ? RISCv64AsmPrinter(MFunc).formatInstr(instr.get()) : "";
instr->remapVRegs(use_remap, def_remap);
new_instrs.push_back(std::move(instr));
if (DEEPERDEBUG && (!use_remap.empty() || !def_remap.empty())) std::cerr << " [Rewrite] Original: " << original_instr_str_for_log << " -> Rewritten: " << RISCv64AsmPrinter(MFunc).formatInstr(new_instrs.back().get()) << "\n";
for(const auto& pair : def_remap) {
unsigned old_vreg = pair.first;
unsigned new_temp_vreg = pair.second;
Type* type = vreg_type_map.at(old_vreg);
RVOpcodes store_op = isFPVReg(old_vreg) ? RVOpcodes::FSW : (type->isPointer() ? RVOpcodes::SD : RVOpcodes::SW);
auto store = std::make_unique<MachineInstr>(store_op);
store->addOperand(std::make_unique<RegOperand>(new_temp_vreg));
store->addOperand(std::make_unique<MemOperand>(
std::make_unique<RegOperand>(PhysicalReg::S0),
std::make_unique<ImmOperand>(frame_info.spill_offsets.at(old_vreg))));
if (DEEPERDEBUG) std::cerr << " [Rewrite] Inserting STORE for def of %v" << old_vreg << " from new %v" << new_temp_vreg << " after original instr.\n";
new_instrs.push_back(std::move(store));
}
}
}
instrs = std::move(new_instrs);
@@ -523,22 +611,18 @@ void RISCv64LinearScan::rewriteProgram() {
void RISCv64LinearScan::applyAllocation() {
if (DEBUG) std::cerr << "[LSRA-Apply] Applying final vreg->preg mapping.\n";
for (auto& mbb : MFunc->getBlocks()) {
if (DEEPDEBUG) std::cerr << " [Apply] In block " << mbb->getName() << ":\n";
for (auto& instr_ptr : mbb->getInstructions()) {
if (DEEPERDEBUG) std::cerr << " [Apply] Instr " << opcodeToString(instr_ptr->getOpcode()) << ": ";
for (auto& op_ptr : instr_ptr->getOperands()) {
if (op_ptr->getKind() == MachineOperand::KIND_REG) {
auto reg_op = static_cast<RegOperand*>(op_ptr.get());
if (reg_op->isVirtual()) {
unsigned vreg = reg_op->getVRegNum();
if (DEEPERDEBUG) std::cerr << "%v" << vreg << " -> ";
if (vreg_to_preg_map.count(vreg)) {
PhysicalReg preg = vreg_to_preg_map.at(vreg);
reg_op->setPReg(preg);
if (DEEPERDEBUG) std::cerr << pregToString(preg) << ", ";
reg_op->setPReg(vreg_to_preg_map.at(vreg));
} else {
reg_op->setPReg(PhysicalReg::T5);
if (DEEPERDEBUG) std::cerr << "T5 (uncolored), ";
std::cerr << "ERROR: Uncolored virtual register %v" << vreg << " found during applyAllocation! in func " << MFunc->getName() << "\n";
// Forcing an error is better than silent failure.
// reg_op->setPReg(PhysicalReg::T5);
}
}
} else if (op_ptr->getKind() == MachineOperand::KIND_MEM) {
@@ -546,24 +630,20 @@ void RISCv64LinearScan::applyAllocation() {
auto reg_op = mem_op->getBase();
if (reg_op->isVirtual()) {
unsigned vreg = reg_op->getVRegNum();
if (DEEPERDEBUG) std::cerr << "mem[%v" << vreg << "] -> ";
if (vreg_to_preg_map.count(vreg)) {
PhysicalReg preg = vreg_to_preg_map.at(vreg);
reg_op->setPReg(preg);
if (DEEPERDEBUG) std::cerr << "mem[" << pregToString(preg) << "], ";
reg_op->setPReg(vreg_to_preg_map.at(vreg));
} else {
reg_op->setPReg(PhysicalReg::T5);
if (DEEPERDEBUG) std::cerr << "mem[T5] (uncolored), ";
std::cerr << "ERROR: Uncolored virtual register %v" << vreg << " in memory operand! in func " << MFunc->getName() << "\n";
// reg_op->setPReg(PhysicalReg::T5);
}
}
}
}
if (DEEPERDEBUG) std::cerr << "\n";
}
}
}
void RISCv64LinearScan::getInstrUseDef(const MachineInstr* instr, std::set<unsigned>& use, std::set<unsigned>& def) {
void getInstrUseDef(const MachineInstr* instr, std::set<unsigned>& use, std::set<unsigned>& def) {
auto opcode = instr->getOpcode();
const auto& operands = instr->getOperands();
@@ -611,4 +691,4 @@ void RISCv64LinearScan::collectUsedCalleeSavedRegs() {
}
}
} // namespace sysy
} // namespace sysy