diff --git a/src/RISCv32Backend.cpp b/src/RISCv32Backend.cpp index 46e91cd..1056c7c 100644 --- a/src/RISCv32Backend.cpp +++ b/src/RISCv32Backend.cpp @@ -1,6 +1,8 @@ #include "RISCv32Backend.h" #include #include +#include +#include namespace sysy { @@ -32,17 +34,6 @@ std::string RISCv32CodeGen::reg_to_string(PhysicalReg reg) { default: return ""; } } -// 简单的临时寄存器分配器 -class TempRegAllocator { - std::vector regs = {"t0", "t1", "t2", "t3", "t4", "t5", "t6"}; - size_t current = 0; -public: - std::string get_next() { - if (current >= regs.size()) throw std::runtime_error("临时寄存器不足"); - return regs[current++]; - } - void reset() { current = 0; } -}; std::string RISCv32CodeGen::code_gen() { std::stringstream ss; @@ -112,179 +103,370 @@ std::string RISCv32CodeGen::function_gen(Function* func) { std::string RISCv32CodeGen::basicBlock_gen(BasicBlock* bb, const RegAllocResult& alloc) { std::stringstream ss; ss << bb->getName() << ":\n"; - for (const auto& inst : bb->getInstructions()) { - auto riscv_insts = instruction_gen(inst.get(), alloc); - for (const auto& riscv_inst : riscv_insts) { - ss << " " << riscv_inst << "\n"; - } + auto dag = build_dag(bb); + std::vector insts; + for (auto& node : dag) { + select_instructions(node.get(), alloc); + emit_instructions(node.get(), insts, alloc); + } + for (const auto& inst : insts) { + ss << " " << inst << "\n"; } return ss.str(); } -std::vector RISCv32CodeGen::instruction_gen(Instruction* inst, const RegAllocResult& alloc) { - std::vector insts; +// DAG 构建 +std::vector> RISCv32CodeGen::build_dag(BasicBlock* bb) { + std::vector> nodes; + std::map value_to_node; + static int vreg_counter = 0; // Counter for unique vreg names - auto load_operand = [&](Value* val, const std::string& reg) { - if (auto constant = dynamic_cast(val)) { - if (constant->isInt()) { - insts.push_back("li " + reg + ", " + std::to_string(constant->getInt())); - } else { - float f = constant->getFloat(); - uint32_t float_bits = *(uint32_t*)&f; - insts.push_back("li " + reg + ", " + std::to_string(float_bits)); - insts.push_back("fmv.w.x " + reg + ", " + reg); - } - } else if (alloc.stack_map.find(val) != alloc.stack_map.end()) { - insts.push_back("lw " + reg + ", " + std::to_string(alloc.stack_map.at(val)) + "(s0)"); - } else if (auto global = dynamic_cast(val)) { - insts.push_back("la " + reg + ", " + global->getName()); - } + auto create_node = [&](DAGNode::NodeKind kind, Value* val = nullptr) { + auto node = std::make_unique(kind); + node->value = val; + node->result_reg = val ? "v" + std::to_string(vreg_counter++) : ""; + if (val) value_to_node[val] = node.get(); + nodes.push_back(std::move(node)); + return nodes.back().get(); }; - if (auto alloca = dynamic_cast(inst)) { - // 栈空间已在 register_allocation 中分配 - } - else if (auto store = dynamic_cast(inst)) { - std::string val_reg = "t0"; - load_operand(store->getValue(), val_reg); - auto ptr = store->getPointer(); - if (auto alloca = dynamic_cast(ptr)) { - int offset = alloc.stack_map.at(alloca); - insts.push_back("sw " + val_reg + ", " + std::to_string(offset) + "(s0)"); - } else if (auto global = dynamic_cast(ptr)) { - std::string ptr_reg = "t1"; - insts.push_back("la " + ptr_reg + ", " + global->getName()); - insts.push_back("sw " + val_reg + ", 0(" + ptr_reg + ")"); + for (const auto& inst : bb->getInstructions()) { + if (auto alloca = dynamic_cast(inst.get())) { + create_node(DAGNode::CONSTANT, alloca); // Allocate stack space + } else if (auto store = dynamic_cast(inst.get())) { + auto store_node = create_node(DAGNode::STORE); + auto val_node = value_to_node.find(store->getValue()) != value_to_node.end() + ? value_to_node[store->getValue()] + : create_node(DAGNode::CONSTANT, store->getValue()); + auto ptr_node = value_to_node.find(store->getPointer()) != value_to_node.end() + ? value_to_node[store->getPointer()] + : create_node(DAGNode::CONSTANT, store->getPointer()); + store_node->operands.push_back(val_node); + store_node->operands.push_back(ptr_node); + val_node->users.push_back(store_node); + ptr_node->users.push_back(store_node); + } else if (auto load = dynamic_cast(inst.get())) { + auto load_node = create_node(DAGNode::LOAD, load); + auto ptr_node = value_to_node.find(load->getPointer()) != value_to_node.end() + ? value_to_node[load->getPointer()] + : create_node(DAGNode::CONSTANT, load->getPointer()); + load_node->operands.push_back(ptr_node); + ptr_node->users.push_back(load_node); + } else if (auto bin = dynamic_cast(inst.get())) { + auto bin_node = create_node(DAGNode::BINARY, bin); + auto lhs_node = value_to_node.find(bin->getLhs()) != value_to_node.end() + ? value_to_node[bin->getLhs()] + : create_node(DAGNode::CONSTANT, bin->getLhs()); + auto rhs_node = value_to_node.find(bin->getRhs()) != value_to_node.end() + ? value_to_node[bin->getRhs()] + : create_node(DAGNode::CONSTANT, bin->getRhs()); + bin_node->operands.push_back(lhs_node); + bin_node->operands.push_back(rhs_node); + lhs_node->users.push_back(bin_node); + rhs_node->users.push_back(bin_node); + } else if (auto call = dynamic_cast(inst.get())) { + auto call_node = create_node(DAGNode::CALL, call); + for (auto arg : call->getArguments()) { + auto arg_node = value_to_node.find(arg->getValue()) != value_to_node.end() + ? value_to_node[arg->getValue()] + : create_node(DAGNode::CONSTANT, arg->getValue()); + call_node->operands.push_back(arg_node); + arg_node->users.push_back(call_node); + } + } else if (auto ret = dynamic_cast(inst.get())) { + auto ret_node = create_node(DAGNode::RETURN); + if (ret->hasReturnValue()) { + auto val_node = value_to_node.find(ret->getReturnValue()) != value_to_node.end() + ? value_to_node[ret->getReturnValue()] + : create_node(DAGNode::CONSTANT, ret->getReturnValue()); + ret_node->operands.push_back(val_node); + val_node->users.push_back(ret_node); + } } } - else if (auto load = dynamic_cast(inst)) { - std::string dst_reg = "t0"; - auto ptr = load->getPointer(); - if (auto alloca = dynamic_cast(ptr)) { - int offset = alloc.stack_map.at(alloca); - insts.push_back("lw " + dst_reg + ", " + std::to_string(offset) + "(s0)"); - } else if (auto global = dynamic_cast(ptr)) { - std::string ptr_reg = "t1"; - insts.push_back("la " + ptr_reg + ", " + global->getName()); - insts.push_back("lw " + dst_reg + ", 0(" + ptr_reg + ")"); + + return nodes; +} + +// 指令选择 +void RISCv32CodeGen::select_instructions(DAGNode* node, const RegAllocResult& alloc) { + if (!node->inst.empty()) return; + + for (auto operand : node->operands) { + select_instructions(operand, alloc); + } + + switch (node->kind) { + case DAGNode::CONSTANT: { + if (auto constant = dynamic_cast(node->value)) { + if (constant->isInt()) { + node->inst = "li " + node->result_reg + ", " + std::to_string(constant->getInt()); + } else { + float f = constant->getFloat(); + uint32_t float_bits = *(uint32_t*)&f; + node->inst = "li " + node->result_reg + ", " + std::to_string(float_bits) + "\nfmv.w.x " + node->result_reg + ", " + node->result_reg; + } + } else if (auto global = dynamic_cast(node->value)) { + node->inst = "la " + node->result_reg + ", " + global->getName(); + } else if (auto alloca = dynamic_cast(node->value)) { + if (alloc.stack_map.find(alloca) != alloc.stack_map.end()) { + node->inst = ""; // Stack address handled in LOAD/STORE + } + } + break; } - if (alloc.stack_map.find(load) != alloc.stack_map.end()) { - insts.push_back("sw " + dst_reg + ", " + std::to_string(alloc.stack_map.at(load)) + "(s0)"); + case DAGNode::LOAD: { + auto ptr_reg = node->operands[0]->result_reg; + if (alloc.stack_map.find(node->operands[0]->value) != alloc.stack_map.end()) { + int offset = alloc.stack_map.at(node->operands[0]->value); + node->inst = "lw " + node->result_reg + ", " + std::to_string(offset) + "(s0)"; + } else { + node->inst = "lw " + node->result_reg + ", 0(" + ptr_reg + ")"; + } + break; + } + case DAGNode::STORE: { + auto val_reg = node->operands[0]->result_reg; + auto ptr_reg = node->operands[1]->result_reg; + if (alloc.stack_map.find(node->operands[1]->value) != alloc.stack_map.end()) { + int offset = alloc.stack_map.at(node->operands[1]->value); + node->inst = "sw " + val_reg + ", " + std::to_string(offset) + "(s0)"; + } else { + node->inst = "sw " + val_reg + ", 0(" + ptr_reg + ")"; + } + break; + } + case DAGNode::BINARY: { + auto bin = dynamic_cast(node->value); + auto lhs_reg = node->operands[0]->result_reg; + auto rhs_reg = node->operands[1]->result_reg; + std::string opcode; + switch (bin->getKind()) { + case BinaryInst::kAdd: opcode = "add"; break; + case BinaryInst::kMul: opcode = "mul"; break; + default: break; + } + if (!opcode.empty()) { + node->inst = opcode + " " + node->result_reg + ", " + lhs_reg + ", " + rhs_reg; + } + break; + } + case DAGNode::CALL: { + auto call = dynamic_cast(node->value); + std::string insts; + for (size_t i = 0; i < node->operands.size() && i < 8; ++i) { + insts += "mv a" + std::to_string(i) + ", " + node->operands[i]->result_reg + "\n"; + } + insts += "jal " + call->getCallee()->getName(); + if (call->getType()->isInt() || call->getType()->isFloat()) { + insts += "\nmv " + node->result_reg + ", a0"; + } + node->inst = insts; + break; + } + case DAGNode::RETURN: { + if (!node->operands.empty()) { + node->inst = "mv a0, " + node->operands[0]->result_reg; + } + break; + } + default: break; + } +} + +// 指令发射 +void RISCv32CodeGen::emit_instructions(DAGNode* node, std::vector& insts, const RegAllocResult& alloc) { + for (auto operand : node->operands) { + emit_instructions(operand, insts, alloc); + } + if (!node->inst.empty()) { + std::stringstream ss(node->inst); + std::string line; + while (std::getline(ss, line, '\n')) { + if (!line.empty()) { + // Replace virtual registers with physical registers + if (!node->result_reg.empty() && alloc.vreg_to_preg.find(node->result_reg) != alloc.vreg_to_preg.end()) { + line = std::regex_replace(line, std::regex("\\b" + node->result_reg + "\\b"), reg_to_string(alloc.vreg_to_preg.at(node->result_reg))); + } + for (auto operand : node->operands) { + if (!operand->result_reg.empty() && alloc.vreg_to_preg.find(operand->result_reg) != alloc.vreg_to_preg.end()) { + line = std::regex_replace(line, std::regex("\\b" + operand->result_reg + "\\b"), reg_to_string(alloc.vreg_to_preg.at(operand->result_reg))); + } + } + insts.push_back(line); + } } } - else if (auto bin = dynamic_cast(inst)) { - std::string lhs_reg = "t0"; - std::string rhs_reg = "t1"; - std::string dst_reg = "t2"; - load_operand(bin->getLhs(), lhs_reg); - load_operand(bin->getRhs(), rhs_reg); - std::string opcode; - switch (bin->getKind()) { - case BinaryInst::kAdd: opcode = "add"; break; - case BinaryInst::kSub: opcode = "sub"; break; - case BinaryInst::kMul: opcode = "mul"; break; - case BinaryInst::kDiv: opcode = "div"; break; - case BinaryInst::kRem: opcode = "rem"; break; - case BinaryInst::kFAdd: opcode = "fadd.s"; break; - case BinaryInst::kFSub: opcode = "fsub.s"; break; - case BinaryInst::kFMul: opcode = "fmul.s"; break; - case BinaryInst::kFDiv: opcode = "fdiv.s"; break; - case BinaryInst::kICmpEQ: insts.push_back("seqz " + dst_reg + ", " + lhs_reg); break; - case BinaryInst::kICmpNE: insts.push_back("snez " + dst_reg + ", " + lhs_reg); break; - case BinaryInst::kICmpLT: insts.push_back("slt " + dst_reg + ", " + lhs_reg + ", " + rhs_reg); break; - case BinaryInst::kICmpGT: insts.push_back("sgt " + dst_reg + ", " + lhs_reg + ", " + rhs_reg); break; - case BinaryInst::kICmpLE: insts.push_back("sle " + dst_reg + ", " + lhs_reg + ", " + rhs_reg); break; - case BinaryInst::kICmpGE: insts.push_back("sge " + dst_reg + ", " + lhs_reg + ", " + rhs_reg); break; - case BinaryInst::kAnd: opcode = "and"; break; - case BinaryInst::kOr: opcode = "or"; break; - default: return insts; - } - if (!opcode.empty()) { - insts.push_back(opcode + " " + dst_reg + ", " + lhs_reg + ", " + rhs_reg); - } - if (alloc.stack_map.find(bin) != alloc.stack_map.end()) { - insts.push_back("sw " + dst_reg + ", " + std::to_string(alloc.stack_map.at(bin)) + "(s0)"); +} + +// 活跃性分析 +std::map> RISCv32CodeGen::liveness_analysis(Function* func) { + std::map> live_in, live_out; + bool changed = true; + + while (changed) { + changed = false; + for (auto it = func->getBasicBlocks_NoRange().rbegin(); it != func->getBasicBlocks_NoRange().rend(); ++it) { + auto bb = it->get(); + for (auto inst_it = bb->getInstructions().rbegin(); inst_it != bb->getInstructions().rend(); ++inst_it) { + auto inst = inst_it->get(); + std::set new_in, new_out; + + // Calculate live_out + if (auto br = dynamic_cast(inst)) { + new_out.insert(live_in[br->getThenBlock()->getInstructions().front().get()].begin(), + live_in[br->getThenBlock()->getInstructions().front().get()].end()); + new_out.insert(live_in[br->getElseBlock()->getInstructions().front().get()].begin(), + live_in[br->getElseBlock()->getInstructions().front().get()].end()); + } else if (auto uncond = dynamic_cast(inst)) { + new_out.insert(live_in[uncond->getBlock()->getInstructions().front().get()].begin(), + live_in[uncond->getBlock()->getInstructions().front().get()].end()); + } else { + auto next_inst = std::next(inst_it); + if (next_inst != bb->getInstructions().rend()) { + new_out = live_in[next_inst->get()]; + } + } + + // Calculate live_in = use ∪ (live_out - def) + std::set use, def; + if (auto bin = dynamic_cast(inst)) { + if (value_vreg_map.find(bin->getLhs()) != value_vreg_map.end()) + use.insert(value_vreg_map[bin->getLhs()]); + if (value_vreg_map.find(bin->getRhs()) != value_vreg_map.end()) + use.insert(value_vreg_map[bin->getRhs()]); + if (value_vreg_map.find(bin) != value_vreg_map.end()) + def.insert(value_vreg_map[bin]); + } else if (auto call = dynamic_cast(inst)) { + for (auto arg : call->getArguments()) { + if (value_vreg_map.find(arg->getValue()) != value_vreg_map.end()) + use.insert(value_vreg_map[arg->getValue()]); + } + if ((call->getType()->isInt() || call->getType()->isFloat()) && value_vreg_map.find(call) != value_vreg_map.end()) { + def.insert(value_vreg_map[call]); + } + } else if (auto load = dynamic_cast(inst)) { + if (value_vreg_map.find(load->getPointer()) != value_vreg_map.end()) + use.insert(value_vreg_map[load->getPointer()]); + if (value_vreg_map.find(load) != value_vreg_map.end()) + def.insert(value_vreg_map[load]); + } else if (auto store = dynamic_cast(inst)) { + if (value_vreg_map.find(store->getValue()) != value_vreg_map.end()) + use.insert(value_vreg_map[store->getValue()]); + if (value_vreg_map.find(store->getPointer()) != value_vreg_map.end()) + use.insert(value_vreg_map[store->getPointer()]); + } else if (auto ret = dynamic_cast(inst)) { + if (ret->hasReturnValue() && value_vreg_map.find(ret->getReturnValue()) != value_vreg_map.end()) { + use.insert(value_vreg_map[ret->getReturnValue()]); + } + } + + new_in = use; + for (const auto& vreg : new_out) { + if (def.find(vreg) == def.end()) { + new_in.insert(vreg); + } + } + + if (live_in[inst] != new_in || live_out[inst] != new_out) { + live_in[inst] = new_in; + live_out[inst] = new_out; + changed = true; + } + } } } - else if (auto uny = dynamic_cast(inst)) { - std::string src_reg = "t0"; - std::string dst_reg = "t1"; - load_operand(uny->getOperand(), src_reg); - switch (uny->getKind()) { - case UnaryInst::kNeg: insts.push_back("sub " + dst_reg + ", x0, " + src_reg); break; - case UnaryInst::kNot: insts.push_back("xori " + dst_reg + ", " + src_reg + ", -1"); break; - case UnaryInst::kFNeg: insts.push_back("fneg.s " + dst_reg + ", " + src_reg); break; - case UnaryInst::kFtoI: insts.push_back("fcvt.w.s " + dst_reg + ", " + src_reg); break; - case UnaryInst::kItoF: insts.push_back("fcvt.s.w " + dst_reg + ", " + src_reg); break; - case UnaryInst::kBitFtoI: insts.push_back("fmv.x.w " + dst_reg + ", " + src_reg); break; - case UnaryInst::kBitItoF: insts.push_back("fmv.w.x " + dst_reg + ", " + src_reg); break; - default: return insts; + + return live_in; +} + +// 干扰图构建 +std::map> RISCv32CodeGen::build_interference_graph( + const std::map>& live_sets) { + std::map> graph; + + for (const auto& pair : live_sets) { + auto inst = pair.first; + const auto& live = pair.second; + std::string def; + if (auto bin = dynamic_cast(inst)) { + if (value_vreg_map.find(bin) != value_vreg_map.end()) + def = value_vreg_map[bin]; + } else if (auto call = dynamic_cast(inst)) { + if ((call->getType()->isInt() || call->getType()->isFloat()) && value_vreg_map.find(call) != value_vreg_map.end()) { + def = value_vreg_map[call]; + } + } else if (auto load = dynamic_cast(inst)) { + if (value_vreg_map.find(load) != value_vreg_map.end()) + def = value_vreg_map[load]; } - if (alloc.stack_map.find(uny) != alloc.stack_map.end()) { - insts.push_back("sw " + dst_reg + ", " + std::to_string(alloc.stack_map.at(uny)) + "(s0)"); + + if (!def.empty()) { + for (const auto& live_vreg : live) { + if (live_vreg != def) { + graph[def].insert(live_vreg); + graph[live_vreg].insert(def); + } + } } } - else if (auto call = dynamic_cast(inst)) { - auto args = call->getArguments(); - size_t i = 0; - for (auto it = args.begin(); it != args.end() && i < 8; ++it, ++i) { - load_operand((*it)->getValue(), "a" + std::to_string(i)); + + return graph; +} + +// 图着色 +void RISCv32CodeGen::color_graph(std::map& vreg_to_preg, + const std::map>& interference_graph) { + std::vector stack; + std::map> temp_graph = interference_graph; + + while (!temp_graph.empty()) { + std::string node_to_remove; + for (const auto& pair : temp_graph) { + if (pair.second.size() < allocable_regs.size()) { + node_to_remove = pair.first; + break; + } } - insts.push_back("jal " + call->getCallee()->getName()); - if (alloc.stack_map.find(call) != alloc.stack_map.end()) { - insts.push_back("sw a0, " + std::to_string(alloc.stack_map.at(call)) + "(s0)"); + + if (node_to_remove.empty()) { + node_to_remove = temp_graph.begin()->first; // Spill if necessary } - } - else if (auto condBr = dynamic_cast(inst)) { - std::string cond_reg = "t0"; - load_operand(condBr->getCondition(), cond_reg); - insts.push_back("bnez " + cond_reg + ", " + condBr->getThenBlock()->getName()); - insts.push_back("j " + condBr->getElseBlock()->getName()); - } - else if (auto br = dynamic_cast(inst)) { - insts.push_back("j " + br->getBlock()->getName()); - } - else if (auto ret = dynamic_cast(inst)) { - if (ret->hasReturnValue()) { - load_operand(ret->getReturnValue(), "a0"); + + stack.push_back(node_to_remove); + for (auto& pair : temp_graph) { + pair.second.erase(node_to_remove); } + temp_graph.erase(node_to_remove); } - else if (auto la = dynamic_cast(inst)) { - std::string dst_reg = "t0"; - load_operand(la->getPointer(), dst_reg); - for (size_t i = 0; i < la->getNumIndices(); ++i) { - std::string idx_reg = "t1"; - load_operand(la->getIndex(i), idx_reg); - insts.push_back("slli " + idx_reg + ", " + idx_reg + ", 2"); - insts.push_back("add " + dst_reg + ", " + dst_reg + ", " + idx_reg); + + while (!stack.empty()) { + auto vreg = stack.back(); + stack.pop_back(); + std::set used_colors; + for (const auto& neighbor : interference_graph.at(vreg)) { + if (vreg_to_preg.find(neighbor) != vreg_to_preg.end()) { + used_colors.insert(reg_to_string(vreg_to_preg[neighbor])); + } } - if (alloc.stack_map.find(la) != alloc.stack_map.end()) { - insts.push_back("sw " + dst_reg + ", " + std::to_string(alloc.stack_map.at(la)) + "(s0)"); + + bool assigned = false; + for (auto preg : allocable_regs) { + if (used_colors.find(reg_to_string(preg)) == used_colors.end()) { + vreg_to_preg[vreg] = preg; + assigned = true; + break; + } } + // If no register is available, spill to stack (handled in register_allocation) } - else if (auto memset = dynamic_cast(inst)) { - std::string ptr_reg = "t0"; - std::string val_reg = "t1"; - std::string size_reg = "t2"; - load_operand(memset->getPointer(), ptr_reg); - load_operand(memset->getValue(), val_reg); - load_operand(memset->getSize(), size_reg); - insts.push_back("mv t3, " + ptr_reg); - insts.push_back("add t4, " + ptr_reg + ", " + size_reg); - insts.push_back("1: sw " + val_reg + ", 0(" + ptr_reg + ")"); - insts.push_back("addi " + ptr_reg + ", " + ptr_reg + ", 4"); - insts.push_back("blt " + ptr_reg + ", t4, 1b"); - } - else if (auto phi = dynamic_cast(inst)) { - // Phi 指令由 eliminate_phi 处理 - } - return insts; } RISCv32CodeGen::RegAllocResult RISCv32CodeGen::register_allocation(Function* func) { RegAllocResult result; int stack_offset = 0; - std::set allocated; + value_vreg_map.clear(); // Clear vreg map for new function + static int vreg_counter = 0; // Counter for unique vreg names // 分配局部变量栈空间 for (const auto& bb : func->getBasicBlocks()) { @@ -292,44 +474,73 @@ RISCv32CodeGen::RegAllocResult RISCv32CodeGen::register_allocation(Function* fun if (auto alloca = dynamic_cast(inst.get())) { if (result.stack_map.find(alloca) == result.stack_map.end()) { result.stack_map[alloca] = stack_offset; + value_vreg_map[alloca] = "v" + std::to_string(vreg_counter++); stack_offset += 4; } + } else if (auto load = dynamic_cast(inst.get())) { + if (value_vreg_map.find(load) == value_vreg_map.end()) { + value_vreg_map[load] = "v" + std::to_string(vreg_counter++); + } + } else if (auto bin = dynamic_cast(inst.get())) { + if (value_vreg_map.find(bin) == value_vreg_map.end()) { + value_vreg_map[bin] = "v" + std::to_string(vreg_counter++); + } + } else if (auto call = dynamic_cast(inst.get())) { + if ((call->getType()->isInt() || call->getType()->isFloat()) && value_vreg_map.find(call) == value_vreg_map.end()) { + value_vreg_map[call] = "v" + std::to_string(vreg_counter++); + } } } } - // 分配函数参数栈空间(入口块的 arguments) + // 分配函数参数栈空间 auto entry_block = func->getEntryBlock(); auto args = entry_block->getArguments(); for (size_t i = 0; i < args.size(); ++i) { - if (i >= 8) { // 超过 8 个参数需要栈空间 + if (i >= 8) { if (result.stack_map.find(args[i]) == result.stack_map.end()) { result.stack_map[args[i]] = stack_offset; + value_vreg_map[args[i]] = "v" + std::to_string(vreg_counter++); stack_offset += 4; } + } else { + value_vreg_map[args[i]] = "v" + std::to_string(vreg_counter++); } } - // 分配中间结果栈空间(如 BinaryInst 和 CallInst) + // 图着色寄存器分配 + auto live_sets = liveness_analysis(func); + auto interference_graph = build_interference_graph(live_sets); + color_graph(result.vreg_to_preg, interference_graph); + + // 分配溢出栈空间 for (const auto& bb : func->getBasicBlocks()) { for (const auto& inst : bb->getInstructions()) { if (auto bin = dynamic_cast(inst.get())) { - if (result.stack_map.find(bin) == result.stack_map.end() && allocated.find(bin) == allocated.end()) { + std::string vreg = value_vreg_map[bin]; + if (result.vreg_to_preg.find(vreg) == result.vreg_to_preg.end()) { result.stack_map[bin] = stack_offset; stack_offset += 4; - allocated.insert(bin); } } else if (auto call = dynamic_cast(inst.get())) { - if (result.stack_map.find(call) == result.stack_map.end() && allocated.find(call) == allocated.end()) { - result.stack_map[call] = stack_offset; + if (call->getType()->isInt() || call->getType()->isFloat()) { + std::string vreg = value_vreg_map[call]; + if (result.vreg_to_preg.find(vreg) == result.vreg_to_preg.end()) { + result.stack_map[call] = stack_offset; + stack_offset += 4; + } + } + } else if (auto load = dynamic_cast(inst.get())) { + std::string vreg = value_vreg_map[load]; + if (result.vreg_to_preg.find(vreg) == result.vreg_to_preg.end()) { + result.stack_map[load] = stack_offset; stack_offset += 4; - allocated.insert(call); } } } } - // 检查是否需要保存 ra 和 s0 + // 保存 ra 和 s0 bool needs_caller_saved = false; for (const auto& bb : func->getBasicBlocks()) { for (const auto& inst : bb->getInstructions()) { @@ -353,20 +564,7 @@ RISCv32CodeGen::RegAllocResult RISCv32CodeGen::register_allocation(Function* fun } void RISCv32CodeGen::eliminate_phi(Function* func) { - // Placeholder: Phi elimination requires inserting moves at predecessor blocks -} - -std::map> RISCv32CodeGen::liveness_analysis(Function* func) { - std::map> live_sets; - // Placeholder: Implement liveness analysis - return live_sets; -} - -std::map> RISCv32CodeGen::build_interference_graph( - const std::map>& live_sets) { - std::map> graph; - // Placeholder: Implement interference graph - return graph; + // TODO: 插入 move 指令处理 phi } } // namespace sysy \ No newline at end of file diff --git a/src/RISCv32Backend.h b/src/RISCv32Backend.h index 7e1239c..14be11f 100644 --- a/src/RISCv32Backend.h +++ b/src/RISCv32Backend.h @@ -6,61 +6,57 @@ #include #include #include +#include namespace sysy { class RISCv32CodeGen { public: - explicit RISCv32CodeGen(Module* mod) : module(mod) {} - std::string code_gen(); // 生成模块的汇编代码 - -private: - Module* module; - - // 物理寄存器 enum class PhysicalReg { - S0, // x8, 帧指针 - T0, T1, T2, T3, T4, T5, T6, // x5-x7, x28-x31 - A0, A1, A2, A3, A4, A5, A6, A7 // x10-x17 - }; - static const std::vector allocable_regs; - - // 操作数 - struct Operand { - enum class Kind { Reg, Imm, Label }; - Kind kind; - Value* value; // 用于寄存器 - std::string label; // 用于标签或立即数 - Operand(Kind k, Value* v) : kind(k), value(v), label("") {} - Operand(Kind k, const std::string& l) : kind(k), value(nullptr), label(l) {} + S0, T0, T1, T2, T3, T4, T5, T6, + A0, A1, A2, A3, A4, A5, A6, A7 }; - // RISC-V 指令 - struct RISCv32Inst { - std::string opcode; - std::vector operands; - RISCv32Inst(const std::string& op, const std::vector& ops) - : opcode(op), operands(ops) {} + // Move DAGNode and RegAllocResult to public section + struct DAGNode { + enum NodeKind { CONSTANT, LOAD, STORE, BINARY, CALL, RETURN }; + NodeKind kind; + Value* value = nullptr; + std::string inst; + std::string result_reg; + std::vector operands; + std::vector users; + DAGNode(NodeKind k) : kind(k) {} }; - // 寄存器分配结果 struct RegAllocResult { - std::map reg_map; // 虚拟寄存器到物理寄存器的映射 - std::map stack_map; // 虚拟寄存器到堆栈槽的映射 - int stack_size; // 堆栈帧大小 + std::map vreg_to_preg; + std::map stack_map; + int stack_size = 0; }; - // 后端方法 + RISCv32CodeGen(Module* mod) : module(mod) {} + + std::string code_gen(); std::string module_gen(); std::string function_gen(Function* func); std::string basicBlock_gen(BasicBlock* bb, const RegAllocResult& alloc); - std::vector instruction_gen(Instruction* inst, const RegAllocResult& alloc); + std::vector> build_dag(BasicBlock* bb); + void select_instructions(DAGNode* node, const RegAllocResult& alloc); // Use const + void emit_instructions(DAGNode* node, std::vector& insts, const RegAllocResult& alloc); // Add alloc + std::map> liveness_analysis(Function* func); + std::map> build_interference_graph( + const std::map>& live_sets); + void color_graph(std::map& vreg_to_preg, + const std::map>& interference_graph); RegAllocResult register_allocation(Function* func); void eliminate_phi(Function* func); - std::map> liveness_analysis(Function* func); - std::map> build_interference_graph( - const std::map>& live_sets); std::string reg_to_string(PhysicalReg reg); + +private: + static const std::vector allocable_regs; + std::map value_vreg_map; + Module* module; }; } // namespace sysy