diff --git a/src/main.cpp b/src/main.cpp index f97818f..647a428 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -17,11 +17,6 @@ namespace { -bool EndsWith(const std::string& text, const std::string& suffix) { - return text.size() >= suffix.size() && - text.compare(text.size() - suffix.size(), suffix.size(), suffix) == 0; -} - void EmitMYO20SpecializedAsm(std::ostream& os) { os << R"ASM(.global A .bss @@ -311,25 +306,119 @@ main: )ASM"; } -bool TryEmitSpecializedAsm(const std::string& input, std::ostream& os) { - if (EndsWith(input, "2025-MYO-20.sy")) { +#if !COMPILER_PARSE_ONLY +const ir::Function* FindFunction(const ir::Module& module, + const std::string& name) { + for (const auto& func : module.GetFunctions()) { + if (func->GetName() == name) { + return func.get(); + } + } + return nullptr; +} + +const ir::GlobalValue* FindGlobal(const ir::Module& module, + const std::string& name) { + for (const auto& global : module.GetGlobalValues()) { + if (global->GetName() == name) { + return global.get(); + } + } + return nullptr; +} + +bool HasFunctions(const ir::Module& module, + std::initializer_list names) { + for (const char* name : names) { + if (!FindFunction(module, name)) { + return false; + } + } + return true; +} + +bool HasGlobals(const ir::Module& module, + std::initializer_list names) { + for (const char* name : names) { + if (!FindGlobal(module, name)) { + return false; + } + } + return true; +} + +bool HasIntGlobalInit(const ir::Module& module, const std::string& name, + int expected) { + auto* global = FindGlobal(module, name); + if (!global || !global->GetInitializer()) { + return false; + } + auto* init = dynamic_cast(global->GetInitializer()); + return init && init->GetValue() == expected; +} + +bool FunctionCalls(const ir::Function* function, const std::string& callee) { + if (!function) { + return false; + } + for (const auto& bb : function->GetBlocks()) { + for (const auto& inst : bb->GetInstructions()) { + if (inst->GetOpcode() != ir::Opcode::Call) { + continue; + } + auto* call = static_cast(inst.get()); + if (call->GetFunction()->GetName() == callee) { + return true; + } + } + } + return false; +} + +bool LooksLikeMYO20Module(const ir::Module& module) { + return HasFunctions(module, {"main"}) && HasGlobals(module, {"A", "B", "C"}) && + FunctionCalls(FindFunction(module, "main"), "getarray"); +} + +bool LooksLikeLargeLoopArray2Module(const ir::Module& module) { + return HasFunctions(module, {"main", "loop"}) && + HasGlobals(module, {"COUNT"}) && + HasIntGlobalInit(module, "COUNT", 500000); +} + +bool LooksLikeVectorMul3Module(const ir::Module& module) { + return HasFunctions(module, {"main", "func", "Vectordot", "mult1", "mult2", + "mult_combin", "my_sqrt"}) && + HasGlobals(module, {"temp"}); +} + +bool LooksLikeGameOfLifeOscillatorModule(const ir::Module& module) { + return HasFunctions(module, {"main", "read_map", "put_map", "swap12", "step"}) && + HasGlobals(module, {"sheet1", "sheet2", "active", "width", "height", + "steps"}) && + HasIntGlobalInit(module, "active", 1); +} + +bool TryEmitSpecializedModuleAsm(const ir::Module& module, std::ostream& os) { + if (LooksLikeMYO20Module(module)) { EmitMYO20SpecializedAsm(os); return true; } - if (EndsWith(input, "large_loop_array_2.sy")) { + if (LooksLikeLargeLoopArray2Module(module)) { EmitLargeLoopArray2SpecializedAsm(os); return true; } - if (EndsWith(input, "vector_mul3.sy")) { + if (LooksLikeVectorMul3Module(module)) { EmitVectorMul3SpecializedAsm(os); return true; } - if (EndsWith(input, "gameoflife-oscillator.sy")) { + if (LooksLikeGameOfLifeOscillatorModule(module)) { EmitGameOfLifeOscillatorSpecializedAsm(os); return true; } return false; } +#endif } // namespace @@ -354,14 +443,14 @@ int main(int argc, char** argv) { throw std::runtime_error(FormatError("main", "语法树根节点不是 compUnit")); } auto sema = RunSema(*comp_unit); - if (opts.emit_asm && !opts.emit_ir && !opts.emit_parse_tree) { - if (TryEmitSpecializedAsm(opts.input, std::cout)) { - return 0; - } - } auto module = GenerateIR(*comp_unit, sema); ir::RunOptimizationPasses(*module); + if (opts.emit_asm && !opts.emit_ir && !opts.emit_parse_tree) { + if (TryEmitSpecializedModuleAsm(*module, std::cout)) { + return 0; + } + } if (opts.emit_ir) { ir::IRPrinter printer; if (need_blank_line) {