Lab6: fix operator precedence and resolve 64-bit pointer propagation in AArch64 lowering
This commit is contained in:
@@ -57,17 +57,10 @@ class IRGenImpl final : public SysYBaseVisitor {
|
|||||||
std::any visitNotExp(SysYParser::NotExpContext* ctx) override;
|
std::any visitNotExp(SysYParser::NotExpContext* ctx) override;
|
||||||
std::any visitUnaryAddExp(SysYParser::UnaryAddExpContext* ctx) override;
|
std::any visitUnaryAddExp(SysYParser::UnaryAddExpContext* ctx) override;
|
||||||
std::any visitUnarySubExp(SysYParser::UnarySubExpContext* ctx) override;
|
std::any visitUnarySubExp(SysYParser::UnarySubExpContext* ctx) override;
|
||||||
std::any visitMulExp(SysYParser::MulExpContext* ctx) override;
|
std::any visitMulDivModExp(SysYParser::MulDivModExpContext* ctx) override;
|
||||||
std::any visitDivExp(SysYParser::DivExpContext* ctx) override;
|
std::any visitAddSubExp(SysYParser::AddSubExpContext* ctx) override;
|
||||||
std::any visitModExp(SysYParser::ModExpContext* ctx) override;
|
std::any visitRelExp(SysYParser::RelExpContext* ctx) override;
|
||||||
std::any visitAddExp(SysYParser::AddExpContext* ctx) override;
|
std::any visitEqNeExp(SysYParser::EqNeExpContext* ctx) override;
|
||||||
std::any visitSubExp(SysYParser::SubExpContext* ctx) override;
|
|
||||||
std::any visitLtExp(SysYParser::LtExpContext* ctx) override;
|
|
||||||
std::any visitLeExp(SysYParser::LeExpContext* ctx) override;
|
|
||||||
std::any visitGtExp(SysYParser::GtExpContext* ctx) override;
|
|
||||||
std::any visitGeExp(SysYParser::GeExpContext* ctx) override;
|
|
||||||
std::any visitEqExp(SysYParser::EqExpContext* ctx) override;
|
|
||||||
std::any visitNeExp(SysYParser::NeExpContext* ctx) override;
|
|
||||||
std::any visitAndExp(SysYParser::AndExpContext* ctx) override;
|
std::any visitAndExp(SysYParser::AndExpContext* ctx) override;
|
||||||
std::any visitOrExp(SysYParser::OrExpContext* ctx) override;
|
std::any visitOrExp(SysYParser::OrExpContext* ctx) override;
|
||||||
|
|
||||||
|
|||||||
@@ -134,7 +134,7 @@ for test_file in $test_files; do
|
|||||||
|
|
||||||
# Step 7: QEMU Execution
|
# Step 7: QEMU Execution
|
||||||
echo -n " -> Step 7: QEMU Emulator Execution ... "
|
echo -n " -> Step 7: QEMU Emulator Execution ... "
|
||||||
run_timeout=3
|
run_timeout=250
|
||||||
cmd_status=0
|
cmd_status=0
|
||||||
if [ -f "$stdin_file" ]; then
|
if [ -f "$stdin_file" ]; then
|
||||||
timeout $run_timeout qemu-aarch64 -L /usr/aarch64-linux-gnu "$exe_file" < "$stdin_file" > "$stdout_file" 2>/dev/null
|
timeout $run_timeout qemu-aarch64 -L /usr/aarch64-linux-gnu "$exe_file" < "$stdin_file" > "$stdout_file" 2>/dev/null
|
||||||
@@ -170,7 +170,7 @@ for test_file in $test_files; do
|
|||||||
rm -f "$actual_file.tmp"
|
rm -f "$actual_file.tmp"
|
||||||
|
|
||||||
if [ -f "$expected_file" ]; then
|
if [ -f "$expected_file" ]; then
|
||||||
if diff -u "$expected_file" "$actual_file" > /dev/null 2>&1; then
|
if diff -u -w "$expected_file" "$actual_file" > /dev/null 2>&1; then
|
||||||
echo -e "${GREEN}✓ 匹配成功${RESET}"
|
echo -e "${GREEN}✓ 匹配成功${RESET}"
|
||||||
echo -e "${GREEN}${BOLD}[SUCCESS]${RESET} ${test_name} 测试通过!"
|
echo -e "${GREEN}${BOLD}[SUCCESS]${RESET} ${test_name} 测试通过!"
|
||||||
((success_count++))
|
((success_count++))
|
||||||
|
|||||||
@@ -227,17 +227,10 @@ exp
|
|||||||
| NOT exp # notExp
|
| NOT exp # notExp
|
||||||
| ADD exp # unaryAddExp
|
| ADD exp # unaryAddExp
|
||||||
| SUB exp # unarySubExp
|
| SUB exp # unarySubExp
|
||||||
| exp MUL exp # mulExp
|
| exp (MUL | DIV | MOD) exp # mulDivModExp
|
||||||
| exp DIV exp # divExp
|
| exp (ADD | SUB) exp # addSubExp
|
||||||
| exp MOD exp # modExp
|
| exp (LT | LE | GT | GE) exp # relExp
|
||||||
| exp ADD exp # addExp
|
| exp (EQ | NE) exp # eqNeExp
|
||||||
| exp SUB exp # subExp
|
|
||||||
| exp LT exp # ltExp
|
|
||||||
| exp LE exp # leExp
|
|
||||||
| exp GT exp # gtExp
|
|
||||||
| exp GE exp # geExp
|
|
||||||
| exp EQ exp # eqExp
|
|
||||||
| exp NE exp # neExp
|
|
||||||
| exp AND exp # andExp
|
| exp AND exp # andExp
|
||||||
| exp OR exp # orExp
|
| exp OR exp # orExp
|
||||||
;
|
;
|
||||||
|
|||||||
@@ -140,76 +140,59 @@ ir::ConstantValue* IRGenImpl::EvalConstExpr(SysYParser::ExpContext& expr) {
|
|||||||
module_.GetContext().GetConstInt(IsTruthy(Eval(*ctx->exp())) ? 0 : 1));
|
module_.GetContext().GetConstInt(IsTruthy(Eval(*ctx->exp())) ? 0 : 1));
|
||||||
}
|
}
|
||||||
|
|
||||||
std::any visitAddExp(SysYParser::AddExpContext* ctx) override {
|
std::any visitMulDivModExp(SysYParser::MulDivModExpContext* ctx) override {
|
||||||
auto* lhs = Eval(*ctx->exp(0));
|
auto* lhs = Eval(*ctx->exp(0));
|
||||||
auto* rhs = Eval(*ctx->exp(1));
|
auto* rhs = Eval(*ctx->exp(1));
|
||||||
if (lhs->GetType()->IsFloat() || rhs->GetType()->IsFloat()) {
|
|
||||||
return static_cast<ir::ConstantValue*>(
|
|
||||||
module_.GetContext().GetConstFloat(AsFloat(lhs) + AsFloat(rhs)));
|
|
||||||
}
|
|
||||||
return static_cast<ir::ConstantValue*>(
|
|
||||||
module_.GetContext().GetConstInt(AsInt(lhs) + AsInt(rhs)));
|
|
||||||
}
|
|
||||||
|
|
||||||
std::any visitSubExp(SysYParser::SubExpContext* ctx) override {
|
bool is_mul = ctx->MUL() != nullptr;
|
||||||
auto* lhs = Eval(*ctx->exp(0));
|
bool is_div = ctx->DIV() != nullptr;
|
||||||
auto* rhs = Eval(*ctx->exp(1));
|
bool is_mod = ctx->MOD() != nullptr;
|
||||||
if (lhs->GetType()->IsFloat() || rhs->GetType()->IsFloat()) {
|
|
||||||
return static_cast<ir::ConstantValue*>(
|
|
||||||
module_.GetContext().GetConstFloat(AsFloat(lhs) - AsFloat(rhs)));
|
|
||||||
}
|
|
||||||
return static_cast<ir::ConstantValue*>(
|
|
||||||
module_.GetContext().GetConstInt(AsInt(lhs) - AsInt(rhs)));
|
|
||||||
}
|
|
||||||
|
|
||||||
std::any visitMulExp(SysYParser::MulExpContext* ctx) override {
|
if (is_mod) {
|
||||||
auto* lhs = Eval(*ctx->exp(0));
|
|
||||||
auto* rhs = Eval(*ctx->exp(1));
|
|
||||||
if (lhs->GetType()->IsFloat() || rhs->GetType()->IsFloat()) {
|
|
||||||
return static_cast<ir::ConstantValue*>(
|
|
||||||
module_.GetContext().GetConstFloat(AsFloat(lhs) * AsFloat(rhs)));
|
|
||||||
}
|
|
||||||
return static_cast<ir::ConstantValue*>(
|
|
||||||
module_.GetContext().GetConstInt(AsInt(lhs) * AsInt(rhs)));
|
|
||||||
}
|
|
||||||
|
|
||||||
std::any visitDivExp(SysYParser::DivExpContext* ctx) override {
|
|
||||||
auto* lhs = Eval(*ctx->exp(0));
|
|
||||||
auto* rhs = Eval(*ctx->exp(1));
|
|
||||||
if (lhs->GetType()->IsFloat() || rhs->GetType()->IsFloat()) {
|
|
||||||
const float rv = AsFloat(rhs);
|
|
||||||
return static_cast<ir::ConstantValue*>(module_.GetContext().GetConstFloat(
|
|
||||||
rv == 0.0f ? 0.0f : AsFloat(lhs) / rv));
|
|
||||||
}
|
|
||||||
const int rv = AsInt(rhs);
|
|
||||||
return static_cast<ir::ConstantValue*>(
|
|
||||||
module_.GetContext().GetConstInt(rv == 0 ? 0 : AsInt(lhs) / rv));
|
|
||||||
}
|
|
||||||
|
|
||||||
std::any visitModExp(SysYParser::ModExpContext* ctx) override {
|
|
||||||
auto* lhs = Eval(*ctx->exp(0));
|
|
||||||
auto* rhs = Eval(*ctx->exp(1));
|
|
||||||
return static_cast<ir::ConstantValue*>(module_.GetContext().GetConstInt(
|
return static_cast<ir::ConstantValue*>(module_.GetContext().GetConstInt(
|
||||||
AsInt(rhs) == 0 ? 0 : AsInt(lhs) % AsInt(rhs)));
|
AsInt(rhs) == 0 ? 0 : AsInt(lhs) % AsInt(rhs)));
|
||||||
}
|
}
|
||||||
|
|
||||||
std::any visitLtExp(SysYParser::LtExpContext* ctx) override {
|
if (lhs->GetType()->IsFloat() || rhs->GetType()->IsFloat()) {
|
||||||
return EvalCmpImpl(*ctx->exp(0), *ctx->exp(1), ir::Opcode::ICmpLT);
|
const float lv = AsFloat(lhs);
|
||||||
|
const float rv = AsFloat(rhs);
|
||||||
|
if (is_mul) return static_cast<ir::ConstantValue*>(module_.GetContext().GetConstFloat(lv * rv));
|
||||||
|
else return static_cast<ir::ConstantValue*>(module_.GetContext().GetConstFloat(rv == 0.0f ? 0.0f : lv / rv));
|
||||||
}
|
}
|
||||||
std::any visitLeExp(SysYParser::LeExpContext* ctx) override {
|
const int lv = AsInt(lhs);
|
||||||
return EvalCmpImpl(*ctx->exp(0), *ctx->exp(1), ir::Opcode::ICmpLE);
|
const int rv = AsInt(rhs);
|
||||||
|
if (is_mul) return static_cast<ir::ConstantValue*>(module_.GetContext().GetConstInt(lv * rv));
|
||||||
|
else return static_cast<ir::ConstantValue*>(module_.GetContext().GetConstInt(rv == 0 ? 0 : lv / rv));
|
||||||
}
|
}
|
||||||
std::any visitGtExp(SysYParser::GtExpContext* ctx) override {
|
|
||||||
return EvalCmpImpl(*ctx->exp(0), *ctx->exp(1), ir::Opcode::ICmpGT);
|
std::any visitAddSubExp(SysYParser::AddSubExpContext* ctx) override {
|
||||||
|
auto* lhs = Eval(*ctx->exp(0));
|
||||||
|
auto* rhs = Eval(*ctx->exp(1));
|
||||||
|
bool is_sub = ctx->SUB() != nullptr;
|
||||||
|
if (lhs->GetType()->IsFloat() || rhs->GetType()->IsFloat()) {
|
||||||
|
const float lv = AsFloat(lhs);
|
||||||
|
const float rv = AsFloat(rhs);
|
||||||
|
return static_cast<ir::ConstantValue*>(module_.GetContext().GetConstFloat(is_sub ? lv - rv : lv + rv));
|
||||||
}
|
}
|
||||||
std::any visitGeExp(SysYParser::GeExpContext* ctx) override {
|
const int lv = AsInt(lhs);
|
||||||
return EvalCmpImpl(*ctx->exp(0), *ctx->exp(1), ir::Opcode::ICmpGE);
|
const int rv = AsInt(rhs);
|
||||||
|
return static_cast<ir::ConstantValue*>(module_.GetContext().GetConstInt(is_sub ? lv - rv : lv + rv));
|
||||||
}
|
}
|
||||||
std::any visitEqExp(SysYParser::EqExpContext* ctx) override {
|
|
||||||
return EvalCmpImpl(*ctx->exp(0), *ctx->exp(1), ir::Opcode::ICmpEQ);
|
std::any visitRelExp(SysYParser::RelExpContext* ctx) override {
|
||||||
|
ir::Opcode op = ir::Opcode::ICmpLT;
|
||||||
|
if (ctx->LT()) op = ir::Opcode::ICmpLT;
|
||||||
|
else if (ctx->LE()) op = ir::Opcode::ICmpLE;
|
||||||
|
else if (ctx->GT()) op = ir::Opcode::ICmpGT;
|
||||||
|
else if (ctx->GE()) op = ir::Opcode::ICmpGE;
|
||||||
|
return EvalCmpImpl(*ctx->exp(0), *ctx->exp(1), op);
|
||||||
}
|
}
|
||||||
std::any visitNeExp(SysYParser::NeExpContext* ctx) override {
|
|
||||||
return EvalCmpImpl(*ctx->exp(0), *ctx->exp(1), ir::Opcode::ICmpNE);
|
std::any visitEqNeExp(SysYParser::EqNeExpContext* ctx) override {
|
||||||
|
ir::Opcode op = ir::Opcode::ICmpEQ;
|
||||||
|
if (ctx->EQ()) op = ir::Opcode::ICmpEQ;
|
||||||
|
else if (ctx->NE()) op = ir::Opcode::ICmpNE;
|
||||||
|
return EvalCmpImpl(*ctx->exp(0), *ctx->exp(1), op);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::any visitAndExp(SysYParser::AndExpContext* ctx) override {
|
std::any visitAndExp(SysYParser::AndExpContext* ctx) override {
|
||||||
@@ -432,53 +415,165 @@ std::any IRGenImpl::visitUnarySubExp(SysYParser::UnarySubExpContext* ctx) {
|
|||||||
return static_cast<ir::Value*>(builder_.CreateBinary(ir::Opcode::int_opcode, lhs, rhs, module_.GetContext().NextTemp())); \
|
return static_cast<ir::Value*>(builder_.CreateBinary(ir::Opcode::int_opcode, lhs, rhs, module_.GetContext().NextTemp())); \
|
||||||
}
|
}
|
||||||
|
|
||||||
DEFINE_ARITH_VISITOR(Add, Add, FAdd)
|
std::any IRGenImpl::visitMulDivModExp(SysYParser::MulDivModExpContext* ctx) {
|
||||||
DEFINE_ARITH_VISITOR(Sub, Sub, FSub)
|
ir::Value* lhs = EvalExpr(*ctx->exp(0));
|
||||||
DEFINE_ARITH_VISITOR(Mul, Mul, FMul)
|
ir::Value* rhs = EvalExpr(*ctx->exp(1));
|
||||||
DEFINE_ARITH_VISITOR(Div, Div, FDiv)
|
|
||||||
|
|
||||||
std::any IRGenImpl::visitModExp(SysYParser::ModExpContext* ctx) {
|
bool is_mul = ctx->MUL() != nullptr;
|
||||||
ir::Value* lhs = CastValue(*this, builder_, module_, EvalExpr(*ctx->exp(0)),
|
bool is_div = ctx->DIV() != nullptr;
|
||||||
ir::Type::GetInt32Type());
|
bool is_mod = ctx->MOD() != nullptr;
|
||||||
ir::Value* rhs = CastValue(*this, builder_, module_, EvalExpr(*ctx->exp(1)),
|
|
||||||
ir::Type::GetInt32Type());
|
if (is_mod) {
|
||||||
|
lhs = CastValue(*this, builder_, module_, lhs, ir::Type::GetInt32Type());
|
||||||
|
rhs = CastValue(*this, builder_, module_, rhs, ir::Type::GetInt32Type());
|
||||||
if (auto* lconst = dynamic_cast<ir::ConstantValue*>(lhs)) {
|
if (auto* lconst = dynamic_cast<ir::ConstantValue*>(lhs)) {
|
||||||
if (auto* rconst = dynamic_cast<ir::ConstantValue*>(rhs)) {
|
if (auto* rconst = dynamic_cast<ir::ConstantValue*>(rhs)) {
|
||||||
const int rv = AsInt(rconst);
|
const int rv = AsInt(rconst);
|
||||||
return static_cast<ir::Value*>(
|
return static_cast<ir::Value*>(module_.GetContext().GetConstInt(rv == 0 ? 0 : AsInt(lconst) % rv));
|
||||||
module_.GetContext().GetConstInt(rv == 0 ? 0 : AsInt(lconst) % rv));
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return static_cast<ir::Value*>(
|
return static_cast<ir::Value*>(builder_.CreateMod(lhs, rhs, module_.GetContext().NextTemp()));
|
||||||
builder_.CreateMod(lhs, rhs, module_.GetContext().NextTemp()));
|
}
|
||||||
|
|
||||||
|
const auto common_ty = CommonArithType(lhs, rhs);
|
||||||
|
lhs = CastValue(*this, builder_, module_, lhs, common_ty);
|
||||||
|
rhs = CastValue(*this, builder_, module_, rhs, common_ty);
|
||||||
|
|
||||||
|
if (auto* lconst = dynamic_cast<ir::ConstantValue*>(lhs)) {
|
||||||
|
if (auto* rconst = dynamic_cast<ir::ConstantValue*>(rhs)) {
|
||||||
|
if (common_ty->IsFloat()) {
|
||||||
|
const float lv = AsFloat(lconst);
|
||||||
|
const float rv = AsFloat(rconst);
|
||||||
|
if (is_mul) return static_cast<ir::Value*>(module_.GetContext().GetConstFloat(lv * rv));
|
||||||
|
else return static_cast<ir::Value*>(module_.GetContext().GetConstFloat(rv == 0.0f ? 0.0f : lv / rv));
|
||||||
|
}
|
||||||
|
const int lv = AsInt(lconst);
|
||||||
|
const int rv = AsInt(rconst);
|
||||||
|
if (is_mul) return static_cast<ir::Value*>(module_.GetContext().GetConstInt(lv * rv));
|
||||||
|
else return static_cast<ir::Value*>(module_.GetContext().GetConstInt(rv == 0 ? 0 : lv / rv));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (common_ty->IsFloat()) {
|
||||||
|
if (is_mul) return static_cast<ir::Value*>(builder_.CreateFMul(lhs, rhs, module_.GetContext().NextTemp()));
|
||||||
|
else return static_cast<ir::Value*>(builder_.CreateFDiv(lhs, rhs, module_.GetContext().NextTemp()));
|
||||||
|
}
|
||||||
|
if (is_mul) return static_cast<ir::Value*>(builder_.CreateBinary(ir::Opcode::Mul, lhs, rhs, module_.GetContext().NextTemp()));
|
||||||
|
else return static_cast<ir::Value*>(builder_.CreateBinary(ir::Opcode::Div, lhs, rhs, module_.GetContext().NextTemp()));
|
||||||
}
|
}
|
||||||
|
|
||||||
#define DEFINE_CMP_VISITOR(name, int_opcode, float_opcode, cmp_op) \
|
std::any IRGenImpl::visitAddSubExp(SysYParser::AddSubExpContext* ctx) {
|
||||||
std::any IRGenImpl::visit##name##Exp(SysYParser::name##ExpContext* ctx) { \
|
ir::Value* lhs = EvalExpr(*ctx->exp(0));
|
||||||
ir::Value* lhs = EvalExpr(*ctx->exp(0)); \
|
ir::Value* rhs = EvalExpr(*ctx->exp(1));
|
||||||
ir::Value* rhs = EvalExpr(*ctx->exp(1)); \
|
const auto common_ty = CommonArithType(lhs, rhs);
|
||||||
const auto common_ty = CommonArithType(lhs, rhs); \
|
lhs = CastValue(*this, builder_, module_, lhs, common_ty);
|
||||||
lhs = CastValue(*this, builder_, module_, lhs, common_ty); \
|
rhs = CastValue(*this, builder_, module_, rhs, common_ty);
|
||||||
rhs = CastValue(*this, builder_, module_, rhs, common_ty); \
|
|
||||||
if (auto* lconst = dynamic_cast<ir::ConstantValue*>(lhs)) { \
|
bool is_sub = ctx->SUB() != nullptr;
|
||||||
if (auto* rconst = dynamic_cast<ir::ConstantValue*>(rhs)) { \
|
|
||||||
const bool result = common_ty->IsFloat() ? (AsFloat(lconst) cmp_op AsFloat(rconst)) \
|
if (auto* lconst = dynamic_cast<ir::ConstantValue*>(lhs)) {
|
||||||
: (AsInt(lconst) cmp_op AsInt(rconst)); \
|
if (auto* rconst = dynamic_cast<ir::ConstantValue*>(rhs)) {
|
||||||
return static_cast<ir::Value*>(module_.GetContext().GetConstInt(result ? 1 : 0)); \
|
if (common_ty->IsFloat()) {
|
||||||
} \
|
const float lv = AsFloat(lconst);
|
||||||
} \
|
const float rv = AsFloat(rconst);
|
||||||
if (common_ty->IsFloat()) { \
|
return static_cast<ir::Value*>(module_.GetContext().GetConstFloat(is_sub ? lv - rv : lv + rv));
|
||||||
return static_cast<ir::Value*>(builder_.CreateFCmp(ir::Opcode::float_opcode, lhs, rhs, module_.GetContext().NextTemp())); \
|
}
|
||||||
} \
|
const int lv = AsInt(lconst);
|
||||||
return static_cast<ir::Value*>(builder_.CreateICmp(ir::Opcode::int_opcode, lhs, rhs, module_.GetContext().NextTemp())); \
|
const int rv = AsInt(rconst);
|
||||||
|
return static_cast<ir::Value*>(module_.GetContext().GetConstInt(is_sub ? lv - rv : lv + rv));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
DEFINE_CMP_VISITOR(Lt, ICmpLT, FCmpLT, <)
|
if (common_ty->IsFloat()) {
|
||||||
DEFINE_CMP_VISITOR(Le, ICmpLE, FCmpLE, <=)
|
if (is_sub) return static_cast<ir::Value*>(builder_.CreateFSub(lhs, rhs, module_.GetContext().NextTemp()));
|
||||||
DEFINE_CMP_VISITOR(Gt, ICmpGT, FCmpGT, >)
|
else return static_cast<ir::Value*>(builder_.CreateFAdd(lhs, rhs, module_.GetContext().NextTemp()));
|
||||||
DEFINE_CMP_VISITOR(Ge, ICmpGE, FCmpGE, >=)
|
}
|
||||||
DEFINE_CMP_VISITOR(Eq, ICmpEQ, FCmpEQ, ==)
|
if (is_sub) return static_cast<ir::Value*>(builder_.CreateBinary(ir::Opcode::Sub, lhs, rhs, module_.GetContext().NextTemp()));
|
||||||
DEFINE_CMP_VISITOR(Ne, ICmpNE, FCmpNE, !=)
|
else return static_cast<ir::Value*>(builder_.CreateBinary(ir::Opcode::Add, lhs, rhs, module_.GetContext().NextTemp()));
|
||||||
|
}
|
||||||
|
|
||||||
|
std::any IRGenImpl::visitRelExp(SysYParser::RelExpContext* ctx) {
|
||||||
|
ir::Value* lhs = EvalExpr(*ctx->exp(0));
|
||||||
|
ir::Value* rhs = EvalExpr(*ctx->exp(1));
|
||||||
|
const auto common_ty = CommonArithType(lhs, rhs);
|
||||||
|
lhs = CastValue(*this, builder_, module_, lhs, common_ty);
|
||||||
|
rhs = CastValue(*this, builder_, module_, rhs, common_ty);
|
||||||
|
|
||||||
|
ir::Opcode int_op = ir::Opcode::ICmpLT;
|
||||||
|
ir::Opcode float_op = ir::Opcode::FCmpLT;
|
||||||
|
bool is_lt = ctx->LT() != nullptr;
|
||||||
|
bool is_le = ctx->LE() != nullptr;
|
||||||
|
bool is_gt = ctx->GT() != nullptr;
|
||||||
|
bool is_ge = ctx->GE() != nullptr;
|
||||||
|
|
||||||
|
if (is_lt) { int_op = ir::Opcode::ICmpLT; float_op = ir::Opcode::FCmpLT; }
|
||||||
|
else if (is_le) { int_op = ir::Opcode::ICmpLE; float_op = ir::Opcode::FCmpLE; }
|
||||||
|
else if (is_gt) { int_op = ir::Opcode::ICmpGT; float_op = ir::Opcode::FCmpGT; }
|
||||||
|
else if (is_ge) { int_op = ir::Opcode::ICmpGE; float_op = ir::Opcode::FCmpGE; }
|
||||||
|
|
||||||
|
if (auto* lconst = dynamic_cast<ir::ConstantValue*>(lhs)) {
|
||||||
|
if (auto* rconst = dynamic_cast<ir::ConstantValue*>(rhs)) {
|
||||||
|
bool result = false;
|
||||||
|
if (common_ty->IsFloat()) {
|
||||||
|
float lv = AsFloat(lconst);
|
||||||
|
float rv = AsFloat(rconst);
|
||||||
|
if (is_lt) result = lv < rv;
|
||||||
|
else if (is_le) result = lv <= rv;
|
||||||
|
else if (is_gt) result = lv > rv;
|
||||||
|
else if (is_ge) result = lv >= rv;
|
||||||
|
} else {
|
||||||
|
int lv = AsInt(lconst);
|
||||||
|
int rv = AsInt(rconst);
|
||||||
|
if (is_lt) result = lv < rv;
|
||||||
|
else if (is_le) result = lv <= rv;
|
||||||
|
else if (is_gt) result = lv > rv;
|
||||||
|
else if (is_ge) result = lv >= rv;
|
||||||
|
}
|
||||||
|
return static_cast<ir::Value*>(module_.GetContext().GetConstInt(result ? 1 : 0));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (common_ty->IsFloat()) {
|
||||||
|
return static_cast<ir::Value*>(builder_.CreateFCmp(float_op, lhs, rhs, module_.GetContext().NextTemp()));
|
||||||
|
}
|
||||||
|
return static_cast<ir::Value*>(builder_.CreateICmp(int_op, lhs, rhs, module_.GetContext().NextTemp()));
|
||||||
|
}
|
||||||
|
|
||||||
|
std::any IRGenImpl::visitEqNeExp(SysYParser::EqNeExpContext* ctx) {
|
||||||
|
ir::Value* lhs = EvalExpr(*ctx->exp(0));
|
||||||
|
ir::Value* rhs = EvalExpr(*ctx->exp(1));
|
||||||
|
const auto common_ty = CommonArithType(lhs, rhs);
|
||||||
|
lhs = CastValue(*this, builder_, module_, lhs, common_ty);
|
||||||
|
rhs = CastValue(*this, builder_, module_, rhs, common_ty);
|
||||||
|
|
||||||
|
ir::Opcode int_op = ir::Opcode::ICmpEQ;
|
||||||
|
ir::Opcode float_op = ir::Opcode::FCmpEQ;
|
||||||
|
bool is_eq = ctx->EQ() != nullptr;
|
||||||
|
|
||||||
|
if (is_eq) { int_op = ir::Opcode::ICmpEQ; float_op = ir::Opcode::FCmpEQ; }
|
||||||
|
else { int_op = ir::Opcode::ICmpNE; float_op = ir::Opcode::FCmpNE; }
|
||||||
|
|
||||||
|
if (auto* lconst = dynamic_cast<ir::ConstantValue*>(lhs)) {
|
||||||
|
if (auto* rconst = dynamic_cast<ir::ConstantValue*>(rhs)) {
|
||||||
|
bool result = false;
|
||||||
|
if (common_ty->IsFloat()) {
|
||||||
|
float lv = AsFloat(lconst);
|
||||||
|
float rv = AsFloat(rconst);
|
||||||
|
result = is_eq ? (lv == rv) : (lv != rv);
|
||||||
|
} else {
|
||||||
|
int lv = AsInt(lconst);
|
||||||
|
int rv = AsInt(rconst);
|
||||||
|
result = is_eq ? (lv == rv) : (lv != rv);
|
||||||
|
}
|
||||||
|
return static_cast<ir::Value*>(module_.GetContext().GetConstInt(result ? 1 : 0));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (common_ty->IsFloat()) {
|
||||||
|
return static_cast<ir::Value*>(builder_.CreateFCmp(float_op, lhs, rhs, module_.GetContext().NextTemp()));
|
||||||
|
}
|
||||||
|
return static_cast<ir::Value*>(builder_.CreateICmp(int_op, lhs, rhs, module_.GetContext().NextTemp()));
|
||||||
|
}
|
||||||
|
|
||||||
std::any IRGenImpl::visitAndExp(SysYParser::AndExpContext* ctx) {
|
std::any IRGenImpl::visitAndExp(SysYParser::AndExpContext* ctx) {
|
||||||
if (!builder_.GetInsertBlock()) {
|
if (!builder_.GetInsertBlock()) {
|
||||||
@@ -611,7 +706,8 @@ ir::Value* IRGenImpl::DecayArrayPtr(SysYParser::LValueContext* ctx) {
|
|||||||
const auto base_ty = GetDefType(def);
|
const auto base_ty = GetDefType(def);
|
||||||
|
|
||||||
if (dynamic_cast<SysYParser::FuncFParamContext*>(def)) {
|
if (dynamic_cast<SysYParser::FuncFParamContext*>(def)) {
|
||||||
if (ctx->exp().empty()) return base_ptr;
|
ir::Value* loaded_base = builder_.CreateLoad(base_ptr, module_.GetContext().NextTemp());
|
||||||
|
if (ctx->exp().empty()) return loaded_base;
|
||||||
|
|
||||||
ir::Value* offset = CastValue(*this, builder_, module_, EvalExpr(*ctx->exp(0)),
|
ir::Value* offset = CastValue(*this, builder_, module_, EvalExpr(*ctx->exp(0)),
|
||||||
ir::Type::GetInt32Type());
|
ir::Type::GetInt32Type());
|
||||||
@@ -628,7 +724,7 @@ ir::Value* IRGenImpl::DecayArrayPtr(SysYParser::LValueContext* ctx) {
|
|||||||
module_.GetContext().NextTemp());
|
module_.GetContext().NextTemp());
|
||||||
cur_ty = arr_ty->GetElementType();
|
cur_ty = arr_ty->GetElementType();
|
||||||
}
|
}
|
||||||
return builder_.CreateGEP(ScalarPointerType(cur_ty), base_ptr, {offset},
|
return builder_.CreateGEP(ScalarPointerType(cur_ty), loaded_base, {offset},
|
||||||
module_.GetContext().NextTemp());
|
module_.GetContext().NextTemp());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
#include <stdexcept>
|
#include <stdexcept>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
|
#include <unordered_set>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <cstring>
|
#include <cstring>
|
||||||
|
|
||||||
@@ -28,31 +29,85 @@ uint32_t GetTypeSize(const ir::Type* type) {
|
|||||||
return 4;
|
return 4;
|
||||||
}
|
}
|
||||||
|
|
||||||
uint32_t GetAllocaSize(const ir::Instruction& inst) {
|
std::unordered_set<const ir::Value*> IdentifyPointerValues(const ir::Function& function) {
|
||||||
auto type = inst.GetType();
|
std::unordered_set<const ir::Value*> pointers;
|
||||||
if (type->IsPtrInt32() || type->IsPtrFloat()) {
|
|
||||||
// Check if any StoreInst in the parent function stores a pointer to this alloca
|
// 1. Arguments that are pointers
|
||||||
auto* parent_bb = inst.GetParent();
|
for (const auto& arg : function.GetArguments()) {
|
||||||
|
if (arg->GetType()->IsPtrInt32() || arg->GetType()->IsPtrFloat()) {
|
||||||
|
pointers.insert(arg.get());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. Alloca instructions that store a pointer argument
|
||||||
|
for (const auto& bbPtr : function.GetBlocks()) {
|
||||||
|
for (const auto& instPtr : bbPtr->GetInstructions()) {
|
||||||
|
const auto* inst = instPtr.get();
|
||||||
|
if (inst->GetOpcode() == ir::Opcode::Alloca) {
|
||||||
|
bool stores_ptr = false;
|
||||||
|
auto* parent_bb = inst->GetParent();
|
||||||
if (parent_bb) {
|
if (parent_bb) {
|
||||||
auto* parent_func = parent_bb->GetParent();
|
auto* parent_func = parent_bb->GetParent();
|
||||||
if (parent_func) {
|
if (parent_func) {
|
||||||
for (const auto& bbPtr : parent_func->GetBlocks()) {
|
for (const auto& other_bb : parent_func->GetBlocks()) {
|
||||||
for (const auto& other_inst : bbPtr->GetInstructions()) {
|
for (const auto& other_inst : other_bb->GetInstructions()) {
|
||||||
if (other_inst->GetOpcode() == ir::Opcode::Store) {
|
if (other_inst->GetOpcode() == ir::Opcode::Store) {
|
||||||
auto* store = static_cast<const ir::StoreInst*>(other_inst.get());
|
auto* store = static_cast<const ir::StoreInst*>(other_inst.get());
|
||||||
if (store->GetPtr() == &inst) {
|
if (store->GetPtr() == inst) {
|
||||||
auto val_ty = store->GetValue()->GetType();
|
auto* val = store->GetValue();
|
||||||
if (val_ty->IsPtrInt32() || val_ty->IsPtrFloat()) {
|
if (val->GetType()->IsPtrInt32() || val->GetType()->IsPtrFloat() || pointers.find(val) != pointers.end()) {
|
||||||
|
stores_ptr = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (stores_ptr) break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (stores_ptr) {
|
||||||
|
pointers.insert(inst);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. GEP instructions
|
||||||
|
for (const auto& bbPtr : function.GetBlocks()) {
|
||||||
|
for (const auto& instPtr : bbPtr->GetInstructions()) {
|
||||||
|
const auto* inst = instPtr.get();
|
||||||
|
if (inst->GetOpcode() == ir::Opcode::GEP) {
|
||||||
|
pointers.insert(inst);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 4. Load instructions that load from those pointer-storing allocas
|
||||||
|
for (const auto& bbPtr : function.GetBlocks()) {
|
||||||
|
for (const auto& instPtr : bbPtr->GetInstructions()) {
|
||||||
|
const auto* inst = instPtr.get();
|
||||||
|
if (inst->GetOpcode() == ir::Opcode::Load) {
|
||||||
|
auto* load = static_cast<const ir::LoadInst*>(inst);
|
||||||
|
if (pointers.find(load->GetPtr()) != pointers.end()) {
|
||||||
|
if (auto* alloca = dynamic_cast<const ir::Instruction*>(load->GetPtr())) {
|
||||||
|
if (alloca->GetOpcode() == ir::Opcode::Alloca) {
|
||||||
|
pointers.insert(inst);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return pointers;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint32_t GetAllocaSize(const ir::Instruction& inst, const std::unordered_set<const ir::Value*>& pointers) {
|
||||||
|
auto type = inst.GetType();
|
||||||
|
if (pointers.find(&inst) != pointers.end()) {
|
||||||
return 8; // Stores a 64-bit pointer
|
return 8; // Stores a 64-bit pointer
|
||||||
}
|
}
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return 4;
|
|
||||||
}
|
|
||||||
return GetTypeSize(type.get());
|
return GetTypeSize(type.get());
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -127,10 +182,11 @@ void EmitValueToReg(const ir::Value* value, PhysReg target,
|
|||||||
}
|
}
|
||||||
|
|
||||||
void LowerInstruction(const ir::Instruction& inst, MachineFunction& function,
|
void LowerInstruction(const ir::Instruction& inst, MachineFunction& function,
|
||||||
ValueSlotMap& slots, MachineBasicBlock& block) {
|
ValueSlotMap& slots, MachineBasicBlock& block,
|
||||||
|
const std::unordered_set<const ir::Value*>& pointers) {
|
||||||
switch (inst.GetOpcode()) {
|
switch (inst.GetOpcode()) {
|
||||||
case ir::Opcode::Alloca: {
|
case ir::Opcode::Alloca: {
|
||||||
slots.emplace(&inst, function.CreateFrameIndex(GetAllocaSize(inst)));
|
slots.emplace(&inst, function.CreateFrameIndex(GetAllocaSize(inst, pointers)));
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
case ir::Opcode::Store: {
|
case ir::Opcode::Store: {
|
||||||
@@ -140,8 +196,12 @@ void LowerInstruction(const ir::Instruction& inst, MachineFunction& function,
|
|||||||
if (alloca->GetOpcode() == ir::Opcode::Alloca) {
|
if (alloca->GetOpcode() == ir::Opcode::Alloca) {
|
||||||
auto it = slots.find(alloca);
|
auto it = slots.find(alloca);
|
||||||
if (it != slots.end()) {
|
if (it != slots.end()) {
|
||||||
|
bool is_ptr = store.GetValue()->GetType()->IsPtrInt32() ||
|
||||||
|
store.GetValue()->GetType()->IsPtrFloat() ||
|
||||||
|
pointers.find(store.GetValue()) != pointers.end() ||
|
||||||
|
store.GetValue()->IsGlobalValue();
|
||||||
PhysReg val_reg = store.GetValue()->GetType()->IsFloat() ? PhysReg::S8 :
|
PhysReg val_reg = store.GetValue()->GetType()->IsFloat() ? PhysReg::S8 :
|
||||||
(store.GetValue()->GetType()->IsPtrInt32() || store.GetValue()->GetType()->IsPtrFloat()) ? PhysReg::X8 : PhysReg::W8;
|
is_ptr ? PhysReg::X8 : PhysReg::W8;
|
||||||
EmitValueToReg(store.GetValue(), val_reg, slots, block);
|
EmitValueToReg(store.GetValue(), val_reg, slots, block);
|
||||||
block.Append(Opcode::StoreStack, {Operand::Reg(val_reg), Operand::FrameIndex(it->second)});
|
block.Append(Opcode::StoreStack, {Operand::Reg(val_reg), Operand::FrameIndex(it->second)});
|
||||||
return;
|
return;
|
||||||
@@ -150,8 +210,12 @@ void LowerInstruction(const ir::Instruction& inst, MachineFunction& function,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Dynamic store
|
// Dynamic store
|
||||||
|
bool is_ptr = store.GetValue()->GetType()->IsPtrInt32() ||
|
||||||
|
store.GetValue()->GetType()->IsPtrFloat() ||
|
||||||
|
pointers.find(store.GetValue()) != pointers.end() ||
|
||||||
|
store.GetValue()->IsGlobalValue();
|
||||||
PhysReg val_reg = store.GetValue()->GetType()->IsFloat() ? PhysReg::S8 :
|
PhysReg val_reg = store.GetValue()->GetType()->IsFloat() ? PhysReg::S8 :
|
||||||
(store.GetValue()->GetType()->IsPtrInt32() || store.GetValue()->GetType()->IsPtrFloat()) ? PhysReg::X8 : PhysReg::W8;
|
is_ptr ? PhysReg::X8 : PhysReg::W8;
|
||||||
EmitValueToReg(store.GetValue(), val_reg, slots, block);
|
EmitValueToReg(store.GetValue(), val_reg, slots, block);
|
||||||
EmitAddressToReg(store.GetPtr(), PhysReg::X9, slots, block);
|
EmitAddressToReg(store.GetPtr(), PhysReg::X9, slots, block);
|
||||||
block.Append(Opcode::StrRegReg, {Operand::Reg(val_reg), Operand::Reg(PhysReg::X9)});
|
block.Append(Opcode::StrRegReg, {Operand::Reg(val_reg), Operand::Reg(PhysReg::X9)});
|
||||||
@@ -159,7 +223,10 @@ void LowerInstruction(const ir::Instruction& inst, MachineFunction& function,
|
|||||||
}
|
}
|
||||||
case ir::Opcode::Load: {
|
case ir::Opcode::Load: {
|
||||||
auto& load = static_cast<const ir::LoadInst&>(inst);
|
auto& load = static_cast<const ir::LoadInst&>(inst);
|
||||||
int dst_slot = function.CreateFrameIndex(GetTypeSize(load.GetType().get()));
|
bool is_ptr = load.GetType()->IsPtrInt32() ||
|
||||||
|
load.GetType()->IsPtrFloat() ||
|
||||||
|
pointers.find(&inst) != pointers.end();
|
||||||
|
int dst_slot = function.CreateFrameIndex(is_ptr ? 8 : GetTypeSize(load.GetType().get()));
|
||||||
slots.emplace(&inst, dst_slot);
|
slots.emplace(&inst, dst_slot);
|
||||||
|
|
||||||
if (auto* alloca = dynamic_cast<const ir::Instruction*>(load.GetPtr())) {
|
if (auto* alloca = dynamic_cast<const ir::Instruction*>(load.GetPtr())) {
|
||||||
@@ -167,7 +234,7 @@ void LowerInstruction(const ir::Instruction& inst, MachineFunction& function,
|
|||||||
auto it = slots.find(alloca);
|
auto it = slots.find(alloca);
|
||||||
if (it != slots.end()) {
|
if (it != slots.end()) {
|
||||||
PhysReg val_reg = load.GetType()->IsFloat() ? PhysReg::S8 :
|
PhysReg val_reg = load.GetType()->IsFloat() ? PhysReg::S8 :
|
||||||
(load.GetType()->IsPtrInt32() || load.GetType()->IsPtrFloat()) ? PhysReg::X8 : PhysReg::W8;
|
is_ptr ? PhysReg::X8 : PhysReg::W8;
|
||||||
block.Append(Opcode::LoadStack, {Operand::Reg(val_reg), Operand::FrameIndex(it->second)});
|
block.Append(Opcode::LoadStack, {Operand::Reg(val_reg), Operand::FrameIndex(it->second)});
|
||||||
block.Append(Opcode::StoreStack, {Operand::Reg(val_reg), Operand::FrameIndex(dst_slot)});
|
block.Append(Opcode::StoreStack, {Operand::Reg(val_reg), Operand::FrameIndex(dst_slot)});
|
||||||
return;
|
return;
|
||||||
@@ -177,7 +244,7 @@ void LowerInstruction(const ir::Instruction& inst, MachineFunction& function,
|
|||||||
|
|
||||||
// Dynamic load
|
// Dynamic load
|
||||||
PhysReg val_reg = load.GetType()->IsFloat() ? PhysReg::S8 :
|
PhysReg val_reg = load.GetType()->IsFloat() ? PhysReg::S8 :
|
||||||
(load.GetType()->IsPtrInt32() || load.GetType()->IsPtrFloat()) ? PhysReg::X8 : PhysReg::W8;
|
is_ptr ? PhysReg::X8 : PhysReg::W8;
|
||||||
EmitAddressToReg(load.GetPtr(), PhysReg::X9, slots, block);
|
EmitAddressToReg(load.GetPtr(), PhysReg::X9, slots, block);
|
||||||
block.Append(Opcode::LdrRegReg, {Operand::Reg(val_reg), Operand::Reg(PhysReg::X9)});
|
block.Append(Opcode::LdrRegReg, {Operand::Reg(val_reg), Operand::Reg(PhysReg::X9)});
|
||||||
block.Append(Opcode::StoreStack, {Operand::Reg(val_reg), Operand::FrameIndex(dst_slot)});
|
block.Append(Opcode::StoreStack, {Operand::Reg(val_reg), Operand::FrameIndex(dst_slot)});
|
||||||
@@ -342,8 +409,12 @@ void LowerInstruction(const ir::Instruction& inst, MachineFunction& function,
|
|||||||
auto slot_it = slots.find(phi);
|
auto slot_it = slots.find(phi);
|
||||||
if (slot_it != slots.end()) {
|
if (slot_it != slots.end()) {
|
||||||
int phi_slot = slot_it->second;
|
int phi_slot = slot_it->second;
|
||||||
|
bool is_ptr = phi->GetType()->IsPtrInt32() ||
|
||||||
|
phi->GetType()->IsPtrFloat() ||
|
||||||
|
pointers.find(phi) != pointers.end() ||
|
||||||
|
(incoming_val && (pointers.find(incoming_val) != pointers.end() || incoming_val->IsGlobalValue()));
|
||||||
PhysReg val_reg = phi->GetType()->IsFloat() ? PhysReg::S8 :
|
PhysReg val_reg = phi->GetType()->IsFloat() ? PhysReg::S8 :
|
||||||
(phi->GetType()->IsPtrInt32() || phi->GetType()->IsPtrFloat()) ? PhysReg::X8 : PhysReg::W8;
|
is_ptr ? PhysReg::X8 : PhysReg::W8;
|
||||||
EmitValueToReg(incoming_val, val_reg, slots, block);
|
EmitValueToReg(incoming_val, val_reg, slots, block);
|
||||||
block.Append(Opcode::StoreStack, {Operand::Reg(val_reg), Operand::FrameIndex(phi_slot)});
|
block.Append(Opcode::StoreStack, {Operand::Reg(val_reg), Operand::FrameIndex(phi_slot)});
|
||||||
}
|
}
|
||||||
@@ -372,7 +443,12 @@ void LowerInstruction(const ir::Instruction& inst, MachineFunction& function,
|
|||||||
case ir::Opcode::Ret: {
|
case ir::Opcode::Ret: {
|
||||||
auto& ret = static_cast<const ir::ReturnInst&>(inst);
|
auto& ret = static_cast<const ir::ReturnInst&>(inst);
|
||||||
if (ret.GetValue()) {
|
if (ret.GetValue()) {
|
||||||
PhysReg ret_reg = ret.GetValue()->GetType()->IsFloat() ? PhysReg::S0 : PhysReg::W0;
|
bool is_ptr = ret.GetValue()->GetType()->IsPtrInt32() ||
|
||||||
|
ret.GetValue()->GetType()->IsPtrFloat() ||
|
||||||
|
pointers.find(ret.GetValue()) != pointers.end() ||
|
||||||
|
ret.GetValue()->IsGlobalValue();
|
||||||
|
PhysReg ret_reg = ret.GetValue()->GetType()->IsFloat() ? PhysReg::S0 :
|
||||||
|
is_ptr ? PhysReg::X0 : PhysReg::W0;
|
||||||
EmitValueToReg(ret.GetValue(), ret_reg, slots, block);
|
EmitValueToReg(ret.GetValue(), ret_reg, slots, block);
|
||||||
}
|
}
|
||||||
block.Append(Opcode::Ret);
|
block.Append(Opcode::Ret);
|
||||||
@@ -380,9 +456,12 @@ void LowerInstruction(const ir::Instruction& inst, MachineFunction& function,
|
|||||||
}
|
}
|
||||||
case ir::Opcode::Call: {
|
case ir::Opcode::Call: {
|
||||||
auto& call = static_cast<const ir::CallInst&>(inst);
|
auto& call = static_cast<const ir::CallInst&>(inst);
|
||||||
|
bool is_ret_ptr = call.GetType()->IsPtrInt32() ||
|
||||||
|
call.GetType()->IsPtrFloat() ||
|
||||||
|
pointers.find(&inst) != pointers.end();
|
||||||
int dst_slot = -1;
|
int dst_slot = -1;
|
||||||
if (!call.GetType()->IsVoid()) {
|
if (!call.GetType()->IsVoid()) {
|
||||||
dst_slot = function.CreateFrameIndex(GetTypeSize(call.GetType().get()));
|
dst_slot = function.CreateFrameIndex(is_ret_ptr ? 8 : GetTypeSize(call.GetType().get()));
|
||||||
slots.emplace(&inst, dst_slot);
|
slots.emplace(&inst, dst_slot);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -395,7 +474,11 @@ void LowerInstruction(const ir::Instruction& inst, MachineFunction& function,
|
|||||||
EmitValueToReg(arg, reg, slots, block);
|
EmitValueToReg(arg, reg, slots, block);
|
||||||
float_idx++;
|
float_idx++;
|
||||||
} else {
|
} else {
|
||||||
PhysReg reg = (arg->GetType()->IsPtrInt32() || arg->GetType()->IsPtrFloat())
|
bool is_arg_ptr = arg->GetType()->IsPtrInt32() ||
|
||||||
|
arg->GetType()->IsPtrFloat() ||
|
||||||
|
pointers.find(arg) != pointers.end() ||
|
||||||
|
arg->IsGlobalValue();
|
||||||
|
PhysReg reg = is_arg_ptr
|
||||||
? static_cast<PhysReg>(static_cast<int>(PhysReg::X0) + int_idx)
|
? static_cast<PhysReg>(static_cast<int>(PhysReg::X0) + int_idx)
|
||||||
: static_cast<PhysReg>(static_cast<int>(PhysReg::W0) + int_idx);
|
: static_cast<PhysReg>(static_cast<int>(PhysReg::W0) + int_idx);
|
||||||
EmitValueToReg(arg, reg, slots, block);
|
EmitValueToReg(arg, reg, slots, block);
|
||||||
@@ -409,7 +492,7 @@ void LowerInstruction(const ir::Instruction& inst, MachineFunction& function,
|
|||||||
if (call.GetType()->IsFloat()) {
|
if (call.GetType()->IsFloat()) {
|
||||||
block.Append(Opcode::StoreStack, {Operand::Reg(PhysReg::S0), Operand::FrameIndex(dst_slot)});
|
block.Append(Opcode::StoreStack, {Operand::Reg(PhysReg::S0), Operand::FrameIndex(dst_slot)});
|
||||||
} else {
|
} else {
|
||||||
PhysReg ret_reg = (call.GetType()->IsPtrInt32() || call.GetType()->IsPtrFloat()) ? PhysReg::X0 : PhysReg::W0;
|
PhysReg ret_reg = is_ret_ptr ? PhysReg::X0 : PhysReg::W0;
|
||||||
block.Append(Opcode::StoreStack, {Operand::Reg(ret_reg), Operand::FrameIndex(dst_slot)});
|
block.Append(Opcode::StoreStack, {Operand::Reg(ret_reg), Operand::FrameIndex(dst_slot)});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -438,11 +521,12 @@ void LowerInstruction(const ir::Instruction& inst, MachineFunction& function,
|
|||||||
auto* idx = gep.GetOperand(i);
|
auto* idx = gep.GetOperand(i);
|
||||||
uint32_t stride = strides.at(i - 1);
|
uint32_t stride = strides.at(i - 1);
|
||||||
|
|
||||||
// Skip if offset index is constant 0
|
|
||||||
if (auto* ci = dynamic_cast<const ir::ConstantInt*>(idx)) {
|
if (auto* ci = dynamic_cast<const ir::ConstantInt*>(idx)) {
|
||||||
if (ci->GetValue() == 0) {
|
int64_t offset = static_cast<int64_t>(ci->GetValue()) * stride;
|
||||||
continue;
|
if (offset != 0) {
|
||||||
|
block.Append(Opcode::AddRegImm, {Operand::Reg(PhysReg::X8), Operand::Reg(PhysReg::X8), Operand::Imm(offset)});
|
||||||
}
|
}
|
||||||
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
EmitValueToReg(idx, PhysReg::W9, slots, block);
|
EmitValueToReg(idx, PhysReg::W9, slots, block);
|
||||||
@@ -477,6 +561,7 @@ std::vector<std::unique_ptr<MachineFunction>> LowerToMIR(const ir::Module& modul
|
|||||||
|
|
||||||
auto machine_func = std::make_unique<MachineFunction>(func.GetName());
|
auto machine_func = std::make_unique<MachineFunction>(func.GetName());
|
||||||
ValueSlotMap slots;
|
ValueSlotMap slots;
|
||||||
|
auto pointers = IdentifyPointerValues(func);
|
||||||
|
|
||||||
// First, create all basic blocks in MachineFunction
|
// First, create all basic blocks in MachineFunction
|
||||||
std::unordered_map<const ir::BasicBlock*, MachineBasicBlock*> bb_map;
|
std::unordered_map<const ir::BasicBlock*, MachineBasicBlock*> bb_map;
|
||||||
@@ -490,7 +575,10 @@ std::vector<std::unique_ptr<MachineFunction>> LowerToMIR(const ir::Module& modul
|
|||||||
for (const auto& bbPtr : func.GetBlocks()) {
|
for (const auto& bbPtr : func.GetBlocks()) {
|
||||||
for (const auto& inst : bbPtr->GetInstructions()) {
|
for (const auto& inst : bbPtr->GetInstructions()) {
|
||||||
if (inst->GetOpcode() == ir::Opcode::Phi) {
|
if (inst->GetOpcode() == ir::Opcode::Phi) {
|
||||||
int slot = machine_func->CreateFrameIndex(GetTypeSize(inst->GetType().get()));
|
bool is_phi_ptr = inst->GetType()->IsPtrInt32() ||
|
||||||
|
inst->GetType()->IsPtrFloat() ||
|
||||||
|
pointers.find(inst.get()) != pointers.end();
|
||||||
|
int slot = machine_func->CreateFrameIndex(is_phi_ptr ? 8 : GetTypeSize(inst->GetType().get()));
|
||||||
slots.emplace(inst.get(), slot);
|
slots.emplace(inst.get(), slot);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -503,7 +591,10 @@ std::vector<std::unique_ptr<MachineFunction>> LowerToMIR(const ir::Module& modul
|
|||||||
int int_idx = 0;
|
int int_idx = 0;
|
||||||
int float_idx = 0;
|
int float_idx = 0;
|
||||||
for (const auto& arg : args) {
|
for (const auto& arg : args) {
|
||||||
int slot = machine_func->CreateFrameIndex(GetTypeSize(arg->GetType().get()));
|
bool is_arg_ptr = arg->GetType()->IsPtrInt32() ||
|
||||||
|
arg->GetType()->IsPtrFloat() ||
|
||||||
|
pointers.find(arg.get()) != pointers.end();
|
||||||
|
int slot = machine_func->CreateFrameIndex(is_arg_ptr ? 8 : GetTypeSize(arg->GetType().get()));
|
||||||
slots.emplace(arg.get(), slot);
|
slots.emplace(arg.get(), slot);
|
||||||
|
|
||||||
if (arg->GetType()->IsFloat()) {
|
if (arg->GetType()->IsFloat()) {
|
||||||
@@ -511,7 +602,7 @@ std::vector<std::unique_ptr<MachineFunction>> LowerToMIR(const ir::Module& modul
|
|||||||
entry_block.Append(Opcode::StoreStack, {Operand::Reg(reg), Operand::FrameIndex(slot)});
|
entry_block.Append(Opcode::StoreStack, {Operand::Reg(reg), Operand::FrameIndex(slot)});
|
||||||
float_idx++;
|
float_idx++;
|
||||||
} else {
|
} else {
|
||||||
PhysReg reg = (arg->GetType()->IsPtrInt32() || arg->GetType()->IsPtrFloat())
|
PhysReg reg = is_arg_ptr
|
||||||
? static_cast<PhysReg>(static_cast<int>(PhysReg::X0) + int_idx)
|
? static_cast<PhysReg>(static_cast<int>(PhysReg::X0) + int_idx)
|
||||||
: static_cast<PhysReg>(static_cast<int>(PhysReg::W0) + int_idx);
|
: static_cast<PhysReg>(static_cast<int>(PhysReg::W0) + int_idx);
|
||||||
entry_block.Append(Opcode::StoreStack, {Operand::Reg(reg), Operand::FrameIndex(slot)});
|
entry_block.Append(Opcode::StoreStack, {Operand::Reg(reg), Operand::FrameIndex(slot)});
|
||||||
@@ -523,7 +614,7 @@ std::vector<std::unique_ptr<MachineFunction>> LowerToMIR(const ir::Module& modul
|
|||||||
for (const auto& bbPtr : func.GetBlocks()) {
|
for (const auto& bbPtr : func.GetBlocks()) {
|
||||||
auto& mbb = *bb_map.at(bbPtr.get());
|
auto& mbb = *bb_map.at(bbPtr.get());
|
||||||
for (const auto& inst : bbPtr->GetInstructions()) {
|
for (const auto& inst : bbPtr->GetInstructions()) {
|
||||||
LowerInstruction(*inst, *machine_func, slots, mbb);
|
LowerInstruction(*inst, *machine_func, slots, mbb, pointers);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
#include "mir/MIR.h"
|
#include "mir/MIR.h"
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
|
#include <unordered_set>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
namespace mir {
|
namespace mir {
|
||||||
@@ -99,10 +100,14 @@ void RunPeephole(MachineFunction& function) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 3. Track stores
|
// 3. Track and optimize stores
|
||||||
if (op == Opcode::StoreStack) {
|
if (op == Opcode::StoreStack) {
|
||||||
PhysReg src = NormalizeReg(ops.at(0).GetReg());
|
PhysReg src = NormalizeReg(ops.at(0).GetReg());
|
||||||
int fi = ops.at(1).GetFrameIndex();
|
int fi = ops.at(1).GetFrameIndex();
|
||||||
|
auto it = slot_to_reg.find(fi);
|
||||||
|
if (it != slot_to_reg.end() && NormalizeReg(it->second) == src) {
|
||||||
|
continue; // Delete redundant store
|
||||||
|
}
|
||||||
slot_to_reg[fi] = src;
|
slot_to_reg[fi] = src;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -180,6 +185,54 @@ void RunPeephole(MachineFunction& function) {
|
|||||||
|
|
||||||
insts = std::move(optimized);
|
insts = std::move(optimized);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 5. Eliminate Dead Stack Slots (stores to slots that are never loaded or address-taken)
|
||||||
|
// Count loads and address-taken operations
|
||||||
|
std::unordered_map<int, int> load_count;
|
||||||
|
std::unordered_map<int, int> address_taken_count;
|
||||||
|
|
||||||
|
for (const auto& block : function.GetBlocks()) {
|
||||||
|
for (const auto& inst : block.GetInstructions()) {
|
||||||
|
Opcode op = inst.GetOpcode();
|
||||||
|
const auto& ops = inst.GetOperands();
|
||||||
|
|
||||||
|
for (const auto& opnd : ops) {
|
||||||
|
if (opnd.GetKind() == Operand::Kind::FrameIndex) {
|
||||||
|
int fi = opnd.GetFrameIndex();
|
||||||
|
if (op == Opcode::LoadStack) {
|
||||||
|
load_count[fi]++;
|
||||||
|
} else if (op != Opcode::StoreStack) {
|
||||||
|
address_taken_count[fi]++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Identify dead slots
|
||||||
|
std::unordered_set<int> dead_slots;
|
||||||
|
for (size_t i = 0; i < function.GetFrameSlots().size(); ++i) {
|
||||||
|
int fi = static_cast<int>(i);
|
||||||
|
if (load_count[fi] == 0 && address_taken_count[fi] == 0) {
|
||||||
|
dead_slots.insert(fi);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove StoreStack to dead slots
|
||||||
|
for (auto& block : function.GetBlocks()) {
|
||||||
|
auto& insts = block.GetInstructions();
|
||||||
|
std::vector<MachineInstr> optimized;
|
||||||
|
for (const auto& inst : insts) {
|
||||||
|
if (inst.GetOpcode() == Opcode::StoreStack) {
|
||||||
|
int fi = inst.GetOperands().at(1).GetFrameIndex();
|
||||||
|
if (dead_slots.find(fi) != dead_slots.end()) {
|
||||||
|
continue; // Delete this store
|
||||||
|
}
|
||||||
|
}
|
||||||
|
optimized.push_back(inst);
|
||||||
|
}
|
||||||
|
insts = std::move(optimized);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mir
|
} // namespace mir
|
||||||
|
|||||||
@@ -211,67 +211,25 @@ class SemaVisitor final : public SysYBaseVisitor {
|
|||||||
return ctx->exp()->accept(this);
|
return ctx->exp()->accept(this);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::any visitMulExp(SysYParser::MulExpContext* ctx) override {
|
std::any visitMulDivModExp(SysYParser::MulDivModExpContext* ctx) override {
|
||||||
ctx->exp(0)->accept(this);
|
ctx->exp(0)->accept(this);
|
||||||
ctx->exp(1)->accept(this);
|
ctx->exp(1)->accept(this);
|
||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
|
|
||||||
std::any visitDivExp(SysYParser::DivExpContext* ctx) override {
|
std::any visitAddSubExp(SysYParser::AddSubExpContext* ctx) override {
|
||||||
ctx->exp(0)->accept(this);
|
ctx->exp(0)->accept(this);
|
||||||
ctx->exp(1)->accept(this);
|
ctx->exp(1)->accept(this);
|
||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
|
|
||||||
std::any visitModExp(SysYParser::ModExpContext* ctx) override {
|
std::any visitRelExp(SysYParser::RelExpContext* ctx) override {
|
||||||
ctx->exp(0)->accept(this);
|
ctx->exp(0)->accept(this);
|
||||||
ctx->exp(1)->accept(this);
|
ctx->exp(1)->accept(this);
|
||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
|
|
||||||
std::any visitAddExp(SysYParser::AddExpContext* ctx) override {
|
std::any visitEqNeExp(SysYParser::EqNeExpContext* ctx) override {
|
||||||
ctx->exp(0)->accept(this);
|
|
||||||
ctx->exp(1)->accept(this);
|
|
||||||
return {};
|
|
||||||
}
|
|
||||||
|
|
||||||
std::any visitSubExp(SysYParser::SubExpContext* ctx) override {
|
|
||||||
ctx->exp(0)->accept(this);
|
|
||||||
ctx->exp(1)->accept(this);
|
|
||||||
return {};
|
|
||||||
}
|
|
||||||
|
|
||||||
std::any visitLtExp(SysYParser::LtExpContext* ctx) override {
|
|
||||||
ctx->exp(0)->accept(this);
|
|
||||||
ctx->exp(1)->accept(this);
|
|
||||||
return {};
|
|
||||||
}
|
|
||||||
|
|
||||||
std::any visitLeExp(SysYParser::LeExpContext* ctx) override {
|
|
||||||
ctx->exp(0)->accept(this);
|
|
||||||
ctx->exp(1)->accept(this);
|
|
||||||
return {};
|
|
||||||
}
|
|
||||||
|
|
||||||
std::any visitGtExp(SysYParser::GtExpContext* ctx) override {
|
|
||||||
ctx->exp(0)->accept(this);
|
|
||||||
ctx->exp(1)->accept(this);
|
|
||||||
return {};
|
|
||||||
}
|
|
||||||
|
|
||||||
std::any visitGeExp(SysYParser::GeExpContext* ctx) override {
|
|
||||||
ctx->exp(0)->accept(this);
|
|
||||||
ctx->exp(1)->accept(this);
|
|
||||||
return {};
|
|
||||||
}
|
|
||||||
|
|
||||||
std::any visitEqExp(SysYParser::EqExpContext* ctx) override {
|
|
||||||
ctx->exp(0)->accept(this);
|
|
||||||
ctx->exp(1)->accept(this);
|
|
||||||
return {};
|
|
||||||
}
|
|
||||||
|
|
||||||
std::any visitNeExp(SysYParser::NeExpContext* ctx) override {
|
|
||||||
ctx->exp(0)->accept(this);
|
ctx->exp(0)->accept(this);
|
||||||
ctx->exp(1)->accept(this);
|
ctx->exp(1)->accept(this);
|
||||||
return {};
|
return {};
|
||||||
|
|||||||
Reference in New Issue
Block a user