Compare commits

...

14 Commits

Author SHA1 Message Date
Lixuanwang
b2b88ee511 [backend-beta] saving for simpler implementation for register allocation 2025-06-24 05:02:11 +08:00
Lixuanwang
395e6e4003 [backend] fixed many bugs 2025-06-24 03:23:45 +08:00
Lixuanwang
20cc08708a [backend] introduced debug option 2025-06-24 02:56:17 +08:00
Lixuanwang
942cb32976 [backend] fixed bugs 2025-06-24 00:42:14 +08:00
Lixuanwang
ac7569d890 Merge branch 'IROptPre' into backend 2025-06-24 00:40:36 +08:00
Lixuanwang
11cd32e6df [backend] fixed some bugs 2025-06-24 00:35:38 +08:00
Lixuanwang
617244fae7 [backend] switch to simpler implementation for inst selection 2025-06-24 00:30:33 +08:00
Lixuanwang
3c3f48ee87 [backend] fixed 1 segmentation fault 2025-06-23 22:38:29 +08:00
Lixuanwang
ab3eb253f9 [backend] debugging segmentation fault caused by branch instr 2025-06-23 17:02:29 +08:00
Lixuanwang
7d37bd7528 [backend] introduced DAG, GraphAlloc 2025-06-23 15:38:01 +08:00
Lixuanwang
af00612376 [backend] supported if 2025-06-23 06:16:19 +08:00
ladev789
10e1476ba1 [backend] test01 passed 2025-06-22 20:05:34 +08:00
ladev789
b94e87637a Merge remote-tracking branch 'origin/IRPrinter' into backend 2025-06-22 20:00:29 +08:00
ladev789
88a561177d [backend] incorrect asm output 2025-06-22 20:00:03 +08:00
6 changed files with 667 additions and 165 deletions

View File

@@ -1,30 +1,10 @@
#include "RISCv32Backend.h" #include "RISCv32Backend.h"
#include <sstream> #include <sstream>
#include <algorithm> #include <algorithm>
#include <stack>
namespace sysy { namespace sysy {
const std::vector<RISCv32CodeGen::PhysicalReg> RISCv32CodeGen::allocable_regs = {
PhysicalReg::T0, PhysicalReg::T1, PhysicalReg::T2, PhysicalReg::T3,
PhysicalReg::T4, PhysicalReg::T5, PhysicalReg::T6,
PhysicalReg::A0, PhysicalReg::A1, PhysicalReg::A2, PhysicalReg::A3,
PhysicalReg::A4, PhysicalReg::A5, PhysicalReg::A6, PhysicalReg::A7
};
std::string RISCv32CodeGen::reg_to_string(PhysicalReg reg) {
switch (reg) {
case PhysicalReg::T0: return "t0"; case PhysicalReg::T1: return "t1";
case PhysicalReg::T2: return "t2"; case PhysicalReg::T3: return "t3";
case PhysicalReg::T4: return "t4"; case PhysicalReg::T5: return "t5";
case PhysicalReg::T6: return "t6"; case PhysicalReg::A0: return "a0";
case PhysicalReg::A1: return "a1"; case PhysicalReg::A2: return "a2";
case PhysicalReg::A3: return "a3"; case PhysicalReg::A4: return "a4";
case PhysicalReg::A5: return "a5"; case PhysicalReg::A6: return "a6";
case PhysicalReg::A7: return "a7";
default: return "";
}
}
std::string RISCv32CodeGen::code_gen() { std::string RISCv32CodeGen::code_gen() {
std::stringstream ss; std::stringstream ss;
ss << ".text\n"; ss << ".text\n";
@@ -34,16 +14,22 @@ std::string RISCv32CodeGen::code_gen() {
std::string RISCv32CodeGen::module_gen() { std::string RISCv32CodeGen::module_gen() {
std::stringstream ss; std::stringstream ss;
// 生成全局变量(数据段) for (auto& global : module->getGlobals()) {
for (const auto& global : module->getGlobals()) { ss << ".global " << global->getName() << "\n";
ss << ".data\n"; ss << ".section .data\n";
ss << ".globl " << global->getName() << "\n"; ss << ".align 2\n";
ss << global->getName() << ":\n"; ss << global->getName() << ":\n";
ss << " .word 0\n"; // 假设初始化为0 for (auto value : global->getInitValues().getValues()) {
auto const_val = dynamic_cast<ConstantValue*>(value);
if (const_val->isInt()) {
ss << ".word " << const_val->getInt() << "\n";
} else {
ss << ".float " << const_val->getFloat() << "\n";
}
}
} }
// 生成函数(文本段) ss << ".section .text\n";
ss << ".text\n"; for (auto& func : module->getFunctions()) {
for (const auto& func : module->getFunctions()) {
ss << function_gen(func.second.get()); ss << function_gen(func.second.get());
} }
return ss.str(); return ss.str();
@@ -51,107 +37,582 @@ std::string RISCv32CodeGen::module_gen() {
std::string RISCv32CodeGen::function_gen(Function* func) { std::string RISCv32CodeGen::function_gen(Function* func) {
std::stringstream ss; std::stringstream ss;
// 函数标签 ss << ".global " << func->getName() << "\n";
ss << ".globl " << func->getName() << "\n"; ss << ".type " << func->getName() << ", @function\n";
ss << func->getName() << ":\n"; ss << func->getName() << ":\n";
// 序言:保存 ra分配堆栈
bool is_leaf = true; // 简化假设 // Perform register allocation
ss << " addi sp, sp, -16\n"; auto live_sets = liveness_analysis(func);
ss << " sw ra, 12(sp)\n"; auto interference_graph = build_interference_graph(live_sets);
// 寄存器分配 auto alloc = color_graph(func, interference_graph);
auto alloc = register_allocation(func);
// 生成基本块代码 // Prologue: Adjust stack and save callee-saved registers
for (const auto& bb : func->getBasicBlocks()) { if (alloc.stack_size > 0) {
ss << basicBlock_gen(bb.get(), alloc); ss << " addi sp, sp, -" << alloc.stack_size << "\n";
ss << " sw ra, " << (alloc.stack_size - 4) << "(sp)\n";
}
for (auto preg : callee_saved) {
if (std::find_if(alloc.vreg_to_preg.begin(), alloc.vreg_to_preg.end(),
[preg](const auto& pair) { return pair.second == preg; }) != alloc.vreg_to_preg.end()) {
ss << " sw " << get_preg_str(preg) << ", " << (alloc.stack_size - 8) << "(sp)\n";
}
}
int block_idx = 0;
for (auto& bb : func->getBasicBlocks()) {
ss << basicBlock_gen(bb.get(), alloc, block_idx++);
}
// Epilogue: Restore callee-saved registers and stack
for (auto preg : callee_saved) {
if (std::find_if(alloc.vreg_to_preg.begin(), alloc.vreg_to_preg.end(),
[preg](const auto& pair) { return pair.second == preg; }) != alloc.vreg_to_preg.end()) {
ss << " lw " << get_preg_str(preg) << ", " << (alloc.stack_size - 8) << "(sp)\n";
}
}
if (alloc.stack_size > 0) {
ss << " lw ra, " << (alloc.stack_size - 4) << "(sp)\n";
ss << " addi sp, sp, " << alloc.stack_size << "\n";
} }
// 结尾:恢复 ra释放堆栈
ss << " lw ra, 12(sp)\n";
ss << " addi sp, sp, 16\n";
ss << " ret\n"; ss << " ret\n";
return ss.str(); return ss.str();
} }
std::string RISCv32CodeGen::basicBlock_gen(BasicBlock* bb, const RegAllocResult& alloc) { std::string RISCv32CodeGen::basicBlock_gen(BasicBlock* bb, const RegAllocResult& alloc, int block_idx) {
std::stringstream ss; std::stringstream ss;
ss << bb->getName() << ":\n"; ss << ".L" << block_idx << ":\n";
for (const auto& inst : bb->getInstructions()) { auto dag_nodes = build_dag(bb);
auto riscv_insts = instruction_gen(inst.get()); for (auto& node : dag_nodes) {
for (const auto& riscv_inst : riscv_insts) { select_instructions(node.get(), alloc);
ss << " " << riscv_inst.opcode; }
for (size_t i = 0; i < riscv_inst.operands.size(); ++i) { std::set<DAGNode*> emitted_nodes;
if (i > 0) ss << ", "; for (auto& node : dag_nodes) {
if (riscv_inst.operands[i].kind == Operand::Kind::Reg) { emit_instructions(node.get(), ss, alloc, emitted_nodes);
auto it = alloc.reg_map.find(riscv_inst.operands[i].value);
if (it != alloc.reg_map.end()) {
ss << reg_to_string(it->second);
} else {
auto stack_it = alloc.stack_map.find(riscv_inst.operands[i].value);
if (stack_it != alloc.stack_map.end()) {
ss << stack_it->second << "(sp)";
} else {
ss << "%" << riscv_inst.operands[i].value->getName();
}
}
} else if (riscv_inst.operands[i].kind == Operand::Kind::Imm) {
ss << riscv_inst.operands[i].label;
} else {
ss << riscv_inst.operands[i].label;
}
}
ss << "\n";
}
} }
return ss.str(); return ss.str();
} }
std::vector<RISCv32CodeGen::RISCv32Inst> RISCv32CodeGen::instruction_gen(Instruction* inst) { std::vector<std::unique_ptr<RISCv32CodeGen::DAGNode>> RISCv32CodeGen::build_dag(BasicBlock* bb) {
std::vector<RISCv32Inst> insts; std::vector<std::unique_ptr<DAGNode>> nodes;
if (auto bin = dynamic_cast<BinaryInst*>(inst)) { std::map<Value*, DAGNode*> value_to_node;
std::string opcode; int vreg_counter = 0;
if (bin->getKind() == BinaryInst::kAdd) opcode = "add";
else if (bin->getKind() == BinaryInst::kSub) opcode = "sub"; for (auto& inst : bb->getInstructions()) {
else if (bin->getKind() == BinaryInst::kMul) opcode = "mul"; if (auto alloca = dynamic_cast<AllocaInst*>(inst.get())) {
else return insts; // 其他操作未实现 auto node = std::make_unique<DAGNode>(DAGNode::ALLOCA_ADDR);
insts.emplace_back(opcode, std::vector<Operand>{ node->value = alloca;
{Operand::Kind::Reg, bin}, node->result_vreg = "%" + inst->getName(); // Use IR name (%a(0), %b(0))
{Operand::Kind::Reg, bin->getLhs()}, value_to_node[alloca] = node.get();
{Operand::Kind::Reg, bin->getRhs()} nodes.push_back(std::move(node));
}); } else if (auto load = dynamic_cast<LoadInst*>(inst.get())) {
} else if (auto load = dynamic_cast<LoadInst*>(inst)) { auto node = std::make_unique<DAGNode>(DAGNode::LOAD);
insts.emplace_back("lw", std::vector<Operand>{ node->value = load;
{Operand::Kind::Reg, load}, node->result_vreg = "%" + inst->getName(); // Use IR name (%0, %1)
{Operand::Kind::Label, load->getPointer()->getName()} auto pointer = load->getPointer();
}); if (value_to_node.count(pointer)) {
} else if (auto store = dynamic_cast<StoreInst*>(inst)) { node->operands.push_back(value_to_node[pointer]);
insts.emplace_back("sw", std::vector<Operand>{ value_to_node[pointer]->users.push_back(node.get());
{Operand::Kind::Reg, store->getValue()}, }
{Operand::Kind::Label, store->getPointer()->getName()} value_to_node[load] = node.get();
}); nodes.push_back(std::move(node));
} else if (auto store = dynamic_cast<StoreInst*>(inst.get())) {
auto node = std::make_unique<DAGNode>(DAGNode::STORE);
node->value = store;
auto value_operand = store->getValue();
auto pointer = store->getPointer();
if (value_to_node.count(value_operand)) {
node->operands.push_back(value_to_node[value_operand]);
value_to_node[value_operand]->users.push_back(node.get());
} else if (auto const_val = dynamic_cast<ConstantValue*>(value_operand)) {
auto const_node = std::make_unique<DAGNode>(DAGNode::CONSTANT);
const_node->value = const_val;
const_node->result_vreg = "%" + std::to_string(vreg_counter++); // Use simple %N for constants
value_to_node[value_operand] = const_node.get();
node->operands.push_back(const_node.get());
const_node->users.push_back(node.get());
nodes.push_back(std::move(const_node));
}
if (value_to_node.count(pointer)) {
node->operands.push_back(value_to_node[pointer]);
value_to_node[pointer]->users.push_back(node.get());
}
nodes.push_back(std::move(node));
} else if (auto binary = dynamic_cast<BinaryInst*>(inst.get())) {
auto node = std::make_unique<DAGNode>(DAGNode::BINARY);
node->value = binary;
node->result_vreg = "%" + inst->getName(); // Use IR name (%2)
for (auto operand : binary->getOperands()) {
auto op_value = operand->getValue();
if (value_to_node.count(op_value)) {
node->operands.push_back(value_to_node[op_value]);
value_to_node[op_value]->users.push_back(node.get());
} else if (auto const_val = dynamic_cast<ConstantValue*>(op_value)) {
auto const_node = std::make_unique<DAGNode>(DAGNode::CONSTANT);
const_node->value = const_val;
const_node->result_vreg = "%" + std::to_string(vreg_counter++);
value_to_node[op_value] = const_node.get();
node->operands.push_back(const_node.get());
const_node->users.push_back(node.get());
nodes.push_back(std::move(const_node));
}
}
value_to_node[binary] = node.get();
nodes.push_back(std::move(node));
} else if (auto ret = dynamic_cast<ReturnInst*>(inst.get())) {
auto node = std::make_unique<DAGNode>(DAGNode::RETURN);
node->value = ret;
if (ret->hasReturnValue()) {
auto value_operand = ret->getReturnValue();
if (value_to_node.count(value_operand)) {
node->operands.push_back(value_to_node[value_operand]);
value_to_node[value_operand]->users.push_back(node.get());
} else if (auto const_val = dynamic_cast<ConstantValue*>(value_operand)) {
auto const_node = std::make_unique<DAGNode>(DAGNode::CONSTANT);
const_node->value = const_val;
const_node->result_vreg = "%" + std::to_string(vreg_counter++);
value_to_node[value_operand] = const_node.get();
node->operands.push_back(const_node.get());
const_node->users.push_back(node.get());
nodes.push_back(std::move(const_node));
}
}
nodes.push_back(std::move(node));
} else if (auto cond_br = dynamic_cast<CondBrInst*>(inst.get())) {
auto node = std::make_unique<DAGNode>(DAGNode::BRANCH);
node->value = cond_br;
auto condition = cond_br->getCondition();
if (value_to_node.count(condition)) {
node->operands.push_back(value_to_node[condition]);
value_to_node[condition]->users.push_back(node.get());
} else if (auto const_val = dynamic_cast<ConstantValue*>(condition)) {
auto const_node = std::make_unique<DAGNode>(DAGNode::CONSTANT);
const_node->value = const_val;
const_node->result_vreg = "%" + std::to_string(vreg_counter++);
value_to_node[condition] = const_node.get();
node->operands.push_back(const_node.get());
const_node->users.push_back(node.get());
nodes.push_back(std::move(const_node));
}
nodes.push_back(std::move(node));
} else if (auto uncond_br = dynamic_cast<UncondBrInst*>(inst.get())) {
auto node = std::make_unique<DAGNode>(DAGNode::BRANCH);
node->value = uncond_br;
nodes.push_back(std::move(node));
}
} }
return insts; return nodes;
} }
void RISCv32CodeGen::eliminate_phi(Function* func) { void RISCv32CodeGen::select_instructions(DAGNode* node, const RegAllocResult& alloc) {
// TODO: 实现 phi 指令消除 if (node->inst.empty()) {
switch (node->kind) {
case DAGNode::CONSTANT: {
auto const_val = dynamic_cast<ConstantValue*>(node->value);
if (const_val->isInt()) {
node->inst = "li " + node->result_vreg + ", " + std::to_string(const_val->getInt());
} else {
node->inst = "# float constant not implemented";
}
break;
}
case DAGNode::LOAD: {
auto load = dynamic_cast<LoadInst*>(node->value);
auto pointer = load->getPointer();
if (auto alloca = dynamic_cast<AllocaInst*>(pointer)) {
if (alloc.stack_map.count(alloca)) {
node->inst = "lw " + node->result_vreg + ", " + std::to_string(alloc.stack_map.at(alloca)) + "(sp)";
}
} else if (auto global = dynamic_cast<GlobalValue*>(pointer)) {
node->inst = "lw " + node->result_vreg + ", " + global->getName() + "(gp)";
}
break;
}
case DAGNode::STORE: {
auto store = dynamic_cast<StoreInst*>(node->value);
auto pointer = store->getPointer();
auto value_vreg = node->operands[0]->result_vreg;
if (auto alloca = dynamic_cast<AllocaInst*>(pointer)) {
if (alloc.stack_map.count(alloca)) {
node->inst = "sw " + value_vreg + ", " + std::to_string(alloc.stack_map.at(alloca)) + "(sp)";
}
} else if (auto global = dynamic_cast<GlobalValue*>(pointer)) {
node->inst = "sw " + value_vreg + ", " + global->getName() + "(gp)";
}
break;
}
case DAGNode::BINARY: {
auto binary = dynamic_cast<BinaryInst*>(node->value);
auto lhs_vreg = node->operands[0]->result_vreg;
auto rhs_vreg = node->operands[1]->result_vreg;
std::string op;
switch (binary->getKind()) {
case Instruction::kAdd: op = "add"; break;
case Instruction::kSub: op = "sub"; break;
case Instruction::kMul: op = "mul"; break;
case Instruction::kDiv: op = "div"; break;
case Instruction::kICmpEQ: op = "seq"; break;
case Instruction::kICmpNE: op = "sne"; break;
case Instruction::kICmpLT: op = "slt"; break;
case Instruction::kICmpGT: op = "sgt"; break;
case Instruction::kICmpLE: op = "sle"; break;
case Instruction::kICmpGE: op = "sge"; break;
default: op = "# unknown"; break;
}
node->inst = op + " " + node->result_vreg + ", " + lhs_vreg + ", " + rhs_vreg;
break;
}
case DAGNode::RETURN: {
auto ret = dynamic_cast<ReturnInst*>(node->value);
if (ret->hasReturnValue()) {
auto value_vreg = node->operands[0]->result_vreg;
node->inst = "mv a0, " + value_vreg;
} else {
node->inst = "ret";
}
break;
}
case DAGNode::BRANCH: {
if (auto cond_br = dynamic_cast<CondBrInst*>(node->value)) {
auto condition_vreg = node->operands[0]->result_vreg;
auto then_block = cond_br->getThenBlock();
auto else_block = cond_br->getElseBlock();
int then_idx = 0, else_idx = 0;
int idx = 0;
for (auto& bb : cond_br->getFunction()->getBasicBlocks()) {
if (bb.get() == then_block) then_idx = idx;
if (bb.get() == else_block) else_idx = idx;
idx++;
}
node->inst = "bne " + condition_vreg + ", zero, .L" + std::to_string(then_idx) + "\n j .L" + std::to_string(else_idx);
} else if (auto uncond_br = dynamic_cast<UncondBrInst*>(node->value)) {
auto target_block = uncond_br->getBlock();
int target_idx = 0;
int idx = 0;
for (auto& bb : uncond_br->getFunction()->getBasicBlocks()) {
if (bb.get() == target_block) target_idx = idx;
idx++;
}
node->inst = "j .L" + std::to_string(target_idx);
}
break;
}
default:
node->inst = "# unimplemented";
break;
}
}
} }
std::map<Instruction*, std::set<Value*>> RISCv32CodeGen::liveness_analysis(Function* func) { void RISCv32CodeGen::emit_instructions(DAGNode* node, std::stringstream& ss, const RegAllocResult& alloc, std::set<DAGNode*>& emitted_nodes) {
std::map<Instruction*, std::set<Value*>> live_sets; if (emitted_nodes.count(node)) return;
// TODO: 实现活跃性分析 for (auto operand : node->operands) {
return live_sets; emit_instructions(operand, ss, alloc, emitted_nodes);
}
if (!node->inst.empty() && node->inst != "# unimplemented" && node->inst.find("# alloca") == std::string::npos) {
std::string inst = node->inst;
std::vector<std::pair<std::string, std::string>> replacements;
// Collect replacements for result and operand virtual registers
if (node->result_vreg != "" && node->kind != DAGNode::ALLOCA_ADDR) {
if (alloc.vreg_to_preg.count(node->result_vreg)) {
replacements.emplace_back(node->result_vreg, get_preg_str(alloc.vreg_to_preg.at(node->result_vreg)));
} else if (alloc.spill_map.count(node->result_vreg)) {
auto temp_reg = PhysicalReg::T0;
replacements.emplace_back(node->result_vreg, get_preg_str(temp_reg));
inst = inst.substr(0, inst.find('\n')); // Handle multi-line instructions
ss << " " << inst << "\n";
ss << " sw " << get_preg_str(temp_reg) << ", " << alloc.spill_map.at(node->result_vreg) << "(sp)\n";
emitted_nodes.insert(node);
return;
} else {
ss << "# Error: Virtual register " << node->result_vreg << " not allocated (kind: " << node->getNodeKindString() << ")\n";
}
}
for (auto operand : node->operands) {
if (operand->result_vreg != "" && operand->kind != DAGNode::ALLOCA_ADDR) {
if (alloc.vreg_to_preg.count(operand->result_vreg)) {
replacements.emplace_back(operand->result_vreg, get_preg_str(alloc.vreg_to_preg.at(operand->result_vreg)));
} else if (alloc.spill_map.count(operand->result_vreg)) {
auto temp_reg = PhysicalReg::T1;
ss << " lw " << get_preg_str(temp_reg) << ", " << alloc.spill_map.at(operand->result_vreg) << "(sp)\n";
replacements.emplace_back(operand->result_vreg, get_preg_str(temp_reg));
} else {
ss << "# Error: Operand virtual register " << operand->result_vreg << " not allocated (kind: " << operand->getNodeKindString() << ")\n";
}
}
}
// Perform all replacements only if vreg exists in inst
for (const auto& [vreg, preg] : replacements) {
size_t pos = inst.find(vreg);
while (pos != std::string::npos) {
inst.replace(pos, vreg.length(), preg);
pos = inst.find(vreg, pos + preg.length());
}
}
// Emit the instruction
if (node->kind == DAGNode::BRANCH || inst.find('\n') != std::string::npos) {
ss << inst << "\n";
} else if (inst != "ret") {
ss << " " << inst << "\n";
}
}
emitted_nodes.insert(node);
} }
std::map<Value*, std::set<Value*>> RISCv32CodeGen::build_interference_graph( std::map<Instruction*, std::set<std::string>> RISCv32CodeGen::liveness_analysis(Function* func) {
const std::map<Instruction*, std::set<Value*>>& live_sets) { std::map<Instruction*, std::set<std::string>> live_in, live_out;
std::map<Value*, std::set<Value*>> graph; bool changed;
// TODO: 实现干扰图构建
return graph; // Build DAG for all basic blocks
std::map<BasicBlock*, std::vector<std::unique_ptr<DAGNode>>> bb_dags;
for (auto& bb : func->getBasicBlocks()) {
bb_dags[bb.get()] = build_dag(bb.get());
}
// Initialize live_in and live_out
for (auto& bb : func->getBasicBlocks()) {
for (auto& inst : bb->getInstructions()) {
live_in[inst.get()];
live_out[inst.get()];
}
}
do {
changed = false;
for (auto& bb : func->getBasicBlocks()) {
// Reverse iterate for backward analysis
for (auto it = bb->getInstructions().rbegin(); it != bb->getInstructions().rend(); ++it) {
auto inst = it->get();
std::set<std::string> new_live_in, new_live_out;
// live_out = union of live_in of successors
for (auto succ : bb->getSuccessors()) {
if (!succ->getInstructions().empty()) {
auto succ_inst = succ->getInstructions().front().get();
new_live_out.insert(live_in[succ_inst].begin(), live_in[succ_inst].end());
}
}
// Collect def and use
std::set<std::string> def, use;
// IR instruction def
if (inst->getName() != "" && !dynamic_cast<AllocaInst*>(inst)) {
def.insert("%" + inst->getName());
}
// IR instruction use
for (auto operand : inst->getOperands()) {
auto value = operand->getValue();
if (auto op_inst = dynamic_cast<Instruction*>(value)) {
if (op_inst->getName() != "" && !dynamic_cast<AllocaInst*>(op_inst)) {
use.insert("%" + op_inst->getName());
}
}
}
// DAG node def and use
for (auto& node : bb_dags[bb.get()]) {
if (node->value == inst && node->kind != DAGNode::ALLOCA_ADDR) {
if (node->result_vreg != "") {
def.insert(node->result_vreg);
}
for (auto operand : node->operands) {
if (operand->result_vreg != "" && operand->kind != DAGNode::ALLOCA_ADDR) {
use.insert(operand->result_vreg);
}
}
}
// Constants
if (node->kind == DAGNode::CONSTANT) {
for (auto user : node->users) {
if (user->value == inst) {
use.insert(node->result_vreg);
}
}
}
}
// live_in = use U (live_out - def)
std::set<std::string> live_out_minus_def;
std::set_difference(new_live_out.begin(), new_live_out.end(),
def.begin(), def.end(),
std::inserter(live_out_minus_def, live_out_minus_def.begin()));
new_live_in.insert(use.begin(), use.end());
new_live_in.insert(live_out_minus_def.begin(), live_out_minus_def.end());
// Debug
std::cerr << "Instruction: " << (inst->getName() != "" ? "%" + inst->getName() : "none") << "\n";
std::cerr << " def: "; for (const auto& d : def) std::cerr << d << " "; std::cerr << "\n";
std::cerr << " use: "; for (const auto& u : use) std::cerr << u << " "; std::cerr << "\n";
std::cerr << " live_in: "; for (const auto& v : new_live_in) std::cerr << v << " "; std::cerr << "\n";
std::cerr << " live_out: "; for (const auto& v : new_live_out) std::cerr << v << " "; std::cerr << "\n";
if (live_in[inst] != new_live_in || live_out[inst] != new_live_out) {
live_in[inst] = new_live_in;
live_out[inst] = new_live_out;
changed = true;
}
}
}
} while (changed);
// Debug live_out
for (const auto& [inst, live_vars] : live_out) {
std::cerr << "Instruction: " << (inst->getName() != "" ? "%" + inst->getName() : "none") << " live_out: ";
for (const auto& var : live_vars) {
std::cerr << var << " ";
}
std::cerr << "\n";
}
return live_out;
} }
RISCv32CodeGen::RegAllocResult RISCv32CodeGen::register_allocation(Function* func) { std::map<std::string, std::set<std::string>> RISCv32CodeGen::build_interference_graph(
RegAllocResult result; const std::map<Instruction*, std::set<std::string>>& live_sets) {
// TODO: 实现寄存器分配 std::map<std::string, std::set<std::string>> interference_graph;
return result; for (const auto& [inst, live_vars] : live_sets) {
std::string def_var = inst->getName() != "" && !dynamic_cast<AllocaInst*>(inst) ? "%" + inst->getName() : "";
if (def_var != "") {
interference_graph[def_var]; // Initialize
for (const auto& live_var : live_vars) {
if (live_var != def_var && live_var.find("%a(") != 0 && live_var.find("%b(") != 0) {
interference_graph[def_var].insert(live_var);
interference_graph[live_var].insert(def_var);
}
}
}
// Initialize all live variables
for (const auto& live_var : live_vars) {
if (live_var.find("%a(") != 0 && live_var.find("%b(") != 0) {
interference_graph[live_var];
}
}
// Live variables interfere with each other
for (auto it1 = live_vars.begin(); it1 != live_vars.end(); ++it1) {
if (it1->find("%a(") == 0 || it1->find("%b(") == 0) continue;
for (auto it2 = std::next(it1); it2 != live_vars.end(); ++it2) {
if (it2->find("%a(") == 0 || it2->find("%b(") == 0) continue;
interference_graph[*it1].insert(*it2);
interference_graph[*it2].insert(*it1);
}
}
}
// Debug
for (const auto& [vreg, neighbors] : interference_graph) {
std::cerr << "Vreg " << vreg << " interferes with: ";
for (const auto& neighbor : neighbors) {
std::cerr << neighbor << " ";
}
std::cerr << "\n";
}
return interference_graph;
}
RISCv32CodeGen::RegAllocResult RISCv32CodeGen::color_graph(Function* func, const std::map<std::string, std::set<std::string>>& interference_graph) {
RegAllocResult alloc;
std::map<std::string, std::set<std::string>> ig = interference_graph;
std::stack<std::string> stack;
std::set<std::string> spilled;
// Available physical registers
std::vector<PhysicalReg> available_regs = {
PhysicalReg::T0, PhysicalReg::T1, PhysicalReg::T2, PhysicalReg::T3, PhysicalReg::T4, PhysicalReg::T5, PhysicalReg::T6,
PhysicalReg::S0, PhysicalReg::S1, PhysicalReg::S2, PhysicalReg::S3, PhysicalReg::S4, PhysicalReg::S5,
PhysicalReg::S6, PhysicalReg::S7, PhysicalReg::S8, PhysicalReg::S9, PhysicalReg::S10, PhysicalReg::S11
};
// Simplify: Push nodes with degree < number of registers
while (!ig.empty()) {
bool simplified = false;
for (auto it = ig.begin(); it != ig.end();) {
if (it->second.size() < available_regs.size()) {
stack.push(it->first);
for (auto& [vreg, neighbors] : ig) {
neighbors.erase(it->first);
}
it = ig.erase(it);
simplified = true;
} else {
++it;
}
}
if (!simplified) {
// Spill the node with the highest degree
auto max_it = ig.begin();
for (auto it = ig.begin(); it != ig.end(); ++it) {
if (it->second.size() > max_it->second.size()) {
max_it = it;
}
}
spilled.insert(max_it->first);
for (auto& [vreg, neighbors] : ig) {
neighbors.erase(max_it->first);
}
ig.erase(max_it);
}
}
// Assign colors (physical registers)
while (!stack.empty()) {
auto vreg = stack.top();
stack.pop();
std::set<PhysicalReg> used_colors;
if (interference_graph.count(vreg)) {
for (const auto& neighbor : interference_graph.at(vreg)) {
if (alloc.vreg_to_preg.count(neighbor)) {
used_colors.insert(alloc.vreg_to_preg.at(neighbor));
}
}
}
bool assigned = false;
for (auto preg : available_regs) {
if (!used_colors.count(preg)) {
alloc.vreg_to_preg[vreg] = preg;
assigned = true;
break;
}
}
if (!assigned) {
spilled.insert(vreg);
}
}
// Allocate stack space for AllocaInst and spilled virtual registers
int stack_offset = 0;
for (auto& bb : func->getBasicBlocks()) {
for (auto& inst : bb->getInstructions()) {
if (auto alloca = dynamic_cast<AllocaInst*>(inst.get())) {
alloc.stack_map[alloca] = stack_offset;
stack_offset += 4; // 4 bytes per variable
}
}
}
for (const auto& vreg : spilled) {
alloc.spill_map[vreg] = stack_offset;
stack_offset += 4;
}
alloc.stack_size = stack_offset + 8; // Extra space for ra and callee-saved
// Debug output to verify register allocation
for (const auto& [vreg, preg] : alloc.vreg_to_preg) {
std::cerr << "Vreg " << vreg << " assigned to " << get_preg_str(preg) << "\n";
}
for (const auto& vreg : spilled) {
std::cerr << "Vreg " << vreg << " spilled to stack offset " << alloc.spill_map.at(vreg) << "\n";
}
return alloc;
}
RISCv32CodeGen::PhysicalReg RISCv32CodeGen::get_preg_or_temp(const std::string& vreg, const RegAllocResult& alloc) const {
if (alloc.vreg_to_preg.count(vreg)) {
return alloc.vreg_to_preg.at(vreg);
}
return PhysicalReg::T0; // Fallback for spilled registers, handled in emit_instructions
} }
} // namespace sysy } // namespace sysy

View File

@@ -6,60 +6,96 @@
#include <vector> #include <vector>
#include <map> #include <map>
#include <set> #include <set>
#include <memory>
#include <iostream>
#include <functional>
#include <stack>
namespace sysy { namespace sysy {
class RISCv32CodeGen { class RISCv32CodeGen {
public: public:
explicit RISCv32CodeGen(Module* mod) : module(mod) {} enum class PhysicalReg {
std::string code_gen(); // 生成模块的汇编代码 ZERO, RA, SP, GP, TP, T0, T1, T2, S0, S1, A0, A1, A2, A3, A4, A5, A6, A7, S2, S3, S4, S5, S6, S7, S8, S9, S10, S11, T3, T4, T5, T6
};
struct DAGNode {
enum NodeKind { CONSTANT, LOAD, STORE, BINARY, CALL, RETURN, BRANCH, ALLOCA_ADDR };
NodeKind kind;
Value* value = nullptr;
std::string inst;
std::string result_vreg;
std::vector<DAGNode*> operands;
std::vector<DAGNode*> users;
DAGNode(NodeKind k) : kind(k) {}
std::string getNodeKindString() const {
switch (kind) {
case CONSTANT: return "CONSTANT";
case LOAD: return "LOAD";
case STORE: return "STORE";
case BINARY: return "BINARY";
case CALL: return "CALL";
case RETURN: return "RETURN";
case BRANCH: return "BRANCH";
case ALLOCA_ADDR: return "ALLOCA_ADDR";
default: return "UNKNOWN";
}
}
};
struct RegAllocResult {
std::map<std::string, PhysicalReg> vreg_to_preg; // 虚拟寄存器到物理寄存器的映射
std::map<Value*, int> stack_map; // AllocaInst到栈偏移的映射
std::map<std::string, int> spill_map; // 溢出的虚拟寄存器到栈偏移的映射
int stack_size = 0; // 总栈帧大小
};
RISCv32CodeGen(Module* mod) : module(mod) {}
std::string code_gen();
std::string module_gen();
std::string function_gen(Function* func);
std::string basicBlock_gen(BasicBlock* bb, const RegAllocResult& alloc, int block_idx);
std::vector<std::unique_ptr<DAGNode>> build_dag(BasicBlock* bb);
void select_instructions(DAGNode* node, const RegAllocResult& alloc);
void emit_instructions(DAGNode* node, std::stringstream& ss, const RegAllocResult& alloc, std::set<DAGNode*>& emitted_nodes);
std::map<Instruction*, std::set<std::string>> liveness_analysis(Function* func);
std::map<std::string, std::set<std::string>> build_interference_graph(
const std::map<Instruction*, std::set<std::string>>& live_sets);
RegAllocResult color_graph(Function* func, const std::map<std::string, std::set<std::string>>& interference_graph);
private: private:
Module* module; Module* module;
std::map<PhysicalReg, std::string> preg_to_str = {
// 物理寄存器 {PhysicalReg::ZERO, "zero"}, {PhysicalReg::RA, "ra"}, {PhysicalReg::SP, "sp"},
enum class PhysicalReg { {PhysicalReg::GP, "gp"}, {PhysicalReg::TP, "tp"}, {PhysicalReg::T0, "t0"},
T0, T1, T2, T3, T4, T5, T6, // x5-x7, x28-x31 {PhysicalReg::T1, "t1"}, {PhysicalReg::T2, "t2"}, {PhysicalReg::S0, "s0"},
A0, A1, A2, A3, A4, A5, A6, A7 // x10-x17 {PhysicalReg::S1, "s1"}, {PhysicalReg::A0, "a0"}, {PhysicalReg::A1, "a1"},
}; {PhysicalReg::A2, "a2"}, {PhysicalReg::A3, "a3"}, {PhysicalReg::A4, "a4"},
static const std::vector<PhysicalReg> allocable_regs; {PhysicalReg::A5, "a5"}, {PhysicalReg::A6, "a6"}, {PhysicalReg::A7, "a7"},
{PhysicalReg::S2, "s2"}, {PhysicalReg::S3, "s3"}, {PhysicalReg::S4, "s4"},
// 操作数 {PhysicalReg::S5, "s5"}, {PhysicalReg::S6, "s6"}, {PhysicalReg::S7, "s7"},
struct Operand { {PhysicalReg::S8, "s8"}, {PhysicalReg::S9, "s9"}, {PhysicalReg::S10, "s10"},
enum class Kind { Reg, Imm, Label }; {PhysicalReg::S11, "s11"}, {PhysicalReg::T3, "t3"}, {PhysicalReg::T4, "t4"},
Kind kind; {PhysicalReg::T5, "t5"}, {PhysicalReg::T6, "t6"}
Value* value; // 用于寄存器
std::string label; // 用于标签或立即数
Operand(Kind k, Value* v) : kind(k), value(v), label("") {}
Operand(Kind k, const std::string& l) : kind(k), value(nullptr), label(l) {}
}; };
// RISC-V 指令 std::vector<PhysicalReg> caller_saved = {
struct RISCv32Inst { PhysicalReg::T0, PhysicalReg::T1, PhysicalReg::T2, PhysicalReg::T3, PhysicalReg::T4, PhysicalReg::T5, PhysicalReg::T6,
std::string opcode; PhysicalReg::A0, PhysicalReg::A1, PhysicalReg::A2, PhysicalReg::A3, PhysicalReg::A4, PhysicalReg::A5, PhysicalReg::A6, PhysicalReg::A7
std::vector<Operand> operands;
RISCv32Inst(const std::string& op, const std::vector<Operand>& ops)
: opcode(op), operands(ops) {}
}; };
// 寄存器分配结果 std::vector<PhysicalReg> callee_saved = {
struct RegAllocResult { PhysicalReg::S0, PhysicalReg::S1, PhysicalReg::S2, PhysicalReg::S3, PhysicalReg::S4, PhysicalReg::S5,
std::map<Value*, PhysicalReg> reg_map; // 虚拟寄存器到物理寄存器的映射 PhysicalReg::S6, PhysicalReg::S7, PhysicalReg::S8, PhysicalReg::S9, PhysicalReg::S10, PhysicalReg::S11
std::map<Value*, int> stack_map; // 虚拟寄存器到堆栈槽的映射
int stack_size; // 堆栈帧大小
}; };
// 后端方法 std::string get_preg_str(PhysicalReg preg) const {
std::string module_gen(); return preg_to_str.at(preg);
std::string function_gen(Function* func); }
std::string basicBlock_gen(BasicBlock* bb, const RegAllocResult& alloc);
std::vector<RISCv32Inst> instruction_gen(Instruction* inst); PhysicalReg get_preg_or_temp(const std::string& vreg, const RegAllocResult& alloc) const;
RegAllocResult register_allocation(Function* func);
void eliminate_phi(Function* func);
std::map<Instruction*, std::set<Value*>> liveness_analysis(Function* func);
std::map<Value*, std::set<Value*>> build_interference_graph(
const std::map<Instruction*, std::set<Value*>>& live_sets);
std::string reg_to_string(PhysicalReg reg);
}; };
} // namespace sysy } // namespace sysy

View File

@@ -10,6 +10,7 @@ using namespace antlr4;
#include "SysYIRGenerator.h" #include "SysYIRGenerator.h"
#include "SysYIRPrinter.h" #include "SysYIRPrinter.h"
#include "SysYIROptPre.h" #include "SysYIROptPre.h"
#include "RISCv32Backend.h"
// #include "LLVMIRGenerator.h" // #include "LLVMIRGenerator.h"
using namespace sysy; using namespace sysy;
@@ -73,18 +74,26 @@ int main(int argc, char **argv) {
// visit AST to generate IR // visit AST to generate IR
SysYIRGenerator generator;
generator.visitCompUnit(moduleAST);
if (argStopAfter == "ir") { if (argStopAfter == "ir") {
SysYIRGenerator generator;
generator.visitCompUnit(moduleAST);
auto moduleIR = generator.get(); auto moduleIR = generator.get();
SysYPrinter printer(moduleIR); SysYPrinter printer(moduleIR);
printer.printIR(); printer.printIR();
auto builder = generator.getBuilder(); auto builder = generator.getBuilder();
SysYOptPre optPre(moduleIR, builder); SysYOptPre optPre(moduleIR, builder);
optPre.SysYOptimizateAfterIR(); optPre.SysYOptimizateAfterIR();
printer.printIR();
return EXIT_SUCCESS; return EXIT_SUCCESS;
} }
// generate assembly
auto module = generator.get();
sysy::RISCv32CodeGen codegen(module);
string asmCode = codegen.code_gen();
if (argStopAfter == "asm") {
cout << asmCode << endl;
return EXIT_SUCCESS;
}
return EXIT_SUCCESS; return EXIT_SUCCESS;
} }

View File

@@ -1,12 +1,8 @@
//test add //test add
int main(){ int main(){
int a, b; int a, b;
float d;
a = 10; a = 10;
b = 2; b = 2;
int c = a; return a + b;
d = 1.1 ;
return a + b + c;
} }

View File

@@ -5,10 +5,10 @@ int main() {
const int b = 2; const int b = 2;
int c; int c;
if (a == b) if (a != b)
c = a + b; c = b - a + 20; // 21 <- this
else else
c = a * b; c = a * b + b + b + 10; // 16
return c; return c;
} }

View File

@@ -7,7 +7,7 @@ int mul(int x, int y) {
int main(){ int main(){
int a, b; int a, b;
a = 10; a = 10;
b = 0; b = 3;
a = mul(a, b); a = mul(a, b); //60
return a + b; return a + b; //66
} }