#!/usr/bin/env python """Run TN expectation for a user-provided circuit and observable. The case module should define: def build_circuit(nqubits, nlayers, seed): ... def build_observable(nqubits, seed): ... ``build_observable`` may return a Qibo SymbolicHamiltonian/form or the qibotn dict form: {"terms": [ {"coefficient": 1.0, "operators": [("X", 0), ("Z", 1)]}, ]} For a single repeated Pauli string, pass ``--pauli-pattern`` instead of defining ``build_observable``. """ from __future__ import annotations import argparse import importlib.util import inspect import json import sys from pathlib import Path ROOT = Path(__file__).resolve().parents[1] SRC = ROOT / "src" if str(SRC) not in sys.path: sys.path.insert(0, str(SRC)) from qibotn.expectation_runner import ( # noqa: E402 ExpectationConfig, exact_for_observable, run_cpu_expectation, ) def optional_int(text): if isinstance(text, str) and text.lower() in {"none", "null", "inf", "unlimited"}: return None return int(text) def optional_float(text): if isinstance(text, str) and text.lower() in {"none", "null", "inf", "unlimited"}: return None return float(text) def load_module(path): path = Path(path).resolve() spec = importlib.util.spec_from_file_location(path.stem, path) if spec is None or spec.loader is None: raise RuntimeError(f"Cannot import case module from {path}.") module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) return module def call_builder(fn, **kwargs): sig = inspect.signature(fn) if any(p.kind == p.VAR_KEYWORD for p in sig.parameters.values()): return fn(**kwargs) accepted = { name: value for name, value in kwargs.items() if name in sig.parameters } return fn(**accepted) def load_observable(args, module): if args.pauli_pattern: return {"pauli_string_pattern": args.pauli_pattern} if args.observable_json: with Path(args.observable_json).open() as f: return json.load(f) if hasattr(module, "build_observable"): return call_builder( module.build_observable, nqubits=args.nqubits, nlayers=args.nlayers, seed=args.seed, ) if hasattr(module, "OBSERVABLE"): return module.OBSERVABLE raise ValueError( "No observable supplied. Define build_observable/OBSERVABLE in the case " "module, or pass --pauli-pattern / --observable-json." ) def build_parallel_opts(args): slicing_opts = {} if args.tn_target_slices is not None: slicing_opts["target_slices"] = args.tn_target_slices if args.tn_target_size is not None: slicing_opts["target_size"] = args.tn_target_size opts = { "slicing_opts": slicing_opts or None, "search_workers": args.tn_search_workers or args.torch_threads, "max_repeats": args.tn_search_repeats, "max_time": args.tn_search_time, "print_stats": not args.no_tn_stats, } if args.tn_search_backend is not None: opts["search_backend"] = args.tn_search_backend if args.dask_address is not None: opts["dask_address"] = args.dask_address if args.dask_close_workers: opts["dask_close_workers"] = True if args.tn_save_tree is not None: opts["save_tree_path"] = args.tn_save_tree if args.tn_load_tree is not None: opts["load_tree_path"] = args.tn_load_tree if args.tn_search_only: opts["search_only"] = True return opts def main(): parser = argparse.ArgumentParser( description="Run CPU TN expectation for a custom qibo circuit module." ) parser.add_argument("case_module", help="Python file defining build_circuit.") parser.add_argument("--nqubits", type=int, required=True) parser.add_argument("--nlayers", type=int, default=0) parser.add_argument("--seed", type=int, default=42) parser.add_argument("--mpi", action="store_true") parser.add_argument("--exact", action="store_true") parser.add_argument("--exact-max-qubits", type=int, default=24) parser.add_argument("--bond", "--bonds", dest="bond", type=optional_int, default=1024) parser.add_argument("--cut-ratio", type=optional_float, default=1e-12) parser.add_argument("--torch-threads", type=int, default=8) parser.add_argument("--quimb-backend", choices=("numpy", "torch"), default="torch") parser.add_argument("--dtype", choices=("complex128", "complex64"), default="complex128") parser.add_argument("--pauli-pattern") parser.add_argument("--observable-json") parser.add_argument("--tn-target-slices", type=int) parser.add_argument("--tn-target-size", type=int, default=2**32) parser.add_argument("--tn-search-workers", type=int) parser.add_argument("--tn-search-repeats", type=int, default=128) parser.add_argument("--tn-search-time", type=float, default=60.0) parser.add_argument("--tn-search-backend", choices=("processpool", "dask")) parser.add_argument("--dask-address") parser.add_argument("--dask-close-workers", action="store_true") parser.add_argument("--tn-save-tree") parser.add_argument("--tn-load-tree") parser.add_argument("--tn-search-only", action="store_true") parser.add_argument("--no-tn-stats", action="store_true") args = parser.parse_args() rank = 0 if args.mpi: from mpi4py import MPI rank = MPI.COMM_WORLD.Get_rank() module = load_module(args.case_module) if not hasattr(module, "build_circuit"): raise ValueError("case_module must define build_circuit.") circuit = call_builder( module.build_circuit, nqubits=args.nqubits, nlayers=args.nlayers, seed=args.seed, ) observable = load_observable(args, module) config = ExpectationConfig( ansatz="tn", mpi=args.mpi, bond=args.bond, cut_ratio=args.cut_ratio, tensor_module="torch", quimb_backend=args.quimb_backend, dtype=args.dtype, torch_threads=args.torch_threads, parallel_opts=build_parallel_opts(args), ) if rank == 0: mode = "MPI" if args.mpi else "serial" print( f"backend=cpu ansatz=TN mode={mode} case={Path(args.case_module).name} " f"nqubits={args.nqubits} nlayers={args.nlayers} seed={args.seed} " f"quimb_backend={args.quimb_backend} dtype={args.dtype} " f"torch_threads={args.torch_threads}", flush=True, ) print("observable exact value abs_error rel_error seconds", flush=True) exact = None if args.exact and rank == 0: if args.nqubits > args.exact_max_qubits: raise ValueError( f"--exact is limited to {args.exact_max_qubits} qubits by default." ) exact = exact_for_observable(circuit, observable, args.nqubits) result = run_cpu_expectation(circuit, observable, config) if args.mpi and result.rank != 0: return abs_error = float("nan") if exact is None else abs(result.value - exact) rel_error = float("nan") if exact is None else abs_error / max(abs(exact), 1e-15) exact_text = "nan" if exact is None else f"{exact:.16e}" print( f"custom {exact_text} {result.value:.16e} " f"{abs_error:.6e} {rel_error:.6e} {result.seconds:.3f}", flush=True, ) for stat in result.parallel_stats or (): cost = stat["path_cost"] search_stats = stat.get("search_stats", {}) print( "tn_term_summary " f"term={stat.get('term_index', 0)} " f"search_seconds={stat.get('search_seconds', float('nan')):.3f} " f"contract_seconds={stat.get('contract_seconds', float('nan')):.3f} " f"completed_trials={search_stats.get('completed_trials', 'na')} " f"finite_trials={search_stats.get('finite_trials', 'na')} " f"failed_trials={search_stats.get('failed_trials', 'na')} " f"requested_trials={search_stats.get('requested_trials', 'na')} " f"best_score={search_stats.get('best_score', float('nan')):.6g} " f"slices={cost.get('slices')} " f"log10_flops={cost.get('log10_flops', float('nan')):.3f} " f"log10_write={cost.get('log10_write', float('nan')):.3f} " f"log2_size={cost.get('log2_size', float('nan')):.3f} " f"peak_memory_gib={cost.get('peak_memory_gib', float('nan')):.3g} " f"rank_slices={stat.get('rank_slices')}", flush=True, ) if __name__ == "__main__": main()