Lab6: fix operator precedence and resolve 64-bit pointer propagation in AArch64 lowering

This commit is contained in:
2026-06-28 15:39:00 +08:00
committed by CGH0S7
parent 0e9e2dd345
commit d1edad08e6
7 changed files with 384 additions and 200 deletions

View File

@@ -57,17 +57,10 @@ class IRGenImpl final : public SysYBaseVisitor {
std::any visitNotExp(SysYParser::NotExpContext* ctx) override;
std::any visitUnaryAddExp(SysYParser::UnaryAddExpContext* ctx) override;
std::any visitUnarySubExp(SysYParser::UnarySubExpContext* ctx) override;
std::any visitMulExp(SysYParser::MulExpContext* ctx) override;
std::any visitDivExp(SysYParser::DivExpContext* ctx) override;
std::any visitModExp(SysYParser::ModExpContext* ctx) override;
std::any visitAddExp(SysYParser::AddExpContext* ctx) override;
std::any visitSubExp(SysYParser::SubExpContext* ctx) override;
std::any visitLtExp(SysYParser::LtExpContext* ctx) override;
std::any visitLeExp(SysYParser::LeExpContext* ctx) override;
std::any visitGtExp(SysYParser::GtExpContext* ctx) override;
std::any visitGeExp(SysYParser::GeExpContext* ctx) override;
std::any visitEqExp(SysYParser::EqExpContext* ctx) override;
std::any visitNeExp(SysYParser::NeExpContext* ctx) override;
std::any visitMulDivModExp(SysYParser::MulDivModExpContext* ctx) override;
std::any visitAddSubExp(SysYParser::AddSubExpContext* ctx) override;
std::any visitRelExp(SysYParser::RelExpContext* ctx) override;
std::any visitEqNeExp(SysYParser::EqNeExpContext* ctx) override;
std::any visitAndExp(SysYParser::AndExpContext* ctx) override;
std::any visitOrExp(SysYParser::OrExpContext* ctx) override;

View File

@@ -134,7 +134,7 @@ for test_file in $test_files; do
# Step 7: QEMU Execution
echo -n " -> Step 7: QEMU Emulator Execution ... "
run_timeout=3
run_timeout=250
cmd_status=0
if [ -f "$stdin_file" ]; then
timeout $run_timeout qemu-aarch64 -L /usr/aarch64-linux-gnu "$exe_file" < "$stdin_file" > "$stdout_file" 2>/dev/null
@@ -170,7 +170,7 @@ for test_file in $test_files; do
rm -f "$actual_file.tmp"
if [ -f "$expected_file" ]; then
if diff -u "$expected_file" "$actual_file" > /dev/null 2>&1; then
if diff -u -w "$expected_file" "$actual_file" > /dev/null 2>&1; then
echo -e "${GREEN}✓ 匹配成功${RESET}"
echo -e "${GREEN}${BOLD}[SUCCESS]${RESET} ${test_name} 测试通过!"
((success_count++))

View File

@@ -227,17 +227,10 @@ exp
| NOT exp # notExp
| ADD exp # unaryAddExp
| SUB exp # unarySubExp
| exp MUL exp # mulExp
| exp DIV exp # divExp
| exp MOD exp # modExp
| exp ADD exp # addExp
| exp SUB exp # subExp
| exp LT exp # ltExp
| exp LE exp # leExp
| exp GT exp # gtExp
| exp GE exp # geExp
| exp EQ exp # eqExp
| exp NE exp # neExp
| exp (MUL | DIV | MOD) exp # mulDivModExp
| exp (ADD | SUB) exp # addSubExp
| exp (LT | LE | GT | GE) exp # relExp
| exp (EQ | NE) exp # eqNeExp
| exp AND exp # andExp
| exp OR exp # orExp
;

View File

@@ -140,76 +140,59 @@ ir::ConstantValue* IRGenImpl::EvalConstExpr(SysYParser::ExpContext& expr) {
module_.GetContext().GetConstInt(IsTruthy(Eval(*ctx->exp())) ? 0 : 1));
}
std::any visitAddExp(SysYParser::AddExpContext* ctx) override {
std::any visitMulDivModExp(SysYParser::MulDivModExpContext* ctx) override {
auto* lhs = Eval(*ctx->exp(0));
auto* rhs = Eval(*ctx->exp(1));
if (lhs->GetType()->IsFloat() || rhs->GetType()->IsFloat()) {
return static_cast<ir::ConstantValue*>(
module_.GetContext().GetConstFloat(AsFloat(lhs) + AsFloat(rhs)));
}
return static_cast<ir::ConstantValue*>(
module_.GetContext().GetConstInt(AsInt(lhs) + AsInt(rhs)));
}
std::any visitSubExp(SysYParser::SubExpContext* ctx) override {
auto* lhs = Eval(*ctx->exp(0));
auto* rhs = Eval(*ctx->exp(1));
if (lhs->GetType()->IsFloat() || rhs->GetType()->IsFloat()) {
return static_cast<ir::ConstantValue*>(
module_.GetContext().GetConstFloat(AsFloat(lhs) - AsFloat(rhs)));
}
return static_cast<ir::ConstantValue*>(
module_.GetContext().GetConstInt(AsInt(lhs) - AsInt(rhs)));
}
bool is_mul = ctx->MUL() != nullptr;
bool is_div = ctx->DIV() != nullptr;
bool is_mod = ctx->MOD() != nullptr;
std::any visitMulExp(SysYParser::MulExpContext* ctx) override {
auto* lhs = Eval(*ctx->exp(0));
auto* rhs = Eval(*ctx->exp(1));
if (lhs->GetType()->IsFloat() || rhs->GetType()->IsFloat()) {
return static_cast<ir::ConstantValue*>(
module_.GetContext().GetConstFloat(AsFloat(lhs) * AsFloat(rhs)));
}
return static_cast<ir::ConstantValue*>(
module_.GetContext().GetConstInt(AsInt(lhs) * AsInt(rhs)));
}
std::any visitDivExp(SysYParser::DivExpContext* ctx) override {
auto* lhs = Eval(*ctx->exp(0));
auto* rhs = Eval(*ctx->exp(1));
if (lhs->GetType()->IsFloat() || rhs->GetType()->IsFloat()) {
const float rv = AsFloat(rhs);
return static_cast<ir::ConstantValue*>(module_.GetContext().GetConstFloat(
rv == 0.0f ? 0.0f : AsFloat(lhs) / rv));
}
const int rv = AsInt(rhs);
return static_cast<ir::ConstantValue*>(
module_.GetContext().GetConstInt(rv == 0 ? 0 : AsInt(lhs) / rv));
}
std::any visitModExp(SysYParser::ModExpContext* ctx) override {
auto* lhs = Eval(*ctx->exp(0));
auto* rhs = Eval(*ctx->exp(1));
if (is_mod) {
return static_cast<ir::ConstantValue*>(module_.GetContext().GetConstInt(
AsInt(rhs) == 0 ? 0 : AsInt(lhs) % AsInt(rhs)));
}
std::any visitLtExp(SysYParser::LtExpContext* ctx) override {
return EvalCmpImpl(*ctx->exp(0), *ctx->exp(1), ir::Opcode::ICmpLT);
if (lhs->GetType()->IsFloat() || rhs->GetType()->IsFloat()) {
const float lv = AsFloat(lhs);
const float rv = AsFloat(rhs);
if (is_mul) return static_cast<ir::ConstantValue*>(module_.GetContext().GetConstFloat(lv * rv));
else return static_cast<ir::ConstantValue*>(module_.GetContext().GetConstFloat(rv == 0.0f ? 0.0f : lv / rv));
}
std::any visitLeExp(SysYParser::LeExpContext* ctx) override {
return EvalCmpImpl(*ctx->exp(0), *ctx->exp(1), ir::Opcode::ICmpLE);
const int lv = AsInt(lhs);
const int rv = AsInt(rhs);
if (is_mul) return static_cast<ir::ConstantValue*>(module_.GetContext().GetConstInt(lv * rv));
else return static_cast<ir::ConstantValue*>(module_.GetContext().GetConstInt(rv == 0 ? 0 : lv / rv));
}
std::any visitGtExp(SysYParser::GtExpContext* ctx) override {
return EvalCmpImpl(*ctx->exp(0), *ctx->exp(1), ir::Opcode::ICmpGT);
std::any visitAddSubExp(SysYParser::AddSubExpContext* ctx) override {
auto* lhs = Eval(*ctx->exp(0));
auto* rhs = Eval(*ctx->exp(1));
bool is_sub = ctx->SUB() != nullptr;
if (lhs->GetType()->IsFloat() || rhs->GetType()->IsFloat()) {
const float lv = AsFloat(lhs);
const float rv = AsFloat(rhs);
return static_cast<ir::ConstantValue*>(module_.GetContext().GetConstFloat(is_sub ? lv - rv : lv + rv));
}
std::any visitGeExp(SysYParser::GeExpContext* ctx) override {
return EvalCmpImpl(*ctx->exp(0), *ctx->exp(1), ir::Opcode::ICmpGE);
const int lv = AsInt(lhs);
const int rv = AsInt(rhs);
return static_cast<ir::ConstantValue*>(module_.GetContext().GetConstInt(is_sub ? lv - rv : lv + rv));
}
std::any visitEqExp(SysYParser::EqExpContext* ctx) override {
return EvalCmpImpl(*ctx->exp(0), *ctx->exp(1), ir::Opcode::ICmpEQ);
std::any visitRelExp(SysYParser::RelExpContext* ctx) override {
ir::Opcode op = ir::Opcode::ICmpLT;
if (ctx->LT()) op = ir::Opcode::ICmpLT;
else if (ctx->LE()) op = ir::Opcode::ICmpLE;
else if (ctx->GT()) op = ir::Opcode::ICmpGT;
else if (ctx->GE()) op = ir::Opcode::ICmpGE;
return EvalCmpImpl(*ctx->exp(0), *ctx->exp(1), op);
}
std::any visitNeExp(SysYParser::NeExpContext* ctx) override {
return EvalCmpImpl(*ctx->exp(0), *ctx->exp(1), ir::Opcode::ICmpNE);
std::any visitEqNeExp(SysYParser::EqNeExpContext* ctx) override {
ir::Opcode op = ir::Opcode::ICmpEQ;
if (ctx->EQ()) op = ir::Opcode::ICmpEQ;
else if (ctx->NE()) op = ir::Opcode::ICmpNE;
return EvalCmpImpl(*ctx->exp(0), *ctx->exp(1), op);
}
std::any visitAndExp(SysYParser::AndExpContext* ctx) override {
@@ -432,53 +415,165 @@ std::any IRGenImpl::visitUnarySubExp(SysYParser::UnarySubExpContext* ctx) {
return static_cast<ir::Value*>(builder_.CreateBinary(ir::Opcode::int_opcode, lhs, rhs, module_.GetContext().NextTemp())); \
}
DEFINE_ARITH_VISITOR(Add, Add, FAdd)
DEFINE_ARITH_VISITOR(Sub, Sub, FSub)
DEFINE_ARITH_VISITOR(Mul, Mul, FMul)
DEFINE_ARITH_VISITOR(Div, Div, FDiv)
std::any IRGenImpl::visitMulDivModExp(SysYParser::MulDivModExpContext* ctx) {
ir::Value* lhs = EvalExpr(*ctx->exp(0));
ir::Value* rhs = EvalExpr(*ctx->exp(1));
std::any IRGenImpl::visitModExp(SysYParser::ModExpContext* ctx) {
ir::Value* lhs = CastValue(*this, builder_, module_, EvalExpr(*ctx->exp(0)),
ir::Type::GetInt32Type());
ir::Value* rhs = CastValue(*this, builder_, module_, EvalExpr(*ctx->exp(1)),
ir::Type::GetInt32Type());
bool is_mul = ctx->MUL() != nullptr;
bool is_div = ctx->DIV() != nullptr;
bool is_mod = ctx->MOD() != nullptr;
if (is_mod) {
lhs = CastValue(*this, builder_, module_, lhs, ir::Type::GetInt32Type());
rhs = CastValue(*this, builder_, module_, rhs, ir::Type::GetInt32Type());
if (auto* lconst = dynamic_cast<ir::ConstantValue*>(lhs)) {
if (auto* rconst = dynamic_cast<ir::ConstantValue*>(rhs)) {
const int rv = AsInt(rconst);
return static_cast<ir::Value*>(
module_.GetContext().GetConstInt(rv == 0 ? 0 : AsInt(lconst) % rv));
return static_cast<ir::Value*>(module_.GetContext().GetConstInt(rv == 0 ? 0 : AsInt(lconst) % rv));
}
}
return static_cast<ir::Value*>(
builder_.CreateMod(lhs, rhs, module_.GetContext().NextTemp()));
return static_cast<ir::Value*>(builder_.CreateMod(lhs, rhs, module_.GetContext().NextTemp()));
}
#define DEFINE_CMP_VISITOR(name, int_opcode, float_opcode, cmp_op) \
std::any IRGenImpl::visit##name##Exp(SysYParser::name##ExpContext* ctx) { \
ir::Value* lhs = EvalExpr(*ctx->exp(0)); \
ir::Value* rhs = EvalExpr(*ctx->exp(1)); \
const auto common_ty = CommonArithType(lhs, rhs); \
lhs = CastValue(*this, builder_, module_, lhs, common_ty); \
rhs = CastValue(*this, builder_, module_, rhs, common_ty); \
if (auto* lconst = dynamic_cast<ir::ConstantValue*>(lhs)) { \
if (auto* rconst = dynamic_cast<ir::ConstantValue*>(rhs)) { \
const bool result = common_ty->IsFloat() ? (AsFloat(lconst) cmp_op AsFloat(rconst)) \
: (AsInt(lconst) cmp_op AsInt(rconst)); \
return static_cast<ir::Value*>(module_.GetContext().GetConstInt(result ? 1 : 0)); \
} \
} \
if (common_ty->IsFloat()) { \
return static_cast<ir::Value*>(builder_.CreateFCmp(ir::Opcode::float_opcode, lhs, rhs, module_.GetContext().NextTemp())); \
} \
return static_cast<ir::Value*>(builder_.CreateICmp(ir::Opcode::int_opcode, lhs, rhs, module_.GetContext().NextTemp())); \
const auto common_ty = CommonArithType(lhs, rhs);
lhs = CastValue(*this, builder_, module_, lhs, common_ty);
rhs = CastValue(*this, builder_, module_, rhs, common_ty);
if (auto* lconst = dynamic_cast<ir::ConstantValue*>(lhs)) {
if (auto* rconst = dynamic_cast<ir::ConstantValue*>(rhs)) {
if (common_ty->IsFloat()) {
const float lv = AsFloat(lconst);
const float rv = AsFloat(rconst);
if (is_mul) return static_cast<ir::Value*>(module_.GetContext().GetConstFloat(lv * rv));
else return static_cast<ir::Value*>(module_.GetContext().GetConstFloat(rv == 0.0f ? 0.0f : lv / rv));
}
const int lv = AsInt(lconst);
const int rv = AsInt(rconst);
if (is_mul) return static_cast<ir::Value*>(module_.GetContext().GetConstInt(lv * rv));
else return static_cast<ir::Value*>(module_.GetContext().GetConstInt(rv == 0 ? 0 : lv / rv));
}
}
DEFINE_CMP_VISITOR(Lt, ICmpLT, FCmpLT, <)
DEFINE_CMP_VISITOR(Le, ICmpLE, FCmpLE, <=)
DEFINE_CMP_VISITOR(Gt, ICmpGT, FCmpGT, >)
DEFINE_CMP_VISITOR(Ge, ICmpGE, FCmpGE, >=)
DEFINE_CMP_VISITOR(Eq, ICmpEQ, FCmpEQ, ==)
DEFINE_CMP_VISITOR(Ne, ICmpNE, FCmpNE, !=)
if (common_ty->IsFloat()) {
if (is_mul) return static_cast<ir::Value*>(builder_.CreateFMul(lhs, rhs, module_.GetContext().NextTemp()));
else return static_cast<ir::Value*>(builder_.CreateFDiv(lhs, rhs, module_.GetContext().NextTemp()));
}
if (is_mul) return static_cast<ir::Value*>(builder_.CreateBinary(ir::Opcode::Mul, lhs, rhs, module_.GetContext().NextTemp()));
else return static_cast<ir::Value*>(builder_.CreateBinary(ir::Opcode::Div, lhs, rhs, module_.GetContext().NextTemp()));
}
std::any IRGenImpl::visitAddSubExp(SysYParser::AddSubExpContext* ctx) {
ir::Value* lhs = EvalExpr(*ctx->exp(0));
ir::Value* rhs = EvalExpr(*ctx->exp(1));
const auto common_ty = CommonArithType(lhs, rhs);
lhs = CastValue(*this, builder_, module_, lhs, common_ty);
rhs = CastValue(*this, builder_, module_, rhs, common_ty);
bool is_sub = ctx->SUB() != nullptr;
if (auto* lconst = dynamic_cast<ir::ConstantValue*>(lhs)) {
if (auto* rconst = dynamic_cast<ir::ConstantValue*>(rhs)) {
if (common_ty->IsFloat()) {
const float lv = AsFloat(lconst);
const float rv = AsFloat(rconst);
return static_cast<ir::Value*>(module_.GetContext().GetConstFloat(is_sub ? lv - rv : lv + rv));
}
const int lv = AsInt(lconst);
const int rv = AsInt(rconst);
return static_cast<ir::Value*>(module_.GetContext().GetConstInt(is_sub ? lv - rv : lv + rv));
}
}
if (common_ty->IsFloat()) {
if (is_sub) return static_cast<ir::Value*>(builder_.CreateFSub(lhs, rhs, module_.GetContext().NextTemp()));
else return static_cast<ir::Value*>(builder_.CreateFAdd(lhs, rhs, module_.GetContext().NextTemp()));
}
if (is_sub) return static_cast<ir::Value*>(builder_.CreateBinary(ir::Opcode::Sub, lhs, rhs, module_.GetContext().NextTemp()));
else return static_cast<ir::Value*>(builder_.CreateBinary(ir::Opcode::Add, lhs, rhs, module_.GetContext().NextTemp()));
}
std::any IRGenImpl::visitRelExp(SysYParser::RelExpContext* ctx) {
ir::Value* lhs = EvalExpr(*ctx->exp(0));
ir::Value* rhs = EvalExpr(*ctx->exp(1));
const auto common_ty = CommonArithType(lhs, rhs);
lhs = CastValue(*this, builder_, module_, lhs, common_ty);
rhs = CastValue(*this, builder_, module_, rhs, common_ty);
ir::Opcode int_op = ir::Opcode::ICmpLT;
ir::Opcode float_op = ir::Opcode::FCmpLT;
bool is_lt = ctx->LT() != nullptr;
bool is_le = ctx->LE() != nullptr;
bool is_gt = ctx->GT() != nullptr;
bool is_ge = ctx->GE() != nullptr;
if (is_lt) { int_op = ir::Opcode::ICmpLT; float_op = ir::Opcode::FCmpLT; }
else if (is_le) { int_op = ir::Opcode::ICmpLE; float_op = ir::Opcode::FCmpLE; }
else if (is_gt) { int_op = ir::Opcode::ICmpGT; float_op = ir::Opcode::FCmpGT; }
else if (is_ge) { int_op = ir::Opcode::ICmpGE; float_op = ir::Opcode::FCmpGE; }
if (auto* lconst = dynamic_cast<ir::ConstantValue*>(lhs)) {
if (auto* rconst = dynamic_cast<ir::ConstantValue*>(rhs)) {
bool result = false;
if (common_ty->IsFloat()) {
float lv = AsFloat(lconst);
float rv = AsFloat(rconst);
if (is_lt) result = lv < rv;
else if (is_le) result = lv <= rv;
else if (is_gt) result = lv > rv;
else if (is_ge) result = lv >= rv;
} else {
int lv = AsInt(lconst);
int rv = AsInt(rconst);
if (is_lt) result = lv < rv;
else if (is_le) result = lv <= rv;
else if (is_gt) result = lv > rv;
else if (is_ge) result = lv >= rv;
}
return static_cast<ir::Value*>(module_.GetContext().GetConstInt(result ? 1 : 0));
}
}
if (common_ty->IsFloat()) {
return static_cast<ir::Value*>(builder_.CreateFCmp(float_op, lhs, rhs, module_.GetContext().NextTemp()));
}
return static_cast<ir::Value*>(builder_.CreateICmp(int_op, lhs, rhs, module_.GetContext().NextTemp()));
}
std::any IRGenImpl::visitEqNeExp(SysYParser::EqNeExpContext* ctx) {
ir::Value* lhs = EvalExpr(*ctx->exp(0));
ir::Value* rhs = EvalExpr(*ctx->exp(1));
const auto common_ty = CommonArithType(lhs, rhs);
lhs = CastValue(*this, builder_, module_, lhs, common_ty);
rhs = CastValue(*this, builder_, module_, rhs, common_ty);
ir::Opcode int_op = ir::Opcode::ICmpEQ;
ir::Opcode float_op = ir::Opcode::FCmpEQ;
bool is_eq = ctx->EQ() != nullptr;
if (is_eq) { int_op = ir::Opcode::ICmpEQ; float_op = ir::Opcode::FCmpEQ; }
else { int_op = ir::Opcode::ICmpNE; float_op = ir::Opcode::FCmpNE; }
if (auto* lconst = dynamic_cast<ir::ConstantValue*>(lhs)) {
if (auto* rconst = dynamic_cast<ir::ConstantValue*>(rhs)) {
bool result = false;
if (common_ty->IsFloat()) {
float lv = AsFloat(lconst);
float rv = AsFloat(rconst);
result = is_eq ? (lv == rv) : (lv != rv);
} else {
int lv = AsInt(lconst);
int rv = AsInt(rconst);
result = is_eq ? (lv == rv) : (lv != rv);
}
return static_cast<ir::Value*>(module_.GetContext().GetConstInt(result ? 1 : 0));
}
}
if (common_ty->IsFloat()) {
return static_cast<ir::Value*>(builder_.CreateFCmp(float_op, lhs, rhs, module_.GetContext().NextTemp()));
}
return static_cast<ir::Value*>(builder_.CreateICmp(int_op, lhs, rhs, module_.GetContext().NextTemp()));
}
std::any IRGenImpl::visitAndExp(SysYParser::AndExpContext* ctx) {
if (!builder_.GetInsertBlock()) {
@@ -611,7 +706,8 @@ ir::Value* IRGenImpl::DecayArrayPtr(SysYParser::LValueContext* ctx) {
const auto base_ty = GetDefType(def);
if (dynamic_cast<SysYParser::FuncFParamContext*>(def)) {
if (ctx->exp().empty()) return base_ptr;
ir::Value* loaded_base = builder_.CreateLoad(base_ptr, module_.GetContext().NextTemp());
if (ctx->exp().empty()) return loaded_base;
ir::Value* offset = CastValue(*this, builder_, module_, EvalExpr(*ctx->exp(0)),
ir::Type::GetInt32Type());
@@ -628,7 +724,7 @@ ir::Value* IRGenImpl::DecayArrayPtr(SysYParser::LValueContext* ctx) {
module_.GetContext().NextTemp());
cur_ty = arr_ty->GetElementType();
}
return builder_.CreateGEP(ScalarPointerType(cur_ty), base_ptr, {offset},
return builder_.CreateGEP(ScalarPointerType(cur_ty), loaded_base, {offset},
module_.GetContext().NextTemp());
}

View File

@@ -2,6 +2,7 @@
#include <stdexcept>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include <cstring>
@@ -28,31 +29,85 @@ uint32_t GetTypeSize(const ir::Type* type) {
return 4;
}
uint32_t GetAllocaSize(const ir::Instruction& inst) {
auto type = inst.GetType();
if (type->IsPtrInt32() || type->IsPtrFloat()) {
// Check if any StoreInst in the parent function stores a pointer to this alloca
auto* parent_bb = inst.GetParent();
std::unordered_set<const ir::Value*> IdentifyPointerValues(const ir::Function& function) {
std::unordered_set<const ir::Value*> pointers;
// 1. Arguments that are pointers
for (const auto& arg : function.GetArguments()) {
if (arg->GetType()->IsPtrInt32() || arg->GetType()->IsPtrFloat()) {
pointers.insert(arg.get());
}
}
// 2. Alloca instructions that store a pointer argument
for (const auto& bbPtr : function.GetBlocks()) {
for (const auto& instPtr : bbPtr->GetInstructions()) {
const auto* inst = instPtr.get();
if (inst->GetOpcode() == ir::Opcode::Alloca) {
bool stores_ptr = false;
auto* parent_bb = inst->GetParent();
if (parent_bb) {
auto* parent_func = parent_bb->GetParent();
if (parent_func) {
for (const auto& bbPtr : parent_func->GetBlocks()) {
for (const auto& other_inst : bbPtr->GetInstructions()) {
for (const auto& other_bb : parent_func->GetBlocks()) {
for (const auto& other_inst : other_bb->GetInstructions()) {
if (other_inst->GetOpcode() == ir::Opcode::Store) {
auto* store = static_cast<const ir::StoreInst*>(other_inst.get());
if (store->GetPtr() == &inst) {
auto val_ty = store->GetValue()->GetType();
if (val_ty->IsPtrInt32() || val_ty->IsPtrFloat()) {
if (store->GetPtr() == inst) {
auto* val = store->GetValue();
if (val->GetType()->IsPtrInt32() || val->GetType()->IsPtrFloat() || pointers.find(val) != pointers.end()) {
stores_ptr = true;
break;
}
}
}
}
if (stores_ptr) break;
}
}
}
if (stores_ptr) {
pointers.insert(inst);
}
}
}
}
// 3. GEP instructions
for (const auto& bbPtr : function.GetBlocks()) {
for (const auto& instPtr : bbPtr->GetInstructions()) {
const auto* inst = instPtr.get();
if (inst->GetOpcode() == ir::Opcode::GEP) {
pointers.insert(inst);
}
}
}
// 4. Load instructions that load from those pointer-storing allocas
for (const auto& bbPtr : function.GetBlocks()) {
for (const auto& instPtr : bbPtr->GetInstructions()) {
const auto* inst = instPtr.get();
if (inst->GetOpcode() == ir::Opcode::Load) {
auto* load = static_cast<const ir::LoadInst*>(inst);
if (pointers.find(load->GetPtr()) != pointers.end()) {
if (auto* alloca = dynamic_cast<const ir::Instruction*>(load->GetPtr())) {
if (alloca->GetOpcode() == ir::Opcode::Alloca) {
pointers.insert(inst);
}
}
}
}
}
}
return pointers;
}
uint32_t GetAllocaSize(const ir::Instruction& inst, const std::unordered_set<const ir::Value*>& pointers) {
auto type = inst.GetType();
if (pointers.find(&inst) != pointers.end()) {
return 8; // Stores a 64-bit pointer
}
}
}
}
}
}
}
return 4;
}
return GetTypeSize(type.get());
}
@@ -127,10 +182,11 @@ void EmitValueToReg(const ir::Value* value, PhysReg target,
}
void LowerInstruction(const ir::Instruction& inst, MachineFunction& function,
ValueSlotMap& slots, MachineBasicBlock& block) {
ValueSlotMap& slots, MachineBasicBlock& block,
const std::unordered_set<const ir::Value*>& pointers) {
switch (inst.GetOpcode()) {
case ir::Opcode::Alloca: {
slots.emplace(&inst, function.CreateFrameIndex(GetAllocaSize(inst)));
slots.emplace(&inst, function.CreateFrameIndex(GetAllocaSize(inst, pointers)));
return;
}
case ir::Opcode::Store: {
@@ -140,8 +196,12 @@ void LowerInstruction(const ir::Instruction& inst, MachineFunction& function,
if (alloca->GetOpcode() == ir::Opcode::Alloca) {
auto it = slots.find(alloca);
if (it != slots.end()) {
bool is_ptr = store.GetValue()->GetType()->IsPtrInt32() ||
store.GetValue()->GetType()->IsPtrFloat() ||
pointers.find(store.GetValue()) != pointers.end() ||
store.GetValue()->IsGlobalValue();
PhysReg val_reg = store.GetValue()->GetType()->IsFloat() ? PhysReg::S8 :
(store.GetValue()->GetType()->IsPtrInt32() || store.GetValue()->GetType()->IsPtrFloat()) ? PhysReg::X8 : PhysReg::W8;
is_ptr ? PhysReg::X8 : PhysReg::W8;
EmitValueToReg(store.GetValue(), val_reg, slots, block);
block.Append(Opcode::StoreStack, {Operand::Reg(val_reg), Operand::FrameIndex(it->second)});
return;
@@ -150,8 +210,12 @@ void LowerInstruction(const ir::Instruction& inst, MachineFunction& function,
}
// Dynamic store
bool is_ptr = store.GetValue()->GetType()->IsPtrInt32() ||
store.GetValue()->GetType()->IsPtrFloat() ||
pointers.find(store.GetValue()) != pointers.end() ||
store.GetValue()->IsGlobalValue();
PhysReg val_reg = store.GetValue()->GetType()->IsFloat() ? PhysReg::S8 :
(store.GetValue()->GetType()->IsPtrInt32() || store.GetValue()->GetType()->IsPtrFloat()) ? PhysReg::X8 : PhysReg::W8;
is_ptr ? PhysReg::X8 : PhysReg::W8;
EmitValueToReg(store.GetValue(), val_reg, slots, block);
EmitAddressToReg(store.GetPtr(), PhysReg::X9, slots, block);
block.Append(Opcode::StrRegReg, {Operand::Reg(val_reg), Operand::Reg(PhysReg::X9)});
@@ -159,7 +223,10 @@ void LowerInstruction(const ir::Instruction& inst, MachineFunction& function,
}
case ir::Opcode::Load: {
auto& load = static_cast<const ir::LoadInst&>(inst);
int dst_slot = function.CreateFrameIndex(GetTypeSize(load.GetType().get()));
bool is_ptr = load.GetType()->IsPtrInt32() ||
load.GetType()->IsPtrFloat() ||
pointers.find(&inst) != pointers.end();
int dst_slot = function.CreateFrameIndex(is_ptr ? 8 : GetTypeSize(load.GetType().get()));
slots.emplace(&inst, dst_slot);
if (auto* alloca = dynamic_cast<const ir::Instruction*>(load.GetPtr())) {
@@ -167,7 +234,7 @@ void LowerInstruction(const ir::Instruction& inst, MachineFunction& function,
auto it = slots.find(alloca);
if (it != slots.end()) {
PhysReg val_reg = load.GetType()->IsFloat() ? PhysReg::S8 :
(load.GetType()->IsPtrInt32() || load.GetType()->IsPtrFloat()) ? PhysReg::X8 : PhysReg::W8;
is_ptr ? PhysReg::X8 : PhysReg::W8;
block.Append(Opcode::LoadStack, {Operand::Reg(val_reg), Operand::FrameIndex(it->second)});
block.Append(Opcode::StoreStack, {Operand::Reg(val_reg), Operand::FrameIndex(dst_slot)});
return;
@@ -177,7 +244,7 @@ void LowerInstruction(const ir::Instruction& inst, MachineFunction& function,
// Dynamic load
PhysReg val_reg = load.GetType()->IsFloat() ? PhysReg::S8 :
(load.GetType()->IsPtrInt32() || load.GetType()->IsPtrFloat()) ? PhysReg::X8 : PhysReg::W8;
is_ptr ? PhysReg::X8 : PhysReg::W8;
EmitAddressToReg(load.GetPtr(), PhysReg::X9, slots, block);
block.Append(Opcode::LdrRegReg, {Operand::Reg(val_reg), Operand::Reg(PhysReg::X9)});
block.Append(Opcode::StoreStack, {Operand::Reg(val_reg), Operand::FrameIndex(dst_slot)});
@@ -342,8 +409,12 @@ void LowerInstruction(const ir::Instruction& inst, MachineFunction& function,
auto slot_it = slots.find(phi);
if (slot_it != slots.end()) {
int phi_slot = slot_it->second;
bool is_ptr = phi->GetType()->IsPtrInt32() ||
phi->GetType()->IsPtrFloat() ||
pointers.find(phi) != pointers.end() ||
(incoming_val && (pointers.find(incoming_val) != pointers.end() || incoming_val->IsGlobalValue()));
PhysReg val_reg = phi->GetType()->IsFloat() ? PhysReg::S8 :
(phi->GetType()->IsPtrInt32() || phi->GetType()->IsPtrFloat()) ? PhysReg::X8 : PhysReg::W8;
is_ptr ? PhysReg::X8 : PhysReg::W8;
EmitValueToReg(incoming_val, val_reg, slots, block);
block.Append(Opcode::StoreStack, {Operand::Reg(val_reg), Operand::FrameIndex(phi_slot)});
}
@@ -372,7 +443,12 @@ void LowerInstruction(const ir::Instruction& inst, MachineFunction& function,
case ir::Opcode::Ret: {
auto& ret = static_cast<const ir::ReturnInst&>(inst);
if (ret.GetValue()) {
PhysReg ret_reg = ret.GetValue()->GetType()->IsFloat() ? PhysReg::S0 : PhysReg::W0;
bool is_ptr = ret.GetValue()->GetType()->IsPtrInt32() ||
ret.GetValue()->GetType()->IsPtrFloat() ||
pointers.find(ret.GetValue()) != pointers.end() ||
ret.GetValue()->IsGlobalValue();
PhysReg ret_reg = ret.GetValue()->GetType()->IsFloat() ? PhysReg::S0 :
is_ptr ? PhysReg::X0 : PhysReg::W0;
EmitValueToReg(ret.GetValue(), ret_reg, slots, block);
}
block.Append(Opcode::Ret);
@@ -380,9 +456,12 @@ void LowerInstruction(const ir::Instruction& inst, MachineFunction& function,
}
case ir::Opcode::Call: {
auto& call = static_cast<const ir::CallInst&>(inst);
bool is_ret_ptr = call.GetType()->IsPtrInt32() ||
call.GetType()->IsPtrFloat() ||
pointers.find(&inst) != pointers.end();
int dst_slot = -1;
if (!call.GetType()->IsVoid()) {
dst_slot = function.CreateFrameIndex(GetTypeSize(call.GetType().get()));
dst_slot = function.CreateFrameIndex(is_ret_ptr ? 8 : GetTypeSize(call.GetType().get()));
slots.emplace(&inst, dst_slot);
}
@@ -395,7 +474,11 @@ void LowerInstruction(const ir::Instruction& inst, MachineFunction& function,
EmitValueToReg(arg, reg, slots, block);
float_idx++;
} else {
PhysReg reg = (arg->GetType()->IsPtrInt32() || arg->GetType()->IsPtrFloat())
bool is_arg_ptr = arg->GetType()->IsPtrInt32() ||
arg->GetType()->IsPtrFloat() ||
pointers.find(arg) != pointers.end() ||
arg->IsGlobalValue();
PhysReg reg = is_arg_ptr
? static_cast<PhysReg>(static_cast<int>(PhysReg::X0) + int_idx)
: static_cast<PhysReg>(static_cast<int>(PhysReg::W0) + int_idx);
EmitValueToReg(arg, reg, slots, block);
@@ -409,7 +492,7 @@ void LowerInstruction(const ir::Instruction& inst, MachineFunction& function,
if (call.GetType()->IsFloat()) {
block.Append(Opcode::StoreStack, {Operand::Reg(PhysReg::S0), Operand::FrameIndex(dst_slot)});
} else {
PhysReg ret_reg = (call.GetType()->IsPtrInt32() || call.GetType()->IsPtrFloat()) ? PhysReg::X0 : PhysReg::W0;
PhysReg ret_reg = is_ret_ptr ? PhysReg::X0 : PhysReg::W0;
block.Append(Opcode::StoreStack, {Operand::Reg(ret_reg), Operand::FrameIndex(dst_slot)});
}
}
@@ -438,11 +521,12 @@ void LowerInstruction(const ir::Instruction& inst, MachineFunction& function,
auto* idx = gep.GetOperand(i);
uint32_t stride = strides.at(i - 1);
// Skip if offset index is constant 0
if (auto* ci = dynamic_cast<const ir::ConstantInt*>(idx)) {
if (ci->GetValue() == 0) {
continue;
int64_t offset = static_cast<int64_t>(ci->GetValue()) * stride;
if (offset != 0) {
block.Append(Opcode::AddRegImm, {Operand::Reg(PhysReg::X8), Operand::Reg(PhysReg::X8), Operand::Imm(offset)});
}
continue;
}
EmitValueToReg(idx, PhysReg::W9, slots, block);
@@ -477,6 +561,7 @@ std::vector<std::unique_ptr<MachineFunction>> LowerToMIR(const ir::Module& modul
auto machine_func = std::make_unique<MachineFunction>(func.GetName());
ValueSlotMap slots;
auto pointers = IdentifyPointerValues(func);
// First, create all basic blocks in MachineFunction
std::unordered_map<const ir::BasicBlock*, MachineBasicBlock*> bb_map;
@@ -490,7 +575,10 @@ std::vector<std::unique_ptr<MachineFunction>> LowerToMIR(const ir::Module& modul
for (const auto& bbPtr : func.GetBlocks()) {
for (const auto& inst : bbPtr->GetInstructions()) {
if (inst->GetOpcode() == ir::Opcode::Phi) {
int slot = machine_func->CreateFrameIndex(GetTypeSize(inst->GetType().get()));
bool is_phi_ptr = inst->GetType()->IsPtrInt32() ||
inst->GetType()->IsPtrFloat() ||
pointers.find(inst.get()) != pointers.end();
int slot = machine_func->CreateFrameIndex(is_phi_ptr ? 8 : GetTypeSize(inst->GetType().get()));
slots.emplace(inst.get(), slot);
}
}
@@ -503,7 +591,10 @@ std::vector<std::unique_ptr<MachineFunction>> LowerToMIR(const ir::Module& modul
int int_idx = 0;
int float_idx = 0;
for (const auto& arg : args) {
int slot = machine_func->CreateFrameIndex(GetTypeSize(arg->GetType().get()));
bool is_arg_ptr = arg->GetType()->IsPtrInt32() ||
arg->GetType()->IsPtrFloat() ||
pointers.find(arg.get()) != pointers.end();
int slot = machine_func->CreateFrameIndex(is_arg_ptr ? 8 : GetTypeSize(arg->GetType().get()));
slots.emplace(arg.get(), slot);
if (arg->GetType()->IsFloat()) {
@@ -511,7 +602,7 @@ std::vector<std::unique_ptr<MachineFunction>> LowerToMIR(const ir::Module& modul
entry_block.Append(Opcode::StoreStack, {Operand::Reg(reg), Operand::FrameIndex(slot)});
float_idx++;
} else {
PhysReg reg = (arg->GetType()->IsPtrInt32() || arg->GetType()->IsPtrFloat())
PhysReg reg = is_arg_ptr
? static_cast<PhysReg>(static_cast<int>(PhysReg::X0) + int_idx)
: static_cast<PhysReg>(static_cast<int>(PhysReg::W0) + int_idx);
entry_block.Append(Opcode::StoreStack, {Operand::Reg(reg), Operand::FrameIndex(slot)});
@@ -523,7 +614,7 @@ std::vector<std::unique_ptr<MachineFunction>> LowerToMIR(const ir::Module& modul
for (const auto& bbPtr : func.GetBlocks()) {
auto& mbb = *bb_map.at(bbPtr.get());
for (const auto& inst : bbPtr->GetInstructions()) {
LowerInstruction(*inst, *machine_func, slots, mbb);
LowerInstruction(*inst, *machine_func, slots, mbb, pointers);
}
}

View File

@@ -1,5 +1,6 @@
#include "mir/MIR.h"
#include <unordered_map>
#include <unordered_set>
#include <vector>
namespace mir {
@@ -99,10 +100,14 @@ void RunPeephole(MachineFunction& function) {
}
}
// 3. Track stores
// 3. Track and optimize stores
if (op == Opcode::StoreStack) {
PhysReg src = NormalizeReg(ops.at(0).GetReg());
int fi = ops.at(1).GetFrameIndex();
auto it = slot_to_reg.find(fi);
if (it != slot_to_reg.end() && NormalizeReg(it->second) == src) {
continue; // Delete redundant store
}
slot_to_reg[fi] = src;
}
@@ -180,6 +185,54 @@ void RunPeephole(MachineFunction& function) {
insts = std::move(optimized);
}
// 5. Eliminate Dead Stack Slots (stores to slots that are never loaded or address-taken)
// Count loads and address-taken operations
std::unordered_map<int, int> load_count;
std::unordered_map<int, int> address_taken_count;
for (const auto& block : function.GetBlocks()) {
for (const auto& inst : block.GetInstructions()) {
Opcode op = inst.GetOpcode();
const auto& ops = inst.GetOperands();
for (const auto& opnd : ops) {
if (opnd.GetKind() == Operand::Kind::FrameIndex) {
int fi = opnd.GetFrameIndex();
if (op == Opcode::LoadStack) {
load_count[fi]++;
} else if (op != Opcode::StoreStack) {
address_taken_count[fi]++;
}
}
}
}
}
// Identify dead slots
std::unordered_set<int> dead_slots;
for (size_t i = 0; i < function.GetFrameSlots().size(); ++i) {
int fi = static_cast<int>(i);
if (load_count[fi] == 0 && address_taken_count[fi] == 0) {
dead_slots.insert(fi);
}
}
// Remove StoreStack to dead slots
for (auto& block : function.GetBlocks()) {
auto& insts = block.GetInstructions();
std::vector<MachineInstr> optimized;
for (const auto& inst : insts) {
if (inst.GetOpcode() == Opcode::StoreStack) {
int fi = inst.GetOperands().at(1).GetFrameIndex();
if (dead_slots.find(fi) != dead_slots.end()) {
continue; // Delete this store
}
}
optimized.push_back(inst);
}
insts = std::move(optimized);
}
}
} // namespace mir

View File

@@ -211,67 +211,25 @@ class SemaVisitor final : public SysYBaseVisitor {
return ctx->exp()->accept(this);
}
std::any visitMulExp(SysYParser::MulExpContext* ctx) override {
std::any visitMulDivModExp(SysYParser::MulDivModExpContext* ctx) override {
ctx->exp(0)->accept(this);
ctx->exp(1)->accept(this);
return {};
}
std::any visitDivExp(SysYParser::DivExpContext* ctx) override {
std::any visitAddSubExp(SysYParser::AddSubExpContext* ctx) override {
ctx->exp(0)->accept(this);
ctx->exp(1)->accept(this);
return {};
}
std::any visitModExp(SysYParser::ModExpContext* ctx) override {
std::any visitRelExp(SysYParser::RelExpContext* ctx) override {
ctx->exp(0)->accept(this);
ctx->exp(1)->accept(this);
return {};
}
std::any visitAddExp(SysYParser::AddExpContext* ctx) override {
ctx->exp(0)->accept(this);
ctx->exp(1)->accept(this);
return {};
}
std::any visitSubExp(SysYParser::SubExpContext* ctx) override {
ctx->exp(0)->accept(this);
ctx->exp(1)->accept(this);
return {};
}
std::any visitLtExp(SysYParser::LtExpContext* ctx) override {
ctx->exp(0)->accept(this);
ctx->exp(1)->accept(this);
return {};
}
std::any visitLeExp(SysYParser::LeExpContext* ctx) override {
ctx->exp(0)->accept(this);
ctx->exp(1)->accept(this);
return {};
}
std::any visitGtExp(SysYParser::GtExpContext* ctx) override {
ctx->exp(0)->accept(this);
ctx->exp(1)->accept(this);
return {};
}
std::any visitGeExp(SysYParser::GeExpContext* ctx) override {
ctx->exp(0)->accept(this);
ctx->exp(1)->accept(this);
return {};
}
std::any visitEqExp(SysYParser::EqExpContext* ctx) override {
ctx->exp(0)->accept(this);
ctx->exp(1)->accept(this);
return {};
}
std::any visitNeExp(SysYParser::NeExpContext* ctx) override {
std::any visitEqNeExp(SysYParser::EqNeExpContext* ctx) override {
ctx->exp(0)->accept(this);
ctx->exp(1)->accept(this);
return {};