Files
nudt-compiler-cpp/src/ir/passes/Mem2Reg.cpp

229 lines
7.2 KiB
C++

#include "ir/PassManager.h"
#include <iostream>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include <stack>
#include <algorithm>
#include <queue>
#include <functional>
namespace ir {
// Predeclaration of rebuild CFG helper
void RebuildCFG(Function* func);
bool RunMem2Reg(Function* func, Context& ctx) {
// 1. Build dominator tree
DominatorTree dom_tree(func);
dom_tree.Run();
// 2. Identify promotable allocas
std::vector<AllocaInst*> promotable_allocas;
for (const auto& bbPtr : func->GetBlocks()) {
for (const auto& instPtr : bbPtr->GetInstructions()) {
if (instPtr->GetOpcode() == Opcode::Alloca) {
auto* alloca = static_cast<AllocaInst*>(instPtr.get());
// Alloca of scalar type: i32 or float (pointers to i32/float in minimum IR)
if (alloca->GetType()->IsPtrInt32() || alloca->GetType()->IsPtrFloat()) {
// Verify all uses are load/store
bool promotable = true;
for (const auto& use : alloca->GetUses()) {
auto* user = use.GetUser();
auto* inst_user = dynamic_cast<Instruction*>(user);
if (!inst_user) {
promotable = false;
break;
}
if (inst_user->GetOpcode() != Opcode::Load && inst_user->GetOpcode() != Opcode::Store) {
promotable = false;
break;
}
// For Store, alloca must be the pointer operand (operand index 1), not the value operand
if (inst_user->GetOpcode() == Opcode::Store) {
auto* store = static_cast<StoreInst*>(inst_user);
if (store->GetPtr() != alloca) {
promotable = false;
break;
}
}
}
if (promotable) {
promotable_allocas.push_back(alloca);
}
}
}
}
}
if (promotable_allocas.empty()) {
return false;
}
// 3. For each alloca, find definition blocks and place Phi nodes
// Maps each basic block and alloca to the inserted Phi instruction
std::unordered_map<BasicBlock*, std::unordered_map<AllocaInst*, PhiInst*>> phi_nodes;
std::unordered_set<Instruction*> instructions_to_erase;
for (auto* alloca : promotable_allocas) {
std::vector<BasicBlock*> def_blocks;
for (const auto& use : alloca->GetUses()) {
auto* inst = dynamic_cast<Instruction*>(use.GetUser());
if (inst && inst->GetOpcode() == Opcode::Store) {
def_blocks.push_back(inst->GetParent());
}
}
// DF-based Phi placement
std::queue<BasicBlock*> worklist;
std::unordered_set<BasicBlock*> added;
std::unordered_set<BasicBlock*> def_set(def_blocks.begin(), def_blocks.end());
for (auto* bb : def_blocks) {
worklist.push(bb);
added.insert(bb);
}
while (!worklist.empty()) {
auto* x = worklist.front();
worklist.pop();
for (auto* y : dom_tree.GetDominanceFrontier(x)) {
if (added.find(y) == added.end()) {
// Place Phi node in Y
std::shared_ptr<Type> ty = alloca->GetType()->IsPtrFloat() ? Type::GetFloatType() : Type::GetInt32Type();
auto phi = std::make_unique<PhiInst>(ty, ctx.NextTemp());
auto* phi_ptr = phi.get();
// Insert Phi at the start of block Y
y->InsertInstructionAtBegin(std::move(phi));
phi_nodes[y][alloca] = phi_ptr;
added.insert(y);
if (def_set.find(y) == def_set.end()) {
worklist.push(y);
}
}
}
}
}
// 4. Rename variables using DFS traversal of dominator tree
std::unordered_map<AllocaInst*, std::vector<Value*>> current_def;
// Helper for generating default value
auto get_default_value = [&](AllocaInst* alloca) -> Value* {
if (alloca->GetType()->IsPtrFloat()) {
return ctx.GetConstFloat(0.0f);
} else {
return ctx.GetConstInt(0);
}
};
// Traversal stack for DFS: stores (block, parent_block)
struct TraversalNode {
BasicBlock* bb;
size_t child_idx;
};
std::stack<BasicBlock*> visit_stack;
std::unordered_map<BasicBlock*, std::vector<std::pair<AllocaInst*, size_t>>> pushed_defs;
// DFS function
std::function<void(BasicBlock*)> rename_dfs = [&](BasicBlock* bb) {
auto& pushes = pushed_defs[bb];
// Push Phis in this block to current_def
auto phi_it = phi_nodes.find(bb);
if (phi_it != phi_nodes.end()) {
for (const auto& pair : phi_it->second) {
auto* alloca = pair.first;
auto* phi = pair.second;
current_def[alloca].push_back(phi);
pushes.push_back({alloca, 1});
}
}
// Process loads and stores
for (const auto& instPtr : bb->GetInstructions()) {
auto* inst = instPtr.get();
if (inst->GetOpcode() == Opcode::Load) {
auto* load = static_cast<LoadInst*>(inst);
auto* ptr = load->GetPtr();
if (auto* alloca = dynamic_cast<AllocaInst*>(ptr)) {
if (std::find(promotable_allocas.begin(), promotable_allocas.end(), alloca) != promotable_allocas.end()) {
auto& defs = current_def[alloca];
Value* val = defs.empty() ? get_default_value(alloca) : defs.back();
load->ReplaceAllUsesWith(val);
instructions_to_erase.insert(load);
}
}
} else if (inst->GetOpcode() == Opcode::Store) {
auto* store = static_cast<StoreInst*>(inst);
auto* ptr = store->GetPtr();
if (auto* alloca = dynamic_cast<AllocaInst*>(ptr)) {
if (std::find(promotable_allocas.begin(), promotable_allocas.end(), alloca) != promotable_allocas.end()) {
current_def[alloca].push_back(store->GetValue());
pushes.push_back({alloca, 1});
instructions_to_erase.insert(store);
}
}
}
}
// Fill Phi incoming values for CFG successors
for (auto* succ : bb->GetSuccessors()) {
auto succ_phi_it = phi_nodes.find(succ);
if (succ_phi_it != phi_nodes.end()) {
for (const auto& pair : succ_phi_it->second) {
auto* alloca = pair.first;
auto* phi = pair.second;
auto& defs = current_def[alloca];
Value* val = defs.empty() ? get_default_value(alloca) : defs.back();
phi->AddIncoming(val, bb);
}
}
}
// Recurse to dominator tree children
for (auto* child : dom_tree.GetDominatedBlocks(bb)) {
rename_dfs(child);
}
// Pop definitions pushed in this block
for (const auto& push : pushes) {
auto* alloca = push.first;
for (size_t k = 0; k < push.second; ++k) {
if (!current_def[alloca].empty()) {
current_def[alloca].pop_back();
}
}
}
};
if (!func->GetBlocks().empty()) {
rename_dfs(func->GetEntry());
}
// 5. Clean up loads, stores and allocas
for (auto* alloca : promotable_allocas) {
instructions_to_erase.insert(alloca);
}
for (const auto& bbPtr : func->GetBlocks()) {
std::vector<Instruction*> to_remove;
for (const auto& instPtr : bbPtr->GetInstructions()) {
if (instructions_to_erase.find(instPtr.get()) != instructions_to_erase.end()) {
to_remove.push_back(instPtr.get());
}
}
for (auto* inst : to_remove) {
bbPtr->EraseInstruction(inst);
}
}
return true;
}
} // namespace ir