"""Chrome trace profiler for the VidalBackend fast path.""" from __future__ import annotations import argparse from pathlib import Path import torch from torch.profiler import ProfilerActivity, profile from qibotn.benchmark_cases import build_circuit, terms_to_dict, observable_terms from qibotn.expectation_runner import ExpectationConfig, run_cpu_expectation def main(): parser = argparse.ArgumentParser() parser.add_argument("--nqubits", type=int, default=34) parser.add_argument("--nlayers", type=int, default=20) parser.add_argument("--bond", type=int, default=512) parser.add_argument("--seed", type=int, default=42) parser.add_argument("--torch-threads", type=int, default=32) parser.add_argument("--cut-ratio", type=float, default=1e-12) parser.add_argument("--profile-memory", action="store_true") parser.add_argument("--rows", type=int, default=60) args = parser.parse_args() torch.set_num_threads(args.torch_threads) prefix = f"profiles/vidal_n{args.nqubits}_l{args.nlayers}_b{args.bond}_t{args.torch_threads}" trace_path = Path(f"{prefix}.json") table_path = Path(f"{prefix}.txt") trace_path.parent.mkdir(parents=True, exist_ok=True) circuit = build_circuit("brickwall_cnot", args.nqubits, args.nlayers, args.seed) observable = terms_to_dict(observable_terms("ring_xz", args.nqubits)) config = ExpectationConfig( ansatz="mps", bond=args.bond, cut_ratio=args.cut_ratio, tensor_module="torch", torch_threads=args.torch_threads, ) print( f"profile vidal nqubits={args.nqubits} nlayers={args.nlayers} " f"bond={args.bond} threads={args.torch_threads}" ) with profile( activities=[ProfilerActivity.CPU], record_shapes=args.profile_memory, profile_memory=args.profile_memory, with_stack=args.profile_memory, ) as prof: result = run_cpu_expectation(circuit, observable, config) table = ( f"expval={result.value:.16e}\n\n" f"# sorted by self_cpu_time_total\n" f"{prof.key_averages().table(sort_by='self_cpu_time_total', row_limit=args.rows)}\n\n" f"# sorted by cpu_time_total\n" f"{prof.key_averages().table(sort_by='cpu_time_total', row_limit=args.rows)}\n" ) print(table, end="") table_path.write_text(table, encoding="utf-8") prof.export_chrome_trace(str(trace_path)) print(f"trace={trace_path}\ntable={table_path}") if __name__ == "__main__": main()