Lab6: fix operator precedence and resolve 64-bit pointer propagation in AArch64 lowering

This commit is contained in:
2026-06-28 15:39:00 +08:00
committed by CGH0S7
parent 0e9e2dd345
commit d1edad08e6
7 changed files with 384 additions and 200 deletions

View File

@@ -140,76 +140,59 @@ ir::ConstantValue* IRGenImpl::EvalConstExpr(SysYParser::ExpContext& expr) {
module_.GetContext().GetConstInt(IsTruthy(Eval(*ctx->exp())) ? 0 : 1));
}
std::any visitAddExp(SysYParser::AddExpContext* ctx) override {
std::any visitMulDivModExp(SysYParser::MulDivModExpContext* ctx) override {
auto* lhs = Eval(*ctx->exp(0));
auto* rhs = Eval(*ctx->exp(1));
if (lhs->GetType()->IsFloat() || rhs->GetType()->IsFloat()) {
return static_cast<ir::ConstantValue*>(
module_.GetContext().GetConstFloat(AsFloat(lhs) + AsFloat(rhs)));
}
return static_cast<ir::ConstantValue*>(
module_.GetContext().GetConstInt(AsInt(lhs) + AsInt(rhs)));
}
bool is_mul = ctx->MUL() != nullptr;
bool is_div = ctx->DIV() != nullptr;
bool is_mod = ctx->MOD() != nullptr;
std::any visitSubExp(SysYParser::SubExpContext* ctx) override {
auto* lhs = Eval(*ctx->exp(0));
auto* rhs = Eval(*ctx->exp(1));
if (lhs->GetType()->IsFloat() || rhs->GetType()->IsFloat()) {
return static_cast<ir::ConstantValue*>(
module_.GetContext().GetConstFloat(AsFloat(lhs) - AsFloat(rhs)));
if (is_mod) {
return static_cast<ir::ConstantValue*>(module_.GetContext().GetConstInt(
AsInt(rhs) == 0 ? 0 : AsInt(lhs) % AsInt(rhs)));
}
return static_cast<ir::ConstantValue*>(
module_.GetContext().GetConstInt(AsInt(lhs) - AsInt(rhs)));
}
std::any visitMulExp(SysYParser::MulExpContext* ctx) override {
auto* lhs = Eval(*ctx->exp(0));
auto* rhs = Eval(*ctx->exp(1));
if (lhs->GetType()->IsFloat() || rhs->GetType()->IsFloat()) {
return static_cast<ir::ConstantValue*>(
module_.GetContext().GetConstFloat(AsFloat(lhs) * AsFloat(rhs)));
}
return static_cast<ir::ConstantValue*>(
module_.GetContext().GetConstInt(AsInt(lhs) * AsInt(rhs)));
}
std::any visitDivExp(SysYParser::DivExpContext* ctx) override {
auto* lhs = Eval(*ctx->exp(0));
auto* rhs = Eval(*ctx->exp(1));
if (lhs->GetType()->IsFloat() || rhs->GetType()->IsFloat()) {
const float lv = AsFloat(lhs);
const float rv = AsFloat(rhs);
return static_cast<ir::ConstantValue*>(module_.GetContext().GetConstFloat(
rv == 0.0f ? 0.0f : AsFloat(lhs) / rv));
if (is_mul) return static_cast<ir::ConstantValue*>(module_.GetContext().GetConstFloat(lv * rv));
else return static_cast<ir::ConstantValue*>(module_.GetContext().GetConstFloat(rv == 0.0f ? 0.0f : lv / rv));
}
const int lv = AsInt(lhs);
const int rv = AsInt(rhs);
return static_cast<ir::ConstantValue*>(
module_.GetContext().GetConstInt(rv == 0 ? 0 : AsInt(lhs) / rv));
if (is_mul) return static_cast<ir::ConstantValue*>(module_.GetContext().GetConstInt(lv * rv));
else return static_cast<ir::ConstantValue*>(module_.GetContext().GetConstInt(rv == 0 ? 0 : lv / rv));
}
std::any visitModExp(SysYParser::ModExpContext* ctx) override {
std::any visitAddSubExp(SysYParser::AddSubExpContext* ctx) override {
auto* lhs = Eval(*ctx->exp(0));
auto* rhs = Eval(*ctx->exp(1));
return static_cast<ir::ConstantValue*>(module_.GetContext().GetConstInt(
AsInt(rhs) == 0 ? 0 : AsInt(lhs) % AsInt(rhs)));
bool is_sub = ctx->SUB() != nullptr;
if (lhs->GetType()->IsFloat() || rhs->GetType()->IsFloat()) {
const float lv = AsFloat(lhs);
const float rv = AsFloat(rhs);
return static_cast<ir::ConstantValue*>(module_.GetContext().GetConstFloat(is_sub ? lv - rv : lv + rv));
}
const int lv = AsInt(lhs);
const int rv = AsInt(rhs);
return static_cast<ir::ConstantValue*>(module_.GetContext().GetConstInt(is_sub ? lv - rv : lv + rv));
}
std::any visitLtExp(SysYParser::LtExpContext* ctx) override {
return EvalCmpImpl(*ctx->exp(0), *ctx->exp(1), ir::Opcode::ICmpLT);
std::any visitRelExp(SysYParser::RelExpContext* ctx) override {
ir::Opcode op = ir::Opcode::ICmpLT;
if (ctx->LT()) op = ir::Opcode::ICmpLT;
else if (ctx->LE()) op = ir::Opcode::ICmpLE;
else if (ctx->GT()) op = ir::Opcode::ICmpGT;
else if (ctx->GE()) op = ir::Opcode::ICmpGE;
return EvalCmpImpl(*ctx->exp(0), *ctx->exp(1), op);
}
std::any visitLeExp(SysYParser::LeExpContext* ctx) override {
return EvalCmpImpl(*ctx->exp(0), *ctx->exp(1), ir::Opcode::ICmpLE);
}
std::any visitGtExp(SysYParser::GtExpContext* ctx) override {
return EvalCmpImpl(*ctx->exp(0), *ctx->exp(1), ir::Opcode::ICmpGT);
}
std::any visitGeExp(SysYParser::GeExpContext* ctx) override {
return EvalCmpImpl(*ctx->exp(0), *ctx->exp(1), ir::Opcode::ICmpGE);
}
std::any visitEqExp(SysYParser::EqExpContext* ctx) override {
return EvalCmpImpl(*ctx->exp(0), *ctx->exp(1), ir::Opcode::ICmpEQ);
}
std::any visitNeExp(SysYParser::NeExpContext* ctx) override {
return EvalCmpImpl(*ctx->exp(0), *ctx->exp(1), ir::Opcode::ICmpNE);
std::any visitEqNeExp(SysYParser::EqNeExpContext* ctx) override {
ir::Opcode op = ir::Opcode::ICmpEQ;
if (ctx->EQ()) op = ir::Opcode::ICmpEQ;
else if (ctx->NE()) op = ir::Opcode::ICmpNE;
return EvalCmpImpl(*ctx->exp(0), *ctx->exp(1), op);
}
std::any visitAndExp(SysYParser::AndExpContext* ctx) override {
@@ -432,53 +415,165 @@ std::any IRGenImpl::visitUnarySubExp(SysYParser::UnarySubExpContext* ctx) {
return static_cast<ir::Value*>(builder_.CreateBinary(ir::Opcode::int_opcode, lhs, rhs, module_.GetContext().NextTemp())); \
}
DEFINE_ARITH_VISITOR(Add, Add, FAdd)
DEFINE_ARITH_VISITOR(Sub, Sub, FSub)
DEFINE_ARITH_VISITOR(Mul, Mul, FMul)
DEFINE_ARITH_VISITOR(Div, Div, FDiv)
std::any IRGenImpl::visitMulDivModExp(SysYParser::MulDivModExpContext* ctx) {
ir::Value* lhs = EvalExpr(*ctx->exp(0));
ir::Value* rhs = EvalExpr(*ctx->exp(1));
bool is_mul = ctx->MUL() != nullptr;
bool is_div = ctx->DIV() != nullptr;
bool is_mod = ctx->MOD() != nullptr;
if (is_mod) {
lhs = CastValue(*this, builder_, module_, lhs, ir::Type::GetInt32Type());
rhs = CastValue(*this, builder_, module_, rhs, ir::Type::GetInt32Type());
if (auto* lconst = dynamic_cast<ir::ConstantValue*>(lhs)) {
if (auto* rconst = dynamic_cast<ir::ConstantValue*>(rhs)) {
const int rv = AsInt(rconst);
return static_cast<ir::Value*>(module_.GetContext().GetConstInt(rv == 0 ? 0 : AsInt(lconst) % rv));
}
}
return static_cast<ir::Value*>(builder_.CreateMod(lhs, rhs, module_.GetContext().NextTemp()));
}
const auto common_ty = CommonArithType(lhs, rhs);
lhs = CastValue(*this, builder_, module_, lhs, common_ty);
rhs = CastValue(*this, builder_, module_, rhs, common_ty);
std::any IRGenImpl::visitModExp(SysYParser::ModExpContext* ctx) {
ir::Value* lhs = CastValue(*this, builder_, module_, EvalExpr(*ctx->exp(0)),
ir::Type::GetInt32Type());
ir::Value* rhs = CastValue(*this, builder_, module_, EvalExpr(*ctx->exp(1)),
ir::Type::GetInt32Type());
if (auto* lconst = dynamic_cast<ir::ConstantValue*>(lhs)) {
if (auto* rconst = dynamic_cast<ir::ConstantValue*>(rhs)) {
if (common_ty->IsFloat()) {
const float lv = AsFloat(lconst);
const float rv = AsFloat(rconst);
if (is_mul) return static_cast<ir::Value*>(module_.GetContext().GetConstFloat(lv * rv));
else return static_cast<ir::Value*>(module_.GetContext().GetConstFloat(rv == 0.0f ? 0.0f : lv / rv));
}
const int lv = AsInt(lconst);
const int rv = AsInt(rconst);
return static_cast<ir::Value*>(
module_.GetContext().GetConstInt(rv == 0 ? 0 : AsInt(lconst) % rv));
if (is_mul) return static_cast<ir::Value*>(module_.GetContext().GetConstInt(lv * rv));
else return static_cast<ir::Value*>(module_.GetContext().GetConstInt(rv == 0 ? 0 : lv / rv));
}
}
return static_cast<ir::Value*>(
builder_.CreateMod(lhs, rhs, module_.GetContext().NextTemp()));
if (common_ty->IsFloat()) {
if (is_mul) return static_cast<ir::Value*>(builder_.CreateFMul(lhs, rhs, module_.GetContext().NextTemp()));
else return static_cast<ir::Value*>(builder_.CreateFDiv(lhs, rhs, module_.GetContext().NextTemp()));
}
if (is_mul) return static_cast<ir::Value*>(builder_.CreateBinary(ir::Opcode::Mul, lhs, rhs, module_.GetContext().NextTemp()));
else return static_cast<ir::Value*>(builder_.CreateBinary(ir::Opcode::Div, lhs, rhs, module_.GetContext().NextTemp()));
}
#define DEFINE_CMP_VISITOR(name, int_opcode, float_opcode, cmp_op) \
std::any IRGenImpl::visit##name##Exp(SysYParser::name##ExpContext* ctx) { \
ir::Value* lhs = EvalExpr(*ctx->exp(0)); \
ir::Value* rhs = EvalExpr(*ctx->exp(1)); \
const auto common_ty = CommonArithType(lhs, rhs); \
lhs = CastValue(*this, builder_, module_, lhs, common_ty); \
rhs = CastValue(*this, builder_, module_, rhs, common_ty); \
if (auto* lconst = dynamic_cast<ir::ConstantValue*>(lhs)) { \
if (auto* rconst = dynamic_cast<ir::ConstantValue*>(rhs)) { \
const bool result = common_ty->IsFloat() ? (AsFloat(lconst) cmp_op AsFloat(rconst)) \
: (AsInt(lconst) cmp_op AsInt(rconst)); \
return static_cast<ir::Value*>(module_.GetContext().GetConstInt(result ? 1 : 0)); \
} \
} \
if (common_ty->IsFloat()) { \
return static_cast<ir::Value*>(builder_.CreateFCmp(ir::Opcode::float_opcode, lhs, rhs, module_.GetContext().NextTemp())); \
} \
return static_cast<ir::Value*>(builder_.CreateICmp(ir::Opcode::int_opcode, lhs, rhs, module_.GetContext().NextTemp())); \
std::any IRGenImpl::visitAddSubExp(SysYParser::AddSubExpContext* ctx) {
ir::Value* lhs = EvalExpr(*ctx->exp(0));
ir::Value* rhs = EvalExpr(*ctx->exp(1));
const auto common_ty = CommonArithType(lhs, rhs);
lhs = CastValue(*this, builder_, module_, lhs, common_ty);
rhs = CastValue(*this, builder_, module_, rhs, common_ty);
bool is_sub = ctx->SUB() != nullptr;
if (auto* lconst = dynamic_cast<ir::ConstantValue*>(lhs)) {
if (auto* rconst = dynamic_cast<ir::ConstantValue*>(rhs)) {
if (common_ty->IsFloat()) {
const float lv = AsFloat(lconst);
const float rv = AsFloat(rconst);
return static_cast<ir::Value*>(module_.GetContext().GetConstFloat(is_sub ? lv - rv : lv + rv));
}
const int lv = AsInt(lconst);
const int rv = AsInt(rconst);
return static_cast<ir::Value*>(module_.GetContext().GetConstInt(is_sub ? lv - rv : lv + rv));
}
}
DEFINE_CMP_VISITOR(Lt, ICmpLT, FCmpLT, <)
DEFINE_CMP_VISITOR(Le, ICmpLE, FCmpLE, <=)
DEFINE_CMP_VISITOR(Gt, ICmpGT, FCmpGT, >)
DEFINE_CMP_VISITOR(Ge, ICmpGE, FCmpGE, >=)
DEFINE_CMP_VISITOR(Eq, ICmpEQ, FCmpEQ, ==)
DEFINE_CMP_VISITOR(Ne, ICmpNE, FCmpNE, !=)
if (common_ty->IsFloat()) {
if (is_sub) return static_cast<ir::Value*>(builder_.CreateFSub(lhs, rhs, module_.GetContext().NextTemp()));
else return static_cast<ir::Value*>(builder_.CreateFAdd(lhs, rhs, module_.GetContext().NextTemp()));
}
if (is_sub) return static_cast<ir::Value*>(builder_.CreateBinary(ir::Opcode::Sub, lhs, rhs, module_.GetContext().NextTemp()));
else return static_cast<ir::Value*>(builder_.CreateBinary(ir::Opcode::Add, lhs, rhs, module_.GetContext().NextTemp()));
}
std::any IRGenImpl::visitRelExp(SysYParser::RelExpContext* ctx) {
ir::Value* lhs = EvalExpr(*ctx->exp(0));
ir::Value* rhs = EvalExpr(*ctx->exp(1));
const auto common_ty = CommonArithType(lhs, rhs);
lhs = CastValue(*this, builder_, module_, lhs, common_ty);
rhs = CastValue(*this, builder_, module_, rhs, common_ty);
ir::Opcode int_op = ir::Opcode::ICmpLT;
ir::Opcode float_op = ir::Opcode::FCmpLT;
bool is_lt = ctx->LT() != nullptr;
bool is_le = ctx->LE() != nullptr;
bool is_gt = ctx->GT() != nullptr;
bool is_ge = ctx->GE() != nullptr;
if (is_lt) { int_op = ir::Opcode::ICmpLT; float_op = ir::Opcode::FCmpLT; }
else if (is_le) { int_op = ir::Opcode::ICmpLE; float_op = ir::Opcode::FCmpLE; }
else if (is_gt) { int_op = ir::Opcode::ICmpGT; float_op = ir::Opcode::FCmpGT; }
else if (is_ge) { int_op = ir::Opcode::ICmpGE; float_op = ir::Opcode::FCmpGE; }
if (auto* lconst = dynamic_cast<ir::ConstantValue*>(lhs)) {
if (auto* rconst = dynamic_cast<ir::ConstantValue*>(rhs)) {
bool result = false;
if (common_ty->IsFloat()) {
float lv = AsFloat(lconst);
float rv = AsFloat(rconst);
if (is_lt) result = lv < rv;
else if (is_le) result = lv <= rv;
else if (is_gt) result = lv > rv;
else if (is_ge) result = lv >= rv;
} else {
int lv = AsInt(lconst);
int rv = AsInt(rconst);
if (is_lt) result = lv < rv;
else if (is_le) result = lv <= rv;
else if (is_gt) result = lv > rv;
else if (is_ge) result = lv >= rv;
}
return static_cast<ir::Value*>(module_.GetContext().GetConstInt(result ? 1 : 0));
}
}
if (common_ty->IsFloat()) {
return static_cast<ir::Value*>(builder_.CreateFCmp(float_op, lhs, rhs, module_.GetContext().NextTemp()));
}
return static_cast<ir::Value*>(builder_.CreateICmp(int_op, lhs, rhs, module_.GetContext().NextTemp()));
}
std::any IRGenImpl::visitEqNeExp(SysYParser::EqNeExpContext* ctx) {
ir::Value* lhs = EvalExpr(*ctx->exp(0));
ir::Value* rhs = EvalExpr(*ctx->exp(1));
const auto common_ty = CommonArithType(lhs, rhs);
lhs = CastValue(*this, builder_, module_, lhs, common_ty);
rhs = CastValue(*this, builder_, module_, rhs, common_ty);
ir::Opcode int_op = ir::Opcode::ICmpEQ;
ir::Opcode float_op = ir::Opcode::FCmpEQ;
bool is_eq = ctx->EQ() != nullptr;
if (is_eq) { int_op = ir::Opcode::ICmpEQ; float_op = ir::Opcode::FCmpEQ; }
else { int_op = ir::Opcode::ICmpNE; float_op = ir::Opcode::FCmpNE; }
if (auto* lconst = dynamic_cast<ir::ConstantValue*>(lhs)) {
if (auto* rconst = dynamic_cast<ir::ConstantValue*>(rhs)) {
bool result = false;
if (common_ty->IsFloat()) {
float lv = AsFloat(lconst);
float rv = AsFloat(rconst);
result = is_eq ? (lv == rv) : (lv != rv);
} else {
int lv = AsInt(lconst);
int rv = AsInt(rconst);
result = is_eq ? (lv == rv) : (lv != rv);
}
return static_cast<ir::Value*>(module_.GetContext().GetConstInt(result ? 1 : 0));
}
}
if (common_ty->IsFloat()) {
return static_cast<ir::Value*>(builder_.CreateFCmp(float_op, lhs, rhs, module_.GetContext().NextTemp()));
}
return static_cast<ir::Value*>(builder_.CreateICmp(int_op, lhs, rhs, module_.GetContext().NextTemp()));
}
std::any IRGenImpl::visitAndExp(SysYParser::AndExpContext* ctx) {
if (!builder_.GetInsertBlock()) {
@@ -611,7 +706,8 @@ ir::Value* IRGenImpl::DecayArrayPtr(SysYParser::LValueContext* ctx) {
const auto base_ty = GetDefType(def);
if (dynamic_cast<SysYParser::FuncFParamContext*>(def)) {
if (ctx->exp().empty()) return base_ptr;
ir::Value* loaded_base = builder_.CreateLoad(base_ptr, module_.GetContext().NextTemp());
if (ctx->exp().empty()) return loaded_base;
ir::Value* offset = CastValue(*this, builder_, module_, EvalExpr(*ctx->exp(0)),
ir::Type::GetInt32Type());
@@ -628,7 +724,7 @@ ir::Value* IRGenImpl::DecayArrayPtr(SysYParser::LValueContext* ctx) {
module_.GetContext().NextTemp());
cur_ty = arr_ty->GetElementType();
}
return builder_.CreateGEP(ScalarPointerType(cur_ty), base_ptr, {offset},
return builder_.CreateGEP(ScalarPointerType(cur_ty), loaded_base, {offset},
module_.GetContext().NextTemp());
}