Merge remote-tracking branch 'origin/backend' into midend

This commit is contained in:
rain2133
2025-08-01 14:06:20 +08:00
188 changed files with 611744 additions and 457 deletions

View File

@@ -1,45 +1,45 @@
#include "CalleeSavedHandler.h"
#include <set>
#include <vector> //
#include <vector>
#include <algorithm>
#include <iterator> //
#include <iterator>
namespace sysy {
char CalleeSavedHandler::ID = 0;
// 辅助函数,用于判断一个物理寄存器是否为浮点寄存器
static bool is_fp_reg(PhysicalReg reg) {
return reg >= PhysicalReg::F0 && reg <= PhysicalReg::F31;
}
bool CalleeSavedHandler::runOnFunction(Function *F, AnalysisManager& AM) {
// This pass works on MachineFunction level, not IR level
return false;
}
void CalleeSavedHandler::runOnMachineFunction(MachineFunction* mfunc) {
// 此 Pass 负责分析、分配栈空间并插入 callee-saved 寄存器的保存/恢复指令。
// 它通过与 FrameInfo 协作,确保为 callee-saved 寄存器分配的空间与局部变量/溢出槽的空间不冲突。
StackFrameInfo& frame_info = mfunc->getFrameInfo();
// [修改] 分别记录被使用的整数和浮点被调用者保存寄存器
std::set<PhysicalReg> used_int_callee_saved;
std::set<PhysicalReg> used_fp_callee_saved;
std::set<PhysicalReg> used_callee_saved;
// 1. 扫描所有指令,找出被使用的s寄存器 (s1-s11) 和 fs寄存器 (fs0-fs11)
// 1. 扫描所有指令,找出被使用的callee-saved寄存器
// 这个Pass在RegAlloc之后运行所以可以访问到物理寄存器
for (auto& mbb : mfunc->getBlocks()) {
for (auto& instr : mbb->getInstructions()) {
for (auto& op : instr->getOperands()) {
auto check_and_insert_reg = [&](RegOperand* reg_op) {
if (!reg_op->isVirtual()) {
if (reg_op && !reg_op->isVirtual()) {
PhysicalReg preg = reg_op->getPReg();
// [修改] 区分整数和浮点被调用者保存寄存器
// s0 由序言/尾声处理器专门处理,这里不计入
// 检查整数 s1-s11
if (preg >= PhysicalReg::S1 && preg <= PhysicalReg::S11) {
used_int_callee_saved.insert(preg);
used_callee_saved.insert(preg);
}
// fs0-fs11 在我们的枚举中对应 f8,f9,f18-f27
// 检查浮点 fs0-fs11 (f8,f9,f18-f27)
else if ((preg >= PhysicalReg::F8 && preg <= PhysicalReg::F9) || (preg >= PhysicalReg::F18 && preg <= PhysicalReg::F27)) {
used_fp_callee_saved.insert(preg);
used_callee_saved.insert(preg);
}
}
};
@@ -53,60 +53,44 @@ void CalleeSavedHandler::runOnMachineFunction(MachineFunction* mfunc) {
}
}
// 如果没有使用任何需要处理的 callee-saved 寄存器,则直接返回
if (used_int_callee_saved.empty() && used_fp_callee_saved.empty()) {
frame_info.callee_saved_size = 0; // 确保大小被初始化
if (used_callee_saved.empty()) {
frame_info.callee_saved_size = 0;
return;
}
// 2. 计算为 callee-saved 寄存器分配的栈空间大小
// 每个寄存器在RV64中都占用8字节
int callee_saved_size = (used_int_callee_saved.size() + used_fp_callee_saved.size()) * 8;
frame_info.callee_saved_size = callee_saved_size;
// 2. 计算并更新 frame_info
frame_info.callee_saved_size = used_callee_saved.size() * 8;
// 为了布局确定性和恢复顺序一致,对寄存器排序
std::vector<PhysicalReg> sorted_regs(used_callee_saved.begin(), used_callee_saved.end());
std::sort(sorted_regs.begin(), sorted_regs.end());
// 3. 在函数序言中插入保存指令
MachineBasicBlock* entry_block = mfunc->getBlocks().front().get();
auto& entry_instrs = entry_block->getInstructions();
// 插入点通常在函数入口标签之后
// 插入点在函数入口标签之后,或者就是最开始
auto insert_pos = entry_instrs.begin();
if (!entry_instrs.empty() && entry_instrs.front()->getOpcode() == RVOpcodes::LABEL) {
insert_pos = std::next(insert_pos);
}
// 为了布局确定性,对寄存器进行排序并按序保存
std::vector<PhysicalReg> sorted_int_regs(used_int_callee_saved.begin(), used_int_callee_saved.end());
std::vector<PhysicalReg> sorted_fp_regs(used_fp_callee_saved.begin(), used_fp_callee_saved.end());
std::sort(sorted_int_regs.begin(), sorted_int_regs.end());
std::sort(sorted_fp_regs.begin(), sorted_fp_regs.end());
std::vector<std::unique_ptr<MachineInstr>> save_instrs;
int current_offset = -16; // ra和s0已占用-8和-16从-24开始分配
// [关键] 从局部变量区域之后开始分配空间
int current_offset = - (16 + frame_info.locals_size);
// 准备整数保存指令 (sd)
for (PhysicalReg reg : sorted_int_regs) {
for (PhysicalReg reg : sorted_regs) {
current_offset -= 8;
auto sd = std::make_unique<MachineInstr>(RVOpcodes::SD);
sd->addOperand(std::make_unique<RegOperand>(reg));
sd->addOperand(std::make_unique<MemOperand>(
RVOpcodes save_op = is_fp_reg(reg) ? RVOpcodes::FSD : RVOpcodes::SD;
auto save_instr = std::make_unique<MachineInstr>(save_op);
save_instr->addOperand(std::make_unique<RegOperand>(reg));
save_instr->addOperand(std::make_unique<MemOperand>(
std::make_unique<RegOperand>(PhysicalReg::S0), // 基址为帧指针 s0
std::make_unique<ImmOperand>(current_offset)
));
save_instrs.push_back(std::move(sd));
}
// 准备浮点保存指令 (fsd)
for (PhysicalReg reg : sorted_fp_regs) {
current_offset -= 8;
auto fsd = std::make_unique<MachineInstr>(RVOpcodes::FSD); // 使用浮点保存指令
fsd->addOperand(std::make_unique<RegOperand>(reg));
fsd->addOperand(std::make_unique<MemOperand>(
std::make_unique<RegOperand>(PhysicalReg::S0),
std::make_unique<ImmOperand>(current_offset)
));
save_instrs.push_back(std::move(fsd));
save_instrs.push_back(std::move(save_instr));
}
// 一次性插入所有保存指令
if (!save_instrs.empty()) {
entry_instrs.insert(insert_pos,
std::make_move_iterator(save_instrs.begin()),
@@ -118,40 +102,27 @@ void CalleeSavedHandler::runOnMachineFunction(MachineFunction* mfunc) {
for (auto it = mbb->getInstructions().begin(); it != mbb->getInstructions().end(); ++it) {
if ((*it)->getOpcode() == RVOpcodes::RET) {
std::vector<std::unique_ptr<MachineInstr>> restore_instrs;
current_offset = -16; // 重置偏移量用于恢复
// [关键] 使用与保存时完全相同的逻辑来计算偏移量
current_offset = - (16 + frame_info.locals_size);
// 准备恢复整数寄存器 (ld) - 以与保存时相同的顺序
for (PhysicalReg reg : sorted_int_regs) {
for (PhysicalReg reg : sorted_regs) {
current_offset -= 8;
auto ld = std::make_unique<MachineInstr>(RVOpcodes::LD);
ld->addOperand(std::make_unique<RegOperand>(reg));
ld->addOperand(std::make_unique<MemOperand>(
RVOpcodes restore_op = is_fp_reg(reg) ? RVOpcodes::FLD : RVOpcodes::LD;
auto restore_instr = std::make_unique<MachineInstr>(restore_op);
restore_instr->addOperand(std::make_unique<RegOperand>(reg));
restore_instr->addOperand(std::make_unique<MemOperand>(
std::make_unique<RegOperand>(PhysicalReg::S0),
std::make_unique<ImmOperand>(current_offset)
));
restore_instrs.push_back(std::move(ld));
restore_instrs.push_back(std::move(restore_instr));
}
// 准备恢复浮点寄存器 (fld)
for (PhysicalReg reg : sorted_fp_regs) {
current_offset -= 8;
auto fld = std::make_unique<MachineInstr>(RVOpcodes::FLD); // 使用浮点加载指令
fld->addOperand(std::make_unique<RegOperand>(reg));
fld->addOperand(std::make_unique<MemOperand>(
std::make_unique<RegOperand>(PhysicalReg::S0),
std::make_unique<ImmOperand>(current_offset)
));
restore_instrs.push_back(std::move(fld));
}
// 一次性插入所有恢复指令
if (!restore_instrs.empty()) {
mbb->getInstructions().insert(it,
std::make_move_iterator(restore_instrs.begin()),
std::make_move_iterator(restore_instrs.end()));
}
// 处理完一个基本块的RET后迭代器已失效需跳出当前块的循环
goto next_block_label;
}
}

View File

@@ -1,12 +1,27 @@
#include "PrologueEpilogueInsertion.h"
#include "RISCv64ISel.h"
#include "RISCv64RegAlloc.h" // 需要访问RegAlloc的结果
#include <algorithm>
namespace sysy {
char PrologueEpilogueInsertionPass::ID = 0;
void PrologueEpilogueInsertionPass::runOnMachineFunction(MachineFunction* mfunc) {
for (auto& mbb : mfunc->getBlocks()) {
auto& instrs = mbb->getInstructions();
// 使用标准的 Erase-Remove Idiom 来删除满足条件的元素
instrs.erase(
std::remove_if(instrs.begin(), instrs.end(),
[](const std::unique_ptr<MachineInstr>& instr) {
return instr->getOpcode() == RVOpcodes::PSEUDO_KEEPALIVE;
}
),
instrs.end()
);
}
StackFrameInfo& frame_info = mfunc->getFrameInfo();
Function* F = mfunc->getFunc();
RISCv64ISel* isel = mfunc->getISel();
@@ -64,54 +79,35 @@ void PrologueEpilogueInsertionPass::runOnMachineFunction(MachineFunction* mfunc)
set_fp->addOperand(std::make_unique<ImmOperand>(aligned_stack_size));
prologue_instrs.push_back(std::move(set_fp));
// --- [正确逻辑] 在s0设置完毕后使用物理寄存器加载栈参数 ---
// --- 在s0设置完毕后使用物理寄存器加载栈参数 ---
if (F && isel) {
// 定义暂存寄存器
const PhysicalReg INT_SCRATCH_REG = PhysicalReg::T5;
const PhysicalReg FP_SCRATCH_REG = PhysicalReg::F7;
int arg_idx = 0;
for (Argument* arg : F->getArguments()) {
if (arg_idx >= 8) {
unsigned vreg = isel->getVReg(arg);
// 确认RegAlloc已经为这个vreg计算了偏移量并且分配了物理寄存器
if (frame_info.alloca_offsets.count(vreg) && vreg_to_preg_map.count(vreg)) {
int offset = frame_info.alloca_offsets.at(vreg);
PhysicalReg dest_preg = vreg_to_preg_map.at(vreg);
Type* arg_type = arg->getType();
// 根据类型执行不同的加载序列
if (arg_type->isFloat()) {
// 1. flw ft7, offset(s0)
auto load_arg = std::make_unique<MachineInstr>(RVOpcodes::FLW);
load_arg->addOperand(std::make_unique<RegOperand>(FP_SCRATCH_REG));
load_arg->addOperand(std::make_unique<RegOperand>(dest_preg));
load_arg->addOperand(std::make_unique<MemOperand>(
std::make_unique<RegOperand>(PhysicalReg::S0),
std::make_unique<ImmOperand>(offset)
));
prologue_instrs.push_back(std::move(load_arg));
// 2. fmv.s dest_preg, ft7
auto move_arg = std::make_unique<MachineInstr>(RVOpcodes::FMV_S);
move_arg->addOperand(std::make_unique<RegOperand>(dest_preg));
move_arg->addOperand(std::make_unique<RegOperand>(FP_SCRATCH_REG));
prologue_instrs.push_back(std::move(move_arg));
} else {
// 确定是加载32位(lw)还是64位(ld)
RVOpcodes load_op = arg_type->isPointer() ? RVOpcodes::LD : RVOpcodes::LW;
// 1. lw/ld t5, offset(s0)
auto load_arg = std::make_unique<MachineInstr>(load_op);
load_arg->addOperand(std::make_unique<RegOperand>(INT_SCRATCH_REG));
load_arg->addOperand(std::make_unique<RegOperand>(dest_preg));
load_arg->addOperand(std::make_unique<MemOperand>(
std::make_unique<RegOperand>(PhysicalReg::S0),
std::make_unique<ImmOperand>(offset)
));
prologue_instrs.push_back(std::move(load_arg));
// 2. mv dest_preg, t5
auto move_arg = std::make_unique<MachineInstr>(RVOpcodes::MV);
move_arg->addOperand(std::make_unique<RegOperand>(dest_preg));
move_arg->addOperand(std::make_unique<RegOperand>(INT_SCRATCH_REG));
prologue_instrs.push_back(std::move(move_arg));
}
}
}

View File

@@ -144,6 +144,9 @@ void RISCv64AsmPrinter::printInstruction(MachineInstr* instr, bool debug) {
case RVOpcodes::FRAME_STORE_F:
if (!debug) throw std::runtime_error("FRAME_STORE_F not eliminated before AsmPrinter");
*OS << "frame_store_f "; break;
case RVOpcodes::PSEUDO_KEEPALIVE:
if (!debug) throw std::runtime_error("PSEUDO_KEEPALIVE not eliminated before AsmPrinter");
*OS << "keepalive "; break;
default:
throw std::runtime_error("Unknown opcode in AsmPrinter");
}

View File

@@ -12,11 +12,28 @@ std::string RISCv64CodeGen::code_gen() {
return module_gen();
}
// 模块级代码生成
void printInitializer(std::stringstream& ss, const ValueCounter& init_values) {
for (size_t i = 0; i < init_values.getValues().size(); ++i) {
auto val = init_values.getValues()[i];
auto count = init_values.getNumbers()[i];
if (auto constant = dynamic_cast<ConstantValue*>(val)) {
for (unsigned j = 0; j < count; ++j) {
if (constant->isInt()) {
ss << " .word " << constant->getInt() << "\n";
} else {
float f = constant->getFloat();
uint32_t float_bits = *(uint32_t*)&f;
ss << " .word " << float_bits << "\n";
}
}
}
}
}
std::string RISCv64CodeGen::module_gen() {
std::stringstream ss;
// --- 步骤1将全局变量分为.data和.bss两组 ---
// --- 步骤1将全局变量(GlobalValue)分为.data和.bss两组 ---
std::vector<GlobalValue*> data_globals;
std::vector<GlobalValue*> bss_globals;
@@ -41,7 +58,7 @@ std::string RISCv64CodeGen::module_gen() {
}
}
// --- 步骤2生成 .bss 段的代码 ---
// --- 步骤2生成 .bss 段的代码 (这部分不变) ---
if (!bss_globals.empty()) {
ss << ".bss\n";
for (GlobalValue* global : bss_globals) {
@@ -57,28 +74,16 @@ std::string RISCv64CodeGen::module_gen() {
}
}
// --- 步骤3生成 .data 段的代码 ---
if (!data_globals.empty()) {
ss << ".data\n"; // 切换到 .data 段
// --- [修改] 步骤3生成 .data 段的代码 ---
// 我们需要检查 data_globals 和 常量列表是否都为空
if (!data_globals.empty() || !module->getConsts().empty()) {
ss << ".data\n";
// a. 先处理普通的全局变量 (GlobalValue)
for (GlobalValue* global : data_globals) {
ss << ".globl " << global->getName() << "\n";
ss << global->getName() << ":\n";
const auto& init_values = global->getInitValues();
for (size_t i = 0; i < init_values.getValues().size(); ++i) {
auto val = init_values.getValues()[i];
auto count = init_values.getNumbers()[i];
if (auto constant = dynamic_cast<ConstantValue*>(val)) {
for (unsigned j = 0; j < count; ++j) {
if (constant->isInt()) {
ss << " .word " << constant->getInt() << "\n";
} else {
float f = constant->getFloat();
uint32_t float_bits = *(uint32_t*)&f;
ss << " .word " << float_bits << "\n";
}
}
}
}
printInitializer(ss, global->getInitValues());
}
// b. [新增] 再处理全局常量 (ConstantVariable)
@@ -86,27 +91,11 @@ std::string RISCv64CodeGen::module_gen() {
ConstantVariable* cnst = const_ptr.get();
ss << ".globl " << cnst->getName() << "\n";
ss << cnst->getName() << ":\n";
const auto& init_values = cnst->getInitValues();
// 这部分逻辑和处理 GlobalValue 完全相同
for (size_t i = 0; i < init_values.getValues().size(); ++i) {
auto val = init_values.getValues()[i];
auto count = init_values.getNumbers()[i];
if (auto constant = dynamic_cast<ConstantValue*>(val)) {
for (unsigned j = 0; j < count; ++j) {
if (constant->isInt()) {
ss << " .word " << constant->getInt() << "\n";
} else {
float f = constant->getFloat();
uint32_t float_bits = *(uint32_t*)&f;
ss << " .word " << float_bits << "\n";
}
}
}
}
printInitializer(ss, cnst->getInitValues());
}
}
// --- 处理函数 (.text段) ---
// --- 处理函数 (.text段) 的逻辑保持不变 ---
if (!module->getFunctions().empty()) {
ss << ".text\n";
for (const auto& func_pair : module->getFunctions()) {

View File

@@ -152,20 +152,48 @@ void RISCv64ISel::selectBasicBlock(BasicBlock* bb) {
for (const auto& inst_ptr : bb->getInstructions()) {
DAGNode* node_to_select = nullptr;
if (value_to_node.count(inst_ptr.get())) {
node_to_select = value_to_node.at(inst_ptr.get());
auto it = value_to_node.find(inst_ptr.get());
if (it != value_to_node.end()) {
node_to_select = it->second;
} else {
for(const auto& node : dag) {
if(node->value == inst_ptr.get()) {
node_to_select = node.get();
break;
}
for(const auto& node : dag) {
if(node->value == inst_ptr.get()) {
node_to_select = node.get();
break;
}
}
if(node_to_select) {
}
if(node_to_select) {
select_recursive(node_to_select);
}
}
if (CurMBB == MFunc->getBlocks().front().get()) { // 只对入口块操作
auto keepalive = std::make_unique<MachineInstr>(RVOpcodes::PSEUDO_KEEPALIVE);
for (Argument* arg : F->getArguments()) {
keepalive->addOperand(std::make_unique<RegOperand>(getVReg(arg)));
}
auto& instrs = CurMBB->getInstructions();
auto insert_pos = instrs.end();
// 关键:检查基本块是否以一个“终止指令”结尾
if (!instrs.empty()) {
RVOpcodes last_op = instrs.back()->getOpcode();
// 扩充了判断条件,涵盖所有可能的终止指令
if (last_op == RVOpcodes::J || last_op == RVOpcodes::RET ||
last_op == RVOpcodes::BEQ || last_op == RVOpcodes::BNE ||
last_op == RVOpcodes::BLT || last_op == RVOpcodes::BGE ||
last_op == RVOpcodes::BLTU || last_op == RVOpcodes::BGEU)
{
// 如果是,插入点就在这个终止指令之前
insert_pos = std::prev(instrs.end());
}
}
// 在计算出的正确位置插入伪指令
instrs.insert(insert_pos, std::move(keepalive));
}
}
// 核心函数为DAG节点选择并生成MachineInstr (已修复和增强的完整版本)

View File

@@ -165,10 +165,6 @@ void RISCv64RegAlloc::handleCallingConvention() {
*/
void RISCv64RegAlloc::eliminateFrameIndices() {
StackFrameInfo& frame_info = MFunc->getFrameInfo();
// 初始偏移量为保存ra和s0留出空间。
// 假设序言是 addi sp, sp, -stack_size; sd ra, stack_size-8(sp); sd s0, stack_size-16(sp);
int current_offset = 16;
Function* F = MFunc->getFunc();
RISCv64ISel* isel = MFunc->getISel();
@@ -190,12 +186,15 @@ void RISCv64RegAlloc::eliminateFrameIndices() {
}
}
// 处理局部变量
// 遍历AllocaInst来计算局部变量所需的总空间
// [关键修改] 为局部变量分配空间时起始点必须考虑为ra, s0以及所有callee-saved寄存器预留的空间。
// 布局顺序为: [s0/ra, 16字节] -> [callee-saved, callee_saved_size字节] -> [局部变量...]
int local_var_offset = 16 + frame_info.callee_saved_size;
int locals_start_offset = local_var_offset; // 记录局部变量区域的起始点,用于计算总大小
// 处理局部变量 (AllocaInst)
for (auto& bb : F->getBasicBlocks()) {
for (auto& inst : bb->getInstructions()) {
if (auto alloca = dynamic_cast<AllocaInst*>(inst.get())) {
// 获取Alloca指令指向的类型 (例如 alloca i32* 中,获取 i32)
Type* allocated_type = alloca->getType()->as<PointerType>()->getBaseType();
int size = getTypeSizeInBytes(allocated_type);
@@ -203,14 +202,17 @@ void RISCv64RegAlloc::eliminateFrameIndices() {
size = (size + 7) & ~7;
if (size == 0) size = 8; // 至少分配8字节
current_offset += size;
local_var_offset += size;
unsigned alloca_vreg = isel->getVReg(alloca);
// 局部变量使用相对于s0的负向偏移
frame_info.alloca_offsets[alloca_vreg] = -current_offset;
frame_info.alloca_offsets[alloca_vreg] = -local_var_offset;
}
}
}
frame_info.locals_size = current_offset;
// [修复] 正确计算并设置locals_size
// 它只应该包含由AllocaInst分配的局部变量的总大小。
frame_info.locals_size = local_var_offset - locals_start_offset;
// 遍历所有机器指令,将伪指令展开为真实指令
for (auto& mbb : MFunc->getBlocks()) {
@@ -349,6 +351,18 @@ void RISCv64RegAlloc::getInstrUseDef(MachineInstr* instr, LiveSet& use, LiveSet&
bool first_reg_is_def = true; // 默认情况下,指令的第一个寄存器操作数是定义 (def)
auto opcode = instr->getOpcode();
if (opcode == RVOpcodes::PSEUDO_KEEPALIVE) {
for (auto& op : instr->getOperands()) {
if (op->getKind() == MachineOperand::KIND_REG) {
auto reg_op = static_cast<RegOperand*>(op.get());
if (reg_op->isVirtual()) {
use.insert(reg_op->getVRegNum()); // 它的所有操作数都是 "use"
}
}
}
return; // 处理完毕
}
// 1. 特殊指令的 `is_def` 标志调整
// 这些指令的第一个寄存器操作数是源操作数 (use),而不是目标操作数 (def)。
if (opcode == RVOpcodes::SW || opcode == RVOpcodes::SD || opcode == RVOpcodes::FSW ||
@@ -384,13 +398,15 @@ void RISCv64RegAlloc::getInstrUseDef(MachineInstr* instr, LiveSet& use, LiveSet&
}
} else { // [修复] CALL指令也可能定义物理寄存器如a0
if (first_reg_operand_is_def) {
if (preg_to_vreg_id_map.count(reg_op->getPReg())) {
def.insert(preg_to_vreg_id_map.at(reg_op->getPReg()));
}
auto it = preg_to_vreg_id_map.find(reg_op->getPReg());
if (it != preg_to_vreg_id_map.end()) {
def.insert(it->second);
}
first_reg_operand_is_def = false;
} else {
if (preg_to_vreg_id_map.count(reg_op->getPReg())) {
use.insert(preg_to_vreg_id_map.at(reg_op->getPReg()));
auto it = preg_to_vreg_id_map.find(reg_op->getPReg());
if (it != preg_to_vreg_id_map.end()) {
use.insert(it->second);
}
}
}
@@ -430,9 +446,10 @@ void RISCv64RegAlloc::getInstrUseDef(MachineInstr* instr, LiveSet& use, LiveSet&
if (reg_op->isVirtual()) {
def.insert(reg_op->getVRegNum());
} else { // 物理寄存器也可以是 Def
if (preg_to_vreg_id_map.count(reg_op->getPReg())) {
def.insert(preg_to_vreg_id_map.at(reg_op->getPReg()));
}
auto it = preg_to_vreg_id_map.find(reg_op->getPReg());
if (it != preg_to_vreg_id_map.end()) {
def.insert(it->second);
}
}
first_reg_is_def = false; // **关键**:处理完第一个寄存器后,立即更新标志
} else {
@@ -440,8 +457,9 @@ void RISCv64RegAlloc::getInstrUseDef(MachineInstr* instr, LiveSet& use, LiveSet&
if (reg_op->isVirtual()) {
use.insert(reg_op->getVRegNum());
} else { // 物理寄存器也可以是 Use
if (preg_to_vreg_id_map.count(reg_op->getPReg())) {
use.insert(preg_to_vreg_id_map.at(reg_op->getPReg()));
auto it = preg_to_vreg_id_map.find(reg_op->getPReg());
if (it != preg_to_vreg_id_map.end()) {
use.insert(it->second);
}
}
}
@@ -453,8 +471,9 @@ void RISCv64RegAlloc::getInstrUseDef(MachineInstr* instr, LiveSet& use, LiveSet&
use.insert(base_reg->getVRegNum());
} else {
PhysicalReg preg = base_reg->getPReg();
if (preg_to_vreg_id_map.count(preg)) {
use.insert(preg_to_vreg_id_map.at(preg));
auto it = preg_to_vreg_id_map.find(preg);
if (it != preg_to_vreg_id_map.end()) {
use.insert(it->second);
}
}
@@ -466,8 +485,9 @@ void RISCv64RegAlloc::getInstrUseDef(MachineInstr* instr, LiveSet& use, LiveSet&
if (src_reg_op->isVirtual()) {
use.insert(src_reg_op->getVRegNum());
} else {
if (preg_to_vreg_id_map.count(src_reg_op->getPReg())) {
use.insert(preg_to_vreg_id_map.at(src_reg_op->getPReg()));
auto it = preg_to_vreg_id_map.find(src_reg_op->getPReg());
if (it != preg_to_vreg_id_map.end()) {
use.insert(it->second);
}
}
}
@@ -671,6 +691,22 @@ void RISCv64RegAlloc::buildInterferenceGraph() {
}
}
// 所有在某一点上同时活跃的寄存器即live_out集合中的所有成员
// 它们之间必须两两互相干扰。
// 这会根据我们修正后的 liveness 信息在所有参数vreg之间构建一个完全图clique
std::vector<unsigned> live_out_vec(live_out.begin(), live_out.end());
for (size_t i = 0; i < live_out_vec.size(); ++i) {
for (size_t j = i + 1; j < live_out_vec.size(); ++j) {
unsigned u = live_out_vec[i];
unsigned v = live_out_vec[j];
if (DEEPDEBUG && interference_graph[u].find(v) == interference_graph[u].end()) {
std::cerr << " Edge (Live-Live): %vreg" << u << " <-> %vreg" << v << "\n";
}
interference_graph[u].insert(v);
interference_graph[v].insert(u);
}
}
// 在非move指令中def 与 use 互相干扰
if (instr->getOpcode() != RVOpcodes::MV) {
for (unsigned d : def) {

View File

@@ -92,7 +92,7 @@ enum class RVOpcodes {
FMV_X_W, // fmv.x.w rd, rs1 (浮点寄存器位模式 -> 整数寄存器)
FNEG_S, // fneg.s rd, rs (浮点取负)
// 新增伪指令,用于解耦栈帧处理
// 伪指令
FRAME_LOAD_W, // 从栈帧加载 32位 Word (对应 lw)
FRAME_LOAD_D, // 从栈帧加载 64位 Doubleword (对应 ld)
FRAME_STORE_W, // 保存 32位 Word 到栈帧 (对应 sw)
@@ -100,6 +100,7 @@ enum class RVOpcodes {
FRAME_LOAD_F, // 从栈帧加载单精度浮点数
FRAME_STORE_F, // 将单精度浮点数存入栈帧
FRAME_ADDR, // 获取栈帧变量的地址
PSEUDO_KEEPALIVE, // 保持寄存器活跃,防止优化器删除
};
inline bool isGPR(PhysicalReg reg) {
@@ -279,7 +280,8 @@ struct StackFrameInfo {
std::map<unsigned, int> alloca_offsets; // <AllocaInst的vreg, 栈偏移>
std::map<unsigned, int> spill_offsets; // <溢出vreg, 栈偏移>
std::set<PhysicalReg> used_callee_saved_regs; // 使用的保存寄存器
std::map<unsigned, PhysicalReg> vreg_to_preg_map;
std::map<unsigned, PhysicalReg> vreg_to_preg_map;
std::vector<PhysicalReg> callee_saved_regs; // 用于存储需要保存的被调用者保存寄存器列表
};
// 机器函数

View File

@@ -126,7 +126,7 @@ class IRBuilder {
UnaryInst * createFNotInst(Value *operand, const std::string &name = "") {
return createUnaryInst(Instruction::kFNot, Type::getIntType(), operand, name);
} ///< 创建浮点取非指令
UnaryInst * createIToFInst(Value *operand, const std::string &name = "") {
UnaryInst * createItoFInst(Value *operand, const std::string &name = "") {
return createUnaryInst(Instruction::kItoF, Type::getFloatType(), operand, name);
} ///< 创建整型转浮点指令
UnaryInst * createBitItoFInst(Value *operand, const std::string &name = "") {

View File

@@ -59,6 +59,35 @@ private:
std::unique_ptr<Module> module;
IRBuilder builder;
using ValueOrOperator = std::variant<Value*, int>;
std::vector<ValueOrOperator> BinaryExpStack; ///< 用于存储二元表达式的中缀表达式
std::vector<int> BinaryExpLenStack; ///< 用于存储该层次的二元表达式的长度
// 下面是用于后缀表达式的计算的数据结构
std::vector<ValueOrOperator> BinaryRPNStack; ///< 用于存储二元表达式的后缀表达式
std::vector<int> BinaryOpStack; ///< 用于存储二元表达式中缀表达式转换到后缀表达式的操作符栈
std::vector<Value *> BinaryValueStack; ///< 用于存储后缀表达式计算的操作数栈
// 约定操作符:
// 1: 'ADD', 2: 'SUB', 3: 'MUL', 4: 'DIV', 5: '%', 6: 'PLUS', 7: 'NEG', 8: 'NOT', 9: 'LPAREN', 10: 'RPAREN'
// 这里的操作符是为了方便后缀表达式的计算而设计
// 其中,'ADD', 'SUB', 'MUL', 'DIV', '%'
// 分别对应加法、减法、乘法、除法和取模
// 'PLUS' 和 'NEG' 分别对应一元加法和一元减法
// 'NOT' 对应逻辑非
// 'LPAREN' 和 'RPAREN' 分别对应左括号和右括号
enum BinaryOp {
ADD = 1, SUB = 2, MUL = 3, DIV = 4, MOD = 5, PLUS = 6, NEG = 7, NOT = 8, LPAREN = 9, RPAREN = 10,
};
int getOperatorPrecedence(int op) {
switch (op) {
case MUL: case DIV: case MOD: return 2;
case ADD: case SUB: return 1;
case PLUS: case NEG: case NOT: return 3;
case LPAREN: case RPAREN: return 0; // Parentheses have lowest precedence for stack logic
default: return -1; // Unknown operator
}
}
public:
SysYIRGenerator() = default;
@@ -97,7 +126,7 @@ public:
std::any visitBlockStmt(SysYParser::BlockStmtContext* ctx) override;
// std::any visitStmt(SysYParser::StmtContext *ctx) override;
std::any visitAssignStmt(SysYParser::AssignStmtContext *ctx) override;
// std::any visitExpStmt(SysYParser::ExpStmtContext *ctx) override;
std::any visitExpStmt(SysYParser::ExpStmtContext *ctx) override;
// std::any visitBlkStmt(SysYParser::BlkStmtContext *ctx) override;
std::any visitIfStmt(SysYParser::IfStmtContext *ctx) override;
std::any visitWhileStmt(SysYParser::WhileStmtContext *ctx) override;
@@ -131,8 +160,13 @@ public:
std::any visitLAndExp(SysYParser::LAndExpContext *ctx) override;
std::any visitLOrExp(SysYParser::LOrExpContext *ctx) override;
// std::any visitConstExp(SysYParser::ConstExpContext *ctx) override;
std::any visitConstExp(SysYParser::ConstExpContext *ctx) override;
bool isRightAssociative(int op);
Value* promoteType(Value* value, Type* targetType);
Value* computeExp(SysYParser::ExpContext *ctx, Type* targetType = nullptr);
Value* computeAddExp(SysYParser::AddExpContext *ctx, Type* targetType = nullptr);
void compute();
public:
// 获取GEP指令的地址
Value* getGEPAddressInst(Value* basePointer, const std::vector<Value*>& indices);
@@ -141,6 +175,7 @@ public:
unsigned countArrayDimensions(Type* type);
}; // class SysYIRGenerator
} // namespace sysy

View File

@@ -16,6 +16,441 @@ using namespace std;
namespace sysy {
// std::vector<Value*> BinaryValueStack; ///< 用于存储value的栈
// std::vector<int> BinaryOpStack; ///< 用于存储二元表达式的操作符栈
// // 约定操作符:
// // 1: 'ADD', 2: 'SUB', 3: 'MUL', 4: 'DIV', 5: '%', 6: 'PLUS', 7: 'NEG', 8: 'NOT'
// enum BinaryOp {
// ADD = 1,
// SUB = 2,
// MUL = 3,
// DIV = 4,
// MOD = 5,
// PLUS = 6,
// NEG = 7,
// NOT = 8
// };
Value *SysYIRGenerator::promoteType(Value *value, Type *targetType) {
//如果是常量则直接返回相应的值
if (targetType == nullptr) {
return value; // 如果值为空,那就不需要转换
}
ConstantInteger* constInt = dynamic_cast<ConstantInteger *>(value);
ConstantFloating *constFloat = dynamic_cast<ConstantFloating *>(value);
if (constInt) {
if (targetType->isFloat()) {
return ConstantFloating::get(static_cast<float>(constInt->getInt()));
}
return constInt; // 如果目标类型是int直接返回原值
} else if (constFloat) {
if (targetType->isInt()) {
return ConstantInteger::get(static_cast<int>(constFloat->getFloat()));
}
return constFloat; // 如果目标类型是float直接返回原值
}
if (value->getType()->isInt() && targetType->isFloat()) {
return builder.createItoFInst(value);
} else if (value->getType()->isFloat() && targetType->isInt()) {
return builder.createFtoIInst(value);
}
// 如果类型已经匹配,直接返回原值
return value;
}
bool SysYIRGenerator::isRightAssociative(int op) {
return (op == BinaryOp::PLUS || op == BinaryOp::NEG || op == BinaryOp::NOT);
}
void SysYIRGenerator::compute() {
// 先将中缀表达式转换为后缀表达式
BinaryRPNStack.clear();
BinaryOpStack.clear();
int begin = BinaryExpStack.size() - BinaryExpLenStack.back(), end = BinaryExpStack.size();
for (int i = begin; i < end; i++) {
auto item = BinaryExpStack[i];
if (std::holds_alternative<sysy::Value *>(item)) {
// 如果是操作数 (Value*),直接推入后缀表达式栈
BinaryRPNStack.push_back(item); // 直接 push_back item (ValueOrOperator类型)
} else {
// 如果是操作符
int currentOp = std::get<int>(item);
if (currentOp == LPAREN) {
// 左括号直接入栈
BinaryOpStack.push_back(currentOp);
} else if (currentOp == RPAREN) {
// 右括号:将操作符栈中的操作符弹出并添加到后缀表达式栈,直到遇到左括号
while (!BinaryOpStack.empty() && BinaryOpStack.back() != LPAREN) {
BinaryRPNStack.push_back(BinaryOpStack.back()); // 直接 push_back int
BinaryOpStack.pop_back();
}
if (!BinaryOpStack.empty() && BinaryOpStack.back() == LPAREN) {
BinaryOpStack.pop_back(); // 弹出左括号,但不添加到后缀表达式栈
} else {
// 错误:不匹配的右括号
std::cerr << "Error: Mismatched parentheses in expression." << std::endl;
return;
}
} else {
// 普通操作符
while (!BinaryOpStack.empty() && BinaryOpStack.back() != LPAREN) {
int stackTopOp = BinaryOpStack.back();
// 如果当前操作符优先级低于栈顶操作符优先级
// 或者 (当前操作符优先级等于栈顶操作符优先级 并且 栈顶操作符是左结合)
if (getOperatorPrecedence(currentOp) < getOperatorPrecedence(stackTopOp) ||
(getOperatorPrecedence(currentOp) == getOperatorPrecedence(stackTopOp) &&
!isRightAssociative(stackTopOp))) {
BinaryRPNStack.push_back(stackTopOp);
BinaryOpStack.pop_back();
} else {
break; // 否则当前操作符入栈
}
}
BinaryOpStack.push_back(currentOp); // 当前操作符入栈
}
}
}
// 遍历结束后,将操作符栈中剩余的所有操作符弹出并添加到后缀表达式栈
while (!BinaryOpStack.empty()) {
if (BinaryOpStack.back() == LPAREN) {
// 错误:不匹配的左括号
std::cerr << "Error: Mismatched parentheses in expression (unclosed parenthesis)." << std::endl;
return;
}
BinaryRPNStack.push_back(BinaryOpStack.back()); // 直接 push_back int
BinaryOpStack.pop_back();
}
// 弹出BinaryExpStack的表达式
while(begin < end) {
BinaryExpStack.pop_back();
BinaryExpLenStack.back()--;
end--;
}
// 计算后缀表达式
// 每次计算前清空操作数栈
BinaryValueStack.clear();
// 遍历后缀表达式栈
Type *commonType = nullptr;
for(const auto &item : BinaryRPNStack) {
if (std::holds_alternative<Value *>(item)) {
// 如果是操作数 (Value*) 检测他的类型
Value *value = std::get<Value *>(item);
if (commonType == nullptr) {
commonType = value->getType();
}
else if (value->getType() != commonType && value->getType()->isFloat()) {
// 如果当前值的类型与commonType不同且是float类型则提升为float
commonType = Type::getFloatType();
break;
}
} else {
continue;
}
}
for (const auto &item : BinaryRPNStack) {
if (std::holds_alternative<sysy::Value *>(item)) {
// 如果是操作数 (Value*),直接推入操作数栈
BinaryValueStack.push_back(std::get<sysy::Value *>(item));
} else {
// 如果是操作符
int op = std::get<int>(item);
Value *resultValue = nullptr;
Value *lhs = nullptr;
Value *rhs = nullptr;
Value *operand = nullptr;
switch (op) {
case BinaryOp::ADD:
case BinaryOp::SUB:
case BinaryOp::MUL:
case BinaryOp::DIV:
case BinaryOp::MOD: {
// 二元操作符需要两个操作数
if (BinaryValueStack.size() < 2) {
std::cerr << "Error: Not enough operands for binary operation: " << op << std::endl;
return; // 或者抛出异常
}
rhs = BinaryValueStack.back();
BinaryValueStack.pop_back();
lhs = BinaryValueStack.back();
BinaryValueStack.pop_back();
// 类型转换
lhs = promoteType(lhs, commonType);
rhs = promoteType(rhs, commonType);
// 尝试常量折叠
ConstantValue *lhsConst = dynamic_cast<ConstantValue *>(lhs);
ConstantValue *rhsConst = dynamic_cast<ConstantValue *>(rhs);
if (lhsConst && rhsConst) {
// 如果都是常量,直接计算结果
if (commonType == Type::getIntType()) {
int lhsVal = lhsConst->getInt();
int rhsVal = rhsConst->getInt();
switch (op) {
case BinaryOp::ADD: resultValue = ConstantInteger::get(lhsVal + rhsVal); break;
case BinaryOp::SUB: resultValue = ConstantInteger::get(lhsVal - rhsVal); break;
case BinaryOp::MUL: resultValue = ConstantInteger::get(lhsVal * rhsVal); break;
case BinaryOp::DIV:
if (rhsVal == 0) {
std::cerr << "Error: Division by zero." << std::endl;
return;
}
resultValue = sysy::ConstantInteger::get(lhsVal / rhsVal); break;
case BinaryOp::MOD:
if (rhsVal == 0) {
std::cerr << "Error: Modulo by zero." << std::endl;
return;
}
resultValue = sysy::ConstantInteger::get(lhsVal % rhsVal); break;
default:
std::cerr << "Error: Unknown binary operator for constants: " << op << std::endl;
return;
}
} else if (commonType == Type::getFloatType()) {
float lhsVal = lhsConst->getFloat();
float rhsVal = rhsConst->getFloat();
switch (op) {
case BinaryOp::ADD: resultValue = ConstantFloating::get(lhsVal + rhsVal); break;
case BinaryOp::SUB: resultValue = ConstantFloating::get(lhsVal - rhsVal); break;
case BinaryOp::MUL: resultValue = ConstantFloating::get(lhsVal * rhsVal); break;
case BinaryOp::DIV:
if (rhsVal == 0.0f) {
std::cerr << "Error: Division by zero." << std::endl;
return;
}
resultValue = sysy::ConstantFloating::get(lhsVal / rhsVal); break;
case BinaryOp::MOD:
std::cerr << "Error: Modulo operator not supported for float types." << std::endl;
return;
default:
std::cerr << "Error: Unknown binary operator for float constants: " << op << std::endl;
return;
}
} else {
std::cerr << "Error: Unsupported type for binary constant operation." << std::endl;
return;
}
} else {
// 否则创建相应的IR指令
if (commonType == Type::getIntType()) {
switch (op) {
case BinaryOp::ADD: resultValue = builder.createAddInst(lhs, rhs); break;
case BinaryOp::SUB: resultValue = builder.createSubInst(lhs, rhs); break;
case BinaryOp::MUL: resultValue = builder.createMulInst(lhs, rhs); break;
case BinaryOp::DIV: resultValue = builder.createDivInst(lhs, rhs); break;
case BinaryOp::MOD: resultValue = builder.createRemInst(lhs, rhs); break;
}
} else if (commonType == Type::getFloatType()) {
switch (op) {
case BinaryOp::ADD: resultValue = builder.createFAddInst(lhs, rhs); break;
case BinaryOp::SUB: resultValue = builder.createFSubInst(lhs, rhs); break;
case BinaryOp::MUL: resultValue = builder.createFMulInst(lhs, rhs); break;
case BinaryOp::DIV: resultValue = builder.createFDivInst(lhs, rhs); break;
case BinaryOp::MOD:
std::cerr << "Error: Modulo operator not supported for float types." << std::endl;
return;
}
} else {
std::cerr << "Error: Unsupported type for binary instruction." << std::endl;
return;
}
}
break;
}
case BinaryOp::PLUS:
case BinaryOp::NEG:
case BinaryOp::NOT: {
// 一元操作符需要一个操作数
if (BinaryValueStack.empty()) {
std::cerr << "Error: Not enough operands for unary operation: " << op << std::endl;
return;
}
operand = BinaryValueStack.back();
BinaryValueStack.pop_back();
operand = promoteType(operand, commonType);
// 尝试常量折叠
ConstantInteger *constInt = dynamic_cast<ConstantInteger *>(operand);
ConstantFloating *constFloat = dynamic_cast<ConstantFloating *>(operand);
if (constInt || constFloat) {
// 如果是常量,直接计算结果
switch (op) {
case BinaryOp::PLUS: resultValue = operand; break;
case BinaryOp::NEG: {
if (constInt) {
resultValue = constInt->getNeg();
} else if (constFloat) {
resultValue = constFloat->getNeg();
} else {
std::cerr << "Error: Negation not supported for constant operand type." << std::endl;
return;
}
break;
}
case BinaryOp::NOT:
if (constInt) {
resultValue = sysy::ConstantInteger::get(constInt->getInt() == 0 ? 1 : 0);
} else if (constFloat) {
resultValue = sysy::ConstantInteger::get(constFloat->getFloat() == 0.0f ? 1 : 0);
} else {
std::cerr << "Error: Logical NOT not supported for constant operand type." << std::endl;
return;
}
break;
default:
std::cerr << "Error: Unknown unary operator for constants: " << op << std::endl;
return;
}
} else {
// 否则创建相应的IR指令
switch (op) {
case BinaryOp::PLUS:
resultValue = operand; // 一元加指令通常直接返回操作数
break;
case BinaryOp::NEG: {
if (commonType == sysy::Type::getIntType()) {
resultValue = builder.createNegInst(operand);
} else if (commonType == sysy::Type::getFloatType()) {
resultValue = builder.createFNegInst(operand);
} else {
std::cerr << "Error: Negation not supported for operand type." << std::endl;
return;
}
break;
}
case BinaryOp::NOT:
// 逻辑非
if (commonType == sysy::Type::getIntType()) {
resultValue = builder.createNotInst(operand);
} else if (commonType == sysy::Type::getFloatType()) {
resultValue = builder.createFNotInst(operand);
} else {
std::cerr << "Error: Logical NOT not supported for operand type." << std::endl;
return;
}
break;
default:
std::cerr << "Error: Unknown unary operator for instructions: " << op << std::endl;
return;
}
}
break;
}
default:
std::cerr << "Error: Unknown operator " << op << " encountered in RPN stack." << std::endl;
return;
}
// 将计算结果或指令结果推入操作数栈
if (resultValue) {
BinaryValueStack.push_back(resultValue);
} else {
std::cerr << "Error: Result value is null after processing operator " << op << "!" << std::endl;
return;
}
}
}
// 后缀表达式处理完毕,操作数栈的栈顶就是最终结果
if (BinaryValueStack.empty()) {
std::cerr << "Error: No values left in BinaryValueStack after processing RPN." << std::endl;
return;
}
if (BinaryValueStack.size() > 1) {
std::cerr
<< "Warning: Multiple values left in BinaryValueStack after processing RPN. Expression might be malformed."
<< std::endl;
}
BinaryRPNStack.clear(); // 清空后缀表达式栈
BinaryOpStack.clear(); // 清空操作符栈
return;
}
Value* SysYIRGenerator::computeExp(SysYParser::ExpContext *ctx, Type* targetType){
if (ctx->addExp() == nullptr) {
assert(false && "ExpContext should have an addExp child!");
}
BinaryExpLenStack.push_back(0); // 进入新的层次时Push 0
visitAddExp(ctx->addExp());
// if(targetType == nullptr) {
// targetType = Type::getIntType(); // 默认目标类型为int
// }
compute();
// 最后一个Value应该是最终结果
Value* result = BinaryValueStack.back();
BinaryValueStack.pop_back(); // 移除结果值
result = promoteType(result, targetType); // 确保结果类型符合目标类型
// 检查当前层次的操作符数量
int ExpLen = BinaryExpLenStack.back();
BinaryExpLenStack.pop_back(); // 离开层次时将该层次
if (ExpLen > 0) {
std::cerr << "Warning: There are still " << ExpLen << " binary val or op left unprocessed in this level!" << std::endl;
return nullptr;
}
return result;
}
Value* SysYIRGenerator::computeAddExp(SysYParser::AddExpContext *ctx, Type* targetType){
// 根据AddExpContext中的操作符和操作数计算加法表达式
// 这里假设AddExpContext已经被正确填充
if (ctx->mulExp().size() == 0) {
assert(false && "AddExpContext should have a mulExp child!");
}
BinaryExpLenStack.push_back(0); // 进入新的层次时Push 0
visitMulExp(ctx->mulExp(0));
// BinaryValueStack.push_back(result);
for (int i = 1; i < ctx->mulExp().size(); i++) {
auto opNode = dynamic_cast<antlr4::tree::TerminalNode*>(ctx->children[2*i-1]);
int opType = opNode->getSymbol()->getType();
switch(opType) {
case SysYParser::ADD: BinaryExpStack.push_back(BinaryOp::ADD); BinaryExpLenStack.back()++; break;
case SysYParser::SUB: BinaryExpStack.push_back(BinaryOp::SUB); BinaryExpLenStack.back()++; break;
default: assert(false && "Unexpected operator in AddExp.");
}
// BinaryExpStack.push_back(opType);
visitMulExp(ctx->mulExp(i));
// BinaryValueStack.push_back(operand);
}
// if(targetType == nullptr) {
// targetType = Type::getIntType(); // 默认目标类型为int
// }
// 根据后缀表达式的逻辑计算
compute();
// 最后一个Value应该是最终结果
Value* result = BinaryValueStack.back();
BinaryValueStack.pop_back(); // 移除最后一个值,因为它已经被计算
result = promoteType(result, targetType); // 确保结果类型符合目标类型
int ExpLen = BinaryExpLenStack.back();
BinaryExpLenStack.pop_back(); // 离开层次时将该层次
if (ExpLen > 0) {
std::cerr << "Warning: There are still " << ExpLen << " binary val or op left unprocessed in this level!" << std::endl;
return nullptr;
}
return result;
}
Type* SysYIRGenerator::buildArrayType(Type* baseType, const std::vector<Value*>& dims){
Type* currentType = baseType;
// 从最内层维度开始构建 ArrayType
@@ -393,7 +828,8 @@ std::any SysYIRGenerator::visitBType(SysYParser::BTypeContext *ctx) {
}
std::any SysYIRGenerator::visitScalarInitValue(SysYParser::ScalarInitValueContext *ctx) {
Value* value = std::any_cast<Value *>(visitExp(ctx->exp()));
// Value* value = std::any_cast<Value *>(visitExp(ctx->exp()));
Value* value = computeExp(ctx->exp());
ArrayValueTree* result = new ArrayValueTree();
result->setValue(value);
return result;
@@ -408,13 +844,17 @@ std::any SysYIRGenerator::visitArrayInitValue(SysYParser::ArrayInitValueContext
return result;
}
std::any SysYIRGenerator::visitConstScalarInitValue(SysYParser::ConstScalarInitValueContext *ctx) {
std::any SysYIRGenerator::visitConstScalarInitValue(SysYParser::ConstScalarInitValueContext *ctx) {
Value* value = std::any_cast<Value *>(visitConstExp(ctx->constExp()));
ArrayValueTree* result = new ArrayValueTree();
result->setValue(value);
return result;
}
std::any SysYIRGenerator::visitConstExp(SysYParser::ConstExpContext *ctx){
return computeAddExp(ctx->addExp());
}
std::any SysYIRGenerator::visitConstArrayInitValue(SysYParser::ConstArrayInitValueContext *ctx) {
std::vector<ArrayValueTree *> children;
for (const auto &constInitVal : ctx->constInitVal())
@@ -570,8 +1010,8 @@ std::any SysYIRGenerator::visitAssignStmt(SysYParser::AssignStmtContext *ctx) {
vector<Value *> indices;
if (lVal->exp().size() > 0) {
// 如果有下标,访问表达式获取下标值
for (const auto &exp : lVal->exp()) {
Value* indexValue = std::any_cast<Value *>(visitExp(exp));
for (auto &exp : lVal->exp()) {
Value* indexValue = std::any_cast<Value *>(computeExp(exp));
indices.push_back(indexValue);
}
}
@@ -610,15 +1050,18 @@ std::any SysYIRGenerator::visitAssignStmt(SysYParser::AssignStmtContext *ctx) {
LValue = getGEPAddressInst(gepBasePointer, gepIndices);
}
Value* RValue = std::any_cast<Value *>(visitExp(ctx->exp())); // 右值
// Value* RValue = std::any_cast<Value *>(visitExp(ctx->exp())); // 右值
// 先推断 LValue 的类型
// 如果 LValue 是指向数组的指针,则需要根据 indices 获取正确的类型
// 如果 LValue 是标量,则直接使用其类型
// 注意LValue 的类型可能是指向数组的指针 (e.g., int(*)[3]) 或者指向标量的指针 (e.g., int*) 也能推断
Type* LType = builder.getIndexedType(variable->getType(), indices);
Value* RValue = computeExp(ctx->exp(), LType); // 右值计算
Type* RType = RValue->getType();
// TODO:computeExp处理了类型转换可以考虑删除判断逻辑
if (LType != RType) {
ConstantValue *constValue = dynamic_cast<ConstantValue *>(RValue);
if (constValue != nullptr) {
@@ -642,7 +1085,7 @@ std::any SysYIRGenerator::visitAssignStmt(SysYParser::AssignStmtContext *ctx) {
}
} else {
if (LType == Type::getFloatType()) {
RValue = builder.createIToFInst(RValue);
RValue = builder.createItoFInst(RValue);
} else { // 假设如果不是浮点型,就是整型
RValue = builder.createFtoIInst(RValue);
}
@@ -655,6 +1098,14 @@ std::any SysYIRGenerator::visitAssignStmt(SysYParser::AssignStmtContext *ctx) {
}
std::any SysYIRGenerator::visitExpStmt(SysYParser::ExpStmtContext *ctx) {
// 访问表达式
if (ctx->exp() != nullptr) {
computeExp(ctx->exp());
}
return std::any();
}
std::any SysYIRGenerator::visitIfStmt(SysYParser::IfStmtContext *ctx) {
// labels string stream
@@ -822,11 +1273,11 @@ std::any SysYIRGenerator::visitContinueStmt(SysYParser::ContinueStmtContext *ctx
std::any SysYIRGenerator::visitReturnStmt(SysYParser::ReturnStmtContext *ctx) {
Value* returnValue = nullptr;
if (ctx->exp() != nullptr) {
returnValue = std::any_cast<Value *>(visitExp(ctx->exp()));
}
Type* funcType = builder.getBasicBlock()->getParent()->getReturnType();
if (ctx->exp() != nullptr) {
returnValue = computeExp(ctx->exp(), funcType);
}
// TODOL 考虑删除类型转换判断逻辑
if (returnValue != nullptr && funcType!= returnValue->getType()) {
ConstantValue * constValue = dynamic_cast<ConstantValue *>(returnValue);
if (constValue != nullptr) {
@@ -849,7 +1300,7 @@ std::any SysYIRGenerator::visitReturnStmt(SysYParser::ReturnStmtContext *ctx) {
}
} else {
if (funcType == Type::getFloatType()) {
returnValue = builder.createIToFInst(returnValue);
returnValue = builder.createItoFInst(returnValue);
} else {
returnValue = builder.createFtoIInst(returnValue);
}
@@ -891,7 +1342,8 @@ std::any SysYIRGenerator::visitLValue(SysYParser::LValueContext *ctx) {
std::vector<Value *> dims;
for (const auto &exp : ctx->exp()) {
dims.push_back(std::any_cast<Value *>(visitExp(exp)));
Value* expValue = std::any_cast<Value *>(computeExp(exp));
dims.push_back(expValue);
}
// 1. 获取变量的声明维度数量
@@ -995,16 +1447,23 @@ std::any SysYIRGenerator::visitLValue(SysYParser::LValueContext *ctx) {
}
std::any SysYIRGenerator::visitPrimaryExp(SysYParser::PrimaryExpContext *ctx) {
if (ctx->exp() != nullptr)
return visitExp(ctx->exp());
if (ctx->lValue() != nullptr)
return visitLValue(ctx->lValue());
if (ctx->number() != nullptr)
return visitNumber(ctx->number());
if (ctx->exp() != nullptr) {
BinaryExpStack.push_back(BinaryOp::LPAREN);BinaryExpLenStack.back()++;
visitExp(ctx->exp());
BinaryExpStack.push_back(BinaryOp::RPAREN);BinaryExpLenStack.back()++;
}
if (ctx->lValue() != nullptr) {
// 如果是 lValue将value压入栈中
BinaryExpStack.push_back(std::any_cast<Value *>(visitLValue(ctx->lValue())));BinaryExpLenStack.back()++;
}
if (ctx->number() != nullptr) {
BinaryExpStack.push_back(std::any_cast<Value *>(visitNumber(ctx->number())));BinaryExpLenStack.back()++;
}
if (ctx->string() != nullptr) {
cout << "String literal not supported in SysYIRGenerator." << endl;
}
return visitNumber(ctx->number());
return std::any();
}
std::any SysYIRGenerator::visitNumber(SysYParser::NumberContext *ctx) {
@@ -1074,7 +1533,7 @@ std::any SysYIRGenerator::visitCall(SysYParser::CallContext *ctx) {
if (formalParamExpectedValueType->isInt() && actualArgType->isFloat()) {
args[i] = builder.createFtoIInst(args[i]);
} else if (formalParamExpectedValueType->isFloat() && actualArgType->isInt()) {
args[i] = builder.createIToFInst(args[i]);
args[i] = builder.createItoFInst(args[i]);
}
// 2. 指针类型转换 (例如数组退化:`[N x T]*` 到 `T*`,或兼容指针类型之间) TODO不清楚有没有这种样例
// 这种情况常见于数组参数,实参可能是一个更具体的数组指针类型,
@@ -1099,235 +1558,78 @@ std::any SysYIRGenerator::visitCall(SysYParser::CallContext *ctx) {
}
std::any SysYIRGenerator::visitUnaryExp(SysYParser::UnaryExpContext *ctx) {
if (ctx->primaryExp() != nullptr)
return visitPrimaryExp(ctx->primaryExp());
if (ctx->call() != nullptr)
return visitCall(ctx->call());
Value* value = std::any_cast<Value *>(visitUnaryExp(ctx->unaryExp()));
Value* result = value;
if (ctx->unaryOp()->SUB() != nullptr) {
ConstantValue * constValue = dynamic_cast<ConstantValue *>(value);
if (constValue != nullptr) {
if (constValue->isFloat()) {
result = ConstantFloating::get(-constValue->getFloat());
} else {
result = ConstantInteger::get(-constValue->getInt());
if (ctx->primaryExp() != nullptr) {
visitPrimaryExp(ctx->primaryExp());
} else if (ctx->call() != nullptr) {
BinaryExpStack.push_back(std::any_cast<Value *>(visitCall(ctx->call())));BinaryExpLenStack.back()++;
} else if (ctx->unaryOp() != nullptr) {
// 遇到一元操作符,将其压入 BinaryExpStack
auto opNode = dynamic_cast<antlr4::tree::TerminalNode*>(ctx->unaryOp()->children[0]);
int opType = opNode->getSymbol()->getType();
switch(opType) {
case SysYParser::ADD: BinaryExpStack.push_back(BinaryOp::PLUS); BinaryExpLenStack.back()++; break;
case SysYParser::SUB: BinaryExpStack.push_back(BinaryOp::NEG); BinaryExpLenStack.back()++; break;
case SysYParser::NOT: BinaryExpStack.push_back(BinaryOp::NOT); BinaryExpLenStack.back()++; break;
default: assert(false && "Unexpected operator in UnaryExp.");
}
} else if (value != nullptr) {
if (value->getType() == Type::getIntType()) {
result = builder.createNegInst(value);
} else {
result = builder.createFNegInst(value);
}
} else {
std::cout << "UnExp: value is nullptr." << std::endl;
assert(false);
}
} else if (ctx->unaryOp()->NOT() != nullptr) {
auto constValue = dynamic_cast<ConstantValue *>(value);
if (constValue != nullptr) {
if (constValue->isFloat()) {
result =
ConstantFloating::get(1 - (constValue->getFloat() != 0.0F ? 1 : 0));
} else {
result = ConstantInteger::get(1 - (constValue->getInt() != 0 ? 1 : 0));
}
} else if (value != nullptr) {
if (value->getType() == Type::getIntType()) {
result = builder.createNotInst(value);
} else {
result = builder.createFNotInst(value);
}
} else {
std::cout << "UnExp: value is nullptr." << std::endl;
assert(false);
}
visitUnaryExp(ctx->unaryExp());
}
return result;
return std::any();
}
std::any SysYIRGenerator::visitFuncRParams(SysYParser::FuncRParamsContext *ctx) {
std::vector<Value *> params;
for (const auto &exp : ctx->exp())
params.push_back(std::any_cast<Value *>(visitExp(exp)));
for (const auto &exp : ctx->exp()) {
auto param = std::any_cast<Value *>(computeExp(exp));
params.push_back(param);
}
return params;
}
std::any SysYIRGenerator::visitMulExp(SysYParser::MulExpContext *ctx) {
Value * result = std::any_cast<Value *>(visitUnaryExp(ctx->unaryExp(0)));
visitUnaryExp(ctx->unaryExp(0));
for (int i = 1; i < ctx->unaryExp().size(); i++) {
auto opNode = dynamic_cast<antlr4::tree::TerminalNode*>(ctx->children[2*i-1]);
int opType = opNode->getSymbol()->getType();
Value* operand = std::any_cast<Value *>(visitUnaryExp(ctx->unaryExp(i)));
Type* resultType = result->getType();
Type* operandType = operand->getType();
Type* floatType = Type::getFloatType();
if (resultType == floatType || operandType == floatType) {
// 如果有一个操作数是浮点数,则将两个操作数都转换为浮点数
if (operandType != floatType) {
ConstantValue * constValue = dynamic_cast<ConstantValue *>(operand);
if (constValue != nullptr) {
if(dynamic_cast<ConstantInteger *>(constValue)) {
// 如果是整型常量,转换为浮点型
operand = ConstantFloating::get(static_cast<float>(constValue->getInt()));
} else if (dynamic_cast<ConstantFloating *>(constValue)) {
// 如果是浮点型常量,直接使用
operand = ConstantFloating::get(static_cast<float>(constValue->getFloat()));
}
}
else
operand = builder.createIToFInst(operand);
} else if (resultType != floatType) {
ConstantValue* constResult = dynamic_cast<ConstantValue *>(result);
if (constResult != nullptr) {
if(dynamic_cast<ConstantInteger *>(constResult)) {
// 如果是整型常量,转换为浮点型
result = ConstantFloating::get(static_cast<float>(constResult->getInt()));
} else if (dynamic_cast<ConstantFloating *>(constResult)) {
// 如果是浮点型常量,直接使用
result = ConstantFloating::get(static_cast<float>(constResult->getFloat()));
}
}
else
result = builder.createIToFInst(result);
}
ConstantFloating* constResult = dynamic_cast<ConstantFloating *>(result);
ConstantFloating* constOperand = dynamic_cast<ConstantFloating *>(operand);
if (opType == SysYParser::MUL) {
if ((constOperand != nullptr) && (constResult != nullptr)) {
result = ConstantFloating::get(constResult->getFloat() *
constOperand->getFloat());
} else {
result = builder.createFMulInst(result, operand);
}
} else if (opType == SysYParser::DIV) {
if ((constOperand != nullptr) && (constResult != nullptr)) {
result = ConstantFloating::get(constResult->getFloat() /
constOperand->getFloat());
} else {
result = builder.createFDivInst(result, operand);
}
} else {
// float类型的取模操作不允许
std::cout << "MulExp: float type mod operation is not allowed." << std::endl;
assert(false);
}
} else {
ConstantInteger *constResult = dynamic_cast<ConstantInteger *>(result);
ConstantInteger *constOperand = dynamic_cast<ConstantInteger *>(operand);
if (opType == SysYParser::MUL) {
if ((constOperand != nullptr) && (constResult != nullptr))
result = ConstantInteger::get(constResult->getInt() * constOperand->getInt());
else
result = builder.createMulInst(result, operand);
} else if (opType == SysYParser::DIV) {
if ((constOperand != nullptr) && (constResult != nullptr))
result = ConstantInteger::get(constResult->getInt() / constOperand->getInt());
else
result = builder.createDivInst(result, operand);
} else {
if ((constOperand != nullptr) && (constResult != nullptr))
result = ConstantInteger::get(constResult->getInt() % constOperand->getInt());
else
result = builder.createRemInst(result, operand);
}
switch(opType) {
case SysYParser::MUL: BinaryExpStack.push_back(BinaryOp::MUL); BinaryExpLenStack.back()++; break;
case SysYParser::DIV: BinaryExpStack.push_back(BinaryOp::DIV); BinaryExpLenStack.back()++; break;
case SysYParser::MOD: BinaryExpStack.push_back(BinaryOp::MOD); BinaryExpLenStack.back()++; break;
default: assert(false && "Unexpected operator in MulExp.");
}
visitUnaryExp(ctx->unaryExp(i));
}
return result;
return std::any();
}
std::any SysYIRGenerator::visitAddExp(SysYParser::AddExpContext *ctx) {
Value* result = std::any_cast<Value *>(visitMulExp(ctx->mulExp(0)));
visitMulExp(ctx->mulExp(0));
for (int i = 1; i < ctx->mulExp().size(); i++) {
auto opNode = dynamic_cast<antlr4::tree::TerminalNode*>(ctx->children[2*i-1]);
int opType = opNode->getSymbol()->getType();
Value* operand = std::any_cast<Value *>(visitMulExp(ctx->mulExp(i)));
Type* resultType = result->getType();
Type* operandType = operand->getType();
Type* floatType = Type::getFloatType();
if (resultType == floatType || operandType == floatType) {
// 类型转换
if (operandType != floatType) {
ConstantValue * constOperand = dynamic_cast<ConstantValue *>(operand);
if (constOperand != nullptr) {
if(dynamic_cast<ConstantInteger *>(constOperand)) {
// 如果是整型常量,转换为浮点型
operand = ConstantFloating::get(static_cast<float>(constOperand->getInt()));
} else if (dynamic_cast<ConstantFloating *>(constOperand)) {
// 如果是浮点型常量,直接使用
operand = ConstantFloating::get(static_cast<float>(constOperand->getFloat()));
}
}
else
operand = builder.createIToFInst(operand);
} else if (resultType != floatType) {
ConstantValue * constResult = dynamic_cast<ConstantValue *>(result);
if (constResult != nullptr) {
if(dynamic_cast<ConstantInteger *>(constResult)) {
// 如果是整型常量,转换为浮点型
result = ConstantFloating::get(static_cast<float>(constResult->getInt()));
} else if (dynamic_cast<ConstantFloating *>(constResult)) {
// 如果是浮点型常量,直接使用
result = ConstantFloating::get(static_cast<float>(constResult->getFloat()));
}
}
else
result = builder.createIToFInst(result);
}
ConstantFloating *constResult = dynamic_cast<ConstantFloating *>(result);
ConstantFloating *constOperand = dynamic_cast<ConstantFloating *>(operand);
if (opType == SysYParser::ADD) {
if ((constResult != nullptr) && (constOperand != nullptr))
result = ConstantFloating::get(constResult->getFloat() + constOperand->getFloat());
else
result = builder.createFAddInst(result, operand);
} else {
if ((constResult != nullptr) && (constOperand != nullptr))
result = ConstantFloating::get(constResult->getFloat() - constOperand->getFloat());
else
result = builder.createFSubInst(result, operand);
}
} else {
ConstantInteger *constResult = dynamic_cast<ConstantInteger *>(result);
ConstantInteger *constOperand = dynamic_cast<ConstantInteger *>(operand);
if (opType == SysYParser::ADD) {
if ((constResult != nullptr) && (constOperand != nullptr))
result = ConstantInteger::get(constResult->getInt() + constOperand->getInt());
else
result = builder.createAddInst(result, operand);
} else {
if ((constResult != nullptr) && (constOperand != nullptr))
result = ConstantInteger::get(constResult->getInt() - constOperand->getInt());
else
result = builder.createSubInst(result, operand);
}
switch(opType) {
case SysYParser::ADD: BinaryExpStack.push_back(BinaryOp::ADD); BinaryExpLenStack.back()++; break;
case SysYParser::SUB: BinaryExpStack.push_back(BinaryOp::SUB); BinaryExpLenStack.back()++; break;
default: assert(false && "Unexpected operator in AddExp.");
}
visitMulExp(ctx->mulExp(i));
}
return result;
return std::any();
}
std::any SysYIRGenerator::visitRelExp(SysYParser::RelExpContext *ctx) {
Value* result = std::any_cast<Value *>(visitAddExp(ctx->addExp(0)));
Value* result = computeAddExp(ctx->addExp(0));
for (int i = 1; i < ctx->addExp().size(); i++) {
auto opNode = dynamic_cast<antlr4::tree::TerminalNode*>(ctx->children[2*i-1]);
int opType = opNode->getSymbol()->getType();
Value* operand = std::any_cast<Value *>(visitAddExp(ctx->addExp(i)));
Value* operand = computeAddExp(ctx->addExp(i));
Type* resultType = result->getType();
Type* operandType = operand->getType();
@@ -1366,7 +1668,7 @@ std::any SysYIRGenerator::visitRelExp(SysYParser::RelExpContext *ctx) {
}
}
else
result = builder.createIToFInst(result);
result = builder.createItoFInst(result);
}
if (operandType != floatType) {
@@ -1380,7 +1682,7 @@ std::any SysYIRGenerator::visitRelExp(SysYParser::RelExpContext *ctx) {
}
}
else
operand = builder.createIToFInst(operand);
operand = builder.createItoFInst(operand);
}
@@ -1407,6 +1709,7 @@ std::any SysYIRGenerator::visitRelExp(SysYParser::RelExpContext *ctx) {
std::any SysYIRGenerator::visitEqExp(SysYParser::EqExpContext *ctx) {
// TODO其实已经保证了result是一个int类型的值可以删除冗余判断逻辑
Value * result = std::any_cast<Value *>(visitRelExp(ctx->relExp(0)));
for (int i = 1; i < ctx->relExp().size(); i++) {
@@ -1445,7 +1748,7 @@ std::any SysYIRGenerator::visitEqExp(SysYParser::EqExpContext *ctx) {
}
}
else
result = builder.createIToFInst(result);
result = builder.createItoFInst(result);
}
if (operandType != floatType) {
if (constOperand != nullptr) {
@@ -1458,7 +1761,7 @@ std::any SysYIRGenerator::visitEqExp(SysYParser::EqExpContext *ctx) {
}
}
else
operand = builder.createIToFInst(operand);
operand = builder.createItoFInst(operand);
}
if (opType == SysYParser::EQ) result = builder.createFCmpEQInst(result, operand);
@@ -1567,7 +1870,7 @@ void Utils::tree2Array(Type *type, ArrayValueTree *root,
assert(false && "Unknown constant type for float conversion.");
}
else
result.push_back(builder->createIToFInst(value));
result.push_back(builder->createItoFInst(value));
} else {
ConstantValue* constValue = dynamic_cast<ConstantValue *>(value);

View File

@@ -1,7 +1,10 @@
#include "SysYIRPrinter.h"
#include <cassert>
#include <fstream>
#include <iomanip>
#include <iostream>
#include <limits>
#include <sstream>
#include <string>
#include "IR.h" // 确保IR.h包含了ArrayType、GetElementPtrInst等的定义
@@ -61,16 +64,21 @@ std::string SysYPrinter::getValueName(Value *value) {
} else if (auto constInt = dynamic_cast<ConstantInteger*>(value)) { // 优先匹配具体的常量类型
return std::to_string(constInt->getInt());
} else if (auto constFloat = dynamic_cast<ConstantFloating*>(value)) { // 优先匹配具体的常量类型
return std::to_string(constFloat->getFloat());
std::ostringstream oss;
oss << std::scientific << std::setprecision(std::numeric_limits<float>::max_digits10) << constFloat->getFloat();
return oss.str();
} else if (auto constUndef = dynamic_cast<UndefinedValue*>(value)) { // 如果有Undef类型
return "undef";
} else if (auto constVal = dynamic_cast<ConstantValue*>(value)) { // fallback for generic ConstantValue
// 这里的逻辑可能需要根据你ConstantValue的实际设计调整
// 确保它能处理所有可能的ConstantValue
if (constVal->getType()->isFloat()) {
return std::to_string(constVal->getFloat());
if (auto constInt = dynamic_cast<ConstantInteger*>(value)) { // 优先匹配具体的常量类型
return std::to_string(constInt->getInt());
} else if (auto constFloat = dynamic_cast<ConstantFloating*>(value)) { // 优先匹配具体的常量类型
std::ostringstream oss;
oss << std::scientific << std::setprecision(std::numeric_limits<float>::max_digits10) << constFloat->getFloat();
return oss.str();
}
return std::to_string(constVal->getInt());
} else if (auto constVar = dynamic_cast<ConstantVariable*>(value)) {
return constVar->getName(); // 假设ConstantVariable有自己的名字或通过getByIndices获取值
} else if (auto argVar = dynamic_cast<Argument*>(value)) {