"""CLI for CPU TN/MPS expectation benchmarks.""" from __future__ import annotations import argparse from qibotn.benchmark_cases import ( CIRCUITS, OBSERVABLES, build_circuit, observable_terms, parse_names, terms_to_dict, ) from qibotn.expectation_runner import ( ExpectationConfig, exact_for_observable, run_cpu_expectation, ) def build_parallel_opts(args): slicing_opts = {} if args.tn_target_slices is not None: slicing_opts["target_slices"] = args.tn_target_slices if args.tn_target_size is not None: slicing_opts["target_size"] = args.tn_target_size opts = { "slicing_opts": slicing_opts or None, "search_workers": args.tn_search_workers or args.torch_threads, "max_repeats": args.tn_search_repeats, "max_time": args.tn_search_time, } if args.tn_search_backend is not None: opts["search_backend"] = args.tn_search_backend if args.dask_address is not None: opts["dask_address"] = args.dask_address return opts def main(): parser = argparse.ArgumentParser() parser.add_argument("--nqubits", type=int, default=40) parser.add_argument("--nlayers", type=int, default=30) parser.add_argument("--bond", "--bonds", dest="bond", type=int, default=1024) parser.add_argument("--cut-ratio", type=float, default=1e-12) parser.add_argument("--seed", type=int, default=42) parser.add_argument("--torch-threads", type=int, default=8) parser.add_argument("--ansatz", choices=("tn", "mps"), default=None) parser.add_argument("--mps", action="store_true") parser.add_argument("--mpi", action="store_true") parser.add_argument("--exact", action="store_true") parser.add_argument("--exact-max-qubits", type=int, default=24) parser.add_argument("--circuits", nargs="+", default=["brickwall_cnot"]) parser.add_argument("--observables", nargs="+", default=["ring_xz"]) parser.add_argument("--pauli-pattern") parser.add_argument("--tn-target-slices", type=int) parser.add_argument("--tn-target-size", type=int) parser.add_argument("--tn-search-workers", type=int) parser.add_argument("--tn-search-repeats", type=int, default=128) parser.add_argument("--tn-search-time", type=float, default=60.0) parser.add_argument( "--tn-search-backend", choices=("processpool", "dask"), help="Path-search backend. In MPI mode, dask search runs only on rank 0 and broadcasts the tree.", ) parser.add_argument( "--dask-address", help="Dask scheduler address, for example tcp://host:8786. If omitted with dask search, a local cluster is created.", ) args = parser.parse_args() ansatz = "mps" if args.mps else (args.ansatz or "tn") circuits = parse_names(args.circuits, CIRCUITS, "circuits") observables = [] if args.pauli_pattern else parse_names( args.observables, OBSERVABLES, "observables" ) rank = 0 if args.mpi: from mpi4py import MPI rank = MPI.COMM_WORLD.Get_rank() config = ExpectationConfig( ansatz=ansatz, mpi=args.mpi, bond=args.bond, cut_ratio=args.cut_ratio, tensor_module="torch", torch_threads=args.torch_threads, parallel_opts=build_parallel_opts(args), ) if rank == 0: mode = "MPI" if args.mpi else "serial" print( f"backend=cpu ansatz={ansatz.upper()} mode={mode} " f"nqubits={args.nqubits} nlayers={args.nlayers} " f"bond={args.bond} cut_ratio={args.cut_ratio:g} seed={args.seed} " f"torch_threads={args.torch_threads} " f"tn_search_backend={args.tn_search_backend or 'processpool'}" ) print("circuit observable exact value abs_error rel_error seconds") for circuit_kind in circuits: circuit = build_circuit(circuit_kind, args.nqubits, args.nlayers, args.seed) named_observables = ( [(f"pattern:{args.pauli_pattern}", {"pauli_string_pattern": args.pauli_pattern})] if args.pauli_pattern else [ (obs_kind, terms_to_dict(observable_terms(obs_kind, args.nqubits))) for obs_kind in observables ] ) for obs_name, observable in named_observables: exact = None if args.exact and rank == 0: if args.nqubits > args.exact_max_qubits: raise ValueError( f"--exact is limited to {args.exact_max_qubits} qubits by default." ) exact = exact_for_observable(circuit, observable, args.nqubits) result = run_cpu_expectation(circuit, observable, config) if args.mpi and result.rank != 0: continue abs_error = float("nan") if exact is None else abs(result.value - exact) rel_error = ( float("nan") if exact is None else abs_error / max(abs(exact), 1e-15) ) exact_text = "nan" if exact is None else f"{exact:.16e}" print( f"{circuit_kind} {obs_name} {exact_text} {result.value:.16e} " f"{abs_error:.6e} {rel_error:.6e} {result.seconds:.3f}" ) if __name__ == "__main__": main()