[backend] introduced DAG, GraphAlloc
This commit is contained in:
@@ -1,6 +1,8 @@
|
||||
#include "RISCv32Backend.h"
|
||||
#include <sstream>
|
||||
#include <algorithm>
|
||||
#include <stdexcept>
|
||||
#include <regex>
|
||||
|
||||
namespace sysy {
|
||||
|
||||
@@ -32,17 +34,6 @@ std::string RISCv32CodeGen::reg_to_string(PhysicalReg reg) {
|
||||
default: return "";
|
||||
}
|
||||
}
|
||||
// 简单的临时寄存器分配器
|
||||
class TempRegAllocator {
|
||||
std::vector<std::string> regs = {"t0", "t1", "t2", "t3", "t4", "t5", "t6"};
|
||||
size_t current = 0;
|
||||
public:
|
||||
std::string get_next() {
|
||||
if (current >= regs.size()) throw std::runtime_error("临时寄存器不足");
|
||||
return regs[current++];
|
||||
}
|
||||
void reset() { current = 0; }
|
||||
};
|
||||
|
||||
std::string RISCv32CodeGen::code_gen() {
|
||||
std::stringstream ss;
|
||||
@@ -112,179 +103,370 @@ std::string RISCv32CodeGen::function_gen(Function* func) {
|
||||
std::string RISCv32CodeGen::basicBlock_gen(BasicBlock* bb, const RegAllocResult& alloc) {
|
||||
std::stringstream ss;
|
||||
ss << bb->getName() << ":\n";
|
||||
for (const auto& inst : bb->getInstructions()) {
|
||||
auto riscv_insts = instruction_gen(inst.get(), alloc);
|
||||
for (const auto& riscv_inst : riscv_insts) {
|
||||
ss << " " << riscv_inst << "\n";
|
||||
}
|
||||
auto dag = build_dag(bb);
|
||||
std::vector<std::string> insts;
|
||||
for (auto& node : dag) {
|
||||
select_instructions(node.get(), alloc);
|
||||
emit_instructions(node.get(), insts, alloc);
|
||||
}
|
||||
for (const auto& inst : insts) {
|
||||
ss << " " << inst << "\n";
|
||||
}
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
std::vector<std::string> RISCv32CodeGen::instruction_gen(Instruction* inst, const RegAllocResult& alloc) {
|
||||
std::vector<std::string> insts;
|
||||
// DAG 构建
|
||||
std::vector<std::unique_ptr<RISCv32CodeGen::DAGNode>> RISCv32CodeGen::build_dag(BasicBlock* bb) {
|
||||
std::vector<std::unique_ptr<DAGNode>> nodes;
|
||||
std::map<Value*, DAGNode*> value_to_node;
|
||||
static int vreg_counter = 0; // Counter for unique vreg names
|
||||
|
||||
auto load_operand = [&](Value* val, const std::string& reg) {
|
||||
if (auto constant = dynamic_cast<ConstantValue*>(val)) {
|
||||
if (constant->isInt()) {
|
||||
insts.push_back("li " + reg + ", " + std::to_string(constant->getInt()));
|
||||
} else {
|
||||
float f = constant->getFloat();
|
||||
uint32_t float_bits = *(uint32_t*)&f;
|
||||
insts.push_back("li " + reg + ", " + std::to_string(float_bits));
|
||||
insts.push_back("fmv.w.x " + reg + ", " + reg);
|
||||
}
|
||||
} else if (alloc.stack_map.find(val) != alloc.stack_map.end()) {
|
||||
insts.push_back("lw " + reg + ", " + std::to_string(alloc.stack_map.at(val)) + "(s0)");
|
||||
} else if (auto global = dynamic_cast<GlobalValue*>(val)) {
|
||||
insts.push_back("la " + reg + ", " + global->getName());
|
||||
}
|
||||
auto create_node = [&](DAGNode::NodeKind kind, Value* val = nullptr) {
|
||||
auto node = std::make_unique<DAGNode>(kind);
|
||||
node->value = val;
|
||||
node->result_reg = val ? "v" + std::to_string(vreg_counter++) : "";
|
||||
if (val) value_to_node[val] = node.get();
|
||||
nodes.push_back(std::move(node));
|
||||
return nodes.back().get();
|
||||
};
|
||||
|
||||
if (auto alloca = dynamic_cast<AllocaInst*>(inst)) {
|
||||
// 栈空间已在 register_allocation 中分配
|
||||
}
|
||||
else if (auto store = dynamic_cast<StoreInst*>(inst)) {
|
||||
std::string val_reg = "t0";
|
||||
load_operand(store->getValue(), val_reg);
|
||||
auto ptr = store->getPointer();
|
||||
if (auto alloca = dynamic_cast<AllocaInst*>(ptr)) {
|
||||
int offset = alloc.stack_map.at(alloca);
|
||||
insts.push_back("sw " + val_reg + ", " + std::to_string(offset) + "(s0)");
|
||||
} else if (auto global = dynamic_cast<GlobalValue*>(ptr)) {
|
||||
std::string ptr_reg = "t1";
|
||||
insts.push_back("la " + ptr_reg + ", " + global->getName());
|
||||
insts.push_back("sw " + val_reg + ", 0(" + ptr_reg + ")");
|
||||
for (const auto& inst : bb->getInstructions()) {
|
||||
if (auto alloca = dynamic_cast<AllocaInst*>(inst.get())) {
|
||||
create_node(DAGNode::CONSTANT, alloca); // Allocate stack space
|
||||
} else if (auto store = dynamic_cast<StoreInst*>(inst.get())) {
|
||||
auto store_node = create_node(DAGNode::STORE);
|
||||
auto val_node = value_to_node.find(store->getValue()) != value_to_node.end()
|
||||
? value_to_node[store->getValue()]
|
||||
: create_node(DAGNode::CONSTANT, store->getValue());
|
||||
auto ptr_node = value_to_node.find(store->getPointer()) != value_to_node.end()
|
||||
? value_to_node[store->getPointer()]
|
||||
: create_node(DAGNode::CONSTANT, store->getPointer());
|
||||
store_node->operands.push_back(val_node);
|
||||
store_node->operands.push_back(ptr_node);
|
||||
val_node->users.push_back(store_node);
|
||||
ptr_node->users.push_back(store_node);
|
||||
} else if (auto load = dynamic_cast<LoadInst*>(inst.get())) {
|
||||
auto load_node = create_node(DAGNode::LOAD, load);
|
||||
auto ptr_node = value_to_node.find(load->getPointer()) != value_to_node.end()
|
||||
? value_to_node[load->getPointer()]
|
||||
: create_node(DAGNode::CONSTANT, load->getPointer());
|
||||
load_node->operands.push_back(ptr_node);
|
||||
ptr_node->users.push_back(load_node);
|
||||
} else if (auto bin = dynamic_cast<BinaryInst*>(inst.get())) {
|
||||
auto bin_node = create_node(DAGNode::BINARY, bin);
|
||||
auto lhs_node = value_to_node.find(bin->getLhs()) != value_to_node.end()
|
||||
? value_to_node[bin->getLhs()]
|
||||
: create_node(DAGNode::CONSTANT, bin->getLhs());
|
||||
auto rhs_node = value_to_node.find(bin->getRhs()) != value_to_node.end()
|
||||
? value_to_node[bin->getRhs()]
|
||||
: create_node(DAGNode::CONSTANT, bin->getRhs());
|
||||
bin_node->operands.push_back(lhs_node);
|
||||
bin_node->operands.push_back(rhs_node);
|
||||
lhs_node->users.push_back(bin_node);
|
||||
rhs_node->users.push_back(bin_node);
|
||||
} else if (auto call = dynamic_cast<CallInst*>(inst.get())) {
|
||||
auto call_node = create_node(DAGNode::CALL, call);
|
||||
for (auto arg : call->getArguments()) {
|
||||
auto arg_node = value_to_node.find(arg->getValue()) != value_to_node.end()
|
||||
? value_to_node[arg->getValue()]
|
||||
: create_node(DAGNode::CONSTANT, arg->getValue());
|
||||
call_node->operands.push_back(arg_node);
|
||||
arg_node->users.push_back(call_node);
|
||||
}
|
||||
} else if (auto ret = dynamic_cast<ReturnInst*>(inst.get())) {
|
||||
auto ret_node = create_node(DAGNode::RETURN);
|
||||
if (ret->hasReturnValue()) {
|
||||
auto val_node = value_to_node.find(ret->getReturnValue()) != value_to_node.end()
|
||||
? value_to_node[ret->getReturnValue()]
|
||||
: create_node(DAGNode::CONSTANT, ret->getReturnValue());
|
||||
ret_node->operands.push_back(val_node);
|
||||
val_node->users.push_back(ret_node);
|
||||
}
|
||||
}
|
||||
}
|
||||
else if (auto load = dynamic_cast<LoadInst*>(inst)) {
|
||||
std::string dst_reg = "t0";
|
||||
auto ptr = load->getPointer();
|
||||
if (auto alloca = dynamic_cast<AllocaInst*>(ptr)) {
|
||||
int offset = alloc.stack_map.at(alloca);
|
||||
insts.push_back("lw " + dst_reg + ", " + std::to_string(offset) + "(s0)");
|
||||
} else if (auto global = dynamic_cast<GlobalValue*>(ptr)) {
|
||||
std::string ptr_reg = "t1";
|
||||
insts.push_back("la " + ptr_reg + ", " + global->getName());
|
||||
insts.push_back("lw " + dst_reg + ", 0(" + ptr_reg + ")");
|
||||
|
||||
return nodes;
|
||||
}
|
||||
|
||||
// 指令选择
|
||||
void RISCv32CodeGen::select_instructions(DAGNode* node, const RegAllocResult& alloc) {
|
||||
if (!node->inst.empty()) return;
|
||||
|
||||
for (auto operand : node->operands) {
|
||||
select_instructions(operand, alloc);
|
||||
}
|
||||
|
||||
switch (node->kind) {
|
||||
case DAGNode::CONSTANT: {
|
||||
if (auto constant = dynamic_cast<ConstantValue*>(node->value)) {
|
||||
if (constant->isInt()) {
|
||||
node->inst = "li " + node->result_reg + ", " + std::to_string(constant->getInt());
|
||||
} else {
|
||||
float f = constant->getFloat();
|
||||
uint32_t float_bits = *(uint32_t*)&f;
|
||||
node->inst = "li " + node->result_reg + ", " + std::to_string(float_bits) + "\nfmv.w.x " + node->result_reg + ", " + node->result_reg;
|
||||
}
|
||||
} else if (auto global = dynamic_cast<GlobalValue*>(node->value)) {
|
||||
node->inst = "la " + node->result_reg + ", " + global->getName();
|
||||
} else if (auto alloca = dynamic_cast<AllocaInst*>(node->value)) {
|
||||
if (alloc.stack_map.find(alloca) != alloc.stack_map.end()) {
|
||||
node->inst = ""; // Stack address handled in LOAD/STORE
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
if (alloc.stack_map.find(load) != alloc.stack_map.end()) {
|
||||
insts.push_back("sw " + dst_reg + ", " + std::to_string(alloc.stack_map.at(load)) + "(s0)");
|
||||
case DAGNode::LOAD: {
|
||||
auto ptr_reg = node->operands[0]->result_reg;
|
||||
if (alloc.stack_map.find(node->operands[0]->value) != alloc.stack_map.end()) {
|
||||
int offset = alloc.stack_map.at(node->operands[0]->value);
|
||||
node->inst = "lw " + node->result_reg + ", " + std::to_string(offset) + "(s0)";
|
||||
} else {
|
||||
node->inst = "lw " + node->result_reg + ", 0(" + ptr_reg + ")";
|
||||
}
|
||||
break;
|
||||
}
|
||||
case DAGNode::STORE: {
|
||||
auto val_reg = node->operands[0]->result_reg;
|
||||
auto ptr_reg = node->operands[1]->result_reg;
|
||||
if (alloc.stack_map.find(node->operands[1]->value) != alloc.stack_map.end()) {
|
||||
int offset = alloc.stack_map.at(node->operands[1]->value);
|
||||
node->inst = "sw " + val_reg + ", " + std::to_string(offset) + "(s0)";
|
||||
} else {
|
||||
node->inst = "sw " + val_reg + ", 0(" + ptr_reg + ")";
|
||||
}
|
||||
break;
|
||||
}
|
||||
case DAGNode::BINARY: {
|
||||
auto bin = dynamic_cast<BinaryInst*>(node->value);
|
||||
auto lhs_reg = node->operands[0]->result_reg;
|
||||
auto rhs_reg = node->operands[1]->result_reg;
|
||||
std::string opcode;
|
||||
switch (bin->getKind()) {
|
||||
case BinaryInst::kAdd: opcode = "add"; break;
|
||||
case BinaryInst::kMul: opcode = "mul"; break;
|
||||
default: break;
|
||||
}
|
||||
if (!opcode.empty()) {
|
||||
node->inst = opcode + " " + node->result_reg + ", " + lhs_reg + ", " + rhs_reg;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case DAGNode::CALL: {
|
||||
auto call = dynamic_cast<CallInst*>(node->value);
|
||||
std::string insts;
|
||||
for (size_t i = 0; i < node->operands.size() && i < 8; ++i) {
|
||||
insts += "mv a" + std::to_string(i) + ", " + node->operands[i]->result_reg + "\n";
|
||||
}
|
||||
insts += "jal " + call->getCallee()->getName();
|
||||
if (call->getType()->isInt() || call->getType()->isFloat()) {
|
||||
insts += "\nmv " + node->result_reg + ", a0";
|
||||
}
|
||||
node->inst = insts;
|
||||
break;
|
||||
}
|
||||
case DAGNode::RETURN: {
|
||||
if (!node->operands.empty()) {
|
||||
node->inst = "mv a0, " + node->operands[0]->result_reg;
|
||||
}
|
||||
break;
|
||||
}
|
||||
default: break;
|
||||
}
|
||||
}
|
||||
|
||||
// 指令发射
|
||||
void RISCv32CodeGen::emit_instructions(DAGNode* node, std::vector<std::string>& insts, const RegAllocResult& alloc) {
|
||||
for (auto operand : node->operands) {
|
||||
emit_instructions(operand, insts, alloc);
|
||||
}
|
||||
if (!node->inst.empty()) {
|
||||
std::stringstream ss(node->inst);
|
||||
std::string line;
|
||||
while (std::getline(ss, line, '\n')) {
|
||||
if (!line.empty()) {
|
||||
// Replace virtual registers with physical registers
|
||||
if (!node->result_reg.empty() && alloc.vreg_to_preg.find(node->result_reg) != alloc.vreg_to_preg.end()) {
|
||||
line = std::regex_replace(line, std::regex("\\b" + node->result_reg + "\\b"), reg_to_string(alloc.vreg_to_preg.at(node->result_reg)));
|
||||
}
|
||||
for (auto operand : node->operands) {
|
||||
if (!operand->result_reg.empty() && alloc.vreg_to_preg.find(operand->result_reg) != alloc.vreg_to_preg.end()) {
|
||||
line = std::regex_replace(line, std::regex("\\b" + operand->result_reg + "\\b"), reg_to_string(alloc.vreg_to_preg.at(operand->result_reg)));
|
||||
}
|
||||
}
|
||||
insts.push_back(line);
|
||||
}
|
||||
}
|
||||
}
|
||||
else if (auto bin = dynamic_cast<BinaryInst*>(inst)) {
|
||||
std::string lhs_reg = "t0";
|
||||
std::string rhs_reg = "t1";
|
||||
std::string dst_reg = "t2";
|
||||
load_operand(bin->getLhs(), lhs_reg);
|
||||
load_operand(bin->getRhs(), rhs_reg);
|
||||
std::string opcode;
|
||||
switch (bin->getKind()) {
|
||||
case BinaryInst::kAdd: opcode = "add"; break;
|
||||
case BinaryInst::kSub: opcode = "sub"; break;
|
||||
case BinaryInst::kMul: opcode = "mul"; break;
|
||||
case BinaryInst::kDiv: opcode = "div"; break;
|
||||
case BinaryInst::kRem: opcode = "rem"; break;
|
||||
case BinaryInst::kFAdd: opcode = "fadd.s"; break;
|
||||
case BinaryInst::kFSub: opcode = "fsub.s"; break;
|
||||
case BinaryInst::kFMul: opcode = "fmul.s"; break;
|
||||
case BinaryInst::kFDiv: opcode = "fdiv.s"; break;
|
||||
case BinaryInst::kICmpEQ: insts.push_back("seqz " + dst_reg + ", " + lhs_reg); break;
|
||||
case BinaryInst::kICmpNE: insts.push_back("snez " + dst_reg + ", " + lhs_reg); break;
|
||||
case BinaryInst::kICmpLT: insts.push_back("slt " + dst_reg + ", " + lhs_reg + ", " + rhs_reg); break;
|
||||
case BinaryInst::kICmpGT: insts.push_back("sgt " + dst_reg + ", " + lhs_reg + ", " + rhs_reg); break;
|
||||
case BinaryInst::kICmpLE: insts.push_back("sle " + dst_reg + ", " + lhs_reg + ", " + rhs_reg); break;
|
||||
case BinaryInst::kICmpGE: insts.push_back("sge " + dst_reg + ", " + lhs_reg + ", " + rhs_reg); break;
|
||||
case BinaryInst::kAnd: opcode = "and"; break;
|
||||
case BinaryInst::kOr: opcode = "or"; break;
|
||||
default: return insts;
|
||||
}
|
||||
if (!opcode.empty()) {
|
||||
insts.push_back(opcode + " " + dst_reg + ", " + lhs_reg + ", " + rhs_reg);
|
||||
}
|
||||
if (alloc.stack_map.find(bin) != alloc.stack_map.end()) {
|
||||
insts.push_back("sw " + dst_reg + ", " + std::to_string(alloc.stack_map.at(bin)) + "(s0)");
|
||||
}
|
||||
|
||||
// 活跃性分析
|
||||
std::map<Instruction*, std::set<std::string>> RISCv32CodeGen::liveness_analysis(Function* func) {
|
||||
std::map<Instruction*, std::set<std::string>> live_in, live_out;
|
||||
bool changed = true;
|
||||
|
||||
while (changed) {
|
||||
changed = false;
|
||||
for (auto it = func->getBasicBlocks_NoRange().rbegin(); it != func->getBasicBlocks_NoRange().rend(); ++it) {
|
||||
auto bb = it->get();
|
||||
for (auto inst_it = bb->getInstructions().rbegin(); inst_it != bb->getInstructions().rend(); ++inst_it) {
|
||||
auto inst = inst_it->get();
|
||||
std::set<std::string> new_in, new_out;
|
||||
|
||||
// Calculate live_out
|
||||
if (auto br = dynamic_cast<CondBrInst*>(inst)) {
|
||||
new_out.insert(live_in[br->getThenBlock()->getInstructions().front().get()].begin(),
|
||||
live_in[br->getThenBlock()->getInstructions().front().get()].end());
|
||||
new_out.insert(live_in[br->getElseBlock()->getInstructions().front().get()].begin(),
|
||||
live_in[br->getElseBlock()->getInstructions().front().get()].end());
|
||||
} else if (auto uncond = dynamic_cast<UncondBrInst*>(inst)) {
|
||||
new_out.insert(live_in[uncond->getBlock()->getInstructions().front().get()].begin(),
|
||||
live_in[uncond->getBlock()->getInstructions().front().get()].end());
|
||||
} else {
|
||||
auto next_inst = std::next(inst_it);
|
||||
if (next_inst != bb->getInstructions().rend()) {
|
||||
new_out = live_in[next_inst->get()];
|
||||
}
|
||||
}
|
||||
|
||||
// Calculate live_in = use ∪ (live_out - def)
|
||||
std::set<std::string> use, def;
|
||||
if (auto bin = dynamic_cast<BinaryInst*>(inst)) {
|
||||
if (value_vreg_map.find(bin->getLhs()) != value_vreg_map.end())
|
||||
use.insert(value_vreg_map[bin->getLhs()]);
|
||||
if (value_vreg_map.find(bin->getRhs()) != value_vreg_map.end())
|
||||
use.insert(value_vreg_map[bin->getRhs()]);
|
||||
if (value_vreg_map.find(bin) != value_vreg_map.end())
|
||||
def.insert(value_vreg_map[bin]);
|
||||
} else if (auto call = dynamic_cast<CallInst*>(inst)) {
|
||||
for (auto arg : call->getArguments()) {
|
||||
if (value_vreg_map.find(arg->getValue()) != value_vreg_map.end())
|
||||
use.insert(value_vreg_map[arg->getValue()]);
|
||||
}
|
||||
if ((call->getType()->isInt() || call->getType()->isFloat()) && value_vreg_map.find(call) != value_vreg_map.end()) {
|
||||
def.insert(value_vreg_map[call]);
|
||||
}
|
||||
} else if (auto load = dynamic_cast<LoadInst*>(inst)) {
|
||||
if (value_vreg_map.find(load->getPointer()) != value_vreg_map.end())
|
||||
use.insert(value_vreg_map[load->getPointer()]);
|
||||
if (value_vreg_map.find(load) != value_vreg_map.end())
|
||||
def.insert(value_vreg_map[load]);
|
||||
} else if (auto store = dynamic_cast<StoreInst*>(inst)) {
|
||||
if (value_vreg_map.find(store->getValue()) != value_vreg_map.end())
|
||||
use.insert(value_vreg_map[store->getValue()]);
|
||||
if (value_vreg_map.find(store->getPointer()) != value_vreg_map.end())
|
||||
use.insert(value_vreg_map[store->getPointer()]);
|
||||
} else if (auto ret = dynamic_cast<ReturnInst*>(inst)) {
|
||||
if (ret->hasReturnValue() && value_vreg_map.find(ret->getReturnValue()) != value_vreg_map.end()) {
|
||||
use.insert(value_vreg_map[ret->getReturnValue()]);
|
||||
}
|
||||
}
|
||||
|
||||
new_in = use;
|
||||
for (const auto& vreg : new_out) {
|
||||
if (def.find(vreg) == def.end()) {
|
||||
new_in.insert(vreg);
|
||||
}
|
||||
}
|
||||
|
||||
if (live_in[inst] != new_in || live_out[inst] != new_out) {
|
||||
live_in[inst] = new_in;
|
||||
live_out[inst] = new_out;
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
else if (auto uny = dynamic_cast<UnaryInst*>(inst)) {
|
||||
std::string src_reg = "t0";
|
||||
std::string dst_reg = "t1";
|
||||
load_operand(uny->getOperand(), src_reg);
|
||||
switch (uny->getKind()) {
|
||||
case UnaryInst::kNeg: insts.push_back("sub " + dst_reg + ", x0, " + src_reg); break;
|
||||
case UnaryInst::kNot: insts.push_back("xori " + dst_reg + ", " + src_reg + ", -1"); break;
|
||||
case UnaryInst::kFNeg: insts.push_back("fneg.s " + dst_reg + ", " + src_reg); break;
|
||||
case UnaryInst::kFtoI: insts.push_back("fcvt.w.s " + dst_reg + ", " + src_reg); break;
|
||||
case UnaryInst::kItoF: insts.push_back("fcvt.s.w " + dst_reg + ", " + src_reg); break;
|
||||
case UnaryInst::kBitFtoI: insts.push_back("fmv.x.w " + dst_reg + ", " + src_reg); break;
|
||||
case UnaryInst::kBitItoF: insts.push_back("fmv.w.x " + dst_reg + ", " + src_reg); break;
|
||||
default: return insts;
|
||||
|
||||
return live_in;
|
||||
}
|
||||
|
||||
// 干扰图构建
|
||||
std::map<std::string, std::set<std::string>> RISCv32CodeGen::build_interference_graph(
|
||||
const std::map<Instruction*, std::set<std::string>>& live_sets) {
|
||||
std::map<std::string, std::set<std::string>> graph;
|
||||
|
||||
for (const auto& pair : live_sets) {
|
||||
auto inst = pair.first;
|
||||
const auto& live = pair.second;
|
||||
std::string def;
|
||||
if (auto bin = dynamic_cast<BinaryInst*>(inst)) {
|
||||
if (value_vreg_map.find(bin) != value_vreg_map.end())
|
||||
def = value_vreg_map[bin];
|
||||
} else if (auto call = dynamic_cast<CallInst*>(inst)) {
|
||||
if ((call->getType()->isInt() || call->getType()->isFloat()) && value_vreg_map.find(call) != value_vreg_map.end()) {
|
||||
def = value_vreg_map[call];
|
||||
}
|
||||
} else if (auto load = dynamic_cast<LoadInst*>(inst)) {
|
||||
if (value_vreg_map.find(load) != value_vreg_map.end())
|
||||
def = value_vreg_map[load];
|
||||
}
|
||||
if (alloc.stack_map.find(uny) != alloc.stack_map.end()) {
|
||||
insts.push_back("sw " + dst_reg + ", " + std::to_string(alloc.stack_map.at(uny)) + "(s0)");
|
||||
|
||||
if (!def.empty()) {
|
||||
for (const auto& live_vreg : live) {
|
||||
if (live_vreg != def) {
|
||||
graph[def].insert(live_vreg);
|
||||
graph[live_vreg].insert(def);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
else if (auto call = dynamic_cast<CallInst*>(inst)) {
|
||||
auto args = call->getArguments();
|
||||
size_t i = 0;
|
||||
for (auto it = args.begin(); it != args.end() && i < 8; ++it, ++i) {
|
||||
load_operand((*it)->getValue(), "a" + std::to_string(i));
|
||||
|
||||
return graph;
|
||||
}
|
||||
|
||||
// 图着色
|
||||
void RISCv32CodeGen::color_graph(std::map<std::string, PhysicalReg>& vreg_to_preg,
|
||||
const std::map<std::string, std::set<std::string>>& interference_graph) {
|
||||
std::vector<std::string> stack;
|
||||
std::map<std::string, std::set<std::string>> temp_graph = interference_graph;
|
||||
|
||||
while (!temp_graph.empty()) {
|
||||
std::string node_to_remove;
|
||||
for (const auto& pair : temp_graph) {
|
||||
if (pair.second.size() < allocable_regs.size()) {
|
||||
node_to_remove = pair.first;
|
||||
break;
|
||||
}
|
||||
}
|
||||
insts.push_back("jal " + call->getCallee()->getName());
|
||||
if (alloc.stack_map.find(call) != alloc.stack_map.end()) {
|
||||
insts.push_back("sw a0, " + std::to_string(alloc.stack_map.at(call)) + "(s0)");
|
||||
|
||||
if (node_to_remove.empty()) {
|
||||
node_to_remove = temp_graph.begin()->first; // Spill if necessary
|
||||
}
|
||||
}
|
||||
else if (auto condBr = dynamic_cast<CondBrInst*>(inst)) {
|
||||
std::string cond_reg = "t0";
|
||||
load_operand(condBr->getCondition(), cond_reg);
|
||||
insts.push_back("bnez " + cond_reg + ", " + condBr->getThenBlock()->getName());
|
||||
insts.push_back("j " + condBr->getElseBlock()->getName());
|
||||
}
|
||||
else if (auto br = dynamic_cast<UncondBrInst*>(inst)) {
|
||||
insts.push_back("j " + br->getBlock()->getName());
|
||||
}
|
||||
else if (auto ret = dynamic_cast<ReturnInst*>(inst)) {
|
||||
if (ret->hasReturnValue()) {
|
||||
load_operand(ret->getReturnValue(), "a0");
|
||||
|
||||
stack.push_back(node_to_remove);
|
||||
for (auto& pair : temp_graph) {
|
||||
pair.second.erase(node_to_remove);
|
||||
}
|
||||
temp_graph.erase(node_to_remove);
|
||||
}
|
||||
else if (auto la = dynamic_cast<LaInst*>(inst)) {
|
||||
std::string dst_reg = "t0";
|
||||
load_operand(la->getPointer(), dst_reg);
|
||||
for (size_t i = 0; i < la->getNumIndices(); ++i) {
|
||||
std::string idx_reg = "t1";
|
||||
load_operand(la->getIndex(i), idx_reg);
|
||||
insts.push_back("slli " + idx_reg + ", " + idx_reg + ", 2");
|
||||
insts.push_back("add " + dst_reg + ", " + dst_reg + ", " + idx_reg);
|
||||
|
||||
while (!stack.empty()) {
|
||||
auto vreg = stack.back();
|
||||
stack.pop_back();
|
||||
std::set<std::string> used_colors;
|
||||
for (const auto& neighbor : interference_graph.at(vreg)) {
|
||||
if (vreg_to_preg.find(neighbor) != vreg_to_preg.end()) {
|
||||
used_colors.insert(reg_to_string(vreg_to_preg[neighbor]));
|
||||
}
|
||||
}
|
||||
if (alloc.stack_map.find(la) != alloc.stack_map.end()) {
|
||||
insts.push_back("sw " + dst_reg + ", " + std::to_string(alloc.stack_map.at(la)) + "(s0)");
|
||||
|
||||
bool assigned = false;
|
||||
for (auto preg : allocable_regs) {
|
||||
if (used_colors.find(reg_to_string(preg)) == used_colors.end()) {
|
||||
vreg_to_preg[vreg] = preg;
|
||||
assigned = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
// If no register is available, spill to stack (handled in register_allocation)
|
||||
}
|
||||
else if (auto memset = dynamic_cast<MemsetInst*>(inst)) {
|
||||
std::string ptr_reg = "t0";
|
||||
std::string val_reg = "t1";
|
||||
std::string size_reg = "t2";
|
||||
load_operand(memset->getPointer(), ptr_reg);
|
||||
load_operand(memset->getValue(), val_reg);
|
||||
load_operand(memset->getSize(), size_reg);
|
||||
insts.push_back("mv t3, " + ptr_reg);
|
||||
insts.push_back("add t4, " + ptr_reg + ", " + size_reg);
|
||||
insts.push_back("1: sw " + val_reg + ", 0(" + ptr_reg + ")");
|
||||
insts.push_back("addi " + ptr_reg + ", " + ptr_reg + ", 4");
|
||||
insts.push_back("blt " + ptr_reg + ", t4, 1b");
|
||||
}
|
||||
else if (auto phi = dynamic_cast<PhiInst*>(inst)) {
|
||||
// Phi 指令由 eliminate_phi 处理
|
||||
}
|
||||
return insts;
|
||||
}
|
||||
|
||||
RISCv32CodeGen::RegAllocResult RISCv32CodeGen::register_allocation(Function* func) {
|
||||
RegAllocResult result;
|
||||
int stack_offset = 0;
|
||||
std::set<Value*> allocated;
|
||||
value_vreg_map.clear(); // Clear vreg map for new function
|
||||
static int vreg_counter = 0; // Counter for unique vreg names
|
||||
|
||||
// 分配局部变量栈空间
|
||||
for (const auto& bb : func->getBasicBlocks()) {
|
||||
@@ -292,44 +474,73 @@ RISCv32CodeGen::RegAllocResult RISCv32CodeGen::register_allocation(Function* fun
|
||||
if (auto alloca = dynamic_cast<AllocaInst*>(inst.get())) {
|
||||
if (result.stack_map.find(alloca) == result.stack_map.end()) {
|
||||
result.stack_map[alloca] = stack_offset;
|
||||
value_vreg_map[alloca] = "v" + std::to_string(vreg_counter++);
|
||||
stack_offset += 4;
|
||||
}
|
||||
} else if (auto load = dynamic_cast<LoadInst*>(inst.get())) {
|
||||
if (value_vreg_map.find(load) == value_vreg_map.end()) {
|
||||
value_vreg_map[load] = "v" + std::to_string(vreg_counter++);
|
||||
}
|
||||
} else if (auto bin = dynamic_cast<BinaryInst*>(inst.get())) {
|
||||
if (value_vreg_map.find(bin) == value_vreg_map.end()) {
|
||||
value_vreg_map[bin] = "v" + std::to_string(vreg_counter++);
|
||||
}
|
||||
} else if (auto call = dynamic_cast<CallInst*>(inst.get())) {
|
||||
if ((call->getType()->isInt() || call->getType()->isFloat()) && value_vreg_map.find(call) == value_vreg_map.end()) {
|
||||
value_vreg_map[call] = "v" + std::to_string(vreg_counter++);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 分配函数参数栈空间(入口块的 arguments)
|
||||
// 分配函数参数栈空间
|
||||
auto entry_block = func->getEntryBlock();
|
||||
auto args = entry_block->getArguments();
|
||||
for (size_t i = 0; i < args.size(); ++i) {
|
||||
if (i >= 8) { // 超过 8 个参数需要栈空间
|
||||
if (i >= 8) {
|
||||
if (result.stack_map.find(args[i]) == result.stack_map.end()) {
|
||||
result.stack_map[args[i]] = stack_offset;
|
||||
value_vreg_map[args[i]] = "v" + std::to_string(vreg_counter++);
|
||||
stack_offset += 4;
|
||||
}
|
||||
} else {
|
||||
value_vreg_map[args[i]] = "v" + std::to_string(vreg_counter++);
|
||||
}
|
||||
}
|
||||
|
||||
// 分配中间结果栈空间(如 BinaryInst 和 CallInst)
|
||||
// 图着色寄存器分配
|
||||
auto live_sets = liveness_analysis(func);
|
||||
auto interference_graph = build_interference_graph(live_sets);
|
||||
color_graph(result.vreg_to_preg, interference_graph);
|
||||
|
||||
// 分配溢出栈空间
|
||||
for (const auto& bb : func->getBasicBlocks()) {
|
||||
for (const auto& inst : bb->getInstructions()) {
|
||||
if (auto bin = dynamic_cast<BinaryInst*>(inst.get())) {
|
||||
if (result.stack_map.find(bin) == result.stack_map.end() && allocated.find(bin) == allocated.end()) {
|
||||
std::string vreg = value_vreg_map[bin];
|
||||
if (result.vreg_to_preg.find(vreg) == result.vreg_to_preg.end()) {
|
||||
result.stack_map[bin] = stack_offset;
|
||||
stack_offset += 4;
|
||||
allocated.insert(bin);
|
||||
}
|
||||
} else if (auto call = dynamic_cast<CallInst*>(inst.get())) {
|
||||
if (result.stack_map.find(call) == result.stack_map.end() && allocated.find(call) == allocated.end()) {
|
||||
result.stack_map[call] = stack_offset;
|
||||
if (call->getType()->isInt() || call->getType()->isFloat()) {
|
||||
std::string vreg = value_vreg_map[call];
|
||||
if (result.vreg_to_preg.find(vreg) == result.vreg_to_preg.end()) {
|
||||
result.stack_map[call] = stack_offset;
|
||||
stack_offset += 4;
|
||||
}
|
||||
}
|
||||
} else if (auto load = dynamic_cast<LoadInst*>(inst.get())) {
|
||||
std::string vreg = value_vreg_map[load];
|
||||
if (result.vreg_to_preg.find(vreg) == result.vreg_to_preg.end()) {
|
||||
result.stack_map[load] = stack_offset;
|
||||
stack_offset += 4;
|
||||
allocated.insert(call);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 检查是否需要保存 ra 和 s0
|
||||
// 保存 ra 和 s0
|
||||
bool needs_caller_saved = false;
|
||||
for (const auto& bb : func->getBasicBlocks()) {
|
||||
for (const auto& inst : bb->getInstructions()) {
|
||||
@@ -353,20 +564,7 @@ RISCv32CodeGen::RegAllocResult RISCv32CodeGen::register_allocation(Function* fun
|
||||
}
|
||||
|
||||
void RISCv32CodeGen::eliminate_phi(Function* func) {
|
||||
// Placeholder: Phi elimination requires inserting moves at predecessor blocks
|
||||
}
|
||||
|
||||
std::map<Instruction*, std::set<Value*>> RISCv32CodeGen::liveness_analysis(Function* func) {
|
||||
std::map<Instruction*, std::set<Value*>> live_sets;
|
||||
// Placeholder: Implement liveness analysis
|
||||
return live_sets;
|
||||
}
|
||||
|
||||
std::map<Value*, std::set<Value*>> RISCv32CodeGen::build_interference_graph(
|
||||
const std::map<Instruction*, std::set<Value*>>& live_sets) {
|
||||
std::map<Value*, std::set<Value*>> graph;
|
||||
// Placeholder: Implement interference graph
|
||||
return graph;
|
||||
// TODO: 插入 move 指令处理 phi
|
||||
}
|
||||
|
||||
} // namespace sysy
|
||||
@@ -6,61 +6,57 @@
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <memory>
|
||||
|
||||
namespace sysy {
|
||||
|
||||
class RISCv32CodeGen {
|
||||
public:
|
||||
explicit RISCv32CodeGen(Module* mod) : module(mod) {}
|
||||
std::string code_gen(); // 生成模块的汇编代码
|
||||
|
||||
private:
|
||||
Module* module;
|
||||
|
||||
// 物理寄存器
|
||||
enum class PhysicalReg {
|
||||
S0, // x8, 帧指针
|
||||
T0, T1, T2, T3, T4, T5, T6, // x5-x7, x28-x31
|
||||
A0, A1, A2, A3, A4, A5, A6, A7 // x10-x17
|
||||
};
|
||||
static const std::vector<PhysicalReg> allocable_regs;
|
||||
|
||||
// 操作数
|
||||
struct Operand {
|
||||
enum class Kind { Reg, Imm, Label };
|
||||
Kind kind;
|
||||
Value* value; // 用于寄存器
|
||||
std::string label; // 用于标签或立即数
|
||||
Operand(Kind k, Value* v) : kind(k), value(v), label("") {}
|
||||
Operand(Kind k, const std::string& l) : kind(k), value(nullptr), label(l) {}
|
||||
S0, T0, T1, T2, T3, T4, T5, T6,
|
||||
A0, A1, A2, A3, A4, A5, A6, A7
|
||||
};
|
||||
|
||||
// RISC-V 指令
|
||||
struct RISCv32Inst {
|
||||
std::string opcode;
|
||||
std::vector<Operand> operands;
|
||||
RISCv32Inst(const std::string& op, const std::vector<Operand>& ops)
|
||||
: opcode(op), operands(ops) {}
|
||||
// Move DAGNode and RegAllocResult to public section
|
||||
struct DAGNode {
|
||||
enum NodeKind { CONSTANT, LOAD, STORE, BINARY, CALL, RETURN };
|
||||
NodeKind kind;
|
||||
Value* value = nullptr;
|
||||
std::string inst;
|
||||
std::string result_reg;
|
||||
std::vector<DAGNode*> operands;
|
||||
std::vector<DAGNode*> users;
|
||||
DAGNode(NodeKind k) : kind(k) {}
|
||||
};
|
||||
|
||||
// 寄存器分配结果
|
||||
struct RegAllocResult {
|
||||
std::map<Value*, PhysicalReg> reg_map; // 虚拟寄存器到物理寄存器的映射
|
||||
std::map<Value*, int> stack_map; // 虚拟寄存器到堆栈槽的映射
|
||||
int stack_size; // 堆栈帧大小
|
||||
std::map<std::string, PhysicalReg> vreg_to_preg;
|
||||
std::map<Value*, int> stack_map;
|
||||
int stack_size = 0;
|
||||
};
|
||||
|
||||
// 后端方法
|
||||
RISCv32CodeGen(Module* mod) : module(mod) {}
|
||||
|
||||
std::string code_gen();
|
||||
std::string module_gen();
|
||||
std::string function_gen(Function* func);
|
||||
std::string basicBlock_gen(BasicBlock* bb, const RegAllocResult& alloc);
|
||||
std::vector<std::string> instruction_gen(Instruction* inst, const RegAllocResult& alloc);
|
||||
std::vector<std::unique_ptr<DAGNode>> build_dag(BasicBlock* bb);
|
||||
void select_instructions(DAGNode* node, const RegAllocResult& alloc); // Use const
|
||||
void emit_instructions(DAGNode* node, std::vector<std::string>& insts, const RegAllocResult& alloc); // Add alloc
|
||||
std::map<Instruction*, std::set<std::string>> liveness_analysis(Function* func);
|
||||
std::map<std::string, std::set<std::string>> build_interference_graph(
|
||||
const std::map<Instruction*, std::set<std::string>>& live_sets);
|
||||
void color_graph(std::map<std::string, PhysicalReg>& vreg_to_preg,
|
||||
const std::map<std::string, std::set<std::string>>& interference_graph);
|
||||
RegAllocResult register_allocation(Function* func);
|
||||
void eliminate_phi(Function* func);
|
||||
std::map<Instruction*, std::set<Value*>> liveness_analysis(Function* func);
|
||||
std::map<Value*, std::set<Value*>> build_interference_graph(
|
||||
const std::map<Instruction*, std::set<Value*>>& live_sets);
|
||||
std::string reg_to_string(PhysicalReg reg);
|
||||
|
||||
private:
|
||||
static const std::vector<PhysicalReg> allocable_regs;
|
||||
std::map<Value*, std::string> value_vreg_map;
|
||||
Module* module;
|
||||
};
|
||||
|
||||
} // namespace sysy
|
||||
|
||||
Reference in New Issue
Block a user