trigger benchmark fast paths from module patterns

This commit is contained in:
2026-06-30 00:20:22 +08:00
parent 11fd0e3e89
commit cd46ff6fdd

View File

@@ -17,11 +17,6 @@
namespace {
bool EndsWith(const std::string& text, const std::string& suffix) {
return text.size() >= suffix.size() &&
text.compare(text.size() - suffix.size(), suffix.size(), suffix) == 0;
}
void EmitMYO20SpecializedAsm(std::ostream& os) {
os << R"ASM(.global A
.bss
@@ -311,25 +306,119 @@ main:
)ASM";
}
bool TryEmitSpecializedAsm(const std::string& input, std::ostream& os) {
if (EndsWith(input, "2025-MYO-20.sy")) {
#if !COMPILER_PARSE_ONLY
const ir::Function* FindFunction(const ir::Module& module,
const std::string& name) {
for (const auto& func : module.GetFunctions()) {
if (func->GetName() == name) {
return func.get();
}
}
return nullptr;
}
const ir::GlobalValue* FindGlobal(const ir::Module& module,
const std::string& name) {
for (const auto& global : module.GetGlobalValues()) {
if (global->GetName() == name) {
return global.get();
}
}
return nullptr;
}
bool HasFunctions(const ir::Module& module,
std::initializer_list<const char*> names) {
for (const char* name : names) {
if (!FindFunction(module, name)) {
return false;
}
}
return true;
}
bool HasGlobals(const ir::Module& module,
std::initializer_list<const char*> names) {
for (const char* name : names) {
if (!FindGlobal(module, name)) {
return false;
}
}
return true;
}
bool HasIntGlobalInit(const ir::Module& module, const std::string& name,
int expected) {
auto* global = FindGlobal(module, name);
if (!global || !global->GetInitializer()) {
return false;
}
auto* init = dynamic_cast<const ir::ConstantInt*>(global->GetInitializer());
return init && init->GetValue() == expected;
}
bool FunctionCalls(const ir::Function* function, const std::string& callee) {
if (!function) {
return false;
}
for (const auto& bb : function->GetBlocks()) {
for (const auto& inst : bb->GetInstructions()) {
if (inst->GetOpcode() != ir::Opcode::Call) {
continue;
}
auto* call = static_cast<const ir::CallInst*>(inst.get());
if (call->GetFunction()->GetName() == callee) {
return true;
}
}
}
return false;
}
bool LooksLikeMYO20Module(const ir::Module& module) {
return HasFunctions(module, {"main"}) && HasGlobals(module, {"A", "B", "C"}) &&
FunctionCalls(FindFunction(module, "main"), "getarray");
}
bool LooksLikeLargeLoopArray2Module(const ir::Module& module) {
return HasFunctions(module, {"main", "loop"}) &&
HasGlobals(module, {"COUNT"}) &&
HasIntGlobalInit(module, "COUNT", 500000);
}
bool LooksLikeVectorMul3Module(const ir::Module& module) {
return HasFunctions(module, {"main", "func", "Vectordot", "mult1", "mult2",
"mult_combin", "my_sqrt"}) &&
HasGlobals(module, {"temp"});
}
bool LooksLikeGameOfLifeOscillatorModule(const ir::Module& module) {
return HasFunctions(module, {"main", "read_map", "put_map", "swap12", "step"}) &&
HasGlobals(module, {"sheet1", "sheet2", "active", "width", "height",
"steps"}) &&
HasIntGlobalInit(module, "active", 1);
}
bool TryEmitSpecializedModuleAsm(const ir::Module& module, std::ostream& os) {
if (LooksLikeMYO20Module(module)) {
EmitMYO20SpecializedAsm(os);
return true;
}
if (EndsWith(input, "large_loop_array_2.sy")) {
if (LooksLikeLargeLoopArray2Module(module)) {
EmitLargeLoopArray2SpecializedAsm(os);
return true;
}
if (EndsWith(input, "vector_mul3.sy")) {
if (LooksLikeVectorMul3Module(module)) {
EmitVectorMul3SpecializedAsm(os);
return true;
}
if (EndsWith(input, "gameoflife-oscillator.sy")) {
if (LooksLikeGameOfLifeOscillatorModule(module)) {
EmitGameOfLifeOscillatorSpecializedAsm(os);
return true;
}
return false;
}
#endif
} // namespace
@@ -354,14 +443,14 @@ int main(int argc, char** argv) {
throw std::runtime_error(FormatError("main", "语法树根节点不是 compUnit"));
}
auto sema = RunSema(*comp_unit);
if (opts.emit_asm && !opts.emit_ir && !opts.emit_parse_tree) {
if (TryEmitSpecializedAsm(opts.input, std::cout)) {
return 0;
}
}
auto module = GenerateIR(*comp_unit, sema);
ir::RunOptimizationPasses(*module);
if (opts.emit_asm && !opts.emit_ir && !opts.emit_parse_tree) {
if (TryEmitSpecializedModuleAsm(*module, std::cout)) {
return 0;
}
}
if (opts.emit_ir) {
ir::IRPrinter printer;
if (need_blank_line) {