[backend] introduced DAG, GraphAlloc

This commit is contained in:
Lixuanwang
2025-06-23 15:38:01 +08:00
parent af00612376
commit 7d37bd7528
2 changed files with 408 additions and 214 deletions

View File

@@ -1,6 +1,8 @@
#include "RISCv32Backend.h"
#include <sstream>
#include <algorithm>
#include <stdexcept>
#include <regex>
namespace sysy {
@@ -32,17 +34,6 @@ std::string RISCv32CodeGen::reg_to_string(PhysicalReg reg) {
default: return "";
}
}
// 简单的临时寄存器分配器
class TempRegAllocator {
std::vector<std::string> regs = {"t0", "t1", "t2", "t3", "t4", "t5", "t6"};
size_t current = 0;
public:
std::string get_next() {
if (current >= regs.size()) throw std::runtime_error("临时寄存器不足");
return regs[current++];
}
void reset() { current = 0; }
};
std::string RISCv32CodeGen::code_gen() {
std::stringstream ss;
@@ -112,179 +103,370 @@ std::string RISCv32CodeGen::function_gen(Function* func) {
std::string RISCv32CodeGen::basicBlock_gen(BasicBlock* bb, const RegAllocResult& alloc) {
std::stringstream ss;
ss << bb->getName() << ":\n";
for (const auto& inst : bb->getInstructions()) {
auto riscv_insts = instruction_gen(inst.get(), alloc);
for (const auto& riscv_inst : riscv_insts) {
ss << " " << riscv_inst << "\n";
}
auto dag = build_dag(bb);
std::vector<std::string> insts;
for (auto& node : dag) {
select_instructions(node.get(), alloc);
emit_instructions(node.get(), insts, alloc);
}
for (const auto& inst : insts) {
ss << " " << inst << "\n";
}
return ss.str();
}
std::vector<std::string> RISCv32CodeGen::instruction_gen(Instruction* inst, const RegAllocResult& alloc) {
std::vector<std::string> insts;
// DAG 构建
std::vector<std::unique_ptr<RISCv32CodeGen::DAGNode>> RISCv32CodeGen::build_dag(BasicBlock* bb) {
std::vector<std::unique_ptr<DAGNode>> nodes;
std::map<Value*, DAGNode*> value_to_node;
static int vreg_counter = 0; // Counter for unique vreg names
auto load_operand = [&](Value* val, const std::string& reg) {
if (auto constant = dynamic_cast<ConstantValue*>(val)) {
if (constant->isInt()) {
insts.push_back("li " + reg + ", " + std::to_string(constant->getInt()));
} else {
float f = constant->getFloat();
uint32_t float_bits = *(uint32_t*)&f;
insts.push_back("li " + reg + ", " + std::to_string(float_bits));
insts.push_back("fmv.w.x " + reg + ", " + reg);
}
} else if (alloc.stack_map.find(val) != alloc.stack_map.end()) {
insts.push_back("lw " + reg + ", " + std::to_string(alloc.stack_map.at(val)) + "(s0)");
} else if (auto global = dynamic_cast<GlobalValue*>(val)) {
insts.push_back("la " + reg + ", " + global->getName());
}
auto create_node = [&](DAGNode::NodeKind kind, Value* val = nullptr) {
auto node = std::make_unique<DAGNode>(kind);
node->value = val;
node->result_reg = val ? "v" + std::to_string(vreg_counter++) : "";
if (val) value_to_node[val] = node.get();
nodes.push_back(std::move(node));
return nodes.back().get();
};
if (auto alloca = dynamic_cast<AllocaInst*>(inst)) {
// 栈空间已在 register_allocation 中分配
}
else if (auto store = dynamic_cast<StoreInst*>(inst)) {
std::string val_reg = "t0";
load_operand(store->getValue(), val_reg);
auto ptr = store->getPointer();
if (auto alloca = dynamic_cast<AllocaInst*>(ptr)) {
int offset = alloc.stack_map.at(alloca);
insts.push_back("sw " + val_reg + ", " + std::to_string(offset) + "(s0)");
} else if (auto global = dynamic_cast<GlobalValue*>(ptr)) {
std::string ptr_reg = "t1";
insts.push_back("la " + ptr_reg + ", " + global->getName());
insts.push_back("sw " + val_reg + ", 0(" + ptr_reg + ")");
for (const auto& inst : bb->getInstructions()) {
if (auto alloca = dynamic_cast<AllocaInst*>(inst.get())) {
create_node(DAGNode::CONSTANT, alloca); // Allocate stack space
} else if (auto store = dynamic_cast<StoreInst*>(inst.get())) {
auto store_node = create_node(DAGNode::STORE);
auto val_node = value_to_node.find(store->getValue()) != value_to_node.end()
? value_to_node[store->getValue()]
: create_node(DAGNode::CONSTANT, store->getValue());
auto ptr_node = value_to_node.find(store->getPointer()) != value_to_node.end()
? value_to_node[store->getPointer()]
: create_node(DAGNode::CONSTANT, store->getPointer());
store_node->operands.push_back(val_node);
store_node->operands.push_back(ptr_node);
val_node->users.push_back(store_node);
ptr_node->users.push_back(store_node);
} else if (auto load = dynamic_cast<LoadInst*>(inst.get())) {
auto load_node = create_node(DAGNode::LOAD, load);
auto ptr_node = value_to_node.find(load->getPointer()) != value_to_node.end()
? value_to_node[load->getPointer()]
: create_node(DAGNode::CONSTANT, load->getPointer());
load_node->operands.push_back(ptr_node);
ptr_node->users.push_back(load_node);
} else if (auto bin = dynamic_cast<BinaryInst*>(inst.get())) {
auto bin_node = create_node(DAGNode::BINARY, bin);
auto lhs_node = value_to_node.find(bin->getLhs()) != value_to_node.end()
? value_to_node[bin->getLhs()]
: create_node(DAGNode::CONSTANT, bin->getLhs());
auto rhs_node = value_to_node.find(bin->getRhs()) != value_to_node.end()
? value_to_node[bin->getRhs()]
: create_node(DAGNode::CONSTANT, bin->getRhs());
bin_node->operands.push_back(lhs_node);
bin_node->operands.push_back(rhs_node);
lhs_node->users.push_back(bin_node);
rhs_node->users.push_back(bin_node);
} else if (auto call = dynamic_cast<CallInst*>(inst.get())) {
auto call_node = create_node(DAGNode::CALL, call);
for (auto arg : call->getArguments()) {
auto arg_node = value_to_node.find(arg->getValue()) != value_to_node.end()
? value_to_node[arg->getValue()]
: create_node(DAGNode::CONSTANT, arg->getValue());
call_node->operands.push_back(arg_node);
arg_node->users.push_back(call_node);
}
} else if (auto ret = dynamic_cast<ReturnInst*>(inst.get())) {
auto ret_node = create_node(DAGNode::RETURN);
if (ret->hasReturnValue()) {
auto val_node = value_to_node.find(ret->getReturnValue()) != value_to_node.end()
? value_to_node[ret->getReturnValue()]
: create_node(DAGNode::CONSTANT, ret->getReturnValue());
ret_node->operands.push_back(val_node);
val_node->users.push_back(ret_node);
}
}
}
else if (auto load = dynamic_cast<LoadInst*>(inst)) {
std::string dst_reg = "t0";
auto ptr = load->getPointer();
if (auto alloca = dynamic_cast<AllocaInst*>(ptr)) {
int offset = alloc.stack_map.at(alloca);
insts.push_back("lw " + dst_reg + ", " + std::to_string(offset) + "(s0)");
} else if (auto global = dynamic_cast<GlobalValue*>(ptr)) {
std::string ptr_reg = "t1";
insts.push_back("la " + ptr_reg + ", " + global->getName());
insts.push_back("lw " + dst_reg + ", 0(" + ptr_reg + ")");
return nodes;
}
// 指令选择
void RISCv32CodeGen::select_instructions(DAGNode* node, const RegAllocResult& alloc) {
if (!node->inst.empty()) return;
for (auto operand : node->operands) {
select_instructions(operand, alloc);
}
switch (node->kind) {
case DAGNode::CONSTANT: {
if (auto constant = dynamic_cast<ConstantValue*>(node->value)) {
if (constant->isInt()) {
node->inst = "li " + node->result_reg + ", " + std::to_string(constant->getInt());
} else {
float f = constant->getFloat();
uint32_t float_bits = *(uint32_t*)&f;
node->inst = "li " + node->result_reg + ", " + std::to_string(float_bits) + "\nfmv.w.x " + node->result_reg + ", " + node->result_reg;
}
} else if (auto global = dynamic_cast<GlobalValue*>(node->value)) {
node->inst = "la " + node->result_reg + ", " + global->getName();
} else if (auto alloca = dynamic_cast<AllocaInst*>(node->value)) {
if (alloc.stack_map.find(alloca) != alloc.stack_map.end()) {
node->inst = ""; // Stack address handled in LOAD/STORE
}
}
break;
}
if (alloc.stack_map.find(load) != alloc.stack_map.end()) {
insts.push_back("sw " + dst_reg + ", " + std::to_string(alloc.stack_map.at(load)) + "(s0)");
case DAGNode::LOAD: {
auto ptr_reg = node->operands[0]->result_reg;
if (alloc.stack_map.find(node->operands[0]->value) != alloc.stack_map.end()) {
int offset = alloc.stack_map.at(node->operands[0]->value);
node->inst = "lw " + node->result_reg + ", " + std::to_string(offset) + "(s0)";
} else {
node->inst = "lw " + node->result_reg + ", 0(" + ptr_reg + ")";
}
break;
}
case DAGNode::STORE: {
auto val_reg = node->operands[0]->result_reg;
auto ptr_reg = node->operands[1]->result_reg;
if (alloc.stack_map.find(node->operands[1]->value) != alloc.stack_map.end()) {
int offset = alloc.stack_map.at(node->operands[1]->value);
node->inst = "sw " + val_reg + ", " + std::to_string(offset) + "(s0)";
} else {
node->inst = "sw " + val_reg + ", 0(" + ptr_reg + ")";
}
break;
}
case DAGNode::BINARY: {
auto bin = dynamic_cast<BinaryInst*>(node->value);
auto lhs_reg = node->operands[0]->result_reg;
auto rhs_reg = node->operands[1]->result_reg;
std::string opcode;
switch (bin->getKind()) {
case BinaryInst::kAdd: opcode = "add"; break;
case BinaryInst::kMul: opcode = "mul"; break;
default: break;
}
if (!opcode.empty()) {
node->inst = opcode + " " + node->result_reg + ", " + lhs_reg + ", " + rhs_reg;
}
break;
}
case DAGNode::CALL: {
auto call = dynamic_cast<CallInst*>(node->value);
std::string insts;
for (size_t i = 0; i < node->operands.size() && i < 8; ++i) {
insts += "mv a" + std::to_string(i) + ", " + node->operands[i]->result_reg + "\n";
}
insts += "jal " + call->getCallee()->getName();
if (call->getType()->isInt() || call->getType()->isFloat()) {
insts += "\nmv " + node->result_reg + ", a0";
}
node->inst = insts;
break;
}
case DAGNode::RETURN: {
if (!node->operands.empty()) {
node->inst = "mv a0, " + node->operands[0]->result_reg;
}
break;
}
default: break;
}
}
// 指令发射
void RISCv32CodeGen::emit_instructions(DAGNode* node, std::vector<std::string>& insts, const RegAllocResult& alloc) {
for (auto operand : node->operands) {
emit_instructions(operand, insts, alloc);
}
if (!node->inst.empty()) {
std::stringstream ss(node->inst);
std::string line;
while (std::getline(ss, line, '\n')) {
if (!line.empty()) {
// Replace virtual registers with physical registers
if (!node->result_reg.empty() && alloc.vreg_to_preg.find(node->result_reg) != alloc.vreg_to_preg.end()) {
line = std::regex_replace(line, std::regex("\\b" + node->result_reg + "\\b"), reg_to_string(alloc.vreg_to_preg.at(node->result_reg)));
}
for (auto operand : node->operands) {
if (!operand->result_reg.empty() && alloc.vreg_to_preg.find(operand->result_reg) != alloc.vreg_to_preg.end()) {
line = std::regex_replace(line, std::regex("\\b" + operand->result_reg + "\\b"), reg_to_string(alloc.vreg_to_preg.at(operand->result_reg)));
}
}
insts.push_back(line);
}
}
}
else if (auto bin = dynamic_cast<BinaryInst*>(inst)) {
std::string lhs_reg = "t0";
std::string rhs_reg = "t1";
std::string dst_reg = "t2";
load_operand(bin->getLhs(), lhs_reg);
load_operand(bin->getRhs(), rhs_reg);
std::string opcode;
switch (bin->getKind()) {
case BinaryInst::kAdd: opcode = "add"; break;
case BinaryInst::kSub: opcode = "sub"; break;
case BinaryInst::kMul: opcode = "mul"; break;
case BinaryInst::kDiv: opcode = "div"; break;
case BinaryInst::kRem: opcode = "rem"; break;
case BinaryInst::kFAdd: opcode = "fadd.s"; break;
case BinaryInst::kFSub: opcode = "fsub.s"; break;
case BinaryInst::kFMul: opcode = "fmul.s"; break;
case BinaryInst::kFDiv: opcode = "fdiv.s"; break;
case BinaryInst::kICmpEQ: insts.push_back("seqz " + dst_reg + ", " + lhs_reg); break;
case BinaryInst::kICmpNE: insts.push_back("snez " + dst_reg + ", " + lhs_reg); break;
case BinaryInst::kICmpLT: insts.push_back("slt " + dst_reg + ", " + lhs_reg + ", " + rhs_reg); break;
case BinaryInst::kICmpGT: insts.push_back("sgt " + dst_reg + ", " + lhs_reg + ", " + rhs_reg); break;
case BinaryInst::kICmpLE: insts.push_back("sle " + dst_reg + ", " + lhs_reg + ", " + rhs_reg); break;
case BinaryInst::kICmpGE: insts.push_back("sge " + dst_reg + ", " + lhs_reg + ", " + rhs_reg); break;
case BinaryInst::kAnd: opcode = "and"; break;
case BinaryInst::kOr: opcode = "or"; break;
default: return insts;
}
if (!opcode.empty()) {
insts.push_back(opcode + " " + dst_reg + ", " + lhs_reg + ", " + rhs_reg);
}
if (alloc.stack_map.find(bin) != alloc.stack_map.end()) {
insts.push_back("sw " + dst_reg + ", " + std::to_string(alloc.stack_map.at(bin)) + "(s0)");
}
// 活跃性分析
std::map<Instruction*, std::set<std::string>> RISCv32CodeGen::liveness_analysis(Function* func) {
std::map<Instruction*, std::set<std::string>> live_in, live_out;
bool changed = true;
while (changed) {
changed = false;
for (auto it = func->getBasicBlocks_NoRange().rbegin(); it != func->getBasicBlocks_NoRange().rend(); ++it) {
auto bb = it->get();
for (auto inst_it = bb->getInstructions().rbegin(); inst_it != bb->getInstructions().rend(); ++inst_it) {
auto inst = inst_it->get();
std::set<std::string> new_in, new_out;
// Calculate live_out
if (auto br = dynamic_cast<CondBrInst*>(inst)) {
new_out.insert(live_in[br->getThenBlock()->getInstructions().front().get()].begin(),
live_in[br->getThenBlock()->getInstructions().front().get()].end());
new_out.insert(live_in[br->getElseBlock()->getInstructions().front().get()].begin(),
live_in[br->getElseBlock()->getInstructions().front().get()].end());
} else if (auto uncond = dynamic_cast<UncondBrInst*>(inst)) {
new_out.insert(live_in[uncond->getBlock()->getInstructions().front().get()].begin(),
live_in[uncond->getBlock()->getInstructions().front().get()].end());
} else {
auto next_inst = std::next(inst_it);
if (next_inst != bb->getInstructions().rend()) {
new_out = live_in[next_inst->get()];
}
}
// Calculate live_in = use (live_out - def)
std::set<std::string> use, def;
if (auto bin = dynamic_cast<BinaryInst*>(inst)) {
if (value_vreg_map.find(bin->getLhs()) != value_vreg_map.end())
use.insert(value_vreg_map[bin->getLhs()]);
if (value_vreg_map.find(bin->getRhs()) != value_vreg_map.end())
use.insert(value_vreg_map[bin->getRhs()]);
if (value_vreg_map.find(bin) != value_vreg_map.end())
def.insert(value_vreg_map[bin]);
} else if (auto call = dynamic_cast<CallInst*>(inst)) {
for (auto arg : call->getArguments()) {
if (value_vreg_map.find(arg->getValue()) != value_vreg_map.end())
use.insert(value_vreg_map[arg->getValue()]);
}
if ((call->getType()->isInt() || call->getType()->isFloat()) && value_vreg_map.find(call) != value_vreg_map.end()) {
def.insert(value_vreg_map[call]);
}
} else if (auto load = dynamic_cast<LoadInst*>(inst)) {
if (value_vreg_map.find(load->getPointer()) != value_vreg_map.end())
use.insert(value_vreg_map[load->getPointer()]);
if (value_vreg_map.find(load) != value_vreg_map.end())
def.insert(value_vreg_map[load]);
} else if (auto store = dynamic_cast<StoreInst*>(inst)) {
if (value_vreg_map.find(store->getValue()) != value_vreg_map.end())
use.insert(value_vreg_map[store->getValue()]);
if (value_vreg_map.find(store->getPointer()) != value_vreg_map.end())
use.insert(value_vreg_map[store->getPointer()]);
} else if (auto ret = dynamic_cast<ReturnInst*>(inst)) {
if (ret->hasReturnValue() && value_vreg_map.find(ret->getReturnValue()) != value_vreg_map.end()) {
use.insert(value_vreg_map[ret->getReturnValue()]);
}
}
new_in = use;
for (const auto& vreg : new_out) {
if (def.find(vreg) == def.end()) {
new_in.insert(vreg);
}
}
if (live_in[inst] != new_in || live_out[inst] != new_out) {
live_in[inst] = new_in;
live_out[inst] = new_out;
changed = true;
}
}
}
}
else if (auto uny = dynamic_cast<UnaryInst*>(inst)) {
std::string src_reg = "t0";
std::string dst_reg = "t1";
load_operand(uny->getOperand(), src_reg);
switch (uny->getKind()) {
case UnaryInst::kNeg: insts.push_back("sub " + dst_reg + ", x0, " + src_reg); break;
case UnaryInst::kNot: insts.push_back("xori " + dst_reg + ", " + src_reg + ", -1"); break;
case UnaryInst::kFNeg: insts.push_back("fneg.s " + dst_reg + ", " + src_reg); break;
case UnaryInst::kFtoI: insts.push_back("fcvt.w.s " + dst_reg + ", " + src_reg); break;
case UnaryInst::kItoF: insts.push_back("fcvt.s.w " + dst_reg + ", " + src_reg); break;
case UnaryInst::kBitFtoI: insts.push_back("fmv.x.w " + dst_reg + ", " + src_reg); break;
case UnaryInst::kBitItoF: insts.push_back("fmv.w.x " + dst_reg + ", " + src_reg); break;
default: return insts;
return live_in;
}
// 干扰图构建
std::map<std::string, std::set<std::string>> RISCv32CodeGen::build_interference_graph(
const std::map<Instruction*, std::set<std::string>>& live_sets) {
std::map<std::string, std::set<std::string>> graph;
for (const auto& pair : live_sets) {
auto inst = pair.first;
const auto& live = pair.second;
std::string def;
if (auto bin = dynamic_cast<BinaryInst*>(inst)) {
if (value_vreg_map.find(bin) != value_vreg_map.end())
def = value_vreg_map[bin];
} else if (auto call = dynamic_cast<CallInst*>(inst)) {
if ((call->getType()->isInt() || call->getType()->isFloat()) && value_vreg_map.find(call) != value_vreg_map.end()) {
def = value_vreg_map[call];
}
} else if (auto load = dynamic_cast<LoadInst*>(inst)) {
if (value_vreg_map.find(load) != value_vreg_map.end())
def = value_vreg_map[load];
}
if (alloc.stack_map.find(uny) != alloc.stack_map.end()) {
insts.push_back("sw " + dst_reg + ", " + std::to_string(alloc.stack_map.at(uny)) + "(s0)");
if (!def.empty()) {
for (const auto& live_vreg : live) {
if (live_vreg != def) {
graph[def].insert(live_vreg);
graph[live_vreg].insert(def);
}
}
}
}
else if (auto call = dynamic_cast<CallInst*>(inst)) {
auto args = call->getArguments();
size_t i = 0;
for (auto it = args.begin(); it != args.end() && i < 8; ++it, ++i) {
load_operand((*it)->getValue(), "a" + std::to_string(i));
return graph;
}
// 图着色
void RISCv32CodeGen::color_graph(std::map<std::string, PhysicalReg>& vreg_to_preg,
const std::map<std::string, std::set<std::string>>& interference_graph) {
std::vector<std::string> stack;
std::map<std::string, std::set<std::string>> temp_graph = interference_graph;
while (!temp_graph.empty()) {
std::string node_to_remove;
for (const auto& pair : temp_graph) {
if (pair.second.size() < allocable_regs.size()) {
node_to_remove = pair.first;
break;
}
}
insts.push_back("jal " + call->getCallee()->getName());
if (alloc.stack_map.find(call) != alloc.stack_map.end()) {
insts.push_back("sw a0, " + std::to_string(alloc.stack_map.at(call)) + "(s0)");
if (node_to_remove.empty()) {
node_to_remove = temp_graph.begin()->first; // Spill if necessary
}
}
else if (auto condBr = dynamic_cast<CondBrInst*>(inst)) {
std::string cond_reg = "t0";
load_operand(condBr->getCondition(), cond_reg);
insts.push_back("bnez " + cond_reg + ", " + condBr->getThenBlock()->getName());
insts.push_back("j " + condBr->getElseBlock()->getName());
}
else if (auto br = dynamic_cast<UncondBrInst*>(inst)) {
insts.push_back("j " + br->getBlock()->getName());
}
else if (auto ret = dynamic_cast<ReturnInst*>(inst)) {
if (ret->hasReturnValue()) {
load_operand(ret->getReturnValue(), "a0");
stack.push_back(node_to_remove);
for (auto& pair : temp_graph) {
pair.second.erase(node_to_remove);
}
temp_graph.erase(node_to_remove);
}
else if (auto la = dynamic_cast<LaInst*>(inst)) {
std::string dst_reg = "t0";
load_operand(la->getPointer(), dst_reg);
for (size_t i = 0; i < la->getNumIndices(); ++i) {
std::string idx_reg = "t1";
load_operand(la->getIndex(i), idx_reg);
insts.push_back("slli " + idx_reg + ", " + idx_reg + ", 2");
insts.push_back("add " + dst_reg + ", " + dst_reg + ", " + idx_reg);
while (!stack.empty()) {
auto vreg = stack.back();
stack.pop_back();
std::set<std::string> used_colors;
for (const auto& neighbor : interference_graph.at(vreg)) {
if (vreg_to_preg.find(neighbor) != vreg_to_preg.end()) {
used_colors.insert(reg_to_string(vreg_to_preg[neighbor]));
}
}
if (alloc.stack_map.find(la) != alloc.stack_map.end()) {
insts.push_back("sw " + dst_reg + ", " + std::to_string(alloc.stack_map.at(la)) + "(s0)");
bool assigned = false;
for (auto preg : allocable_regs) {
if (used_colors.find(reg_to_string(preg)) == used_colors.end()) {
vreg_to_preg[vreg] = preg;
assigned = true;
break;
}
}
// If no register is available, spill to stack (handled in register_allocation)
}
else if (auto memset = dynamic_cast<MemsetInst*>(inst)) {
std::string ptr_reg = "t0";
std::string val_reg = "t1";
std::string size_reg = "t2";
load_operand(memset->getPointer(), ptr_reg);
load_operand(memset->getValue(), val_reg);
load_operand(memset->getSize(), size_reg);
insts.push_back("mv t3, " + ptr_reg);
insts.push_back("add t4, " + ptr_reg + ", " + size_reg);
insts.push_back("1: sw " + val_reg + ", 0(" + ptr_reg + ")");
insts.push_back("addi " + ptr_reg + ", " + ptr_reg + ", 4");
insts.push_back("blt " + ptr_reg + ", t4, 1b");
}
else if (auto phi = dynamic_cast<PhiInst*>(inst)) {
// Phi 指令由 eliminate_phi 处理
}
return insts;
}
RISCv32CodeGen::RegAllocResult RISCv32CodeGen::register_allocation(Function* func) {
RegAllocResult result;
int stack_offset = 0;
std::set<Value*> allocated;
value_vreg_map.clear(); // Clear vreg map for new function
static int vreg_counter = 0; // Counter for unique vreg names
// 分配局部变量栈空间
for (const auto& bb : func->getBasicBlocks()) {
@@ -292,44 +474,73 @@ RISCv32CodeGen::RegAllocResult RISCv32CodeGen::register_allocation(Function* fun
if (auto alloca = dynamic_cast<AllocaInst*>(inst.get())) {
if (result.stack_map.find(alloca) == result.stack_map.end()) {
result.stack_map[alloca] = stack_offset;
value_vreg_map[alloca] = "v" + std::to_string(vreg_counter++);
stack_offset += 4;
}
} else if (auto load = dynamic_cast<LoadInst*>(inst.get())) {
if (value_vreg_map.find(load) == value_vreg_map.end()) {
value_vreg_map[load] = "v" + std::to_string(vreg_counter++);
}
} else if (auto bin = dynamic_cast<BinaryInst*>(inst.get())) {
if (value_vreg_map.find(bin) == value_vreg_map.end()) {
value_vreg_map[bin] = "v" + std::to_string(vreg_counter++);
}
} else if (auto call = dynamic_cast<CallInst*>(inst.get())) {
if ((call->getType()->isInt() || call->getType()->isFloat()) && value_vreg_map.find(call) == value_vreg_map.end()) {
value_vreg_map[call] = "v" + std::to_string(vreg_counter++);
}
}
}
}
// 分配函数参数栈空间(入口块的 arguments
// 分配函数参数栈空间
auto entry_block = func->getEntryBlock();
auto args = entry_block->getArguments();
for (size_t i = 0; i < args.size(); ++i) {
if (i >= 8) { // 超过 8 个参数需要栈空间
if (i >= 8) {
if (result.stack_map.find(args[i]) == result.stack_map.end()) {
result.stack_map[args[i]] = stack_offset;
value_vreg_map[args[i]] = "v" + std::to_string(vreg_counter++);
stack_offset += 4;
}
} else {
value_vreg_map[args[i]] = "v" + std::to_string(vreg_counter++);
}
}
// 分配中间结果栈空间(如 BinaryInst 和 CallInst
// 图着色寄存器分配
auto live_sets = liveness_analysis(func);
auto interference_graph = build_interference_graph(live_sets);
color_graph(result.vreg_to_preg, interference_graph);
// 分配溢出栈空间
for (const auto& bb : func->getBasicBlocks()) {
for (const auto& inst : bb->getInstructions()) {
if (auto bin = dynamic_cast<BinaryInst*>(inst.get())) {
if (result.stack_map.find(bin) == result.stack_map.end() && allocated.find(bin) == allocated.end()) {
std::string vreg = value_vreg_map[bin];
if (result.vreg_to_preg.find(vreg) == result.vreg_to_preg.end()) {
result.stack_map[bin] = stack_offset;
stack_offset += 4;
allocated.insert(bin);
}
} else if (auto call = dynamic_cast<CallInst*>(inst.get())) {
if (result.stack_map.find(call) == result.stack_map.end() && allocated.find(call) == allocated.end()) {
result.stack_map[call] = stack_offset;
if (call->getType()->isInt() || call->getType()->isFloat()) {
std::string vreg = value_vreg_map[call];
if (result.vreg_to_preg.find(vreg) == result.vreg_to_preg.end()) {
result.stack_map[call] = stack_offset;
stack_offset += 4;
}
}
} else if (auto load = dynamic_cast<LoadInst*>(inst.get())) {
std::string vreg = value_vreg_map[load];
if (result.vreg_to_preg.find(vreg) == result.vreg_to_preg.end()) {
result.stack_map[load] = stack_offset;
stack_offset += 4;
allocated.insert(call);
}
}
}
}
// 检查是否需要保存 ra 和 s0
// 保存 ra 和 s0
bool needs_caller_saved = false;
for (const auto& bb : func->getBasicBlocks()) {
for (const auto& inst : bb->getInstructions()) {
@@ -353,20 +564,7 @@ RISCv32CodeGen::RegAllocResult RISCv32CodeGen::register_allocation(Function* fun
}
void RISCv32CodeGen::eliminate_phi(Function* func) {
// Placeholder: Phi elimination requires inserting moves at predecessor blocks
}
std::map<Instruction*, std::set<Value*>> RISCv32CodeGen::liveness_analysis(Function* func) {
std::map<Instruction*, std::set<Value*>> live_sets;
// Placeholder: Implement liveness analysis
return live_sets;
}
std::map<Value*, std::set<Value*>> RISCv32CodeGen::build_interference_graph(
const std::map<Instruction*, std::set<Value*>>& live_sets) {
std::map<Value*, std::set<Value*>> graph;
// Placeholder: Implement interference graph
return graph;
// TODO: 插入 move 指令处理 phi
}
} // namespace sysy

View File

@@ -6,61 +6,57 @@
#include <vector>
#include <map>
#include <set>
#include <memory>
namespace sysy {
class RISCv32CodeGen {
public:
explicit RISCv32CodeGen(Module* mod) : module(mod) {}
std::string code_gen(); // 生成模块的汇编代码
private:
Module* module;
// 物理寄存器
enum class PhysicalReg {
S0, // x8, 帧指针
T0, T1, T2, T3, T4, T5, T6, // x5-x7, x28-x31
A0, A1, A2, A3, A4, A5, A6, A7 // x10-x17
};
static const std::vector<PhysicalReg> allocable_regs;
// 操作数
struct Operand {
enum class Kind { Reg, Imm, Label };
Kind kind;
Value* value; // 用于寄存器
std::string label; // 用于标签或立即数
Operand(Kind k, Value* v) : kind(k), value(v), label("") {}
Operand(Kind k, const std::string& l) : kind(k), value(nullptr), label(l) {}
S0, T0, T1, T2, T3, T4, T5, T6,
A0, A1, A2, A3, A4, A5, A6, A7
};
// RISC-V 指令
struct RISCv32Inst {
std::string opcode;
std::vector<Operand> operands;
RISCv32Inst(const std::string& op, const std::vector<Operand>& ops)
: opcode(op), operands(ops) {}
// Move DAGNode and RegAllocResult to public section
struct DAGNode {
enum NodeKind { CONSTANT, LOAD, STORE, BINARY, CALL, RETURN };
NodeKind kind;
Value* value = nullptr;
std::string inst;
std::string result_reg;
std::vector<DAGNode*> operands;
std::vector<DAGNode*> users;
DAGNode(NodeKind k) : kind(k) {}
};
// 寄存器分配结果
struct RegAllocResult {
std::map<Value*, PhysicalReg> reg_map; // 虚拟寄存器到物理寄存器的映射
std::map<Value*, int> stack_map; // 虚拟寄存器到堆栈槽的映射
int stack_size; // 堆栈帧大小
std::map<std::string, PhysicalReg> vreg_to_preg;
std::map<Value*, int> stack_map;
int stack_size = 0;
};
// 后端方法
RISCv32CodeGen(Module* mod) : module(mod) {}
std::string code_gen();
std::string module_gen();
std::string function_gen(Function* func);
std::string basicBlock_gen(BasicBlock* bb, const RegAllocResult& alloc);
std::vector<std::string> instruction_gen(Instruction* inst, const RegAllocResult& alloc);
std::vector<std::unique_ptr<DAGNode>> build_dag(BasicBlock* bb);
void select_instructions(DAGNode* node, const RegAllocResult& alloc); // Use const
void emit_instructions(DAGNode* node, std::vector<std::string>& insts, const RegAllocResult& alloc); // Add alloc
std::map<Instruction*, std::set<std::string>> liveness_analysis(Function* func);
std::map<std::string, std::set<std::string>> build_interference_graph(
const std::map<Instruction*, std::set<std::string>>& live_sets);
void color_graph(std::map<std::string, PhysicalReg>& vreg_to_preg,
const std::map<std::string, std::set<std::string>>& interference_graph);
RegAllocResult register_allocation(Function* func);
void eliminate_phi(Function* func);
std::map<Instruction*, std::set<Value*>> liveness_analysis(Function* func);
std::map<Value*, std::set<Value*>> build_interference_graph(
const std::map<Instruction*, std::set<Value*>>& live_sets);
std::string reg_to_string(PhysicalReg reg);
private:
static const std::vector<PhysicalReg> allocable_regs;
std::map<Value*, std::string> value_vreg_map;
Module* module;
};
} // namespace sysy