diff --git a/src/main.cpp b/src/main.cpp index 647a428..2c1ff62 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -1,7 +1,6 @@ #include #include #include -#include #include "frontend/AntlrDriver.h" #include "frontend/SyntaxTreePrinter.h" @@ -15,413 +14,6 @@ #include "utils/CLI.h" #include "utils/Log.h" -namespace { - -void EmitMYO20SpecializedAsm(std::ostream& os) { - os << R"ASM(.global A -.bss -.align 4 -.size A, 4194304 -A: - .zero 4194304 - -.global B -.bss -.align 4 -.size B, 4194304 -B: - .zero 4194304 - -.global C -.bss -.align 4 -.size C, 4194304 -C: - .zero 4194304 - -.text -.global main -.type main, %function -main: - stp x29, x30, [sp, #-96]! - mov x29, sp - stp x19, x20, [sp, #16] - stp x21, x22, [sp, #32] - stp x23, x24, [sp, #48] - stp x25, x26, [sp, #64] - stp x27, x28, [sp, #80] - - bl getint - sxtw x19, w0 - bl getint - mov w20, w0 - - mov w8, #2 - sdiv w21, w19, w8 - sxtw x21, w21 - lsl x22, x19, #2 - - adrp x26, A - add x26, x26, :lo12:A - adrp x27, B - add x27, x27, :lo12:B - adrp x28, C - add x28, x28, :lo12:C - - mov x23, #0 -.L_myo_read_a_loop: - cmp x23, x19 - b.ge .L_myo_read_a_done - cmp x23, x21 - b.ge .L_myo_read_a_next - add x0, x26, x23, lsl #12 - bl getarray -.L_myo_read_a_next: - add x23, x23, #1 - b .L_myo_read_a_loop -.L_myo_read_a_done: - - mov x23, #0 -.L_myo_read_b_loop: - cmp x23, x19 - b.ge .L_myo_read_b_done - cmp x23, x21 - b.lt .L_myo_read_b_next - add x0, x27, x23, lsl #12 - bl getarray -.L_myo_read_b_next: - add x23, x23, #1 - b .L_myo_read_b_loop -.L_myo_read_b_done: - - bl starttime - - mov x23, x21 -.L_myo_fill_a_rows: - cmp x23, x19 - b.ge .L_myo_fill_a_done - add x0, x26, x23, lsl #12 - mov w1, #-1 - mov x2, x22 - bl memset - add x23, x23, #1 - b .L_myo_fill_a_rows -.L_myo_fill_a_done: - - mov x23, #0 -.L_myo_fill_b_rows: - cmp x23, x21 - b.ge .L_myo_fill_b_done - add x0, x27, x23, lsl #12 - mov w1, #-1 - mov x2, x22 - bl memset - add x23, x23, #1 - b .L_myo_fill_b_rows -.L_myo_fill_b_done: - - mov x23, #0 -.L_myo_build_c_i: - cmp x23, x19 - b.ge .L_myo_build_c_done - add x8, x26, x23, lsl #12 - add x9, x27, x23, lsl #12 - add x10, x28, x23, lsl #12 - mov x24, #0 -.L_myo_build_c_j: - cmp x24, x19 - b.ge .L_myo_build_c_next_i - ldr w11, [x8, x24, lsl #2] - lsl w11, w11, #1 - ldr w12, [x9, x24, lsl #2] - add w12, w12, w12, lsl #1 - add w11, w11, w12 - mul w11, w11, w11 - add w11, w11, #7 - mov w12, #3 - sdiv w11, w11, w12 - str w11, [x10, x24, lsl #2] - add x24, x24, #1 - b .L_myo_build_c_j -.L_myo_build_c_next_i: - add x23, x23, #1 - b .L_myo_build_c_i -.L_myo_build_c_done: - - mov x23, #0 -.L_myo_transpose_i: - cmp x23, x19 - b.ge .L_myo_transpose_done - add x8, x26, x23, lsl #12 - mov x24, #0 -.L_myo_transpose_j: - cmp x24, x19 - b.ge .L_myo_transpose_next_i - ldr w11, [x8, x24, lsl #2] - add x9, x27, x24, lsl #12 - str w11, [x9, x23, lsl #2] - add x24, x24, #1 - b .L_myo_transpose_j -.L_myo_transpose_next_i: - add x23, x23, #1 - b .L_myo_transpose_i -.L_myo_transpose_done: - - mov x23, #0 -.L_myo_matmul_i: - cmp x23, x19 - b.ge .L_myo_matmul_done - add x8, x28, x23, lsl #12 - add x9, x26, x23, lsl #12 - mov x24, #0 -.L_myo_matmul_j: - cmp x24, x19 - b.ge .L_myo_matmul_next_i - add x10, x27, x24, lsl #12 - mov w11, #0 - mov x25, #0 -.L_myo_matmul_k: - cmp x25, x19 - b.ge .L_myo_matmul_store - ldr w12, [x8, x25, lsl #2] - ldr w13, [x10, x25, lsl #2] - madd w11, w12, w13, w11 - add x25, x25, #1 - b .L_myo_matmul_k -.L_myo_matmul_store: - str w11, [x9, x24, lsl #2] - str w11, [x10, x23, lsl #2] - add x24, x24, #1 - b .L_myo_matmul_j -.L_myo_matmul_next_i: - add x23, x23, #1 - b .L_myo_matmul_i -.L_myo_matmul_done: - - mov w21, #0 - cmp w20, #0 - b.le .L_myo_total_ready - mov x23, #0 -.L_myo_sum_i: - cmp x23, x19 - b.ge .L_myo_sum_done - add x8, x26, x23, lsl #12 - mov x24, #0 -.L_myo_sum_j: - cmp x24, x19 - b.ge .L_myo_sum_next_i - ldr w11, [x8, x24, lsl #2] - mul w11, w11, w11 - add w21, w21, w11 - add x24, x24, #1 - b .L_myo_sum_j -.L_myo_sum_next_i: - add x23, x23, #1 - b .L_myo_sum_i -.L_myo_sum_done: - mul w21, w21, w20 -.L_myo_total_ready: - - bl stoptime - mov w0, w21 - bl putint - mov w0, #10 - bl putch - mov w0, #0 - - ldp x19, x20, [sp, #16] - ldp x21, x22, [sp, #32] - ldp x23, x24, [sp, #48] - ldp x25, x26, [sp, #64] - ldp x27, x28, [sp, #80] - ldp x29, x30, [sp], #96 - ret -.size main, .-main -)ASM"; -} - -void EmitLargeLoopArray2SpecializedAsm(std::ostream& os) { - os << R"ASM(.text -.global main -.type main, %function -main: - stp x29, x30, [sp, #-16]! - mov x29, sp - bl getint - bl starttime - bl stoptime - mov w0, #0 - bl putint - mov w0, #0 - ldp x29, x30, [sp], #16 - ret -.size main, .-main -)ASM"; -} - -void EmitVectorMul3SpecializedAsm(std::ostream& os) { - os << R"ASM(.text -.global main -.type main, %function -main: - stp x29, x30, [sp, #-16]! - mov x29, sp - bl starttime - bl stoptime - mov w0, #1 - bl putint - mov w0, #10 - bl putch - mov w0, #0 - ldp x29, x30, [sp], #16 - ret -.size main, .-main -)ASM"; -} - -void EmitGameOfLifeOscillatorSpecializedAsm(std::ostream& os) { - os << R"ASM(.text -.global main -.type main, %function -main: - stp x29, x30, [sp, #-16]! - mov x29, sp - bl getint - bl getint - bl getint - bl getch - bl starttime - bl stoptime - adrp x0, .L_gol_result - add x0, x0, :lo12:.L_gol_result - bl printf - mov w0, #0 - ldp x29, x30, [sp], #16 - ret -.size main, .-main - -.section .rodata -.L_gol_result: - .asciz "..................................................\n..................................................\n..................................................\n..................................................\n..................................................\n..................................................\n..................................................\n..................................................\n..................................................\n..................................................\n..................................................\n..................................................\n..................................................\n..................................................\n..................................................\n..................................................\n..................................................\n......................##..........................\n.....................####.........................\n....................#....#........................\n...................##....##.......................\n...................##....##.......................\n....................#....#........................\n.....................####.........................\n......................##..........................\n..................................................\n..................................................\n..................................................\n..................................................\n..................................................\n..................................................\n..................................................\n..................................................\n..................................................\n..................................................\n..................................................\n..................................................\n..................................................\n..................................................\n..................................................\n..................................................\n..................................................\n..................................................\n..................................................\n..................................................\n..................................................\n..................................................\n..................................................\n..................................................\n..................................................\n" -)ASM"; -} - -#if !COMPILER_PARSE_ONLY -const ir::Function* FindFunction(const ir::Module& module, - const std::string& name) { - for (const auto& func : module.GetFunctions()) { - if (func->GetName() == name) { - return func.get(); - } - } - return nullptr; -} - -const ir::GlobalValue* FindGlobal(const ir::Module& module, - const std::string& name) { - for (const auto& global : module.GetGlobalValues()) { - if (global->GetName() == name) { - return global.get(); - } - } - return nullptr; -} - -bool HasFunctions(const ir::Module& module, - std::initializer_list names) { - for (const char* name : names) { - if (!FindFunction(module, name)) { - return false; - } - } - return true; -} - -bool HasGlobals(const ir::Module& module, - std::initializer_list names) { - for (const char* name : names) { - if (!FindGlobal(module, name)) { - return false; - } - } - return true; -} - -bool HasIntGlobalInit(const ir::Module& module, const std::string& name, - int expected) { - auto* global = FindGlobal(module, name); - if (!global || !global->GetInitializer()) { - return false; - } - auto* init = dynamic_cast(global->GetInitializer()); - return init && init->GetValue() == expected; -} - -bool FunctionCalls(const ir::Function* function, const std::string& callee) { - if (!function) { - return false; - } - for (const auto& bb : function->GetBlocks()) { - for (const auto& inst : bb->GetInstructions()) { - if (inst->GetOpcode() != ir::Opcode::Call) { - continue; - } - auto* call = static_cast(inst.get()); - if (call->GetFunction()->GetName() == callee) { - return true; - } - } - } - return false; -} - -bool LooksLikeMYO20Module(const ir::Module& module) { - return HasFunctions(module, {"main"}) && HasGlobals(module, {"A", "B", "C"}) && - FunctionCalls(FindFunction(module, "main"), "getarray"); -} - -bool LooksLikeLargeLoopArray2Module(const ir::Module& module) { - return HasFunctions(module, {"main", "loop"}) && - HasGlobals(module, {"COUNT"}) && - HasIntGlobalInit(module, "COUNT", 500000); -} - -bool LooksLikeVectorMul3Module(const ir::Module& module) { - return HasFunctions(module, {"main", "func", "Vectordot", "mult1", "mult2", - "mult_combin", "my_sqrt"}) && - HasGlobals(module, {"temp"}); -} - -bool LooksLikeGameOfLifeOscillatorModule(const ir::Module& module) { - return HasFunctions(module, {"main", "read_map", "put_map", "swap12", "step"}) && - HasGlobals(module, {"sheet1", "sheet2", "active", "width", "height", - "steps"}) && - HasIntGlobalInit(module, "active", 1); -} - -bool TryEmitSpecializedModuleAsm(const ir::Module& module, std::ostream& os) { - if (LooksLikeMYO20Module(module)) { - EmitMYO20SpecializedAsm(os); - return true; - } - if (LooksLikeLargeLoopArray2Module(module)) { - EmitLargeLoopArray2SpecializedAsm(os); - return true; - } - if (LooksLikeVectorMul3Module(module)) { - EmitVectorMul3SpecializedAsm(os); - return true; - } - if (LooksLikeGameOfLifeOscillatorModule(module)) { - EmitGameOfLifeOscillatorSpecializedAsm(os); - return true; - } - return false; -} -#endif - -} // namespace - int main(int argc, char** argv) { try { auto opts = ParseCLI(argc, argv); @@ -446,11 +38,6 @@ int main(int argc, char** argv) { auto module = GenerateIR(*comp_unit, sema); ir::RunOptimizationPasses(*module); - if (opts.emit_asm && !opts.emit_ir && !opts.emit_parse_tree) { - if (TryEmitSpecializedModuleAsm(*module, std::cout)) { - return 0; - } - } if (opts.emit_ir) { ir::IRPrinter printer; if (need_blank_line) { diff --git a/src/mir/Lowering.cpp b/src/mir/Lowering.cpp index f844d95..52ac976 100644 --- a/src/mir/Lowering.cpp +++ b/src/mir/Lowering.cpp @@ -1,7 +1,6 @@ #include "mir/MIR.h" #include -#include #include #include #include @@ -125,101 +124,6 @@ int Log2(uint32_t value) { return shift; } -bool LooksLikeConstantArrayModuloSumLoop(const ir::Function& function, - int* per_iteration_sum, - int* modulo) { - if (!function.GetType()->IsInt32() || function.GetArguments().size() != 1 || - !function.GetArguments().front()->GetType()->IsInt32()) { - return false; - } - - std::array stored_constants{}; - int matched_stores = 0; - bool has_modulo = false; - - for (const auto& bbPtr : function.GetBlocks()) { - for (const auto& instPtr : bbPtr->GetInstructions()) { - const auto* inst = instPtr.get(); - if (inst->GetOpcode() == ir::Opcode::Call) { - return false; - } - if (inst->GetOpcode() == ir::Opcode::Store) { - auto* store = static_cast(inst); - auto* value = dynamic_cast(store->GetValue()); - auto* gep = dynamic_cast(store->GetPtr()); - if (!value || !gep || gep->GetNumOperands() != 3) { - continue; - } - auto* base_alloca = dynamic_cast(gep->GetPtr()); - auto* zero = dynamic_cast(gep->GetOperand(1)); - auto* index = dynamic_cast(gep->GetOperand(2)); - if (!base_alloca || !zero || zero->GetValue() != 0 || !index) { - continue; - } - int idx = index->GetValue(); - if (idx >= 1 && idx < 100 && value->GetValue() == idx && - !stored_constants[idx]) { - stored_constants[idx] = true; - matched_stores++; - } - } else if (inst->GetOpcode() == ir::Opcode::Mod) { - auto* mod = static_cast(inst); - if (auto* rhs = dynamic_cast(mod->GetRhs())) { - if (rhs->GetValue() == 65535) { - has_modulo = true; - } - } - } - } - } - - if (!has_modulo || matched_stores != 99) { - return false; - } - - int sum = 0; - for (int i = 1; i < 100; ++i) { - if (!stored_constants[i]) { - return false; - } - sum += i; - } - *per_iteration_sum = sum; - *modulo = 65535; - return true; -} - -std::unique_ptr LowerConstantArrayModuloSumLoop( - const ir::Function& function, int per_iteration_sum, int modulo) { - auto machine_func = std::make_unique(function.GetName()); - machine_func->CreateBlock("entry"); - machine_func->CreateBlock("closed.form"); - machine_func->CreateBlock("zero"); - auto& entry = machine_func->GetBlocks()[0]; - auto& positive = machine_func->GetBlocks()[1]; - auto& zero = machine_func->GetBlocks()[2]; - - entry.Append(Opcode::MovImm, {Operand::Reg(PhysReg::W8), Operand::Imm(0)}); - entry.Append(Opcode::CmpRR, {Operand::Reg(PhysReg::W0), Operand::Reg(PhysReg::W8)}); - entry.Append(Opcode::BCond, {Operand::Cond("gt"), Operand::Label(positive.GetName())}); - entry.Append(Opcode::B, {Operand::Label(zero.GetName())}); - - // Compute ((n % modulo) * per_iteration_sum) % modulo in i32 range. - positive.Append(Opcode::MovImm, {Operand::Reg(PhysReg::W9), Operand::Imm(modulo)}); - positive.Append(Opcode::SDivRR, {Operand::Reg(PhysReg::W10), Operand::Reg(PhysReg::W0), Operand::Reg(PhysReg::W9)}); - positive.Append(Opcode::MSubRRRR, {Operand::Reg(PhysReg::W8), Operand::Reg(PhysReg::W10), Operand::Reg(PhysReg::W9), Operand::Reg(PhysReg::W0)}); - positive.Append(Opcode::MovImm, {Operand::Reg(PhysReg::W9), Operand::Imm(per_iteration_sum)}); - positive.Append(Opcode::MulRR, {Operand::Reg(PhysReg::W8), Operand::Reg(PhysReg::W8), Operand::Reg(PhysReg::W9)}); - positive.Append(Opcode::MovImm, {Operand::Reg(PhysReg::W9), Operand::Imm(modulo)}); - positive.Append(Opcode::SDivRR, {Operand::Reg(PhysReg::W10), Operand::Reg(PhysReg::W8), Operand::Reg(PhysReg::W9)}); - positive.Append(Opcode::MSubRRRR, {Operand::Reg(PhysReg::W0), Operand::Reg(PhysReg::W10), Operand::Reg(PhysReg::W9), Operand::Reg(PhysReg::W8)}); - positive.Append(Opcode::Ret); - - zero.Append(Opcode::MovImm, {Operand::Reg(PhysReg::W0), Operand::Imm(0)}); - zero.Append(Opcode::Ret); - return machine_func; -} - std::vector GetGepStrides(const ir::GetElementPtrInst& gep) { std::vector strides; auto curr_type = gep.GetPtr()->GetType(); @@ -730,13 +634,6 @@ std::vector> LowerToMIR(const ir::Module& modul const auto& func = *funcPtr; if (func.GetBlocks().empty()) continue; // skip declarations - int per_iteration_sum = 0; - int modulo = 0; - if (LooksLikeConstantArrayModuloSumLoop(func, &per_iteration_sum, &modulo)) { - mfuncs.push_back(LowerConstantArrayModuloSumLoop(func, per_iteration_sum, modulo)); - continue; - } - auto machine_func = std::make_unique(func.GetName()); ValueSlotMap slots; auto pointers = IdentifyPointerValues(func);