Compare commits

..

24 Commits

Author SHA1 Message Date
Lixuanwang
c2153b6fab [deploy]部署版本1 2025-07-20 00:10:24 +08:00
Lixuanwang
d7fb017550 Merge branch 'backend-llir' into backend 2025-07-19 18:00:42 +08:00
Lixuanwang
c4b18a70db [backend]准备合并backend-llir 2025-07-19 17:59:45 +08:00
Lixuanwang
9528335a04 [backend-llir]修复了许多重构的bug 2025-07-19 17:50:14 +08:00
rain2133
0d5748e9c5 [IR]修复初始化数组指令的逻辑,更新IR常量定义。 2025-07-19 16:18:05 +08:00
Lixuanwang
d4a6996d74 [backend]重构了后端 2025-07-19 16:06:35 +08:00
rain2133
36cfd2f64d 先将SCCP中重构IR的部分移植到backend 2025-07-19 15:00:04 +08:00
Lixuanwang
75e61bf274 [backend-llir]引入了LLIR定义 2025-07-19 14:29:57 +08:00
Lixuanwang
c8308047df [backend]引入了Memset指令在后端的展开 2025-07-19 13:52:09 +08:00
Lixuanwang
86d1de6696 [backend]向脚本添加了打印不通过测例的功能 2025-07-19 12:00:02 +08:00
Lixuanwang
69d27f058d [backend]将testdata/下的测例替换为了赛方测试用例,更新了测试脚本 2025-07-19 01:44:37 +08:00
Lixuanwang
6335abe806 [backend]修复了引入常量重质化后全局常量加载指令的缺失问题 2025-07-19 00:46:46 +08:00
Lixuanwang
6ed5965b29 [backend]数组访存问题基本修复 2025-07-19 00:32:47 +08:00
Lixuanwang
0f26be3586 [backend]添加了对主函数中新引入的调试信息打印的控制,修改了测试脚本为云平台测试的参数,调整了73_int_io测例的输入文件的格式 2025-07-18 22:50:06 +08:00
Lixuanwang
d38ec13cbd [backend]修复了函数调用参数为常数时,参数传递有误的bug 2025-07-18 21:54:24 +08:00
Lixuanwang
e8660120cc [backend]删除了后端对数组访存的地址展开,因为已经在IR中实现 2025-07-18 20:48:59 +08:00
Lixuanwang
3657c08644 [backend]引入新的活跃性分析 2025-07-18 20:24:47 +08:00
Lixuanwang
1bcb5eba2a [backend]去除了错误的寄存器分配机制 2025-07-18 18:50:21 +08:00
Lixuanwang
fc62446b40 Merge branch 'backend' of gitee.com:lixuanwang/mysysy into backend 2025-07-18 18:48:44 +08:00
Lixuanwang
fedb4b0a9f [backend]修复了栈分配空间大小不考虑数组的错误 2025-07-18 18:48:38 +08:00
Lixuanwang
4bf4c98787 [backend]修复了栈分配空间大小不考虑数组的错误 2025-07-18 18:44:13 +08:00
Lixuanwang
198c1974e3 [backend] 新pass ACE修改完毕 2025-07-18 13:04:02 +08:00
Lixuanwang
b90e4faa6a [backend] 删除了部分错误代码 2025-07-18 01:37:29 +08:00
Lixuanwang
be8ca144d0 [backend]引入了新的pass,负责消除数组复杂地址访问 2025-07-18 00:10:10 +08:00
354 changed files with 47362 additions and 5195 deletions

2
.gitignore vendored
View File

@@ -36,7 +36,7 @@
doxygen doxygen
!/testdata/functional/*.out !/testdata/functional/*.out
!/testdata/performance/*.out !/testdata/h_functional/*.out
build/ build/
.antlr .antlr
.vscode/ .vscode/

View File

@@ -0,0 +1,160 @@
#include "AddressCalculationExpansion.h"
#include <iostream>
#include <vector>
#include "IR.h"
#include "IRBuilder.h"
extern int DEBUG;
namespace sysy {
bool AddressCalculationExpansion::run() {
bool changed = false;
for (auto& funcPair : pModule->getFunctions()) {
Function* func = funcPair.second.get();
for (auto& bb_ptr : func->getBasicBlocks()) {
BasicBlock* bb = bb_ptr.get();
for (auto it = bb->getInstructions().begin(); it != bb->getInstructions().end(); ) {
Instruction* inst = it->get();
Value* basePointer = nullptr;
Value* valueToStore = nullptr;
size_t firstIndexOperandIdx = 0;
size_t numBaseOperands = 0;
if (inst->isLoad()) {
numBaseOperands = 1;
basePointer = inst->getOperand(0);
firstIndexOperandIdx = 1;
} else if (inst->isStore()) {
numBaseOperands = 2;
valueToStore = inst->getOperand(0);
basePointer = inst->getOperand(1);
firstIndexOperandIdx = 2;
} else {
++it;
continue;
}
if (inst->getNumOperands() <= numBaseOperands) {
++it;
continue;
}
std::vector<int> dims;
if (AllocaInst* allocaInst = dynamic_cast<AllocaInst*>(basePointer)) {
for (const auto& use_ptr : allocaInst->getDims()) {
Value* dimValue = use_ptr->getValue();
if (ConstantValue* constVal = dynamic_cast<ConstantValue*>(dimValue)) {
dims.push_back(constVal->getInt());
} else {
std::cerr << "Warning: AllocaInst dimension is not a constant integer. Skipping GEP expansion for: ";
SysYPrinter::printValue(allocaInst);
std::cerr << "\n";
dims.clear();
break;
}
}
} else if (GlobalValue* globalValue = dynamic_cast<GlobalValue*>(basePointer)) {
std::cerr << "Warning: GlobalValue dimension handling needs explicit implementation for GEP expansion. Skipping GEP for: ";
SysYPrinter::printValue(globalValue);
std::cerr << "\n";
++it;
continue;
} else {
std::cerr << "Warning: Base pointer is not AllocaInst/GlobalValue or its array dimensions cannot be determined for GEP expansion. Skipping GEP for: ";
SysYPrinter::printValue(basePointer);
std::cerr << " in instruction ";
SysYPrinter::printInst(inst);
std::cerr << "\n";
++it;
continue;
}
if (dims.empty() && (inst->getNumOperands() > numBaseOperands)) {
if (DEBUG) {
std::cerr << "ACE Warning: Could not get valid array dimensions for ";
SysYPrinter::printValue(basePointer);
std::cerr << " in instruction ";
SysYPrinter::printInst(inst);
std::cerr << " (expected dimensions for indices, but got none).\n";
}
++it;
continue;
}
std::vector<Value*> indexOperands;
for (size_t i = firstIndexOperandIdx; i < inst->getNumOperands(); ++i) {
indexOperands.push_back(inst->getOperand(i));
}
if (AllocaInst* allocaInst = dynamic_cast<AllocaInst*>(basePointer)) {
if (allocaInst->getNumDims() != indexOperands.size()) {
if (DEBUG) {
std::cerr << "ACE Warning: Index count (" << indexOperands.size() << ") does not match AllocaInst dimensions (" << allocaInst->getNumDims() << ") for instruction ";
SysYPrinter::printInst(inst);
std::cerr << "\n";
}
++it;
continue;
}
}
Value* totalOffset = ConstantInteger::get(0);
pBuilder->setPosition(bb, it);
for (size_t i = 0; i < indexOperands.size(); ++i) {
Value* index = indexOperands[i];
int stride = calculateStride(dims, i);
Value* strideConst = ConstantInteger::get(stride);
Type* intType = Type::getIntType();
BinaryInst* currentDimOffsetInst = pBuilder->createBinaryInst(Instruction::kMul, intType, index, strideConst);
BinaryInst* newTotalOffsetInst = pBuilder->createBinaryInst(Instruction::kAdd, intType, totalOffset, currentDimOffsetInst);
totalOffset = newTotalOffsetInst;
}
// 计算有效地址effective_address = basePointer + totalOffset
Value* effective_address = pBuilder->createBinaryInst(Instruction::kAdd, basePointer->getType(), basePointer, totalOffset);
// 创建新的 LoadInst 或 StoreInstindices 为空
Instruction* newInst = nullptr;
if (inst->isLoad()) {
newInst = pBuilder->createLoadInst(effective_address, {});
inst->replaceAllUsesWith(newInst);
} else { // StoreInst
newInst = pBuilder->createStoreInst(valueToStore, effective_address, {});
}
Instruction* oldInst = it->get();
++it;
for (size_t i = 0; i < oldInst->getNumOperands(); ++i) {
Value* operandValue = oldInst->getOperand(i);
if (operandValue) {
for (auto use_it = operandValue->getUses().begin(); use_it != operandValue->getUses().end(); ++use_it) {
if ((*use_it)->getUser() == oldInst && (*use_it)->getIndex() == i) {
operandValue->removeUse(*use_it);
break;
}
}
}
}
bb->getInstructions().erase(std::prev(it));
changed = true;
if (DEBUG) {
std::cerr << "ACE: Computed effective address:\n";
SysYPrinter::printInst(dynamic_cast<Instruction*>(effective_address));
std::cerr << "ACE: New Load/Store instruction:\n";
SysYPrinter::printInst(newInst);
std::cerr << "--------------------------------\n";
}
}
}
}
return changed;
}
} // namespace sysy

View File

@@ -23,10 +23,14 @@ add_executable(sysyc
SysYIRPrinter.cpp SysYIRPrinter.cpp
SysYIROptPre.cpp SysYIROptPre.cpp
SysYIRAnalyser.cpp SysYIRAnalyser.cpp
DeadCodeElimination.cpp # DeadCodeElimination.cpp
Mem2Reg.cpp AddressCalculationExpansion.cpp
Reg2Mem.cpp # Mem2Reg.cpp
# Reg2Mem.cpp
RISCv64Backend.cpp RISCv64Backend.cpp
RISCv64ISel.cpp
RISCv64RegAlloc.cpp
RISCv64AsmPrinter.cpp
) )
# 设置 include 路径,包含 ANTLR 运行时库和项目头文件 # 设置 include 路径,包含 ANTLR 运行时库和项目头文件

View File

@@ -1,276 +0,0 @@
#include "DeadCodeElimination.h"
#include <iostream>
extern int DEBUG;
namespace sysy {
void DeadCodeElimination::runDCEPipeline() {
const auto& functions = pModule->getFunctions();
for (const auto& function : functions) {
const auto& func = function.second;
bool changed = true;
while (changed) {
changed = false;
eliminateDeadStores(func.get(), changed);
eliminateDeadLoads(func.get(), changed);
eliminateDeadAllocas(func.get(), changed);
eliminateDeadRedundantLoadStore(func.get(), changed);
eliminateDeadGlobals(changed);
}
}
}
// 消除无用存储 消除条件:
// 存储的目标指针pointer不是全局变量!isGlobal(pointer))。
// 存储的目标指针不是数组参数(!isArr(pointer) 或不在函数参数列表里)。
// 该指针的所有使用者uses仅限 alloca 或 store即没有 load 或其他指令使用它)。
void DeadCodeElimination::eliminateDeadStores(Function* func, bool& changed) {
for (const auto& block : func->getBasicBlocks()) {
auto& instrs = block->getInstructions();
for (auto iter = instrs.begin(); iter != instrs.end();) {
auto inst = iter->get();
if (!inst->isStore()) {
++iter;
continue;
}
auto storeInst = dynamic_cast<StoreInst*>(inst);
auto pointer = storeInst->getPointer();
// 如果是全局变量或者是函数的数组参数
if (isGlobal(pointer) || (isArr(pointer) &&
std::find(func->getEntryBlock()->getArguments().begin(),
func->getEntryBlock()->getArguments().end(),
pointer) != func->getEntryBlock()->getArguments().end())) {
++iter;
continue;
}
bool changetag = true;
for (auto& use : pointer->getUses()) {
// 依次判断store的指针是否被其他指令使用
auto user = use->getUser();
auto userInst = dynamic_cast<Instruction*>(user);
// 如果使用store的指针的指令不是Alloca或Store则不删除
if (userInst != nullptr && !userInst->isAlloca() && !userInst->isStore()) {
changetag = false;
break;
}
}
if (changetag) {
changed = true;
if(DEBUG){
std::cout << "=== Dead Store Found ===\n";
SysYPrinter::printInst(storeInst);
}
usedelete(storeInst);
iter = instrs.erase(iter);
} else {
++iter;
}
}
}
}
// 消除无用加载 消除条件:
// 该指令的结果未被使用inst->getUses().empty())。
void DeadCodeElimination::eliminateDeadLoads(Function* func, bool& changed) {
for (const auto& block : func->getBasicBlocks()) {
auto& instrs = block->getInstructions();
for (auto iter = instrs.begin(); iter != instrs.end();) {
auto inst = iter->get();
if (inst->isBinary() || inst->isUnary() || inst->isLoad()) {
if (inst->getUses().empty()) {
changed = true;
if(DEBUG){
std::cout << "=== Dead Load Binary Unary Found ===\n";
SysYPrinter::printInst(inst);
}
usedelete(inst);
iter = instrs.erase(iter);
continue;
}
}
++iter;
}
}
}
// 消除无用加载 消除条件:
// 该 alloca 未被任何指令使用allocaInst->getUses().empty())。
// 该 alloca 不是函数的参数(不在 entry 块的参数列表里)。
void DeadCodeElimination::eliminateDeadAllocas(Function* func, bool& changed) {
for (const auto& block : func->getBasicBlocks()) {
auto& instrs = block->getInstructions();
for (auto iter = instrs.begin(); iter != instrs.end();) {
auto inst = iter->get();
if (inst->isAlloca()) {
auto allocaInst = dynamic_cast<AllocaInst*>(inst);
if (allocaInst->getUses().empty() &&
std::find(func->getEntryBlock()->getArguments().begin(),
func->getEntryBlock()->getArguments().end(),
allocaInst) == func->getEntryBlock()->getArguments().end()) {
changed = true;
if(DEBUG){
std::cout << "=== Dead Alloca Found ===\n";
SysYPrinter::printInst(inst);
}
usedelete(inst);
iter = instrs.erase(iter);
continue;
}
}
++iter;
}
}
}
void DeadCodeElimination::eliminateDeadIndirectiveAllocas(Function* func, bool& changed) {
// 删除mem2reg时引入的且现在已经没有value使用了的隐式alloca
FunctionAnalysisInfo* funcInfo = pCFA->getFunctionAnalysisInfo(func);
for (auto it = funcInfo->getIndirectAllocas().begin(); it != funcInfo->getIndirectAllocas().end();) {
auto &allocaInst = *it;
if (allocaInst->getUses().empty()) {
changed = true;
if(DEBUG){
std::cout << "=== Dead Indirect Alloca Found ===\n";
SysYPrinter::printInst(allocaInst.get());
}
it = funcInfo->getIndirectAllocas().erase(it);
} else {
++it;
}
}
}
// 该全局变量未被任何指令使用global->getUses().empty())。
void DeadCodeElimination::eliminateDeadGlobals(bool& changed) {
auto& globals = pModule->getGlobals();
for (auto it = globals.begin(); it != globals.end();) {
auto& global = *it;
if (global->getUses().empty()) {
changed = true;
if(DEBUG){
std::cout << "=== Dead Global Found ===\n";
SysYPrinter::printValue(global.get());
}
it = globals.erase(it);
} else {
++it;
}
}
}
// 消除冗余加载和存储 消除条件:
// phi 指令的目标指针仅被该 phi 使用(无其他 store/load 使用)。
// memset 指令的目标指针未被使用pointer->getUses().empty()
// store -> load -> store 模式
void DeadCodeElimination::eliminateDeadRedundantLoadStore(Function* func, bool& changed) {
for (const auto& block : func->getBasicBlocks()) {
auto& instrs = block->getInstructions();
for (auto iter = instrs.begin(); iter != instrs.end();) {
auto inst = iter->get();
if (inst->isPhi()) {
auto phiInst = dynamic_cast<PhiInst*>(inst);
auto pointer = phiInst->getPointer();
bool tag = true;
for (const auto& use : pointer->getUses()) {
auto user = use->getUser();
if (user != inst) {
tag = false;
break;
}
}
/// 如果 pointer 仅被该 phi 使用,可以删除 ph
if (tag) {
changed = true;
usedelete(inst);
iter = instrs.erase(iter);
continue;
}
// 数组指令还不完善不保证memset优化效果
} else if (inst->isMemset()) {
auto memsetInst = dynamic_cast<MemsetInst*>(inst);
auto pointer = memsetInst->getPointer();
if (pointer->getUses().empty()) {
changed = true;
usedelete(inst);
iter = instrs.erase(iter);
continue;
}
}else if(inst->isLoad()) {
if (iter != instrs.begin()) {
auto loadInst = dynamic_cast<LoadInst*>(inst);
auto loadPointer = loadInst->getPointer();
// TODO:store -> load -> store 模式
auto prevIter = std::prev(iter);
auto prevInst = prevIter->get();
if (prevInst->isStore()) {
auto prevStore = dynamic_cast<StoreInst*>(prevInst);
auto prevStorePointer = prevStore->getPointer();
auto prevStoreValue = prevStore->getOperand(0);
// 确保前一个 store 不是数组操作
if (prevStore->getIndices().empty()) {
// 检查后一条指令是否是 store 同一个值
auto nextIter = std::next(iter);
if (nextIter != instrs.end()) {
auto nextInst = nextIter->get();
if (nextInst->isStore()) {
auto nextStore = dynamic_cast<StoreInst*>(nextInst);
auto nextStorePointer = nextStore->getPointer();
auto nextStoreValue = nextStore->getOperand(0);
// 确保后一个 store 不是数组操作
if (nextStore->getIndices().empty()) {
// 判断优化条件:
// 1. prevStore 的指针操作数 == load 的指针操作数
// 2. nextStore 的值操作数 == load 指令本身
if (prevStorePointer == loadPointer &&
nextStoreValue == loadInst) {
// 可以优化直接把prevStorePointer的值存到nextStorePointer
changed = true;
nextStore->setOperand(0, prevStoreValue);
if(DEBUG){
std::cout << "=== Dead Store Load Store Found(now only del Load) ===\n";
SysYPrinter::printInst(prevStore);
SysYPrinter::printInst(loadInst);
SysYPrinter::printInst(nextStore);
}
usedelete(loadInst);
iter = instrs.erase(iter);
// 删除 prevStore 这里是不是可以留给删除无用store处理
// if (prevStore->getUses().empty()) {
// usedelete(prevStore);
// instrs.erase(prevIter); // 删除 prevStore
// }
continue; // 跳过 ++iter因为已经移动迭代器
}
}
}
}
}
}
}
}
++iter;
}
}
}
bool DeadCodeElimination::isGlobal(Value *val){
auto gval = dynamic_cast<GlobalValue *>(val);
return gval != nullptr;
}
bool DeadCodeElimination::isArr(Value *val){
auto aval = dynamic_cast<AllocaInst *>(val);
return aval != nullptr && aval->getNumDims() != 0;
}
void DeadCodeElimination::usedelete(Instruction *instr){
for (auto &use1 : instr->getOperands()) {
auto val1 = use1->getValue();
val1->removeUse(use1);
}
}
} // namespace sysy

View File

@@ -102,30 +102,54 @@ void Value::replaceAllUsesWith(Value *value) {
uses.clear(); uses.clear();
} }
ConstantValue* ConstantValue::get(int value) {
static std::map<int, std::unique_ptr<ConstantValue>> intConstants; // Implementations for static members
auto iter = intConstants.find(value);
if (iter != intConstants.end()) { std::unordered_map<ConstantValueKey, ConstantValue*, ConstantValueHash, ConstantValueEqual> ConstantValue::mConstantPool;
return iter->second.get(); std::unordered_map<Type*, UndefinedValue*> UndefinedValue::UndefValues;
ConstantValue* ConstantValue::get(Type* type, ConstantValVariant val) {
ConstantValueKey key = {type, val};
auto it = mConstantPool.find(key);
if (it != mConstantPool.end()) {
return it->second;
} }
auto inst = new ConstantValue(value);
assert(inst); ConstantValue* newConstant = nullptr;
auto result = intConstants.emplace(value, inst); if (std::holds_alternative<int>(val)) {
return result.first->second.get(); newConstant = new ConstantInteger(type, std::get<int>(val));
} else if (std::holds_alternative<float>(val)) {
newConstant = new ConstantFloating(type, std::get<float>(val));
} else {
assert(false && "Unsupported ConstantValVariant type");
}
mConstantPool[key] = newConstant;
return newConstant;
} }
ConstantValue* ConstantValue::get(float value) { ConstantInteger* ConstantInteger::get(Type* type, int val) {
static std::map<float, std::unique_ptr<ConstantValue>> floatConstants; return dynamic_cast<ConstantInteger*>(ConstantValue::get(type, val));
auto iter = floatConstants.find(value);
if (iter != floatConstants.end()) {
return iter->second.get();
}
auto inst = new ConstantValue(value);
assert(inst);
auto result = floatConstants.emplace(value, inst);
return result.first->second.get();
} }
ConstantFloating* ConstantFloating::get(Type* type, float val) {
return dynamic_cast<ConstantFloating*>(ConstantValue::get(type, val));
}
UndefinedValue* UndefinedValue::get(Type* type) {
assert(!type->isVoid() && "Cannot get UndefinedValue of void type!");
auto it = UndefValues.find(type);
if (it != UndefValues.end()) {
return it->second;
}
UndefinedValue* newUndef = new UndefinedValue(type);
UndefValues[type] = newUndef;
return newUndef;
}
auto Function::getCalleesWithNoExternalAndSelf() -> std::set<Function *> { auto Function::getCalleesWithNoExternalAndSelf() -> std::set<Function *> {
std::set<Function *> result; std::set<Function *> result;
for (auto callee : callees) { for (auto callee : callees) {
@@ -545,6 +569,83 @@ void User::replaceOperand(unsigned index, Value *value) {
value->addUse(use); value->addUse(use);
} }
/**
* phi相关函数
*/
Value* PhiInst::getvalfromBlk(BasicBlock* blk){
refreshB2VMap();
if( blk2val.find(blk) != blk2val.end()) {
return blk2val.at(blk);
}
return nullptr;
}
BasicBlock* PhiInst::getBlkfromVal(Value* val){
// 返回第一个值对应的基本块
for(unsigned i = 0; i < vsize; i++) {
if(getValue(i) == val) {
return getBlock(i);
}
}
return nullptr;
}
void PhiInst::delValue(Value* val){
//根据value删除对应的基本块和值
unsigned i = 0;
BasicBlock* blk = getBlkfromVal(val);
for(i = 0; i < vsize; i++) {
if(getValue(i) == val) {
break;
}
}
removeOperand(2 * i + 1); // 删除blk
removeOperand(2 * i); // 删除val
vsize--;
blk2val.erase(blk); // 删除blk2val映射
}
void PhiInst::delBlk(BasicBlock* blk){
//根据Blk删除对应的基本块和值
unsigned i = 0;
Value* val = getvalfromBlk(blk);
for(i = 0; i < vsize; i++) {
if(getBlock(i) == blk) {
break;
}
}
removeOperand(2 * i + 1); // 删除blk
removeOperand(2 * i); // 删除val
vsize--;
blk2val.erase(blk); // 删除blk2val映射
}
void PhiInst::replaceBlk(BasicBlock* newBlk, unsigned k){
refreshB2VMap();
Value* val = blk2val.at(getBlock(k));
// 替换基本块
setOperand(2 * k + 1, newBlk);
// 替换blk2val映射
blk2val.erase(getBlock(k));
blk2val.emplace(newBlk, val);
}
void PhiInst::replaceold2new(BasicBlock* oldBlk, BasicBlock* newBlk){
refreshB2VMap();
Value* val = blk2val.at(oldBlk);
// 替换基本块
delBlk(oldBlk);
addIncoming(val, newBlk);
}
void PhiInst::refreshB2VMap(){
blk2val.clear();
for(unsigned i = 0; i < vsize; i++) {
blk2val.emplace(getBlock(i), getValue(i));
}
}
CallInst::CallInst(Function *callee, const std::vector<Value *> &args, BasicBlock *parent, const std::string &name) CallInst::CallInst(Function *callee, const std::vector<Value *> &args, BasicBlock *parent, const std::string &name)
: Instruction(kCall, callee->getReturnType(), parent, name) { : Instruction(kCall, callee->getReturnType(), parent, name) {
addOperand(callee); addOperand(callee);

View File

@@ -1,801 +0,0 @@
#include "Mem2Reg.h"
#include <algorithm>
#include <cassert>
#include <iterator>
#include <memory>
#include <queue>
#include <stack>
#include <string>
#include <unordered_map>
#include <utility>
#include "IR.h"
#include "SysYIRAnalyser.h"
#include "SysYIRPrinter.h"
namespace sysy {
// 计算给定变量的定义块集合的迭代支配边界
// TODO优化Semi-Naive IDF
std::unordered_set<BasicBlock *> Mem2Reg::computeIterDf(const std::unordered_set<BasicBlock *> &blocks) {
std::unordered_set<BasicBlock *> workList;
std::unordered_set<BasicBlock *> ret_list;
workList.insert(blocks.begin(), blocks.end());
while (!workList.empty()) {
auto n = workList.begin();
BlockAnalysisInfo* blockInfo = controlFlowAnalysis->getBlockAnalysisInfo(*n);
auto DFs = blockInfo->getDomFrontiers();
for (auto c : DFs) {
// 如果c不在ret_list中则将其加入ret_list和workList
// 这里的c是n的支配边界
// 也就是n的支配边界中的块
// 需要注意的是,支配边界是一个集合,所以可能会有重复
if (ret_list.count(c) == 0U) {
ret_list.emplace(c);
workList.emplace(c);
}
}
workList.erase(n);
}
return ret_list;
}
/**
* 计算value2Blocks的映射包括value2AllocBlocks、value2DefBlocks以及value2UseBlocks
* 其中value2DefBlocks可用于计算迭代支配边界来插入相应变量的phi结点
* 这里的value2AllocBlocks、value2DefBlocks和value2UseBlocks改变了函数级别的分析信息
*/
auto Mem2Reg::computeValue2Blocks() -> void {
SysYPrinter printer(pModule); // 初始化打印机
// std::cout << "===== Start computeValue2Blocks =====" << std::endl;
auto &functions = pModule->getFunctions();
for (const auto &function : functions) {
auto func = function.second.get();
// std::cout << "\nProcessing function: " << func->getName() << std::endl;
FunctionAnalysisInfo* funcInfo = controlFlowAnalysis->getFunctionAnalysisInfo(func);
if (!funcInfo) {
std::cerr << "ERROR: No analysis info for function " << func->getName() << std::endl;
continue;
}
auto basicBlocks = func->getBasicBlocks();
// std::cout << "BasicBlocks count: " << basicBlocks.size() << std::endl;
for (auto &it : basicBlocks) {
auto basicBlock = it.get();
// std::cout << "\nProcessing BB: " << basicBlock->getName() << std::endl;
// printer.printBlock(basicBlock); // 打印基本块内容
auto &instrs = basicBlock->getInstructions();
for (auto &instr : instrs) {
// std::cout << " Analyzing instruction: ";
// printer.printInst(instr.get());
// std::cout << std::endl;
if (instr->isAlloca()) {
if (!(isArr(instr.get()) || isGlobal(instr.get()))) {
// std::cout << " Found alloca: ";
// printer.printInst(instr.get());
// std::cout << " -> Adding to allocBlocks" << std::endl;
funcInfo->addValue2AllocBlocks(instr.get(), basicBlock);
} else {
// std::cout << " Skip array/global alloca: ";
// printer.printInst(instr.get());
// std::cout << std::endl;
}
}
else if (instr->isStore()) {
auto val = instr->getOperand(1);
// std::cout << " Store target: ";
// printer.printInst(dynamic_cast<Instruction *>(val));
if (!(isArr(val) || isGlobal(val))) {
// std::cout << " Adding store to defBlocks for value: ";
// printer.printInst(dynamic_cast<Instruction *>(instr.get()));
// std::cout << std::endl;
// 将store的目标值添加到defBlocks中
funcInfo->addValue2DefBlocks(val, basicBlock);
} else {
// std::cout << " Skip array/global store" << std::endl;
}
}
else if (instr->isLoad()) {
auto val = instr->getOperand(0);
// std::cout << " Load source: ";
// printer.printInst(dynamic_cast<Instruction *>(val));
// std::cout << std::endl;
if (!(isArr(val) || isGlobal(val))) {
// std::cout << " Adding load to useBlocks for value: ";
// printer.printInst(dynamic_cast<Instruction *>(val));
// std::cout << std::endl;
funcInfo->addValue2UseBlocks(val, basicBlock);
} else {
// std::cout << " Skip array/global load" << std::endl;
}
}
}
}
// 打印分析结果
// std::cout << "\nAnalysis results for function " << func->getName() << ":" << std::endl;
// auto &allocMap = funcInfo->getValue2AllocBlocks();
// std::cout << "AllocBlocks (" << allocMap.size() << "):" << std::endl;
// for (auto &[val, bb] : allocMap) {
// std::cout << " ";
// printer.printInst(dynamic_cast<Instruction *>(val));
// std::cout << " in BB: " << bb->getName() << std::endl;
// }
// auto &defMap = funcInfo->getValue2DefBlocks();
// std::cout << "DefBlocks (" << defMap.size() << "):" << std::endl;
// for (auto &[val, bbs] : defMap) {
// std::cout << " ";
// printer.printInst(dynamic_cast<Instruction *>(val));
// for (const auto &[bb, count] : bbs) {
// std::cout << " in BB: " << bb->getName() << " (count: " << count << ")";
// }
// }
// auto &useMap = funcInfo->getValue2UseBlocks();
// std::cout << "UseBlocks (" << useMap.size() << "):" << std::endl;
// for (auto &[val, bbs] : useMap) {
// std::cout << " ";
// printer.printInst(dynamic_cast<Instruction *>(val));
// for (const auto &[bb, count] : bbs) {
// std::cout << " in BB: " << bb->getName() << " (count: " << count << ")";
// }
// }
}
// std::cout << "===== End computeValue2Blocks =====" << std::endl;
}
/**
* @brief 级联关系的顺带消除用于llvm mem2reg类预优化1
*
* 采用队列进行模拟从某种程度上来看其实可以看作是UD链的反向操作
*
* @param [in] instr store指令使用的指令
* @param [in] changed 不动点法的判断标准,地址传递
* @param [in] func 指令所在函数
* @param [in] block 指令所在基本块
* @param [in] instrs 基本块所在指令集合,地址传递
* @return 无返回值,但满足条件的情况下会对指令进行删除
*/
auto Mem2Reg::cascade(Instruction *instr, bool &changed, Function *func, BasicBlock *block,
std::list<std::unique_ptr<Instruction>> &instrs) -> void {
if (instr != nullptr) {
if (instr->isUnary() || instr->isBinary() || instr->isLoad()) {
std::queue<Instruction *> toRemove;
toRemove.push(instr);
while (!toRemove.empty()) {
auto top = toRemove.front();
toRemove.pop();
auto operands = top->getOperands();
for (const auto &operand : operands) {
auto elem = dynamic_cast<Instruction *>(operand->getValue());
if (elem != nullptr) {
if ((elem->isUnary() || elem->isBinary() || elem->isLoad()) && elem->getUses().size() == 1 &&
elem->getUses().front()->getUser() == top) {
toRemove.push(elem);
} else if (elem->isAlloca()) {
// value2UseBlock中该block对应次数-1如果该变量的该useblock中count减为0了则意味着
// 该block其他地方也没用到该alloc了故从value2UseBlock中删除
FunctionAnalysisInfo* funcInfo = controlFlowAnalysis->getFunctionAnalysisInfo(func);
auto res = funcInfo->removeValue2UseBlock(elem, block);
// 只要有一次返回了true就说明有变化
if (res) {
changed = true;
}
}
}
}
auto tofind =
std::find_if(instrs.begin(), instrs.end(), [&top](const auto &instr) { return instr.get() == top; });
assert(tofind != instrs.end());
usedelete(tofind->get());
instrs.erase(tofind);
}
}
}
}
/**
* llvm mem2reg预优化1: 删除不含load的alloc和store
*
* 1. 删除不含load的alloc和store
* 2. 删除store指令之前的用于作store指令第0个操作数的那些级联指令就冗余了也要删除
* 3. 删除之后可能有些变量的load使用恰好又没有了因此再次从第一步开始循环这里使用不动点法
*
* 由于删除了级联关系,所以这里的方法有点儿激进;
* 同时也考虑了级联关系时如果调用了函数可能会有side effect所以没有删除调用函数的级联关系
* 而且关于函数参数的alloca不会在指令中删除也不会在value2Alloca中删除;
* 同样地我们不考虑数组和global不过这里的代码是基于value2blocks的在value2blocks中已经考虑了所以不用显式指明
*=
*/
auto Mem2Reg::preOptimize1() -> void {
SysYPrinter printer(pModule); // 初始化打印机
auto &functions = pModule->getFunctions();
// std::cout << "===== Start preOptimize1 =====" << std::endl;
for (const auto &function : functions) {
auto func = function.second.get();
// std::cout << "\nProcessing function: " << func->getName() << std::endl;
FunctionAnalysisInfo* funcInfo = controlFlowAnalysis->getFunctionAnalysisInfo(func);
if (!funcInfo) {
// std::cerr << "ERROR: No analysis info for function " << func->getName() << std::endl;
continue;
}
auto &vToDefB = funcInfo->getValue2DefBlocks();
auto &vToUseB = funcInfo->getValue2UseBlocks();
auto &vToAllocB = funcInfo->getValue2AllocBlocks();
// 打印初始状态
// std::cout << "Initial allocas: " << vToAllocB.size() << std::endl;
// for (auto &[val, bb] : vToAllocB) {
// std::cout << " Alloca: ";
// printer.printInst(dynamic_cast<Instruction *>(val));
// std::cout << " in BB: " << bb->getName() << std::endl;
// }
// 阶段1删除无store的alloca
// std::cout << "\nPhase 1: Remove unused allocas" << std::endl;
for (auto iter = vToAllocB.begin(); iter != vToAllocB.end();) {
auto val = iter->first;
auto bb = iter->second;
// std::cout << "Checking alloca: ";
// printer.printInst(dynamic_cast<Instruction *>(val));
// std::cout << " in BB: " << bb->getName() << std::endl;
// 如果该alloca没有对应的store指令且不在函数参数中
// 这里的vToDefB是value2DefBlocksvToUseB是value2UseBlocks
// 打印vToDefB
// std::cout << "DefBlocks (" << vToDefB.size() << "):" << std::endl;
// for (auto &[val, bbs] : vToDefB) {
// std::cout << " ";
// printer.printInst(dynamic_cast<Instruction *>(val));
// for (const auto &[bb, count] : bbs) {
// std::cout << " in BB: " << bb->getName() << " (count: " << count << ")" << std::endl;
// }
// }
// std::cout << vToDefB.count(val) << std::endl;
if (vToDefB.count(val) == 0U &&
std::find(func->getEntryBlock()->getArguments().begin(),
func->getEntryBlock()->getArguments().end(),
val) == func->getEntryBlock()->getArguments().end()) {
// std::cout << " Removing unused alloca: ";
// printer.printInst(dynamic_cast<Instruction *>(val));
// std::cout << std::endl;
auto tofind = std::find_if(bb->getInstructions().begin(),
bb->getInstructions().end(),
[val](const auto &instr) {
return instr.get() == val;
});
if (tofind == bb->getInstructions().end()) {
// std::cerr << "ERROR: Alloca not found in BB!" << std::endl;
++iter;
continue;
}
usedelete(tofind->get());
bb->getInstructions().erase(tofind);
iter = vToAllocB.erase(iter);
} else {
++iter;
}
}
// 阶段2删除无load的store
// std::cout << "\nPhase 2: Remove dead stores" << std::endl;
bool changed = true;
int iteration = 0;
while (changed) {
changed = false;
iteration++;
// std::cout << "\nIteration " << iteration << std::endl;
for (auto iter = vToDefB.begin(); iter != vToDefB.end();) {
auto val = iter->first;
// std::cout << "Checking value: ";
// printer.printInst(dynamic_cast<Instruction *>(val));
// std::cout << std::endl;
if (vToUseB.count(val) == 0U) {
// std::cout << " Found dead store for value: ";
// printer.printInst(dynamic_cast<Instruction *>(val));
// std::cout << std::endl;
auto blocks = funcInfo->getDefBlocksByValue(val);
for (auto block : blocks) {
// std::cout << " Processing BB: " << block->getName() << std::endl;
// printer.printBlock(block); // 打印基本块内容
auto &instrs = block->getInstructions();
for (auto it = instrs.begin(); it != instrs.end();) {
if ((*it)->isStore() && (*it)->getOperand(1) == val) {
// std::cout << " Removing store: ";
// printer.printInst(it->get());
std::cout << std::endl;
auto valUsedByStore = dynamic_cast<Instruction *>((*it)->getOperand(0));
usedelete(it->get());
if (valUsedByStore != nullptr &&
valUsedByStore->getUses().size() == 1 &&
valUsedByStore->getUses().front()->getUser() == (*it).get()) {
// std::cout << " Cascade deleting: ";
// printer.printInst(valUsedByStore);
// std::cout << std::endl;
cascade(valUsedByStore, changed, func, block, instrs);
}
it = instrs.erase(it);
changed = true;
} else {
++it;
}
}
}
// 删除对应的alloca
if (std::find(func->getEntryBlock()->getArguments().begin(),
func->getEntryBlock()->getArguments().end(),
val) == func->getEntryBlock()->getArguments().end()) {
auto bb = funcInfo->getAllocBlockByValue(val);
if (bb != nullptr) {
// std::cout << " Removing alloca: ";
// printer.printInst(dynamic_cast<Instruction *>(val));
// std::cout << " in BB: " << bb->getName() << std::endl;
funcInfo->removeValue2AllocBlock(val);
auto tofind = std::find_if(bb->getInstructions().begin(),
bb->getInstructions().end(),
[val](const auto &instr) {
return instr.get() == val;
});
if (tofind != bb->getInstructions().end()) {
usedelete(tofind->get());
bb->getInstructions().erase(tofind);
} else {
std::cerr << "ERROR: Alloca not found in BB!" << std::endl;
}
}
}
iter = vToDefB.erase(iter);
} else {
++iter;
}
}
}
}
// std::cout << "===== End preOptimize1 =====" << std::endl;
}
/**
* llvm mem2reg预优化2: 针对某个变量的Defblocks只有一个块的情况
*
* 1. 该基本块最后一次对该变量的store指令后的所有对该变量的load指令都可以替换为该基本块最后一次store指令的第0个操作数
* 2. 以该基本块为必经结点的结点集合中的对该变量的load指令都可以替换为该基本块最后一次对该变量的store指令的第0个操作数
* 3.
* 如果对该变量的所有load均替换掉了删除该基本块中最后一次store指令如果这个store指令是唯一的define那么再删除alloca指令不删除参数的alloca
* 4.
* 如果对该value的所有load都替换掉了对于该变量剩下还有store的话就转换成了preOptimize1的情况再调用preOptimize1进行删除
*
* 同样不考虑数组和全局变量因为这些变量不会被mem2reg优化在value2blocks中已经考虑了所以不用显式指明
* 替换的操作采用了UD链进行简化和效率的提升
*
*/
auto Mem2Reg::preOptimize2() -> void {
auto &functions = pModule->getFunctions();
for (const auto &function : functions) {
auto func = function.second.get();
FunctionAnalysisInfo* funcInfo = controlFlowAnalysis->getFunctionAnalysisInfo(func);
auto values = funcInfo->getValuesOfDefBlock();
for (auto val : values) {
auto blocks = funcInfo->getDefBlocksByValue(val);
// 该val只有一个defining block
if (blocks.size() == 1) {
auto block = *blocks.begin();
auto &instrs = block->getInstructions();
auto rit = std::find_if(instrs.rbegin(), instrs.rend(),
[val](const auto &instr) { return instr->isStore() && instr->getOperand(1) == val; });
// 注意reverse_iterator求base后是指向下一个指令因此要减一才是原来的指令
assert(rit != instrs.rend());
auto it = --rit.base();
auto propogationVal = (*it)->getOperand(0);
// 其实该块中it后对该val的load指令也可以替换掉了
for (auto curit = std::next(it); curit != instrs.end();) {
if ((*curit)->isLoad() && (*curit)->getOperand(0) == val) {
curit->get()->replaceAllUsesWith(propogationVal);
usedelete(curit->get());
curit = instrs.erase(curit);
funcInfo->removeValue2UseBlock(val, block);
} else {
++curit;
}
}
// 在支配树后继结点中替换load指令的操作数
BlockAnalysisInfo* blockInfo = controlFlowAnalysis->getBlockAnalysisInfo(block);
std::vector<BasicBlock *> blkchildren;
// 获取该块的支配树后继结点
std::queue<BasicBlock *> q;
auto sdoms = blockInfo->getSdoms();
for (auto sdom : sdoms) {
q.push(sdom);
blkchildren.push_back(sdom);
}
while (!q.empty()) {
auto blk = q.front();
q.pop();
BlockAnalysisInfo* blkInfo = controlFlowAnalysis->getBlockAnalysisInfo(blk);
for (auto sdom : blkInfo->getSdoms()) {
q.push(sdom);
blkchildren.push_back(sdom);
}
}
for (auto child : blkchildren) {
auto &childInstrs = child->getInstructions();
for (auto childIter = childInstrs.begin(); childIter != childInstrs.end();) {
if ((*childIter)->isLoad() && (*childIter)->getOperand(0) == val) {
childIter->get()->replaceAllUsesWith(propogationVal);
usedelete(childIter->get());
childIter = childInstrs.erase(childIter);
funcInfo->removeValue2UseBlock(val, child);
} else {
++childIter;
}
}
}
// 如果对该val的所有load均替换掉了那么对于该val的defining block中的最后一个define也可以删除了
// 同时该块中前面对于该val的define也变成死代码了可调用preOptimize1进行删除
if (funcInfo->getUseBlocksByValue(val).empty()) {
usedelete(it->get());
instrs.erase(it);
auto change = funcInfo->removeValue2DefBlock(val, block);
if (change) {
// 如果define是唯一的且不是函数参数的alloca直接删alloca
if (std::find(func->getEntryBlock()->getArguments().begin(), func->getEntryBlock()->getArguments().end(),
val) == func->getEntryBlock()->getArguments().end()) {
auto bb = funcInfo->getAllocBlockByValue(val);
assert(bb != nullptr);
auto tofind = std::find_if(bb->getInstructions().begin(), bb->getInstructions().end(),
[val](const auto &instr) { return instr.get() == val; });
usedelete(tofind->get());
bb->getInstructions().erase(tofind);
funcInfo->removeValue2AllocBlock(val);
}
} else {
// 如果该变量还有其他的define那么前面的define也变成死代码了
assert(!funcInfo->getDefBlocksByValue(val).empty());
assert(funcInfo->getUseBlocksByValue(val).empty());
preOptimize1();
}
}
}
}
}
}
/**
* @brief llvm mem2reg类预优化3针对某个变量的所有读写都在同一个块中的情况
*
* 1. 将每一个load替换成前一个store的值并删除该load
* 2. 如果在load前没有对该变量的store则不删除该load
* 3. 如果一个store后没有任何对改变量的load则删除该store
*
* @note 额外说明第二点不用显式处理因为我们的方法是从找到第一个store开始
* 第三点其实可以更激进一步地理解即每次替换了load之后它对应地那个store也可以删除了同时注意这里不要使用preoptimize1进行处理因为他们的级联关系是有用的即用来求load的替换值
* 同样地我们这里不考虑数组和全局变量因为这些变量不会被mem2reg优化不过这里在计算value2DefBlocks时已经跳过了所以不需要再显式处理了
* 替换的操作采用了UD链进行简化和效率的提升
*
* @param [in] void
* @return 无返回值,但满足条件的情况下会对指令的操作数进行替换以及对指令进行删除
*/
auto Mem2Reg::preOptimize3() -> void {
auto &functions = pModule->getFunctions();
for (const auto &function : functions) {
auto func = function.second.get();
FunctionAnalysisInfo* funcInfo = controlFlowAnalysis->getFunctionAnalysisInfo(func);
auto values = funcInfo->getValuesOfDefBlock();
for (auto val : values) {
auto sblocks = funcInfo->getDefBlocksByValue(val);
auto lblocks = funcInfo->getUseBlocksByValue(val);
if (sblocks.size() == 1 && lblocks.size() == 1 && *sblocks.begin() == *lblocks.begin()) {
auto block = *sblocks.begin();
auto &instrs = block->getInstructions();
auto it = std::find_if(instrs.begin(), instrs.end(),
[val](const auto &instr) { return instr->isStore() && instr->getOperand(1) == val; });
while (it != instrs.end()) {
auto propogationVal = (*it)->getOperand(0);
auto last = std::find_if(std::next(it), instrs.end(), [val](const auto &instr) {
return instr->isStore() && instr->getOperand(1) == val;
});
for (auto curit = std::next(it); curit != last;) {
if ((*curit)->isLoad() && (*curit)->getOperand(0) == val) {
curit->get()->replaceAllUsesWith(propogationVal);
usedelete(curit->get());
curit = instrs.erase(curit);
funcInfo->removeValue2UseBlock(val, block);
} else {
++curit;
}
}
// 替换了load之后它对应地那个store也可以删除了
if (!(std::find_if(func->getEntryBlock()->getArguments().begin(), func->getEntryBlock()->getArguments().end(),
[val](const auto &instr) { return instr == val; }) !=
func->getEntryBlock()->getArguments().end()) &&
last == instrs.end()) {
usedelete(it->get());
it = instrs.erase(it);
if (funcInfo->removeValue2DefBlock(val, block)) {
auto bb = funcInfo->getAllocBlockByValue(val);
if (bb != nullptr) {
auto tofind = std::find_if(bb->getInstructions().begin(), bb->getInstructions().end(),
[val](const auto &instr) { return instr.get() == val; });
usedelete(tofind->get());
bb->getInstructions().erase(tofind);
funcInfo->removeValue2AllocBlock(val);
}
}
}
it = last;
}
}
}
}
}
/**
* 为所有变量的定义块集合的迭代支配边界插入phi结点
*
* insertPhi是mem2reg的核心之一这里是对所有变量的迭代支配边界的phi结点插入无参数也无返回值
* 同样跳过对数组和全局变量的处理因为这些变量不会被mem2reg优化刚好这里在计算value2DefBlocks时已经跳过了所以不需要再显式处理了
* 同时我们进行了剪枝处理只有在基本块入口活跃的变量才插入phi函数
*
*/
auto Mem2Reg::insertPhi() -> void {
auto &functions = pModule->getFunctions();
for (const auto &function : functions) {
auto func = function.second.get();
FunctionAnalysisInfo* funcInfo = controlFlowAnalysis->getFunctionAnalysisInfo(func);
const auto &vToDefB = funcInfo->getValue2DefBlocks();
for (const auto &map_pair : vToDefB) {
// 首先为每个变量找到迭代支配边界
auto val = map_pair.first;
auto blocks = funcInfo->getDefBlocksByValue(val);
auto itDFs = computeIterDf(blocks);
// 然后在每个变量相应的迭代支配边界上插入phi结点
for (auto basicBlock : itDFs) {
const auto &actiTable = activeVarAnalysis->getActiveTable();
auto dval = dynamic_cast<User *>(val);
// 只有在基本块入口活跃的变量才插入phi函数
if (actiTable.at(basicBlock).front().count(dval) != 0U) {
pBuilder->createPhiInst(val->getType(), val, basicBlock);
}
}
}
}
}
/**
* 重命名
*
* 重命名是mem2reg的核心之二这里是对单个块的重命名递归实现
* 同样跳过对数组和全局变量的处理因为这些变量不会被mem2reg优化
*
*/
auto Mem2Reg::rename(BasicBlock *block, std::unordered_map<Value *, int> &count,
std::unordered_map<Value *, std::stack<Instruction *>> &stacks) -> void {
auto &instrs = block->getInstructions();
std::unordered_map<Value *, int> valPop;
// 第一大步:对块中的所有指令遍历处理
for (auto iter = instrs.begin(); iter != instrs.end();) {
auto instr = iter->get();
// 对于load指令变量用最新的那个
if (instr->isLoad()) {
auto val = instr->getOperand(0);
if (!(isArr(val) || isGlobal(val))) {
if (!stacks[val].empty()) {
instr->replaceOperand(0, stacks[val].top());
}
}
}
// 然后对于define的情况看alloca、store和phi指令
if (instr->isDefine()) {
if (instr->isAlloca()) {
// alloca指令名字不改了命名就按xx_1x_2...来就行
auto val = instr;
if (!(isArr(val) || isGlobal(val))) {
++valPop[val];
stacks[val].push(val);
++count[val];
}
} else if (instr->isPhi()) {
// Phi指令也是一条特殊的define指令
auto val = dynamic_cast<PhiInst *>(instr)->getMapVal();
if (!(isArr(val) || isGlobal(val))) {
auto i = count[val];
if (i == 0) {
// 对还未alloca就有phi的指令的处理直接删除
usedelete(iter->get());
iter = instrs.erase(iter);
continue;
}
auto newname = dynamic_cast<Instruction *>(val)->getName() + "_" + std::to_string(i);
auto newalloca = pBuilder->createAllocaInstWithoutInsert(val->getType(), {}, block, newname);
FunctionAnalysisInfo* ParentfuncInfo = controlFlowAnalysis->getFunctionAnalysisInfo(block->getParent());
ParentfuncInfo->addIndirectAlloca(newalloca);
instr->replaceOperand(0, newalloca);
++valPop[val];
stacks[val].push(newalloca);
++count[val];
}
} else {
// store指令看operand的名字我们的实现是规定变量在operand的第二位用一个新的alloca x_i代替
auto val = instr->getOperand(1);
if (!(isArr(val) || isGlobal(val))) {
auto i = count[val];
auto newname = dynamic_cast<Instruction *>(val)->getName() + "_" + std::to_string(i);
auto newalloca = pBuilder->createAllocaInstWithoutInsert(val->getType(), {}, block, newname);
FunctionAnalysisInfo* ParentfuncInfo = controlFlowAnalysis->getFunctionAnalysisInfo(block->getParent());
ParentfuncInfo->addIndirectAlloca(newalloca);
// block->getParent()->addIndirectAlloca(newalloca);
instr->replaceOperand(1, newalloca);
++valPop[val];
stacks[val].push(newalloca);
++count[val];
}
}
}
++iter;
}
// 第二大步把所有CFG中的该块的successor的phi指令的相应operand确定
for (auto succ : block->getSuccessors()) {
auto position = getPredIndex(block, succ);
for (auto &instr : succ->getInstructions()) {
if (instr->isPhi()) {
auto val = dynamic_cast<PhiInst *>(instr.get())->getMapVal();
if (!stacks[val].empty()) {
instr->replaceOperand(position + 1, stacks[val].top());
}
} else {
// phi指令是添加在块的最前面的因此过了之后就不会有phi了直接break
break;
}
}
}
// 第三大步递归支配树的后继支配树才能表示define-use关系
BlockAnalysisInfo* blockInfo = controlFlowAnalysis->getBlockAnalysisInfo(block);
for (auto sdom : blockInfo->getSdoms()) {
rename(sdom, count, stacks);
}
// 第四大步遍历块中的所有指令如果涉及到define就弹栈这一步是必要的可以从递归的整体性来思考原因
// 注意这里count没清理因为平级之间计数仍然是一直增加的但是stack要清理因为define-use关系来自直接
// 支配结点而不是平级之间,不清理栈会被污染
// 提前优化知道变量对应的要弹栈的次数就可以了没必要遍历所有instr.
for (auto val_pair : valPop) {
auto val = val_pair.first;
for (int i = 0; i < val_pair.second; ++i) {
stacks[val].pop();
}
}
}
/**
* 重命名所有块
*
* 调用rename自上而下实现所有rename
*
*/
auto Mem2Reg::renameAll() -> void {
auto &functions = pModule->getFunctions();
for (const auto &function : functions) {
auto func = function.second.get();
// 对于每个function都要SSA化所以count和stacks定义在这并初始化
std::unordered_map<Value *, int> count;
std::unordered_map<Value *, std::stack<Instruction *>> stacks;
FunctionAnalysisInfo* funcInfo = controlFlowAnalysis->getFunctionAnalysisInfo(func);
for (const auto &map_pair : funcInfo->getValue2DefBlocks()) {
auto val = map_pair.first;
count[val] = 0;
}
rename(func->getEntryBlock(), count, stacks);
}
}
/**
* mem2reg对外的接口
*
* 静态单一赋值 + mem2reg等pass的逻辑组合
*
*/
auto Mem2Reg::mem2regPipeline() -> void {
// 首先进行mem2reg的前置分析
controlFlowAnalysis->clear();
controlFlowAnalysis->runControlFlowAnalysis();
// 活跃变量分析
activeVarAnalysis->clear();
dataFlowAnalysisUtils.addBackwardAnalyzer(activeVarAnalysis);
dataFlowAnalysisUtils.backwardAnalyze(pModule);
// 计算所有valueToBlocks的定义映射
computeValue2Blocks();
// SysYPrinter printer(pModule);
// 参考llvm的mem2reg遍在插入phi结点之前先做些优化
preOptimize1();
// printer.printIR();
preOptimize2();
// printer.printIR();
// 优化三 可能会针对局部变量优化而删除整个块的alloca/store
preOptimize3();
//再进行活跃变量分析
// 报错?
// printer.printIR();
dataFlowAnalysisUtils.backwardAnalyze(pModule);
// 为所有变量插入phi结点
insertPhi();
// 重命名
renameAll();
}
/**
* 计算块n是块s的第几个前驱
*
* helperfunction没有返回值但是会将dom和other的交集赋值给dom
*
*/
auto Mem2Reg::getPredIndex(BasicBlock *n, BasicBlock *s) -> int {
int index = 0;
for (auto elem : s->getPredecessors()) {
if (elem == n) {
break;
}
++index;
}
assert(index < static_cast<int>(s->getPredecessors().size()) && "n is not a predecessor of s.");
return index;
}
/**
* 判断一个value是不是全局变量
*/
auto Mem2Reg::isGlobal(Value *val) -> bool {
auto gval = dynamic_cast<GlobalValue *>(val);
return gval != nullptr;
}
/**
* 判断一个value是不是数组
*/
auto Mem2Reg::isArr(Value *val) -> bool {
auto aval = dynamic_cast<AllocaInst *>(val);
return aval != nullptr && aval->getNumDims() != 0;
}
/**
* 删除一个指令的operand对应的value的该条use
*/
auto Mem2Reg::usedelete(Instruction *instr) -> void {
for (auto &use : instr->getOperands()) {
auto val = use->getValue();
val->removeUse(use);
}
}
} // namespace sysy

225
src/RISCv64AsmPrinter.cpp Normal file
View File

@@ -0,0 +1,225 @@
#include "RISCv64AsmPrinter.h"
#include "RISCv64ISel.h"
#include <stdexcept>
namespace sysy {
// 检查是否为内存加载/存储指令,以处理特殊的打印格式
bool isMemoryOp(RVOpcodes opcode) {
switch (opcode) {
case RVOpcodes::LB: case RVOpcodes::LH: case RVOpcodes::LW: case RVOpcodes::LD:
case RVOpcodes::LBU: case RVOpcodes::LHU: case RVOpcodes::LWU:
case RVOpcodes::SB: case RVOpcodes::SH: case RVOpcodes::SW: case RVOpcodes::SD:
return true;
default:
return false;
}
}
RISCv64AsmPrinter::RISCv64AsmPrinter(MachineFunction* mfunc) : MFunc(mfunc) {}
void RISCv64AsmPrinter::run(std::ostream& os) {
OS = &os;
*OS << ".globl " << MFunc->getName() << "\n";
*OS << MFunc->getName() << ":\n";
printPrologue();
for (auto& mbb : MFunc->getBlocks()) {
printBasicBlock(mbb.get());
}
}
void RISCv64AsmPrinter::printPrologue() {
StackFrameInfo& frame_info = MFunc->getFrameInfo();
// 序言需要为保存ra和s0预留16字节
int total_stack_size = frame_info.locals_size + frame_info.spill_size + 16;
int aligned_stack_size = (total_stack_size + 15) & ~15;
frame_info.total_size = aligned_stack_size;
if (aligned_stack_size > 0) {
*OS << " addi sp, sp, -" << aligned_stack_size << "\n";
*OS << " sd ra, " << (aligned_stack_size - 8) << "(sp)\n";
*OS << " sd s0, " << (aligned_stack_size - 16) << "(sp)\n";
*OS << " mv s0, sp\n";
}
// 忠实还原保存函数入口参数的逻辑
Function* F = MFunc->getFunc();
if (F && F->getEntryBlock()) {
int arg_idx = 0;
RISCv64ISel* isel = MFunc->getISel();
for (AllocaInst* alloca_for_param : F->getEntryBlock()->getArguments()) {
if (arg_idx >= 8) break;
unsigned vreg = isel->getVReg(alloca_for_param);
if (frame_info.alloca_offsets.count(vreg)) {
int offset = frame_info.alloca_offsets.at(vreg);
auto arg_reg = static_cast<PhysicalReg>(static_cast<int>(PhysicalReg::A0) + arg_idx);
*OS << " sw " << regToString(arg_reg) << ", " << offset << "(s0)\n";
}
arg_idx++;
}
}
}
void RISCv64AsmPrinter::printEpilogue() {
int aligned_stack_size = MFunc->getFrameInfo().total_size;
if (aligned_stack_size > 0) {
*OS << " ld ra, " << (aligned_stack_size - 8) << "(sp)\n";
*OS << " ld s0, " << (aligned_stack_size - 16) << "(sp)\n";
*OS << " addi sp, sp, " << aligned_stack_size << "\n";
}
}
void RISCv64AsmPrinter::printBasicBlock(MachineBasicBlock* mbb) {
if (!mbb->getName().empty()) {
*OS << mbb->getName() << ":\n";
}
for (auto& instr : mbb->getInstructions()) {
printInstruction(instr.get());
}
}
void RISCv64AsmPrinter::printInstruction(MachineInstr* instr) {
auto opcode = instr->getOpcode();
if (opcode == RVOpcodes::RET) {
printEpilogue();
}
if (opcode != RVOpcodes::LABEL) {
*OS << " ";
}
switch (opcode) {
case RVOpcodes::ADD: *OS << "add "; break; case RVOpcodes::ADDI: *OS << "addi "; break;
case RVOpcodes::ADDW: *OS << "addw "; break; case RVOpcodes::ADDIW: *OS << "addiw "; break;
case RVOpcodes::SUB: *OS << "sub "; break; case RVOpcodes::SUBW: *OS << "subw "; break;
case RVOpcodes::MUL: *OS << "mul "; break; case RVOpcodes::MULW: *OS << "mulw "; break;
case RVOpcodes::DIV: *OS << "div "; break; case RVOpcodes::DIVW: *OS << "divw "; break;
case RVOpcodes::REM: *OS << "rem "; break; case RVOpcodes::REMW: *OS << "remw "; break;
case RVOpcodes::XOR: *OS << "xor "; break; case RVOpcodes::XORI: *OS << "xori "; break;
case RVOpcodes::OR: *OS << "or "; break; case RVOpcodes::ORI: *OS << "ori "; break;
case RVOpcodes::AND: *OS << "and "; break; case RVOpcodes::ANDI: *OS << "andi "; break;
case RVOpcodes::SLL: *OS << "sll "; break; case RVOpcodes::SLLI: *OS << "slli "; break;
case RVOpcodes::SLLW: *OS << "sllw "; break; case RVOpcodes::SLLIW: *OS << "slliw "; break;
case RVOpcodes::SRL: *OS << "srl "; break; case RVOpcodes::SRLI: *OS << "srli "; break;
case RVOpcodes::SRLW: *OS << "srlw "; break; case RVOpcodes::SRLIW: *OS << "srliw "; break;
case RVOpcodes::SRA: *OS << "sra "; break; case RVOpcodes::SRAI: *OS << "srai "; break;
case RVOpcodes::SRAW: *OS << "sraw "; break; case RVOpcodes::SRAIW: *OS << "sraiw "; break;
case RVOpcodes::SLT: *OS << "slt "; break; case RVOpcodes::SLTI: *OS << "slti "; break;
case RVOpcodes::SLTU: *OS << "sltu "; break; case RVOpcodes::SLTIU: *OS << "sltiu "; break;
case RVOpcodes::LW: *OS << "lw "; break; case RVOpcodes::LH: *OS << "lh "; break;
case RVOpcodes::LB: *OS << "lb "; break; case RVOpcodes::LWU: *OS << "lwu "; break;
case RVOpcodes::LHU: *OS << "lhu "; break; case RVOpcodes::LBU: *OS << "lbu "; break;
case RVOpcodes::SW: *OS << "sw "; break; case RVOpcodes::SH: *OS << "sh "; break;
case RVOpcodes::SB: *OS << "sb "; break; case RVOpcodes::LD: *OS << "ld "; break;
case RVOpcodes::SD: *OS << "sd "; break;
case RVOpcodes::J: *OS << "j "; break; case RVOpcodes::JAL: *OS << "jal "; break;
case RVOpcodes::JALR: *OS << "jalr "; break; case RVOpcodes::RET: *OS << "ret"; break;
case RVOpcodes::BEQ: *OS << "beq "; break; case RVOpcodes::BNE: *OS << "bne "; break;
case RVOpcodes::BLT: *OS << "blt "; break; case RVOpcodes::BGE: *OS << "bge "; break;
case RVOpcodes::BLTU: *OS << "bltu "; break; case RVOpcodes::BGEU: *OS << "bgeu "; break;
case RVOpcodes::LI: *OS << "li "; break; case RVOpcodes::LA: *OS << "la "; break;
case RVOpcodes::MV: *OS << "mv "; break; case RVOpcodes::NEG: *OS << "neg "; break;
case RVOpcodes::NEGW: *OS << "negw "; break; case RVOpcodes::SEQZ: *OS << "seqz "; break;
case RVOpcodes::SNEZ: *OS << "snez "; break;
case RVOpcodes::CALL: *OS << "call "; break;
case RVOpcodes::LABEL:
printOperand(instr->getOperands()[0].get());
*OS << ":";
break;
case RVOpcodes::FRAME_LOAD:
case RVOpcodes::FRAME_STORE:
// These should have been eliminated by RegAlloc
throw std::runtime_error("FRAME pseudo-instruction not eliminated before AsmPrinter");
default:
throw std::runtime_error("Unknown opcode in AsmPrinter");
}
const auto& operands = instr->getOperands();
if (!operands.empty()) {
if (isMemoryOp(opcode)) {
printOperand(operands[0].get());
*OS << ", ";
printOperand(operands[1].get());
} else {
for (size_t i = 0; i < operands.size(); ++i) {
printOperand(operands[i].get());
if (i < operands.size() - 1) {
*OS << ", ";
}
}
}
}
*OS << "\n";
}
void RISCv64AsmPrinter::printOperand(MachineOperand* op) {
if (!op) return;
switch(op->getKind()) {
case MachineOperand::KIND_REG: {
auto reg_op = static_cast<RegOperand*>(op);
if (reg_op->isVirtual()) {
*OS << "%vreg" << reg_op->getVRegNum();
} else {
*OS << regToString(reg_op->getPReg());
}
break;
}
case MachineOperand::KIND_IMM:
*OS << static_cast<ImmOperand*>(op)->getValue();
break;
case MachineOperand::KIND_LABEL:
*OS << static_cast<LabelOperand*>(op)->getName();
break;
case MachineOperand::KIND_MEM: {
auto mem_op = static_cast<MemOperand*>(op);
printOperand(mem_op->getOffset());
*OS << "(";
printOperand(mem_op->getBase());
*OS << ")";
break;
}
}
}
std::string RISCv64AsmPrinter::regToString(PhysicalReg reg) {
switch (reg) {
case PhysicalReg::ZERO: return "x0"; case PhysicalReg::RA: return "ra";
case PhysicalReg::SP: return "sp"; case PhysicalReg::GP: return "gp";
case PhysicalReg::TP: return "tp"; case PhysicalReg::T0: return "t0";
case PhysicalReg::T1: return "t1"; case PhysicalReg::T2: return "t2";
case PhysicalReg::S0: return "s0"; case PhysicalReg::S1: return "s1";
case PhysicalReg::A0: return "a0"; case PhysicalReg::A1: return "a1";
case PhysicalReg::A2: return "a2"; case PhysicalReg::A3: return "a3";
case PhysicalReg::A4: return "a4"; case PhysicalReg::A5: return "a5";
case PhysicalReg::A6: return "a6"; case PhysicalReg::A7: return "a7";
case PhysicalReg::S2: return "s2"; case PhysicalReg::S3: return "s3";
case PhysicalReg::S4: return "s4"; case PhysicalReg::S5: return "s5";
case PhysicalReg::S6: return "s6"; case PhysicalReg::S7: return "s7";
case PhysicalReg::S8: return "s8"; case PhysicalReg::S9: return "s9";
case PhysicalReg::S10: return "s10"; case PhysicalReg::S11: return "s11";
case PhysicalReg::T3: return "t3"; case PhysicalReg::T4: return "t4";
case PhysicalReg::T5: return "t5"; case PhysicalReg::T6: return "t6";
case PhysicalReg::F0: return "f0"; case PhysicalReg::F1: return "f1";
case PhysicalReg::F2: return "f2"; case PhysicalReg::F3: return "f3";
case PhysicalReg::F4: return "f4"; case PhysicalReg::F5: return "f5";
case PhysicalReg::F6: return "f6"; case PhysicalReg::F7: return "f7";
case PhysicalReg::F8: return "f8"; case PhysicalReg::F9: return "f9";
case PhysicalReg::F10: return "f10"; case PhysicalReg::F11: return "f11";
case PhysicalReg::F12: return "f12"; case PhysicalReg::F13: return "f13";
case PhysicalReg::F14: return "f14"; case PhysicalReg::F15: return "f15";
case PhysicalReg::F16: return "f16"; case PhysicalReg::F17: return "f17";
case PhysicalReg::F18: return "f18"; case PhysicalReg::F19: return "f19";
case PhysicalReg::F20: return "f20"; case PhysicalReg::F21: return "f21";
case PhysicalReg::F22: return "f22"; case PhysicalReg::F23: return "f23";
case PhysicalReg::F24: return "f24"; case PhysicalReg::F25: return "f25";
case PhysicalReg::F26: return "f26"; case PhysicalReg::F27: return "f27";
case PhysicalReg::F28: return "f28"; case PhysicalReg::F29: return "f29";
case PhysicalReg::F30: return "f30"; case PhysicalReg::F31: return "f31";
default: return "UNKNOWN_REG";
}
}
} // namespace sysy

File diff suppressed because it is too large Load Diff

635
src/RISCv64ISel.cpp Normal file
View File

@@ -0,0 +1,635 @@
#include "RISCv64ISel.h"
#include <stdexcept>
#include <set>
#include <functional>
#include <cmath> // For std::fabs
#include <limits> // For std::numeric_limits
namespace sysy {
// DAG节点定义 (内部实现)
struct RISCv64ISel::DAGNode {
enum NodeKind { CONSTANT, LOAD, STORE, BINARY, CALL, RETURN, BRANCH, ALLOCA_ADDR, UNARY, MEMSET };
NodeKind kind;
Value* value = nullptr;
std::vector<DAGNode*> operands;
std::vector<DAGNode*> users;
DAGNode(NodeKind k) : kind(k) {}
};
RISCv64ISel::RISCv64ISel() : vreg_counter(0), local_label_counter(0) {}
// 为一个IR Value获取或分配一个新的虚拟寄存器
unsigned RISCv64ISel::getVReg(Value* val) {
if (!val) {
throw std::runtime_error("Cannot get vreg for a null Value.");
}
if (vreg_map.find(val) == vreg_map.end()) {
if (vreg_counter == 0) {
vreg_counter = 1; // vreg 0 保留
}
vreg_map[val] = vreg_counter++;
}
return vreg_map.at(val);
}
// 主入口函数
std::unique_ptr<MachineFunction> RISCv64ISel::runOnFunction(Function* func) {
F = func;
if (!F) return nullptr;
MFunc = std::make_unique<MachineFunction>(F, this);
vreg_map.clear();
bb_map.clear();
vreg_counter = 0;
local_label_counter = 0;
select();
return std::move(MFunc);
}
// 指令选择主流程
void RISCv64ISel::select() {
for (const auto& bb_ptr : F->getBasicBlocks()) {
auto mbb = std::make_unique<MachineBasicBlock>(bb_ptr->getName(), MFunc.get());
bb_map[bb_ptr.get()] = mbb.get();
MFunc->addBlock(std::move(mbb));
}
if (F->getEntryBlock()) {
for (auto* arg_alloca : F->getEntryBlock()->getArguments()) {
getVReg(arg_alloca);
}
}
for (const auto& bb_ptr : F->getBasicBlocks()) {
selectBasicBlock(bb_ptr.get());
}
for (const auto& bb_ptr : F->getBasicBlocks()) {
CurMBB = bb_map.at(bb_ptr.get());
for (auto succ : bb_ptr->getSuccessors()) {
CurMBB->successors.push_back(bb_map.at(succ));
}
for (auto pred : bb_ptr->getPredecessors()) {
CurMBB->predecessors.push_back(bb_map.at(pred));
}
}
}
// 处理单个基本块
void RISCv64ISel::selectBasicBlock(BasicBlock* bb) {
CurMBB = bb_map.at(bb);
auto dag = build_dag(bb);
std::map<Value*, DAGNode*> value_to_node;
for(const auto& node : dag) {
if (node->value) {
value_to_node[node->value] = node.get();
}
}
std::set<DAGNode*> selected_nodes;
std::function<void(DAGNode*)> select_recursive =
[&](DAGNode* node) {
if (!node || selected_nodes.count(node)) return;
for (auto operand : node->operands) {
select_recursive(operand);
}
selectNode(node);
selected_nodes.insert(node);
};
for (const auto& inst_ptr : bb->getInstructions()) {
DAGNode* node_to_select = nullptr;
if (value_to_node.count(inst_ptr.get())) {
node_to_select = value_to_node.at(inst_ptr.get());
} else {
for(const auto& node : dag) {
if(node->value == inst_ptr.get()) {
node_to_select = node.get();
break;
}
}
}
if(node_to_select) {
select_recursive(node_to_select);
}
}
}
// 核心函数为DAG节点选择并生成MachineInstr (忠实移植版)
void RISCv64ISel::selectNode(DAGNode* node) {
switch (node->kind) {
case DAGNode::CONSTANT:
case DAGNode::ALLOCA_ADDR:
if (node->value) getVReg(node->value);
break;
case DAGNode::LOAD: {
auto dest_vreg = getVReg(node->value);
Value* ptr_val = node->operands[0]->value;
if (auto alloca = dynamic_cast<AllocaInst*>(ptr_val)) {
auto instr = std::make_unique<MachineInstr>(RVOpcodes::FRAME_LOAD);
instr->addOperand(std::make_unique<RegOperand>(dest_vreg));
instr->addOperand(std::make_unique<RegOperand>(getVReg(alloca)));
CurMBB->addInstruction(std::move(instr));
} else if (auto global = dynamic_cast<GlobalValue*>(ptr_val)) {
auto addr_vreg = getNewVReg();
auto la = std::make_unique<MachineInstr>(RVOpcodes::LA);
la->addOperand(std::make_unique<RegOperand>(addr_vreg));
la->addOperand(std::make_unique<LabelOperand>(global->getName()));
CurMBB->addInstruction(std::move(la));
auto lw = std::make_unique<MachineInstr>(RVOpcodes::LW);
lw->addOperand(std::make_unique<RegOperand>(dest_vreg));
lw->addOperand(std::make_unique<MemOperand>(
std::make_unique<RegOperand>(addr_vreg),
std::make_unique<ImmOperand>(0)
));
CurMBB->addInstruction(std::move(lw));
} else {
auto ptr_vreg = getVReg(ptr_val);
auto lw = std::make_unique<MachineInstr>(RVOpcodes::LW);
lw->addOperand(std::make_unique<RegOperand>(dest_vreg));
lw->addOperand(std::make_unique<MemOperand>(
std::make_unique<RegOperand>(ptr_vreg),
std::make_unique<ImmOperand>(0)
));
CurMBB->addInstruction(std::move(lw));
}
break;
}
case DAGNode::STORE: {
Value* val_to_store = node->operands[0]->value;
Value* ptr_val = node->operands[1]->value;
if (auto val_const = dynamic_cast<ConstantValue*>(val_to_store)) {
auto li = std::make_unique<MachineInstr>(RVOpcodes::LI);
li->addOperand(std::make_unique<RegOperand>(getVReg(val_const)));
li->addOperand(std::make_unique<ImmOperand>(val_const->getInt()));
CurMBB->addInstruction(std::move(li));
}
auto val_vreg = getVReg(val_to_store);
if (auto alloca = dynamic_cast<AllocaInst*>(ptr_val)) {
auto instr = std::make_unique<MachineInstr>(RVOpcodes::FRAME_STORE);
instr->addOperand(std::make_unique<RegOperand>(val_vreg));
instr->addOperand(std::make_unique<RegOperand>(getVReg(alloca)));
CurMBB->addInstruction(std::move(instr));
} else if (auto global = dynamic_cast<GlobalValue*>(ptr_val)) {
auto addr_vreg = getNewVReg();
auto la = std::make_unique<MachineInstr>(RVOpcodes::LA);
la->addOperand(std::make_unique<RegOperand>(addr_vreg));
la->addOperand(std::make_unique<LabelOperand>(global->getName()));
CurMBB->addInstruction(std::move(la));
auto sw = std::make_unique<MachineInstr>(RVOpcodes::SW);
sw->addOperand(std::make_unique<RegOperand>(val_vreg));
sw->addOperand(std::make_unique<MemOperand>(
std::make_unique<RegOperand>(addr_vreg),
std::make_unique<ImmOperand>(0)
));
CurMBB->addInstruction(std::move(sw));
} else {
auto ptr_vreg = getVReg(ptr_val);
auto sw = std::make_unique<MachineInstr>(RVOpcodes::SW);
sw->addOperand(std::make_unique<RegOperand>(val_vreg));
sw->addOperand(std::make_unique<MemOperand>(
std::make_unique<RegOperand>(ptr_vreg),
std::make_unique<ImmOperand>(0)
));
CurMBB->addInstruction(std::move(sw));
}
break;
}
case DAGNode::BINARY: {
auto bin = dynamic_cast<BinaryInst*>(node->value);
Value* lhs = bin->getLhs();
Value* rhs = bin->getRhs();
auto load_val_if_const = [&](Value* val) {
if (auto c = dynamic_cast<ConstantValue*>(val)) {
auto li = std::make_unique<MachineInstr>(RVOpcodes::LI);
li->addOperand(std::make_unique<RegOperand>(getVReg(c)));
li->addOperand(std::make_unique<ImmOperand>(c->getInt()));
CurMBB->addInstruction(std::move(li));
}
};
load_val_if_const(lhs);
load_val_if_const(rhs);
auto dest_vreg = getVReg(bin);
auto lhs_vreg = getVReg(lhs);
auto rhs_vreg = getVReg(rhs);
if (bin->getKind() == BinaryInst::kAdd) {
if (auto rhs_const = dynamic_cast<ConstantValue*>(rhs)) {
if (rhs_const->getInt() >= -2048 && rhs_const->getInt() < 2048) {
auto instr = std::make_unique<MachineInstr>(RVOpcodes::ADDIW);
instr->addOperand(std::make_unique<RegOperand>(dest_vreg));
instr->addOperand(std::make_unique<RegOperand>(lhs_vreg));
instr->addOperand(std::make_unique<ImmOperand>(rhs_const->getInt()));
CurMBB->addInstruction(std::move(instr));
return;
}
}
}
switch (bin->getKind()) {
case BinaryInst::kAdd: {
RVOpcodes opcode = (lhs->getType()->isPointer() || rhs->getType()->isPointer()) ? RVOpcodes::ADD : RVOpcodes::ADDW;
auto instr = std::make_unique<MachineInstr>(opcode);
instr->addOperand(std::make_unique<RegOperand>(dest_vreg));
instr->addOperand(std::make_unique<RegOperand>(lhs_vreg));
instr->addOperand(std::make_unique<RegOperand>(rhs_vreg));
CurMBB->addInstruction(std::move(instr));
break;
}
case BinaryInst::kSub: {
auto instr = std::make_unique<MachineInstr>(RVOpcodes::SUBW);
instr->addOperand(std::make_unique<RegOperand>(dest_vreg));
instr->addOperand(std::make_unique<RegOperand>(lhs_vreg));
instr->addOperand(std::make_unique<RegOperand>(rhs_vreg));
CurMBB->addInstruction(std::move(instr));
break;
}
case BinaryInst::kMul: {
auto instr = std::make_unique<MachineInstr>(RVOpcodes::MULW);
instr->addOperand(std::make_unique<RegOperand>(dest_vreg));
instr->addOperand(std::make_unique<RegOperand>(lhs_vreg));
instr->addOperand(std::make_unique<RegOperand>(rhs_vreg));
CurMBB->addInstruction(std::move(instr));
break;
}
case Instruction::kDiv: {
auto instr = std::make_unique<MachineInstr>(RVOpcodes::DIVW);
instr->addOperand(std::make_unique<RegOperand>(dest_vreg));
instr->addOperand(std::make_unique<RegOperand>(lhs_vreg));
instr->addOperand(std::make_unique<RegOperand>(rhs_vreg));
CurMBB->addInstruction(std::move(instr));
break;
}
case Instruction::kRem: {
auto instr = std::make_unique<MachineInstr>(RVOpcodes::REMW);
instr->addOperand(std::make_unique<RegOperand>(dest_vreg));
instr->addOperand(std::make_unique<RegOperand>(lhs_vreg));
instr->addOperand(std::make_unique<RegOperand>(rhs_vreg));
CurMBB->addInstruction(std::move(instr));
break;
}
case BinaryInst::kICmpEQ: {
auto sub = std::make_unique<MachineInstr>(RVOpcodes::SUBW);
sub->addOperand(std::make_unique<RegOperand>(dest_vreg));
sub->addOperand(std::make_unique<RegOperand>(lhs_vreg));
sub->addOperand(std::make_unique<RegOperand>(rhs_vreg));
CurMBB->addInstruction(std::move(sub));
auto seqz = std::make_unique<MachineInstr>(RVOpcodes::SEQZ);
seqz->addOperand(std::make_unique<RegOperand>(dest_vreg));
seqz->addOperand(std::make_unique<RegOperand>(dest_vreg));
CurMBB->addInstruction(std::move(seqz));
break;
}
case BinaryInst::kICmpNE: {
auto sub = std::make_unique<MachineInstr>(RVOpcodes::SUBW);
sub->addOperand(std::make_unique<RegOperand>(dest_vreg));
sub->addOperand(std::make_unique<RegOperand>(lhs_vreg));
sub->addOperand(std::make_unique<RegOperand>(rhs_vreg));
CurMBB->addInstruction(std::move(sub));
auto snez = std::make_unique<MachineInstr>(RVOpcodes::SNEZ);
snez->addOperand(std::make_unique<RegOperand>(dest_vreg));
snez->addOperand(std::make_unique<RegOperand>(dest_vreg));
CurMBB->addInstruction(std::move(snez));
break;
}
case BinaryInst::kICmpLT: {
auto instr = std::make_unique<MachineInstr>(RVOpcodes::SLT);
instr->addOperand(std::make_unique<RegOperand>(dest_vreg));
instr->addOperand(std::make_unique<RegOperand>(lhs_vreg));
instr->addOperand(std::make_unique<RegOperand>(rhs_vreg));
CurMBB->addInstruction(std::move(instr));
break;
}
case BinaryInst::kICmpGT: {
auto instr = std::make_unique<MachineInstr>(RVOpcodes::SLT);
instr->addOperand(std::make_unique<RegOperand>(dest_vreg));
instr->addOperand(std::make_unique<RegOperand>(rhs_vreg));
instr->addOperand(std::make_unique<RegOperand>(lhs_vreg));
CurMBB->addInstruction(std::move(instr));
break;
}
case BinaryInst::kICmpLE: {
auto slt = std::make_unique<MachineInstr>(RVOpcodes::SLT);
slt->addOperand(std::make_unique<RegOperand>(dest_vreg));
slt->addOperand(std::make_unique<RegOperand>(rhs_vreg));
slt->addOperand(std::make_unique<RegOperand>(lhs_vreg));
CurMBB->addInstruction(std::move(slt));
auto xori = std::make_unique<MachineInstr>(RVOpcodes::XORI);
xori->addOperand(std::make_unique<RegOperand>(dest_vreg));
xori->addOperand(std::make_unique<RegOperand>(dest_vreg));
xori->addOperand(std::make_unique<ImmOperand>(1));
CurMBB->addInstruction(std::move(xori));
break;
}
case BinaryInst::kICmpGE: {
auto slt = std::make_unique<MachineInstr>(RVOpcodes::SLT);
slt->addOperand(std::make_unique<RegOperand>(dest_vreg));
slt->addOperand(std::make_unique<RegOperand>(lhs_vreg));
slt->addOperand(std::make_unique<RegOperand>(rhs_vreg));
CurMBB->addInstruction(std::move(slt));
auto xori = std::make_unique<MachineInstr>(RVOpcodes::XORI);
xori->addOperand(std::make_unique<RegOperand>(dest_vreg));
xori->addOperand(std::make_unique<RegOperand>(dest_vreg));
xori->addOperand(std::make_unique<ImmOperand>(1));
CurMBB->addInstruction(std::move(xori));
break;
}
default:
throw std::runtime_error("Unsupported binary instruction in ISel");
}
break;
}
case DAGNode::UNARY: {
auto unary = dynamic_cast<UnaryInst*>(node->value);
auto dest_vreg = getVReg(unary);
auto src_vreg = getVReg(unary->getOperand());
switch (unary->getKind()) {
case UnaryInst::kNeg: {
auto instr = std::make_unique<MachineInstr>(RVOpcodes::SUBW);
instr->addOperand(std::make_unique<RegOperand>(dest_vreg));
instr->addOperand(std::make_unique<RegOperand>(PhysicalReg::ZERO));
instr->addOperand(std::make_unique<RegOperand>(src_vreg));
CurMBB->addInstruction(std::move(instr));
break;
}
case UnaryInst::kNot: {
auto instr = std::make_unique<MachineInstr>(RVOpcodes::SEQZ);
instr->addOperand(std::make_unique<RegOperand>(dest_vreg));
instr->addOperand(std::make_unique<RegOperand>(src_vreg));
CurMBB->addInstruction(std::move(instr));
break;
}
default:
throw std::runtime_error("Unsupported unary instruction in ISel");
}
break;
}
case DAGNode::CALL: {
auto call = dynamic_cast<CallInst*>(node->value);
for (size_t i = 0; i < node->operands.size() && i < 8; ++i) {
DAGNode* arg_node = node->operands[i];
auto arg_preg = static_cast<PhysicalReg>(static_cast<int>(PhysicalReg::A0) + i);
if (arg_node->kind == DAGNode::CONSTANT) {
if (auto const_val = dynamic_cast<ConstantValue*>(arg_node->value)) {
auto li = std::make_unique<MachineInstr>(RVOpcodes::LI);
li->addOperand(std::make_unique<RegOperand>(arg_preg));
li->addOperand(std::make_unique<ImmOperand>(const_val->getInt()));
CurMBB->addInstruction(std::move(li));
}
} else {
auto src_vreg = getVReg(arg_node->value);
auto mv = std::make_unique<MachineInstr>(RVOpcodes::MV);
mv->addOperand(std::make_unique<RegOperand>(arg_preg));
mv->addOperand(std::make_unique<RegOperand>(src_vreg));
CurMBB->addInstruction(std::move(mv));
}
}
auto call_instr = std::make_unique<MachineInstr>(RVOpcodes::CALL);
call_instr->addOperand(std::make_unique<LabelOperand>(call->getCallee()->getName()));
CurMBB->addInstruction(std::move(call_instr));
if (!call->getType()->isVoid()) {
auto mv_instr = std::make_unique<MachineInstr>(RVOpcodes::MV);
mv_instr->addOperand(std::make_unique<RegOperand>(getVReg(call)));
mv_instr->addOperand(std::make_unique<RegOperand>(PhysicalReg::A0));
CurMBB->addInstruction(std::move(mv_instr));
}
break;
}
case DAGNode::RETURN: {
auto ret_inst_ir = dynamic_cast<ReturnInst*>(node->value);
if (ret_inst_ir && ret_inst_ir->hasReturnValue()) {
Value* ret_val = ret_inst_ir->getReturnValue();
if (auto const_val = dynamic_cast<ConstantValue*>(ret_val)) {
auto li_instr = std::make_unique<MachineInstr>(RVOpcodes::LI);
li_instr->addOperand(std::make_unique<RegOperand>(PhysicalReg::A0));
li_instr->addOperand(std::make_unique<ImmOperand>(const_val->getInt()));
CurMBB->addInstruction(std::move(li_instr));
} else {
auto mv_instr = std::make_unique<MachineInstr>(RVOpcodes::MV);
mv_instr->addOperand(std::make_unique<RegOperand>(PhysicalReg::A0));
mv_instr->addOperand(std::make_unique<RegOperand>(getVReg(ret_val)));
CurMBB->addInstruction(std::move(mv_instr));
}
}
auto ret_mi = std::make_unique<MachineInstr>(RVOpcodes::RET);
CurMBB->addInstruction(std::move(ret_mi));
break;
}
case DAGNode::BRANCH: {
if (auto cond_br = dynamic_cast<CondBrInst*>(node->value)) {
auto br_instr = std::make_unique<MachineInstr>(RVOpcodes::BNE);
br_instr->addOperand(std::make_unique<RegOperand>(getVReg(cond_br->getCondition())));
br_instr->addOperand(std::make_unique<RegOperand>(PhysicalReg::ZERO));
br_instr->addOperand(std::make_unique<LabelOperand>(cond_br->getThenBlock()->getName()));
CurMBB->addInstruction(std::move(br_instr));
} else if (auto uncond_br = dynamic_cast<UncondBrInst*>(node->value)) {
auto j_instr = std::make_unique<MachineInstr>(RVOpcodes::J);
j_instr->addOperand(std::make_unique<LabelOperand>(uncond_br->getBlock()->getName()));
CurMBB->addInstruction(std::move(j_instr));
}
break;
}
case DAGNode::MEMSET: {
auto memset = dynamic_cast<MemsetInst*>(node->value);
auto r_dest_addr = getVReg(memset->getPointer());
auto r_num_bytes = getVReg(memset->getSize());
auto r_value_byte = getVReg(memset->getValue());
auto r_counter = getNewVReg();
auto r_end_addr = getNewVReg();
auto r_current_addr = getNewVReg();
auto r_temp_val = getNewVReg();
auto add_instr = [&](RVOpcodes op, unsigned rd, unsigned rs1, unsigned rs2) {
auto i = std::make_unique<MachineInstr>(op);
i->addOperand(std::make_unique<RegOperand>(rd));
i->addOperand(std::make_unique<RegOperand>(rs1));
i->addOperand(std::make_unique<RegOperand>(rs2));
CurMBB->addInstruction(std::move(i));
};
auto addi_instr = [&](RVOpcodes op, unsigned rd, unsigned rs1, int64_t imm) {
auto i = std::make_unique<MachineInstr>(op);
i->addOperand(std::make_unique<RegOperand>(rd));
i->addOperand(std::make_unique<RegOperand>(rs1));
i->addOperand(std::make_unique<ImmOperand>(imm));
CurMBB->addInstruction(std::move(i));
};
auto store_instr = [&](RVOpcodes op, unsigned src, unsigned base, int64_t off) {
auto i = std::make_unique<MachineInstr>(op);
i->addOperand(std::make_unique<RegOperand>(src));
i->addOperand(std::make_unique<MemOperand>(std::make_unique<RegOperand>(base), std::make_unique<ImmOperand>(off)));
CurMBB->addInstruction(std::move(i));
};
auto branch_instr = [&](RVOpcodes op, unsigned rs1, unsigned rs2, const std::string& label) {
auto i = std::make_unique<MachineInstr>(op);
i->addOperand(std::make_unique<RegOperand>(rs1));
i->addOperand(std::make_unique<RegOperand>(rs2));
i->addOperand(std::make_unique<LabelOperand>(label));
CurMBB->addInstruction(std::move(i));
};
auto jump_instr = [&](const std::string& label) {
auto i = std::make_unique<MachineInstr>(RVOpcodes::J);
i->addOperand(std::make_unique<LabelOperand>(label));
CurMBB->addInstruction(std::move(i));
};
auto label_instr = [&](const std::string& name) {
auto i = std::make_unique<MachineInstr>(RVOpcodes::LABEL);
i->addOperand(std::make_unique<LabelOperand>(name));
CurMBB->addInstruction(std::move(i));
};
int unique_id = this->local_label_counter++;
std::string loop_start_label = MFunc->getName() + "_memset_loop_start_" + std::to_string(unique_id);
std::string loop_end_label = MFunc->getName() + "_memset_loop_end_" + std::to_string(unique_id);
std::string remainder_label = MFunc->getName() + "_memset_remainder_" + std::to_string(unique_id);
std::string done_label = MFunc->getName() + "_memset_done_" + std::to_string(unique_id);
addi_instr(RVOpcodes::ANDI, r_temp_val, r_value_byte, 255);
addi_instr(RVOpcodes::SLLI, r_value_byte, r_temp_val, 8);
add_instr(RVOpcodes::OR, r_temp_val, r_temp_val, r_value_byte);
addi_instr(RVOpcodes::SLLI, r_value_byte, r_temp_val, 16);
add_instr(RVOpcodes::OR, r_temp_val, r_temp_val, r_value_byte);
addi_instr(RVOpcodes::SLLI, r_value_byte, r_temp_val, 32);
add_instr(RVOpcodes::OR, r_temp_val, r_temp_val, r_value_byte);
add_instr(RVOpcodes::ADD, r_end_addr, r_dest_addr, r_num_bytes);
auto mv = std::make_unique<MachineInstr>(RVOpcodes::MV);
mv->addOperand(std::make_unique<RegOperand>(r_current_addr));
mv->addOperand(std::make_unique<RegOperand>(r_dest_addr));
CurMBB->addInstruction(std::move(mv));
addi_instr(RVOpcodes::ANDI, r_counter, r_num_bytes, -8);
add_instr(RVOpcodes::ADD, r_counter, r_dest_addr, r_counter);
label_instr(loop_start_label);
branch_instr(RVOpcodes::BGEU, r_current_addr, r_counter, loop_end_label);
store_instr(RVOpcodes::SD, r_temp_val, r_current_addr, 0);
addi_instr(RVOpcodes::ADDI, r_current_addr, r_current_addr, 8);
jump_instr(loop_start_label);
label_instr(loop_end_label);
label_instr(remainder_label);
branch_instr(RVOpcodes::BGEU, r_current_addr, r_end_addr, done_label);
store_instr(RVOpcodes::SB, r_temp_val, r_current_addr, 0);
addi_instr(RVOpcodes::ADDI, r_current_addr, r_current_addr, 1);
jump_instr(remainder_label);
label_instr(done_label);
break;
}
default:
throw std::runtime_error("Unsupported DAGNode kind in ISel");
}
}
// 以下是忠实移植的DAG构建函数
RISCv64ISel::DAGNode* RISCv64ISel::create_node(int kind_int, Value* val, std::map<Value*, DAGNode*>& value_to_node, std::vector<std::unique_ptr<DAGNode>>& nodes_storage) {
auto kind = static_cast<DAGNode::NodeKind>(kind_int);
if (val && value_to_node.count(val) && kind != DAGNode::STORE && kind != DAGNode::RETURN && kind != DAGNode::BRANCH && kind != DAGNode::MEMSET) {
return value_to_node[val];
}
auto node = std::make_unique<DAGNode>(kind);
node->value = val;
DAGNode* raw_node_ptr = node.get();
nodes_storage.push_back(std::move(node));
if (val && !val->getType()->isVoid() && (dynamic_cast<Instruction*>(val) || dynamic_cast<GlobalValue*>(val))) {
value_to_node[val] = raw_node_ptr;
}
return raw_node_ptr;
}
RISCv64ISel::DAGNode* RISCv64ISel::get_operand_node(Value* val_ir, std::map<Value*, DAGNode*>& value_to_node, std::vector<std::unique_ptr<DAGNode>>& nodes_storage) {
if (value_to_node.count(val_ir)) {
return value_to_node[val_ir];
} else if (dynamic_cast<ConstantValue*>(val_ir)) {
return create_node(DAGNode::CONSTANT, val_ir, value_to_node, nodes_storage);
} else if (dynamic_cast<GlobalValue*>(val_ir)) {
return create_node(DAGNode::CONSTANT, val_ir, value_to_node, nodes_storage);
} else if (dynamic_cast<AllocaInst*>(val_ir)) {
return create_node(DAGNode::ALLOCA_ADDR, val_ir, value_to_node, nodes_storage);
}
return create_node(DAGNode::LOAD, val_ir, value_to_node, nodes_storage);
}
std::vector<std::unique_ptr<RISCv64ISel::DAGNode>> RISCv64ISel::build_dag(BasicBlock* bb) {
std::vector<std::unique_ptr<DAGNode>> nodes_storage;
std::map<Value*, DAGNode*> value_to_node;
for (const auto& inst_ptr : bb->getInstructions()) {
Instruction* inst = inst_ptr.get();
if (auto alloca = dynamic_cast<AllocaInst*>(inst)) {
create_node(DAGNode::ALLOCA_ADDR, alloca, value_to_node, nodes_storage);
} else if (auto store = dynamic_cast<StoreInst*>(inst)) {
auto store_node = create_node(DAGNode::STORE, store, value_to_node, nodes_storage);
store_node->operands.push_back(get_operand_node(store->getValue(), value_to_node, nodes_storage));
store_node->operands.push_back(get_operand_node(store->getPointer(), value_to_node, nodes_storage));
} else if (auto memset = dynamic_cast<MemsetInst*>(inst)) {
auto memset_node = create_node(DAGNode::MEMSET, memset, value_to_node, nodes_storage);
memset_node->operands.push_back(get_operand_node(memset->getPointer(), value_to_node, nodes_storage));
memset_node->operands.push_back(get_operand_node(memset->getBegin(), value_to_node, nodes_storage));
memset_node->operands.push_back(get_operand_node(memset->getSize(), value_to_node, nodes_storage));
memset_node->operands.push_back(get_operand_node(memset->getValue(), value_to_node, nodes_storage));
} else if (auto load = dynamic_cast<LoadInst*>(inst)) {
auto load_node = create_node(DAGNode::LOAD, load, value_to_node, nodes_storage);
load_node->operands.push_back(get_operand_node(load->getPointer(), value_to_node, nodes_storage));
} else if (auto bin = dynamic_cast<BinaryInst*>(inst)) {
if(value_to_node.count(bin)) continue;
if (bin->getKind() == BinaryInst::kSub) {
if (auto const_lhs = dynamic_cast<ConstantValue*>(bin->getLhs())) {
if (const_lhs->getInt() == 0) {
auto unary_node = create_node(DAGNode::UNARY, bin, value_to_node, nodes_storage);
unary_node->operands.push_back(get_operand_node(bin->getRhs(), value_to_node, nodes_storage));
continue;
}
}
}
auto bin_node = create_node(DAGNode::BINARY, bin, value_to_node, nodes_storage);
bin_node->operands.push_back(get_operand_node(bin->getLhs(), value_to_node, nodes_storage));
bin_node->operands.push_back(get_operand_node(bin->getRhs(), value_to_node, nodes_storage));
} else if (auto un = dynamic_cast<UnaryInst*>(inst)) {
if(value_to_node.count(un)) continue;
auto unary_node = create_node(DAGNode::UNARY, un, value_to_node, nodes_storage);
unary_node->operands.push_back(get_operand_node(un->getOperand(), value_to_node, nodes_storage));
} else if (auto call = dynamic_cast<CallInst*>(inst)) {
if(value_to_node.count(call)) continue;
auto call_node = create_node(DAGNode::CALL, call, value_to_node, nodes_storage);
for (auto arg : call->getArguments()) {
call_node->operands.push_back(get_operand_node(arg->getValue(), value_to_node, nodes_storage));
}
} else if (auto ret = dynamic_cast<ReturnInst*>(inst)) {
auto ret_node = create_node(DAGNode::RETURN, ret, value_to_node, nodes_storage);
if (ret->hasReturnValue()) {
ret_node->operands.push_back(get_operand_node(ret->getReturnValue(), value_to_node, nodes_storage));
}
} else if (auto cond_br = dynamic_cast<CondBrInst*>(inst)) {
auto br_node = create_node(DAGNode::BRANCH, cond_br, value_to_node, nodes_storage);
br_node->operands.push_back(get_operand_node(cond_br->getCondition(), value_to_node, nodes_storage));
} else if (auto uncond_br = dynamic_cast<UncondBrInst*>(inst)) {
create_node(DAGNode::BRANCH, uncond_br, value_to_node, nodes_storage);
}
}
return nodes_storage;
}
} // namespace sysy

8
src/RISCv64Passes.cpp Normal file
View File

@@ -0,0 +1,8 @@
// RISCv64Passes.cpp
#include "RISCv64Passes.h"
namespace sysy {
// 此处为未来优化Pass的实现
} // namespace sysy

322
src/RISCv64RegAlloc.cpp Normal file
View File

@@ -0,0 +1,322 @@
#include "RISCv64RegAlloc.h"
#include "RISCv64ISel.h"
#include <algorithm>
#include <vector>
namespace sysy {
RISCv64RegAlloc::RISCv64RegAlloc(MachineFunction* mfunc) : MFunc(mfunc) {
allocable_int_regs = {
PhysicalReg::T0, PhysicalReg::T1, PhysicalReg::T2, PhysicalReg::T3,
PhysicalReg::T4, PhysicalReg::T5, PhysicalReg::T6,
PhysicalReg::A0, PhysicalReg::A1, PhysicalReg::A2, PhysicalReg::A3,
PhysicalReg::A4, PhysicalReg::A5, PhysicalReg::A6, PhysicalReg::A7,
PhysicalReg::S0, PhysicalReg::S1, PhysicalReg::S2, PhysicalReg::S3,
PhysicalReg::S4, PhysicalReg::S5, PhysicalReg::S6, PhysicalReg::S7,
PhysicalReg::S8, PhysicalReg::S9, PhysicalReg::S10, PhysicalReg::S11,
};
}
void RISCv64RegAlloc::run() {
eliminateFrameIndices();
analyzeLiveness();
buildInterferenceGraph();
colorGraph();
rewriteFunction();
}
void RISCv64RegAlloc::eliminateFrameIndices() {
StackFrameInfo& frame_info = MFunc->getFrameInfo();
int current_offset = 0;
Function* F = MFunc->getFunc();
RISCv64ISel* isel = MFunc->getISel();
for (auto& bb : F->getBasicBlocks()) {
for (auto& inst : bb->getInstructions()) {
if (auto alloca = dynamic_cast<AllocaInst*>(inst.get())) {
int size = 4;
if (!alloca->getDims().empty()) {
int num_elements = 1;
for (const auto& dim_use : alloca->getDims()) {
if (auto const_dim = dynamic_cast<ConstantValue*>(dim_use->getValue())) {
num_elements *= const_dim->getInt();
}
}
size *= num_elements;
}
current_offset += size;
unsigned alloca_vreg = isel->getVReg(alloca);
frame_info.alloca_offsets[alloca_vreg] = -current_offset;
}
}
}
frame_info.locals_size = current_offset;
for (auto& mbb : MFunc->getBlocks()) {
std::vector<std::unique_ptr<MachineInstr>> new_instructions;
for (auto& instr_ptr : mbb->getInstructions()) {
if (instr_ptr->getOpcode() == RVOpcodes::FRAME_LOAD) {
auto& operands = instr_ptr->getOperands();
unsigned dest_vreg = static_cast<RegOperand*>(operands[0].get())->getVRegNum();
unsigned alloca_vreg = static_cast<RegOperand*>(operands[1].get())->getVRegNum();
int offset = frame_info.alloca_offsets.at(alloca_vreg);
auto addr_vreg = isel->getNewVReg();
auto addi = std::make_unique<MachineInstr>(RVOpcodes::ADDI);
addi->addOperand(std::make_unique<RegOperand>(addr_vreg));
addi->addOperand(std::make_unique<RegOperand>(PhysicalReg::S0));
addi->addOperand(std::make_unique<ImmOperand>(offset));
new_instructions.push_back(std::move(addi));
auto lw = std::make_unique<MachineInstr>(RVOpcodes::LW);
lw->addOperand(std::make_unique<RegOperand>(dest_vreg));
lw->addOperand(std::make_unique<MemOperand>(
std::make_unique<RegOperand>(addr_vreg),
std::make_unique<ImmOperand>(0)));
new_instructions.push_back(std::move(lw));
} else if (instr_ptr->getOpcode() == RVOpcodes::FRAME_STORE) {
auto& operands = instr_ptr->getOperands();
unsigned src_vreg = static_cast<RegOperand*>(operands[0].get())->getVRegNum();
unsigned alloca_vreg = static_cast<RegOperand*>(operands[1].get())->getVRegNum();
int offset = frame_info.alloca_offsets.at(alloca_vreg);
auto addr_vreg = isel->getNewVReg();
auto addi = std::make_unique<MachineInstr>(RVOpcodes::ADDI);
addi->addOperand(std::make_unique<RegOperand>(addr_vreg));
addi->addOperand(std::make_unique<RegOperand>(PhysicalReg::S0));
addi->addOperand(std::make_unique<ImmOperand>(offset));
new_instructions.push_back(std::move(addi));
auto sw = std::make_unique<MachineInstr>(RVOpcodes::SW);
sw->addOperand(std::make_unique<RegOperand>(src_vreg));
sw->addOperand(std::make_unique<MemOperand>(
std::make_unique<RegOperand>(addr_vreg),
std::make_unique<ImmOperand>(0)));
new_instructions.push_back(std::move(sw));
} else {
new_instructions.push_back(std::move(instr_ptr));
}
}
mbb->getInstructions() = std::move(new_instructions);
}
}
void RISCv64RegAlloc::getInstrUseDef(MachineInstr* instr, LiveSet& use, LiveSet& def) {
bool is_def = true;
auto opcode = instr->getOpcode();
// 预定义def和use规则
if (opcode == RVOpcodes::SW || opcode == RVOpcodes::SD ||
opcode == RVOpcodes::BEQ || opcode == RVOpcodes::BNE ||
opcode == RVOpcodes::BLT || opcode == RVOpcodes::BGE ||
opcode == RVOpcodes::RET || opcode == RVOpcodes::J) {
is_def = false;
}
if (opcode == RVOpcodes::CALL) {
// CALL会杀死所有调用者保存寄存器这是一个简化处理
// 同时也使用了传入a0-a7的参数
}
for (const auto& op : instr->getOperands()) {
if (op->getKind() == MachineOperand::KIND_REG) {
auto reg_op = static_cast<RegOperand*>(op.get());
if (reg_op->isVirtual()) {
if (is_def) {
def.insert(reg_op->getVRegNum());
is_def = false;
} else {
use.insert(reg_op->getVRegNum());
}
}
} else if (op->getKind() == MachineOperand::KIND_MEM) {
auto mem_op = static_cast<MemOperand*>(op.get());
if (mem_op->getBase()->isVirtual()) {
use.insert(mem_op->getBase()->getVRegNum());
}
}
}
}
void RISCv64RegAlloc::analyzeLiveness() {
bool changed = true;
while (changed) {
changed = false;
for (auto it = MFunc->getBlocks().rbegin(); it != MFunc->getBlocks().rend(); ++it) {
auto& mbb = *it;
LiveSet live_out;
for (auto succ : mbb->successors) {
if (!succ->getInstructions().empty()) {
auto first_instr = succ->getInstructions().front().get();
if (live_in_map.count(first_instr)) {
live_out.insert(live_in_map.at(first_instr).begin(), live_in_map.at(first_instr).end());
}
}
}
for (auto instr_it = mbb->getInstructions().rbegin(); instr_it != mbb->getInstructions().rend(); ++instr_it) {
MachineInstr* instr = instr_it->get();
LiveSet old_live_in = live_in_map[instr];
live_out_map[instr] = live_out;
LiveSet use, def;
getInstrUseDef(instr, use, def);
LiveSet live_in = use;
LiveSet diff = live_out;
for (auto vreg : def) {
diff.erase(vreg);
}
live_in.insert(diff.begin(), diff.end());
live_in_map[instr] = live_in;
live_out = live_in;
if (live_in_map[instr] != old_live_in) {
changed = true;
}
}
}
}
}
void RISCv64RegAlloc::buildInterferenceGraph() {
std::set<unsigned> all_vregs;
for (auto& mbb : MFunc->getBlocks()) {
for(auto& instr : mbb->getInstructions()) {
LiveSet use, def;
getInstrUseDef(instr.get(), use, def);
for(auto u : use) all_vregs.insert(u);
for(auto d : def) all_vregs.insert(d);
}
}
for (auto vreg : all_vregs) { interference_graph[vreg] = {}; }
for (auto& mbb : MFunc->getBlocks()) {
for (auto& instr : mbb->getInstructions()) {
LiveSet def, use;
getInstrUseDef(instr.get(), use, def);
const LiveSet& live_out = live_out_map.at(instr.get());
for (unsigned d : def) {
for (unsigned l : live_out) {
if (d != l) {
interference_graph[d].insert(l);
interference_graph[l].insert(d);
}
}
}
}
}
}
void RISCv64RegAlloc::colorGraph() {
std::vector<unsigned> sorted_vregs;
for (auto const& [vreg, neighbors] : interference_graph) {
sorted_vregs.push_back(vreg);
}
std::sort(sorted_vregs.begin(), sorted_vregs.end(), [&](unsigned a, unsigned b) {
return interference_graph[a].size() > interference_graph[b].size();
});
for (unsigned vreg : sorted_vregs) {
std::set<PhysicalReg> used_colors;
for (unsigned neighbor : interference_graph.at(vreg)) {
if (color_map.count(neighbor)) {
used_colors.insert(color_map.at(neighbor));
}
}
bool colored = false;
for (PhysicalReg preg : allocable_int_regs) {
if (used_colors.find(preg) == used_colors.end()) {
color_map[vreg] = preg;
colored = true;
break;
}
}
if (!colored) {
spilled_vregs.insert(vreg);
}
}
}
void RISCv64RegAlloc::rewriteFunction() {
StackFrameInfo& frame_info = MFunc->getFrameInfo();
int current_offset = frame_info.locals_size;
for (unsigned vreg : spilled_vregs) {
current_offset += 4;
frame_info.spill_offsets[vreg] = -current_offset;
}
frame_info.spill_size = current_offset - frame_info.locals_size;
for (auto& mbb : MFunc->getBlocks()) {
std::vector<std::unique_ptr<MachineInstr>> new_instructions;
for (auto& instr_ptr : mbb->getInstructions()) {
LiveSet use, def;
getInstrUseDef(instr_ptr.get(), use, def);
for (unsigned vreg : use) {
if (spilled_vregs.count(vreg)) {
int offset = frame_info.spill_offsets.at(vreg);
auto load = std::make_unique<MachineInstr>(RVOpcodes::LW);
load->addOperand(std::make_unique<RegOperand>(vreg));
load->addOperand(std::make_unique<MemOperand>(
std::make_unique<RegOperand>(PhysicalReg::S0),
std::make_unique<ImmOperand>(offset)
));
new_instructions.push_back(std::move(load));
}
}
new_instructions.push_back(std::move(instr_ptr));
for (unsigned vreg : def) {
if (spilled_vregs.count(vreg)) {
int offset = frame_info.spill_offsets.at(vreg);
auto store = std::make_unique<MachineInstr>(RVOpcodes::SW);
store->addOperand(std::make_unique<RegOperand>(vreg));
store->addOperand(std::make_unique<MemOperand>(
std::make_unique<RegOperand>(PhysicalReg::S0),
std::make_unique<ImmOperand>(offset)
));
new_instructions.push_back(std::move(store));
}
}
}
mbb->getInstructions() = std::move(new_instructions);
}
for (auto& mbb : MFunc->getBlocks()) {
for (auto& instr_ptr : mbb->getInstructions()) {
for (auto& op_ptr : instr_ptr->getOperands()) {
if(op_ptr->getKind() == MachineOperand::KIND_REG) {
auto reg_op = static_cast<RegOperand*>(op_ptr.get());
if (reg_op->isVirtual()) {
unsigned vreg = reg_op->getVRegNum();
if (color_map.count(vreg)) {
reg_op->setPReg(color_map.at(vreg));
} else if (spilled_vregs.count(vreg)) {
reg_op->setPReg(PhysicalReg::T6); // 溢出统一用t6
}
}
} else if (op_ptr->getKind() == MachineOperand::KIND_MEM) {
auto mem_op = static_cast<MemOperand*>(op_ptr.get());
auto base_reg_op = mem_op->getBase();
if(base_reg_op->isVirtual()){
unsigned vreg = base_reg_op->getVRegNum();
if(color_map.count(vreg)) {
base_reg_op->setPReg(color_map.at(vreg));
} else if (spilled_vregs.count(vreg)) {
base_reg_op->setPReg(PhysicalReg::T6);
}
}
}
}
}
}
}
} // namespace sysy

View File

@@ -1,129 +0,0 @@
#include "Reg2Mem.h"
#include <cstddef>
#include <iostream>
#include <list>
#include <memory>
namespace sysy {
/**
* 删除phi节点
* 删除phi节点后可能会生成冗余存储代码
*/
void Reg2Mem::DeletePhiInst(){
auto &functions = pModule->getFunctions();
for (auto &function : functions) {
auto basicBlocks = function.second->getBasicBlocks();
for (auto &basicBlock : basicBlocks) {
for (auto iter = basicBlock->begin(); iter != basicBlock->end();) {
auto &instruction = *iter;
if (instruction->isPhi()) {
auto predBlocks = basicBlock->getPredecessors();
// 寻找源和目的
// 目的就是phi指令的第一个操作数
// 源就是phi指令的后续操作数
auto destination = instruction->getOperand(0);
int predBlockindex = 0;
for (auto &predBlock : predBlocks) {
++predBlockindex;
// 判断前驱块儿只有一个后继还是多个后继
// 如果有多个
auto source = instruction->getOperand(predBlockindex);
if (source == destination) {
continue;
}
// std::cout << predBlock->getNumSuccessors() << std::endl;
if (predBlock->getNumSuccessors() > 1) {
// 创建一个basicblock
auto newbasicBlock = function.second->addBasicBlock();
std::stringstream ss;
ss << " phidel.L" << pBuilder->getLabelIndex();
newbasicBlock->setName(ss.str());
ss.str("");
// // 修改前驱后继关系
basicBlock->replacePredecessor(predBlock, newbasicBlock);
// predBlock = newbasicBlock;
newbasicBlock->addPredecessor(predBlock);
newbasicBlock->addSuccessor(basicBlock.get());
predBlock->removeSuccessor(basicBlock.get());
predBlock->addSuccessor(newbasicBlock);
// std::cout << "the block name is " << basicBlock->getName() << std::endl;
// for (auto pb : basicBlock->getPredecessors()) {
// // newbasicBlock->addPredecessor(pb);
// std::cout << pb->getName() << std::endl;
// }
// sysy::BasicBlock::conectBlocks(newbasicBlock, static_cast<BasicBlock *>(basicBlock.get()));
// 若后为跳转指令,应该修改跳转指令所到达的位置
auto thelastinst = predBlock->end();
(--thelastinst);
if (thelastinst->get()->isConditional() || thelastinst->get()->isUnconditional()) { // 如果是跳转指令
auto opnum = thelastinst->get()->getNumOperands();
for (size_t i = 0; i < opnum; i++) {
if (thelastinst->get()->getOperand(i) == basicBlock.get()) {
thelastinst->get()->replaceOperand(i, newbasicBlock);
}
}
}
// 在新块中插入store指令
pBuilder->setPosition(newbasicBlock, newbasicBlock->end());
// pBuilder->createStoreInst(source, destination);
if (source->isInt() || source->isFloat()) {
pBuilder->createStoreInst(source, destination);
} else {
auto loadInst = pBuilder->createLoadInst(source);
pBuilder->createStoreInst(loadInst, destination);
}
// pBuilder->createMoveInst(Instruction::kMove, destination->getType(), destination, source,
// newbasicBlock);
pBuilder->setPosition(newbasicBlock, newbasicBlock->end());
pBuilder->createUncondBrInst(basicBlock.get(), {});
} else {
// 如果前驱块只有一个后继
auto thelastinst = predBlock->end();
(--thelastinst);
// std::cout << predBlock->getName() << std::endl;
// std::cout << thelastinst->get() << std::endl;
// std::cout << "First point 11 " << std::endl;
if (thelastinst->get()->isConditional() || thelastinst->get()->isUnconditional()) {
// 在跳转语句前insert st指令
pBuilder->setPosition(predBlock, thelastinst);
} else {
pBuilder->setPosition(predBlock, predBlock->end());
}
if (source->isInt() || source->isFloat()) {
pBuilder->createStoreInst(source, destination);
} else {
auto loadInst = pBuilder->createLoadInst(source);
pBuilder->createStoreInst(loadInst, destination);
}
}
}
// 删除phi指令
auto &instructions = basicBlock->getInstructions();
usedelete(iter->get());
iter = instructions.erase(iter);
if (basicBlock->getNumInstructions() == 0) {
if (basicBlock->getNumSuccessors() == 1) {
pBuilder->setPosition(basicBlock.get(), basicBlock->end());
pBuilder->createUncondBrInst(basicBlock->getSuccessors()[0], {});
}
}
} else {
break;
}
}
}
}
}
void Reg2Mem::usedelete(Instruction *instr) {
for (auto &use : instr->getOperands()) {
auto val = use->getValue();
val->removeUse(use);
}
}
} // namespace sysy

View File

@@ -10,10 +10,9 @@
#include <sstream> #include <sstream>
#include <string> #include <string>
#include <vector> #include <vector>
using namespace std;
#include "SysYIRGenerator.h" #include "SysYIRGenerator.h"
using namespace std;
namespace sysy { namespace sysy {
/* /*
@@ -130,30 +129,111 @@ std::any SysYIRGenerator::visitVarDecl(SysYParser::VarDeclContext *ctx) {
delete root; delete root;
if (dims.empty()) { if (dims.empty()) {
builder.createStoreInst(values.getValue(0), alloca); builder.createStoreInst(values.getValue(0), alloca);
} else { } else{
// 对于多维数组使用memset初始化 // **数组变量初始化**
// 计算每个维度的大小 const std::vector<sysy::Value *> &counterValues = values.getValues();
// 这里的values.getNumbers()返回的是每个维度的大小
// 这里的values.getValues()返回的是每个维度对应的值 // 计算数组的**总元素数量**和**总字节大小**
// 例如对于一个二维数组values.getNumbers()可能是[3, 4]表示3行4列 int numElements = 1;
// values.getValues()可能是[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12] // 存储每个维度的实际整数大小,用于索引计算
// 对于每个维度使用memset将对应的值填充到数组中 std::vector<int> dimSizes;
// 这里的alloca是一个指向数组的指针 for (Value *dimVal : dims) {
const std::vector<unsigned int> & counterNumbers = values.getNumbers(); if (ConstantInteger *constInt = dynamic_cast<ConstantInteger *>(dimVal)) {
const std::vector<sysy::Value *> & counterValues = values.getValues(); int dimSize = constInt->getInt();
unsigned begin = 0; numElements *= dimSize;
for (size_t i = 0; i < counterNumbers.size(); i++) { dimSizes.push_back(dimSize);
}
// TODO else 错误处理:数组维度必须是常量(对于静态分配)
}
unsigned int elementSizeInBytes = type->getSize(); // 获取单个元素的大小(字节)
unsigned int totalSizeInBytes = numElements * elementSizeInBytes;
// **判断是否可以进行全零初始化优化**
bool allValuesAreZero = false;
if (counterValues.empty()) { // 例如 int arr[3] = {}; 或 int arr[3][4] = {};
allValuesAreZero = true;
}
else {
allValuesAreZero = true;
for (Value *val : counterValues){
if (ConstantInteger *constInt = dynamic_cast<ConstantInteger *>(val)){
if (constInt->getInt() != 0){
allValuesAreZero = false;
break;
}
}
else{
// 如果值不是常量,我们通常不能确定它是否为零,所以不进行 memset 优化
allValuesAreZero = false;
break;
}
}
}
if (allValuesAreZero) {
// 如果所有初始化值都是零(或没有明确初始化但语法允许),使用 memset 优化
builder.createMemsetInst( builder.createMemsetInst(
alloca, ConstantValue::get(static_cast<int>(begin)), alloca, // 目标数组的起始地址
ConstantValue::get(static_cast<int>(counterNumbers[i])), ConstantInteger::get(0), // 偏移量通常为0后续删除
counterValues[i]); ConstantInteger::get(totalSizeInBytes),
begin += counterNumbers[i]; ConstantInteger::get(0)); // 填充的总字节数
}
else {
// **逐元素存储:遍历所有初始值,并为每个值生成一个 store 指令**
for (size_t k = 0; k < counterValues.size(); ++k) {
// 用于存储当前元素的索引列表
std::vector<Value *> currentIndices;
int tempLinearIndex = k; // 临时线性索引,用于计算多维索引
// **将线性索引转换为多维索引**
// 这个循环从最内层维度开始倒推,计算每个维度的索引
// 假设是行主序row-major order这是 C/C++ 数组的标准存储方式
for (int dimIdx = dimSizes.size() - 1; dimIdx >= 0; --dimIdx)
{
// 计算当前维度的索引,并插入到列表的最前面
currentIndices.insert(currentIndices.begin(),
ConstantInteger::get(static_cast<int>(tempLinearIndex % dimSizes[dimIdx])));
// 更新线性索引,用于计算下一个更高维度的索引
tempLinearIndex /= dimSizes[dimIdx];
}
// **生成 store 指令,传入值、基指针和计算出的索引列表**
// 你的 builder.createStoreInst 签名需要能够接受这些参数
// 假设你的 builder.createStoreInst(Value *val, Value *ptr, const std::vector<Value *> &indices, ...)
builder.createStoreInst(counterValues[k], alloca, currentIndices);
}
} }
} }
} }
else
{ // **如果没有显式初始化值,默认对数组进行零初始化**
if (!dims.empty())
{ // 只有数组才需要默认的零初始化
int numElements = 1;
for (Value *dimVal : dims)
{
if (ConstantInteger *constInt = dynamic_cast<ConstantInteger *>(dimVal))
{
numElements *= constInt->getInt();
}
}
unsigned int elementSizeInBytes = type->getSize();
unsigned int totalSizeInBytes = numElements * elementSizeInBytes;
// 使用 memset 将整个数组清零
builder.createMemsetInst(
alloca,
ConstantInteger::get(0),
ConstantInteger::get(totalSizeInBytes),
ConstantInteger::get(0)
); // 填充的总字节数
}
// 标量变量如果没有初始化值,通常不生成额外的初始化指令,因为其内存已分配但未赋值。
}
module->addVariable(name, alloca); module->addVariable(name, alloca);
} }
return std::any(); return std::any();
} }
@@ -218,7 +298,7 @@ std::any SysYIRGenerator::visitFuncDef(SysYParser::FuncDefContext *ctx){
paramNames.push_back(param->Ident()->getText()); paramNames.push_back(param->Ident()->getText());
std::vector<Value *> dims = {}; std::vector<Value *> dims = {};
if (!param->LBRACK().empty()) { if (!param->LBRACK().empty()) {
dims.push_back(ConstantValue::get(-1)); // 第一个维度不确定 dims.push_back(ConstantInteger::get(-1)); // 第一个维度不确定
for (const auto &exp : param->exp()) { for (const auto &exp : param->exp()) {
dims.push_back(std::any_cast<Value *>(visitExp(exp))); dims.push_back(std::any_cast<Value *>(visitExp(exp)));
} }
@@ -247,9 +327,9 @@ std::any SysYIRGenerator::visitFuncDef(SysYParser::FuncDefContext *ctx){
if(HasReturnInst == false) { if(HasReturnInst == false) {
// 如果没有return语句则默认返回0 // 如果没有return语句则默认返回0
if (returnType != Type::getVoidType()) { if (returnType != Type::getVoidType()) {
Value* returnValue = ConstantValue::get(0); Value* returnValue = ConstantInteger::get(0);
if (returnType == Type::getFloatType()) { if (returnType == Type::getFloatType()) {
returnValue = ConstantValue::get(0.0f); returnValue = ConstantFloating::get(0.0f);
} }
builder.createReturnInst(returnValue); builder.createReturnInst(returnValue);
} else { } else {
@@ -286,9 +366,9 @@ std::any SysYIRGenerator::visitAssignStmt(SysYParser::AssignStmtContext *ctx) {
ConstantValue * constValue = dynamic_cast<ConstantValue *>(value); ConstantValue * constValue = dynamic_cast<ConstantValue *>(value);
if (constValue != nullptr) { if (constValue != nullptr) {
if (variableType == Type::getFloatType()) { if (variableType == Type::getFloatType()) {
value = ConstantValue::get(static_cast<float>(constValue->getInt())); value = ConstantInteger::get(static_cast<float>(constValue->getInt()));
} else { } else {
value = ConstantValue::get(static_cast<int>(constValue->getFloat())); value = ConstantFloating::get(static_cast<int>(constValue->getFloat()));
} }
} else { } else {
if (variableType == Type::getFloatType()) { if (variableType == Type::getFloatType()) {
@@ -478,9 +558,9 @@ std::any SysYIRGenerator::visitReturnStmt(SysYParser::ReturnStmtContext *ctx) {
ConstantValue * constValue = dynamic_cast<ConstantValue *>(returnValue); ConstantValue * constValue = dynamic_cast<ConstantValue *>(returnValue);
if (constValue != nullptr) { if (constValue != nullptr) {
if (funcType == Type::getFloatType()) { if (funcType == Type::getFloatType()) {
returnValue = ConstantValue::get(static_cast<float>(constValue->getInt())); returnValue = ConstantInteger::get(static_cast<float>(constValue->getInt()));
} else { } else {
returnValue = ConstantValue::get(static_cast<int>(constValue->getFloat())); returnValue = ConstantFloating::get(static_cast<int>(constValue->getFloat()));
} }
} else { } else {
if (funcType == Type::getFloatType()) { if (funcType == Type::getFloatType()) {
@@ -560,10 +640,10 @@ std::any SysYIRGenerator::visitPrimaryExp(SysYParser::PrimaryExpContext *ctx) {
std::any SysYIRGenerator::visitNumber(SysYParser::NumberContext *ctx) { std::any SysYIRGenerator::visitNumber(SysYParser::NumberContext *ctx) {
if (ctx->ILITERAL() != nullptr) { if (ctx->ILITERAL() != nullptr) {
int value = std::stol(ctx->ILITERAL()->getText(), nullptr, 0); int value = std::stol(ctx->ILITERAL()->getText(), nullptr, 0);
return static_cast<Value *>(ConstantValue::get(value)); return static_cast<Value *>(ConstantInteger::get(value));
} else if (ctx->FLITERAL() != nullptr) { } else if (ctx->FLITERAL() != nullptr) {
float value = std::stof(ctx->FLITERAL()->getText()); float value = std::stof(ctx->FLITERAL()->getText());
return static_cast<Value *>(ConstantValue::get(value)); return static_cast<Value *>(ConstantFloating::get(value));
} }
throw std::runtime_error("Unknown number type."); throw std::runtime_error("Unknown number type.");
return std::any(); // 不会到达这里 return std::any(); // 不会到达这里
@@ -599,9 +679,9 @@ std::any SysYIRGenerator::visitCall(SysYParser::CallContext *ctx) {
ConstantValue * constValue = dynamic_cast<ConstantValue *>(args[i]); ConstantValue * constValue = dynamic_cast<ConstantValue *>(args[i]);
if (constValue != nullptr) { if (constValue != nullptr) {
if (params[i]->getType() == Type::getPointerType(Type::getFloatType())) { if (params[i]->getType() == Type::getPointerType(Type::getFloatType())) {
args[i] = ConstantValue::get(static_cast<float>(constValue->getInt())); args[i] = ConstantInteger::get(static_cast<float>(constValue->getInt()));
} else { } else {
args[i] = ConstantValue::get(static_cast<int>(constValue->getFloat())); args[i] = ConstantFloating::get(static_cast<int>(constValue->getFloat()));
} }
} else { } else {
if (params[i]->getType() == Type::getPointerType(Type::getFloatType())) { if (params[i]->getType() == Type::getPointerType(Type::getFloatType())) {
@@ -629,9 +709,9 @@ std::any SysYIRGenerator::visitUnaryExp(SysYParser::UnaryExpContext *ctx) {
ConstantValue * constValue = dynamic_cast<ConstantValue *>(value); ConstantValue * constValue = dynamic_cast<ConstantValue *>(value);
if (constValue != nullptr) { if (constValue != nullptr) {
if (constValue->isFloat()) { if (constValue->isFloat()) {
result = ConstantValue::get(-constValue->getFloat()); result = ConstantFloating::get(-constValue->getFloat());
} else { } else {
result = ConstantValue::get(-constValue->getInt()); result = ConstantInteger::get(-constValue->getInt());
} }
} else if (value != nullptr) { } else if (value != nullptr) {
if (value->getType() == Type::getIntType()) { if (value->getType() == Type::getIntType()) {
@@ -648,9 +728,9 @@ std::any SysYIRGenerator::visitUnaryExp(SysYParser::UnaryExpContext *ctx) {
if (constValue != nullptr) { if (constValue != nullptr) {
if (constValue->isFloat()) { if (constValue->isFloat()) {
result = result =
ConstantValue::get(1 - (constValue->getFloat() != 0.0F ? 1 : 0)); ConstantFloating::get(1 - (constValue->getFloat() != 0.0F ? 1 : 0));
} else { } else {
result = ConstantValue::get(1 - (constValue->getInt() != 0 ? 1 : 0)); result = ConstantInteger::get(1 - (constValue->getInt() != 0 ? 1 : 0));
} }
} else if (value != nullptr) { } else if (value != nullptr) {
if (value->getType() == Type::getIntType()) { if (value->getType() == Type::getIntType()) {
@@ -692,13 +772,13 @@ std::any SysYIRGenerator::visitMulExp(SysYParser::MulExpContext *ctx) {
if (operandType != floatType) { if (operandType != floatType) {
ConstantValue * constValue = dynamic_cast<ConstantValue *>(operand); ConstantValue * constValue = dynamic_cast<ConstantValue *>(operand);
if (constValue != nullptr) if (constValue != nullptr)
operand = ConstantValue::get(static_cast<float>(constValue->getInt())); operand = ConstantFloating::get(static_cast<float>(constValue->getInt()));
else else
operand = builder.createIToFInst(operand); operand = builder.createIToFInst(operand);
} else if (resultType != floatType) { } else if (resultType != floatType) {
ConstantValue* constResult = dynamic_cast<ConstantValue *>(result); ConstantValue* constResult = dynamic_cast<ConstantValue *>(result);
if (constResult != nullptr) if (constResult != nullptr)
result = ConstantValue::get(static_cast<float>(constResult->getInt())); result = ConstantFloating::get(static_cast<float>(constResult->getInt()));
else else
result = builder.createIToFInst(result); result = builder.createIToFInst(result);
} }
@@ -707,14 +787,14 @@ std::any SysYIRGenerator::visitMulExp(SysYParser::MulExpContext *ctx) {
ConstantValue* constOperand = dynamic_cast<ConstantValue *>(operand); ConstantValue* constOperand = dynamic_cast<ConstantValue *>(operand);
if (opType == SysYParser::MUL) { if (opType == SysYParser::MUL) {
if ((constOperand != nullptr) && (constResult != nullptr)) { if ((constOperand != nullptr) && (constResult != nullptr)) {
result = ConstantValue::get(constResult->getFloat() * result = ConstantFloating::get(constResult->getFloat() *
constOperand->getFloat()); constOperand->getFloat());
} else { } else {
result = builder.createFMulInst(result, operand); result = builder.createFMulInst(result, operand);
} }
} else if (opType == SysYParser::DIV) { } else if (opType == SysYParser::DIV) {
if ((constOperand != nullptr) && (constResult != nullptr)) { if ((constOperand != nullptr) && (constResult != nullptr)) {
result = ConstantValue::get(constResult->getFloat() / result = ConstantFloating::get(constResult->getFloat() /
constOperand->getFloat()); constOperand->getFloat());
} else { } else {
result = builder.createFDivInst(result, operand); result = builder.createFDivInst(result, operand);
@@ -729,17 +809,17 @@ std::any SysYIRGenerator::visitMulExp(SysYParser::MulExpContext *ctx) {
ConstantValue * constOperand = dynamic_cast<ConstantValue *>(operand); ConstantValue * constOperand = dynamic_cast<ConstantValue *>(operand);
if (opType == SysYParser::MUL) { if (opType == SysYParser::MUL) {
if ((constOperand != nullptr) && (constResult != nullptr)) if ((constOperand != nullptr) && (constResult != nullptr))
result = ConstantValue::get(constResult->getInt() * constOperand->getInt()); result = ConstantInteger::get(constResult->getInt() * constOperand->getInt());
else else
result = builder.createMulInst(result, operand); result = builder.createMulInst(result, operand);
} else if (opType == SysYParser::DIV) { } else if (opType == SysYParser::DIV) {
if ((constOperand != nullptr) && (constResult != nullptr)) if ((constOperand != nullptr) && (constResult != nullptr))
result = ConstantValue::get(constResult->getInt() / constOperand->getInt()); result = ConstantInteger::get(constResult->getInt() / constOperand->getInt());
else else
result = builder.createDivInst(result, operand); result = builder.createDivInst(result, operand);
} else { } else {
if ((constOperand != nullptr) && (constResult != nullptr)) if ((constOperand != nullptr) && (constResult != nullptr))
result = ConstantValue::get(constResult->getInt() % constOperand->getInt()); result = ConstantInteger::get(constResult->getInt() % constOperand->getInt());
else else
result = builder.createRemInst(result, operand); result = builder.createRemInst(result, operand);
} }
@@ -767,13 +847,13 @@ std::any SysYIRGenerator::visitAddExp(SysYParser::AddExpContext *ctx) {
if (operandType != floatType) { if (operandType != floatType) {
ConstantValue * constOperand = dynamic_cast<ConstantValue *>(operand); ConstantValue * constOperand = dynamic_cast<ConstantValue *>(operand);
if (constOperand != nullptr) if (constOperand != nullptr)
operand = ConstantValue::get(static_cast<float>(constOperand->getInt())); operand = ConstantFloating::get(static_cast<float>(constOperand->getInt()));
else else
operand = builder.createIToFInst(operand); operand = builder.createIToFInst(operand);
} else if (resultType != floatType) { } else if (resultType != floatType) {
ConstantValue * constResult = dynamic_cast<ConstantValue *>(result); ConstantValue * constResult = dynamic_cast<ConstantValue *>(result);
if (constResult != nullptr) if (constResult != nullptr)
result = ConstantValue::get(static_cast<float>(constResult->getInt())); result = ConstantFloating::get(static_cast<float>(constResult->getInt()));
else else
result = builder.createIToFInst(result); result = builder.createIToFInst(result);
} }
@@ -782,12 +862,12 @@ std::any SysYIRGenerator::visitAddExp(SysYParser::AddExpContext *ctx) {
ConstantValue * constOperand = dynamic_cast<ConstantValue *>(operand); ConstantValue * constOperand = dynamic_cast<ConstantValue *>(operand);
if (opType == SysYParser::ADD) { if (opType == SysYParser::ADD) {
if ((constResult != nullptr) && (constOperand != nullptr)) if ((constResult != nullptr) && (constOperand != nullptr))
result = ConstantValue::get(constResult->getFloat() + constOperand->getFloat()); result = ConstantFloating::get(constResult->getFloat() + constOperand->getFloat());
else else
result = builder.createFAddInst(result, operand); result = builder.createFAddInst(result, operand);
} else { } else {
if ((constResult != nullptr) && (constOperand != nullptr)) if ((constResult != nullptr) && (constOperand != nullptr))
result = ConstantValue::get(constResult->getFloat() - constOperand->getFloat()); result = ConstantFloating::get(constResult->getFloat() - constOperand->getFloat());
else else
result = builder.createFSubInst(result, operand); result = builder.createFSubInst(result, operand);
} }
@@ -796,12 +876,12 @@ std::any SysYIRGenerator::visitAddExp(SysYParser::AddExpContext *ctx) {
ConstantValue * constOperand = dynamic_cast<ConstantValue *>(operand); ConstantValue * constOperand = dynamic_cast<ConstantValue *>(operand);
if (opType == SysYParser::ADD) { if (opType == SysYParser::ADD) {
if ((constResult != nullptr) && (constOperand != nullptr)) if ((constResult != nullptr) && (constOperand != nullptr))
result = ConstantValue::get(constResult->getInt() + constOperand->getInt()); result = ConstantInteger::get(constResult->getInt() + constOperand->getInt());
else else
result = builder.createAddInst(result, operand); result = builder.createAddInst(result, operand);
} else { } else {
if ((constResult != nullptr) && (constOperand != nullptr)) if ((constResult != nullptr) && (constOperand != nullptr))
result = ConstantValue::get(constResult->getInt() - constOperand->getInt()); result = ConstantInteger::get(constResult->getInt() - constOperand->getInt());
else else
result = builder.createSubInst(result, operand); result = builder.createSubInst(result, operand);
} }
@@ -833,10 +913,10 @@ std::any SysYIRGenerator::visitRelExp(SysYParser::RelExpContext *ctx) {
auto operand2 = constOperand->isFloat() ? constOperand->getFloat() auto operand2 = constOperand->isFloat() ? constOperand->getFloat()
: constOperand->getInt(); : constOperand->getInt();
if (opType == SysYParser::LT) result = ConstantValue::get(operand1 < operand2 ? 1 : 0); if (opType == SysYParser::LT) result = ConstantInteger::get(operand1 < operand2 ? 1 : 0);
else if (opType == SysYParser::GT) result = ConstantValue::get(operand1 > operand2 ? 1 : 0); else if (opType == SysYParser::GT) result = ConstantInteger::get(operand1 > operand2 ? 1 : 0);
else if (opType == SysYParser::LE) result = ConstantValue::get(operand1 <= operand2 ? 1 : 0); else if (opType == SysYParser::LE) result = ConstantInteger::get(operand1 <= operand2 ? 1 : 0);
else if (opType == SysYParser::GE) result = ConstantValue::get(operand1 >= operand2 ? 1 : 0); else if (opType == SysYParser::GE) result = ConstantInteger::get(operand1 >= operand2 ? 1 : 0);
else assert(false); else assert(false);
} else { } else {
@@ -848,14 +928,14 @@ std::any SysYIRGenerator::visitRelExp(SysYParser::RelExpContext *ctx) {
if (resultType == floatType || operandType == floatType) { if (resultType == floatType || operandType == floatType) {
if (resultType != floatType) { if (resultType != floatType) {
if (constResult != nullptr) if (constResult != nullptr)
result = ConstantValue::get(static_cast<float>(constResult->getInt())); result = ConstantFloating::get(static_cast<float>(constResult->getInt()));
else else
result = builder.createIToFInst(result); result = builder.createIToFInst(result);
} }
if (operandType != floatType) { if (operandType != floatType) {
if (constOperand != nullptr) if (constOperand != nullptr)
operand = ConstantValue::get(static_cast<float>(constOperand->getInt())); operand = ConstantFloating::get(static_cast<float>(constOperand->getInt()));
else else
operand = builder.createIToFInst(operand); operand = builder.createIToFInst(operand);
@@ -901,8 +981,8 @@ std::any SysYIRGenerator::visitEqExp(SysYParser::EqExpContext *ctx) {
auto operand2 = constOperand->isFloat() ? constOperand->getFloat() auto operand2 = constOperand->isFloat() ? constOperand->getFloat()
: constOperand->getInt(); : constOperand->getInt();
if (opType == SysYParser::EQ) result = ConstantValue::get(operand1 == operand2 ? 1 : 0); if (opType == SysYParser::EQ) result = ConstantInteger::get(operand1 == operand2 ? 1 : 0);
else if (opType == SysYParser::NE) result = ConstantValue::get(operand1 != operand2 ? 1 : 0); else if (opType == SysYParser::NE) result = ConstantInteger::get(operand1 != operand2 ? 1 : 0);
else assert(false); else assert(false);
} else { } else {
@@ -913,13 +993,13 @@ std::any SysYIRGenerator::visitEqExp(SysYParser::EqExpContext *ctx) {
if (resultType == floatType || operandType == floatType) { if (resultType == floatType || operandType == floatType) {
if (resultType != floatType) { if (resultType != floatType) {
if (constResult != nullptr) if (constResult != nullptr)
result = ConstantValue::get(static_cast<float>(constResult->getInt())); result = ConstantFloating::get(static_cast<float>(constResult->getInt()));
else else
result = builder.createIToFInst(result); result = builder.createIToFInst(result);
} }
if (operandType != floatType) { if (operandType != floatType) {
if (constOperand != nullptr) if (constOperand != nullptr)
operand = ConstantValue::get(static_cast<float>(constOperand->getInt())); operand = ConstantFloating::get(static_cast<float>(constOperand->getInt()));
else else
operand = builder.createIToFInst(operand); operand = builder.createIToFInst(operand);
} }
@@ -943,9 +1023,9 @@ std::any SysYIRGenerator::visitEqExp(SysYParser::EqExpContext *ctx) {
// 如果只有一个关系表达式则将结果转换为0或1 // 如果只有一个关系表达式则将结果转换为0或1
if (constResult != nullptr) { if (constResult != nullptr) {
if (constResult->isFloat()) if (constResult->isFloat())
result = ConstantValue::get(constResult->getFloat() != 0.0F ? 1 : 0); result = ConstantInteger::get(constResult->getFloat() != 0.0F ? 1 : 0);
else else
result = ConstantValue::get(constResult->getInt() != 0 ? 1 : 0); result = ConstantInteger::get(constResult->getInt() != 0 ? 1 : 0);
} }
} }
@@ -1013,6 +1093,7 @@ void Utils::tree2Array(Type *type, ArrayValueTree *root,
ValueCounter &result, IRBuilder *builder) { ValueCounter &result, IRBuilder *builder) {
Value* value = root->getValue(); Value* value = root->getValue();
auto &children = root->getChildren(); auto &children = root->getChildren();
// 类型转换
if (value != nullptr) { if (value != nullptr) {
if (type == value->getType()) { if (type == value->getType()) {
result.push_back(value); result.push_back(value);
@@ -1020,14 +1101,14 @@ void Utils::tree2Array(Type *type, ArrayValueTree *root,
if (type == Type::getFloatType()) { if (type == Type::getFloatType()) {
ConstantValue* constValue = dynamic_cast<ConstantValue *>(value); ConstantValue* constValue = dynamic_cast<ConstantValue *>(value);
if (constValue != nullptr) if (constValue != nullptr)
result.push_back(ConstantValue::get(static_cast<float>(constValue->getInt()))); result.push_back(ConstantFloating::get(static_cast<float>(constValue->getInt())));
else else
result.push_back(builder->createIToFInst(value)); result.push_back(builder->createIToFInst(value));
} else { } else {
ConstantValue* constValue = dynamic_cast<ConstantValue *>(value); ConstantValue* constValue = dynamic_cast<ConstantValue *>(value);
if (constValue != nullptr) if (constValue != nullptr)
result.push_back(ConstantValue::get(static_cast<int>(constValue->getFloat()))); result.push_back(ConstantInteger::get(static_cast<int>(constValue->getFloat())));
else else
result.push_back(builder->createFtoIInst(value)); result.push_back(builder->createFtoIInst(value));
@@ -1061,9 +1142,9 @@ void Utils::tree2Array(Type *type, ArrayValueTree *root,
int num = blockSize - afterSize + beforeSize; int num = blockSize - afterSize + beforeSize;
if (num > 0) { if (num > 0) {
if (type == Type::getFloatType()) if (type == Type::getFloatType())
result.push_back(ConstantValue::get(0.0F), num); result.push_back(ConstantFloating::get(0.0F), num);
else else
result.push_back(ConstantValue::get(0), num); result.push_back(ConstantInteger::get(0), num);
} }
} }
@@ -1101,7 +1182,7 @@ void Utils::initExternalFunction(Module *pModule, IRBuilder *pBuilder) {
funcName, pModule, pBuilder); funcName, pModule, pBuilder);
paramTypes.push_back(Type::getIntType()); paramTypes.push_back(Type::getIntType());
paramNames.emplace_back("x"); paramNames.emplace_back("x");
paramDims.push_back(std::vector<Value *>{ConstantValue::get(-1)}); paramDims.push_back(std::vector<Value *>{ConstantInteger::get(-1)});
funcName = "getarray"; funcName = "getarray";
Utils::createExternalFunction(paramTypes, paramNames, paramDims, returnType, Utils::createExternalFunction(paramTypes, paramNames, paramDims, returnType,
funcName, pModule, pBuilder); funcName, pModule, pBuilder);
@@ -1117,7 +1198,7 @@ void Utils::initExternalFunction(Module *pModule, IRBuilder *pBuilder) {
returnType = Type::getIntType(); returnType = Type::getIntType();
paramTypes.push_back(Type::getFloatType()); paramTypes.push_back(Type::getFloatType());
paramNames.emplace_back("x"); paramNames.emplace_back("x");
paramDims.push_back(std::vector<Value *>{ConstantValue::get(-1)}); paramDims.push_back(std::vector<Value *>{ConstantInteger::get(-1)});
funcName = "getfarray"; funcName = "getfarray";
Utils::createExternalFunction(paramTypes, paramNames, paramDims, returnType, Utils::createExternalFunction(paramTypes, paramNames, paramDims, returnType,
funcName, pModule, pBuilder); funcName, pModule, pBuilder);
@@ -1141,7 +1222,7 @@ void Utils::initExternalFunction(Module *pModule, IRBuilder *pBuilder) {
paramTypes.push_back(Type::getIntType()); paramTypes.push_back(Type::getIntType());
paramDims.clear(); paramDims.clear();
paramDims.emplace_back(); paramDims.emplace_back();
paramDims.push_back(std::vector<Value *>{ConstantValue::get(-1)}); paramDims.push_back(std::vector<Value *>{ConstantInteger::get(-1)});
paramNames.clear(); paramNames.clear();
paramNames.emplace_back("n"); paramNames.emplace_back("n");
paramNames.emplace_back("a"); paramNames.emplace_back("a");
@@ -1164,7 +1245,7 @@ void Utils::initExternalFunction(Module *pModule, IRBuilder *pBuilder) {
paramTypes.push_back(Type::getFloatType()); paramTypes.push_back(Type::getFloatType());
paramDims.clear(); paramDims.clear();
paramDims.emplace_back(); paramDims.emplace_back();
paramDims.push_back(std::vector<Value *>{ConstantValue::get(-1)}); paramDims.push_back(std::vector<Value *>{ConstantInteger::get(-1)});
paramNames.clear(); paramNames.clear();
paramNames.emplace_back("n"); paramNames.emplace_back("n");
paramNames.emplace_back("a"); paramNames.emplace_back("a");

View File

@@ -469,9 +469,9 @@ void SysYOptPre::SysYAddReturn() {
pBuilder->setPosition(block.get(), block->end()); pBuilder->setPosition(block.get(), block->end());
// TODO: 如果int float函数缺少返回值是否需要报错 // TODO: 如果int float函数缺少返回值是否需要报错
if (func->getReturnType()->isInt()) { if (func->getReturnType()->isInt()) {
pBuilder->createReturnInst(ConstantValue::get(0)); pBuilder->createReturnInst(ConstantInteger::get(0));
} else if (func->getReturnType()->isFloat()) { } else if (func->getReturnType()->isFloat()) {
pBuilder->createReturnInst(ConstantValue::get(0.0F)); pBuilder->createReturnInst(ConstantFloating::get(0.0F));
} else { } else {
pBuilder->createReturnInst(); pBuilder->createReturnInst();
} }

View File

@@ -0,0 +1,59 @@
#pragma once
#include "IR.h" // 假设IR.h包含了Module, Function, BasicBlock, Instruction, Value, IRBuilder, Type等定义
#include "IRBuilder.h" // 需要IRBuilder来创建新指令
#include "SysYIRPrinter.h" // 新增: 用于调试输出
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include <list> // 用于迭代和修改指令列表
#include <algorithm> // for std::reverse (if needed, although not used in final version)
#include <iostream> // MODIFICATION: 用于警告输出
namespace sysy {
/**
* @brief AddressCalculationExpansion Pass
*
* 这是一个IR优化Pass用于将LoadInst和StoreInst中包含的多维数组索引
* 显式地转换为IR中的BinaryInst乘法和加法序列并生成带有线性偏移量的
* LoadInst/StoreInst。
*
* 目的确保在寄存器分配之前所有中间地址计算的结果都有明确的IR指令和对应的虚拟寄存器
* 从而避免在后端DAG构建时临时创建值而导致寄存器分配缺失的问题。
*
* SysY语言特性
* - 无指针类型所有数组访问的基地址是alloca或global的AllocaType/ArrayType
* - 数据类型只有int和float且都占用4字节。
* - LoadInst和StoreInst直接接受多个索引作为额外操作数。
*/
class AddressCalculationExpansion {
private:
Module* pModule;
IRBuilder* pBuilder; // 用于在IR中插入新指令
// 数组元素的固定大小根据SysY特性int和float都是4字节
static const int ELEMENT_SIZE = 4;
// 辅助函数:根据数组的维度信息和当前索引的维度,计算该索引的步长(字节数)
// dims: 包含所有维度大小的vector例如 {2, 3, 4}
// currentDimIndex: 当前正在处理的索引在 dims 中的位置 (0, 1, 2...)
int calculateStride(const std::vector<int>& dims, size_t currentDimIndex) {
int stride = ELEMENT_SIZE; // 最内层元素大小 (4字节)
// 乘以当前维度之后的所有维度的大小
for (size_t i = currentDimIndex + 1; i < dims.size(); ++i) {
stride *= dims[i];
}
return stride;
}
public:
AddressCalculationExpansion(Module* module, IRBuilder* builder)
: pModule(module), pBuilder(builder) {}
// 运行此Pass
bool run();
};
} // namespace sysy

View File

@@ -268,6 +268,51 @@ class ValueCounter {
} ///< 清空ValueCounter } ///< 清空ValueCounter
}; };
// --- Refactored ConstantValue and related classes start here ---
using ConstantValVariant = std::variant<int, float>;
// Helper for hashing std::variant
struct VariantHash {
template <class T>
std::size_t operator()(const T& val) const {
return std::hash<T>{}(val);
}
std::size_t operator()(const ConstantValVariant& v) const {
return std::visit(*this, v);
}
};
struct ConstantValueKey {
Type* type;
ConstantValVariant val;
bool operator==(const ConstantValueKey& other) const {
// Assuming Type objects are canonicalized, or add Type::isSame()
// If Type::isSame() is not available and Type objects are not canonicalized,
// this comparison might not be robust enough for structural equivalence of types.
return type == other.type && val == other.val;
}
};
struct ConstantValueHash {
std::size_t operator()(const ConstantValueKey& key) const {
std::size_t typeHash = std::hash<Type*>{}(key.type);
std::size_t valHash = VariantHash{}(key.val);
// A simple way to combine hashes
return typeHash ^ (valHash << 1);
}
};
struct ConstantValueEqual {
bool operator()(const ConstantValueKey& lhs, const ConstantValueKey& rhs) const {
// Assuming Type objects are canonicalized (e.g., Type::getIntType() always returns same pointer)
// If not, and Type::isSame() is intended, it should be added to Type class.
return lhs.type == rhs.type && lhs.val == rhs.val;
}
};
/*! /*!
* Static constants known at compile time. * Static constants known at compile time.
* *
@@ -276,45 +321,135 @@ class ValueCounter {
* `ConstantValue`并不由指令定义, 也不使用任何Value。它的类型为int/float。 * `ConstantValue`并不由指令定义, 也不使用任何Value。它的类型为int/float。
*/ */
template<class T> struct always_false : std::false_type {};
template<class T> constexpr bool always_false_v = always_false<T>::value;
class ConstantValue : public Value { class ConstantValue : public Value {
protected: protected:
/// 定义字面量类型的聚合类型 static std::unordered_map<ConstantValueKey, ConstantValue*, ConstantValueHash, ConstantValueEqual> mConstantPool;
union {
int iScalar;
float fScalar;
};
protected: public:
explicit ConstantValue(int value, const std::string &name = "") : Value(Type::getIntType(), name), iScalar(value) {} explicit ConstantValue(Type* type, const std::string& name = "") : Value(type, name) {}
explicit ConstantValue(float value, const std::string &name = "") virtual ~ConstantValue() = default;
: Value(Type::getFloatType(), name), fScalar(value) {}
public: virtual size_t hash() const = 0;
static ConstantValue* get(int value); ///< 获取一个int类型的ConstValue *其值为value virtual ConstantValVariant getVal() const = 0;
static ConstantValue* get(float value); ///< 获取一个float类型的ConstValue *其值为value
public: // Static factory method to get a canonical ConstantValue from the pool
static ConstantValue* get(Type* type, ConstantValVariant val);
// Helper methods to access constant values with appropriate casting
int getInt() const { int getInt() const {
assert(isInt()); assert(getType()->isInt() && "Calling getInt() on non-integer type");
return iScalar; return std::get<int>(getVal());
} ///< 返回int类型的值 }
float getFloat() const { float getFloat() const {
assert(isFloat()); assert(getType()->isFloat() && "Calling getFloat() on non-float type");
return fScalar; return std::get<float>(getVal());
} ///< 返回float类型的值 }
template <typename T>
T getValue() const { template<typename T>
if (std::is_same<T, int>::value && isInt()) { T getVal() const {
if constexpr (std::is_same_v<T, int>) {
return getInt(); return getInt();
} } else if constexpr (std::is_same_v<T, float>) {
if (std::is_same<T, float>::value && isFloat()) {
return getFloat(); return getFloat();
} else {
// This ensures a compilation error if an unsupported type is used
static_assert(always_false_v<T>, "Unsupported type for ConstantValue::getValue()");
} }
throw std::bad_cast(); // 或者其他适当的异常处理 }
} ///< 返回值getInt和getFloat统一化整数返回整形浮点返回浮点型
virtual bool isZero() const = 0;
virtual bool isOne() const = 0;
}; };
class ConstantInteger : public ConstantValue {
int constVal;
public:
explicit ConstantInteger(Type* type, int val, const std::string& name = "")
: ConstantValue(type, name), constVal(val) {}
size_t hash() const override {
std::size_t typeHash = std::hash<Type*>{}(getType());
std::size_t valHash = std::hash<int>{}(constVal);
return typeHash ^ (valHash << 1);
}
int getInt() const { return constVal; }
ConstantValVariant getVal() const override { return constVal; }
static ConstantInteger* get(Type* type, int val);
static ConstantInteger* get(int val) { return get(Type::getIntType(), val); }
ConstantInteger* getNeg() const {
assert(getType()->isInt() && "Cannot negate non-integer constant");
return ConstantInteger::get(-constVal);
}
bool isZero() const override { return constVal == 0; }
bool isOne() const override { return constVal == 1; }
};
class ConstantFloating : public ConstantValue {
float constFVal;
public:
explicit ConstantFloating(Type* type, float val, const std::string& name = "")
: ConstantValue(type, name), constFVal(val) {}
size_t hash() const override {
std::size_t typeHash = std::hash<Type*>{}(getType());
std::size_t valHash = std::hash<float>{}(constFVal);
return typeHash ^ (valHash << 1);
}
float getFloat() const { return constFVal; }
ConstantValVariant getVal() const override { return constFVal; }
static ConstantFloating* get(Type* type, float val);
static ConstantFloating* get(float val) { return get(Type::getFloatType(), val); }
ConstantFloating* getNeg() const {
assert(getType()->isFloat() && "Cannot negate non-float constant");
return ConstantFloating::get(-constFVal);
}
bool isZero() const override { return constFVal == 0.0f; }
bool isOne() const override { return constFVal == 1.0f; }
};
class UndefinedValue : public ConstantValue {
private:
static std::unordered_map<Type*, UndefinedValue*> UndefValues;
protected:
explicit UndefinedValue(Type* type, const std::string& name = "")
: ConstantValue(type, name) {
assert(!type->isVoid() && "Cannot create UndefinedValue of void type!");
}
public:
static UndefinedValue* get(Type* type);
size_t hash() const override {
return std::hash<Type*>{}(getType());
}
ConstantValVariant getVal() const override {
if (getType()->isInt()) {
return 0; // Return 0 for undefined integer
} else if (getType()->isFloat()) {
return 0.0f; // Return 0.0f for undefined float
}
assert(false && "UndefinedValue has unexpected type for getValue()");
return 0; // Should not be reached
}
bool isZero() const override { return false; }
bool isOne() const override { return false; }
};
// --- End of refactored ConstantValue and related classes ---
class Instruction; class Instruction;
class Function; class Function;
class BasicBlock; class BasicBlock;
@@ -562,8 +697,8 @@ class Instruction : public User {
kLa = 0x1UL << 36, kLa = 0x1UL << 36,
kMemset = 0x1UL << 37, kMemset = 0x1UL << 37,
kGetSubArray = 0x1UL << 38, kGetSubArray = 0x1UL << 38,
// constant // Constant Kind removed as Constants are now Values, not Instructions.
kConstant = 0x1UL << 37, // kConstant = 0x1UL << 37, // Conflicts with kMemset if kept as is
// phi // phi
kPhi = 0x1UL << 39, kPhi = 0x1UL << 39,
kBitItoF = 0x1UL << 40, kBitItoF = 0x1UL << 40,
@@ -755,24 +890,51 @@ class LaInst : public Instruction {
class PhiInst : public Instruction { class PhiInst : public Instruction {
friend class IRBuilder; friend class IRBuilder;
friend class Function; friend class Function;
friend class SysySSA;
protected: protected:
Value *map_val; // Phi的旧值
PhiInst(Type *type, Value *lhs, const std::vector<Value *> &rhs, Value *mval, BasicBlock *parent, std::unordered_map<BasicBlock *, Value *> blk2val; ///< 存储每个基本块对应的值
unsigned vsize; ///< 存储值的数量
PhiInst(Type *type,
const std::vector<Value *> &rhs = {},
const std::vector<BasicBlock*> &Blocks = {},
BasicBlock *parent = nullptr,
const std::string &name = "") const std::string &name = "")
: Instruction(Kind::kPhi, type, parent, name) { : Instruction(Kind::kPhi, type, parent, name), vsize(rhs.size()) {
map_val = mval; assert(rhs.size() == Blocks.size() && "PhiInst: rhs and Blocks must have the same size");
addOperand(lhs); for(size_t i = 0; i < rhs.size(); ++i) {
addOperands(rhs); addOperand(rhs[i]);
blk2val[Blocks[i]] = rhs[i];
}
} }
public: public:
Value* getMapVal() { return map_val; } Value* getValue(unsigned k) const {return getOperand(2 * k);} ///< 获取位置为k的值
Value* getPointer() const { return getOperand(0); } BasicBlock* getBlock(unsigned k) const {return dynamic_cast<BasicBlock*>(getOperand(2 * k + 1));}
auto& getincomings() const {return blk2val;} ///< 获取所有的基本块和对应的值
Value* getvalfromBlk(BasicBlock* blk);
BasicBlock* getBlkfromVal(Value* val);
unsigned getNumIncomingValues() const { return vsize; } ///< 获取传入值的数量
void addIncoming(Value *value, BasicBlock *block) {
assert(value && block && "PhiInst: value and block must not be null");
addOperand(value);
addOperand(block);
blk2val[block] = value;
vsize++;
} ///< 添加传入值和对应的基本块
void delValue(Value* val);
void delBlk(BasicBlock* blk);
void replaceBlk(BasicBlock* newBlk, unsigned k);
void replaceold2new(BasicBlock* oldBlk, BasicBlock* newBlk);
void refreshB2VMap();
auto getValues() { return make_range(std::next(operand_begin()), operand_end()); } auto getValues() { return make_range(std::next(operand_begin()), operand_end()); }
Value* getValue(unsigned index) const { return getOperand(index + 1); }
}; };
@@ -884,7 +1046,7 @@ public:
} }
} ///< 根据指令类型进行二元计算eval template模板实现 } ///< 根据指令类型进行二元计算eval template模板实现
static BinaryInst* create(Kind kind, Type *type, Value *lhs, Value *rhs, BasicBlock *parent, const std::string &name = "") { static BinaryInst* create(Kind kind, Type *type, Value *lhs, Value *rhs, BasicBlock *parent, const std::string &name = "") {
// 后端处理数组访存操作时需要创建计算地址的指令,需要在外部构造 BinaryInst 对象所以写了个public的方法。 // 后端处理数组访存操作时需要创建计算地址的指令,需要在外部构造 BinaryInst 对象
return new BinaryInst(kind, type, lhs, rhs, parent, name); return new BinaryInst(kind, type, lhs, rhs, parent, name);
} }
}; // class BinaryInst }; // class BinaryInst
@@ -1230,12 +1392,15 @@ protected:
if (init.size() == 0) { if (init.size() == 0) {
unsigned num = 1; unsigned num = 1;
for (unsigned i = 0; i < numDims; i++) { for (unsigned i = 0; i < numDims; i++) {
num *= dynamic_cast<ConstantValue *>(dims[i])->getInt(); // Assume dims elements are ConstantInteger and cast appropriately
auto dim_val = dynamic_cast<ConstantInteger*>(dims[i]);
assert(dim_val && "GlobalValue dims must be constant integers");
num *= dim_val->getInt();
} }
if (dynamic_cast<PointerType *>(type)->getBaseType() == Type::getFloatType()) { if (dynamic_cast<PointerType *>(type)->getBaseType() == Type::getFloatType()) {
init.push_back(ConstantValue::get(0.0F), num); init.push_back(ConstantFloating::get(0.0F), num); // Use new constant factory
} else { } else {
init.push_back(ConstantValue::get(0), num); init.push_back(ConstantInteger::get(0), num); // Use new constant factory
} }
} }
initValues = init; initValues = init;
@@ -1261,8 +1426,11 @@ public:
Value* getByIndices(const std::vector<Value *> &indices) const { Value* getByIndices(const std::vector<Value *> &indices) const {
int index = 0; int index = 0;
for (size_t i = 0; i < indices.size(); i++) { for (size_t i = 0; i < indices.size(); i++) {
index = dynamic_cast<ConstantValue *>(getDim(i))->getInt() * index + // Ensure dims[i] and indices[i] are ConstantInteger and retrieve their values correctly
dynamic_cast<ConstantValue *>(indices[i])->getInt(); auto dim_val = dynamic_cast<ConstantInteger*>(getDim(i));
auto idx_val = dynamic_cast<ConstantInteger*>(indices[i]);
assert(dim_val && idx_val && "Dims and indices must be constant integers");
index = dim_val->getInt() * index + idx_val->getInt();
} }
return getByIndex(index); return getByIndex(index);
} ///< 通过多维索引indices获取初始值 } ///< 通过多维索引indices获取初始值
@@ -1303,8 +1471,11 @@ class ConstantVariable : public User, public LVal {
int index = 0; int index = 0;
// 计算偏移量 // 计算偏移量
for (size_t i = 0; i < indices.size(); i++) { for (size_t i = 0; i < indices.size(); i++) {
index = dynamic_cast<ConstantValue *>(getDim(i))->getInt() * index + // Ensure dims[i] and indices[i] are ConstantInteger and retrieve their values correctly
dynamic_cast<ConstantValue *>(indices[i])->getInt(); auto dim_val = dynamic_cast<ConstantInteger*>(getDim(i));
auto idx_val = dynamic_cast<ConstantInteger*>(indices[i]);
assert(dim_val && idx_val && "Dims and indices must be constant integers");
index = dim_val->getInt() * index + idx_val->getInt();
} }
return getByIndex(index); return getByIndex(index);

View File

@@ -333,15 +333,11 @@ class IRBuilder {
block->getInstructions().emplace(position, inst); block->getInstructions().emplace(position, inst);
return inst; return inst;
} ///< 创建store指令 } ///< 创建store指令
PhiInst * createPhiInst(Type *type, Value *lhs, BasicBlock *parent, const std::string &name = "") { PhiInst * createPhiInst(Type *type, const std::vector<Value*> &vals = {}, const std::vector<BasicBlock*> &blks = {}, const std::string &name = "") {
auto predNum = parent->getNumPredecessors(); auto predNum = block->getNumPredecessors();
std::vector<Value *> rhs; auto inst = new PhiInst(type, vals, blks, block, name);
for (size_t i = 0; i < predNum; i++) {
rhs.push_back(lhs);
}
auto inst = new PhiInst(type, lhs, rhs, lhs, parent, name);
assert(inst); assert(inst);
parent->getInstructions().emplace(parent->begin(), inst); block->getInstructions().emplace(block->begin(), inst);
return inst; return inst;
} ///< 创建Phi指令 } ///< 创建Phi指令
}; };

View File

@@ -0,0 +1,32 @@
#ifndef RISCV64_ASMPRINTER_H
#define RISCV64_ASMPRINTER_H
#include "RISCv64LLIR.h"
#include <iostream>
namespace sysy {
class RISCv64AsmPrinter {
public:
RISCv64AsmPrinter(MachineFunction* mfunc);
// 主入口
void run(std::ostream& os);
private:
// 打印各个部分
void printPrologue();
void printEpilogue();
void printBasicBlock(MachineBasicBlock* mbb);
void printInstruction(MachineInstr* instr);
// 辅助函数
std::string regToString(PhysicalReg reg);
void printOperand(MachineOperand* op);
MachineFunction* MFunc;
std::ostream* OS;
};
} // namespace sysy
#endif // RISCV64_ASMPRINTER_H

View File

@@ -3,118 +3,23 @@
#include "IR.h" #include "IR.h"
#include <string> #include <string>
#include <vector>
#include <map>
#include <set>
#include <memory>
#include <iostream>
#include <functional> // For std::function
extern int DEBUG;
extern int DEEPDEBUG;
namespace sysy { namespace sysy {
// RISCv64CodeGen 现在是一个高层驱动器
class RISCv64CodeGen { class RISCv64CodeGen {
public: public:
enum class PhysicalReg {
ZERO, RA, SP, GP, TP, T0, T1, T2, S0, S1, A0, A1, A2, A3, A4, A5, A6, A7, S2, S3, S4, S5, S6, S7, S8, S9, S10, S11, T3, T4, T5, T6,
F0, F1, F2, F3, F4, F5, F6, F7, F8, F9, F10, F11, F12, F13, F14, F15,F16, F17, F18, F19, F20, F21, F22, F23, F24, F25, F26, F27, F28, F29, F30, F31
};
// Move DAGNode and RegAllocResult to public section
struct DAGNode {
enum NodeKind { CONSTANT, LOAD, STORE, BINARY, CALL, RETURN, BRANCH, ALLOCA_ADDR, UNARY };
NodeKind kind;
Value* value = nullptr; // For IR Value
std::string inst; // Generated RISC-V instruction(s) for this node
std::string result_vreg; // Virtual register assigned to this node's result
std::vector<DAGNode*> operands;
std::vector<DAGNode*> users; // For debugging and potentially optimizations
DAGNode(NodeKind k) : kind(k) {}
// Debugging / helper
std::string getNodeKindString() const {
switch (kind) {
case CONSTANT: return "CONSTANT";
case LOAD: return "LOAD";
case STORE: return "STORE";
case BINARY: return "BINARY";
case CALL: return "CALL";
case RETURN: return "RETURN";
case BRANCH: return "BRANCH";
case ALLOCA_ADDR: return "ALLOCA_ADDR";
case UNARY: return "UNARY";
default: return "UNKNOWN";
}
}
};
struct RegAllocResult {
std::map<std::string, PhysicalReg> vreg_to_preg; // Virtual register to Physical Register mapping
std::map<Value*, int> stack_map; // Value (AllocaInst) to stack offset
int stack_size = 0; // Total stack frame size for locals and spills
};
RISCv64CodeGen(Module* mod) : module(mod) {} RISCv64CodeGen(Module* mod) : module(mod) {}
// 唯一的公共入口点
std::string code_gen(); std::string code_gen();
std::string module_gen();
std::string function_gen(Function* func);
// 修改 basicBlock_gen 的声明,添加 int block_idx 参数
std::string basicBlock_gen(BasicBlock* bb, const RegAllocResult& alloc, int block_idx);
// DAG related
std::vector<std::unique_ptr<DAGNode>> build_dag(BasicBlock* bb);
void select_instructions(DAGNode* node, const RegAllocResult& alloc);
// 改变 emit_instructions 的参数,使其可以直接添加汇编指令到 main ss
void emit_instructions(DAGNode* node, std::stringstream& ss, const RegAllocResult& alloc, std::set<DAGNode*>& emitted_nodes);
// Register Allocation related
std::map<Instruction*, std::set<std::string>> liveness_analysis(Function* func);
std::map<std::string, std::set<std::string>> build_interference_graph(
const std::map<Instruction*, std::set<std::string>>& live_sets);
void color_graph(std::map<std::string, PhysicalReg>& vreg_to_preg,
const std::map<std::string, std::set<std::string>>& interference_graph);
RegAllocResult register_allocation(Function* func);
void eliminate_phi(Function* func); // Phi elimination is typically done before DAG building
// Utility
std::string reg_to_string(PhysicalReg reg);
void print_dag(const std::vector<std::unique_ptr<DAGNode>>& dag, const std::string& bb_name);
private: private:
static const std::vector<PhysicalReg> allocable_regs; // 模块级代码生成
std::map<Value*, std::string> value_vreg_map; // Maps IR Value* to its virtual register name std::string module_gen();
// 函数级代码生成 (实现新的流水线)
std::string function_gen(Function* func);
Module* module; Module* module;
int vreg_counter = 0; // Counter for unique virtual register names
int alloca_offset_counter = 0; // Counter for alloca offsets
// 新增一个成员变量来存储当前函数的所有 DAGNode以确保其生命周期贯穿整个函数代码生成
// 这样可以在多个 BasicBlock_gen 调用中访问到完整的 DAG 节点
std::vector<std::unique_ptr<DAGNode>> current_function_dag_nodes;
// 为空标签定义一个伪名称前缀,加上块索引以确保唯一性
const std::string ENTRY_BLOCK_PSEUDO_NAME = "entry_block_";
// !!! 修改get_operand_node 辅助函数现在需要传入 value_to_node 和 nodes_storage 的引用
// 因为它们是 build_dag 局部管理的
DAGNode* get_operand_node(
Value* val_ir,
std::map<Value*, DAGNode*>& value_to_node,
std::vector<std::unique_ptr<DAGNode>>& nodes_storage
);
// !!! 新增create_node 辅助函数也需要传入 value_to_node 和 nodes_storage 的引用
// 并且它应该不再是 lambda而是一个真正的成员函数
DAGNode* create_node(
DAGNode::NodeKind kind,
Value* val,
std::map<Value*, DAGNode*>& value_to_node,
std::vector<std::unique_ptr<DAGNode>>& nodes_storage
);
std::vector<std::unique_ptr<Instruction>> temp_instructions_storage; // 用于存储 build_dag 中创建的临时 BinaryInst
}; };
} // namespace sysy } // namespace sysy

49
src/include/RISCv64ISel.h Normal file
View File

@@ -0,0 +1,49 @@
#ifndef RISCV64_ISEL_H
#define RISCV64_ISEL_H
#include "RISCv64LLIR.h"
namespace sysy {
class RISCv64ISel {
public:
RISCv64ISel();
// 模块主入口将一个高层IR函数转换为底层LLIR函数
std::unique_ptr<MachineFunction> runOnFunction(Function* func);
// 公开接口以便后续模块如RegAlloc可以查询或创建vreg
unsigned getVReg(Value* val);
unsigned getNewVReg() { return vreg_counter++; }
private:
// DAG节点定义作为ISel的内部实现细节
struct DAGNode;
// 指令选择主流程
void select();
// 为单个基本块生成指令
void selectBasicBlock(BasicBlock* bb);
// 核心函数为DAG节点选择并生成MachineInstr
void selectNode(DAGNode* node);
// DAG 构建相关函数 (从原RISCv64Backend迁移)
std::vector<std::unique_ptr<DAGNode>> build_dag(BasicBlock* bb);
DAGNode* get_operand_node(Value* val_ir, std::map<Value*, DAGNode*>&, std::vector<std::unique_ptr<DAGNode>>&);
DAGNode* create_node(int kind, Value* val, std::map<Value*, DAGNode*>&, std::vector<std::unique_ptr<DAGNode>>&);
// 状态
Function* F; // 当前处理的高层IR函数
std::unique_ptr<MachineFunction> MFunc; // 正在构建的底层LLIR函数
MachineBasicBlock* CurMBB; // 当前正在处理的机器基本块
// 映射关系
std::map<Value*, unsigned> vreg_map;
std::map<const BasicBlock*, MachineBasicBlock*> bb_map;
unsigned vreg_counter;
int local_label_counter;
};
} // namespace sysy
#endif // RISCV64_ISEL_H

200
src/include/RISCv64LLIR.h Normal file
View File

@@ -0,0 +1,200 @@
#ifndef RISCV64_LLIR_H
#define RISCV64_LLIR_H
#include "IR.h" // 确保包含了您自己的IR头文件
#include <string>
#include <vector>
#include <memory>
#include <cstdint>
#include <map>
// 前向声明,避免循环引用
namespace sysy {
class Function;
class RISCv64ISel;
}
namespace sysy {
// 物理寄存器定义
enum class PhysicalReg {
ZERO, RA, SP, GP, TP, T0, T1, T2, S0, S1, A0, A1, A2, A3, A4, A5, A6, A7, S2, S3, S4, S5, S6, S7, S8, S9, S10, S11, T3, T4, T5, T6,
F0, F1, F2, F3, F4, F5, F6, F7, F8, F9, F10, F11, F12, F13, F14, F15,F16, F17, F18, F19, F20, F21, F22, F23, F24, F25, F26, F27, F28, F29, F30, F31
};
// RISC-V 指令操作码枚举
enum class RVOpcodes {
// 算术指令
ADD, ADDI, ADDW, ADDIW, SUB, SUBW, MUL, MULW, DIV, DIVW, REM, REMW,
// 逻辑指令
XOR, XORI, OR, ORI, AND, ANDI,
// 移位指令
SLL, SLLI, SLLW, SLLIW, SRL, SRLI, SRLW, SRLIW, SRA, SRAI, SRAW, SRAIW,
// 比较指令
SLT, SLTI, SLTU, SLTIU,
// 内存访问指令
LW, LH, LB, LWU, LHU, LBU, SW, SH, SB, LD, SD,
// 控制流指令
J, JAL, JALR, RET,
BEQ, BNE, BLT, BGE, BLTU, BGEU,
// 伪指令
LI, LA, MV, NEG, NEGW, SEQZ, SNEZ,
// 函数调用
CALL,
// 特殊标记,非指令
LABEL,
// 新增伪指令,用于解耦栈帧处理
FRAME_LOAD, // 从栈帧加载 (AllocaInst)
FRAME_STORE, // 保存到栈帧 (AllocaInst)
};
class MachineOperand;
class RegOperand;
class ImmOperand;
class LabelOperand;
class MemOperand;
class MachineInstr;
class MachineBasicBlock;
class MachineFunction;
// 操作数基类
class MachineOperand {
public:
enum OperandKind { KIND_REG, KIND_IMM, KIND_LABEL, KIND_MEM };
MachineOperand(OperandKind kind) : kind(kind) {}
virtual ~MachineOperand() = default;
OperandKind getKind() const { return kind; }
private:
OperandKind kind;
};
// 寄存器操作数
class RegOperand : public MachineOperand {
public:
// 构造虚拟寄存器
RegOperand(unsigned vreg_num)
: MachineOperand(KIND_REG), vreg_num(vreg_num), is_virtual(true) {}
// 构造物理寄存器
RegOperand(PhysicalReg preg)
: MachineOperand(KIND_REG), preg(preg), is_virtual(false) {}
bool isVirtual() const { return is_virtual; }
unsigned getVRegNum() const { return vreg_num; }
PhysicalReg getPReg() const { return preg; }
void setPReg(PhysicalReg new_preg) {
preg = new_preg;
is_virtual = false;
}
private:
unsigned vreg_num = 0;
PhysicalReg preg = PhysicalReg::ZERO;
bool is_virtual;
};
// 立即数操作数
class ImmOperand : public MachineOperand {
public:
ImmOperand(int64_t value) : MachineOperand(KIND_IMM), value(value) {}
int64_t getValue() const { return value; }
private:
int64_t value;
};
// 标签操作数
class LabelOperand : public MachineOperand {
public:
LabelOperand(const std::string& name) : MachineOperand(KIND_LABEL), name(name) {}
const std::string& getName() const { return name; }
private:
std::string name;
};
// 内存操作数, 表示 offset(base_reg)
class MemOperand : public MachineOperand {
public:
MemOperand(std::unique_ptr<RegOperand> base, std::unique_ptr<ImmOperand> offset)
: MachineOperand(KIND_MEM), base(std::move(base)), offset(std::move(offset)) {}
RegOperand* getBase() const { return base.get(); }
ImmOperand* getOffset() const { return offset.get(); }
private:
std::unique_ptr<RegOperand> base;
std::unique_ptr<ImmOperand> offset;
};
// 机器指令
class MachineInstr {
public:
MachineInstr(RVOpcodes opcode) : opcode(opcode) {}
RVOpcodes getOpcode() const { return opcode; }
const std::vector<std::unique_ptr<MachineOperand>>& getOperands() const { return operands; }
std::vector<std::unique_ptr<MachineOperand>>& getOperands() { return operands; }
void addOperand(std::unique_ptr<MachineOperand> operand) {
operands.push_back(std::move(operand));
}
private:
RVOpcodes opcode;
std::vector<std::unique_ptr<MachineOperand>> operands;
};
// 机器基本块
class MachineBasicBlock {
public:
MachineBasicBlock(const std::string& name, MachineFunction* parent)
: name(name), parent(parent) {}
const std::string& getName() const { return name; }
MachineFunction* getParent() const { return parent; }
const std::vector<std::unique_ptr<MachineInstr>>& getInstructions() const { return instructions; }
std::vector<std::unique_ptr<MachineInstr>>& getInstructions() { return instructions; }
void addInstruction(std::unique_ptr<MachineInstr> instr) {
instructions.push_back(std::move(instr));
}
std::vector<MachineBasicBlock*> successors;
std::vector<MachineBasicBlock*> predecessors;
private:
std::string name;
std::vector<std::unique_ptr<MachineInstr>> instructions;
MachineFunction* parent;
};
// 栈帧信息
struct StackFrameInfo {
int locals_size = 0; // 仅为AllocaInst分配的大小
int spill_size = 0; // 仅为溢出分配的大小
int total_size = 0; // 总大小
std::map<unsigned, int> alloca_offsets; // <AllocaInst的vreg, 栈偏移>
std::map<unsigned, int> spill_offsets; // <溢出vreg, 栈偏移>
};
// 机器函数
class MachineFunction {
public:
MachineFunction(Function* func, RISCv64ISel* isel) : F(func), name(func->getName()), isel(isel) {}
Function* getFunc() const { return F; }
RISCv64ISel* getISel() const { return isel; }
const std::string& getName() const { return name; }
StackFrameInfo& getFrameInfo() { return frame_info; }
const std::vector<std::unique_ptr<MachineBasicBlock>>& getBlocks() const { return blocks; }
std::vector<std::unique_ptr<MachineBasicBlock>>& getBlocks() { return blocks; }
void addBlock(std::unique_ptr<MachineBasicBlock> block) {
blocks.push_back(std::move(block));
}
private:
Function* F;
RISCv64ISel* isel; // 指向创建它的ISel用于获取vreg映射等信息
std::string name;
std::vector<std::unique_ptr<MachineBasicBlock>> blocks;
StackFrameInfo frame_info;
};
} // namespace sysy
#endif // RISCV64_LLIR_H

View File

@@ -0,0 +1,18 @@
// RISCv64Passes.h
#ifndef RISCV64_PASSES_H
#define RISCV64_PASSES_H
#include "RISCv64LLIR.h"
namespace sysy {
// 此处为未来优化Pass的基类或独立类定义
// 例如:
// class PeepholeOptimizer {
// public:
// void runOnMachineFunction(MachineFunction* mfunc);
// };
} // namespace sysy
#endif // RISCV64_PASSES_H

View File

@@ -0,0 +1,56 @@
#ifndef RISCV64_REGALLOC_H
#define RISCV64_REGALLOC_H
#include "RISCv64LLIR.h"
namespace sysy {
class RISCv64RegAlloc {
public:
RISCv64RegAlloc(MachineFunction* mfunc);
// 模块主入口
void run();
private:
using LiveSet = std::set<unsigned>; // 活跃虚拟寄存器集合
using InterferenceGraph = std::map<unsigned, std::set<unsigned>>;
// 栈帧管理
void eliminateFrameIndices();
// 活跃性分析
void analyzeLiveness();
// 构建干扰图
void buildInterferenceGraph();
// 图着色分配寄存器
void colorGraph();
// 重写函数替换vreg并插入溢出代码
void rewriteFunction();
// 辅助函数获取指令的Use/Def集合
void getInstrUseDef(MachineInstr* instr, LiveSet& use, LiveSet& def);
MachineFunction* MFunc;
// 活跃性分析结果
std::map<const MachineInstr*, LiveSet> live_in_map;
std::map<const MachineInstr*, LiveSet> live_out_map;
// 干扰图
InterferenceGraph interference_graph;
// 图着色结果
std::map<unsigned, PhysicalReg> color_map; // vreg -> preg
std::set<unsigned> spilled_vregs; // 被溢出的vreg集合
// 可用的物理寄存器池
std::vector<PhysicalReg> allocable_int_regs;
};
} // namespace sysy
#endif // RISCV64_REGALLOC_H

View File

@@ -16,9 +16,10 @@ using namespace antlr4;
#include "SysYIROptPre.h" #include "SysYIROptPre.h"
#include "RISCv64Backend.h" #include "RISCv64Backend.h"
#include "SysYIRAnalyser.h" #include "SysYIRAnalyser.h"
#include "DeadCodeElimination.h" // #include "DeadCodeElimination.h"
#include "Mem2Reg.h" #include "AddressCalculationExpansion.h"
#include "Reg2Mem.h" // #include "Mem2Reg.h"
// #include "Reg2Mem.h"
using namespace sysy; using namespace sysy;
@@ -124,7 +125,7 @@ int main(int argc, char **argv) {
// 无论最终输出是 IR 还是 ASM只要不是停止在 AST 阶段,都会进入此优化流程。 // 无论最终输出是 IR 还是 ASM只要不是停止在 AST 阶段,都会进入此优化流程。
// optLevel = 0 时,执行默认优化。 // optLevel = 0 时,执行默认优化。
// optLevel >= 1 时,执行默认优化 + 额外的 -O1 优化。 // optLevel >= 1 时,执行默认优化 + 额外的 -O1 优化。
cout << "Applying middle-end optimizations (level -O" << optLevel << ")...\n"; if (DEBUG) cout << "Applying middle-end optimizations (level -O" << optLevel << ")...\n";
// 设置 DEBUG 模式(如果指定了 'ird' // 设置 DEBUG 模式(如果指定了 'ird'
if (argStopAfter == "ird") { if (argStopAfter == "ird") {
@@ -143,23 +144,22 @@ int main(int argc, char **argv) {
cout << "=== After CFA & AVA (Default) ===\n"; cout << "=== After CFA & AVA (Default) ===\n";
SysYPrinter(moduleIR).printIR(); // 临时打印器用于调试 SysYPrinter(moduleIR).printIR(); // 临时打印器用于调试
} }
AddressCalculationExpansion ace(moduleIR, builder);
DeadCodeElimination dce(moduleIR, &cfa, &ava); if (ace.run()) {
dce.runDCEPipeline(); if (DEBUG) cout << "AddressCalculationExpansion made changes.\n";
if (DEBUG) { // 如果 ACE 改变了IR并且 DEBUG 模式开启可以考虑打印IR
cout << "=== After 1st DCE (Default) ===\n"; if (DEBUG) {
SysYPrinter(moduleIR).printIR(); cout << "=== After AddressCalculationExpansion ===\n";
SysYPrinter(moduleIR).printIR();
}
} else {
if (DEBUG) cout << "AddressCalculationExpansion made no changes.\n";
} }
// 根据优化级别,执行额外的优化 pass // 根据优化级别,执行额外的优化 pass
if (optLevel >= 1) { if (optLevel >= 1) {
cout << "Applying additional -O" << optLevel << " optimizations...\n"; if (DEBUG) cout << "Applying additional -O" << optLevel << " optimizations...\n";
// 放置 -O1 及其以上级别要启用的额外优化 pass // 放置 -O1 及其以上级别要启用的额外优化 pass
// 例如: // 例如:
// MyNewOptimizationPass newOpt(moduleIR, builder); // MyNewOptimizationPass newOpt(moduleIR, builder);
@@ -174,28 +174,34 @@ int main(int argc, char **argv) {
// MyCustomOpt2 opt2_pass(moduleIR, builder, &cfa); // 假设需要CFA // MyCustomOpt2 opt2_pass(moduleIR, builder, &cfa); // 假设需要CFA
// opt2_pass.run(); // opt2_pass.run();
// ... 更多 -O1 特有的优化 // ... 更多 -O1 特有的优化
// DeadCodeElimination dce(moduleIR, &cfa, &ava);
Mem2Reg mem2reg(moduleIR, builder, &cfa, &ava); // dce.runDCEPipeline();
mem2reg.mem2regPipeline(); // if (DEBUG) {
if (DEBUG) { // cout << "=== After 1st DCE (Default) ===\n";
cout << "=== After Mem2Reg (Default) ===\n"; // SysYPrinter(moduleIR).printIR();
SysYPrinter(moduleIR).printIR(); // }
}
Reg2Mem reg2mem(moduleIR, builder); // Mem2Reg mem2reg(moduleIR, builder, &cfa, &ava);
reg2mem.DeletePhiInst(); // mem2reg.mem2regPipeline();
if (DEBUG) { // if (DEBUG) {
cout << "=== After Reg2Mem (Default) ===\n"; // cout << "=== After Mem2Reg (Default) ===\n";
SysYPrinter(moduleIR).printIR(); // SysYPrinter(moduleIR).printIR();
} // }
dce.runDCEPipeline(); // 第二次 DCE (默认) // Reg2Mem reg2mem(moduleIR, builder);
if (DEBUG) { // reg2mem.DeletePhiInst();
cout << "=== After 2nd DCE (Default) ===\n"; // if (DEBUG) {
SysYPrinter(moduleIR).printIR(); // cout << "=== After Reg2Mem (Default) ===\n";
} // SysYPrinter(moduleIR).printIR();
// }
// dce.runDCEPipeline(); // 第二次 DCE (默认)
// if (DEBUG) {
// cout << "=== After 2nd DCE (Default) ===\n";
// SysYPrinter(moduleIR).printIR();
// }
} else { } else {
cout << "No additional middle-end optimizations applied for -O" << optLevel << ".\n"; if (DEBUG) cout << "No additional middle-end optimizations applied for -O" << optLevel << ".\n";
} }
// 5. 根据 argStopAfter 决定后续操作 // 5. 根据 argStopAfter 决定后续操作
@@ -212,7 +218,7 @@ int main(int argc, char **argv) {
// 设置 DEBUG 模式(如果指定了 'asmd' // 设置 DEBUG 模式(如果指定了 'asmd'
if (argStopAfter == "asmd") { if (argStopAfter == "asmd") {
DEBUG = 1; DEBUG = 1;
DEEPDEBUG = 1; // DEEPDEBUG = 1;
} }
sysy::RISCv64CodeGen codegen(moduleIR); // 传入优化后的 moduleIR sysy::RISCv64CodeGen codegen(moduleIR); // 传入优化后的 moduleIR
string asmCode = codegen.code_gen(); string asmCode = codegen.code_gen();

View File

@@ -20,7 +20,12 @@ TESTDATA_DIR="${SCRIPT_DIR}/testdata"
# 定义编译器 (这里假设 gcc 在 VM 内部是可用的) # 定义编译器 (这里假设 gcc 在 VM 内部是可用的)
GCC_NATIVE="gcc" # VM 内部的 gcc GCC_NATIVE="gcc" # VM 内部的 gcc
# 不再需要 QEMU_RISCV64因为直接执行
# --- 新增功能: 初始化变量 ---
TIMEOUT_SECONDS=5 # 默认运行时超时时间为 5 秒
COMPILE_TIMEOUT_SECONDS=10 # 默认编译超时时间为 10 秒
TOTAL_CASES=0
PASSED_CASES=0
# 显示帮助信息的函数 # 显示帮助信息的函数
show_help() { show_help() {
@@ -29,31 +34,32 @@ show_help() {
echo "假设当前运行环境已经是 RISC-V 64 位架构,可以直接执行编译后的程序。" echo "假设当前运行环境已经是 RISC-V 64 位架构,可以直接执行编译后的程序。"
echo "" echo ""
echo "选项:" echo "选项:"
echo " -c, --clean 清理 'tmp' 目录下的所有生成文件。" echo " -c, --clean 清理 'tmp' 目录下的所有生成文件。"
echo " -h, --help 显示此帮助信息并退出。" echo " -t, --timeout N 设置每个测试用例的运行时超时为 N 秒 (默认: 5)。"
echo " -ct, --compile-timeout M 设置 gcc 编译的超时时间为 M 秒 (默认: 10)。"
echo " -h, --help 显示此帮助信息并退出。"
echo "" echo ""
echo "执行步骤:" echo "执行步骤:"
echo "1. 遍历 'tmp/' 目录下的所有 .s 汇编文件。" echo "1. 遍历 'tmp/' 目录下的所有 .s 汇编文件。"
echo "2. 使用 VM 内部的 gcc 将 .s 文件汇编并链接为可执行文件 (链接 -L./lib -lsysy_riscv -static)。" echo "2. 在指定的超时时间内使用 VM 内部的 gcc 将 .s 文件汇编并链接为可执行文件。"
echo "3. 直接运行编译后的可执行文件。" echo "3. 在指定的超时时间内运行编译后的可执行文件。"
echo "4. 根据对应的 testdata/*.out 文件内容(最后一行是否为整数)决定是进行返回值比较、标准输出比较,或两者都进行。" echo "4. 根据对应的 .out 文件内容进行返回值和/或标准输出比较。"
echo "5. 如果没有对应的 .in/.out 文件,则打印可执行文件的返回值。" echo "5. 输出比较时会忽略行尾多余的换行符。"
echo "6. 输出比较时会忽略行尾多余的换行符。" echo "6. 所有测试结束后,报告总通过率。"
} }
# 清理临时文件的函数 # 清理临时文件的函数
clean_tmp() { clean_tmp() {
echo "正在清理临时目录: ${TMP_DIR}" echo "正在清理临时目录: ${TMP_DIR}"
# 清理所有由本脚本和 runit.sh 生成的文件
rm -rf "${TMP_DIR}"/*.s \ rm -rf "${TMP_DIR}"/*.s \
"${TMP_DIR}"/*_sysyc_riscv64 \ "${TMP_DIR}"/*_sysyc_riscv64 \
"${TMP_DIR}"/*_sysyc_riscv64.actual_out \ "${TMP_DIR}"/*_sysyc_riscv64.actual_out \
"${TMP_DIR}"/*_sysyc_riscv64.expected_stdout \ "${TMP_DIR}"/*_sysyc_riscv64.expected_stdout \
"${TMP_DIR}"/*_sysyc_riscv64.o # 以防生成了 .o 文件 "${TMP_DIR}"/*_sysyc_riscv64.o
echo "清理完成。" echo "清理完成。"
} }
# 如果临时目录不存在,则创建它 (尽管 runit.sh 应该已经创建了) # 如果临时目录不存在,则创建它
mkdir -p "${TMP_DIR}" mkdir -p "${TMP_DIR}"
# 解析命令行参数 # 解析命令行参数
@@ -63,6 +69,24 @@ while [[ "$#" -gt 0 ]]; do
clean_tmp clean_tmp
exit 0 exit 0
;; ;;
-t|--timeout)
if [[ -n "$2" && "$2" =~ ^[0-9]+$ ]]; then
TIMEOUT_SECONDS="$2"
shift # 移过参数值
else
echo "错误: --timeout 需要一个正整数参数。" >&2
exit 1
fi
;;
-ct|--compile-timeout)
if [[ -n "$2" && "$2" =~ ^[0-9]+$ ]]; then
COMPILE_TIMEOUT_SECONDS="$2"
shift # 移过参数值
else
echo "错误: --compile-timeout 需要一个正整数参数。" >&2
exit 1
fi
;;
-h|--help) -h|--help)
show_help show_help
exit 0 exit 0
@@ -73,30 +97,33 @@ while [[ "$#" -gt 0 ]]; do
exit 1 exit 1
;; ;;
esac esac
shift # 移过参数名
done done
echo "SysY VM 内部测试运行器启动..." echo "SysY VM 内部测试运行器启动..."
echo "编译超时设置为: ${COMPILE_TIMEOUT_SECONDS}"
echo "运行时超时设置为: ${TIMEOUT_SECONDS}"
echo "汇编文件目录: ${TMP_DIR}" echo "汇编文件目录: ${TMP_DIR}"
echo "库文件目录: ${LIB_DIR}" echo "库文件目录: ${LIB_DIR}"
echo "测试数据目录: ${TESTDATA_DIR}" echo "测试数据目录: ${TESTDATA_DIR}"
echo "" echo ""
# 查找 tmp 目录下的所有 .s 汇编文件 # 查找 tmp 目录下的所有 .s 汇编文件
s_files=$(find "${TMP_DIR}" -maxdepth 1 -name "*.s")
TOTAL_CASES=$(echo "$s_files" | wc -w)
# 遍历找到的每个 .s 文件 # 遍历找到的每个 .s 文件
find "${TMP_DIR}" -maxdepth 1 -name "*.s" | while read s_file; do echo "$s_files" | while read s_file; do
# --- 新增功能: 初始化用例通过状态 ---
is_passed=1 # 1 表示通过, 0 表示失败
# 从 .s 文件名中提取原始的测试用例名称部分 # 从 .s 文件名中提取原始的测试用例名称部分
# 例如:从 functional_21_if_test2_sysyc_riscv64.s 提取 functional_21_if_test2
base_name_from_s_file=$(basename "$s_file" .s) base_name_from_s_file=$(basename "$s_file" .s)
# 这一步得到的是 'functional_21_if_test2' 或 'performance_2024-2D0-22'
original_test_name_underscored=$(echo "$base_name_from_s_file" | sed 's/_sysyc_riscv64$//') original_test_name_underscored=$(echo "$base_name_from_s_file" | sed 's/_sysyc_riscv64$//')
# 将 `original_test_name_underscored` 分割成类别和文件名 # 将 `original_test_name_underscored` 分割成类别和文件名
# 例如:'functional_21_if_test2' 分割为 'functional' 和 '21_if_test2'
category=$(echo "$original_test_name_underscored" | cut -d'_' -f1) category=$(echo "$original_test_name_underscored" | cut -d'_' -f1)
# cut -d'_' -f2- 会从第二个下划线开始获取所有剩余部分
test_file_base=$(echo "$original_test_name_underscored" | cut -d'_' -f2-) test_file_base=$(echo "$original_test_name_underscored" | cut -d'_' -f2-)
# 构建原始的相对路径,例如:'functional/21_if_test2'
original_relative_path="${category}/${test_file_base}" original_relative_path="${category}/${test_file_base}"
# 定义可执行文件、输入文件、参考输出文件和实际输出文件的路径 # 定义可执行文件、输入文件、参考输出文件和实际输出文件的路径
@@ -109,109 +136,112 @@ find "${TMP_DIR}" -maxdepth 1 -name "*.s" | while read s_file; do
echo " 对应的测试用例路径: ${original_relative_path}" echo " 对应的测试用例路径: ${original_relative_path}"
# 步骤 1: 使用 VM 内部的 gcc 编译 .s 到可执行文件 # 步骤 1: 使用 VM 内部的 gcc 编译 .s 到可执行文件
# 注意:这里假设 gcc 在 VM 环境中可用,且 ./lib 是相对于当前脚本运行目录 echo " 使用 gcc 汇编并链接 (超时 ${COMPILE_TIMEOUT_SECONDS}s)..."
echo " 使用 gcc 汇编并链接: ${GCC_NATIVE} \"${s_file}\" -o \"${executable_file}\" -L\"${LIB_DIR}\" -lsysy_riscv -static -g" # --- 修改点: 为 gcc 增加 timeout ---
"${GCC_NATIVE}" "${s_file}" -o "${executable_file}" -L"${LIB_DIR}" -lsysy_riscv -static -g timeout ${COMPILE_TIMEOUT_SECONDS} "${GCC_NATIVE}" "${s_file}" -o "${executable_file}" -L"${LIB_DIR}" -lsysy_riscv -static -g
if [ $? -ne 0 ]; then GCC_STATUS=$?
echo -e "\e[31m错误: GCC 汇编/链接 ${s_file} 失败\e[0m" if [ $GCC_STATUS -eq 124 ]; then
continue echo -e "\e[31m错误: GCC 编译/链接 ${s_file} 超时 (超过 ${COMPILE_TIMEOUT_SECONDS} 秒)\e[0m"
fi is_passed=0
echo " 生成的可执行文件: ${executable_file}" elif [ $GCC_STATUS -ne 0 ]; then
echo -e "\e[31m错误: GCC 汇编/链接 ${s_file} 失败,退出码: ${GCC_STATUS}\e[0m"
is_passed=0
else
echo " 生成的可执行文件: ${executable_file}"
echo " 正在执行 (超时 ${TIMEOUT_SECONDS}s): \"${executable_file}\""
# 步骤 2: 执行编译后的文件并比较/报告结果 # 步骤 2: 执行编译后的文件并比较/报告结果
# 直接执行可执行文件,不再通过 qemu-riscv64 if [ -f "${output_reference_file}" ]; then
echo " 正在执行: \"${executable_file}\"" # 修改点:移除多余的 ./ LAST_LINE_TRIMMED=$(tail -n 1 "${output_reference_file}" | tr -d '[:space:]')
# 检查是否存在 .out 文件
if [ -f "${output_reference_file}" ]; then
# 尝试从 .out 文件中提取期望的返回码和期望的标准输出
# 获取 .out 文件的最后一行,去除空白字符
LAST_LINE_TRIMMED=$(tail -n 1 "${output_reference_file}" | tr -d '[:space:]')
# 检查最后一行是否为纯整数 (允许正负号)
if [[ "$LAST_LINE_TRIMMED" =~ ^[-+]?[0-9]+$ ]]; then
# 假设最后一行是期望的返回码
EXPECTED_RETURN_CODE="$LAST_LINE_TRIMMED"
# 创建一个只包含期望标准输出的临时文件 (所有行除了最后一行) if [[ "$LAST_LINE_TRIMMED" =~ ^[-+]?[0-9]+$ ]]; then
EXPECTED_STDOUT_FILE="${TMP_DIR}/${base_name_from_s_file}.expected_stdout" EXPECTED_RETURN_CODE="$LAST_LINE_TRIMMED"
# 使用 head -n -1 来获取除了最后一行之外的所有行。如果文件只有一行,则生成一个空文件。 EXPECTED_STDOUT_FILE="${TMP_DIR}/${base_name_from_s_file}.expected_stdout"
head -n -1 "${output_reference_file}" > "${EXPECTED_STDOUT_FILE}" head -n -1 "${output_reference_file}" > "${EXPECTED_STDOUT_FILE}"
echo " 检测到 .out 文件同时包含标准输出和期望的返回码。"
echo " 期望返回码: ${EXPECTED_RETURN_CODE}"
echo " 检测到 .out 文件同时包含标准输出和期望的返回码。" if [ -f "${input_file}" ]; then
echo " 期望返回码: ${EXPECTED_RETURN_CODE}" timeout ${TIMEOUT_SECONDS} "${executable_file}" < "${input_file}" > "${output_actual_file}"
if [ -s "${EXPECTED_STDOUT_FILE}" ]; then # -s 检查文件是否非空 else
echo " 期望标准输出文件: ${EXPECTED_STDOUT_FILE}" timeout ${TIMEOUT_SECONDS} "${executable_file}" > "${output_actual_file}"
else fi
echo " 期望标准输出为空。" ACTUAL_RETURN_CODE=$?
fi
# 执行程序,捕获实际返回码和实际标准输出 if [ "$ACTUAL_RETURN_CODE" -eq 124 ]; then
if [ -f "${input_file}" ]; then echo -e "\e[31m 执行超时: ${original_relative_path}.sy 运行超过 ${TIMEOUT_SECONDS} 秒\e[0m"
echo " 使用输入文件: ${input_file}" is_passed=0
"${executable_file}" < "${input_file}" > "${output_actual_file}" # 修改点:移除多余的 ./ else
else if [ "$ACTUAL_RETURN_CODE" -eq "$EXPECTED_RETURN_CODE" ]; then
"${executable_file}" > "${output_actual_file}" # 修改点:移除多余的 ./ echo -e "\e[32m 返回码测试成功: (${ACTUAL_RETURN_CODE}) 与期望值 (${EXPECTED_RETURN_CODE}) 匹配\e[0m"
fi else
ACTUAL_RETURN_CODE=$? # 捕获执行状态 echo -e "\e[31m 返回码测试失败: 期望: ${EXPECTED_RETURN_CODE}, 实际: ${ACTUAL_RETURN_CODE}\e[0m"
is_passed=0
fi
# 比较实际返回码与期望返回码 if diff -q <(sed ':a;N;$!ba;s/\n*$//' "${output_actual_file}") <(sed ':a;N;$!ba;s/\n*$//' "${EXPECTED_STDOUT_FILE}") >/dev/null 2>&1; then
if [ "$ACTUAL_RETURN_CODE" -eq "$EXPECTED_RETURN_CODE" ]; then echo -e "\e[32m 标准输出测试成功\e[0m"
echo -e "\e[32m 返回码测试成功: ${original_relative_path}.sy 的返回码 (${ACTUAL_RETURN_CODE}) 与期望值 (${EXPECTED_RETURN_CODE}) 匹配\e[0m" else
echo -e "\e[31m 标准输出测试失败\e[0m"
echo " 差异:"
diff "${output_actual_file}" "${EXPECTED_STDOUT_FILE}"
is_passed=0
fi
fi
else else
echo -e "\e[31m 返回码测试失败: ${original_relative_path}.sy 的返回码不匹配。期望: ${EXPECTED_RETURN_CODE}, 实际: ${ACTUAL_RETURN_CODE}\e[0m" echo " 检测到 .out 文件为纯标准输出参考。"
fi if [ -f "${input_file}" ]; then
timeout ${TIMEOUT_SECONDS} "${executable_file}" < "${input_file}" > "${output_actual_file}"
else
timeout ${TIMEOUT_SECONDS} "${executable_file}" > "${output_actual_file}"
fi
EXEC_STATUS=$?
# 比较实际标准输出与期望标准输出,忽略文件末尾的换行符差异 if [ $EXEC_STATUS -eq 124 ]; then
if diff -q <(sed ':a;N;$!ba;s/\n*$//' "${output_actual_file}") <(sed ':a;N;$!ba;s/\n*$//' "${EXPECTED_STDOUT_FILE}") >/dev/null 2>&1; then echo -e "\e[31m 执行超时: ${original_relative_path}.sy 运行超过 ${TIMEOUT_SECONDS} 秒\e[0m"
echo -e "\e[32m 标准输出测试成功: 输出与 ${original_relative_path}.sy 的参考输出匹配 (忽略行尾换行符差异)\e[0m" is_passed=0
else else
echo -e "\e[31m 标准输出测试失败: ${original_relative_path}.sy 的输出不匹配\e[0m" if [ $EXEC_STATUS -ne 0 ]; then
echo " 差异 (可能包含行尾换行符差异):" echo -e "\e[33m警告: 程序以非零状态 ${EXEC_STATUS} 退出 (纯输出比较模式)。\e[0m"
diff "${output_actual_file}" "${EXPECTED_STDOUT_FILE}" # 显示原始差异以便调试 fi
if diff -q <(sed ':a;N;$!ba;s/\n*$//' "${output_actual_file}") <(sed ':a;N;$!ba;s/\n*$//' "${output_reference_file}") >/dev/null 2>&1; then
echo -e "\e[32m 成功: 输出与参考输出匹配\e[0m"
else
echo -e "\e[31m 失败: 输出不匹配\e[0m"
echo " 差异:"
diff "${output_actual_file}" "${output_reference_file}"
is_passed=0
fi
fi
fi fi
else else
# 最后一行不是纯整数,将整个 .out 文件视为纯标准输出 echo " 未找到 .out 文件。正在运行并报告返回码。"
echo " 检测到 .out 文件为纯标准输出参考。正在与输出文件比较: ${output_reference_file}" timeout ${TIMEOUT_SECONDS} "${executable_file}"
EXEC_STATUS=$?
# 执行程序,并将输出重定向到临时文件 if [ $EXEC_STATUS -eq 124 ]; then
if [ -f "${input_file}" ]; then echo -e "\e[31m 执行超时: ${original_relative_path}.sy 运行超过 ${TIMEOUT_SECONDS} 秒\e[0m"
echo " 使用输入文件: ${input_file}" is_passed=0
"${executable_file}" < "${input_file}" > "${output_actual_file}" # 修改点:移除多余的 ./
else else
"${executable_file}" > "${output_actual_file}" # 修改点:移除多余的 ./ echo " ${original_relative_path}.sy 的返回码: ${EXEC_STATUS}"
fi
EXEC_STATUS=$? # 捕获执行状态
if [ $EXEC_STATUS -ne 0 ]; then
echo -e "\e[33m警告: 可执行文件 ${original_relative_path}.sy 以非零状态 ${EXEC_STATUS} 退出 (纯输出比较模式)。请检查程序逻辑或其是否应返回此状态。\e[0m"
fi
# 比较实际输出与参考输出,忽略文件末尾的换行符差异
if diff -q <(sed ':a;N;$!ba;s/\n*$//' "${output_actual_file}") <(sed ':a;N;$!ba;s/\n*$//' "${output_reference_file}") >/dev/null 2>&1; then
echo -e "\e[32m 成功: 输出与 ${original_relative_path}.sy 的参考输出匹配 (忽略行尾换行符差异)\e[0m"
else
echo -e "\e[31m 失败: ${original_relative_path}.sy 的输出不匹配\e[0m"
echo " 差异 (可能包含行尾换行符差异):"
diff "${output_actual_file}" "${output_reference_file}" # 显示原始差异以便调试
fi fi
fi fi
elif [ -f "${input_file}" ]; then
# 只有 .in 文件存在,使用输入运行并报告退出码(无参考输出)
echo " 使用输入文件: ${input_file}"
echo " 没有 .out 文件进行比较。正在运行并报告返回码。"
"${executable_file}" < "${input_file}" # 修改点:移除多余的 ./
EXEC_STATUS=$?
echo " ${original_relative_path}.sy 的返回码: ${EXEC_STATUS}"
else
# .in 和 .out 文件都不存在,只运行并报告退出码
echo " 未找到 .in 或 .out 文件。正在运行并报告返回码。"
"${executable_file}" # 修改点:移除多余的 ./
EXEC_STATUS=$?
echo " ${original_relative_path}.sy 的返回码: ${EXEC_STATUS}"
fi fi
echo "" # 为测试用例之间添加一个空行,以提高可读性
# --- 新增功能: 更新通过用例计数 ---
if [ "$is_passed" -eq 1 ]; then
((PASSED_CASES++))
fi
echo "" # 为测试用例之间添加一个空行
done done
echo "脚本完成。" # --- 新增功能: 打印最终总结 ---
echo "========================================"
echo "测试完成"
echo "测试通过率: [${PASSED_CASES}/${TOTAL_CASES}]"
echo "========================================"
if [ "$PASSED_CASES" -eq "$TOTAL_CASES" ]; then
exit 0
else
exit 1
fi

View File

@@ -16,38 +16,33 @@ SYSYC="${BUILD_BIN_DIR}/sysyc"
GCC_RISCV64="riscv64-linux-gnu-gcc" GCC_RISCV64="riscv64-linux-gnu-gcc"
QEMU_RISCV64="qemu-riscv64" QEMU_RISCV64="qemu-riscv64"
# 标志,用于确定是否应该生成和运行可执行文件 # --- 新增功能: 初始化变量 ---
EXECUTE_MODE=false EXECUTE_MODE=false
SYSYC_TIMEOUT=10 # sysyc 编译超时 (秒)
GCC_TIMEOUT=10 # gcc 编译超时 (秒)
EXEC_TIMEOUT=5 # qemu 执行超时 (秒)
TOTAL_CASES=0
PASSED_CASES=0
FAILED_CASES_LIST="" # 用于存储未通过的测例列表
# 显示帮助信息的函数 # 显示帮助信息的函数
show_help() { show_help() {
echo "用法: $0 [选项]" echo "用法: $0 [选项]"
echo "此脚本用于编译 .sy 文件,并可选择性地运行它们进行测试。" echo "此脚本用于按文件名前缀数字升序编译和测试 .sy 文件。"
echo "" echo ""
echo "选项:" echo "选项:"
echo " -e, --executable 编译为可执行文件运行可执行文件,并比较输出(如果存在 .in/.out 文件)。" echo " -e, --executable 编译为可执行文件运行测试。"
echo " 如果 .out 文件的最后一行是整数,则将其视为期望的返回值进行比较,其余内容视为期望的标准输出。" echo " -c, --clean 清理 'tmp' 目录下的所有生成文件。"
echo " 如果 .out 文件的最后一行不是整数,则将整个 .out 文件视为期望的标准输出进行比较。" echo " -sct N 设置 sysyc 编译超时为 N 秒 (默认: 10)。"
echo " 输出比较时会忽略行尾多余的换行符。" echo " -gct N 设置 gcc 交叉编译超时为 N 秒 (默认: 10)。"
echo " 如果不存在 .in/.out 文件,则打印返回码。" echo " -et N 设置 qemu 执行超时为 N 秒 (默认: 5)。"
echo " -c, --clean 清理 'tmp' 目录下的所有生成文件。" echo " -h, --help 显示此帮助信息并退出。"
echo " -h, --help 显示此帮助信息并退出。"
echo ""
echo "编译步骤:"
echo "1. 调用 sysyc 将 .sy 编译为 .s (RISC-V 汇编)。"
echo "2. 调用 riscv64-linux-gnu-gcc 将 .s 编译为可执行文件,并链接 -L../lib/ -lsysy_riscv -static。"
echo "3. 调用 qemu-riscv64 执行编译后的文件。"
echo "4. 根据 .out 文件内容(最后一行是否为整数)决定是进行返回值比较、标准输出比较,或两者都进行。"
echo "5. 如果没有 .in/.out 文件,则打印可执行文件的返回值。"
} }
# 清理临时文件的函数 # 清理临时文件的函数
clean_tmp() { clean_tmp() {
echo "正在清理临时目录: ${TMP_DIR}" echo "正在清理临时目录: ${TMP_DIR}"
rm -rf "${TMP_DIR}"/* rm -rf "${TMP_DIR}"/*
# 如果需要,也可以根据 clean.sh 示例清理其他特定文件
# rm -rf "${SCRIPT_DIR}"/*.s "${SCRIPT_DIR}"/*.ll "${SCRIPT_DIR}"/*clang "${SCRIPT_DIR}"/*sysyc
# rm -rf "${SCRIPT_DIR}"/*_riscv64
} }
# 如果临时目录不存在,则创建它 # 如果临时目录不存在,则创建它
@@ -58,12 +53,20 @@ while [[ "$#" -gt 0 ]]; do
case "$1" in case "$1" in
-e|--executable) -e|--executable)
EXECUTE_MODE=true EXECUTE_MODE=true
shift
;; ;;
-c|--clean) -c|--clean)
clean_tmp clean_tmp
exit 0 exit 0
;; ;;
-sct)
if [[ -n "$2" && "$2" =~ ^[0-9]+$ ]]; then SYSYC_TIMEOUT="$2"; shift; else echo "错误: -sct 需要一个正整数参数。" >&2; exit 1; fi
;;
-gct)
if [[ -n "$2" && "$2" =~ ^[0-9]+$ ]]; then GCC_TIMEOUT="$2"; shift; else echo "错误: -gct 需要一个正整数参数。" >&2; exit 1; fi
;;
-et)
if [[ -n "$2" && "$2" =~ ^[0-9]+$ ]]; then EXEC_TIMEOUT="$2"; shift; else echo "错误: -et 需要一个正整数参数。" >&2; exit 1; fi
;;
-h|--help) -h|--help)
show_help show_help
exit 0 exit 0
@@ -74,150 +77,175 @@ while [[ "$#" -gt 0 ]]; do
exit 1 exit 1
;; ;;
esac esac
shift
done done
echo "SysY 测试运行器启动..." echo "SysY 测试运行器启动..."
echo "输入目录: ${TESTDATA_DIR}" echo "输入目录: ${TESTDATA_DIR}"
echo "临时目录: ${TMP_DIR}" echo "临时目录: ${TMP_DIR}"
echo "执行模式已启用: ${EXECUTE_MODE}" echo "执行模式: ${EXECUTE_MODE}"
if ${EXECUTE_MODE}; then
echo "超时设置: sysyc=${SYSYC_TIMEOUT}s, gcc=${GCC_TIMEOUT}s, qemu=${EXEC_TIMEOUT}s"
fi
echo "" echo ""
# 查找 testdata 目录及其子目录中的所有 .sy 文件 # --- 修改点: 查找所有 .sy 文件并按文件名前缀数字排序 ---
# 遍历找到的每个 .sy 文件 sy_files=$(find "${TESTDATA_DIR}" -name "*.sy" | sort -V)
find "${TESTDATA_DIR}" -name "*.sy" | while read sy_file; do TOTAL_CASES=$(echo "$sy_files" | wc -w)
# 获取 .sy 文件的基本名称例如21_if_test2
# 这也处理了文件位于子目录中的情况例如functional/21_if_test2.sy # --- 本次修复: 使用 here-string (<<<) 代替管道 (|) 来避免子 shell 问题 ---
# 这样可以确保循环内的 PASSED_CASES 变量修改在循环结束后依然有效
while IFS= read -r sy_file; do
is_passed=1 # 1 表示通过, 0 表示失败
relative_path_no_ext=$(realpath --relative-to="${TESTDATA_DIR}" "${sy_file%.*}") relative_path_no_ext=$(realpath --relative-to="${TESTDATA_DIR}" "${sy_file%.*}")
# 将斜杠替换为下划线,用于输出文件名,以避免冲突并保持结构
output_base_name=$(echo "${relative_path_no_ext}" | tr '/' '_') output_base_name=$(echo "${relative_path_no_ext}" | tr '/' '_')
# 定义汇编文件、可执行文件、输入文件和输出文件的路径
assembly_file="${TMP_DIR}/${output_base_name}_sysyc_riscv64.s" assembly_file="${TMP_DIR}/${output_base_name}_sysyc_riscv64.s"
executable_file="${TMP_DIR}/${output_base_name}_sysyc_riscv64" executable_file="${TMP_DIR}/${output_base_name}_sysyc_riscv64"
input_file="${sy_file%.*}.in" input_file="${sy_file%.*}.in"
output_reference_file="${sy_file%.*}.out" output_reference_file="${sy_file%.*}.out"
output_actual_file="${TMP_DIR}/${output_base_name}_sysyc_riscv64.actual_out" output_actual_file="${TMP_DIR}/${output_base_name}_sysyc_riscv64.actual_out"
echo "正在处理: $(basename "$sy_file")" echo "正在处理: $(basename "$sy_file") (路径: ${relative_path_no_ext}.sy)"
echo " SY 文件: ${sy_file}"
# 步骤 1: 使用 sysyc 编译 .sy 到 .s # 步骤 1: 使用 sysyc 编译 .sy 到 .s
echo " 使用 sysyc 编译: ${SYSYC} -s asm \"${sy_file}\" > \"${assembly_file}\"" echo " 使用 sysyc 编译 (超时 ${SYSYC_TIMEOUT}s)..."
"${SYSYC}" -s asm "${sy_file}" > "${assembly_file}" timeout ${SYSYC_TIMEOUT} "${SYSYC}" -S "${sy_file}" -o "${assembly_file}"
if [ $? -ne 0 ]; then SYSYC_STATUS=$?
echo -e "\e[31m错误: SysY 编译 ${sy_file} 失败\e[0m" if [ $SYSYC_STATUS -eq 124 ]; then
echo -e "\e[31m错误: SysY 编译 ${sy_file} 超时\e[0m"
is_passed=0
elif [ $SYSYC_STATUS -ne 0 ]; then
echo -e "\e[31m错误: SysY 编译 ${sy_file} 失败,退出码: ${SYSYC_STATUS}\e[0m"
is_passed=0
fi
# 只有当 EXECUTE_MODE 为 true 且上一步成功时才继续
if ${EXECUTE_MODE} && [ "$is_passed" -eq 1 ]; then
# 步骤 2: 使用 riscv64-linux-gnu-gcc 编译 .s 到可执行文件
echo " 使用 gcc 编译 (超时 ${GCC_TIMEOUT}s)..."
timeout ${GCC_TIMEOUT} "${GCC_RISCV64}" "${assembly_file}" -o "${executable_file}" -L"${LIB_DIR}" -lsysy_riscv -static
GCC_STATUS=$?
if [ $GCC_STATUS -eq 124 ]; then
echo -e "\e[31m错误: GCC 编译 ${assembly_file} 超时\e[0m"
is_passed=0
elif [ $GCC_STATUS -ne 0 ]; then
echo -e "\e[31m错误: GCC 编译 ${assembly_file} 失败,退出码: ${GCC_STATUS}\e[0m"
is_passed=0
fi
elif ! ${EXECUTE_MODE}; then
echo " 跳过执行模式。仅生成汇编文件。"
# 如果只编译不执行,只要编译成功就算通过
if [ "$is_passed" -eq 1 ]; then
((PASSED_CASES++))
else
# --- 本次修改点 ---
FAILED_CASES_LIST+="${relative_path_no_ext}.sy\n"
fi
echo ""
continue continue
fi fi
echo " 生成的汇编文件: ${assembly_file}"
# 只有当 EXECUTE_MODE 为 true 时才继续生成和执行可执行文件 # 步骤 3, 4, 5: 只有当编译都成功时才执行
if ${EXECUTE_MODE}; then if [ "$is_passed" -eq 1 ]; then
# 步骤 2: 使用 riscv64-linux-gnu-gcc 编译 .s 到可执行文件 echo " 正在执行 (超时 ${EXEC_TIMEOUT}s)..."
echo " 使用 gcc 编译: ${GCC_RISCV64} \"${assembly_file}\" -o \"${executable_file}\" -L\"${LIB_DIR}\" -lsysy_riscv -static"
"${GCC_RISCV64}" "${assembly_file}" -o "${executable_file}" -L"${LIB_DIR}" -lsysy_riscv -static # 准备执行命令
if [ $? -ne 0 ]; then exec_cmd="${QEMU_RISCV64} \"${executable_file}\""
echo -e "\e[31m错误: GCC 编译 ${assembly_file} 失败\e[0m" if [ -f "${input_file}" ]; then
continue exec_cmd+=" < \"${input_file}\""
fi fi
echo " 生成的可执行文件: ${executable_file}" exec_cmd+=" > \"${output_actual_file}\""
# 步骤 3, 4, 5: 执行编译后的文件并比较/报告结果 # 执行并捕获返回码
echo " 正在执行: ${QEMU_RISCV664} \"${executable_file}\"" eval "timeout ${EXEC_TIMEOUT} ${exec_cmd}"
ACTUAL_RETURN_CODE=$?
# 检查是否存在 .out 文件 if [ "$ACTUAL_RETURN_CODE" -eq 124 ]; then
if [ -f "${output_reference_file}" ]; then echo -e "\e[31m 执行超时: ${sy_file} 运行超过 ${EXEC_TIMEOUT} 秒\e[0m"
# 尝试从 .out 文件中提取期望的返回码和期望的标准输出 is_passed=0
# 获取 .out 文件的最后一行,去除空白字符
LAST_LINE_TRIMMED=$(tail -n 1 "${output_reference_file}" | tr -d '[:space:]')
# 检查最后一行是否为纯整数 (允许正负号)
if [[ "$LAST_LINE_TRIMMED" =~ ^[-+]?[0-9]+$ ]]; then
# 假设最后一行是期望的返回码
EXPECTED_RETURN_CODE="$LAST_LINE_TRIMMED"
# 创建一个只包含期望标准输出的临时文件 (所有行除了最后一行)
EXPECTED_STDOUT_FILE="${TMP_DIR}/${output_base_name}_sysyc_riscv64.expected_stdout"
# 使用 head -n -1 来获取除了最后一行之外的所有行。如果文件只有一行,则生成一个空文件。
head -n -1 "${output_reference_file}" > "${EXPECTED_STDOUT_FILE}"
echo " 检测到 .out 文件同时包含标准输出和期望的返回码。"
echo " 期望返回码: ${EXPECTED_RETURN_CODE}"
if [ -s "${EXPECTED_STDOUT_FILE}" ]; then # -s 检查文件是否非空
echo " 期望标准输出文件: ${EXPECTED_STDOUT_FILE}"
else
echo " 期望标准输出为空。"
fi
# 执行程序,捕获实际返回码和实际标准输出
if [ -f "${input_file}" ]; then
echo " 使用输入文件: ${input_file}"
"${QEMU_RISCV64}" "${executable_file}" < "${input_file}" > "${output_actual_file}"
else
"${QEMU_RISCV64}" "${executable_file}" > "${output_actual_file}"
fi
ACTUAL_RETURN_CODE=$? # 捕获执行状态
# 比较实际返回码与期望返回码
if [ "$ACTUAL_RETURN_CODE" -eq "$EXPECTED_RETURN_CODE" ]; then
echo -e "\e[32m 返回码测试成功: ${sy_file} 的返回码 (${ACTUAL_RETURN_CODE}) 与期望值 (${EXPECTED_RETURN_CODE}) 匹配\e[0m"
else
echo -e "\e[31m 返回码测试失败: ${sy_file} 的返回码不匹配。期望: ${EXPECTED_RETURN_CODE}, 实际: ${ACTUAL_RETURN_CODE}\e[0m"
fi
# 比较实际标准输出与期望标准输出,忽略文件末尾的换行符差异
# 使用 sed 命令去除文件末尾的所有换行符,再通过 diff 进行比较
if diff -q <(sed ':a;N;$!ba;s/\n*$//' "${output_actual_file}") <(sed ':a;N;$!ba;s/\n*$//' "${EXPECTED_STDOUT_FILE}") >/dev/null 2>&1; then
echo -e "\e[32m 标准输出测试成功: 输出与 ${sy_file} 的参考输出匹配 (忽略行尾换行符差异)\e[0m"
else
echo -e "\e[31m 标准输出测试失败: ${sy_file} 的输出不匹配\e[0m"
echo " 差异 (可能包含行尾换行符差异):"
diff "${output_actual_file}" "${EXPECTED_STDOUT_FILE}" # 显示原始差异以便调试
fi
else
# 最后一行不是纯整数,将整个 .out 文件视为纯标准输出
echo " 检测到 .out 文件为纯标准输出参考。正在与输出文件比较: ${output_reference_file}"
# 使用输入文件(如果存在)运行可执行文件,并将输出重定向到临时文件
if [ -f "${input_file}" ]; then
echo " 使用输入文件: ${input_file}"
"${QEMU_RISCV64}" "${executable_file}" < "${input_file}" > "${output_actual_file}"
else
"${QEMU_RISCV64}" "${executable_file}" > "${output_actual_file}"
fi
EXEC_STATUS=$? # 捕获执行状态
if [ $EXEC_STATUS -ne 0 ]; then
echo -e "\e[33m警告: 可执行文件 ${sy_file} 以非零状态 ${EXEC_STATUS} 退出 (纯输出比较模式)。请检查程序逻辑或其是否应返回此状态。\e[0m"
fi
# 比较实际输出与参考输出,忽略文件末尾的换行符差异
if diff -q <(sed ':a;N;$!ba;s/\n*$//' "${output_actual_file}") <(sed ':a;N;$!ba;s/\n*$//' "${output_reference_file}") >/dev/null 2>&1; then
echo -e "\e[32m 成功: 输出与 ${sy_file} 的参考输出匹配 (忽略行尾换行符差异)\e[0m"
else
echo -e "\e[31m 失败: ${sy_file} 的输出不匹配\e[0m"
echo " 差异 (可能包含行尾换行符差异):"
diff "${output_actual_file}" "${output_reference_file}" # 显示原始差异以便调试
fi
fi
elif [ -f "${input_file}" ]; then
# 只有 .in 文件存在,使用输入运行并报告退出码(无参考输出)
echo " 使用输入文件: ${input_file}"
echo " 没有 .out 文件进行比较。正在运行并报告返回码。"
"${QEMU_RISCV64}" "${executable_file}" < "${input_file}"
EXEC_STATUS=$?
echo " ${sy_file} 的返回码: ${EXEC_STATUS}"
else else
# .in 和 .out 文件都不存在,只运行并报告退出码 # 检查是否存在 .out 文件以进行比较
echo " 未找到 .in 或 .out 文件。正在运行并报告返回码。" if [ -f "${output_reference_file}" ]; then
"${QEMU_RISCV64}" "${executable_file}" LAST_LINE_TRIMMED=$(tail -n 1 "${output_reference_file}" | tr -d '[:space:]')
EXEC_STATUS=$?
echo " ${sy_file} 的返回码: ${EXEC_STATUS}" if [[ "$LAST_LINE_TRIMMED" =~ ^[-+]?[0-9]+$ ]]; then
fi EXPECTED_RETURN_CODE="$LAST_LINE_TRIMMED"
else EXPECTED_STDOUT_FILE="${TMP_DIR}/${output_base_name}_sysyc_riscv64.expected_stdout"
echo " 跳过执行模式。仅生成汇编文件。" head -n -1 "${output_reference_file}" > "${EXPECTED_STDOUT_FILE}"
fi
echo "" # 为测试用例之间添加一个空行,以提高可读性
done
echo "脚本完成。" # 比较返回码
if [ "$ACTUAL_RETURN_CODE" -eq "$EXPECTED_RETURN_CODE" ]; then
echo -e "\e[32m 返回码测试成功: (${ACTUAL_RETURN_CODE}) 与期望值 (${EXPECTED_RETURN_CODE}) 匹配\e[0m"
else
echo -e "\e[31m 返回码测试失败: 期望: ${EXPECTED_RETURN_CODE}, 实际: ${ACTUAL_RETURN_CODE}\e[0m"
is_passed=0
fi
# 比较标准输出
if diff -q <(sed ':a;N;$!ba;s/\n*$//' "${output_actual_file}") <(sed ':a;N;$!ba;s/\n*$//' "${EXPECTED_STDOUT_FILE}") >/dev/null 2>&1; then
echo -e "\e[32m 标准输出测试成功\e[0m"
else
echo -e "\e[31m 标准输出测试失败\e[0m"
is_passed=0
echo -e " \e[36m---------- 期望输出 ----------\e[0m"
cat "${EXPECTED_STDOUT_FILE}"
echo -e " \e[36m---------- 实际输出 ----------\e[0m"
cat "${output_actual_file}"
echo -e " \e[36m------------------------------\e[0m"
fi
else
# 纯标准输出比较
if [ $ACTUAL_RETURN_CODE -ne 0 ]; then
echo -e "\e[33m警告: 程序以非零状态 ${ACTUAL_RETURN_CODE} 退出 (纯输出比较模式)。\e[0m"
fi
if diff -q <(sed ':a;N;$!ba;s/\n*$//' "${output_actual_file}") <(sed ':a;N;$!ba;s/\n*$//' "${output_reference_file}") >/dev/null 2>&1; then
echo -e "\e[32m 成功: 输出与参考输出匹配\e[0m"
else
echo -e "\e[31m 失败: 输出不匹配\e[0m"
is_passed=0
echo -e " \e[36m---------- 期望输出 ----------\e[0m"
cat "${output_reference_file}"
echo -e " \e[36m---------- 实际输出 ----------\e[0m"
cat "${output_actual_file}"
echo -e " \e[36m------------------------------\e[0m"
fi
fi
else
# 没有 .out 文件,只报告返回码
echo " 无参考输出文件。程序返回码: ${ACTUAL_RETURN_CODE}"
fi
fi
fi
# 更新通过用例计数
# --- 本次修改点 ---
if [ "$is_passed" -eq 1 ]; then
((PASSED_CASES++))
else
# 将失败的用例名称添加到列表中
FAILED_CASES_LIST+="${relative_path_no_ext}.sy\n"
fi
echo "" # 添加空行以提高可读性
done <<< "$sy_files"
# --- 新增功能: 打印最终总结 ---
echo "========================================"
echo "测试完成"
echo "测试通过率: [${PASSED_CASES}/${TOTAL_CASES}]"
# --- 本次修改点: 打印未通过的测例列表 ---
if [ -n "$FAILED_CASES_LIST" ]; then
echo ""
echo -e "\e[31m未通过的测例:\e[0m"
# 使用 -e 来解释换行符 \n
echo -e "${FAILED_CASES_LIST}"
fi
echo "========================================"
if [ "$PASSED_CASES" -eq "$TOTAL_CASES" ]; then
exit 0
else
exit 1
fi

1
testdata/functional/00_main.out vendored Normal file
View File

@@ -0,0 +1 @@
3

3
testdata/functional/00_main.sy vendored Normal file
View File

@@ -0,0 +1,3 @@
int main(){
return 3;
}

View File

@@ -1,2 +1 @@
10 10
0

8
testdata/functional/01_var_defn2.sy vendored Normal file
View File

@@ -0,0 +1,8 @@
//test domain of global var define and local define
int a = 3;
int b = 5;
int main(){
int a = 5;
return a + b;
}

1
testdata/functional/02_var_defn3.out vendored Normal file
View File

@@ -0,0 +1 @@
5

8
testdata/functional/02_var_defn3.sy vendored Normal file
View File

@@ -0,0 +1,8 @@
//test local var define
int main(){
int a, b0, _c;
a = 1;
b0 = 2;
_c = 3;
return b0 + _c;
}

1
testdata/functional/03_arr_defn2.out vendored Normal file
View File

@@ -0,0 +1 @@
0

4
testdata/functional/03_arr_defn2.sy vendored Normal file
View File

@@ -0,0 +1,4 @@
int a[10][10];
int main(){
return 0;
}

1
testdata/functional/04_arr_defn3.out vendored Normal file
View File

@@ -0,0 +1 @@
14

9
testdata/functional/04_arr_defn3.sy vendored Normal file
View File

@@ -0,0 +1,9 @@
//test array define
int main(){
int a[4][2] = {};
int b[4][2] = {1, 2, 3, 4, 5, 6, 7, 8};
int c[4][2] = {{1, 2}, {3, 4}, {5, 6}, {7, 8}};
int d[4][2] = {1, 2, {3}, {5}, 7 , 8};
int e[4][2] = {{d[2][1], c[2][1]}, {3, 4}, {5, 6}, {7, 8}};
return e[3][1] + e[0][0] + e[0][1] + a[2][0];
}

1
testdata/functional/05_arr_defn4.out vendored Normal file
View File

@@ -0,0 +1 @@
21

9
testdata/functional/05_arr_defn4.sy vendored Normal file
View File

@@ -0,0 +1,9 @@
int main(){
const int a[4][2] = {{1, 2}, {3, 4}, {}, 7};
int b[4][2] = {};
int c[4][2] = {1, 2, 3, 4, 5, 6, 7, 8};
int d[3 + 1][2] = {1, 2, {3}, {5}, a[3][0], 8};
int e[4][2][1] = {{d[2][1], {c[2][1]}}, {3, 4}, {5, 6}, {7, 8}};
return e[3][1][0] + e[0][0][0] + e[0][1][0] + d[3][0];
}

View File

@@ -0,0 +1 @@
5

View File

@@ -0,0 +1,6 @@
//test const gloal var define
const int a = 10, b = 5;
int main(){
return b;
}

View File

@@ -0,0 +1 @@
5

View File

@@ -0,0 +1,5 @@
//test const local var define
int main(){
const int a = 10, b = 5;
return b;
}

View File

@@ -0,0 +1 @@
4

View File

@@ -0,0 +1,5 @@
const int a[5]={0,1,2,3,4};
int main(){
return a[4];
}

1
testdata/functional/09_func_defn.out vendored Normal file
View File

@@ -0,0 +1 @@
9

11
testdata/functional/09_func_defn.sy vendored Normal file
View File

@@ -0,0 +1,11 @@
int a;
int func(int p){
p = p - 1;
return p;
}
int main(){
int b;
a = 10;
b = func(a);
return b;
}

View File

@@ -0,0 +1 @@
4

View File

@@ -0,0 +1,8 @@
int defn(){
return 4;
}
int main(){
int a=defn();
return a;
}

1
testdata/functional/11_add2.out vendored Normal file
View File

@@ -0,0 +1 @@
9

7
testdata/functional/11_add2.sy vendored Normal file
View File

@@ -0,0 +1,7 @@
//test add
int main(){
int a, b;
a = 10;
b = -1;
return a + b;
}

1
testdata/functional/12_addc.out vendored Normal file
View File

@@ -0,0 +1 @@
15

5
testdata/functional/12_addc.sy vendored Normal file
View File

@@ -0,0 +1,5 @@
//test addc
const int a = 10;
int main(){
return a + 5;
}

1
testdata/functional/13_sub2.out vendored Normal file
View File

@@ -0,0 +1 @@
248

7
testdata/functional/13_sub2.sy vendored Normal file
View File

@@ -0,0 +1,7 @@
//test sub
const int a = 10;
int main(){
int b;
b = 2;
return b - a;
}

1
testdata/functional/14_subc.out vendored Normal file
View File

@@ -0,0 +1 @@
8

6
testdata/functional/14_subc.sy vendored Normal file
View File

@@ -0,0 +1,6 @@
//test subc
int main(){
int a;
a = 10;
return a - 2;
}

1
testdata/functional/15_mul.out vendored Normal file
View File

@@ -0,0 +1 @@
50

7
testdata/functional/15_mul.sy vendored Normal file
View File

@@ -0,0 +1,7 @@
//test mul
int main(){
int a, b;
a = 10;
b = 5;
return a * b;
}

1
testdata/functional/16_mulc.out vendored Normal file
View File

@@ -0,0 +1 @@
25

5
testdata/functional/16_mulc.sy vendored Normal file
View File

@@ -0,0 +1,5 @@
//test mulc
const int a = 5;
int main(){
return a * 5;
}

1
testdata/functional/17_div.out vendored Normal file
View File

@@ -0,0 +1 @@
2

7
testdata/functional/17_div.sy vendored Normal file
View File

@@ -0,0 +1,7 @@
//test div
int main(){
int a, b;
a = 10;
b = 5;
return a / b;
}

1
testdata/functional/18_divc.out vendored Normal file
View File

@@ -0,0 +1 @@
2

5
testdata/functional/18_divc.sy vendored Normal file
View File

@@ -0,0 +1,5 @@
//test divc
const int a = 10;
int main(){
return a / 5;
}

1
testdata/functional/19_mod.out vendored Normal file
View File

@@ -0,0 +1 @@
3

6
testdata/functional/19_mod.sy vendored Normal file
View File

@@ -0,0 +1,6 @@
//test mod
int main(){
int a;
a = 10;
return a / 3;
}

1
testdata/functional/20_rem.out vendored Normal file
View File

@@ -0,0 +1 @@
1

6
testdata/functional/20_rem.sy vendored Normal file
View File

@@ -0,0 +1,6 @@
//test rem
int main(){
int a;
a = 10;
return a % 3;
}

0
testdata/functional/21_if_test2.out vendored Executable file → Normal file
View File

0
testdata/functional/21_if_test2.sy vendored Executable file → Normal file
View File

1
testdata/functional/22_if_test3.out vendored Normal file
View File

@@ -0,0 +1 @@
25

18
testdata/functional/22_if_test3.sy vendored Normal file
View File

@@ -0,0 +1,18 @@
// test if-if-else
int ififElse() {
int a;
a = 5;
int b;
b = 10;
if(a == 5)
if (b == 10)
a = 25;
else
a = a + 15;
return (a);
}
int main(){
return (ififElse());
}

1
testdata/functional/23_if_test4.out vendored Normal file
View File

@@ -0,0 +1 @@
25

18
testdata/functional/23_if_test4.sy vendored Normal file
View File

@@ -0,0 +1,18 @@
// test if-{if-else}
int if_ifElse_() {
int a;
a = 5;
int b;
b = 10;
if(a == 5){
if (b == 10)
a = 25;
else
a = a + 15;
}
return (a);
}
int main(){
return (if_ifElse_());
}

1
testdata/functional/24_if_test5.out vendored Normal file
View File

@@ -0,0 +1 @@
25

18
testdata/functional/24_if_test5.sy vendored Normal file
View File

@@ -0,0 +1,18 @@
// test if-{if}-else
int if_if_Else() {
int a;
a = 5;
int b;
b = 10;
if(a == 5){
if (b == 10)
a = 25;
}
else
a = a + 15;
return (a);
}
int main(){
return (if_if_Else());
}

2
testdata/functional/25_while_if.out vendored Normal file
View File

@@ -0,0 +1,2 @@
88
0

31
testdata/functional/25_while_if.sy vendored Normal file
View File

@@ -0,0 +1,31 @@
int get_one(int a) {
return 1;
}
int deepWhileBr(int a, int b) {
int c;
c = a + b;
while (c < 75) {
int d;
d = 42;
if (c < 100) {
c = c + d;
if (c > 99) {
int e;
e = d * 2;
if (get_one(0) == 1) {
c = e * 2;
}
}
}
}
return (c);
}
int main() {
int p;
p = 2;
p = deepWhileBr(p, p);
putint(p);
return 0;
}

0
testdata/functional/26_while_test1.out vendored Executable file → Normal file
View File

0
testdata/functional/26_while_test1.sy vendored Executable file → Normal file
View File

View File

@@ -0,0 +1 @@
54

31
testdata/functional/27_while_test2.sy vendored Normal file
View File

@@ -0,0 +1,31 @@
int FourWhile() {
int a;
a = 5;
int b;
int c;
b = 6;
c = 7;
int d;
d = 10;
while (a < 20) {
a = a + 3;
while(b < 10){
b = b + 1;
while(c == 7){
c = c - 1;
while(d < 20){
d = d + 3;
}
d = d - 1;
}
c = c + 1;
}
b = b - 2;
}
return (a + (b + d) + c);
}
int main() {
return FourWhile();
}

View File

@@ -0,0 +1 @@
23

55
testdata/functional/28_while_test3.sy vendored Normal file
View File

@@ -0,0 +1,55 @@
int g;
int h;
int f;
int e;
int EightWhile() {
int a;
a = 5;
int b;
int c;
b = 6;
c = 7;
int d;
d = 10;
while (a < 20) {
a = a + 3;
while(b < 10){
b = b + 1;
while(c == 7){
c = c - 1;
while(d < 20){
d = d + 3;
while(e > 1){
e = e-1;
while(f > 2){
f = f -2;
while(g < 3){
g = g +10;
while(h < 10){
h = h + 8;
}
h = h-1;
}
g = g- 8;
}
f = f + 1;
}
e = e + 1;
}
d = d - 1;
}
c = c + 1;
}
b = b - 2;
}
return (a + (b + d) + c)-(e + d - g + h);
}
int main() {
g = 1;
h = 2;
e = 4;
f = 6;
return EightWhile();
}

1
testdata/functional/29_break.out vendored Normal file
View File

@@ -0,0 +1 @@
201

15
testdata/functional/29_break.sy vendored Normal file
View File

@@ -0,0 +1,15 @@
//test break
int main(){
int i;
i = 0;
int sum;
sum = 0;
while(i < 100){
if(i == 50){
break;
}
sum = sum + i;
i = i + 1;
}
return sum;
}

1
testdata/functional/30_continue.out vendored Normal file
View File

@@ -0,0 +1 @@
36

16
testdata/functional/30_continue.sy vendored Normal file
View File

@@ -0,0 +1,16 @@
//test continue
int main(){
int i;
i = 0;
int sum;
sum = 0;
while(i < 100){
if(i == 50){
i = i + 1;
continue;
}
sum = sum + i;
i = i + 1;
}
return sum;
}

View File

@@ -0,0 +1 @@
198

View File

@@ -0,0 +1,25 @@
// test while-if
int whileIf() {
int a;
a = 0;
int b;
b = 0;
while (a < 100) {
if (a == 5) {
b = 25;
}
else if (a == 10) {
b = 42;
}
else {
b = a * 2;
}
a = a + 1;
}
return (b);
}
int main(){
return (whileIf());
}

View File

@@ -0,0 +1 @@
96

View File

@@ -0,0 +1,23 @@
int ifWhile() {
int a;
a = 0;
int b;
b = 3;
if (a == 5) {
while(b == 2){
b = b + 2;
}
b = b + 25;
}
else
while (a < 5) {
b = b * 2;
a = a + 1;
}
return (b);
}
int main(){
return (ifWhile());
}

View File

@@ -0,0 +1 @@
88

View File

@@ -0,0 +1,25 @@
int deepWhileBr(int a, int b) {
int c;
c = a + b;
while (c < 75) {
int d;
d = 42;
if (c < 100) {
c = c + d;
if (c > 99) {
int e;
e = d * 2;
if (1 == 1) {
c = e * 2;
}
}
}
}
return (c);
}
int main() {
int p;
p = 2;
return deepWhileBr(p, p);
}

View File

@@ -0,0 +1 @@
51

11
testdata/functional/34_arr_expr_len.sy vendored Normal file
View File

@@ -0,0 +1,11 @@
//const int N = -1;
int arr[-1 + 2 * 4 - 99 / 99] = {1, 2, 33, 4, 5, 6};
int main() {
int i = 0, sum = 0;
while (i < 6) {
sum = sum + arr[i];
i = i + 1;
}
return sum;
}

0
testdata/functional/35_op_priority1.out vendored Executable file → Normal file
View File

0
testdata/functional/35_op_priority1.sy vendored Executable file → Normal file
View File

View File

@@ -0,0 +1 @@
24

View File

@@ -0,0 +1,9 @@
//test the priority of add and mul
int main(){
int a, b, c, d;
a = 10;
b = 4;
c = 2;
d = 2;
return (c + a) * (b - d);
}

Some files were not shown because too many files have changed in this diff Show More