Compare commits

...

13 Commits

Author SHA1 Message Date
rain2133
fc7afdbb35 [midend]修复错误的RelExp类型转换 2025-07-31 13:55:59 +08:00
Lixuanwang
6d60522ce2 Merge branch 'midend' into backend-float 2025-07-31 12:14:38 +08:00
rain2133
82288464c3 [midend]修复常量变量的声明逻辑同变量声明,重构表达式生成逻辑(将中缀表达式转换为后缀表达式判断类型提升后再进行统一类型转换和计算)。运行脚本通过率[117/140]。 2025-07-31 02:47:39 +08:00
rain2133
7e8b90ffd4 [midend]修改全局变量,全局常量类,提供维度访问方法,消除维度信息(记录在Type中),createItoFInst命名修复,增加打印全局常量。 2025-07-31 02:45:40 +08:00
Lixuanwang
8e94f89931 Merge branch 'midend' into backend 2025-07-30 18:27:42 +08:00
Lixuanwang
b388dc4542 Merge branch 'backend-float' into backend 2025-07-30 18:26:06 +08:00
Lixuanwang
48b0aec6c3 [midend][backend]修复了全局常量数组的访问错误 2025-07-30 18:23:56 +08:00
Lixuanwang
1fb5cd398d [backend]修复了多参数传递的问题 2025-07-30 17:58:39 +08:00
Lixuanwang
877a0f5dc2 [backend-float]修复部分问题 2025-07-30 16:00:02 +08:00
a3c4d5a2b8 [Optimize]对PreRA指令调度进行优化 2025-07-30 15:27:23 +08:00
Lixuanwang
39c13c46ec Merge branch 'midend' into backend-float 2025-07-30 15:10:38 +08:00
Lixuanwang
dd38bdc133 [backend]引入浮点数支持,但目前寄存器分配存在问题 2025-07-30 15:07:29 +08:00
860ebcd447 [Optimize]对PostRA指令调度进行容器/算法/缓存优化 2025-07-30 10:28:06 +08:00
18 changed files with 2144 additions and 903 deletions

View File

@@ -1,6 +1,8 @@
#include "CalleeSavedHandler.h"
#include <set>
#include <vector> //
#include <algorithm>
#include <iterator> //
namespace sysy {
@@ -14,23 +16,34 @@ bool CalleeSavedHandler::runOnFunction(Function *F, AnalysisManager& AM) {
void CalleeSavedHandler::runOnMachineFunction(MachineFunction* mfunc) {
// 此 Pass 负责分析、分配栈空间并插入 callee-saved 寄存器的保存/恢复指令。
// 它通过与 FrameInfo 协作,确保为 callee-saved 寄存器分配的空间与局部变量/溢出槽的空间不冲突。
// 这样做可以使生成的 sd/ld 指令能被后续的优化 Pass (如 PostRA-Scheduler) 处理。
StackFrameInfo& frame_info = mfunc->getFrameInfo();
std::set<PhysicalReg> used_callee_saved;
// [修改] 分别记录被使用的整数和浮点被调用者保存寄存器
std::set<PhysicalReg> used_int_callee_saved;
std::set<PhysicalReg> used_fp_callee_saved;
// 1. 扫描所有指令找出被使用的s寄存器 (s1-s11)
// 1. 扫描所有指令找出被使用的s寄存器 (s1-s11) 和 fs寄存器 (fs0-fs11)
for (auto& mbb : mfunc->getBlocks()) {
for (auto& instr : mbb->getInstructions()) {
for (auto& op : instr->getOperands()) {
auto check_and_insert_reg = [&](RegOperand* reg_op) {
if (!reg_op->isVirtual()) {
PhysicalReg preg = reg_op->getPReg();
// [修改] 区分整数和浮点被调用者保存寄存器
// s0 由序言/尾声处理器专门处理,这里不计入
if (preg >= PhysicalReg::S1 && preg <= PhysicalReg::S11) {
used_callee_saved.insert(preg);
used_int_callee_saved.insert(preg);
}
// fs0-fs11 在我们的枚举中对应 f8,f9,f18-f27
else if ((preg >= PhysicalReg::F8 && preg <= PhysicalReg::F9) || (preg >= PhysicalReg::F18 && preg <= PhysicalReg::F27)) {
used_fp_callee_saved.insert(preg);
}
}
};
if (op->getKind() == MachineOperand::KIND_REG) {
check_and_insert_reg(static_cast<RegOperand*>(op.get()));
} else if (op->getKind() == MachineOperand::KIND_MEM) {
@@ -40,83 +53,109 @@ void CalleeSavedHandler::runOnMachineFunction(MachineFunction* mfunc) {
}
}
if (used_callee_saved.empty()) {
// 如果没有使用任何需要处理的 callee-saved 寄存器,则直接返回
if (used_int_callee_saved.empty() && used_fp_callee_saved.empty()) {
frame_info.callee_saved_size = 0; // 确保大小被初始化
return; // 无需操作
return;
}
// 2. 计算为 callee-saved 寄存器分配的栈空间
// 这里的关键是,偏移的基准点要在局部变量和溢出槽之下。
int callee_saved_size = used_callee_saved.size() * 8;
frame_info.callee_saved_size = callee_saved_size; // 将大小存入 FrameInfo
// 2. 计算为 callee-saved 寄存器分配的栈空间大小
// 每个寄存器在RV64中都占用8字节
int callee_saved_size = (used_int_callee_saved.size() + used_fp_callee_saved.size()) * 8;
frame_info.callee_saved_size = callee_saved_size;
// 3. 计算无冲突的栈偏移
// 栈向下增长,所以偏移是负数。
// ra/s0 占用 -8 和 -16。局部变量和溢出区在它们之下。callee-saved 区在更下方。
// 我们使用相对于 s0 的偏移。s0 将指向栈顶 (sp + total_size)。
int base_offset = -16 - frame_info.locals_size - frame_info.spill_size;
// 为了栈帧布局确定性,对寄存器进行排序
std::vector<PhysicalReg> sorted_regs(used_callee_saved.begin(), used_callee_saved.end());
std::sort(sorted_regs.begin(), sorted_regs.end());
// 4. 在函数序言插入保存指令
// 3. 在函数序言中插入保存指令
MachineBasicBlock* entry_block = mfunc->getBlocks().front().get();
auto& entry_instrs = entry_block->getInstructions();
auto prologue_end = entry_instrs.begin();
// 插入点通常在函数入口标签之后
auto insert_pos = entry_instrs.begin();
if (!entry_instrs.empty() && entry_instrs.front()->getOpcode() == RVOpcodes::LABEL) {
insert_pos = std::next(insert_pos);
}
// 找到序言结束的位置通常是addi s0, sp, size之后但为了让优化器看到我们插在更前面
// 合理的位置是在 IR 指令开始之前,即在任何非序言指令(如第一个标签)之前。
// 为简单起见,我们直接插入到块的开头,后续重排 pass 会处理。
// (更优的实现会寻找一个特定的插入点)
// 为了布局确定性,对寄存器进行排序并按序保存
std::vector<PhysicalReg> sorted_int_regs(used_int_callee_saved.begin(), used_int_callee_saved.end());
std::vector<PhysicalReg> sorted_fp_regs(used_fp_callee_saved.begin(), used_fp_callee_saved.end());
std::sort(sorted_int_regs.begin(), sorted_int_regs.end());
std::sort(sorted_fp_regs.begin(), sorted_fp_regs.end());
std::vector<std::unique_ptr<MachineInstr>> save_instrs;
int current_offset = -16; // ra和s0已占用-8和-16从-24开始分配
int current_offset = base_offset;
for (PhysicalReg reg : sorted_regs) {
// 准备整数保存指令 (sd)
for (PhysicalReg reg : sorted_int_regs) {
current_offset -= 8;
auto sd = std::make_unique<MachineInstr>(RVOpcodes::SD);
sd->addOperand(std::make_unique<RegOperand>(reg));
sd->addOperand(std::make_unique<MemOperand>(
std::make_unique<RegOperand>(PhysicalReg::S0), // 基址为帧指针 s0
std::make_unique<ImmOperand>(current_offset)
));
// 从头部插入,但要放在函数标签之后
entry_instrs.insert(entry_instrs.begin() + 1, std::move(sd));
save_instrs.push_back(std::move(sd));
}
// 准备浮点保存指令 (fsd)
for (PhysicalReg reg : sorted_fp_regs) {
current_offset -= 8;
auto fsd = std::make_unique<MachineInstr>(RVOpcodes::FSD); // 使用浮点保存指令
fsd->addOperand(std::make_unique<RegOperand>(reg));
fsd->addOperand(std::make_unique<MemOperand>(
std::make_unique<RegOperand>(PhysicalReg::S0),
std::make_unique<ImmOperand>(current_offset)
));
save_instrs.push_back(std::move(fsd));
}
// 5. 在函数结尾ret之前插入恢复指令使用反向遍历来避免迭代器失效
// 一次性插入所有保存指令
if (!save_instrs.empty()) {
entry_instrs.insert(insert_pos,
std::make_move_iterator(save_instrs.begin()),
std::make_move_iterator(save_instrs.end()));
}
// 4. 在函数结尾ret之前插入恢复指令
for (auto& mbb : mfunc->getBlocks()) {
// 使用手动控制的反向循环
for (auto it = mbb->getInstructions().begin(); it != mbb->getInstructions().end(); ++it) {
if ((*it)->getOpcode() == RVOpcodes::RET) {
// 1. 创建一个临时vector来存储所有需要插入的恢复指令
std::vector<std::unique_ptr<MachineInstr>> restore_instrs;
current_offset = -16; // 重置偏移量用于恢复
int current_offset_load = base_offset;
// 以相同的顺序(例如 s1, s2, ...)创建恢复指令
for (PhysicalReg reg : sorted_regs) {
// 准备恢复整数寄存器 (ld) - 以与保存时相同的顺序
for (PhysicalReg reg : sorted_int_regs) {
current_offset -= 8;
auto ld = std::make_unique<MachineInstr>(RVOpcodes::LD);
ld->addOperand(std::make_unique<RegOperand>(reg));
ld->addOperand(std::make_unique<MemOperand>(
std::make_unique<RegOperand>(PhysicalReg::S0),
std::make_unique<ImmOperand>(current_offset_load)
std::make_unique<ImmOperand>(current_offset)
));
restore_instrs.push_back(std::move(ld));
current_offset_load -= 8;
}
// 2. 使用 make_move_iterator 一次性将所有恢复指令插入到 RET 指令之前
// 这可以高效地转移指令的所有权,并且只让迭代器失效一次。
// 准备恢复浮点寄存器 (fld)
for (PhysicalReg reg : sorted_fp_regs) {
current_offset -= 8;
auto fld = std::make_unique<MachineInstr>(RVOpcodes::FLD); // 使用浮点加载指令
fld->addOperand(std::make_unique<RegOperand>(reg));
fld->addOperand(std::make_unique<MemOperand>(
std::make_unique<RegOperand>(PhysicalReg::S0),
std::make_unique<ImmOperand>(current_offset)
));
restore_instrs.push_back(std::move(fld));
}
// 一次性插入所有恢复指令
if (!restore_instrs.empty()) {
mbb->getInstructions().insert(it,
std::make_move_iterator(restore_instrs.begin()),
std::make_move_iterator(restore_instrs.end())
);
mbb->getInstructions().insert(it,
std::make_move_iterator(restore_instrs.begin()),
std::make_move_iterator(restore_instrs.end()));
}
// 找到了RET并处理完毕后就可以跳出内层循环继续寻找下一个基本块
break;
// 处理完一个基本块的RET后迭代器已失效需跳出当前块的循环
goto next_block_label;
}
}
next_block_label:;
}
}

View File

@@ -52,6 +52,8 @@ void LegalizeImmediatesPass::runOnMachineFunction(MachineFunction* mfunc) {
case RVOpcodes::ADDI:
case RVOpcodes::ADDIW: {
auto& operands = instr_ptr->getOperands();
// 确保操作数足够多,以防万一
if (operands.size() < 3) break;
auto imm_op = static_cast<ImmOperand*>(operands.back().get());
if (!isLegalImmediate(imm_op->getValue())) {
@@ -73,7 +75,7 @@ void LegalizeImmediatesPass::runOnMachineFunction(MachineFunction* mfunc) {
add->addOperand(std::move(rd_op));
add->addOperand(std::move(rs1_op));
add->addOperand(std::make_unique<RegOperand>(TEMP_REG));
if (DEEPDEBUG) {
std::cerr << " New sequence:\n ";
temp_printer.printInstruction(li.get(), true);
@@ -92,7 +94,8 @@ void LegalizeImmediatesPass::runOnMachineFunction(MachineFunction* mfunc) {
// 处理所有内存加载/存储指令
case RVOpcodes::LB: case RVOpcodes::LH: case RVOpcodes::LW: case RVOpcodes::LD:
case RVOpcodes::LBU: case RVOpcodes::LHU: case RVOpcodes::LWU:
case RVOpcodes::SB: case RVOpcodes::SH: case RVOpcodes::SW: case RVOpcodes::SD: {
case RVOpcodes::SB: case RVOpcodes::SH: case RVOpcodes::SW: case RVOpcodes::SD:
case RVOpcodes::FLW: case RVOpcodes::FSW: {
auto& operands = instr_ptr->getOperands();
auto mem_op = static_cast<MemOperand*>(operands.back().get());
auto offset_op = mem_op->getOffset();

View File

@@ -1,4 +1,6 @@
#include "PrologueEpilogueInsertion.h"
#include "RISCv64ISel.h"
#include "RISCv64RegAlloc.h" // 需要访问RegAlloc的结果
namespace sysy {
@@ -6,7 +8,13 @@ char PrologueEpilogueInsertionPass::ID = 0;
void PrologueEpilogueInsertionPass::runOnMachineFunction(MachineFunction* mfunc) {
StackFrameInfo& frame_info = mfunc->getFrameInfo();
Function* F = mfunc->getFunc();
RISCv64ISel* isel = mfunc->getISel();
// [关键] 获取寄存器分配的结果 (vreg -> preg 的映射)
// RegAlloc Pass 必须已经运行过
auto& vreg_to_preg_map = frame_info.vreg_to_preg_map;
// 完全遵循 AsmPrinter 中的计算逻辑
int total_stack_size = frame_info.locals_size +
frame_info.spill_size +
@@ -24,7 +32,6 @@ void PrologueEpilogueInsertionPass::runOnMachineFunction(MachineFunction* mfunc)
std::vector<std::unique_ptr<MachineInstr>> prologue_instrs;
// 严格按照 AsmPrinter 的打印顺序来创建和组织指令
// 1. addi sp, sp, -aligned_stack_size
auto alloc_stack = std::make_unique<MachineInstr>(RVOpcodes::ADDI);
alloc_stack->addOperand(std::make_unique<RegOperand>(PhysicalReg::SP));
@@ -57,10 +64,63 @@ void PrologueEpilogueInsertionPass::runOnMachineFunction(MachineFunction* mfunc)
set_fp->addOperand(std::make_unique<ImmOperand>(aligned_stack_size));
prologue_instrs.push_back(std::move(set_fp));
// 确定插入点(在函数名标签之后)
// --- [正确逻辑] 在s0设置完毕后使用物理寄存器加载栈参数 ---
if (F && isel) {
// 定义暂存寄存器
const PhysicalReg INT_SCRATCH_REG = PhysicalReg::T5;
const PhysicalReg FP_SCRATCH_REG = PhysicalReg::F7;
int arg_idx = 0;
for (Argument* arg : F->getArguments()) {
if (arg_idx >= 8) {
unsigned vreg = isel->getVReg(arg);
// 确认RegAlloc已经为这个vreg计算了偏移量并且分配了物理寄存器
if (frame_info.alloca_offsets.count(vreg) && vreg_to_preg_map.count(vreg)) {
int offset = frame_info.alloca_offsets.at(vreg);
PhysicalReg dest_preg = vreg_to_preg_map.at(vreg);
Type* arg_type = arg->getType();
// 根据类型执行不同的加载序列
if (arg_type->isFloat()) {
// 1. flw ft7, offset(s0)
auto load_arg = std::make_unique<MachineInstr>(RVOpcodes::FLW);
load_arg->addOperand(std::make_unique<RegOperand>(FP_SCRATCH_REG));
load_arg->addOperand(std::make_unique<MemOperand>(
std::make_unique<RegOperand>(PhysicalReg::S0),
std::make_unique<ImmOperand>(offset)
));
prologue_instrs.push_back(std::move(load_arg));
// 2. fmv.s dest_preg, ft7
auto move_arg = std::make_unique<MachineInstr>(RVOpcodes::FMV_S);
move_arg->addOperand(std::make_unique<RegOperand>(dest_preg));
move_arg->addOperand(std::make_unique<RegOperand>(FP_SCRATCH_REG));
prologue_instrs.push_back(std::move(move_arg));
} else {
// 确定是加载32位(lw)还是64位(ld)
RVOpcodes load_op = arg_type->isPointer() ? RVOpcodes::LD : RVOpcodes::LW;
// 1. lw/ld t5, offset(s0)
auto load_arg = std::make_unique<MachineInstr>(load_op);
load_arg->addOperand(std::make_unique<RegOperand>(INT_SCRATCH_REG));
load_arg->addOperand(std::make_unique<MemOperand>(
std::make_unique<RegOperand>(PhysicalReg::S0),
std::make_unique<ImmOperand>(offset)
));
prologue_instrs.push_back(std::move(load_arg));
// 2. mv dest_preg, t5
auto move_arg = std::make_unique<MachineInstr>(RVOpcodes::MV);
move_arg->addOperand(std::make_unique<RegOperand>(dest_preg));
move_arg->addOperand(std::make_unique<RegOperand>(INT_SCRATCH_REG));
prologue_instrs.push_back(std::move(move_arg));
}
}
}
arg_idx++;
}
}
// 确定插入点
auto insert_pos = entry_instrs.begin();
// [重要] 这里我们不再需要跳过LABEL因为AsmPrinter将不再打印函数名标签
// 第一个基本块的标签就是函数入口
// 一次性将所有序言指令插入
if (!prologue_instrs.empty()) {
@@ -69,14 +129,13 @@ void PrologueEpilogueInsertionPass::runOnMachineFunction(MachineFunction* mfunc)
std::make_move_iterator(prologue_instrs.end()));
}
// --- 2. 插入尾声 ---
// --- 2. 插入尾声 (此部分逻辑保持不变) ---
for (auto& mbb : mfunc->getBlocks()) {
for (auto it = mbb->getInstructions().begin(); it != mbb->getInstructions().end(); ++it) {
if ((*it)->getOpcode() == RVOpcodes::RET) {
std::vector<std::unique_ptr<MachineInstr>> epilogue_instrs;
// 同样严格按照 AsmPrinter 的打印顺序
// 1. ld ra, (aligned_stack_size - 8)(sp)
// 1. ld ra
auto restore_ra = std::make_unique<MachineInstr>(RVOpcodes::LD);
restore_ra->addOperand(std::make_unique<RegOperand>(PhysicalReg::RA));
restore_ra->addOperand(std::make_unique<MemOperand>(
@@ -85,7 +144,7 @@ void PrologueEpilogueInsertionPass::runOnMachineFunction(MachineFunction* mfunc)
));
epilogue_instrs.push_back(std::move(restore_ra));
// 2. ld s0, (aligned_stack_size - 16)(sp)
// 2. ld s0
auto restore_fp = std::make_unique<MachineInstr>(RVOpcodes::LD);
restore_fp->addOperand(std::make_unique<RegOperand>(PhysicalReg::S0));
restore_fp->addOperand(std::make_unique<MemOperand>(
@@ -106,7 +165,6 @@ void PrologueEpilogueInsertionPass::runOnMachineFunction(MachineFunction* mfunc)
std::make_move_iterator(epilogue_instrs.begin()),
std::make_move_iterator(epilogue_instrs.end()));
}
// 处理完一个基本块中的RET后迭代器已失效需跳出
goto next_block;
}
}

View File

@@ -1,8 +1,8 @@
#include "PostRA_Scheduler.h"
#include <set>
#include <map>
#include <vector>
#include <algorithm>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#define MAX_SCHEDULING_BLOCK_SIZE 10000 // 限制调度块大小,避免过大导致性能问题
namespace sysy {
@@ -10,374 +10,407 @@ namespace sysy {
char PostRA_Scheduler::ID = 0;
// 检查指令是否是加载指令 (LW, LD)
bool isLoadInstr(MachineInstr* instr) {
RVOpcodes opcode = instr->getOpcode();
return opcode == RVOpcodes::LW || opcode == RVOpcodes::LD ||
opcode == RVOpcodes::LH || opcode == RVOpcodes::LB ||
opcode == RVOpcodes::LHU || opcode == RVOpcodes::LBU ||
opcode == RVOpcodes::LWU;
bool isLoadInstr(MachineInstr *instr) {
RVOpcodes opcode = instr->getOpcode();
return opcode == RVOpcodes::LW || opcode == RVOpcodes::LD ||
opcode == RVOpcodes::LH || opcode == RVOpcodes::LB ||
opcode == RVOpcodes::LHU || opcode == RVOpcodes::LBU ||
opcode == RVOpcodes::LWU;
}
// 检查指令是否是存储指令 (SW, SD)
bool isStoreInstr(MachineInstr* instr) {
RVOpcodes opcode = instr->getOpcode();
return opcode == RVOpcodes::SW || opcode == RVOpcodes::SD ||
opcode == RVOpcodes::SH || opcode == RVOpcodes::SB;
bool isStoreInstr(MachineInstr *instr) {
RVOpcodes opcode = instr->getOpcode();
return opcode == RVOpcodes::SW || opcode == RVOpcodes::SD ||
opcode == RVOpcodes::SH || opcode == RVOpcodes::SB;
}
// 检查指令是否为控制流指令
bool isControlFlowInstr(MachineInstr* instr) {
RVOpcodes opcode = instr->getOpcode();
return opcode == RVOpcodes::RET || opcode == RVOpcodes::J ||
opcode == RVOpcodes::BEQ || opcode == RVOpcodes::BNE ||
opcode == RVOpcodes::BLT || opcode == RVOpcodes::BGE ||
opcode == RVOpcodes::BLTU || opcode == RVOpcodes::BGEU ||
opcode == RVOpcodes::CALL;
bool isControlFlowInstr(MachineInstr *instr) {
RVOpcodes opcode = instr->getOpcode();
return opcode == RVOpcodes::RET || opcode == RVOpcodes::J ||
opcode == RVOpcodes::BEQ || opcode == RVOpcodes::BNE ||
opcode == RVOpcodes::BLT || opcode == RVOpcodes::BGE ||
opcode == RVOpcodes::BLTU || opcode == RVOpcodes::BGEU ||
opcode == RVOpcodes::CALL;
}
// 获取指令定义的寄存器 - 修复版本
std::set<PhysicalReg> getDefinedRegisters(MachineInstr* instr) {
std::set<PhysicalReg> defined_regs;
RVOpcodes opcode = instr->getOpcode();
// 特殊处理CALL指令
if (opcode == RVOpcodes::CALL) {
// CALL指令可能定义返回值寄存器
if (!instr->getOperands().empty() &&
instr->getOperands().front()->getKind() == MachineOperand::KIND_REG) {
auto reg_op = static_cast<RegOperand*>(instr->getOperands().front().get());
if (!reg_op->isVirtual()) {
defined_regs.insert(reg_op->getPReg());
}
}
return defined_regs;
}
// 存储指令不定义寄存器
if (isStoreInstr(instr)) {
return defined_regs;
}
// 分支指令不定义寄存器
if (opcode == RVOpcodes::BEQ || opcode == RVOpcodes::BNE ||
opcode == RVOpcodes::BLT || opcode == RVOpcodes::BGE ||
opcode == RVOpcodes::BLTU || opcode == RVOpcodes::BGEU ||
opcode == RVOpcodes::J || opcode == RVOpcodes::RET) {
return defined_regs;
}
// 对于其他指令,第一个寄存器操作数通常是定义的
if (!instr->getOperands().empty() &&
// 预计算指令信息的缓存
static std::unordered_map<MachineInstr *, InstrRegInfo> instr_info_cache;
// 获取指令定义的寄存器 - 优化版本
std::unordered_set<PhysicalReg> getDefinedRegisters(MachineInstr *instr) {
std::unordered_set<PhysicalReg> defined_regs;
RVOpcodes opcode = instr->getOpcode();
// 特殊处理CALL指令
if (opcode == RVOpcodes::CALL) {
// CALL指令可能定义返回值寄存器
if (!instr->getOperands().empty() &&
instr->getOperands().front()->getKind() == MachineOperand::KIND_REG) {
auto reg_op = static_cast<RegOperand*>(instr->getOperands().front().get());
if (!reg_op->isVirtual()) {
defined_regs.insert(reg_op->getPReg());
}
auto reg_op =
static_cast<RegOperand *>(instr->getOperands().front().get());
if (!reg_op->isVirtual()) {
defined_regs.insert(reg_op->getPReg());
}
}
return defined_regs;
}
// 存储指令不定义寄存器
if (isStoreInstr(instr)) {
return defined_regs;
}
// 分支指令不定义寄存器
if (opcode == RVOpcodes::BEQ || opcode == RVOpcodes::BNE ||
opcode == RVOpcodes::BLT || opcode == RVOpcodes::BGE ||
opcode == RVOpcodes::BLTU || opcode == RVOpcodes::BGEU ||
opcode == RVOpcodes::J || opcode == RVOpcodes::RET) {
return defined_regs;
}
// 对于其他指令,第一个寄存器操作数通常是定义的
if (!instr->getOperands().empty() &&
instr->getOperands().front()->getKind() == MachineOperand::KIND_REG) {
auto reg_op = static_cast<RegOperand *>(instr->getOperands().front().get());
if (!reg_op->isVirtual()) {
defined_regs.insert(reg_op->getPReg());
}
}
return defined_regs;
}
// 获取指令使用的寄存器 - 修复版本
std::set<PhysicalReg> getUsedRegisters(MachineInstr* instr) {
std::set<PhysicalReg> used_regs;
RVOpcodes opcode = instr->getOpcode();
// 特殊处理CALL指令
if (opcode == RVOpcodes::CALL) {
bool first_reg_skipped = false;
for (const auto& op : instr->getOperands()) {
if (op->getKind() == MachineOperand::KIND_REG) {
if (!first_reg_skipped) {
first_reg_skipped = true;
continue; // 跳过返回值寄存器
}
auto reg_op = static_cast<RegOperand*>(op.get());
if (!reg_op->isVirtual()) {
used_regs.insert(reg_op->getPReg());
}
}
// 获取指令使用的寄存器 - 优化版本
std::unordered_set<PhysicalReg> getUsedRegisters(MachineInstr *instr) {
std::unordered_set<PhysicalReg> used_regs;
RVOpcodes opcode = instr->getOpcode();
// 特殊处理CALL指令
if (opcode == RVOpcodes::CALL) {
bool first_reg_skipped = false;
for (const auto &op : instr->getOperands()) {
if (op->getKind() == MachineOperand::KIND_REG) {
if (!first_reg_skipped) {
first_reg_skipped = true;
continue; // 跳过返回值寄存器
}
return used_regs;
}
// 对于存储指令,所有寄存器操作数都是使用的
if (isStoreInstr(instr)) {
for (const auto& op : instr->getOperands()) {
if (op->getKind() == MachineOperand::KIND_REG) {
auto reg_op = static_cast<RegOperand*>(op.get());
if (!reg_op->isVirtual()) {
used_regs.insert(reg_op->getPReg());
}
} else if (op->getKind() == MachineOperand::KIND_MEM) {
auto mem_op = static_cast<MemOperand*>(op.get());
if (!mem_op->getBase()->isVirtual()) {
used_regs.insert(mem_op->getBase()->getPReg());
}
}
}
return used_regs;
}
// 对于分支指令,所有寄存器操作数都是使用的
if (opcode == RVOpcodes::BEQ || opcode == RVOpcodes::BNE ||
opcode == RVOpcodes::BLT || opcode == RVOpcodes::BGE ||
opcode == RVOpcodes::BLTU || opcode == RVOpcodes::BGEU) {
for (const auto& op : instr->getOperands()) {
if (op->getKind() == MachineOperand::KIND_REG) {
auto reg_op = static_cast<RegOperand*>(op.get());
if (!reg_op->isVirtual()) {
used_regs.insert(reg_op->getPReg());
}
}
}
return used_regs;
}
// 对于其他指令,除了第一个寄存器操作数(通常是定义),其余都是使用的
bool first_reg = true;
for (const auto& op : instr->getOperands()) {
if (op->getKind() == MachineOperand::KIND_REG) {
if (first_reg) {
first_reg = false;
continue; // 跳过第一个寄存器(定义)
}
auto reg_op = static_cast<RegOperand*>(op.get());
if (!reg_op->isVirtual()) {
used_regs.insert(reg_op->getPReg());
}
} else if (op->getKind() == MachineOperand::KIND_MEM) {
auto mem_op = static_cast<MemOperand*>(op.get());
if (!mem_op->getBase()->isVirtual()) {
used_regs.insert(mem_op->getBase()->getPReg());
}
auto reg_op = static_cast<RegOperand *>(op.get());
if (!reg_op->isVirtual()) {
used_regs.insert(reg_op->getPReg());
}
}
}
return used_regs;
}
// 对于存储指令,所有寄存器操作数都是使用的
if (isStoreInstr(instr)) {
for (const auto &op : instr->getOperands()) {
if (op->getKind() == MachineOperand::KIND_REG) {
auto reg_op = static_cast<RegOperand *>(op.get());
if (!reg_op->isVirtual()) {
used_regs.insert(reg_op->getPReg());
}
} else if (op->getKind() == MachineOperand::KIND_MEM) {
auto mem_op = static_cast<MemOperand *>(op.get());
if (!mem_op->getBase()->isVirtual()) {
used_regs.insert(mem_op->getBase()->getPReg());
}
}
}
return used_regs;
}
// 对于分支指令,所有寄存器操作数都是使用的
if (opcode == RVOpcodes::BEQ || opcode == RVOpcodes::BNE ||
opcode == RVOpcodes::BLT || opcode == RVOpcodes::BGE ||
opcode == RVOpcodes::BLTU || opcode == RVOpcodes::BGEU) {
for (const auto &op : instr->getOperands()) {
if (op->getKind() == MachineOperand::KIND_REG) {
auto reg_op = static_cast<RegOperand *>(op.get());
if (!reg_op->isVirtual()) {
used_regs.insert(reg_op->getPReg());
}
}
}
return used_regs;
}
// 对于其他指令,除了第一个寄存器操作数(通常是定义),其余都是使用的
bool first_reg = true;
for (const auto &op : instr->getOperands()) {
if (op->getKind() == MachineOperand::KIND_REG) {
if (first_reg) {
first_reg = false;
continue; // 跳过第一个寄存器(定义)
}
auto reg_op = static_cast<RegOperand *>(op.get());
if (!reg_op->isVirtual()) {
used_regs.insert(reg_op->getPReg());
}
} else if (op->getKind() == MachineOperand::KIND_MEM) {
auto mem_op = static_cast<MemOperand *>(op.get());
if (!mem_op->getBase()->isVirtual()) {
used_regs.insert(mem_op->getBase()->getPReg());
}
}
}
return used_regs;
}
// 获取内存访问的基址和偏移
struct MemoryAccess {
PhysicalReg base_reg;
int64_t offset;
bool valid;
MemoryAccess() : valid(false) {}
MemoryAccess(PhysicalReg base, int64_t off) : base_reg(base), offset(off), valid(true) {}
};
MemoryAccess getMemoryAccess(MachineInstr* instr) {
if (!isLoadInstr(instr) && !isStoreInstr(instr)) {
return MemoryAccess();
}
// 查找内存操作数
for (const auto& op : instr->getOperands()) {
if (op->getKind() == MachineOperand::KIND_MEM) {
auto mem_op = static_cast<MemOperand*>(op.get());
if (!mem_op->getBase()->isVirtual()) {
return MemoryAccess(mem_op->getBase()->getPReg(), mem_op->getOffset()->getValue());
}
}
}
MemoryAccess getMemoryAccess(MachineInstr *instr) {
if (!isLoadInstr(instr) && !isStoreInstr(instr)) {
return MemoryAccess();
}
// 查找内存操作数
for (const auto &op : instr->getOperands()) {
if (op->getKind() == MachineOperand::KIND_MEM) {
auto mem_op = static_cast<MemOperand *>(op.get());
if (!mem_op->getBase()->isVirtual()) {
return MemoryAccess(mem_op->getBase()->getPReg(),
mem_op->getOffset()->getValue());
}
}
}
return MemoryAccess();
}
// 检查内存依赖 - 加强版本
bool hasMemoryDependency(MachineInstr* instr1, MachineInstr* instr2) {
// 如果都不是内存指令,没有内存依赖
if (!isLoadInstr(instr1) && !isStoreInstr(instr1) &&
!isLoadInstr(instr2) && !isStoreInstr(instr2)) {
return false;
}
MemoryAccess mem1 = getMemoryAccess(instr1);
MemoryAccess mem2 = getMemoryAccess(instr2);
if (!mem1.valid || !mem2.valid) {
// 如果无法确定内存访问模式,保守地认为存在依赖
return true;
}
// 如果访问相同的内存位置
if (mem1.base_reg == mem2.base_reg && mem1.offset == mem2.offset) {
// Store->Load: RAW依赖
// Load->Store: WAR依赖
// Store->Store: WAW依赖
return isStoreInstr(instr1) || isStoreInstr(instr2);
}
// 不同内存位置通常没有依赖,但为了安全起见,
// 如果涉及store指令我们需要更保守
if (isStoreInstr(instr1) && isLoadInstr(instr2)) {
// 保守处理不同store和load之间可能有别名
return false; // 这里可以根据需要调整策略
}
// 预计算指令信息
InstrRegInfo &getInstrInfo(MachineInstr *instr) {
auto it = instr_info_cache.find(instr);
if (it != instr_info_cache.end()) {
return it->second;
}
InstrRegInfo &info = instr_info_cache[instr];
info.defined_regs = getDefinedRegisters(instr);
info.used_regs = getUsedRegisters(instr);
info.is_load = isLoadInstr(instr);
info.is_store = isStoreInstr(instr);
info.is_control_flow = isControlFlowInstr(instr);
info.mem_access = getMemoryAccess(instr);
return info;
}
// 检查内存依赖 - 优化版本
bool hasMemoryDependency(const InstrRegInfo &info1, const InstrRegInfo &info2) {
// 如果都不是内存指令,没有内存依赖
if (!info1.is_load && !info1.is_store && !info2.is_load && !info2.is_store) {
return false;
}
const MemoryAccess &mem1 = info1.mem_access;
const MemoryAccess &mem2 = info2.mem_access;
if (!mem1.valid || !mem2.valid) {
// 如果无法确定内存访问模式,保守地认为存在依赖
return true;
}
// 如果访问相同的内存位置
if (mem1.base_reg == mem2.base_reg && mem1.offset == mem2.offset) {
// Store->Load: RAW依赖
// Load->Store: WAR依赖
// Store->Store: WAW依赖
return info1.is_store || info2.is_store;
}
// 不同内存位置通常没有依赖,但为了安全起见,
// 如果涉及store指令我们需要更保守
if (info1.is_store && info2.is_load) {
// 保守处理不同store和load之间可能有别名
return false; // 这里可以根据需要调整策略
}
return false;
}
// 检查两个指令之间是否存在依赖关系 - 修复版本
bool hasDependency(MachineInstr* instr1, MachineInstr* instr2) {
// 检查RAW依赖instr1定义的寄存器是否被instr2使用
auto defined_regs1 = getDefinedRegisters(instr1);
auto used_regs2 = getUsedRegisters(instr2);
for (const auto& reg : defined_regs1) {
if (used_regs2.find(reg) != used_regs2.end()) {
return true; // RAW依赖 - instr2读取instr1写入的值
}
// 检查两个指令之间是否存在依赖关系 - 优化版本
bool hasDependency(MachineInstr *instr1, MachineInstr *instr2) {
const InstrRegInfo &info1 = getInstrInfo(instr1);
const InstrRegInfo &info2 = getInstrInfo(instr2);
// 检查RAW依赖instr1定义的寄存器是否被instr2使用
for (const auto &reg : info1.defined_regs) {
if (info2.used_regs.find(reg) != info2.used_regs.end()) {
return true; // RAW依赖 - instr2读取instr1写入的值
}
// 检查WAR依赖instr1使用的寄存器是否被instr2定义
auto used_regs1 = getUsedRegisters(instr1);
auto defined_regs2 = getDefinedRegisters(instr2);
for (const auto& reg : used_regs1) {
if (defined_regs2.find(reg) != defined_regs2.end()) {
return true; // WAR依赖 - instr2覆盖instr1需要的值
}
}
// 检查WAR依赖instr1使用的寄存器是否被instr2定义
for (const auto &reg : info1.used_regs) {
if (info2.defined_regs.find(reg) != info2.defined_regs.end()) {
return true; // WAR依赖 - instr2覆盖instr1需要的值
}
// 检查WAW依赖两个指令定义相同寄存器
for (const auto& reg : defined_regs1) {
if (defined_regs2.find(reg) != defined_regs2.end()) {
return true; // WAW依赖 - 两条指令写入同一寄存器
}
}
// 检查WAW依赖两个指令定义相同寄存器
for (const auto &reg : info1.defined_regs) {
if (info2.defined_regs.find(reg) != info2.defined_regs.end()) {
return true; // WAW依赖 - 两条指令写入同一寄存器
}
// 检查内存依赖
if (hasMemoryDependency(instr1, instr2)) {
return true;
}
}
// 检查内存依赖
if (hasMemoryDependency(info1, info2)) {
return true;
}
return false;
}
// 检查是否可以安全地将instr1和instr2交换位置 - 优化版本
bool canSwapInstructions(MachineInstr *instr1, MachineInstr *instr2) {
const InstrRegInfo &info1 = getInstrInfo(instr1);
const InstrRegInfo &info2 = getInstrInfo(instr2);
// 不能移动控制流指令
if (info1.is_control_flow || info2.is_control_flow) {
return false;
}
// 检查双向依赖关系
return !hasDependency(instr1, instr2) && !hasDependency(instr2, instr1);
}
// 检查是否可以安全地将instr1和instr2交换位置
bool canSwapInstructions(MachineInstr* instr1, MachineInstr* instr2) {
// 不能移动控制流指令
if (isControlFlowInstr(instr1) || isControlFlowInstr(instr2)) {
return false;
}
// 检查双向依赖关系
return !hasDependency(instr1, instr2) && !hasDependency(instr2, instr1);
}
// 新增:验证调度结果的正确性 - 优化版本
void validateSchedule(const std::vector<MachineInstr *> &instr_list) {
for (int i = 0; i < (int)instr_list.size(); i++) {
for (int j = i + 1; j < (int)instr_list.size(); j++) {
MachineInstr *earlier = instr_list[i];
MachineInstr *later = instr_list[j];
// 新增:验证调度结果的正确性
void validateSchedule(const std::vector<MachineInstr*>& instr_list) {
for (int i = 0; i < (int)instr_list.size(); i++) {
for (int j = i + 1; j < (int)instr_list.size(); j++) {
MachineInstr* earlier = instr_list[i];
MachineInstr* later = instr_list[j];
// 检查是否存在被违反的依赖关系
auto defined_regs = getDefinedRegisters(earlier);
auto used_regs = getUsedRegisters(later);
// 检查RAW依赖
for (const auto& reg : defined_regs) {
if (used_regs.find(reg) != used_regs.end()) {
// 这是正常的依赖关系earlier应该在later之前
continue;
}
}
// 检查内存依赖
if (hasMemoryDependency(earlier, later)) {
MemoryAccess mem1 = getMemoryAccess(earlier);
MemoryAccess mem2 = getMemoryAccess(later);
if (mem1.valid && mem2.valid &&
mem1.base_reg == mem2.base_reg && mem1.offset == mem2.offset) {
if (isStoreInstr(earlier) && isLoadInstr(later)) {
// Store->Load依赖顺序正确
continue;
}
}
}
const InstrRegInfo &info_earlier = getInstrInfo(earlier);
const InstrRegInfo &info_later = getInstrInfo(later);
// 检查是否存在被违反的依赖关系
// 检查RAW依赖
for (const auto &reg : info_earlier.defined_regs) {
if (info_later.used_regs.find(reg) != info_later.used_regs.end()) {
// 这是正常的依赖关系earlier应该在later之前
continue;
}
}
}
}
// 在基本块内对指令进行调度优化 - 完全重写版本
void scheduleBlock(MachineBasicBlock* mbb) {
auto& instructions = mbb->getInstructions();
if (instructions.size() <= 1) return;
if (instructions.size() > MAX_SCHEDULING_BLOCK_SIZE) {
return; // 跳过超大块,防止卡住
}
std::vector<MachineInstr*> instr_list;
for (auto& instr : instructions) {
instr_list.push_back(instr.get());
}
// 使用更严格的调度策略,避免破坏依赖关系
bool changed = true;
int max_iterations = 10; // 限制迭代次数避免死循环
int iteration = 0;
while (changed && iteration < max_iterations) {
changed = false;
iteration++;
for (int i = 0; i < (int)instr_list.size() - 1; i++) {
MachineInstr* instr1 = instr_list[i];
MachineInstr* instr2 = instr_list[i + 1];
// 只进行非常保守的优化
bool should_swap = false;
// 策略1: 将load指令提前减少load-use延迟
if (isLoadInstr(instr2) && !isLoadInstr(instr1) && !isStoreInstr(instr1)) {
should_swap = canSwapInstructions(instr1, instr2);
}
// 策略2: 将非关键store指令延后为其他指令让路
else if (isStoreInstr(instr1) && !isLoadInstr(instr2) && !isStoreInstr(instr2)) {
should_swap = canSwapInstructions(instr1, instr2);
}
if (should_swap) {
std::swap(instr_list[i], instr_list[i + 1]);
changed = true;
// 调试输出
// std::cout << "Swapped instructions at positions " << i << " and " << (i+1) << std::endl;
}
// 检查内存依赖
if (hasMemoryDependency(info_earlier, info_later)) {
const MemoryAccess &mem1 = info_earlier.mem_access;
const MemoryAccess &mem2 = info_later.mem_access;
if (mem1.valid && mem2.valid && mem1.base_reg == mem2.base_reg &&
mem1.offset == mem2.offset) {
if (info_earlier.is_store && info_later.is_load) {
// Store->Load依赖顺序正确
continue;
}
}
}
}
// 验证调度结果的正确性
validateSchedule(instr_list);
// 将调度后的指令顺序写回
std::map<MachineInstr*, std::unique_ptr<MachineInstr>> instr_map;
for (auto& instr : instructions) {
instr_map[instr.get()] = std::move(instr);
}
instructions.clear();
for (auto instr : instr_list) {
instructions.push_back(std::move(instr_map[instr]));
}
}
}
bool PostRA_Scheduler::runOnFunction(Function *F, AnalysisManager& AM) {
// 这个函数在IR级别运行但我们需要在机器指令级别运行
// 所以我们返回false表示没有对IR进行修改
return false;
// 在基本块内对指令进行调度优化 - 优化版本
void scheduleBlock(MachineBasicBlock *mbb) {
auto &instructions = mbb->getInstructions();
if (instructions.size() <= 1)
return;
if (instructions.size() > MAX_SCHEDULING_BLOCK_SIZE) {
return; // 跳过超大块,防止卡住
}
// 清理缓存,避免无效指针
instr_info_cache.clear();
std::vector<MachineInstr *> instr_list;
instr_list.reserve(instructions.size()); // 预分配容量
for (auto &instr : instructions) {
instr_list.push_back(instr.get());
}
// 预计算所有指令的信息
for (auto *instr : instr_list) {
getInstrInfo(instr);
}
// 使用更严格的调度策略,避免破坏依赖关系
bool changed = true;
int max_iterations = 10; // 限制迭代次数避免死循环
int iteration = 0;
while (changed && iteration < max_iterations) {
changed = false;
iteration++;
for (int i = 0; i < (int)instr_list.size() - 1; i++) {
MachineInstr *instr1 = instr_list[i];
MachineInstr *instr2 = instr_list[i + 1];
const InstrRegInfo &info1 = getInstrInfo(instr1);
const InstrRegInfo &info2 = getInstrInfo(instr2);
// 只进行非常保守的优化
bool should_swap = false;
// 策略1: 将load指令提前减少load-use延迟
if (info2.is_load && !info1.is_load && !info1.is_store) {
should_swap = canSwapInstructions(instr1, instr2);
}
// 策略2: 将非关键store指令延后为其他指令让路
else if (info1.is_store && !info2.is_load && !info2.is_store) {
should_swap = canSwapInstructions(instr1, instr2);
}
if (should_swap) {
std::swap(instr_list[i], instr_list[i + 1]);
changed = true;
// 调试输出
// std::cout << "Swapped instructions at positions " << i << " and " <<
// (i+1) << std::endl;
}
}
}
// 验证调度结果的正确性
validateSchedule(instr_list);
// 将调度后的指令顺序写回
std::unordered_map<MachineInstr *, std::unique_ptr<MachineInstr>> instr_map;
instr_map.reserve(instructions.size()); // 预分配容量
for (auto &instr : instructions) {
instr_map[instr.get()] = std::move(instr);
}
instructions.clear();
instructions.reserve(instr_list.size()); // 预分配容量
for (auto instr : instr_list) {
instructions.push_back(std::move(instr_map[instr]));
}
}
bool PostRA_Scheduler::runOnFunction(Function *F, AnalysisManager &AM) {
// 这个函数在IR级别运行但我们需要在机器指令级别运行
// 所以我们返回false表示没有对IR进行修改
return false;
}
void PostRA_Scheduler::runOnMachineFunction(MachineFunction *mfunc) {
// std::cout << "Running Post-RA Local Scheduler... " << std::endl;
// 遍历每个机器基本块
for (auto& mbb : mfunc->getBlocks()) {
scheduleBlock(mbb.get());
}
// std::cout << "Running Post-RA Local Scheduler... " << std::endl;
// 遍历每个机器基本块
for (auto &mbb : mfunc->getBlocks()) {
scheduleBlock(mbb.get());
}
// 清理全局缓存
instr_info_cache.clear();
}
} // namespace sysy

View File

@@ -1,8 +1,8 @@
#include "PreRA_Scheduler.h"
#include "RISCv64LLIR.h"
#include <algorithm>
#include <map>
#include <set>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#define MAX_SCHEDULING_BLOCK_SIZE 1000 // 严格限制调度块大小
@@ -66,9 +66,44 @@ static bool hasMemoryAccess(MachineInstr *instr) {
return isLoadInstr(instr) || isStoreInstr(instr);
}
// 获取指令定义的虚拟寄存器
static std::set<unsigned> getDefinedVirtualRegisters(MachineInstr *instr) {
std::set<unsigned> defined_regs;
// 获取内存访问位置信息
struct MemoryLocation {
unsigned base_reg;
int64_t offset;
bool is_valid;
MemoryLocation() : base_reg(0), offset(0), is_valid(false) {}
MemoryLocation(unsigned base, int64_t off)
: base_reg(base), offset(off), is_valid(true) {}
bool operator==(const MemoryLocation &other) const {
return is_valid && other.is_valid && base_reg == other.base_reg &&
offset == other.offset;
}
};
// 缓存指令分析信息
struct InstrInfo {
std::unordered_set<unsigned> defined_regs;
std::unordered_set<unsigned> used_regs;
MemoryLocation mem_location;
bool is_load;
bool is_store;
bool is_terminator;
bool is_call;
bool has_side_effect;
bool has_memory_access;
InstrInfo() : is_load(false), is_store(false), is_terminator(false),
is_call(false), has_side_effect(false), has_memory_access(false) {}
};
// 指令信息缓存
static std::unordered_map<MachineInstr*, InstrInfo> instr_info_cache;
// 获取指令定义的虚拟寄存器 - 优化版本
static std::unordered_set<unsigned> getDefinedVirtualRegisters(MachineInstr *instr) {
std::unordered_set<unsigned> defined_regs;
RVOpcodes opcode = instr->getOpcode();
// CALL指令可能定义返回值寄存器
@@ -101,9 +136,9 @@ static std::set<unsigned> getDefinedVirtualRegisters(MachineInstr *instr) {
return defined_regs;
}
// 获取指令使用的虚拟寄存器
static std::set<unsigned> getUsedVirtualRegisters(MachineInstr *instr) {
std::set<unsigned> used_regs;
// 获取指令使用的虚拟寄存器 - 优化版本
static std::unordered_set<unsigned> getUsedVirtualRegisters(MachineInstr *instr) {
std::unordered_set<unsigned> used_regs;
RVOpcodes opcode = instr->getOpcode();
// CALL指令跳过第一个操作数返回值其余为参数
@@ -164,22 +199,6 @@ static std::set<unsigned> getUsedVirtualRegisters(MachineInstr *instr) {
return used_regs;
}
// 获取内存访问位置信息
struct MemoryLocation {
unsigned base_reg;
int64_t offset;
bool is_valid;
MemoryLocation() : base_reg(0), offset(0), is_valid(false) {}
MemoryLocation(unsigned base, int64_t off)
: base_reg(base), offset(off), is_valid(true) {}
bool operator==(const MemoryLocation &other) const {
return is_valid && other.is_valid && base_reg == other.base_reg &&
offset == other.offset;
}
};
// 获取内存访问位置
static MemoryLocation getMemoryLocation(MachineInstr *instr) {
if (!isLoadInstr(instr) && !isStoreInstr(instr)) {
@@ -199,6 +218,27 @@ static MemoryLocation getMemoryLocation(MachineInstr *instr) {
return MemoryLocation();
}
// 预计算并缓存指令信息
static const InstrInfo& getInstrInfo(MachineInstr *instr) {
auto it = instr_info_cache.find(instr);
if (it != instr_info_cache.end()) {
return it->second;
}
InstrInfo& info = instr_info_cache[instr];
info.defined_regs = getDefinedVirtualRegisters(instr);
info.used_regs = getUsedVirtualRegisters(instr);
info.mem_location = getMemoryLocation(instr);
info.is_load = isLoadInstr(instr);
info.is_store = isStoreInstr(instr);
info.is_terminator = isTerminatorInstr(instr);
info.is_call = isCallInstr(instr);
info.has_side_effect = hasSideEffect(instr);
info.has_memory_access = hasMemoryAccess(instr);
return info;
}
// 检查两个内存位置是否可能别名
static bool mayAlias(const MemoryLocation &loc1, const MemoryLocation &loc2) {
if (!loc1.is_valid || !loc2.is_valid) {
@@ -214,30 +254,28 @@ static bool mayAlias(const MemoryLocation &loc1, const MemoryLocation &loc2) {
return loc1.offset == loc2.offset;
}
// 检查两个指令之间是否存在数据依赖
// 检查两个指令之间是否存在数据依赖 - 优化版本
static bool hasDataDependency(MachineInstr *first, MachineInstr *second) {
auto defined_regs_first = getDefinedVirtualRegisters(first);
auto used_regs_first = getUsedVirtualRegisters(first);
auto defined_regs_second = getDefinedVirtualRegisters(second);
auto used_regs_second = getUsedVirtualRegisters(second);
const InstrInfo& info_first = getInstrInfo(first);
const InstrInfo& info_second = getInstrInfo(second);
// RAW依赖: second读取first写入的寄存器
for (const auto &reg : defined_regs_first) {
if (used_regs_second.count(reg)) {
for (const auto &reg : info_first.defined_regs) {
if (info_second.used_regs.find(reg) != info_second.used_regs.end()) {
return true;
}
}
// WAR依赖: second写入first读取的寄存器
for (const auto &reg : used_regs_first) {
if (defined_regs_second.count(reg)) {
for (const auto &reg : info_first.used_regs) {
if (info_second.defined_regs.find(reg) != info_second.defined_regs.end()) {
return true;
}
}
// WAW依赖: 两个指令写入同一寄存器
for (const auto &reg : defined_regs_first) {
if (defined_regs_second.count(reg)) {
for (const auto &reg : info_first.defined_regs) {
if (info_second.defined_regs.find(reg) != info_second.defined_regs.end()) {
return true;
}
}
@@ -245,40 +283,41 @@ static bool hasDataDependency(MachineInstr *first, MachineInstr *second) {
return false;
}
// 检查两个指令之间是否存在内存依赖
// 检查两个指令之间是否存在内存依赖 - 优化版本
static bool hasMemoryDependency(MachineInstr *first, MachineInstr *second) {
bool first_accesses_memory = isLoadInstr(first) || isStoreInstr(first);
bool second_accesses_memory = isLoadInstr(second) || isStoreInstr(second);
const InstrInfo& info_first = getInstrInfo(first);
const InstrInfo& info_second = getInstrInfo(second);
if (!first_accesses_memory || !second_accesses_memory) {
if (!info_first.has_memory_access || !info_second.has_memory_access) {
return false;
}
// 如果至少有一个是存储指令,需要检查别名
if (isStoreInstr(first) || isStoreInstr(second)) {
MemoryLocation loc1 = getMemoryLocation(first);
MemoryLocation loc2 = getMemoryLocation(second);
return mayAlias(loc1, loc2);
if (info_first.is_store || info_second.is_store) {
return mayAlias(info_first.mem_location, info_second.mem_location);
}
return false; // 两个加载指令之间没有依赖
}
// 检查两个指令之间是否存在控制依赖
// 检查两个指令之间是否存在控制依赖 - 优化版本
static bool hasControlDependency(MachineInstr *first, MachineInstr *second) {
const InstrInfo& info_first = getInstrInfo(first);
const InstrInfo& info_second = getInstrInfo(second);
// 终结指令与任何其他指令都有控制依赖
if (isTerminatorInstr(first)) {
if (info_first.is_terminator) {
return true; // first是终结指令second不能移动到first之前
}
if (isTerminatorInstr(second)) {
if (info_second.is_terminator) {
return false; // second是终结指令可以保持在后面
}
// CALL指令具有控制副作用但可以参与有限的调度
if (isCallInstr(first) || isCallInstr(second)) {
if (info_first.is_call || info_second.is_call) {
// CALL指令之间保持顺序
if (isCallInstr(first) && isCallInstr(second)) {
if (info_first.is_call && info_second.is_call) {
return true;
}
// 其他情况允许调度(通过数据依赖控制)
@@ -287,7 +326,7 @@ static bool hasControlDependency(MachineInstr *first, MachineInstr *second) {
return false;
}
// 综合检查两个指令是否可以交换
// 综合检查两个指令是否可以交换 - 优化版本
static bool canSwapInstructions(MachineInstr *first, MachineInstr *second) {
// 检查所有类型的依赖
if (hasDataDependency(first, second) || hasDataDependency(second, first)) {
@@ -306,15 +345,17 @@ static bool canSwapInstructions(MachineInstr *first, MachineInstr *second) {
return true;
}
// 找到基本块中的调度边界
// 找到基本块中的调度边界 - 优化版本
static std::vector<size_t>
findSchedulingBoundaries(const std::vector<MachineInstr *> &instrs) {
std::vector<size_t> boundaries;
boundaries.reserve(instrs.size() / 10); // 预估边界数量
boundaries.push_back(0); // 起始边界
for (size_t i = 0; i < instrs.size(); i++) {
const InstrInfo& info = getInstrInfo(instrs[i]);
// 终结指令前后都是边界
if (isTerminatorInstr(instrs[i])) {
if (info.is_terminator) {
if (i > 0)
boundaries.push_back(i);
if (i + 1 < instrs.size())
@@ -333,7 +374,7 @@ findSchedulingBoundaries(const std::vector<MachineInstr *> &instrs) {
return boundaries;
}
// 在单个调度区域内进行指令调度
// 在单个调度区域内进行指令调度 - 优化版本
static void scheduleRegion(std::vector<MachineInstr *> &instrs, size_t start,
size_t end) {
if (end - start <= 1) {
@@ -347,7 +388,8 @@ static void scheduleRegion(std::vector<MachineInstr *> &instrs, size_t start,
// 简单的调度算法:只尝试将加载指令尽可能前移
for (size_t i = start + 1; i < end; i++) {
if (isLoadInstr(instrs[i])) {
const InstrInfo& info = getInstrInfo(instrs[i]);
if (info.is_load) {
// 尝试将加载指令向前移动
for (size_t j = i; j > start; j--) {
// 检查是否可以与前一条指令交换
@@ -369,12 +411,21 @@ static void scheduleBlock(MachineBasicBlock *mbb) {
return;
}
// 清理缓存,避免无效指针
instr_info_cache.clear();
// 构建指令列表
std::vector<MachineInstr *> instr_list;
instr_list.reserve(instructions.size()); // 预分配容量
for (auto &instr : instructions) {
instr_list.push_back(instr.get());
}
// 预计算所有指令信息
for (auto* instr : instr_list) {
getInstrInfo(instr);
}
// 找到调度边界
std::vector<size_t> boundaries = findSchedulingBoundaries(instr_list);
@@ -386,12 +437,14 @@ static void scheduleBlock(MachineBasicBlock *mbb) {
}
// 重建指令序列
std::map<MachineInstr *, std::unique_ptr<MachineInstr>> instr_map;
std::unordered_map<MachineInstr *, std::unique_ptr<MachineInstr>> instr_map;
instr_map.reserve(instructions.size()); // 预分配容量
for (auto &instr : instructions) {
instr_map[instr.get()] = std::move(instr);
}
instructions.clear();
instructions.reserve(instr_list.size()); // 预分配容量
for (auto *instr : instr_list) {
instructions.push_back(std::move(instr_map[instr]));
}
@@ -405,6 +458,9 @@ void PreRA_Scheduler::runOnMachineFunction(MachineFunction *mfunc) {
for (auto &mbb : mfunc->getBlocks()) {
scheduleBlock(mbb.get());
}
// 清理全局缓存
instr_info_cache.clear();
}
} // namespace sysy

View File

@@ -7,9 +7,15 @@ namespace sysy {
// 检查是否为内存加载/存储指令,以处理特殊的打印格式
bool isMemoryOp(RVOpcodes opcode) {
switch (opcode) {
// --- 整数加载/存储 (原有逻辑) ---
case RVOpcodes::LB: case RVOpcodes::LH: case RVOpcodes::LW: case RVOpcodes::LD:
case RVOpcodes::LBU: case RVOpcodes::LHU: case RVOpcodes::LWU:
case RVOpcodes::SB: case RVOpcodes::SH: case RVOpcodes::SW: case RVOpcodes::SD:
case RVOpcodes::FLW:
case RVOpcodes::FSW:
// 如果未来支持双精度也在这里添加FLD/FSD
// case RVOpcodes::FLD:
// case RVOpcodes::FSD:
return true;
default:
return false;
@@ -73,7 +79,9 @@ void RISCv64AsmPrinter::printInstruction(MachineInstr* instr, bool debug) {
case RVOpcodes::LHU: *OS << "lhu "; break; case RVOpcodes::LBU: *OS << "lbu "; break;
case RVOpcodes::SW: *OS << "sw "; break; case RVOpcodes::SH: *OS << "sh "; break;
case RVOpcodes::SB: *OS << "sb "; break; case RVOpcodes::LD: *OS << "ld "; break;
case RVOpcodes::SD: *OS << "sd "; break;
case RVOpcodes::SD: *OS << "sd "; break; case RVOpcodes::FLW: *OS << "flw "; break;
case RVOpcodes::FSW: *OS << "fsw "; break; case RVOpcodes::FLD: *OS << "fld "; break;
case RVOpcodes::FSD: *OS << "fsd "; break;
case RVOpcodes::J: *OS << "j "; break; case RVOpcodes::JAL: *OS << "jal "; break;
case RVOpcodes::JALR: *OS << "jalr "; break; case RVOpcodes::RET: *OS << "ret"; break;
case RVOpcodes::BEQ: *OS << "beq "; break; case RVOpcodes::BNE: *OS << "bne "; break;
@@ -82,7 +90,20 @@ void RISCv64AsmPrinter::printInstruction(MachineInstr* instr, bool debug) {
case RVOpcodes::LI: *OS << "li "; break; case RVOpcodes::LA: *OS << "la "; break;
case RVOpcodes::MV: *OS << "mv "; break; case RVOpcodes::NEG: *OS << "neg "; break;
case RVOpcodes::NEGW: *OS << "negw "; break; case RVOpcodes::SEQZ: *OS << "seqz "; break;
case RVOpcodes::SNEZ: *OS << "snez "; break;
case RVOpcodes::SNEZ: *OS << "snez "; break;
case RVOpcodes::FADD_S: *OS << "fadd.s "; break;
case RVOpcodes::FSUB_S: *OS << "fsub.s "; break;
case RVOpcodes::FMUL_S: *OS << "fmul.s "; break;
case RVOpcodes::FDIV_S: *OS << "fdiv.s "; break;
case RVOpcodes::FNEG_S: *OS << "fneg.s "; break;
case RVOpcodes::FEQ_S: *OS << "feq.s "; break;
case RVOpcodes::FLT_S: *OS << "flt.s "; break;
case RVOpcodes::FLE_S: *OS << "fle.s "; break;
case RVOpcodes::FCVT_S_W: *OS << "fcvt.s.w "; break;
case RVOpcodes::FCVT_W_S: *OS << "fcvt.w.s "; break;
case RVOpcodes::FMV_S: *OS << "fmv.s "; break;
case RVOpcodes::FMV_W_X: *OS << "fmv.w.x "; break;
case RVOpcodes::FMV_X_W: *OS << "fmv.x.w "; break;
case RVOpcodes::CALL: { // [核心修改] 为CALL指令添加特殊处理逻辑
*OS << "call ";
// 遍历所有操作数,只寻找并打印函数名标签
@@ -117,6 +138,12 @@ void RISCv64AsmPrinter::printInstruction(MachineInstr* instr, bool debug) {
// It should have been eliminated by RegAlloc
if (!debug) throw std::runtime_error("FRAME pseudo-instruction not eliminated before AsmPrinter");
*OS << "frame_addr "; break;
case RVOpcodes::FRAME_LOAD_F:
if (!debug) throw std::runtime_error("FRAME_LOAD_F not eliminated before AsmPrinter");
*OS << "frame_load_f "; break;
case RVOpcodes::FRAME_STORE_F:
if (!debug) throw std::runtime_error("FRAME_STORE_F not eliminated before AsmPrinter");
*OS << "frame_store_f "; break;
default:
throw std::runtime_error("Unknown opcode in AsmPrinter");
}

View File

@@ -16,7 +16,7 @@ std::string RISCv64CodeGen::code_gen() {
std::string RISCv64CodeGen::module_gen() {
std::stringstream ss;
// --- [新逻辑] 步骤1将全局变量分为.data和.bss两组 ---
// --- 步骤1将全局变量分为.data和.bss两组 ---
std::vector<GlobalValue*> data_globals;
std::vector<GlobalValue*> bss_globals;
@@ -26,7 +26,6 @@ std::string RISCv64CodeGen::module_gen() {
// 判断是否为大型零初始化数组,以便放入.bss段
bool is_large_zero_array = false;
// 规则初始化列表只有一项且该项是值为0的整数且数量大于一个阈值例如16
if (init_values.getValues().size() == 1) {
if (auto const_val = dynamic_cast<ConstantValue*>(init_values.getValues()[0])) {
if (const_val->isInt() && const_val->getInt() == 0 && init_values.getNumbers()[0] > 16) {
@@ -42,33 +41,53 @@ std::string RISCv64CodeGen::module_gen() {
}
}
// --- [新逻辑] 步骤2生成 .bss 段的代码 ---
// --- 步骤2生成 .bss 段的代码 ---
if (!bss_globals.empty()) {
ss << ".bss\n"; // 切换到 .bss 段
ss << ".bss\n";
for (GlobalValue* global : bss_globals) {
// 获取数组总大小(元素个数 * 元素大小)
// 在SysY中我们假设元素都是4字节int或float
unsigned count = global->getInitValues().getNumbers()[0];
unsigned total_size = count * 4;
unsigned total_size = count * 4; // 假设元素都是4字节
ss << " .align 3\n"; // 8字节对齐 (2^3)
ss << " .align 3\n";
ss << ".globl " << global->getName() << "\n";
ss << ".type " << global->getName() << ", @object\n";
ss << ".size " << global->getName() << ", " << total_size << "\n";
ss << global->getName() << ":\n";
// 使用 .space 指令来预留指定大小的零填充空间
ss << " .space " << total_size << "\n";
}
}
// --- [旧逻辑保留] 步骤3生成 .data 段的代码 ---
// --- 步骤3生成 .data 段的代码 ---
if (!data_globals.empty()) {
ss << ".data\n"; // 切换到 .data 段
for (GlobalValue* global : data_globals) {
ss << ".globl " << global->getName() << "\n";
ss << global->getName() << ":\n";
const auto& init_values = global->getInitValues();
// 使用您原有的逻辑来处理显式初始化的值
for (size_t i = 0; i < init_values.getValues().size(); ++i) {
auto val = init_values.getValues()[i];
auto count = init_values.getNumbers()[i];
if (auto constant = dynamic_cast<ConstantValue*>(val)) {
for (unsigned j = 0; j < count; ++j) {
if (constant->isInt()) {
ss << " .word " << constant->getInt() << "\n";
} else {
float f = constant->getFloat();
uint32_t float_bits = *(uint32_t*)&f;
ss << " .word " << float_bits << "\n";
}
}
}
}
}
// b. [新增] 再处理全局常量 (ConstantVariable)
for (const auto& const_ptr : module->getConsts()) {
ConstantVariable* cnst = const_ptr.get();
ss << ".globl " << cnst->getName() << "\n";
ss << cnst->getName() << ":\n";
const auto& init_values = cnst->getInitValues();
// 这部分逻辑和处理 GlobalValue 完全相同
for (size_t i = 0; i < init_values.getValues().size(); ++i) {
auto val = init_values.getValues()[i];
auto count = init_values.getNumbers()[i];
@@ -87,7 +106,7 @@ std::string RISCv64CodeGen::module_gen() {
}
}
// --- 处理函数 (.text段) 的逻辑保持不变 ---
// --- 处理函数 (.text段) ---
if (!module->getFunctions().empty()) {
ss << ".text\n";
for (const auto& func_pair : module->getFunctions()) {
@@ -99,7 +118,6 @@ std::string RISCv64CodeGen::module_gen() {
return ss.str();
}
// function_gen 现在是包含具体优化名称的、完整的处理流水线
std::string RISCv64CodeGen::function_gen(Function* func) {
// === 完整的后端处理流水线 ===

View File

@@ -10,7 +10,23 @@ namespace sysy {
// DAG节点定义 (内部实现)
struct RISCv64ISel::DAGNode {
enum NodeKind {ARGUMENT, CONSTANT, LOAD, STORE, BINARY, CALL, RETURN, BRANCH, ALLOCA_ADDR, UNARY, MEMSET, GET_ELEMENT_PTR};
enum NodeKind {
ARGUMENT,
CONSTANT, // 整数或地址常量
LOAD,
STORE,
BINARY, // 整数二元运算
CALL,
RETURN,
BRANCH,
ALLOCA_ADDR,
UNARY, // 整数一元运算
MEMSET,
GET_ELEMENT_PTR,
FP_CONSTANT, // 浮点常量
FBINARY, // 浮点二元运算 (如 FADD, FSUB, FCMP)
FUNARY, // 浮点一元运算 (如 FCVT, FNEG)
};
NodeKind kind;
Value* value = nullptr;
std::vector<DAGNode*> operands;
@@ -29,11 +45,20 @@ unsigned RISCv64ISel::getVReg(Value* val) {
if (vreg_counter == 0) {
vreg_counter = 1; // vreg 0 保留
}
vreg_map[val] = vreg_counter++;
unsigned new_vreg = vreg_counter++;
vreg_map[val] = new_vreg;
vreg_to_value_map[new_vreg] = val;
vreg_type_map[new_vreg] = val->getType();
}
return vreg_map.at(val);
}
unsigned RISCv64ISel::getNewVReg(Type* type) {
unsigned new_vreg = vreg_counter++;
vreg_type_map[new_vreg] = type; // 记录这个新vreg的类型
return new_vreg;
}
// 主入口函数
std::unique_ptr<MachineFunction> RISCv64ISel::runOnFunction(Function* func) {
F = func;
@@ -161,18 +186,52 @@ void RISCv64ISel::selectNode(DAGNode* node) {
}
break;
case DAGNode::FP_CONSTANT: {
// RISC-V没有直接加载浮点立即数的指令
// 标准做法是1. 将浮点数的32位二进制表示加载到一个整数寄存器
// 2. 使用 fmv.w.x 指令将位模式从整数寄存器移动到浮点寄存器
auto const_val = dynamic_cast<ConstantValue*>(node->value);
auto float_vreg = getVReg(const_val);
auto temp_int_vreg = getNewVReg(Type::getIntType()); // 临时整数虚拟寄存器
float f_val = const_val->getFloat();
// 使用 reinterpret_cast 获取浮点数的32位二进制表示
uint32_t float_bits = *reinterpret_cast<uint32_t*>(&f_val);
// 1. li temp_int_vreg, float_bits
auto li = std::make_unique<MachineInstr>(RVOpcodes::LI);
li->addOperand(std::make_unique<RegOperand>(temp_int_vreg));
li->addOperand(std::make_unique<ImmOperand>(float_bits));
CurMBB->addInstruction(std::move(li));
// 2. fmv.w.x float_vreg, temp_int_vreg
auto fmv = std::make_unique<MachineInstr>(RVOpcodes::FMV_W_X);
fmv->addOperand(std::make_unique<RegOperand>(float_vreg));
fmv->addOperand(std::make_unique<RegOperand>(temp_int_vreg));
CurMBB->addInstruction(std::move(fmv));
break;
}
case DAGNode::LOAD: {
auto dest_vreg = getVReg(node->value);
Value* ptr_val = node->operands[0]->value;
// --- 修改点 ---
// 1. 获取加载结果的类型 (即这个LOAD指令自身的类型)
Type* loaded_type = node->value->getType();
// 2. 根据类型选择正确的伪指令或真实指令操作码
RVOpcodes frame_opcode = loaded_type->isPointer() ? RVOpcodes::FRAME_LOAD_D : RVOpcodes::FRAME_LOAD_W;
RVOpcodes real_opcode = loaded_type->isPointer() ? RVOpcodes::LD : RVOpcodes::LW;
RVOpcodes frame_opcode;
RVOpcodes real_opcode;
if (loaded_type->isPointer()) {
frame_opcode = RVOpcodes::FRAME_LOAD_D;
real_opcode = RVOpcodes::LD;
} else if (loaded_type->isFloat()) {
frame_opcode = RVOpcodes::FRAME_LOAD_F;
real_opcode = RVOpcodes::FLW;
} else { // 默认为整数
frame_opcode = RVOpcodes::FRAME_LOAD_W;
real_opcode = RVOpcodes::LW;
}
if (auto alloca = dynamic_cast<AllocaInst*>(ptr_val)) {
// 3. 创建使用新的、区分宽度的伪指令
@@ -183,7 +242,7 @@ void RISCv64ISel::selectNode(DAGNode* node) {
} else if (auto global = dynamic_cast<GlobalValue*>(ptr_val)) {
// 对于全局变量,先用 la 加载其地址
auto addr_vreg = getNewVReg();
auto addr_vreg = getNewVReg(Type::getPointerType(global->getType()));
auto la = std::make_unique<MachineInstr>(RVOpcodes::LA);
la->addOperand(std::make_unique<RegOperand>(addr_vreg));
la->addOperand(std::make_unique<LabelOperand>(global->getName()));
@@ -220,20 +279,51 @@ void RISCv64ISel::selectNode(DAGNode* node) {
// 如果要存储的值是一个常量,就在这里生成 `li` 指令加载它
if (auto val_const = dynamic_cast<ConstantValue*>(val_to_store)) {
auto li = std::make_unique<MachineInstr>(RVOpcodes::LI);
li->addOperand(std::make_unique<RegOperand>(getVReg(val_const)));
li->addOperand(std::make_unique<ImmOperand>(val_const->getInt()));
CurMBB->addInstruction(std::move(li));
// 区分整数常量和浮点常量
if (val_const->isInt()) {
auto li = std::make_unique<MachineInstr>(RVOpcodes::LI);
li->addOperand(std::make_unique<RegOperand>(getVReg(val_const)));
li->addOperand(std::make_unique<ImmOperand>(val_const->getInt()));
CurMBB->addInstruction(std::move(li));
} else if (val_const->isFloat()) {
// 先将浮点数的位模式加载到整数vreg再用fmv.w.x移到浮点vreg
auto temp_int_vreg = getNewVReg(Type::getIntType());
auto float_vreg = getVReg(val_const);
float f_val = val_const->getFloat();
uint32_t float_bits = *reinterpret_cast<uint32_t*>(&f_val);
// 1. li temp_int_vreg, float_bits
auto li = std::make_unique<MachineInstr>(RVOpcodes::LI);
li->addOperand(std::make_unique<RegOperand>(temp_int_vreg));
li->addOperand(std::make_unique<ImmOperand>(float_bits));
CurMBB->addInstruction(std::move(li));
// 2. fmv.w.x float_vreg, temp_int_vreg
auto fmv = std::make_unique<MachineInstr>(RVOpcodes::FMV_W_X);
fmv->addOperand(std::make_unique<RegOperand>(float_vreg));
fmv->addOperand(std::make_unique<RegOperand>(temp_int_vreg));
CurMBB->addInstruction(std::move(fmv));
}
}
auto val_vreg = getVReg(val_to_store);
// --- 修改点 ---
// 1. 获取被存储的值的类型
Type* stored_type = val_to_store->getType();
// 2. 根据类型选择正确的伪指令或真实指令操作码
RVOpcodes frame_opcode = stored_type->isPointer() ? RVOpcodes::FRAME_STORE_D : RVOpcodes::FRAME_STORE_W;
RVOpcodes real_opcode = stored_type->isPointer() ? RVOpcodes::SD : RVOpcodes::SW;
RVOpcodes frame_opcode;
RVOpcodes real_opcode;
if (stored_type->isPointer()) {
frame_opcode = RVOpcodes::FRAME_STORE_D;
real_opcode = RVOpcodes::SD;
} else if (stored_type->isFloat()) {
frame_opcode = RVOpcodes::FRAME_STORE_F;
real_opcode = RVOpcodes::FSW;
} else { // 默认为整数
frame_opcode = RVOpcodes::FRAME_STORE_W;
real_opcode = RVOpcodes::SW;
}
if (auto alloca = dynamic_cast<AllocaInst*>(ptr_val)) {
// 3. 创建使用新的、区分宽度的伪指令
@@ -244,7 +334,7 @@ void RISCv64ISel::selectNode(DAGNode* node) {
} else if (auto global = dynamic_cast<GlobalValue*>(ptr_val)) {
// 向全局变量存储
auto addr_vreg = getNewVReg();
auto addr_vreg = getNewVReg(Type::getIntType());
auto la = std::make_unique<MachineInstr>(RVOpcodes::LA);
la->addOperand(std::make_unique<RegOperand>(addr_vreg));
la->addOperand(std::make_unique<LabelOperand>(global->getName()));
@@ -304,7 +394,7 @@ void RISCv64ISel::selectNode(DAGNode* node) {
}
// 2. [修改] 根据基地址的类型,生成不同的指令来获取基地址
auto base_addr_vreg = getNewVReg(); // 创建一个新的临时vreg来存放基地址
auto base_addr_vreg = getNewVReg(Type::getIntType()); // 创建一个新的临时vreg来存放基地址
// 情况一:基地址是局部栈变量
if (auto alloca_base = dynamic_cast<AllocaInst*>(base)) {
@@ -497,6 +587,109 @@ void RISCv64ISel::selectNode(DAGNode* node) {
break;
}
case DAGNode::FBINARY: {
auto bin = dynamic_cast<BinaryInst*>(node->value);
auto dest_vreg = getVReg(bin);
auto lhs_vreg = getVReg(bin->getLhs());
auto rhs_vreg = getVReg(bin->getRhs());
switch (bin->getKind()) {
case Instruction::kFAdd: {
auto instr = std::make_unique<MachineInstr>(RVOpcodes::FADD_S);
instr->addOperand(std::make_unique<RegOperand>(dest_vreg));
instr->addOperand(std::make_unique<RegOperand>(lhs_vreg));
instr->addOperand(std::make_unique<RegOperand>(rhs_vreg));
CurMBB->addInstruction(std::move(instr));
break;
}
case Instruction::kFSub: {
auto instr = std::make_unique<MachineInstr>(RVOpcodes::FSUB_S);
instr->addOperand(std::make_unique<RegOperand>(dest_vreg));
instr->addOperand(std::make_unique<RegOperand>(lhs_vreg));
instr->addOperand(std::make_unique<RegOperand>(rhs_vreg));
CurMBB->addInstruction(std::move(instr));
break;
}
case Instruction::kFMul: {
auto instr = std::make_unique<MachineInstr>(RVOpcodes::FMUL_S);
instr->addOperand(std::make_unique<RegOperand>(dest_vreg));
instr->addOperand(std::make_unique<RegOperand>(lhs_vreg));
instr->addOperand(std::make_unique<RegOperand>(rhs_vreg));
CurMBB->addInstruction(std::move(instr));
break;
}
case Instruction::kFDiv: {
auto instr = std::make_unique<MachineInstr>(RVOpcodes::FDIV_S);
instr->addOperand(std::make_unique<RegOperand>(dest_vreg));
instr->addOperand(std::make_unique<RegOperand>(lhs_vreg));
instr->addOperand(std::make_unique<RegOperand>(rhs_vreg));
CurMBB->addInstruction(std::move(instr));
break;
}
// --- 浮点比较指令 ---
// 注意:比较结果(0或1)写入的是一个通用整数寄存器(dest_vreg)
case Instruction::kFCmpEQ: {
auto instr = std::make_unique<MachineInstr>(RVOpcodes::FEQ_S);
instr->addOperand(std::make_unique<RegOperand>(dest_vreg));
instr->addOperand(std::make_unique<RegOperand>(lhs_vreg));
instr->addOperand(std::make_unique<RegOperand>(rhs_vreg));
CurMBB->addInstruction(std::move(instr));
break;
}
case Instruction::kFCmpLT: {
auto instr = std::make_unique<MachineInstr>(RVOpcodes::FLT_S);
instr->addOperand(std::make_unique<RegOperand>(dest_vreg));
instr->addOperand(std::make_unique<RegOperand>(lhs_vreg));
instr->addOperand(std::make_unique<RegOperand>(rhs_vreg));
CurMBB->addInstruction(std::move(instr));
break;
}
case Instruction::kFCmpLE: {
auto instr = std::make_unique<MachineInstr>(RVOpcodes::FLE_S);
instr->addOperand(std::make_unique<RegOperand>(dest_vreg));
instr->addOperand(std::make_unique<RegOperand>(lhs_vreg));
instr->addOperand(std::make_unique<RegOperand>(rhs_vreg));
CurMBB->addInstruction(std::move(instr));
break;
}
// --- 通过交换操作数或组合指令实现其余比较 ---
case Instruction::kFCmpGT: { // a > b 等价于 b < a
auto instr = std::make_unique<MachineInstr>(RVOpcodes::FLT_S);
instr->addOperand(std::make_unique<RegOperand>(dest_vreg));
instr->addOperand(std::make_unique<RegOperand>(rhs_vreg)); // 操作数交换
instr->addOperand(std::make_unique<RegOperand>(lhs_vreg));
CurMBB->addInstruction(std::move(instr));
break;
}
case Instruction::kFCmpGE: { // a >= b 等价于 b <= a
auto instr = std::make_unique<MachineInstr>(RVOpcodes::FLE_S);
instr->addOperand(std::make_unique<RegOperand>(dest_vreg));
instr->addOperand(std::make_unique<RegOperand>(rhs_vreg)); // 操作数交换
instr->addOperand(std::make_unique<RegOperand>(lhs_vreg));
CurMBB->addInstruction(std::move(instr));
break;
}
case Instruction::kFCmpNE: { // a != b 等价于 !(a == b)
// 1. 先用 feq.s 比较,结果存入 dest_vreg
auto feq = std::make_unique<MachineInstr>(RVOpcodes::FEQ_S);
feq->addOperand(std::make_unique<RegOperand>(dest_vreg));
feq->addOperand(std::make_unique<RegOperand>(lhs_vreg));
feq->addOperand(std::make_unique<RegOperand>(rhs_vreg));
CurMBB->addInstruction(std::move(feq));
// 2. 再用 seqz 对结果取反 (如果相等(1)则变0如果不等(0)则变1)
auto seqz = std::make_unique<MachineInstr>(RVOpcodes::SEQZ);
seqz->addOperand(std::make_unique<RegOperand>(dest_vreg));
seqz->addOperand(std::make_unique<RegOperand>(dest_vreg));
CurMBB->addInstruction(std::move(seqz));
break;
}
default:
throw std::runtime_error("Unsupported float binary instruction in ISel");
}
break;
}
case DAGNode::UNARY: {
auto unary = dynamic_cast<UnaryInst*>(node->value);
auto dest_vreg = getVReg(unary);
@@ -524,109 +717,245 @@ void RISCv64ISel::selectNode(DAGNode* node) {
break;
}
case DAGNode::FUNARY: {
auto unary = dynamic_cast<UnaryInst*>(node->value);
auto dest_vreg = getVReg(unary);
auto src_vreg = getVReg(unary->getOperand());
switch (unary->getKind()) {
case Instruction::kItoF: { // 整数 to 浮点
auto instr = std::make_unique<MachineInstr>(RVOpcodes::FCVT_S_W);
instr->addOperand(std::make_unique<RegOperand>(dest_vreg)); // 目标是浮点vreg
instr->addOperand(std::make_unique<RegOperand>(src_vreg)); // 源是整数vreg
CurMBB->addInstruction(std::move(instr));
break;
}
case Instruction::kFtoI: { // 浮点 to 整数
auto instr = std::make_unique<MachineInstr>(RVOpcodes::FCVT_W_S);
instr->addOperand(std::make_unique<RegOperand>(dest_vreg)); // 目标是整数vreg
instr->addOperand(std::make_unique<RegOperand>(src_vreg)); // 源是浮点vreg
CurMBB->addInstruction(std::move(instr));
break;
}
case Instruction::kFNeg: { // 浮点取负
auto instr = std::make_unique<MachineInstr>(RVOpcodes::FNEG_S);
instr->addOperand(std::make_unique<RegOperand>(dest_vreg));
instr->addOperand(std::make_unique<RegOperand>(src_vreg));
CurMBB->addInstruction(std::move(instr));
break;
}
// --- 处理位传送指令 ---
case Instruction::kBitItoF: { // 整数位模式 -> 浮点寄存器
auto instr = std::make_unique<MachineInstr>(RVOpcodes::FMV_W_X);
instr->addOperand(std::make_unique<RegOperand>(dest_vreg)); // 目标是浮点vreg
instr->addOperand(std::make_unique<RegOperand>(src_vreg)); // 源是整数vreg
CurMBB->addInstruction(std::move(instr));
break;
}
case Instruction::kBitFtoI: { // 浮点位模式 -> 整数寄存器
auto instr = std::make_unique<MachineInstr>(RVOpcodes::FMV_X_W);
instr->addOperand(std::make_unique<RegOperand>(dest_vreg)); // 目标是整数vreg
instr->addOperand(std::make_unique<RegOperand>(src_vreg)); // 源是浮点vreg
CurMBB->addInstruction(std::move(instr));
break;
}
default:
throw std::runtime_error("Unsupported float unary instruction in ISel");
}
break;
}
case DAGNode::CALL: {
auto call = dynamic_cast<CallInst*>(node->value);
// 处理函数参数放入a0-a7物理寄存器
size_t num_operands = node->operands.size();
size_t reg_arg_count = std::min(num_operands, (size_t)8);
for (size_t i = 0; i < reg_arg_count; ++i) {
DAGNode* arg_node = node->operands[i];
auto arg_preg = static_cast<PhysicalReg>(static_cast<int>(PhysicalReg::A0) + i);
if (arg_node->kind == DAGNode::CONSTANT) {
if (auto const_val = dynamic_cast<ConstantValue*>(arg_node->value)) {
auto li = std::make_unique<MachineInstr>(RVOpcodes::LI);
li->addOperand(std::make_unique<RegOperand>(arg_preg));
li->addOperand(std::make_unique<ImmOperand>(const_val->getInt()));
CurMBB->addInstruction(std::move(li));
// --- 步骤 1: 分配寄存器参数和栈参数 ---
// 根据RISC-V调用约定前8个整数/指针参数通过a0-a7传递
// 前8个浮点参数通过fa0-fa7传递 (物理寄存器 f10-f17)。其余参数通过栈传递。
int int_reg_idx = 0; // a0-a7 的索引
int fp_reg_idx = 0; // fa0-fa7 的索引
// 用于存储需要通过栈传递的参数
std::vector<DAGNode*> stack_args;
for (size_t i = 0; i < num_operands; ++i) {
DAGNode* arg_node = node->operands[i];
Value* arg_val = arg_node->value;
Type* arg_type = arg_val->getType();
// 判断参数是浮点类型还是整型/指针类型
if (arg_type->isFloat()) {
if (fp_reg_idx < 8) {
// --- 处理浮点寄存器参数 (fa0-fa7, 对应物理寄存器 F10-F17) ---
auto arg_preg = static_cast<PhysicalReg>(static_cast<int>(PhysicalReg::F10) + fp_reg_idx);
fp_reg_idx++;
if (auto const_val = dynamic_cast<ConstantValue*>(arg_val)) {
// 如果是浮点常量,需要先物化
// 1. 获取其32位二进制表示
float f_val = const_val->getFloat();
uint32_t float_bits = *reinterpret_cast<uint32_t*>(&f_val);
// 2. 将位模式加载到一个临时整数寄存器 (使用t0)
auto li = std::make_unique<MachineInstr>(RVOpcodes::LI);
li->addOperand(std::make_unique<RegOperand>(PhysicalReg::T0));
li->addOperand(std::make_unique<ImmOperand>(float_bits));
CurMBB->addInstruction(std::move(li));
// 3. 使用fmv.w.x将位模式从整数寄存器移动到目标浮点参数寄存器
auto fmv_wx = std::make_unique<MachineInstr>(RVOpcodes::FMV_W_X);
fmv_wx->addOperand(std::make_unique<RegOperand>(arg_preg));
fmv_wx->addOperand(std::make_unique<RegOperand>(PhysicalReg::T0));
CurMBB->addInstruction(std::move(fmv_wx));
} else {
// 如果已经是虚拟寄存器,直接用 fmv.s 移动
auto src_vreg = getVReg(arg_val);
auto fmv_s = std::make_unique<MachineInstr>(RVOpcodes::FMV_S);
fmv_s->addOperand(std::make_unique<RegOperand>(arg_preg));
fmv_s->addOperand(std::make_unique<RegOperand>(src_vreg));
CurMBB->addInstruction(std::move(fmv_s));
}
} else {
// 浮点寄存器已用完,放到栈上传递
stack_args.push_back(arg_node);
}
} else { // 整数或指针参数
if (int_reg_idx < 8) {
// --- 处理整数/指针寄存器参数 (a0-a7) ---
auto arg_preg = static_cast<PhysicalReg>(static_cast<int>(PhysicalReg::A0) + int_reg_idx);
int_reg_idx++;
if (arg_node->kind == DAGNode::CONSTANT) {
if (auto const_val = dynamic_cast<ConstantValue*>(arg_val)) {
auto li = std::make_unique<MachineInstr>(RVOpcodes::LI);
li->addOperand(std::make_unique<RegOperand>(arg_preg));
li->addOperand(std::make_unique<ImmOperand>(const_val->getInt()));
CurMBB->addInstruction(std::move(li));
}
} else {
auto src_vreg = getVReg(arg_val);
auto mv = std::make_unique<MachineInstr>(RVOpcodes::MV);
mv->addOperand(std::make_unique<RegOperand>(arg_preg));
mv->addOperand(std::make_unique<RegOperand>(src_vreg));
CurMBB->addInstruction(std::move(mv));
}
} else {
// 整数寄存器已用完,放到栈上传递
stack_args.push_back(arg_node);
}
} else {
auto src_vreg = getVReg(arg_node->value);
auto mv = std::make_unique<MachineInstr>(RVOpcodes::MV);
mv->addOperand(std::make_unique<RegOperand>(arg_preg));
mv->addOperand(std::make_unique<RegOperand>(src_vreg));
CurMBB->addInstruction(std::move(mv));
}
}
if (num_operands > 8) {
size_t stack_arg_count = num_operands - 8;
int stack_space = stack_arg_count * 8; // RV64中每个参数槽位8字节
// 2a. 在栈上分配空间
auto alloc_instr = std::make_unique<MachineInstr>(RVOpcodes::ADDI);
alloc_instr->addOperand(std::make_unique<RegOperand>(PhysicalReg::SP));
alloc_instr->addOperand(std::make_unique<RegOperand>(PhysicalReg::SP));
alloc_instr->addOperand(std::make_unique<ImmOperand>(-stack_space));
CurMBB->addInstruction(std::move(alloc_instr));
// --- 步骤 2: 处理所有栈参数 ---
int stack_space = 0;
if (!stack_args.empty()) {
// 计算栈参数所需的总空间RV64中每个槽位为8字节
stack_space = stack_args.size() * 8;
// 根据ABI为call分配的栈空间需要16字节对齐
if (stack_space % 16 != 0) {
stack_space += 16 - (stack_space % 16);
}
// 在栈上分配空间
if (stack_space > 0) {
auto alloc_instr = std::make_unique<MachineInstr>(RVOpcodes::ADDI);
alloc_instr->addOperand(std::make_unique<RegOperand>(PhysicalReg::SP));
alloc_instr->addOperand(std::make_unique<RegOperand>(PhysicalReg::SP));
alloc_instr->addOperand(std::make_unique<ImmOperand>(-stack_space));
CurMBB->addInstruction(std::move(alloc_instr));
}
// 将每个参数存储到栈上对应的位置
for (size_t i = 0; i < stack_args.size(); ++i) {
DAGNode* arg_node = stack_args[i];
Value* arg_val = arg_node->value;
Type* arg_type = arg_val->getType();
int offset = i * 8;
// 2b. 存储每个栈参数
for (size_t i = 8; i < num_operands; ++i) {
DAGNode* arg_node = node->operands[i];
unsigned src_vreg;
// 准备源寄存器
if (arg_node->kind == DAGNode::CONSTANT) {
// 如果是常量,先加载到临时寄存器
src_vreg = getNewVReg();
auto const_val = dynamic_cast<ConstantValue*>(arg_node->value);
auto li = std::make_unique<MachineInstr>(RVOpcodes::LI);
li->addOperand(std::make_unique<RegOperand>(src_vreg));
li->addOperand(std::make_unique<ImmOperand>(const_val->getInt()));
CurMBB->addInstruction(std::move(li));
// 如果是常量先加载到临时vreg
if (auto const_val = dynamic_cast<ConstantValue*>(arg_val)) {
src_vreg = getNewVReg(arg_type);
if(arg_type->isFloat()) {
auto temp_int_vreg = getNewVReg(Type::getIntType());
float f_val = const_val->getFloat();
uint32_t float_bits = *reinterpret_cast<uint32_t*>(&f_val);
auto li = std::make_unique<MachineInstr>(RVOpcodes::LI);
li->addOperand(std::make_unique<RegOperand>(temp_int_vreg));
li->addOperand(std::make_unique<ImmOperand>(float_bits));
CurMBB->addInstruction(std::move(li));
auto fmv_wx = std::make_unique<MachineInstr>(RVOpcodes::FMV_W_X);
fmv_wx->addOperand(std::make_unique<RegOperand>(src_vreg));
fmv_wx->addOperand(std::make_unique<RegOperand>(temp_int_vreg));
CurMBB->addInstruction(std::move(fmv_wx));
} else {
auto li = std::make_unique<MachineInstr>(RVOpcodes::LI);
li->addOperand(std::make_unique<RegOperand>(src_vreg));
li->addOperand(std::make_unique<ImmOperand>(const_val->getInt()));
CurMBB->addInstruction(std::move(li));
}
} else {
src_vreg = getVReg(arg_node->value);
src_vreg = getVReg(arg_val);
}
// 计算在栈上的偏移量
int offset = (i - 8) * 8;
// 生成 sd 指令
auto sd_instr = std::make_unique<MachineInstr>(RVOpcodes::SD);
sd_instr->addOperand(std::make_unique<RegOperand>(src_vreg));
sd_instr->addOperand(std::make_unique<MemOperand>(
// 根据类型选择 fsw (浮点) 或 sd (整型/指针) 存储指令
std::unique_ptr<MachineInstr> store_instr;
if (arg_type->isFloat()) {
store_instr = std::make_unique<MachineInstr>(RVOpcodes::FSW);
} else {
store_instr = std::make_unique<MachineInstr>(RVOpcodes::SD);
}
store_instr->addOperand(std::make_unique<RegOperand>(src_vreg));
store_instr->addOperand(std::make_unique<MemOperand>(
std::make_unique<RegOperand>(PhysicalReg::SP),
std::make_unique<ImmOperand>(offset)
));
CurMBB->addInstruction(std::move(sd_instr));
CurMBB->addInstruction(std::move(store_instr));
}
}
// --- 步骤 3: 生成CALL指令 ---
auto call_instr = std::make_unique<MachineInstr>(RVOpcodes::CALL);
// [协议] 如果函数有返回值,将它的目标虚拟寄存器作为第一个操作数
if (!call->getType()->isVoid()) {
unsigned dest_vreg = getVReg(call);
call_instr->addOperand(std::make_unique<RegOperand>(dest_vreg));
}
// 将函数名标签作为后续操作数
call_instr->addOperand(std::make_unique<LabelOperand>(call->getCallee()->getName()));
// 将所有参数的虚拟寄存器也作为后续操作数供getInstrUseDef分析
for (size_t i = 0; i < num_operands; ++i) {
if (node->operands[i]->kind != DAGNode::CONSTANT) { // 常量参数已直接加载无需作为use
if (node->operands[i]->kind != DAGNode::CONSTANT && node->operands[i]->kind != DAGNode::FP_CONSTANT) {
call_instr->addOperand(std::make_unique<RegOperand>(getVReg(node->operands[i]->value)));
}
}
CurMBB->addInstruction(std::move(call_instr));
if (num_operands > 8) {
size_t stack_arg_count = num_operands - 8;
int stack_space = stack_arg_count * 8;
// --- 步骤 4: 处理返回值 ---
if (!call->getType()->isVoid()) {
unsigned dest_vreg = getVReg(call);
if (call->getType()->isFloat()) {
// 浮点返回值在 fa0 (物理寄存器 F10)
auto fmv_s = std::make_unique<MachineInstr>(RVOpcodes::FMV_S);
fmv_s->addOperand(std::make_unique<RegOperand>(dest_vreg));
fmv_s->addOperand(std::make_unique<RegOperand>(PhysicalReg::F10)); // fa0
CurMBB->addInstruction(std::move(fmv_s));
} else {
// 整数/指针返回值在 a0
auto mv = std::make_unique<MachineInstr>(RVOpcodes::MV);
mv->addOperand(std::make_unique<RegOperand>(dest_vreg));
mv->addOperand(std::make_unique<RegOperand>(PhysicalReg::A0));
CurMBB->addInstruction(std::move(mv));
}
}
// --- 步骤 5: 回收为栈参数分配的空间 ---
if (stack_space > 0) {
auto dealloc_instr = std::make_unique<MachineInstr>(RVOpcodes::ADDI);
dealloc_instr->addOperand(std::make_unique<RegOperand>(PhysicalReg::SP));
dealloc_instr->addOperand(std::make_unique<RegOperand>(PhysicalReg::SP));
dealloc_instr->addOperand(std::make_unique<ImmOperand>(stack_space));
CurMBB->addInstruction(std::move(dealloc_instr));
}
// 处理返回值从a0移动到目标虚拟寄存器
// if (!call->getType()->isVoid()) {
// auto mv_instr = std::make_unique<MachineInstr>(RVOpcodes::MV);
// mv_instr->addOperand(std::make_unique<RegOperand>(getVReg(call)));
// mv_instr->addOperand(std::make_unique<RegOperand>(PhysicalReg::A0));
// CurMBB->addInstruction(std::move(mv_instr));
// }
break;
}
@@ -634,17 +963,47 @@ void RISCv64ISel::selectNode(DAGNode* node) {
auto ret_inst_ir = dynamic_cast<ReturnInst*>(node->value);
if (ret_inst_ir && ret_inst_ir->hasReturnValue()) {
Value* ret_val = ret_inst_ir->getReturnValue();
// [V2优点] 在RETURN节点内加载常量返回值
if (auto const_val = dynamic_cast<ConstantValue*>(ret_val)) {
auto li_instr = std::make_unique<MachineInstr>(RVOpcodes::LI);
li_instr->addOperand(std::make_unique<RegOperand>(PhysicalReg::A0));
li_instr->addOperand(std::make_unique<ImmOperand>(const_val->getInt()));
CurMBB->addInstruction(std::move(li_instr));
Type* ret_type = ret_val->getType();
if (ret_type->isFloat()) {
// --- 处理浮点返回值 ---
// 返回值需要被放入 fa0 (物理寄存器 F10)
if (auto const_val = dynamic_cast<ConstantValue*>(ret_val)) {
// 如果是浮点常量需要先物化到fa0
float f_val = const_val->getFloat();
uint32_t float_bits = *reinterpret_cast<uint32_t*>(&f_val);
// 1. 加载位模式到临时整数寄存器 (t0)
auto li = std::make_unique<MachineInstr>(RVOpcodes::LI);
li->addOperand(std::make_unique<RegOperand>(PhysicalReg::T0));
li->addOperand(std::make_unique<ImmOperand>(float_bits));
CurMBB->addInstruction(std::move(li));
// 2. 将位模式从 t0 移动到 fa0
auto fmv_wx = std::make_unique<MachineInstr>(RVOpcodes::FMV_W_X);
fmv_wx->addOperand(std::make_unique<RegOperand>(PhysicalReg::F10)); // fa0
fmv_wx->addOperand(std::make_unique<RegOperand>(PhysicalReg::T0));
CurMBB->addInstruction(std::move(fmv_wx));
} else {
// 如果是vreg直接用 fmv.s 移动到 fa0
auto fmv_s = std::make_unique<MachineInstr>(RVOpcodes::FMV_S);
fmv_s->addOperand(std::make_unique<RegOperand>(PhysicalReg::F10)); // fa0
fmv_s->addOperand(std::make_unique<RegOperand>(getVReg(ret_val)));
CurMBB->addInstruction(std::move(fmv_s));
}
} else {
auto mv_instr = std::make_unique<MachineInstr>(RVOpcodes::MV);
mv_instr->addOperand(std::make_unique<RegOperand>(PhysicalReg::A0));
mv_instr->addOperand(std::make_unique<RegOperand>(getVReg(ret_val)));
CurMBB->addInstruction(std::move(mv_instr));
// --- 处理整数/指针返回值 ---
// 返回值需要被放入 a0
// [V2优点] 在RETURN节点内加载常量返回值
if (auto const_val = dynamic_cast<ConstantValue*>(ret_val)) {
auto li_instr = std::make_unique<MachineInstr>(RVOpcodes::LI);
li_instr->addOperand(std::make_unique<RegOperand>(PhysicalReg::A0));
li_instr->addOperand(std::make_unique<ImmOperand>(const_val->getInt()));
CurMBB->addInstruction(std::move(li_instr));
} else {
auto mv_instr = std::make_unique<MachineInstr>(RVOpcodes::MV);
mv_instr->addOperand(std::make_unique<RegOperand>(PhysicalReg::A0));
mv_instr->addOperand(std::make_unique<RegOperand>(getVReg(ret_val)));
CurMBB->addInstruction(std::move(mv_instr));
}
}
}
// [V1设计保留] 函数尾声epilogue不由RETURN节点生成
@@ -862,6 +1221,11 @@ void RISCv64ISel::selectNode(DAGNode* node) {
la_instr->addOperand(std::make_unique<RegOperand>(current_addr_vreg));
la_instr->addOperand(std::make_unique<LabelOperand>(global_base->getName()));
CurMBB->addInstruction(std::move(la_instr));
} else if (auto const_global_base = dynamic_cast<ConstantVariable*>(base_ptr_node->value)) {
auto la_instr = std::make_unique<MachineInstr>(RVOpcodes::LA);
la_instr->addOperand(std::make_unique<RegOperand>(current_addr_vreg));
la_instr->addOperand(std::make_unique<LabelOperand>(const_global_base->getName()));
CurMBB->addInstruction(std::move(la_instr));
} else {
auto base_vreg = getVReg(base_ptr_node->value);
auto mv = std::make_unique<MachineInstr>(RVOpcodes::MV);
@@ -870,7 +1234,7 @@ void RISCv64ISel::selectNode(DAGNode* node) {
CurMBB->addInstruction(std::move(mv));
}
// --- Step 2: [最终权威版] 遵循LLVM GEP语义迭代计算地址 ---
// --- Step 2: 遵循LLVM GEP语义迭代计算地址 ---
// 初始被索引的类型,是基指针指向的那个类型 (例如, [2 x i32])
Type* current_type = gep->getBasePointer()->getType()->as<PointerType>()->getBaseType();
@@ -979,15 +1343,17 @@ RISCv64ISel::DAGNode* RISCv64ISel::get_operand_node(
// 规则1如果这个Value已经有对应的节点直接返回
if (value_to_node.count(val_ir)) {
return value_to_node.at(val_ir);
return value_to_node[val_ir];
}
// 规则2识别各种类型的叶子节点并创建相应的DAG节点
if (dynamic_cast<ConstantValue*>(val_ir)) {
return create_node(DAGNode::CONSTANT, val_ir, value_to_node, nodes_storage);
if (auto const_val = dynamic_cast<ConstantValue*>(val_ir)) {
if (const_val->isInt()) {
return create_node(DAGNode::CONSTANT, val_ir, value_to_node, nodes_storage);
} else {
// 为浮点常量创建新的FP_CONSTANT节点
return create_node(DAGNode::FP_CONSTANT, val_ir, value_to_node, nodes_storage);
}
}
if (dynamic_cast<GlobalValue*>(val_ir)) {
// 全局变量/常量数组被视为一个常量地址
return create_node(DAGNode::CONSTANT, val_ir, value_to_node, nodes_storage);
}
if (dynamic_cast<AllocaInst*>(val_ir)) {
@@ -1059,6 +1425,17 @@ std::vector<std::unique_ptr<RISCv64ISel::DAGNode>> RISCv64ISel::build_dag(BasicB
load_node->operands.push_back(get_operand_node(load->getPointer(), value_to_node, nodes_storage));
} else if (auto bin = dynamic_cast<BinaryInst*>(inst)) {
if(value_to_node.count(bin)) continue;
if (bin->getKind() == Instruction::kFSub) {
if (auto const_lhs = dynamic_cast<ConstantValue*>(bin->getLhs())) {
// 使用isZero()来判断浮点数0.0,比直接比较更健壮
if (const_lhs->isZero()) {
// 这是一个浮点取负操作,创建 FUNARY 节点
auto funary_node = create_node(DAGNode::FUNARY, bin, value_to_node, nodes_storage);
funary_node->operands.push_back(get_operand_node(bin->getRhs(), value_to_node, nodes_storage));
continue; // 处理完毕,跳到下一条指令
}
}
}
if (bin->getKind() == BinaryInst::kSub) {
if (auto const_lhs = dynamic_cast<ConstantValue*>(bin->getLhs())) {
if (const_lhs->getInt() == 0) {
@@ -1068,13 +1445,24 @@ std::vector<std::unique_ptr<RISCv64ISel::DAGNode>> RISCv64ISel::build_dag(BasicB
}
}
}
auto bin_node = create_node(DAGNode::BINARY, bin, value_to_node, nodes_storage);
bin_node->operands.push_back(get_operand_node(bin->getLhs(), value_to_node, nodes_storage));
bin_node->operands.push_back(get_operand_node(bin->getRhs(), value_to_node, nodes_storage));
if (bin->getKind() >= Instruction::kFAdd) { // 假设浮点指令枚举值更大
auto fbin_node = create_node(DAGNode::FBINARY, bin, value_to_node, nodes_storage);
fbin_node->operands.push_back(get_operand_node(bin->getLhs(), value_to_node, nodes_storage));
fbin_node->operands.push_back(get_operand_node(bin->getRhs(), value_to_node, nodes_storage));
} else {
auto bin_node = create_node(DAGNode::BINARY, bin, value_to_node, nodes_storage);
bin_node->operands.push_back(get_operand_node(bin->getLhs(), value_to_node, nodes_storage));
bin_node->operands.push_back(get_operand_node(bin->getRhs(), value_to_node, nodes_storage));
}
} else if (auto un = dynamic_cast<UnaryInst*>(inst)) {
if(value_to_node.count(un)) continue;
auto unary_node = create_node(DAGNode::UNARY, un, value_to_node, nodes_storage);
unary_node->operands.push_back(get_operand_node(un->getOperand(), value_to_node, nodes_storage));
if (un->getKind() >= Instruction::kFNeg) {
auto funary_node = create_node(DAGNode::FUNARY, un, value_to_node, nodes_storage);
funary_node->operands.push_back(get_operand_node(un->getOperand(), value_to_node, nodes_storage));
} else {
auto unary_node = create_node(DAGNode::UNARY, un, value_to_node, nodes_storage);
unary_node->operands.push_back(get_operand_node(un->getOperand(), value_to_node, nodes_storage));
}
} else if (auto call = dynamic_cast<CallInst*>(inst)) {
if(value_to_node.count(call)) continue;
auto call_node = create_node(DAGNode::CALL, call, value_to_node, nodes_storage);

View File

@@ -10,9 +10,10 @@
namespace sysy {
RISCv64RegAlloc::RISCv64RegAlloc(MachineFunction* mfunc) : MFunc(mfunc) {
// 1. 初始化可分配的整数寄存器池
allocable_int_regs = {
PhysicalReg::T0, PhysicalReg::T1, PhysicalReg::T2, PhysicalReg::T3,
PhysicalReg::T4, /*PhysicalReg::T5,*/PhysicalReg::T6,
PhysicalReg::T4, /*PhysicalReg::T5,*/ PhysicalReg::T6, // T5是大立即数传送寄存器
PhysicalReg::A0, PhysicalReg::A1, PhysicalReg::A2, PhysicalReg::A3,
PhysicalReg::A4, PhysicalReg::A5, PhysicalReg::A6, PhysicalReg::A7,
PhysicalReg::S0, PhysicalReg::S1, PhysicalReg::S2, PhysicalReg::S3,
@@ -20,26 +21,39 @@ RISCv64RegAlloc::RISCv64RegAlloc(MachineFunction* mfunc) : MFunc(mfunc) {
PhysicalReg::S8, PhysicalReg::S9, PhysicalReg::S10, PhysicalReg::S11,
};
// 创建一个包含所有通用整数寄存器的临时列表
const std::vector<PhysicalReg> all_int_regs = {
PhysicalReg::T0, PhysicalReg::T1, PhysicalReg::T2, PhysicalReg::T3,
PhysicalReg::T4, PhysicalReg::T5, PhysicalReg::T6,
PhysicalReg::A0, PhysicalReg::A1, PhysicalReg::A2, PhysicalReg::A3,
PhysicalReg::A4, PhysicalReg::A5, PhysicalReg::A6, PhysicalReg::A7,
PhysicalReg::S0, PhysicalReg::S1, PhysicalReg::S2, PhysicalReg::S3,
PhysicalReg::S4, PhysicalReg::S5, PhysicalReg::S6, PhysicalReg::S7,
PhysicalReg::S8, PhysicalReg::S9, PhysicalReg::S10, PhysicalReg::S11,
// 2. 初始化可分配的浮点寄存器池
allocable_fp_regs = {
// 浮点临时寄存器 ft0-ft11
PhysicalReg::F0, PhysicalReg::F1, PhysicalReg::F2, PhysicalReg::F3,
PhysicalReg::F4, PhysicalReg::F5, PhysicalReg::F6, PhysicalReg::F7,
PhysicalReg::F28, PhysicalReg::F29, PhysicalReg::F30, PhysicalReg::F31,
// 浮点参数/返回值寄存器 fa0-fa7
PhysicalReg::F10, PhysicalReg::F11, PhysicalReg::F12, PhysicalReg::F13,
PhysicalReg::F14, PhysicalReg::F15, PhysicalReg::F16, PhysicalReg::F17,
// 浮点保存寄存器 fs0-fs11
PhysicalReg::F8, PhysicalReg::F9,
PhysicalReg::F18, PhysicalReg::F19, PhysicalReg::F20, PhysicalReg::F21,
PhysicalReg::F22, PhysicalReg::F23, PhysicalReg::F24, PhysicalReg::F25,
PhysicalReg::F26, PhysicalReg::F27
};
// 映射物理寄存器到特殊的虚拟寄存器ID用于干扰图中的物理寄存器节点
// 确保这些特殊ID不会与vreg_counter生成的常规虚拟寄存器ID冲突
for (PhysicalReg preg : all_int_regs) {
preg_to_vreg_id_map[preg] = static_cast<unsigned>(PhysicalReg::PHYS_REG_START_ID) + static_cast<unsigned>(preg);
// 3. 映射所有物理寄存器包括整数、浮点和特殊寄存器到特殊的虚拟寄存器ID
// 这是为了让活跃性分析和干扰图构建能够统一处理所有类型的寄存器
for (int i = 0; i < static_cast<int>(PhysicalReg::PHYS_REG_START_ID); ++i) {
auto preg = static_cast<PhysicalReg>(i);
preg_to_vreg_id_map[preg] = static_cast<unsigned>(PhysicalReg::PHYS_REG_START_ID) + i;
}
}
// 寄存器分配的主入口点
void RISCv64RegAlloc::run() {
// --- 在所有流程开始前构建完整的vreg到Value的反向映射 ---
const auto& vreg_map_from_isel = MFunc->getISel()->getVRegMap();
for (const auto& pair : vreg_map_from_isel) {
Value* val = pair.first;
unsigned vreg = pair.second;
vreg_to_value_map[vreg] = val;
}
// 阶段 1: 处理函数调用约定(参数寄存器预着色)
handleCallingConvention();
// 阶段 2: 消除帧索引(为局部变量和栈参数分配栈偏移)
@@ -68,7 +82,10 @@ void RISCv64RegAlloc::run() {
// 阶段 5: 图着色算法分配物理寄存器
colorGraph();
// 阶段 6: 重写函数(插入溢出/填充代码,替换虚拟寄存器为物理寄存器)
rewriteFunction();
rewriteFunction();
// 将最终的寄存器分配结果保存到MachineFunction的帧信息中供后续Pass使用
MFunc->getFrameInfo().vreg_to_preg_map = this->color_map;
}
/**
@@ -82,35 +99,38 @@ void RISCv64RegAlloc::handleCallingConvention() {
RISCv64ISel* isel = MFunc->getISel();
// --- 部分1处理函数传入参数的预着色 ---
// 获取函数的Argument对象列表
if (F) {
auto& args = F->getArguments();
// RISC-V RV64G调用约定前8个整型/指针参数通过 a0-a7 传递
int arg_idx = 0;
// 遍历 Argument* 列表
for (Argument* arg : args) {
if (arg_idx >= 8) {
break;
}
// 获取该 Argument 对象对应的虚拟寄存器ID
// 通过 MachineFunction -> RISCv64ISel -> vreg_map 来获取
const auto& vreg_map_from_isel = MFunc->getISel()->getVRegMap();
assert(vreg_map_from_isel.count(arg) && "Argument not found in ISel's vreg_map!");
// 1. 获取该 Argument 对象对应的虚拟寄存器
unsigned vreg = isel->getVReg(arg);
// 2. 根据参数索引,确定对应的物理寄存器 (a0, a1, ...)
auto preg = static_cast<PhysicalReg>(static_cast<int>(PhysicalReg::A0) + arg_idx);
// 3. 在 color_map 中,将 vreg "预着色" 为对应的物理寄存器
color_map[vreg] = preg;
// [修改] 为整数参数和浮点参数分别维护索引
int int_arg_idx = 0;
int float_arg_idx = 0;
arg_idx++;
for (Argument* arg : args) {
// [修改] 根据参数类型决定使用哪个寄存器池和索引
if (arg->getType()->isFloat()) {
// --- 处理浮点参数 ---
if (float_arg_idx >= 8) continue; // fa0-fa7
unsigned vreg = isel->getVReg(arg);
// 浮点参数使用 fa10-fa17 (在RISC-V ABI中对应F10-F17)
auto preg = static_cast<PhysicalReg>(static_cast<int>(PhysicalReg::F10) + float_arg_idx);
color_map[vreg] = preg;
float_arg_idx++;
} else {
// --- 处理整数/指针参数 (原有逻辑) ---
if (int_arg_idx >= 8) continue; // a0-a7
unsigned vreg = isel->getVReg(arg);
auto preg = static_cast<PhysicalReg>(static_cast<int>(PhysicalReg::A0) + int_arg_idx);
color_map[vreg] = preg;
int_arg_idx++;
}
}
}
// // --- 部分2[新逻辑] 遍历所有指令,为CALL指令的返回值预着色为 a0 ---
// // 这是为了强制寄存器分配器知道call的结果物理上出现在a0寄存器。
// --- 部分2为CALL指令的返回值预着色 ---
for (auto& mbb : MFunc->getBlocks()) {
for (auto& instr : mbb->getInstructions()) {
if (instr->getOpcode() == RVOpcodes::CALL) {
@@ -121,11 +141,17 @@ void RISCv64RegAlloc::handleCallingConvention() {
auto reg_op = static_cast<RegOperand*>(instr->getOperands().front().get());
if (reg_op->isVirtual()) {
unsigned ret_vreg = reg_op->getVRegNum();
// 强制将这个虚拟寄存器预着色为 a0
color_map[ret_vreg] = PhysicalReg::A0;
if (DEBUG) {
std::cout << "[DEBUG] Pre-coloring vreg" << ret_vreg
<< " to a0 for CALL instruction." << std::endl;
// [修改] 检查返回值的类型,预着色到 a0 或 fa0
assert(MFunc->getISel()->getVRegValueMap().count(ret_vreg) && "Return vreg not found in value map!");
Value* ret_val = MFunc->getISel()->getVRegValueMap().at(ret_vreg);
if (ret_val->getType()->isFloat()) {
// 浮点返回值预着色到 fa0 (F10)
color_map[ret_vreg] = PhysicalReg::F10;
} else {
// 整数/指针返回值预着色到 a0
color_map[ret_vreg] = PhysicalReg::A0;
}
}
}
@@ -218,6 +244,30 @@ void RISCv64RegAlloc::eliminateFrameIndices() {
std::make_unique<ImmOperand>(0)));
new_instructions.push_back(std::move(load_instr));
} else if (opcode == RVOpcodes::FRAME_LOAD_F) {
// 展开浮点加载伪指令
RVOpcodes real_load_op = RVOpcodes::FLW; // 对应的真实指令是 flw
auto& operands = instr_ptr->getOperands();
unsigned dest_vreg = static_cast<RegOperand*>(operands[0].get())->getVRegNum();
unsigned alloca_vreg = static_cast<RegOperand*>(operands[1].get())->getVRegNum();
int offset = frame_info.alloca_offsets.at(alloca_vreg);
auto addr_vreg = isel->getNewVReg();
// 展开为: addi addr_vreg, s0, offset
auto addi = std::make_unique<MachineInstr>(RVOpcodes::ADDI);
addi->addOperand(std::make_unique<RegOperand>(addr_vreg));
addi->addOperand(std::make_unique<RegOperand>(PhysicalReg::S0));
addi->addOperand(std::make_unique<ImmOperand>(offset));
new_instructions.push_back(std::move(addi));
// 展开为: flw dest_vreg, 0(addr_vreg)
auto load_instr = std::make_unique<MachineInstr>(real_load_op);
load_instr->addOperand(std::make_unique<RegOperand>(dest_vreg));
load_instr->addOperand(std::make_unique<MemOperand>(
std::make_unique<RegOperand>(addr_vreg),
std::make_unique<ImmOperand>(0)));
new_instructions.push_back(std::move(load_instr));
} else if (opcode == RVOpcodes::FRAME_STORE_W || opcode == RVOpcodes::FRAME_STORE_D) {
// 确定要生成的真实存储指令是 sw 还是 sd
RVOpcodes real_store_op = (opcode == RVOpcodes::FRAME_STORE_W) ? RVOpcodes::SW : RVOpcodes::SD;
@@ -243,6 +293,30 @@ void RISCv64RegAlloc::eliminateFrameIndices() {
std::make_unique<ImmOperand>(0)));
new_instructions.push_back(std::move(store_instr));
} else if (opcode == RVOpcodes::FRAME_STORE_F) {
// 展开浮点存储伪指令
RVOpcodes real_store_op = RVOpcodes::FSW; // 对应的真实指令是 fsw
auto& operands = instr_ptr->getOperands();
unsigned src_vreg = static_cast<RegOperand*>(operands[0].get())->getVRegNum();
unsigned alloca_vreg = static_cast<RegOperand*>(operands[1].get())->getVRegNum();
int offset = frame_info.alloca_offsets.at(alloca_vreg);
auto addr_vreg = isel->getNewVReg();
// 展开为: addi addr_vreg, s0, offset
auto addi = std::make_unique<MachineInstr>(RVOpcodes::ADDI);
addi->addOperand(std::make_unique<RegOperand>(addr_vreg));
addi->addOperand(std::make_unique<RegOperand>(PhysicalReg::S0));
addi->addOperand(std::make_unique<ImmOperand>(offset));
new_instructions.push_back(std::move(addi));
// 展开为: fsw src_vreg, 0(addr_vreg)
auto store_instr = std::make_unique<MachineInstr>(real_store_op);
store_instr->addOperand(std::make_unique<RegOperand>(src_vreg));
store_instr->addOperand(std::make_unique<MemOperand>(
std::make_unique<RegOperand>(addr_vreg),
std::make_unique<ImmOperand>(0)));
new_instructions.push_back(std::move(store_instr));
} else if (instr_ptr->getOpcode() == RVOpcodes::FRAME_ADDR) {
auto& operands = instr_ptr->getOperands();
unsigned dest_vreg = static_cast<RegOperand*>(operands[0].get())->getVRegNum();
@@ -277,7 +351,7 @@ void RISCv64RegAlloc::getInstrUseDef(MachineInstr* instr, LiveSet& use, LiveSet&
// 1. 特殊指令的 `is_def` 标志调整
// 这些指令的第一个寄存器操作数是源操作数 (use),而不是目标操作数 (def)。
if (opcode == RVOpcodes::SW || opcode == RVOpcodes::SD ||
if (opcode == RVOpcodes::SW || opcode == RVOpcodes::SD || opcode == RVOpcodes::FSW ||
opcode == RVOpcodes::BEQ || opcode == RVOpcodes::BNE ||
opcode == RVOpcodes::BLT || opcode == RVOpcodes::BGE ||
opcode == RVOpcodes::BLTU || opcode == RVOpcodes::BGEU ||
@@ -324,6 +398,27 @@ void RISCv64RegAlloc::getInstrUseDef(MachineInstr* instr, LiveSet& use, LiveSet&
}
return; // CALL 指令处理完毕
}
// 2.1 浮点比较指令添加特殊规则
// 它们的源操作数是浮点寄存器,但目标操作数是整数寄存器
if (opcode == RVOpcodes::FEQ_S || opcode == RVOpcodes::FLT_S || opcode == RVOpcodes::FLE_S) {
auto& operands = instr->getOperands();
// Def: 第一个操作数 (整数vreg)
if (operands[0]->getKind() == MachineOperand::KIND_REG) {
auto reg_op = static_cast<RegOperand*>(operands[0].get());
if(reg_op->isVirtual()) def.insert(reg_op->getVRegNum());
}
// Use: 第二、三个操作数 (浮点vreg)
if (operands[1]->getKind() == MachineOperand::KIND_REG) {
auto reg_op = static_cast<RegOperand*>(operands[1].get());
if(reg_op->isVirtual()) use.insert(reg_op->getVRegNum());
}
if (operands[2]->getKind() == MachineOperand::KIND_REG) {
auto reg_op = static_cast<RegOperand*>(operands[2].get());
if(reg_op->isVirtual()) use.insert(reg_op->getVRegNum());
}
return; // 处理完毕
}
// 3. 对其他所有指令的通用处理逻辑 [已重构和修复]
for (const auto& op : instr->getOperands()) {
@@ -351,7 +446,7 @@ void RISCv64RegAlloc::getInstrUseDef(MachineInstr* instr, LiveSet& use, LiveSet&
}
}
} else if (op->getKind() == MachineOperand::KIND_MEM) {
// [保持不变] 内存操作数的处理逻辑看起来是正确的
// 内存操作数的处理逻辑看起来是正确的
auto mem_op = static_cast<MemOperand*>(op.get());
auto base_reg = mem_op->getBase();
if (base_reg->isVirtual()) {
@@ -364,7 +459,7 @@ void RISCv64RegAlloc::getInstrUseDef(MachineInstr* instr, LiveSet& use, LiveSet&
}
// 对于存储内存指令 (SW, SD),要存储的值(第一个操作数)也是 `use`
if ((opcode == RVOpcodes::SW || opcode == RVOpcodes::SD) &&
if ((opcode == RVOpcodes::SW || opcode == RVOpcodes::SD || opcode == RVOpcodes::FSW) &&
!instr->getOperands().empty() &&
instr->getOperands().front()->getKind() == MachineOperand::KIND_REG) {
auto src_reg_op = static_cast<RegOperand*>(instr->getOperands().front().get());
@@ -605,28 +700,53 @@ void RISCv64RegAlloc::buildInterferenceGraph() {
}
// CALL 指令会定义(杀死)所有调用者保存的寄存器。
// 因此,所有调用者保存的物理寄存器都与 CALL 指令的 live_out 中的所有变量冲突。
const std::vector<PhysicalReg>& caller_saved_regs = getCallerSavedIntRegs();
for (PhysicalReg cs_reg : caller_saved_regs) {
if (preg_to_vreg_id_map.count(cs_reg)) {
unsigned cs_vreg_id = preg_to_vreg_id_map.at(cs_reg); // 获取物理寄存器对应的特殊vreg ID
// 辅助函数用于判断一个vreg是整数类型还是浮点类型
auto is_fp_vreg = [&](unsigned vreg) {
if (vreg_to_value_map.count(vreg)) {
return vreg_to_value_map.at(vreg)->getType()->isFloat();
}
// 对于ISel创建的、没有直接IR Value对应的临时vreg
// 默认其为整数类型。这是一个合理的兜底策略。
return false;
};
// --- 处理整数寄存器干扰 ---
const std::vector<PhysicalReg>& caller_saved_int_regs = getCallerSavedIntRegs();
for (PhysicalReg cs_reg : caller_saved_int_regs) {
// 确保物理寄存器在映射表中,我们已在构造函数中保证了这一点
unsigned cs_vreg_id = preg_to_vreg_id_map.at(cs_reg);
// 将这个物理寄存器节点与 CALL 指令的 live_out 中的所有虚拟寄存器添加干扰边。
for (unsigned live_vreg_out : live_out) {
if (cs_vreg_id != live_vreg_out) { // 避免自己和自己干扰
// [新增调试逻辑] 打印添加的干扰边及其原因
for (unsigned live_vreg_out : live_out) {
// 只为整数vreg添加与整数preg的干扰
if (!is_fp_vreg(live_vreg_out)) {
if (cs_vreg_id != live_vreg_out) {
if (DEEPDEBUG && interference_graph[cs_vreg_id].find(live_vreg_out) == interference_graph[cs_vreg_id].end()) {
std::cerr << " Edge (CALL) : preg(" << static_cast<int>(cs_reg) << ") <-> %vreg" << live_vreg_out << "\n";
std::cerr << " Edge (CALL, Int): preg(" << static_cast<int>(cs_reg) << ") <-> %vreg" << live_vreg_out << "\n";
}
interference_graph[cs_vreg_id].insert(live_vreg_out);
interference_graph[live_vreg_out].insert(cs_vreg_id);
}
}
} else {
// 如果物理寄存器没有对应的特殊虚拟寄存器ID可能是因为它不是调用者保存的寄存器。
// 这种情况通常不应该发生,但我们可以在这里添加一个警告或错误处理。
if (DEEPDEBUG) {
std::cerr << "Warning: Physical register " << static_cast<int>(cs_reg)
<< " does not have a corresponding special vreg ID.\n";
}
}
// --- 处理浮点寄存器干扰 ---
const std::vector<PhysicalReg>& caller_saved_fp_regs = getCallerSavedFpRegs();
for (PhysicalReg cs_reg : caller_saved_fp_regs) {
unsigned cs_vreg_id = preg_to_vreg_id_map.at(cs_reg);
for (unsigned live_vreg_out : live_out) {
// 只为浮点vreg添加与浮点preg的干扰
if (is_fp_vreg(live_vreg_out)) {
if (cs_vreg_id != live_vreg_out) {
// 添加与整数版本一致的调试代码
if (DEEPDEBUG && interference_graph[cs_vreg_id].find(live_vreg_out) == interference_graph[cs_vreg_id].end()) {
std::cerr << " Edge (CALL, FP): preg(" << static_cast<int>(cs_reg) << ") <-> %vreg" << live_vreg_out << "\n";
}
interference_graph[cs_vreg_id].insert(live_vreg_out);
interference_graph[live_vreg_out].insert(cs_vreg_id);
}
}
}
}
@@ -650,34 +770,70 @@ void RISCv64RegAlloc::colorGraph() {
return interference_graph[a].size() > interference_graph[b].size();
});
// [调试] 辅助函数用于判断一个vreg是整数还是浮点类型并打印详细诊断信息
auto is_fp_vreg = [&](unsigned vreg) {
if (DEEPDEBUG) {
std::cout << " [Debug is_fp_vreg] Checking vreg" << vreg << ": ";
}
if (vreg_to_value_map.count(vreg)) {
Value* val = vreg_to_value_map.at(vreg);
bool is_float = val->getType()->isFloat();
if (DEEPDEBUG) {
std::cout << "Found in map. Value is '" << val->getName()
<< "', Type is " << (is_float ? "FLOAT" : "INT")
<< ". Returning " << (is_float ? "true" : "false") << ".\n";
}
return is_float;
}
if (DEEPDEBUG) {
std::cout << "NOT found in vreg_to_value_map. Defaulting to INT. Returning false.\n";
}
// 对于ISel创建的、没有直接IR Value对应的临时vreg默认其为整数类型。
return false;
};
// 着色
for (unsigned vreg : sorted_vregs) {
std::set<PhysicalReg> used_colors;
for (unsigned neighbor_id : interference_graph.at(vreg)) {
// --- 关键改进 (来自 rec 分支) ---
// 情况 1: 邻居是一个已经被着色的虚拟寄存器
// 收集邻居颜色的逻辑保持不变
if (color_map.count(neighbor_id)) {
used_colors.insert(color_map.at(neighbor_id));
}
// 情况 2: 邻居本身就是一个代表物理寄存器的节点
else if (neighbor_id >= static_cast<unsigned>(PhysicalReg::PHYS_REG_START_ID)) {
// 从特殊ID反向解析出是哪个物理寄存器
PhysicalReg neighbor_preg = static_cast<PhysicalReg>(neighbor_id - static_cast<unsigned>(PhysicalReg::PHYS_REG_START_ID));
used_colors.insert(neighbor_preg);
}
}
bool is_float = is_fp_vreg(vreg);
const auto& allocable_regs = is_float ? allocable_fp_regs : allocable_int_regs;
// [调试] 打印着色决策过程
if (DEBUG) {
std::cout << "[DEBUG] Coloring %vreg" << vreg
<< ": Type is " << (is_float ? "FLOAT" : "INT")
<< ", choosing from " << (is_float ? "Float" : "Integer") << " pool.\n";
}
bool colored = false;
for (PhysicalReg preg : allocable_int_regs) {
for (PhysicalReg preg : allocable_regs) {
if (used_colors.find(preg) == used_colors.end()) {
color_map[vreg] = preg;
colored = true;
if (DEBUG) {
RISCv64AsmPrinter p(MFunc); // For regToString
std::cout << " -> Assigned to physical register: " << p.regToString(preg) << "\n";
}
break;
}
}
if (!colored) {
spilled_vregs.insert(vreg);
if (DEBUG) {
std::cout << " -> FAILED to color. Spilling.\n";
}
}
}
}
@@ -686,7 +842,7 @@ void RISCv64RegAlloc::rewriteFunction() {
StackFrameInfo& frame_info = MFunc->getFrameInfo();
int current_offset = frame_info.locals_size;
// --- FIX 1: 动态计算溢出槽大小 ---
// --- 动态计算溢出槽大小 ---
// 根据溢出虚拟寄存器的真实类型,为其在栈上分配正确大小的空间。
for (unsigned vreg : spilled_vregs) {
// 从反向映射中查找 vreg 对应的 IR Value
@@ -704,23 +860,40 @@ void RISCv64RegAlloc::rewriteFunction() {
}
frame_info.spill_size = current_offset - frame_info.locals_size;
// 定义专用的溢出寄存器
const PhysicalReg INT_SPILL_REG = PhysicalReg::T6; // t6
const PhysicalReg FP_SPILL_REG = PhysicalReg::F7; // ft7
for (auto& mbb : MFunc->getBlocks()) {
std::vector<std::unique_ptr<MachineInstr>> new_instructions;
for (auto& instr_ptr : mbb->getInstructions()) {
LiveSet use, def;
getInstrUseDef(instr_ptr.get(), use, def);
// --- FIX 2: 为溢出的 'use' 操作数插入正确的加载指令 ---
// --- 为溢出的 'use' 操作数插入正确的加载指令 ---
for (unsigned vreg : use) {
if (spilled_vregs.count(vreg)) {
// 同样地,根据 vreg 的类型决定使用 lw 还是 ld
assert(vreg_to_value_map.count(vreg));
Value* val = vreg_to_value_map.at(vreg);
RVOpcodes load_op = val->getType()->isPointer() ? RVOpcodes::LD : RVOpcodes::LW;
// 根据vreg类型决定加载指令(lw/ld/flw)和目标物理寄存器(t6/ft7)
RVOpcodes load_op;
PhysicalReg target_preg;
if (val->getType()->isFloat()) {
load_op = RVOpcodes::FLW;
target_preg = FP_SPILL_REG;
} else if (val->getType()->isPointer()) {
load_op = RVOpcodes::LD;
target_preg = INT_SPILL_REG;
} else {
load_op = RVOpcodes::LW;
target_preg = INT_SPILL_REG;
}
int offset = frame_info.spill_offsets.at(vreg);
auto load = std::make_unique<MachineInstr>(load_op);
load->addOperand(std::make_unique<RegOperand>(vreg));
load->addOperand(std::make_unique<RegOperand>(target_preg)); // 加载到专用溢出寄存器
load->addOperand(std::make_unique<MemOperand>(
std::make_unique<RegOperand>(PhysicalReg::S0),
std::make_unique<ImmOperand>(offset)
@@ -731,17 +904,29 @@ void RISCv64RegAlloc::rewriteFunction() {
new_instructions.push_back(std::move(instr_ptr));
// --- FIX 3: 为溢出的 'def' 操作数插入正确的存储指令 ---
// --- 为溢出的 'def' 操作数插入正确的存储指令 ---
for (unsigned vreg : def) {
if (spilled_vregs.count(vreg)) {
// 根据 vreg 的类型决定使用 sw 还是 sd
assert(vreg_to_value_map.count(vreg));
Value* val = vreg_to_value_map.at(vreg);
RVOpcodes store_op = val->getType()->isPointer() ? RVOpcodes::SD : RVOpcodes::SW;
// 根据vreg类型决定存储指令(sw/sd/fsw)和源物理寄存器(t6/ft7)
RVOpcodes store_op;
PhysicalReg src_preg;
if (val->getType()->isFloat()) {
store_op = RVOpcodes::FSW;
src_preg = FP_SPILL_REG;
} else if (val->getType()->isPointer()) {
store_op = RVOpcodes::SD;
src_preg = INT_SPILL_REG;
} else {
store_op = RVOpcodes::SW;
src_preg = INT_SPILL_REG;
}
int offset = frame_info.spill_offsets.at(vreg);
auto store = std::make_unique<MachineInstr>(store_op);
store->addOperand(std::make_unique<RegOperand>(vreg));
store->addOperand(std::make_unique<RegOperand>(src_preg)); // 从专用溢出寄存器存储
store->addOperand(std::make_unique<MemOperand>(
std::make_unique<RegOperand>(PhysicalReg::S0),
std::make_unique<ImmOperand>(offset)
@@ -757,40 +942,29 @@ void RISCv64RegAlloc::rewriteFunction() {
for (auto& mbb : MFunc->getBlocks()) {
for (auto& instr_ptr : mbb->getInstructions()) {
for (auto& op_ptr : instr_ptr->getOperands()) {
// 情况一:操作数本身就是一个寄存器 (例如 add rd, rs1, rs2 中的所有操作数)
if(op_ptr->getKind() == MachineOperand::KIND_REG) {
auto reg_op = static_cast<RegOperand*>(op_ptr.get());
// 定义一个处理寄存器操作数的 lambda 函数
auto process_reg_op = [&](RegOperand* reg_op) {
if (reg_op->isVirtual()) {
unsigned vreg = reg_op->getVRegNum();
if (color_map.count(vreg)) {
PhysicalReg preg = color_map.at(vreg);
reg_op->setPReg(preg);
reg_op->setPReg(color_map.at(vreg));
} else if (spilled_vregs.count(vreg)) {
// 如果vreg被溢出,替换为专用溢出物理寄存器t6
reg_op->setPReg(PhysicalReg::T6);
}
}
}
// 情况二:操作数是一个内存地址 (例如 lw rd, offset(rs1) 中的 offset(rs1))
else if (op_ptr->getKind() == MachineOperand::KIND_MEM) {
auto mem_op = static_cast<MemOperand*>(op_ptr.get());
// 获取内存操作数内部的“基址寄存器”
auto base_reg_op = mem_op->getBase();
// 对这个基址寄存器,执行与情况一完全相同的替换逻辑
if(base_reg_op->isVirtual()){
unsigned vreg = base_reg_op->getVRegNum();
if(color_map.count(vreg)) {
// 如果基址vreg被成功着色替换
PhysicalReg preg = color_map.at(vreg);
base_reg_op->setPReg(preg);
} else if (spilled_vregs.count(vreg)) {
// 如果基址vreg被溢出替换为t6
base_reg_op->setPReg(PhysicalReg::T6);
// 根据vreg类型,替换为对应的专用溢出寄存器
assert(vreg_to_value_map.count(vreg));
Value* val = vreg_to_value_map.at(vreg);
if (val->getType()->isFloat()) {
reg_op->setPReg(FP_SPILL_REG);
} else {
reg_op->setPReg(INT_SPILL_REG);
}
}
}
};
if(op_ptr->getKind() == MachineOperand::KIND_REG) {
process_reg_op(static_cast<RegOperand*>(op_ptr.get()));
} else if (op_ptr->getKind() == MachineOperand::KIND_MEM) {
process_reg_op(static_cast<MemOperand*>(op_ptr.get())->getBase());
}
}
}

View File

@@ -12,6 +12,26 @@ namespace sysy {
* * 主要目标是优化寄存器分配器插入的spill/fill代码(lw/sw)
* 尝试将加载指令提前,以隐藏其访存延迟。
*/
struct MemoryAccess {
PhysicalReg base_reg;
int64_t offset;
bool valid;
MemoryAccess() : valid(false) {}
MemoryAccess(PhysicalReg base, int64_t off) : base_reg(base), offset(off), valid(true) {}
};
struct InstrRegInfo {
std::unordered_set<PhysicalReg> defined_regs;
std::unordered_set<PhysicalReg> used_regs;
bool is_load;
bool is_store;
bool is_control_flow;
MemoryAccess mem_access;
InstrRegInfo() : is_load(false), is_store(false), is_control_flow(false) {}
};
class PostRA_Scheduler : public Pass {
public:
static char ID;

View File

@@ -18,12 +18,12 @@ public:
void printInstruction(MachineInstr* instr, bool debug = false);
// 辅助函数
void setStream(std::ostream& os) { OS = &os; }
// 辅助函数
std::string regToString(PhysicalReg reg);
private:
// 打印各个部分
void printBasicBlock(MachineBasicBlock* mbb, bool debug = false);
// 辅助函数
std::string regToString(PhysicalReg reg);
void printOperand(MachineOperand* op);
MachineFunction* MFunc;

View File

@@ -17,8 +17,11 @@ public:
// 公开接口以便后续模块如RegAlloc可以查询或创建vreg
unsigned getVReg(Value* val);
unsigned getNewVReg() { return vreg_counter++; }
unsigned getNewVReg(Type* type);
// 获取 vreg_map 的公共接口
const std::map<Value*, unsigned>& getVRegMap() const { return vreg_map; }
const std::map<unsigned, Value*>& getVRegValueMap() const { return vreg_to_value_map; }
const std::map<unsigned, Type*>& getVRegTypeMap() const { return vreg_type_map; }
private:
// DAG节点定义作为ISel的内部实现细节
@@ -38,6 +41,7 @@ private:
// 用于计算类型大小的辅助函数
unsigned getTypeSizeInBytes(Type* type);
// 打印DAG图以供调试
void print_dag(const std::vector<std::unique_ptr<DAGNode>>& dag, const std::string& bb_name);
// 状态
@@ -47,6 +51,8 @@ private:
// 映射关系
std::map<Value*, unsigned> vreg_map;
std::map<unsigned, Value*> vreg_to_value_map;
std::map<unsigned, Type*> vreg_type_map;
std::map<const BasicBlock*, MachineBasicBlock*> bb_map;
unsigned vreg_counter;

View File

@@ -32,7 +32,6 @@ enum class PhysicalReg {
A0, A1, A2, A3, A4, A5, A6, A7,
// --- 浮点寄存器 ---
// (保持您原有的 F0-F31 命名)
F0, F1, F2, F3, F4, F5, F6, F7, F8, F9, F10, F11,
F12, F13, F14, F15, F16, F17, F18, F19, F20, F21,
F22, F23, F24, F25, F26, F27, F28, F29, F30, F31,
@@ -64,16 +63,97 @@ enum class RVOpcodes {
CALL,
// 特殊标记,非指令
LABEL,
// 浮点指令 (RISC-V 'F' 扩展)
// 浮点加载与存储
FLW, // flw rd, offset(rs1)
FSW, // fsw rs2, offset(rs1)
FLD, // fld rd, offset(rs1)
FSD, // fsd rs2, offset(rs1)
// 浮点算术运算 (单精度)
FADD_S, // fadd.s rd, rs1, rs2
FSUB_S, // fsub.s rd, rs1, rs2
FMUL_S, // fmul.s rd, rs1, rs2
FDIV_S, // fdiv.s rd, rs1, rs2
// 浮点比较 (单精度)
FEQ_S, // feq.s rd, rs1, rs2 (结果写入整数寄存器rd)
FLT_S, // flt.s rd, rs1, rs2 (less than)
FLE_S, // fle.s rd, rs1, rs2 (less than or equal)
// 浮点转换
FCVT_S_W, // fcvt.s.w rd, rs1 (有符号整数 -> 单精度浮点)
FCVT_W_S, // fcvt.w.s rd, rs1 (单精度浮点 -> 有符号整数)
// 浮点传送/移动
FMV_S, // fmv.s rd, rs1 (浮点寄存器之间)
FMV_W_X, // fmv.w.x rd, rs1 (整数寄存器位模式 -> 浮点寄存器)
FMV_X_W, // fmv.x.w rd, rs1 (浮点寄存器位模式 -> 整数寄存器)
FNEG_S, // fneg.s rd, rs (浮点取负)
// 新增伪指令,用于解耦栈帧处理
FRAME_LOAD_W, // 从栈帧加载 32位 Word (对应 lw)
FRAME_LOAD_D, // 从栈帧加载 64位 Doubleword (对应 ld)
FRAME_STORE_W, // 保存 32位 Word 到栈帧 (对应 sw)
FRAME_STORE_D, // 保存 64位 Doubleword 到栈帧 (对应 sd)
FRAME_LOAD_F, // 从栈帧加载单精度浮点数
FRAME_STORE_F, // 将单精度浮点数存入栈帧
FRAME_ADDR, // 获取栈帧变量的地址
};
// 定义一个全局辅助函数或常量,提供调用者保存寄存器列表
const std::vector<PhysicalReg>& getCallerSavedIntRegs();
inline bool isGPR(PhysicalReg reg) {
return reg >= PhysicalReg::ZERO && reg <= PhysicalReg::T6;
}
// 判断一个物理寄存器是否是浮点寄存器 (FPR)
inline bool isFPR(PhysicalReg reg) {
return reg >= PhysicalReg::F0 && reg <= PhysicalReg::F31;
}
// 获取所有调用者保存的整数寄存器 (t0-t6, a0-a7)
inline const std::vector<PhysicalReg>& getCallerSavedIntRegs() {
static const std::vector<PhysicalReg> regs = {
PhysicalReg::T0, PhysicalReg::T1, PhysicalReg::T2, PhysicalReg::T3,
PhysicalReg::T4, PhysicalReg::T5, PhysicalReg::T6,
PhysicalReg::A0, PhysicalReg::A1, PhysicalReg::A2, PhysicalReg::A3,
PhysicalReg::A4, PhysicalReg::A5, PhysicalReg::A6, PhysicalReg::A7
};
return regs;
}
// 获取所有被调用者保存的整数寄存器 (s0-s11)
inline const std::vector<PhysicalReg>& getCalleeSavedIntRegs() {
static const std::vector<PhysicalReg> regs = {
PhysicalReg::S0, PhysicalReg::S1, PhysicalReg::S2, PhysicalReg::S3,
PhysicalReg::S4, PhysicalReg::S5, PhysicalReg::S6, PhysicalReg::S7,
PhysicalReg::S8, PhysicalReg::S9, PhysicalReg::S10, PhysicalReg::S11
};
return regs;
}
// 获取所有调用者保存的浮点寄存器 (ft0-ft11, fa0-fa7)
inline const std::vector<PhysicalReg>& getCallerSavedFpRegs() {
static const std::vector<PhysicalReg> regs = {
PhysicalReg::F0, PhysicalReg::F1, PhysicalReg::F2, PhysicalReg::F3,
PhysicalReg::F4, PhysicalReg::F5, PhysicalReg::F6, PhysicalReg::F7,
PhysicalReg::F8, PhysicalReg::F9, PhysicalReg::F10, PhysicalReg::F11, // ft0-ft11 和 fa0-fa7 在标准ABI中重叠
PhysicalReg::F12, PhysicalReg::F13, PhysicalReg::F14, PhysicalReg::F15,
PhysicalReg::F16, PhysicalReg::F17
};
return regs;
}
// 获取所有被调用者保存的浮点寄存器 (fs0-fs11)
inline const std::vector<PhysicalReg>& getCalleeSavedFpRegs() {
static const std::vector<PhysicalReg> regs = {
PhysicalReg::F18, PhysicalReg::F19, PhysicalReg::F20, PhysicalReg::F21,
PhysicalReg::F22, PhysicalReg::F23, PhysicalReg::F24, PhysicalReg::F25,
PhysicalReg::F26, PhysicalReg::F27, PhysicalReg::F28, PhysicalReg::F29,
PhysicalReg::F30, PhysicalReg::F31
};
return regs;
}
class MachineOperand;
class RegOperand;
@@ -199,6 +279,7 @@ struct StackFrameInfo {
std::map<unsigned, int> alloca_offsets; // <AllocaInst的vreg, 栈偏移>
std::map<unsigned, int> spill_offsets; // <溢出vreg, 栈偏移>
std::set<PhysicalReg> used_callee_saved_regs; // 使用的保存寄存器
std::map<unsigned, PhysicalReg> vreg_to_preg_map;
};
// 机器函数
@@ -224,15 +305,6 @@ private:
StackFrameInfo frame_info;
};
inline const std::vector<PhysicalReg>& getCallerSavedIntRegs() {
static const std::vector<PhysicalReg> regs = {
PhysicalReg::T0, PhysicalReg::T1, PhysicalReg::T2, PhysicalReg::T3,
PhysicalReg::T4, PhysicalReg::T5, PhysicalReg::T6,
PhysicalReg::A0, PhysicalReg::A1, PhysicalReg::A2, PhysicalReg::A3,
PhysicalReg::A4, PhysicalReg::A5, PhysicalReg::A6, PhysicalReg::A7
};
return regs;
}
} // namespace sysy
#endif // RISCV64_LLIR_H

View File

@@ -56,6 +56,7 @@ private:
// 可用的物理寄存器池
std::vector<PhysicalReg> allocable_int_regs;
std::vector<PhysicalReg> allocable_fp_regs;
// 存储vreg到IR Value*的反向映射
// 这个map将在run()函数开始时被填充并在rewriteFunction()中使用。

View File

@@ -126,7 +126,7 @@ class IRBuilder {
UnaryInst * createFNotInst(Value *operand, const std::string &name = "") {
return createUnaryInst(Instruction::kFNot, Type::getIntType(), operand, name);
} ///< 创建浮点取非指令
UnaryInst * createIToFInst(Value *operand, const std::string &name = "") {
UnaryInst * createItoFInst(Value *operand, const std::string &name = "") {
return createUnaryInst(Instruction::kItoF, Type::getFloatType(), operand, name);
} ///< 创建整型转浮点指令
UnaryInst * createBitItoFInst(Value *operand, const std::string &name = "") {

View File

@@ -59,6 +59,35 @@ private:
std::unique_ptr<Module> module;
IRBuilder builder;
using ValueOrOperator = std::variant<Value*, int>;
std::vector<ValueOrOperator> BinaryExpStack; ///< 用于存储二元表达式的中缀表达式
std::vector<int> BinaryExpLenStack; ///< 用于存储该层次的二元表达式的长度
// 下面是用于后缀表达式的计算的数据结构
std::vector<ValueOrOperator> BinaryRPNStack; ///< 用于存储二元表达式的后缀表达式
std::vector<int> BinaryOpStack; ///< 用于存储二元表达式中缀表达式转换到后缀表达式的操作符栈
std::vector<Value *> BinaryValueStack; ///< 用于存储后缀表达式计算的操作数栈
// 约定操作符:
// 1: 'ADD', 2: 'SUB', 3: 'MUL', 4: 'DIV', 5: '%', 6: 'PLUS', 7: 'NEG', 8: 'NOT', 9: 'LPAREN', 10: 'RPAREN'
// 这里的操作符是为了方便后缀表达式的计算而设计
// 其中,'ADD', 'SUB', 'MUL', 'DIV', '%'
// 分别对应加法、减法、乘法、除法和取模
// 'PLUS' 和 'NEG' 分别对应一元加法和一元减法
// 'NOT' 对应逻辑非
// 'LPAREN' 和 'RPAREN' 分别对应左括号和右括号
enum BinaryOp {
ADD = 1, SUB = 2, MUL = 3, DIV = 4, MOD = 5, PLUS = 6, NEG = 7, NOT = 8, LPAREN = 9, RPAREN = 10,
};
int getOperatorPrecedence(int op) {
switch (op) {
case MUL: case DIV: case MOD: return 2;
case ADD: case SUB: return 1;
case PLUS: case NEG: case NOT: return 3;
case LPAREN: case RPAREN: return 0; // Parentheses have lowest precedence for stack logic
default: return -1; // Unknown operator
}
}
public:
SysYIRGenerator() = default;
@@ -97,7 +126,7 @@ public:
std::any visitBlockStmt(SysYParser::BlockStmtContext* ctx) override;
// std::any visitStmt(SysYParser::StmtContext *ctx) override;
std::any visitAssignStmt(SysYParser::AssignStmtContext *ctx) override;
// std::any visitExpStmt(SysYParser::ExpStmtContext *ctx) override;
std::any visitExpStmt(SysYParser::ExpStmtContext *ctx) override;
// std::any visitBlkStmt(SysYParser::BlkStmtContext *ctx) override;
std::any visitIfStmt(SysYParser::IfStmtContext *ctx) override;
std::any visitWhileStmt(SysYParser::WhileStmtContext *ctx) override;
@@ -131,8 +160,13 @@ public:
std::any visitLAndExp(SysYParser::LAndExpContext *ctx) override;
std::any visitLOrExp(SysYParser::LOrExpContext *ctx) override;
// std::any visitConstExp(SysYParser::ConstExpContext *ctx) override;
std::any visitConstExp(SysYParser::ConstExpContext *ctx) override;
bool isRightAssociative(int op);
Value* promoteType(Value* value, Type* targetType);
Value* computeExp(SysYParser::ExpContext *ctx, Type* targetType = nullptr);
Value* computeAddExp(SysYParser::AddExpContext *ctx, Type* targetType = nullptr);
void compute();
public:
// 获取GEP指令的地址
Value* getGEPAddressInst(Value* basePointer, const std::vector<Value*>& indices);
@@ -141,6 +175,7 @@ public:
unsigned countArrayDimensions(Type* type);
}; // class SysYIRGenerator
} // namespace sysy

View File

@@ -16,6 +16,441 @@ using namespace std;
namespace sysy {
// std::vector<Value*> BinaryValueStack; ///< 用于存储value的栈
// std::vector<int> BinaryOpStack; ///< 用于存储二元表达式的操作符栈
// // 约定操作符:
// // 1: 'ADD', 2: 'SUB', 3: 'MUL', 4: 'DIV', 5: '%', 6: 'PLUS', 7: 'NEG', 8: 'NOT'
// enum BinaryOp {
// ADD = 1,
// SUB = 2,
// MUL = 3,
// DIV = 4,
// MOD = 5,
// PLUS = 6,
// NEG = 7,
// NOT = 8
// };
Value *SysYIRGenerator::promoteType(Value *value, Type *targetType) {
//如果是常量则直接返回相应的值
if (targetType == nullptr) {
return value; // 如果值为空,那就不需要转换
}
ConstantInteger* constInt = dynamic_cast<ConstantInteger *>(value);
ConstantFloating *constFloat = dynamic_cast<ConstantFloating *>(value);
if (constInt) {
if (targetType->isFloat()) {
return ConstantFloating::get(static_cast<float>(constInt->getInt()));
}
return constInt; // 如果目标类型是int直接返回原值
} else if (constFloat) {
if (targetType->isInt()) {
return ConstantInteger::get(static_cast<int>(constFloat->getFloat()));
}
return constFloat; // 如果目标类型是float直接返回原值
}
if (value->getType()->isInt() && targetType->isFloat()) {
return builder.createItoFInst(value);
} else if (value->getType()->isFloat() && targetType->isInt()) {
return builder.createFtoIInst(value);
}
// 如果类型已经匹配,直接返回原值
return value;
}
bool SysYIRGenerator::isRightAssociative(int op) {
return (op == BinaryOp::PLUS || op == BinaryOp::NEG || op == BinaryOp::NOT);
}
void SysYIRGenerator::compute() {
// 先将中缀表达式转换为后缀表达式
BinaryRPNStack.clear();
BinaryOpStack.clear();
int begin = BinaryExpStack.size() - BinaryExpLenStack.back(), end = BinaryExpStack.size();
for (int i = begin; i < end; i++) {
auto item = BinaryExpStack[i];
if (std::holds_alternative<sysy::Value *>(item)) {
// 如果是操作数 (Value*),直接推入后缀表达式栈
BinaryRPNStack.push_back(item); // 直接 push_back item (ValueOrOperator类型)
} else {
// 如果是操作符
int currentOp = std::get<int>(item);
if (currentOp == LPAREN) {
// 左括号直接入栈
BinaryOpStack.push_back(currentOp);
} else if (currentOp == RPAREN) {
// 右括号:将操作符栈中的操作符弹出并添加到后缀表达式栈,直到遇到左括号
while (!BinaryOpStack.empty() && BinaryOpStack.back() != LPAREN) {
BinaryRPNStack.push_back(BinaryOpStack.back()); // 直接 push_back int
BinaryOpStack.pop_back();
}
if (!BinaryOpStack.empty() && BinaryOpStack.back() == LPAREN) {
BinaryOpStack.pop_back(); // 弹出左括号,但不添加到后缀表达式栈
} else {
// 错误:不匹配的右括号
std::cerr << "Error: Mismatched parentheses in expression." << std::endl;
return;
}
} else {
// 普通操作符
while (!BinaryOpStack.empty() && BinaryOpStack.back() != LPAREN) {
int stackTopOp = BinaryOpStack.back();
// 如果当前操作符优先级低于栈顶操作符优先级
// 或者 (当前操作符优先级等于栈顶操作符优先级 并且 栈顶操作符是左结合)
if (getOperatorPrecedence(currentOp) < getOperatorPrecedence(stackTopOp) ||
(getOperatorPrecedence(currentOp) == getOperatorPrecedence(stackTopOp) &&
!isRightAssociative(stackTopOp))) {
BinaryRPNStack.push_back(stackTopOp);
BinaryOpStack.pop_back();
} else {
break; // 否则当前操作符入栈
}
}
BinaryOpStack.push_back(currentOp); // 当前操作符入栈
}
}
}
// 遍历结束后,将操作符栈中剩余的所有操作符弹出并添加到后缀表达式栈
while (!BinaryOpStack.empty()) {
if (BinaryOpStack.back() == LPAREN) {
// 错误:不匹配的左括号
std::cerr << "Error: Mismatched parentheses in expression (unclosed parenthesis)." << std::endl;
return;
}
BinaryRPNStack.push_back(BinaryOpStack.back()); // 直接 push_back int
BinaryOpStack.pop_back();
}
// 弹出BinaryExpStack的表达式
while(begin < end) {
BinaryExpStack.pop_back();
BinaryExpLenStack.back()--;
end--;
}
// 计算后缀表达式
// 每次计算前清空操作数栈
BinaryValueStack.clear();
// 遍历后缀表达式栈
Type *commonType = nullptr;
for(const auto &item : BinaryRPNStack) {
if (std::holds_alternative<Value *>(item)) {
// 如果是操作数 (Value*) 检测他的类型
Value *value = std::get<Value *>(item);
if (commonType == nullptr) {
commonType = value->getType();
}
else if (value->getType() != commonType && value->getType()->isFloat()) {
// 如果当前值的类型与commonType不同且是float类型则提升为float
commonType = Type::getFloatType();
break;
}
} else {
continue;
}
}
for (const auto &item : BinaryRPNStack) {
if (std::holds_alternative<sysy::Value *>(item)) {
// 如果是操作数 (Value*),直接推入操作数栈
BinaryValueStack.push_back(std::get<sysy::Value *>(item));
} else {
// 如果是操作符
int op = std::get<int>(item);
Value *resultValue = nullptr;
Value *lhs = nullptr;
Value *rhs = nullptr;
Value *operand = nullptr;
switch (op) {
case BinaryOp::ADD:
case BinaryOp::SUB:
case BinaryOp::MUL:
case BinaryOp::DIV:
case BinaryOp::MOD: {
// 二元操作符需要两个操作数
if (BinaryValueStack.size() < 2) {
std::cerr << "Error: Not enough operands for binary operation: " << op << std::endl;
return; // 或者抛出异常
}
rhs = BinaryValueStack.back();
BinaryValueStack.pop_back();
lhs = BinaryValueStack.back();
BinaryValueStack.pop_back();
// 类型转换
lhs = promoteType(lhs, commonType);
rhs = promoteType(rhs, commonType);
// 尝试常量折叠
ConstantValue *lhsConst = dynamic_cast<ConstantValue *>(lhs);
ConstantValue *rhsConst = dynamic_cast<ConstantValue *>(rhs);
if (lhsConst && rhsConst) {
// 如果都是常量,直接计算结果
if (commonType == Type::getIntType()) {
int lhsVal = lhsConst->getInt();
int rhsVal = rhsConst->getInt();
switch (op) {
case BinaryOp::ADD: resultValue = ConstantInteger::get(lhsVal + rhsVal); break;
case BinaryOp::SUB: resultValue = ConstantInteger::get(lhsVal - rhsVal); break;
case BinaryOp::MUL: resultValue = ConstantInteger::get(lhsVal * rhsVal); break;
case BinaryOp::DIV:
if (rhsVal == 0) {
std::cerr << "Error: Division by zero." << std::endl;
return;
}
resultValue = sysy::ConstantInteger::get(lhsVal / rhsVal); break;
case BinaryOp::MOD:
if (rhsVal == 0) {
std::cerr << "Error: Modulo by zero." << std::endl;
return;
}
resultValue = sysy::ConstantInteger::get(lhsVal % rhsVal); break;
default:
std::cerr << "Error: Unknown binary operator for constants: " << op << std::endl;
return;
}
} else if (commonType == Type::getFloatType()) {
float lhsVal = lhsConst->getFloat();
float rhsVal = rhsConst->getFloat();
switch (op) {
case BinaryOp::ADD: resultValue = ConstantFloating::get(lhsVal + rhsVal); break;
case BinaryOp::SUB: resultValue = ConstantFloating::get(lhsVal - rhsVal); break;
case BinaryOp::MUL: resultValue = ConstantFloating::get(lhsVal * rhsVal); break;
case BinaryOp::DIV:
if (rhsVal == 0.0f) {
std::cerr << "Error: Division by zero." << std::endl;
return;
}
resultValue = sysy::ConstantFloating::get(lhsVal / rhsVal); break;
case BinaryOp::MOD:
std::cerr << "Error: Modulo operator not supported for float types." << std::endl;
return;
default:
std::cerr << "Error: Unknown binary operator for float constants: " << op << std::endl;
return;
}
} else {
std::cerr << "Error: Unsupported type for binary constant operation." << std::endl;
return;
}
} else {
// 否则创建相应的IR指令
if (commonType == Type::getIntType()) {
switch (op) {
case BinaryOp::ADD: resultValue = builder.createAddInst(lhs, rhs); break;
case BinaryOp::SUB: resultValue = builder.createSubInst(lhs, rhs); break;
case BinaryOp::MUL: resultValue = builder.createMulInst(lhs, rhs); break;
case BinaryOp::DIV: resultValue = builder.createDivInst(lhs, rhs); break;
case BinaryOp::MOD: resultValue = builder.createRemInst(lhs, rhs); break;
}
} else if (commonType == Type::getFloatType()) {
switch (op) {
case BinaryOp::ADD: resultValue = builder.createFAddInst(lhs, rhs); break;
case BinaryOp::SUB: resultValue = builder.createFSubInst(lhs, rhs); break;
case BinaryOp::MUL: resultValue = builder.createFMulInst(lhs, rhs); break;
case BinaryOp::DIV: resultValue = builder.createFDivInst(lhs, rhs); break;
case BinaryOp::MOD:
std::cerr << "Error: Modulo operator not supported for float types." << std::endl;
return;
}
} else {
std::cerr << "Error: Unsupported type for binary instruction." << std::endl;
return;
}
}
break;
}
case BinaryOp::PLUS:
case BinaryOp::NEG:
case BinaryOp::NOT: {
// 一元操作符需要一个操作数
if (BinaryValueStack.empty()) {
std::cerr << "Error: Not enough operands for unary operation: " << op << std::endl;
return;
}
operand = BinaryValueStack.back();
BinaryValueStack.pop_back();
operand = promoteType(operand, commonType);
// 尝试常量折叠
ConstantInteger *constInt = dynamic_cast<ConstantInteger *>(operand);
ConstantFloating *constFloat = dynamic_cast<ConstantFloating *>(operand);
if (constInt || constFloat) {
// 如果是常量,直接计算结果
switch (op) {
case BinaryOp::PLUS: resultValue = operand; break;
case BinaryOp::NEG: {
if (constInt) {
resultValue = constInt->getNeg();
} else if (constFloat) {
resultValue = constFloat->getNeg();
} else {
std::cerr << "Error: Negation not supported for constant operand type." << std::endl;
return;
}
break;
}
case BinaryOp::NOT:
if (constInt) {
resultValue = sysy::ConstantInteger::get(constInt->getInt() == 0 ? 1 : 0);
} else if (constFloat) {
resultValue = sysy::ConstantInteger::get(constFloat->getFloat() == 0.0f ? 1 : 0);
} else {
std::cerr << "Error: Logical NOT not supported for constant operand type." << std::endl;
return;
}
break;
default:
std::cerr << "Error: Unknown unary operator for constants: " << op << std::endl;
return;
}
} else {
// 否则创建相应的IR指令
switch (op) {
case BinaryOp::PLUS:
resultValue = operand; // 一元加指令通常直接返回操作数
break;
case BinaryOp::NEG: {
if (commonType == sysy::Type::getIntType()) {
resultValue = builder.createNegInst(operand);
} else if (commonType == sysy::Type::getFloatType()) {
resultValue = builder.createFNegInst(operand);
} else {
std::cerr << "Error: Negation not supported for operand type." << std::endl;
return;
}
break;
}
case BinaryOp::NOT:
// 逻辑非
if (commonType == sysy::Type::getIntType()) {
resultValue = builder.createNotInst(operand);
} else if (commonType == sysy::Type::getFloatType()) {
resultValue = builder.createFNotInst(operand);
} else {
std::cerr << "Error: Logical NOT not supported for operand type." << std::endl;
return;
}
break;
default:
std::cerr << "Error: Unknown unary operator for instructions: " << op << std::endl;
return;
}
}
break;
}
default:
std::cerr << "Error: Unknown operator " << op << " encountered in RPN stack." << std::endl;
return;
}
// 将计算结果或指令结果推入操作数栈
if (resultValue) {
BinaryValueStack.push_back(resultValue);
} else {
std::cerr << "Error: Result value is null after processing operator " << op << "!" << std::endl;
return;
}
}
}
// 后缀表达式处理完毕,操作数栈的栈顶就是最终结果
if (BinaryValueStack.empty()) {
std::cerr << "Error: No values left in BinaryValueStack after processing RPN." << std::endl;
return;
}
if (BinaryValueStack.size() > 1) {
std::cerr
<< "Warning: Multiple values left in BinaryValueStack after processing RPN. Expression might be malformed."
<< std::endl;
}
BinaryRPNStack.clear(); // 清空后缀表达式栈
BinaryOpStack.clear(); // 清空操作符栈
return;
}
Value* SysYIRGenerator::computeExp(SysYParser::ExpContext *ctx, Type* targetType){
if (ctx->addExp() == nullptr) {
assert(false && "ExpContext should have an addExp child!");
}
BinaryExpLenStack.push_back(0); // 进入新的层次时Push 0
visitAddExp(ctx->addExp());
// if(targetType == nullptr) {
// targetType = Type::getIntType(); // 默认目标类型为int
// }
compute();
// 最后一个Value应该是最终结果
Value* result = BinaryValueStack.back();
BinaryValueStack.pop_back(); // 移除结果值
result = promoteType(result, targetType); // 确保结果类型符合目标类型
// 检查当前层次的操作符数量
int ExpLen = BinaryExpLenStack.back();
BinaryExpLenStack.pop_back(); // 离开层次时将该层次
if (ExpLen > 0) {
std::cerr << "Warning: There are still " << ExpLen << " binary val or op left unprocessed in this level!" << std::endl;
return nullptr;
}
return result;
}
Value* SysYIRGenerator::computeAddExp(SysYParser::AddExpContext *ctx, Type* targetType){
// 根据AddExpContext中的操作符和操作数计算加法表达式
// 这里假设AddExpContext已经被正确填充
if (ctx->mulExp().size() == 0) {
assert(false && "AddExpContext should have a mulExp child!");
}
BinaryExpLenStack.push_back(0); // 进入新的层次时Push 0
visitMulExp(ctx->mulExp(0));
// BinaryValueStack.push_back(result);
for (int i = 1; i < ctx->mulExp().size(); i++) {
auto opNode = dynamic_cast<antlr4::tree::TerminalNode*>(ctx->children[2*i-1]);
int opType = opNode->getSymbol()->getType();
switch(opType) {
case SysYParser::ADD: BinaryExpStack.push_back(BinaryOp::ADD); BinaryExpLenStack.back()++; break;
case SysYParser::SUB: BinaryExpStack.push_back(BinaryOp::SUB); BinaryExpLenStack.back()++; break;
default: assert(false && "Unexpected operator in AddExp.");
}
// BinaryExpStack.push_back(opType);
visitMulExp(ctx->mulExp(i));
// BinaryValueStack.push_back(operand);
}
// if(targetType == nullptr) {
// targetType = Type::getIntType(); // 默认目标类型为int
// }
// 根据后缀表达式的逻辑计算
compute();
// 最后一个Value应该是最终结果
Value* result = BinaryValueStack.back();
BinaryValueStack.pop_back(); // 移除最后一个值,因为它已经被计算
result = promoteType(result, targetType); // 确保结果类型符合目标类型
int ExpLen = BinaryExpLenStack.back();
BinaryExpLenStack.pop_back(); // 离开层次时将该层次
if (ExpLen > 0) {
std::cerr << "Warning: There are still " << ExpLen << " binary val or op left unprocessed in this level!" << std::endl;
return nullptr;
}
return result;
}
Type* SysYIRGenerator::buildArrayType(Type* baseType, const std::vector<Value*>& dims){
Type* currentType = baseType;
// 从最内层维度开始构建 ArrayType
@@ -393,7 +828,8 @@ std::any SysYIRGenerator::visitBType(SysYParser::BTypeContext *ctx) {
}
std::any SysYIRGenerator::visitScalarInitValue(SysYParser::ScalarInitValueContext *ctx) {
Value* value = std::any_cast<Value *>(visitExp(ctx->exp()));
// Value* value = std::any_cast<Value *>(visitExp(ctx->exp()));
Value* value = computeExp(ctx->exp());
ArrayValueTree* result = new ArrayValueTree();
result->setValue(value);
return result;
@@ -408,13 +844,17 @@ std::any SysYIRGenerator::visitArrayInitValue(SysYParser::ArrayInitValueContext
return result;
}
std::any SysYIRGenerator::visitConstScalarInitValue(SysYParser::ConstScalarInitValueContext *ctx) {
std::any SysYIRGenerator::visitConstScalarInitValue(SysYParser::ConstScalarInitValueContext *ctx) {
Value* value = std::any_cast<Value *>(visitConstExp(ctx->constExp()));
ArrayValueTree* result = new ArrayValueTree();
result->setValue(value);
return result;
}
std::any SysYIRGenerator::visitConstExp(SysYParser::ConstExpContext *ctx){
return computeAddExp(ctx->addExp());
}
std::any SysYIRGenerator::visitConstArrayInitValue(SysYParser::ConstArrayInitValueContext *ctx) {
std::vector<ArrayValueTree *> children;
for (const auto &constInitVal : ctx->constInitVal())
@@ -570,8 +1010,8 @@ std::any SysYIRGenerator::visitAssignStmt(SysYParser::AssignStmtContext *ctx) {
vector<Value *> indices;
if (lVal->exp().size() > 0) {
// 如果有下标,访问表达式获取下标值
for (const auto &exp : lVal->exp()) {
Value* indexValue = std::any_cast<Value *>(visitExp(exp));
for (auto &exp : lVal->exp()) {
Value* indexValue = std::any_cast<Value *>(computeExp(exp));
indices.push_back(indexValue);
}
}
@@ -610,15 +1050,18 @@ std::any SysYIRGenerator::visitAssignStmt(SysYParser::AssignStmtContext *ctx) {
LValue = getGEPAddressInst(gepBasePointer, gepIndices);
}
Value* RValue = std::any_cast<Value *>(visitExp(ctx->exp())); // 右值
// Value* RValue = std::any_cast<Value *>(visitExp(ctx->exp())); // 右值
// 先推断 LValue 的类型
// 如果 LValue 是指向数组的指针,则需要根据 indices 获取正确的类型
// 如果 LValue 是标量,则直接使用其类型
// 注意LValue 的类型可能是指向数组的指针 (e.g., int(*)[3]) 或者指向标量的指针 (e.g., int*) 也能推断
Type* LType = builder.getIndexedType(variable->getType(), indices);
Value* RValue = computeExp(ctx->exp(), LType); // 右值计算
Type* RType = RValue->getType();
// TODO:computeExp处理了类型转换可以考虑删除判断逻辑
if (LType != RType) {
ConstantValue *constValue = dynamic_cast<ConstantValue *>(RValue);
if (constValue != nullptr) {
@@ -642,7 +1085,7 @@ std::any SysYIRGenerator::visitAssignStmt(SysYParser::AssignStmtContext *ctx) {
}
} else {
if (LType == Type::getFloatType()) {
RValue = builder.createIToFInst(RValue);
RValue = builder.createItoFInst(RValue);
} else { // 假设如果不是浮点型,就是整型
RValue = builder.createFtoIInst(RValue);
}
@@ -655,6 +1098,14 @@ std::any SysYIRGenerator::visitAssignStmt(SysYParser::AssignStmtContext *ctx) {
}
std::any SysYIRGenerator::visitExpStmt(SysYParser::ExpStmtContext *ctx) {
// 访问表达式
if (ctx->exp() != nullptr) {
computeExp(ctx->exp());
}
return std::any();
}
std::any SysYIRGenerator::visitIfStmt(SysYParser::IfStmtContext *ctx) {
// labels string stream
@@ -822,11 +1273,11 @@ std::any SysYIRGenerator::visitContinueStmt(SysYParser::ContinueStmtContext *ctx
std::any SysYIRGenerator::visitReturnStmt(SysYParser::ReturnStmtContext *ctx) {
Value* returnValue = nullptr;
if (ctx->exp() != nullptr) {
returnValue = std::any_cast<Value *>(visitExp(ctx->exp()));
}
Type* funcType = builder.getBasicBlock()->getParent()->getReturnType();
if (ctx->exp() != nullptr) {
returnValue = computeExp(ctx->exp(), funcType);
}
// TODOL 考虑删除类型转换判断逻辑
if (returnValue != nullptr && funcType!= returnValue->getType()) {
ConstantValue * constValue = dynamic_cast<ConstantValue *>(returnValue);
if (constValue != nullptr) {
@@ -849,7 +1300,7 @@ std::any SysYIRGenerator::visitReturnStmt(SysYParser::ReturnStmtContext *ctx) {
}
} else {
if (funcType == Type::getFloatType()) {
returnValue = builder.createIToFInst(returnValue);
returnValue = builder.createItoFInst(returnValue);
} else {
returnValue = builder.createFtoIInst(returnValue);
}
@@ -891,7 +1342,8 @@ std::any SysYIRGenerator::visitLValue(SysYParser::LValueContext *ctx) {
std::vector<Value *> dims;
for (const auto &exp : ctx->exp()) {
dims.push_back(std::any_cast<Value *>(visitExp(exp)));
Value* expValue = std::any_cast<Value *>(computeExp(exp));
dims.push_back(expValue);
}
// 1. 获取变量的声明维度数量
@@ -995,16 +1447,23 @@ std::any SysYIRGenerator::visitLValue(SysYParser::LValueContext *ctx) {
}
std::any SysYIRGenerator::visitPrimaryExp(SysYParser::PrimaryExpContext *ctx) {
if (ctx->exp() != nullptr)
return visitExp(ctx->exp());
if (ctx->lValue() != nullptr)
return visitLValue(ctx->lValue());
if (ctx->number() != nullptr)
return visitNumber(ctx->number());
if (ctx->exp() != nullptr) {
BinaryExpStack.push_back(BinaryOp::LPAREN);BinaryExpLenStack.back()++;
visitExp(ctx->exp());
BinaryExpStack.push_back(BinaryOp::RPAREN);BinaryExpLenStack.back()++;
}
if (ctx->lValue() != nullptr) {
// 如果是 lValue将value压入栈中
BinaryExpStack.push_back(std::any_cast<Value *>(visitLValue(ctx->lValue())));BinaryExpLenStack.back()++;
}
if (ctx->number() != nullptr) {
BinaryExpStack.push_back(std::any_cast<Value *>(visitNumber(ctx->number())));BinaryExpLenStack.back()++;
}
if (ctx->string() != nullptr) {
cout << "String literal not supported in SysYIRGenerator." << endl;
}
return visitNumber(ctx->number());
return std::any();
}
std::any SysYIRGenerator::visitNumber(SysYParser::NumberContext *ctx) {
@@ -1074,7 +1533,7 @@ std::any SysYIRGenerator::visitCall(SysYParser::CallContext *ctx) {
if (formalParamExpectedValueType->isInt() && actualArgType->isFloat()) {
args[i] = builder.createFtoIInst(args[i]);
} else if (formalParamExpectedValueType->isFloat() && actualArgType->isInt()) {
args[i] = builder.createIToFInst(args[i]);
args[i] = builder.createItoFInst(args[i]);
}
// 2. 指针类型转换 (例如数组退化:`[N x T]*` 到 `T*`,或兼容指针类型之间) TODO不清楚有没有这种样例
// 这种情况常见于数组参数,实参可能是一个更具体的数组指针类型,
@@ -1099,235 +1558,78 @@ std::any SysYIRGenerator::visitCall(SysYParser::CallContext *ctx) {
}
std::any SysYIRGenerator::visitUnaryExp(SysYParser::UnaryExpContext *ctx) {
if (ctx->primaryExp() != nullptr)
return visitPrimaryExp(ctx->primaryExp());
if (ctx->call() != nullptr)
return visitCall(ctx->call());
Value* value = std::any_cast<Value *>(visitUnaryExp(ctx->unaryExp()));
Value* result = value;
if (ctx->unaryOp()->SUB() != nullptr) {
ConstantValue * constValue = dynamic_cast<ConstantValue *>(value);
if (constValue != nullptr) {
if (constValue->isFloat()) {
result = ConstantFloating::get(-constValue->getFloat());
} else {
result = ConstantInteger::get(-constValue->getInt());
if (ctx->primaryExp() != nullptr) {
visitPrimaryExp(ctx->primaryExp());
} else if (ctx->call() != nullptr) {
BinaryExpStack.push_back(std::any_cast<Value *>(visitCall(ctx->call())));BinaryExpLenStack.back()++;
} else if (ctx->unaryOp() != nullptr) {
// 遇到一元操作符,将其压入 BinaryExpStack
auto opNode = dynamic_cast<antlr4::tree::TerminalNode*>(ctx->unaryOp()->children[0]);
int opType = opNode->getSymbol()->getType();
switch(opType) {
case SysYParser::ADD: BinaryExpStack.push_back(BinaryOp::PLUS); BinaryExpLenStack.back()++; break;
case SysYParser::SUB: BinaryExpStack.push_back(BinaryOp::NEG); BinaryExpLenStack.back()++; break;
case SysYParser::NOT: BinaryExpStack.push_back(BinaryOp::NOT); BinaryExpLenStack.back()++; break;
default: assert(false && "Unexpected operator in UnaryExp.");
}
} else if (value != nullptr) {
if (value->getType() == Type::getIntType()) {
result = builder.createNegInst(value);
} else {
result = builder.createFNegInst(value);
}
} else {
std::cout << "UnExp: value is nullptr." << std::endl;
assert(false);
}
} else if (ctx->unaryOp()->NOT() != nullptr) {
auto constValue = dynamic_cast<ConstantValue *>(value);
if (constValue != nullptr) {
if (constValue->isFloat()) {
result =
ConstantFloating::get(1 - (constValue->getFloat() != 0.0F ? 1 : 0));
} else {
result = ConstantInteger::get(1 - (constValue->getInt() != 0 ? 1 : 0));
}
} else if (value != nullptr) {
if (value->getType() == Type::getIntType()) {
result = builder.createNotInst(value);
} else {
result = builder.createFNotInst(value);
}
} else {
std::cout << "UnExp: value is nullptr." << std::endl;
assert(false);
}
visitUnaryExp(ctx->unaryExp());
}
return result;
return std::any();
}
std::any SysYIRGenerator::visitFuncRParams(SysYParser::FuncRParamsContext *ctx) {
std::vector<Value *> params;
for (const auto &exp : ctx->exp())
params.push_back(std::any_cast<Value *>(visitExp(exp)));
for (const auto &exp : ctx->exp()) {
auto param = std::any_cast<Value *>(computeExp(exp));
params.push_back(param);
}
return params;
}
std::any SysYIRGenerator::visitMulExp(SysYParser::MulExpContext *ctx) {
Value * result = std::any_cast<Value *>(visitUnaryExp(ctx->unaryExp(0)));
visitUnaryExp(ctx->unaryExp(0));
for (int i = 1; i < ctx->unaryExp().size(); i++) {
auto opNode = dynamic_cast<antlr4::tree::TerminalNode*>(ctx->children[2*i-1]);
int opType = opNode->getSymbol()->getType();
Value* operand = std::any_cast<Value *>(visitUnaryExp(ctx->unaryExp(i)));
Type* resultType = result->getType();
Type* operandType = operand->getType();
Type* floatType = Type::getFloatType();
if (resultType == floatType || operandType == floatType) {
// 如果有一个操作数是浮点数,则将两个操作数都转换为浮点数
if (operandType != floatType) {
ConstantValue * constValue = dynamic_cast<ConstantValue *>(operand);
if (constValue != nullptr) {
if(dynamic_cast<ConstantInteger *>(constValue)) {
// 如果是整型常量,转换为浮点型
operand = ConstantFloating::get(static_cast<float>(constValue->getInt()));
} else if (dynamic_cast<ConstantFloating *>(constValue)) {
// 如果是浮点型常量,直接使用
operand = ConstantFloating::get(static_cast<float>(constValue->getFloat()));
}
}
else
operand = builder.createIToFInst(operand);
} else if (resultType != floatType) {
ConstantValue* constResult = dynamic_cast<ConstantValue *>(result);
if (constResult != nullptr) {
if(dynamic_cast<ConstantInteger *>(constResult)) {
// 如果是整型常量,转换为浮点型
result = ConstantFloating::get(static_cast<float>(constResult->getInt()));
} else if (dynamic_cast<ConstantFloating *>(constResult)) {
// 如果是浮点型常量,直接使用
result = ConstantFloating::get(static_cast<float>(constResult->getFloat()));
}
}
else
result = builder.createIToFInst(result);
}
ConstantFloating* constResult = dynamic_cast<ConstantFloating *>(result);
ConstantFloating* constOperand = dynamic_cast<ConstantFloating *>(operand);
if (opType == SysYParser::MUL) {
if ((constOperand != nullptr) && (constResult != nullptr)) {
result = ConstantFloating::get(constResult->getFloat() *
constOperand->getFloat());
} else {
result = builder.createFMulInst(result, operand);
}
} else if (opType == SysYParser::DIV) {
if ((constOperand != nullptr) && (constResult != nullptr)) {
result = ConstantFloating::get(constResult->getFloat() /
constOperand->getFloat());
} else {
result = builder.createFDivInst(result, operand);
}
} else {
// float类型的取模操作不允许
std::cout << "MulExp: float type mod operation is not allowed." << std::endl;
assert(false);
}
} else {
ConstantInteger *constResult = dynamic_cast<ConstantInteger *>(result);
ConstantInteger *constOperand = dynamic_cast<ConstantInteger *>(operand);
if (opType == SysYParser::MUL) {
if ((constOperand != nullptr) && (constResult != nullptr))
result = ConstantInteger::get(constResult->getInt() * constOperand->getInt());
else
result = builder.createMulInst(result, operand);
} else if (opType == SysYParser::DIV) {
if ((constOperand != nullptr) && (constResult != nullptr))
result = ConstantInteger::get(constResult->getInt() / constOperand->getInt());
else
result = builder.createDivInst(result, operand);
} else {
if ((constOperand != nullptr) && (constResult != nullptr))
result = ConstantInteger::get(constResult->getInt() % constOperand->getInt());
else
result = builder.createRemInst(result, operand);
}
switch(opType) {
case SysYParser::MUL: BinaryExpStack.push_back(BinaryOp::MUL); BinaryExpLenStack.back()++; break;
case SysYParser::DIV: BinaryExpStack.push_back(BinaryOp::DIV); BinaryExpLenStack.back()++; break;
case SysYParser::MOD: BinaryExpStack.push_back(BinaryOp::MOD); BinaryExpLenStack.back()++; break;
default: assert(false && "Unexpected operator in MulExp.");
}
visitUnaryExp(ctx->unaryExp(i));
}
return result;
return std::any();
}
std::any SysYIRGenerator::visitAddExp(SysYParser::AddExpContext *ctx) {
Value* result = std::any_cast<Value *>(visitMulExp(ctx->mulExp(0)));
visitMulExp(ctx->mulExp(0));
for (int i = 1; i < ctx->mulExp().size(); i++) {
auto opNode = dynamic_cast<antlr4::tree::TerminalNode*>(ctx->children[2*i-1]);
int opType = opNode->getSymbol()->getType();
Value* operand = std::any_cast<Value *>(visitMulExp(ctx->mulExp(i)));
Type* resultType = result->getType();
Type* operandType = operand->getType();
Type* floatType = Type::getFloatType();
if (resultType == floatType || operandType == floatType) {
// 类型转换
if (operandType != floatType) {
ConstantValue * constOperand = dynamic_cast<ConstantValue *>(operand);
if (constOperand != nullptr) {
if(dynamic_cast<ConstantInteger *>(constOperand)) {
// 如果是整型常量,转换为浮点型
operand = ConstantFloating::get(static_cast<float>(constOperand->getInt()));
} else if (dynamic_cast<ConstantFloating *>(constOperand)) {
// 如果是浮点型常量,直接使用
operand = ConstantFloating::get(static_cast<float>(constOperand->getFloat()));
}
}
else
operand = builder.createIToFInst(operand);
} else if (resultType != floatType) {
ConstantValue * constResult = dynamic_cast<ConstantValue *>(result);
if (constResult != nullptr) {
if(dynamic_cast<ConstantInteger *>(constResult)) {
// 如果是整型常量,转换为浮点型
result = ConstantFloating::get(static_cast<float>(constResult->getInt()));
} else if (dynamic_cast<ConstantFloating *>(constResult)) {
// 如果是浮点型常量,直接使用
result = ConstantFloating::get(static_cast<float>(constResult->getFloat()));
}
}
else
result = builder.createIToFInst(result);
}
ConstantFloating *constResult = dynamic_cast<ConstantFloating *>(result);
ConstantFloating *constOperand = dynamic_cast<ConstantFloating *>(operand);
if (opType == SysYParser::ADD) {
if ((constResult != nullptr) && (constOperand != nullptr))
result = ConstantFloating::get(constResult->getFloat() + constOperand->getFloat());
else
result = builder.createFAddInst(result, operand);
} else {
if ((constResult != nullptr) && (constOperand != nullptr))
result = ConstantFloating::get(constResult->getFloat() - constOperand->getFloat());
else
result = builder.createFSubInst(result, operand);
}
} else {
ConstantInteger *constResult = dynamic_cast<ConstantInteger *>(result);
ConstantInteger *constOperand = dynamic_cast<ConstantInteger *>(operand);
if (opType == SysYParser::ADD) {
if ((constResult != nullptr) && (constOperand != nullptr))
result = ConstantInteger::get(constResult->getInt() + constOperand->getInt());
else
result = builder.createAddInst(result, operand);
} else {
if ((constResult != nullptr) && (constOperand != nullptr))
result = ConstantInteger::get(constResult->getInt() - constOperand->getInt());
else
result = builder.createSubInst(result, operand);
}
switch(opType) {
case SysYParser::ADD: BinaryExpStack.push_back(BinaryOp::ADD); BinaryExpLenStack.back()++; break;
case SysYParser::SUB: BinaryExpStack.push_back(BinaryOp::SUB); BinaryExpLenStack.back()++; break;
default: assert(false && "Unexpected operator in AddExp.");
}
visitMulExp(ctx->mulExp(i));
}
return result;
return std::any();
}
std::any SysYIRGenerator::visitRelExp(SysYParser::RelExpContext *ctx) {
Value* result = std::any_cast<Value *>(visitAddExp(ctx->addExp(0)));
Value* result = computeAddExp(ctx->addExp(0));
for (int i = 1; i < ctx->addExp().size(); i++) {
auto opNode = dynamic_cast<antlr4::tree::TerminalNode*>(ctx->children[2*i-1]);
int opType = opNode->getSymbol()->getType();
Value* operand = std::any_cast<Value *>(visitAddExp(ctx->addExp(i)));
Value* operand = computeAddExp(ctx->addExp(i));
Type* resultType = result->getType();
Type* operandType = operand->getType();
@@ -1366,7 +1668,7 @@ std::any SysYIRGenerator::visitRelExp(SysYParser::RelExpContext *ctx) {
}
}
else
result = builder.createIToFInst(result);
result = builder.createItoFInst(result);
}
if (operandType != floatType) {
@@ -1380,7 +1682,7 @@ std::any SysYIRGenerator::visitRelExp(SysYParser::RelExpContext *ctx) {
}
}
else
operand = builder.createIToFInst(operand);
operand = builder.createItoFInst(operand);
}
@@ -1407,6 +1709,7 @@ std::any SysYIRGenerator::visitRelExp(SysYParser::RelExpContext *ctx) {
std::any SysYIRGenerator::visitEqExp(SysYParser::EqExpContext *ctx) {
// TODO其实已经保证了result是一个int类型的值可以删除冗余判断逻辑
Value * result = std::any_cast<Value *>(visitRelExp(ctx->relExp(0)));
for (int i = 1; i < ctx->relExp().size(); i++) {
@@ -1445,7 +1748,7 @@ std::any SysYIRGenerator::visitEqExp(SysYParser::EqExpContext *ctx) {
}
}
else
result = builder.createIToFInst(result);
result = builder.createItoFInst(result);
}
if (operandType != floatType) {
if (constOperand != nullptr) {
@@ -1458,7 +1761,7 @@ std::any SysYIRGenerator::visitEqExp(SysYParser::EqExpContext *ctx) {
}
}
else
operand = builder.createIToFInst(operand);
operand = builder.createItoFInst(operand);
}
if (opType == SysYParser::EQ) result = builder.createFCmpEQInst(result, operand);
@@ -1567,7 +1870,7 @@ void Utils::tree2Array(Type *type, ArrayValueTree *root,
assert(false && "Unknown constant type for float conversion.");
}
else
result.push_back(builder->createIToFInst(value));
result.push_back(builder->createItoFInst(value));
} else {
ConstantValue* constValue = dynamic_cast<ConstantValue *>(value);

View File

@@ -1,7 +1,10 @@
#include "SysYIRPrinter.h"
#include <cassert>
#include <fstream>
#include <iomanip>
#include <iostream>
#include <limits>
#include <sstream>
#include <string>
#include "IR.h" // 确保IR.h包含了ArrayType、GetElementPtrInst等的定义
@@ -61,16 +64,21 @@ std::string SysYPrinter::getValueName(Value *value) {
} else if (auto constInt = dynamic_cast<ConstantInteger*>(value)) { // 优先匹配具体的常量类型
return std::to_string(constInt->getInt());
} else if (auto constFloat = dynamic_cast<ConstantFloating*>(value)) { // 优先匹配具体的常量类型
return std::to_string(constFloat->getFloat());
std::ostringstream oss;
oss << std::scientific << std::setprecision(std::numeric_limits<float>::max_digits10) << constFloat->getFloat();
return oss.str();
} else if (auto constUndef = dynamic_cast<UndefinedValue*>(value)) { // 如果有Undef类型
return "undef";
} else if (auto constVal = dynamic_cast<ConstantValue*>(value)) { // fallback for generic ConstantValue
// 这里的逻辑可能需要根据你ConstantValue的实际设计调整
// 确保它能处理所有可能的ConstantValue
if (constVal->getType()->isFloat()) {
return std::to_string(constVal->getFloat());
if (auto constInt = dynamic_cast<ConstantInteger*>(value)) { // 优先匹配具体的常量类型
return std::to_string(constInt->getInt());
} else if (auto constFloat = dynamic_cast<ConstantFloating*>(value)) { // 优先匹配具体的常量类型
std::ostringstream oss;
oss << std::scientific << std::setprecision(std::numeric_limits<float>::max_digits10) << constFloat->getFloat();
return oss.str();
}
return std::to_string(constVal->getInt());
} else if (auto constVar = dynamic_cast<ConstantVariable*>(value)) {
return constVar->getName(); // 假设ConstantVariable有自己的名字或通过getByIndices获取值
} else if (auto argVar = dynamic_cast<Argument*>(value)) {