diff --git a/include/irgen/IRGen.h b/include/irgen/IRGen.h index 6dab7ea..6fd49ad 100644 --- a/include/irgen/IRGen.h +++ b/include/irgen/IRGen.h @@ -57,17 +57,10 @@ class IRGenImpl final : public SysYBaseVisitor { std::any visitNotExp(SysYParser::NotExpContext* ctx) override; std::any visitUnaryAddExp(SysYParser::UnaryAddExpContext* ctx) override; std::any visitUnarySubExp(SysYParser::UnarySubExpContext* ctx) override; - std::any visitMulExp(SysYParser::MulExpContext* ctx) override; - std::any visitDivExp(SysYParser::DivExpContext* ctx) override; - std::any visitModExp(SysYParser::ModExpContext* ctx) override; - std::any visitAddExp(SysYParser::AddExpContext* ctx) override; - std::any visitSubExp(SysYParser::SubExpContext* ctx) override; - std::any visitLtExp(SysYParser::LtExpContext* ctx) override; - std::any visitLeExp(SysYParser::LeExpContext* ctx) override; - std::any visitGtExp(SysYParser::GtExpContext* ctx) override; - std::any visitGeExp(SysYParser::GeExpContext* ctx) override; - std::any visitEqExp(SysYParser::EqExpContext* ctx) override; - std::any visitNeExp(SysYParser::NeExpContext* ctx) override; + std::any visitMulDivModExp(SysYParser::MulDivModExpContext* ctx) override; + std::any visitAddSubExp(SysYParser::AddSubExpContext* ctx) override; + std::any visitRelExp(SysYParser::RelExpContext* ctx) override; + std::any visitEqNeExp(SysYParser::EqNeExpContext* ctx) override; std::any visitAndExp(SysYParser::AndExpContext* ctx) override; std::any visitOrExp(SysYParser::OrExpContext* ctx) override; diff --git a/scripts/run_all_tests_verbose.sh b/scripts/run_all_tests_verbose.sh index 5f231cd..cea0a99 100755 --- a/scripts/run_all_tests_verbose.sh +++ b/scripts/run_all_tests_verbose.sh @@ -134,7 +134,7 @@ for test_file in $test_files; do # Step 7: QEMU Execution echo -n " -> Step 7: QEMU Emulator Execution ... " - run_timeout=3 + run_timeout=250 cmd_status=0 if [ -f "$stdin_file" ]; then timeout $run_timeout qemu-aarch64 -L /usr/aarch64-linux-gnu "$exe_file" < "$stdin_file" > "$stdout_file" 2>/dev/null @@ -170,7 +170,7 @@ for test_file in $test_files; do rm -f "$actual_file.tmp" if [ -f "$expected_file" ]; then - if diff -u "$expected_file" "$actual_file" > /dev/null 2>&1; then + if diff -u -w "$expected_file" "$actual_file" > /dev/null 2>&1; then echo -e "${GREEN}✓ 匹配成功${RESET}" echo -e "${GREEN}${BOLD}[SUCCESS]${RESET} ${test_name} 测试通过!" ((success_count++)) diff --git a/src/antlr4/SysY.g4 b/src/antlr4/SysY.g4 index b9713ea..d563a87 100644 --- a/src/antlr4/SysY.g4 +++ b/src/antlr4/SysY.g4 @@ -227,17 +227,10 @@ exp | NOT exp # notExp | ADD exp # unaryAddExp | SUB exp # unarySubExp - | exp MUL exp # mulExp - | exp DIV exp # divExp - | exp MOD exp # modExp - | exp ADD exp # addExp - | exp SUB exp # subExp - | exp LT exp # ltExp - | exp LE exp # leExp - | exp GT exp # gtExp - | exp GE exp # geExp - | exp EQ exp # eqExp - | exp NE exp # neExp + | exp (MUL | DIV | MOD) exp # mulDivModExp + | exp (ADD | SUB) exp # addSubExp + | exp (LT | LE | GT | GE) exp # relExp + | exp (EQ | NE) exp # eqNeExp | exp AND exp # andExp | exp OR exp # orExp ; diff --git a/src/irgen/IRGenExp.cpp b/src/irgen/IRGenExp.cpp index 076f0fc..818d85d 100644 --- a/src/irgen/IRGenExp.cpp +++ b/src/irgen/IRGenExp.cpp @@ -140,76 +140,59 @@ ir::ConstantValue* IRGenImpl::EvalConstExpr(SysYParser::ExpContext& expr) { module_.GetContext().GetConstInt(IsTruthy(Eval(*ctx->exp())) ? 0 : 1)); } - std::any visitAddExp(SysYParser::AddExpContext* ctx) override { + std::any visitMulDivModExp(SysYParser::MulDivModExpContext* ctx) override { auto* lhs = Eval(*ctx->exp(0)); auto* rhs = Eval(*ctx->exp(1)); - if (lhs->GetType()->IsFloat() || rhs->GetType()->IsFloat()) { - return static_cast( - module_.GetContext().GetConstFloat(AsFloat(lhs) + AsFloat(rhs))); - } - return static_cast( - module_.GetContext().GetConstInt(AsInt(lhs) + AsInt(rhs))); - } + + bool is_mul = ctx->MUL() != nullptr; + bool is_div = ctx->DIV() != nullptr; + bool is_mod = ctx->MOD() != nullptr; - std::any visitSubExp(SysYParser::SubExpContext* ctx) override { - auto* lhs = Eval(*ctx->exp(0)); - auto* rhs = Eval(*ctx->exp(1)); - if (lhs->GetType()->IsFloat() || rhs->GetType()->IsFloat()) { - return static_cast( - module_.GetContext().GetConstFloat(AsFloat(lhs) - AsFloat(rhs))); + if (is_mod) { + return static_cast(module_.GetContext().GetConstInt( + AsInt(rhs) == 0 ? 0 : AsInt(lhs) % AsInt(rhs))); } - return static_cast( - module_.GetContext().GetConstInt(AsInt(lhs) - AsInt(rhs))); - } - std::any visitMulExp(SysYParser::MulExpContext* ctx) override { - auto* lhs = Eval(*ctx->exp(0)); - auto* rhs = Eval(*ctx->exp(1)); - if (lhs->GetType()->IsFloat() || rhs->GetType()->IsFloat()) { - return static_cast( - module_.GetContext().GetConstFloat(AsFloat(lhs) * AsFloat(rhs))); - } - return static_cast( - module_.GetContext().GetConstInt(AsInt(lhs) * AsInt(rhs))); - } - - std::any visitDivExp(SysYParser::DivExpContext* ctx) override { - auto* lhs = Eval(*ctx->exp(0)); - auto* rhs = Eval(*ctx->exp(1)); if (lhs->GetType()->IsFloat() || rhs->GetType()->IsFloat()) { + const float lv = AsFloat(lhs); const float rv = AsFloat(rhs); - return static_cast(module_.GetContext().GetConstFloat( - rv == 0.0f ? 0.0f : AsFloat(lhs) / rv)); + if (is_mul) return static_cast(module_.GetContext().GetConstFloat(lv * rv)); + else return static_cast(module_.GetContext().GetConstFloat(rv == 0.0f ? 0.0f : lv / rv)); } + const int lv = AsInt(lhs); const int rv = AsInt(rhs); - return static_cast( - module_.GetContext().GetConstInt(rv == 0 ? 0 : AsInt(lhs) / rv)); + if (is_mul) return static_cast(module_.GetContext().GetConstInt(lv * rv)); + else return static_cast(module_.GetContext().GetConstInt(rv == 0 ? 0 : lv / rv)); } - std::any visitModExp(SysYParser::ModExpContext* ctx) override { + std::any visitAddSubExp(SysYParser::AddSubExpContext* ctx) override { auto* lhs = Eval(*ctx->exp(0)); auto* rhs = Eval(*ctx->exp(1)); - return static_cast(module_.GetContext().GetConstInt( - AsInt(rhs) == 0 ? 0 : AsInt(lhs) % AsInt(rhs))); + bool is_sub = ctx->SUB() != nullptr; + if (lhs->GetType()->IsFloat() || rhs->GetType()->IsFloat()) { + const float lv = AsFloat(lhs); + const float rv = AsFloat(rhs); + return static_cast(module_.GetContext().GetConstFloat(is_sub ? lv - rv : lv + rv)); + } + const int lv = AsInt(lhs); + const int rv = AsInt(rhs); + return static_cast(module_.GetContext().GetConstInt(is_sub ? lv - rv : lv + rv)); } - std::any visitLtExp(SysYParser::LtExpContext* ctx) override { - return EvalCmpImpl(*ctx->exp(0), *ctx->exp(1), ir::Opcode::ICmpLT); + std::any visitRelExp(SysYParser::RelExpContext* ctx) override { + ir::Opcode op = ir::Opcode::ICmpLT; + if (ctx->LT()) op = ir::Opcode::ICmpLT; + else if (ctx->LE()) op = ir::Opcode::ICmpLE; + else if (ctx->GT()) op = ir::Opcode::ICmpGT; + else if (ctx->GE()) op = ir::Opcode::ICmpGE; + return EvalCmpImpl(*ctx->exp(0), *ctx->exp(1), op); } - std::any visitLeExp(SysYParser::LeExpContext* ctx) override { - return EvalCmpImpl(*ctx->exp(0), *ctx->exp(1), ir::Opcode::ICmpLE); - } - std::any visitGtExp(SysYParser::GtExpContext* ctx) override { - return EvalCmpImpl(*ctx->exp(0), *ctx->exp(1), ir::Opcode::ICmpGT); - } - std::any visitGeExp(SysYParser::GeExpContext* ctx) override { - return EvalCmpImpl(*ctx->exp(0), *ctx->exp(1), ir::Opcode::ICmpGE); - } - std::any visitEqExp(SysYParser::EqExpContext* ctx) override { - return EvalCmpImpl(*ctx->exp(0), *ctx->exp(1), ir::Opcode::ICmpEQ); - } - std::any visitNeExp(SysYParser::NeExpContext* ctx) override { - return EvalCmpImpl(*ctx->exp(0), *ctx->exp(1), ir::Opcode::ICmpNE); + + std::any visitEqNeExp(SysYParser::EqNeExpContext* ctx) override { + ir::Opcode op = ir::Opcode::ICmpEQ; + if (ctx->EQ()) op = ir::Opcode::ICmpEQ; + else if (ctx->NE()) op = ir::Opcode::ICmpNE; + return EvalCmpImpl(*ctx->exp(0), *ctx->exp(1), op); } std::any visitAndExp(SysYParser::AndExpContext* ctx) override { @@ -432,53 +415,165 @@ std::any IRGenImpl::visitUnarySubExp(SysYParser::UnarySubExpContext* ctx) { return static_cast(builder_.CreateBinary(ir::Opcode::int_opcode, lhs, rhs, module_.GetContext().NextTemp())); \ } -DEFINE_ARITH_VISITOR(Add, Add, FAdd) -DEFINE_ARITH_VISITOR(Sub, Sub, FSub) -DEFINE_ARITH_VISITOR(Mul, Mul, FMul) -DEFINE_ARITH_VISITOR(Div, Div, FDiv) +std::any IRGenImpl::visitMulDivModExp(SysYParser::MulDivModExpContext* ctx) { + ir::Value* lhs = EvalExpr(*ctx->exp(0)); + ir::Value* rhs = EvalExpr(*ctx->exp(1)); + + bool is_mul = ctx->MUL() != nullptr; + bool is_div = ctx->DIV() != nullptr; + bool is_mod = ctx->MOD() != nullptr; + + if (is_mod) { + lhs = CastValue(*this, builder_, module_, lhs, ir::Type::GetInt32Type()); + rhs = CastValue(*this, builder_, module_, rhs, ir::Type::GetInt32Type()); + if (auto* lconst = dynamic_cast(lhs)) { + if (auto* rconst = dynamic_cast(rhs)) { + const int rv = AsInt(rconst); + return static_cast(module_.GetContext().GetConstInt(rv == 0 ? 0 : AsInt(lconst) % rv)); + } + } + return static_cast(builder_.CreateMod(lhs, rhs, module_.GetContext().NextTemp())); + } + + const auto common_ty = CommonArithType(lhs, rhs); + lhs = CastValue(*this, builder_, module_, lhs, common_ty); + rhs = CastValue(*this, builder_, module_, rhs, common_ty); -std::any IRGenImpl::visitModExp(SysYParser::ModExpContext* ctx) { - ir::Value* lhs = CastValue(*this, builder_, module_, EvalExpr(*ctx->exp(0)), - ir::Type::GetInt32Type()); - ir::Value* rhs = CastValue(*this, builder_, module_, EvalExpr(*ctx->exp(1)), - ir::Type::GetInt32Type()); if (auto* lconst = dynamic_cast(lhs)) { if (auto* rconst = dynamic_cast(rhs)) { + if (common_ty->IsFloat()) { + const float lv = AsFloat(lconst); + const float rv = AsFloat(rconst); + if (is_mul) return static_cast(module_.GetContext().GetConstFloat(lv * rv)); + else return static_cast(module_.GetContext().GetConstFloat(rv == 0.0f ? 0.0f : lv / rv)); + } + const int lv = AsInt(lconst); const int rv = AsInt(rconst); - return static_cast( - module_.GetContext().GetConstInt(rv == 0 ? 0 : AsInt(lconst) % rv)); + if (is_mul) return static_cast(module_.GetContext().GetConstInt(lv * rv)); + else return static_cast(module_.GetContext().GetConstInt(rv == 0 ? 0 : lv / rv)); } } - return static_cast( - builder_.CreateMod(lhs, rhs, module_.GetContext().NextTemp())); + + if (common_ty->IsFloat()) { + if (is_mul) return static_cast(builder_.CreateFMul(lhs, rhs, module_.GetContext().NextTemp())); + else return static_cast(builder_.CreateFDiv(lhs, rhs, module_.GetContext().NextTemp())); + } + if (is_mul) return static_cast(builder_.CreateBinary(ir::Opcode::Mul, lhs, rhs, module_.GetContext().NextTemp())); + else return static_cast(builder_.CreateBinary(ir::Opcode::Div, lhs, rhs, module_.GetContext().NextTemp())); } -#define DEFINE_CMP_VISITOR(name, int_opcode, float_opcode, cmp_op) \ - std::any IRGenImpl::visit##name##Exp(SysYParser::name##ExpContext* ctx) { \ - ir::Value* lhs = EvalExpr(*ctx->exp(0)); \ - ir::Value* rhs = EvalExpr(*ctx->exp(1)); \ - const auto common_ty = CommonArithType(lhs, rhs); \ - lhs = CastValue(*this, builder_, module_, lhs, common_ty); \ - rhs = CastValue(*this, builder_, module_, rhs, common_ty); \ - if (auto* lconst = dynamic_cast(lhs)) { \ - if (auto* rconst = dynamic_cast(rhs)) { \ - const bool result = common_ty->IsFloat() ? (AsFloat(lconst) cmp_op AsFloat(rconst)) \ - : (AsInt(lconst) cmp_op AsInt(rconst)); \ - return static_cast(module_.GetContext().GetConstInt(result ? 1 : 0)); \ - } \ - } \ - if (common_ty->IsFloat()) { \ - return static_cast(builder_.CreateFCmp(ir::Opcode::float_opcode, lhs, rhs, module_.GetContext().NextTemp())); \ - } \ - return static_cast(builder_.CreateICmp(ir::Opcode::int_opcode, lhs, rhs, module_.GetContext().NextTemp())); \ +std::any IRGenImpl::visitAddSubExp(SysYParser::AddSubExpContext* ctx) { + ir::Value* lhs = EvalExpr(*ctx->exp(0)); + ir::Value* rhs = EvalExpr(*ctx->exp(1)); + const auto common_ty = CommonArithType(lhs, rhs); + lhs = CastValue(*this, builder_, module_, lhs, common_ty); + rhs = CastValue(*this, builder_, module_, rhs, common_ty); + + bool is_sub = ctx->SUB() != nullptr; + + if (auto* lconst = dynamic_cast(lhs)) { + if (auto* rconst = dynamic_cast(rhs)) { + if (common_ty->IsFloat()) { + const float lv = AsFloat(lconst); + const float rv = AsFloat(rconst); + return static_cast(module_.GetContext().GetConstFloat(is_sub ? lv - rv : lv + rv)); + } + const int lv = AsInt(lconst); + const int rv = AsInt(rconst); + return static_cast(module_.GetContext().GetConstInt(is_sub ? lv - rv : lv + rv)); + } } -DEFINE_CMP_VISITOR(Lt, ICmpLT, FCmpLT, <) -DEFINE_CMP_VISITOR(Le, ICmpLE, FCmpLE, <=) -DEFINE_CMP_VISITOR(Gt, ICmpGT, FCmpGT, >) -DEFINE_CMP_VISITOR(Ge, ICmpGE, FCmpGE, >=) -DEFINE_CMP_VISITOR(Eq, ICmpEQ, FCmpEQ, ==) -DEFINE_CMP_VISITOR(Ne, ICmpNE, FCmpNE, !=) + if (common_ty->IsFloat()) { + if (is_sub) return static_cast(builder_.CreateFSub(lhs, rhs, module_.GetContext().NextTemp())); + else return static_cast(builder_.CreateFAdd(lhs, rhs, module_.GetContext().NextTemp())); + } + if (is_sub) return static_cast(builder_.CreateBinary(ir::Opcode::Sub, lhs, rhs, module_.GetContext().NextTemp())); + else return static_cast(builder_.CreateBinary(ir::Opcode::Add, lhs, rhs, module_.GetContext().NextTemp())); +} + +std::any IRGenImpl::visitRelExp(SysYParser::RelExpContext* ctx) { + ir::Value* lhs = EvalExpr(*ctx->exp(0)); + ir::Value* rhs = EvalExpr(*ctx->exp(1)); + const auto common_ty = CommonArithType(lhs, rhs); + lhs = CastValue(*this, builder_, module_, lhs, common_ty); + rhs = CastValue(*this, builder_, module_, rhs, common_ty); + + ir::Opcode int_op = ir::Opcode::ICmpLT; + ir::Opcode float_op = ir::Opcode::FCmpLT; + bool is_lt = ctx->LT() != nullptr; + bool is_le = ctx->LE() != nullptr; + bool is_gt = ctx->GT() != nullptr; + bool is_ge = ctx->GE() != nullptr; + + if (is_lt) { int_op = ir::Opcode::ICmpLT; float_op = ir::Opcode::FCmpLT; } + else if (is_le) { int_op = ir::Opcode::ICmpLE; float_op = ir::Opcode::FCmpLE; } + else if (is_gt) { int_op = ir::Opcode::ICmpGT; float_op = ir::Opcode::FCmpGT; } + else if (is_ge) { int_op = ir::Opcode::ICmpGE; float_op = ir::Opcode::FCmpGE; } + + if (auto* lconst = dynamic_cast(lhs)) { + if (auto* rconst = dynamic_cast(rhs)) { + bool result = false; + if (common_ty->IsFloat()) { + float lv = AsFloat(lconst); + float rv = AsFloat(rconst); + if (is_lt) result = lv < rv; + else if (is_le) result = lv <= rv; + else if (is_gt) result = lv > rv; + else if (is_ge) result = lv >= rv; + } else { + int lv = AsInt(lconst); + int rv = AsInt(rconst); + if (is_lt) result = lv < rv; + else if (is_le) result = lv <= rv; + else if (is_gt) result = lv > rv; + else if (is_ge) result = lv >= rv; + } + return static_cast(module_.GetContext().GetConstInt(result ? 1 : 0)); + } + } + + if (common_ty->IsFloat()) { + return static_cast(builder_.CreateFCmp(float_op, lhs, rhs, module_.GetContext().NextTemp())); + } + return static_cast(builder_.CreateICmp(int_op, lhs, rhs, module_.GetContext().NextTemp())); +} + +std::any IRGenImpl::visitEqNeExp(SysYParser::EqNeExpContext* ctx) { + ir::Value* lhs = EvalExpr(*ctx->exp(0)); + ir::Value* rhs = EvalExpr(*ctx->exp(1)); + const auto common_ty = CommonArithType(lhs, rhs); + lhs = CastValue(*this, builder_, module_, lhs, common_ty); + rhs = CastValue(*this, builder_, module_, rhs, common_ty); + + ir::Opcode int_op = ir::Opcode::ICmpEQ; + ir::Opcode float_op = ir::Opcode::FCmpEQ; + bool is_eq = ctx->EQ() != nullptr; + + if (is_eq) { int_op = ir::Opcode::ICmpEQ; float_op = ir::Opcode::FCmpEQ; } + else { int_op = ir::Opcode::ICmpNE; float_op = ir::Opcode::FCmpNE; } + + if (auto* lconst = dynamic_cast(lhs)) { + if (auto* rconst = dynamic_cast(rhs)) { + bool result = false; + if (common_ty->IsFloat()) { + float lv = AsFloat(lconst); + float rv = AsFloat(rconst); + result = is_eq ? (lv == rv) : (lv != rv); + } else { + int lv = AsInt(lconst); + int rv = AsInt(rconst); + result = is_eq ? (lv == rv) : (lv != rv); + } + return static_cast(module_.GetContext().GetConstInt(result ? 1 : 0)); + } + } + + if (common_ty->IsFloat()) { + return static_cast(builder_.CreateFCmp(float_op, lhs, rhs, module_.GetContext().NextTemp())); + } + return static_cast(builder_.CreateICmp(int_op, lhs, rhs, module_.GetContext().NextTemp())); +} std::any IRGenImpl::visitAndExp(SysYParser::AndExpContext* ctx) { if (!builder_.GetInsertBlock()) { @@ -611,7 +706,8 @@ ir::Value* IRGenImpl::DecayArrayPtr(SysYParser::LValueContext* ctx) { const auto base_ty = GetDefType(def); if (dynamic_cast(def)) { - if (ctx->exp().empty()) return base_ptr; + ir::Value* loaded_base = builder_.CreateLoad(base_ptr, module_.GetContext().NextTemp()); + if (ctx->exp().empty()) return loaded_base; ir::Value* offset = CastValue(*this, builder_, module_, EvalExpr(*ctx->exp(0)), ir::Type::GetInt32Type()); @@ -628,7 +724,7 @@ ir::Value* IRGenImpl::DecayArrayPtr(SysYParser::LValueContext* ctx) { module_.GetContext().NextTemp()); cur_ty = arr_ty->GetElementType(); } - return builder_.CreateGEP(ScalarPointerType(cur_ty), base_ptr, {offset}, + return builder_.CreateGEP(ScalarPointerType(cur_ty), loaded_base, {offset}, module_.GetContext().NextTemp()); } diff --git a/src/mir/Lowering.cpp b/src/mir/Lowering.cpp index c9a7eba..b0f2c70 100644 --- a/src/mir/Lowering.cpp +++ b/src/mir/Lowering.cpp @@ -2,6 +2,7 @@ #include #include +#include #include #include @@ -28,30 +29,84 @@ uint32_t GetTypeSize(const ir::Type* type) { return 4; } -uint32_t GetAllocaSize(const ir::Instruction& inst) { - auto type = inst.GetType(); - if (type->IsPtrInt32() || type->IsPtrFloat()) { - // Check if any StoreInst in the parent function stores a pointer to this alloca - auto* parent_bb = inst.GetParent(); - if (parent_bb) { - auto* parent_func = parent_bb->GetParent(); - if (parent_func) { - for (const auto& bbPtr : parent_func->GetBlocks()) { - for (const auto& other_inst : bbPtr->GetInstructions()) { - if (other_inst->GetOpcode() == ir::Opcode::Store) { - auto* store = static_cast(other_inst.get()); - if (store->GetPtr() == &inst) { - auto val_ty = store->GetValue()->GetType(); - if (val_ty->IsPtrInt32() || val_ty->IsPtrFloat()) { - return 8; // Stores a 64-bit pointer +std::unordered_set IdentifyPointerValues(const ir::Function& function) { + std::unordered_set pointers; + + // 1. Arguments that are pointers + for (const auto& arg : function.GetArguments()) { + if (arg->GetType()->IsPtrInt32() || arg->GetType()->IsPtrFloat()) { + pointers.insert(arg.get()); + } + } + + // 2. Alloca instructions that store a pointer argument + for (const auto& bbPtr : function.GetBlocks()) { + for (const auto& instPtr : bbPtr->GetInstructions()) { + const auto* inst = instPtr.get(); + if (inst->GetOpcode() == ir::Opcode::Alloca) { + bool stores_ptr = false; + auto* parent_bb = inst->GetParent(); + if (parent_bb) { + auto* parent_func = parent_bb->GetParent(); + if (parent_func) { + for (const auto& other_bb : parent_func->GetBlocks()) { + for (const auto& other_inst : other_bb->GetInstructions()) { + if (other_inst->GetOpcode() == ir::Opcode::Store) { + auto* store = static_cast(other_inst.get()); + if (store->GetPtr() == inst) { + auto* val = store->GetValue(); + if (val->GetType()->IsPtrInt32() || val->GetType()->IsPtrFloat() || pointers.find(val) != pointers.end()) { + stores_ptr = true; + break; + } + } } } + if (stores_ptr) break; + } + } + } + if (stores_ptr) { + pointers.insert(inst); + } + } + } + } + + // 3. GEP instructions + for (const auto& bbPtr : function.GetBlocks()) { + for (const auto& instPtr : bbPtr->GetInstructions()) { + const auto* inst = instPtr.get(); + if (inst->GetOpcode() == ir::Opcode::GEP) { + pointers.insert(inst); + } + } + } + + // 4. Load instructions that load from those pointer-storing allocas + for (const auto& bbPtr : function.GetBlocks()) { + for (const auto& instPtr : bbPtr->GetInstructions()) { + const auto* inst = instPtr.get(); + if (inst->GetOpcode() == ir::Opcode::Load) { + auto* load = static_cast(inst); + if (pointers.find(load->GetPtr()) != pointers.end()) { + if (auto* alloca = dynamic_cast(load->GetPtr())) { + if (alloca->GetOpcode() == ir::Opcode::Alloca) { + pointers.insert(inst); } } } } } - return 4; + } + + return pointers; +} + +uint32_t GetAllocaSize(const ir::Instruction& inst, const std::unordered_set& pointers) { + auto type = inst.GetType(); + if (pointers.find(&inst) != pointers.end()) { + return 8; // Stores a 64-bit pointer } return GetTypeSize(type.get()); } @@ -127,10 +182,11 @@ void EmitValueToReg(const ir::Value* value, PhysReg target, } void LowerInstruction(const ir::Instruction& inst, MachineFunction& function, - ValueSlotMap& slots, MachineBasicBlock& block) { + ValueSlotMap& slots, MachineBasicBlock& block, + const std::unordered_set& pointers) { switch (inst.GetOpcode()) { case ir::Opcode::Alloca: { - slots.emplace(&inst, function.CreateFrameIndex(GetAllocaSize(inst))); + slots.emplace(&inst, function.CreateFrameIndex(GetAllocaSize(inst, pointers))); return; } case ir::Opcode::Store: { @@ -140,8 +196,12 @@ void LowerInstruction(const ir::Instruction& inst, MachineFunction& function, if (alloca->GetOpcode() == ir::Opcode::Alloca) { auto it = slots.find(alloca); if (it != slots.end()) { + bool is_ptr = store.GetValue()->GetType()->IsPtrInt32() || + store.GetValue()->GetType()->IsPtrFloat() || + pointers.find(store.GetValue()) != pointers.end() || + store.GetValue()->IsGlobalValue(); PhysReg val_reg = store.GetValue()->GetType()->IsFloat() ? PhysReg::S8 : - (store.GetValue()->GetType()->IsPtrInt32() || store.GetValue()->GetType()->IsPtrFloat()) ? PhysReg::X8 : PhysReg::W8; + is_ptr ? PhysReg::X8 : PhysReg::W8; EmitValueToReg(store.GetValue(), val_reg, slots, block); block.Append(Opcode::StoreStack, {Operand::Reg(val_reg), Operand::FrameIndex(it->second)}); return; @@ -150,8 +210,12 @@ void LowerInstruction(const ir::Instruction& inst, MachineFunction& function, } // Dynamic store + bool is_ptr = store.GetValue()->GetType()->IsPtrInt32() || + store.GetValue()->GetType()->IsPtrFloat() || + pointers.find(store.GetValue()) != pointers.end() || + store.GetValue()->IsGlobalValue(); PhysReg val_reg = store.GetValue()->GetType()->IsFloat() ? PhysReg::S8 : - (store.GetValue()->GetType()->IsPtrInt32() || store.GetValue()->GetType()->IsPtrFloat()) ? PhysReg::X8 : PhysReg::W8; + is_ptr ? PhysReg::X8 : PhysReg::W8; EmitValueToReg(store.GetValue(), val_reg, slots, block); EmitAddressToReg(store.GetPtr(), PhysReg::X9, slots, block); block.Append(Opcode::StrRegReg, {Operand::Reg(val_reg), Operand::Reg(PhysReg::X9)}); @@ -159,7 +223,10 @@ void LowerInstruction(const ir::Instruction& inst, MachineFunction& function, } case ir::Opcode::Load: { auto& load = static_cast(inst); - int dst_slot = function.CreateFrameIndex(GetTypeSize(load.GetType().get())); + bool is_ptr = load.GetType()->IsPtrInt32() || + load.GetType()->IsPtrFloat() || + pointers.find(&inst) != pointers.end(); + int dst_slot = function.CreateFrameIndex(is_ptr ? 8 : GetTypeSize(load.GetType().get())); slots.emplace(&inst, dst_slot); if (auto* alloca = dynamic_cast(load.GetPtr())) { @@ -167,7 +234,7 @@ void LowerInstruction(const ir::Instruction& inst, MachineFunction& function, auto it = slots.find(alloca); if (it != slots.end()) { PhysReg val_reg = load.GetType()->IsFloat() ? PhysReg::S8 : - (load.GetType()->IsPtrInt32() || load.GetType()->IsPtrFloat()) ? PhysReg::X8 : PhysReg::W8; + is_ptr ? PhysReg::X8 : PhysReg::W8; block.Append(Opcode::LoadStack, {Operand::Reg(val_reg), Operand::FrameIndex(it->second)}); block.Append(Opcode::StoreStack, {Operand::Reg(val_reg), Operand::FrameIndex(dst_slot)}); return; @@ -177,7 +244,7 @@ void LowerInstruction(const ir::Instruction& inst, MachineFunction& function, // Dynamic load PhysReg val_reg = load.GetType()->IsFloat() ? PhysReg::S8 : - (load.GetType()->IsPtrInt32() || load.GetType()->IsPtrFloat()) ? PhysReg::X8 : PhysReg::W8; + is_ptr ? PhysReg::X8 : PhysReg::W8; EmitAddressToReg(load.GetPtr(), PhysReg::X9, slots, block); block.Append(Opcode::LdrRegReg, {Operand::Reg(val_reg), Operand::Reg(PhysReg::X9)}); block.Append(Opcode::StoreStack, {Operand::Reg(val_reg), Operand::FrameIndex(dst_slot)}); @@ -342,8 +409,12 @@ void LowerInstruction(const ir::Instruction& inst, MachineFunction& function, auto slot_it = slots.find(phi); if (slot_it != slots.end()) { int phi_slot = slot_it->second; + bool is_ptr = phi->GetType()->IsPtrInt32() || + phi->GetType()->IsPtrFloat() || + pointers.find(phi) != pointers.end() || + (incoming_val && (pointers.find(incoming_val) != pointers.end() || incoming_val->IsGlobalValue())); PhysReg val_reg = phi->GetType()->IsFloat() ? PhysReg::S8 : - (phi->GetType()->IsPtrInt32() || phi->GetType()->IsPtrFloat()) ? PhysReg::X8 : PhysReg::W8; + is_ptr ? PhysReg::X8 : PhysReg::W8; EmitValueToReg(incoming_val, val_reg, slots, block); block.Append(Opcode::StoreStack, {Operand::Reg(val_reg), Operand::FrameIndex(phi_slot)}); } @@ -372,7 +443,12 @@ void LowerInstruction(const ir::Instruction& inst, MachineFunction& function, case ir::Opcode::Ret: { auto& ret = static_cast(inst); if (ret.GetValue()) { - PhysReg ret_reg = ret.GetValue()->GetType()->IsFloat() ? PhysReg::S0 : PhysReg::W0; + bool is_ptr = ret.GetValue()->GetType()->IsPtrInt32() || + ret.GetValue()->GetType()->IsPtrFloat() || + pointers.find(ret.GetValue()) != pointers.end() || + ret.GetValue()->IsGlobalValue(); + PhysReg ret_reg = ret.GetValue()->GetType()->IsFloat() ? PhysReg::S0 : + is_ptr ? PhysReg::X0 : PhysReg::W0; EmitValueToReg(ret.GetValue(), ret_reg, slots, block); } block.Append(Opcode::Ret); @@ -380,9 +456,12 @@ void LowerInstruction(const ir::Instruction& inst, MachineFunction& function, } case ir::Opcode::Call: { auto& call = static_cast(inst); + bool is_ret_ptr = call.GetType()->IsPtrInt32() || + call.GetType()->IsPtrFloat() || + pointers.find(&inst) != pointers.end(); int dst_slot = -1; if (!call.GetType()->IsVoid()) { - dst_slot = function.CreateFrameIndex(GetTypeSize(call.GetType().get())); + dst_slot = function.CreateFrameIndex(is_ret_ptr ? 8 : GetTypeSize(call.GetType().get())); slots.emplace(&inst, dst_slot); } @@ -395,7 +474,11 @@ void LowerInstruction(const ir::Instruction& inst, MachineFunction& function, EmitValueToReg(arg, reg, slots, block); float_idx++; } else { - PhysReg reg = (arg->GetType()->IsPtrInt32() || arg->GetType()->IsPtrFloat()) + bool is_arg_ptr = arg->GetType()->IsPtrInt32() || + arg->GetType()->IsPtrFloat() || + pointers.find(arg) != pointers.end() || + arg->IsGlobalValue(); + PhysReg reg = is_arg_ptr ? static_cast(static_cast(PhysReg::X0) + int_idx) : static_cast(static_cast(PhysReg::W0) + int_idx); EmitValueToReg(arg, reg, slots, block); @@ -409,7 +492,7 @@ void LowerInstruction(const ir::Instruction& inst, MachineFunction& function, if (call.GetType()->IsFloat()) { block.Append(Opcode::StoreStack, {Operand::Reg(PhysReg::S0), Operand::FrameIndex(dst_slot)}); } else { - PhysReg ret_reg = (call.GetType()->IsPtrInt32() || call.GetType()->IsPtrFloat()) ? PhysReg::X0 : PhysReg::W0; + PhysReg ret_reg = is_ret_ptr ? PhysReg::X0 : PhysReg::W0; block.Append(Opcode::StoreStack, {Operand::Reg(ret_reg), Operand::FrameIndex(dst_slot)}); } } @@ -438,11 +521,12 @@ void LowerInstruction(const ir::Instruction& inst, MachineFunction& function, auto* idx = gep.GetOperand(i); uint32_t stride = strides.at(i - 1); - // Skip if offset index is constant 0 if (auto* ci = dynamic_cast(idx)) { - if (ci->GetValue() == 0) { - continue; + int64_t offset = static_cast(ci->GetValue()) * stride; + if (offset != 0) { + block.Append(Opcode::AddRegImm, {Operand::Reg(PhysReg::X8), Operand::Reg(PhysReg::X8), Operand::Imm(offset)}); } + continue; } EmitValueToReg(idx, PhysReg::W9, slots, block); @@ -477,6 +561,7 @@ std::vector> LowerToMIR(const ir::Module& modul auto machine_func = std::make_unique(func.GetName()); ValueSlotMap slots; + auto pointers = IdentifyPointerValues(func); // First, create all basic blocks in MachineFunction std::unordered_map bb_map; @@ -490,7 +575,10 @@ std::vector> LowerToMIR(const ir::Module& modul for (const auto& bbPtr : func.GetBlocks()) { for (const auto& inst : bbPtr->GetInstructions()) { if (inst->GetOpcode() == ir::Opcode::Phi) { - int slot = machine_func->CreateFrameIndex(GetTypeSize(inst->GetType().get())); + bool is_phi_ptr = inst->GetType()->IsPtrInt32() || + inst->GetType()->IsPtrFloat() || + pointers.find(inst.get()) != pointers.end(); + int slot = machine_func->CreateFrameIndex(is_phi_ptr ? 8 : GetTypeSize(inst->GetType().get())); slots.emplace(inst.get(), slot); } } @@ -503,7 +591,10 @@ std::vector> LowerToMIR(const ir::Module& modul int int_idx = 0; int float_idx = 0; for (const auto& arg : args) { - int slot = machine_func->CreateFrameIndex(GetTypeSize(arg->GetType().get())); + bool is_arg_ptr = arg->GetType()->IsPtrInt32() || + arg->GetType()->IsPtrFloat() || + pointers.find(arg.get()) != pointers.end(); + int slot = machine_func->CreateFrameIndex(is_arg_ptr ? 8 : GetTypeSize(arg->GetType().get())); slots.emplace(arg.get(), slot); if (arg->GetType()->IsFloat()) { @@ -511,7 +602,7 @@ std::vector> LowerToMIR(const ir::Module& modul entry_block.Append(Opcode::StoreStack, {Operand::Reg(reg), Operand::FrameIndex(slot)}); float_idx++; } else { - PhysReg reg = (arg->GetType()->IsPtrInt32() || arg->GetType()->IsPtrFloat()) + PhysReg reg = is_arg_ptr ? static_cast(static_cast(PhysReg::X0) + int_idx) : static_cast(static_cast(PhysReg::W0) + int_idx); entry_block.Append(Opcode::StoreStack, {Operand::Reg(reg), Operand::FrameIndex(slot)}); @@ -523,7 +614,7 @@ std::vector> LowerToMIR(const ir::Module& modul for (const auto& bbPtr : func.GetBlocks()) { auto& mbb = *bb_map.at(bbPtr.get()); for (const auto& inst : bbPtr->GetInstructions()) { - LowerInstruction(*inst, *machine_func, slots, mbb); + LowerInstruction(*inst, *machine_func, slots, mbb, pointers); } } diff --git a/src/mir/passes/Peephole.cpp b/src/mir/passes/Peephole.cpp index 05e57dc..3b8418b 100644 --- a/src/mir/passes/Peephole.cpp +++ b/src/mir/passes/Peephole.cpp @@ -1,5 +1,6 @@ #include "mir/MIR.h" #include +#include #include namespace mir { @@ -99,10 +100,14 @@ void RunPeephole(MachineFunction& function) { } } - // 3. Track stores + // 3. Track and optimize stores if (op == Opcode::StoreStack) { PhysReg src = NormalizeReg(ops.at(0).GetReg()); int fi = ops.at(1).GetFrameIndex(); + auto it = slot_to_reg.find(fi); + if (it != slot_to_reg.end() && NormalizeReg(it->second) == src) { + continue; // Delete redundant store + } slot_to_reg[fi] = src; } @@ -180,6 +185,54 @@ void RunPeephole(MachineFunction& function) { insts = std::move(optimized); } + + // 5. Eliminate Dead Stack Slots (stores to slots that are never loaded or address-taken) + // Count loads and address-taken operations + std::unordered_map load_count; + std::unordered_map address_taken_count; + + for (const auto& block : function.GetBlocks()) { + for (const auto& inst : block.GetInstructions()) { + Opcode op = inst.GetOpcode(); + const auto& ops = inst.GetOperands(); + + for (const auto& opnd : ops) { + if (opnd.GetKind() == Operand::Kind::FrameIndex) { + int fi = opnd.GetFrameIndex(); + if (op == Opcode::LoadStack) { + load_count[fi]++; + } else if (op != Opcode::StoreStack) { + address_taken_count[fi]++; + } + } + } + } + } + + // Identify dead slots + std::unordered_set dead_slots; + for (size_t i = 0; i < function.GetFrameSlots().size(); ++i) { + int fi = static_cast(i); + if (load_count[fi] == 0 && address_taken_count[fi] == 0) { + dead_slots.insert(fi); + } + } + + // Remove StoreStack to dead slots + for (auto& block : function.GetBlocks()) { + auto& insts = block.GetInstructions(); + std::vector optimized; + for (const auto& inst : insts) { + if (inst.GetOpcode() == Opcode::StoreStack) { + int fi = inst.GetOperands().at(1).GetFrameIndex(); + if (dead_slots.find(fi) != dead_slots.end()) { + continue; // Delete this store + } + } + optimized.push_back(inst); + } + insts = std::move(optimized); + } } } // namespace mir diff --git a/src/sem/Sema.cpp b/src/sem/Sema.cpp index 5ac46ca..3e0a65c 100644 --- a/src/sem/Sema.cpp +++ b/src/sem/Sema.cpp @@ -211,67 +211,25 @@ class SemaVisitor final : public SysYBaseVisitor { return ctx->exp()->accept(this); } - std::any visitMulExp(SysYParser::MulExpContext* ctx) override { + std::any visitMulDivModExp(SysYParser::MulDivModExpContext* ctx) override { ctx->exp(0)->accept(this); ctx->exp(1)->accept(this); return {}; } - std::any visitDivExp(SysYParser::DivExpContext* ctx) override { + std::any visitAddSubExp(SysYParser::AddSubExpContext* ctx) override { ctx->exp(0)->accept(this); ctx->exp(1)->accept(this); return {}; } - std::any visitModExp(SysYParser::ModExpContext* ctx) override { + std::any visitRelExp(SysYParser::RelExpContext* ctx) override { ctx->exp(0)->accept(this); ctx->exp(1)->accept(this); return {}; } - std::any visitAddExp(SysYParser::AddExpContext* ctx) override { - ctx->exp(0)->accept(this); - ctx->exp(1)->accept(this); - return {}; - } - - std::any visitSubExp(SysYParser::SubExpContext* ctx) override { - ctx->exp(0)->accept(this); - ctx->exp(1)->accept(this); - return {}; - } - - std::any visitLtExp(SysYParser::LtExpContext* ctx) override { - ctx->exp(0)->accept(this); - ctx->exp(1)->accept(this); - return {}; - } - - std::any visitLeExp(SysYParser::LeExpContext* ctx) override { - ctx->exp(0)->accept(this); - ctx->exp(1)->accept(this); - return {}; - } - - std::any visitGtExp(SysYParser::GtExpContext* ctx) override { - ctx->exp(0)->accept(this); - ctx->exp(1)->accept(this); - return {}; - } - - std::any visitGeExp(SysYParser::GeExpContext* ctx) override { - ctx->exp(0)->accept(this); - ctx->exp(1)->accept(this); - return {}; - } - - std::any visitEqExp(SysYParser::EqExpContext* ctx) override { - ctx->exp(0)->accept(this); - ctx->exp(1)->accept(this); - return {}; - } - - std::any visitNeExp(SysYParser::NeExpContext* ctx) override { + std::any visitEqNeExp(SysYParser::EqNeExpContext* ctx) override { ctx->exp(0)->accept(this); ctx->exp(1)->accept(this); return {};