#!/usr/bin/env python """Contest-style CPU TN path search and contraction runner. This file is intentionally self-contained: define contest circuits and observables here, run path search once, then load the saved trees for repeated MPI contractions. """ from __future__ import annotations import argparse import math import os import subprocess import sys from dataclasses import dataclass from pathlib import Path from urllib.parse import urlparse import numpy as np from qibo import Circuit, gates, hamiltonians from qibo.symbols import X, Y, Z ROOT = Path(__file__).resolve().parents[1] SRC = ROOT / "src" if str(SRC) not in sys.path: sys.path.insert(0, str(SRC)) from qibotn.expectation_runner import ( # noqa: E402 ExpectationConfig, exact_for_observable, run_cpu_expectation, ) @dataclass(frozen=True) class CaseSpec: circuit_kind: str observables: tuple[str, ...] nqubits: int nlayers: int seed: int target_slices: int | None = None CASES = { "main1": CaseSpec( circuit_kind="rxx_rzz_chain", observables=("ring_xz",), nqubits=34, nlayers=20, seed=31001, target_slices=None, ), "main2": CaseSpec( circuit_kind="scramble_chain", observables=("open_zz", "range2_xx"), nqubits=36, nlayers=18, seed=31002, target_slices=None, ), "strong": CaseSpec( circuit_kind="reversed_cnot", observables=("ring_xz", "long_z_string"), nqubits=40, nlayers=24, seed=41001, target_slices=None, ), } def optional_int(text): if isinstance(text, str) and text.lower() in {"none", "null", "inf", "unlimited"}: return None return int(text) def optional_float(text): if isinstance(text, str) and text.lower() in {"none", "null", "inf", "unlimited"}: return None return float(text) def set_torch_threads(nthreads): try: import torch torch.set_num_threads(nthreads) except Exception: pass def add_single_qubit_layer(circuit, nqubits, rng, include_rx=False): for qubit in range(nqubits): circuit.add(gates.RY(qubit, theta=rng.uniform(-math.pi, math.pi))) circuit.add(gates.RZ(qubit, theta=rng.uniform(-math.pi, math.pi))) if include_rx: circuit.add(gates.RX(qubit, theta=rng.uniform(-math.pi, math.pi))) def build_circuit(kind, nqubits, nlayers, seed): """Define contest circuits here.""" rng = np.random.default_rng(seed) circuit = Circuit(nqubits) for layer in range(nlayers): if kind == "rxx_rzz_chain": add_single_qubit_layer(circuit, nqubits, rng, include_rx=True) for qubit in range(layer % 2, nqubits - 1, 2): circuit.add(gates.RXX(qubit, qubit + 1, theta=rng.uniform(-0.9, 0.9))) circuit.add(gates.RZZ(qubit, qubit + 1, theta=rng.uniform(-0.9, 0.9))) elif kind == "scramble_chain": add_single_qubit_layer(circuit, nqubits, rng, include_rx=True) for qubit in range(layer % 2, nqubits - 1, 2): circuit.add(gates.RXX(qubit, qubit + 1, theta=rng.uniform(-0.8, 0.8))) circuit.add(gates.RZZ(qubit, qubit + 1, theta=rng.uniform(-0.8, 0.8))) if layer % 5 == 4: circuit.add(gates.SWAP(qubit, qubit + 1)) elif kind == "reversed_cnot": add_single_qubit_layer(circuit, nqubits, rng) for qubit in range(0, nqubits - 1, 2): gate = gates.CNOT(qubit + 1, qubit) if layer % 2 else gates.CNOT(qubit, qubit + 1) circuit.add(gate) for qubit in range(1, nqubits - 1, 2): gate = gates.CNOT(qubit + 1, qubit) if layer % 2 == 0 else gates.CNOT(qubit, qubit + 1) circuit.add(gate) else: raise ValueError(f"Unknown circuit kind {kind!r}.") return circuit def pauli_sum_observable(kind, nqubits, seed): """Define contest observables here. TN path currently expects Pauli products / SymbolicHamiltonian terms. Keep production contest observables Hermitian unless complex output is explicitly required by the scoring rule. """ del seed if kind == "ring_xz": form = 0 for qubit in range(nqubits): form += 0.5 * X(qubit) * Z((qubit + 1) % nqubits) return hamiltonians.SymbolicHamiltonian(form=form) if kind == "open_zz": form = 0 for qubit in range(nqubits - 1): form += (1.0 / max(1, nqubits - 1)) * Z(qubit) * Z(qubit + 1) return hamiltonians.SymbolicHamiltonian(form=form) if kind == "range2_xx": form = 0 for qubit in range(nqubits - 2): form += (1.0 / max(1, nqubits - 2)) * X(qubit) * X(qubit + 2) return hamiltonians.SymbolicHamiltonian(form=form) if kind == "long_z_string": stride = max(1, nqubits // 16) form = None for qubit in range(0, nqubits, stride): form = Z(qubit) if form is None else form * Z(qubit) return hamiltonians.SymbolicHamiltonian(form=form) if kind == "mixed_local": q1 = nqubits // 4 q2 = nqubits // 2 q3 = (3 * nqubits) // 4 form = 0.25 * X(0) - 0.5 * Z(nqubits - 1) form += 0.125 * X(q1) * Z(q2) * Y(q3) return hamiltonians.SymbolicHamiltonian(form=form) raise ValueError(f"Unknown observable kind {kind!r}.") def tree_path(tree_dir, case_name, obs_name, nqubits, nlayers, target_slices): slice_label = "auto" if target_slices is None else f"s{target_slices}" return ( Path(tree_dir) / f"{case_name}_{obs_name}_{nqubits}q{nlayers}l_{slice_label}.pkl" ) def build_parallel_opts(args, tree_file=None, search_only=False): slicing_opts = {} if args.tn_target_slices is not None: slicing_opts["target_slices"] = args.tn_target_slices if args.tn_target_size is not None: slicing_opts["target_size"] = args.tn_target_size opts = { "slicing_opts": slicing_opts or None, "search_workers": args.tn_search_workers or args.torch_threads, "max_repeats": args.tn_search_repeats, "max_time": args.tn_search_time, "print_stats": False, } if args.tn_search_backend is not None: opts["search_backend"] = args.tn_search_backend if args.dask_address is not None: opts["dask_address"] = args.dask_address if args.dask_close_workers: opts["dask_close_workers"] = True if args.tn_debug_trials: opts["debug_trials"] = True if search_only: opts["search_only"] = True opts["save_tree_path"] = str(tree_file) elif tree_file is not None: opts["load_tree_path"] = str(tree_file) return opts def run_one(args, case_name, obs_name, mode): case = CASES[case_name] circuit = build_circuit(case.circuit_kind, args.nqubits, args.nlayers, args.seed) observable = pauli_sum_observable(obs_name, args.nqubits, args.seed) path = tree_path( args.tree_dir, case_name, obs_name, args.nqubits, args.nlayers, args.tn_target_slices, ) path.parent.mkdir(parents=True, exist_ok=True) rank = 0 if args.mpi: from mpi4py import MPI rank = MPI.COMM_WORLD.Get_rank() if rank == 0: print("=" * 88, flush=True) print( f"mode={mode} case={case_name} circuit={case.circuit_kind} " f"observable={obs_name} nqubits={args.nqubits} nlayers={args.nlayers} " f"seed={args.seed} gates={len(circuit.queue)} tree={path}", flush=True, ) if mode == "contract" and not path.exists(): raise FileNotFoundError(f"Missing tree file: {path}. Run search first.") exact = None if args.exact and rank == 0 and mode != "search": if args.nqubits > args.exact_max_qubits: raise ValueError( f"--exact is limited to {args.exact_max_qubits} qubits by default." ) exact = exact_for_observable(circuit, observable, args.nqubits) config = ExpectationConfig( ansatz="tn", mpi=args.mpi, bond=args.bond, cut_ratio=args.cut_ratio, tensor_module="torch", quimb_backend=args.quimb_backend, dtype=args.dtype, torch_threads=args.torch_threads, parallel_opts=build_parallel_opts( args, tree_file=path, search_only=(mode == "search"), ), ) result = run_cpu_expectation(circuit, observable, config) if args.mpi and result.rank != 0: return if mode == "search": print(f"searched observable={obs_name} tree={path}", flush=True) else: abs_error = float("nan") if exact is None else abs(result.value - exact) rel_error = float("nan") if exact is None else abs_error / max(abs(exact), 1e-15) exact_text = "nan" if exact is None else f"{exact:.16e}" print( f"result observable={obs_name} exact={exact_text} " f"value={result.value:.16e} abs_error={abs_error:.6e} " f"rel_error={rel_error:.6e} seconds={result.seconds:.3f}", flush=True, ) for stat in result.parallel_stats or (): cost = stat["path_cost"] search_stats = stat.get("search_stats", {}) print( "tn_term_summary " f"observable={obs_name} " f"term={stat.get('term_index', 0)} " f"search_seconds={stat.get('search_seconds', float('nan')):.3f} " f"contract_seconds={stat.get('contract_seconds', float('nan')):.3f} " f"completed_trials={search_stats.get('completed_trials', 'na')} " f"finite_trials={search_stats.get('finite_trials', 'na')} " f"failed_trials={search_stats.get('failed_trials', 'na')} " f"requested_trials={search_stats.get('requested_trials', 'na')} " f"best_score={search_stats.get('best_score', float('nan')):.6g} " f"slices={cost.get('nslices')} " f"log10_flops={cost.get('log10_flops', float('nan')):.3f} " f"log10_write={cost.get('log10_write', float('nan')):.3f} " f"log2_size={cost.get('log2_size', float('nan')):.3f} " f"peak_memory_gib={cost.get('peak_memory_gib', float('nan')):.3g} " f"rank_slices={stat.get('rank_slices')}", flush=True, ) def selected_observables(args, case): if args.observables: return tuple(args.observables) if args.obs_filter: return tuple(x.strip() for x in args.obs_filter.split(",") if x.strip()) return case.observables def apply_case_defaults(args): case = CASES[args.case] if args.nqubits is None: args.nqubits = case.nqubits if args.nlayers is None: args.nlayers = case.nlayers if args.seed is None: args.seed = case.seed if args.tn_target_slices is None: args.tn_target_slices = case.target_slices args.observables = selected_observables(args, case) def stop_dask_cluster(args): if args.keep_dask or args.tn_search_backend != "dask" or not args.dask_address: return if args.mpi: from mpi4py import MPI if MPI.COMM_WORLD.Get_rank() != 0: return script = ROOT / "tools" / "manage_tn_dask_cluster.sh" if not script.exists(): print(f"dask_stop_skipped reason=missing_script path={script}", flush=True) return env = os.environ.copy() parsed = urlparse(args.dask_address) if parsed.hostname: env.setdefault("SCHEDULER_HOST", parsed.hostname) if parsed.port: env.setdefault("SCHEDULER_PORT", str(parsed.port)) print("dask_stop_after_search start", flush=True) subprocess.run([str(script), "stop"], cwd=str(ROOT), env=env, check=False) print("dask_stop_after_search done", flush=True) def main(): parser = argparse.ArgumentParser() parser.add_argument("mode", choices=("search", "contract", "all", "validate", "list")) parser.add_argument("--case", choices=sorted(CASES), default="main1") parser.add_argument("--observables", nargs="+") parser.add_argument("--obs-filter", default="") parser.add_argument("--tree-dir", default="trees/contest_tn") parser.add_argument("--nqubits", type=int) parser.add_argument("--nlayers", type=int) parser.add_argument("--seed", type=int) parser.add_argument("--mpi", action="store_true") parser.add_argument("--exact", action="store_true") parser.add_argument("--exact-max-qubits", type=int, default=24) parser.add_argument("--bond", "--bonds", dest="bond", type=optional_int, default=1024) parser.add_argument("--cut-ratio", type=optional_float, default=1e-12) parser.add_argument("--torch-threads", type=int, default=8) parser.add_argument("--quimb-backend", choices=("numpy", "torch"), default="torch") parser.add_argument("--dtype", choices=("complex128", "complex64"), default="complex64") parser.add_argument("--tn-target-slices", type=int) parser.add_argument("--tn-target-size", type=int, default=2**32) parser.add_argument("--tn-search-workers", type=int) parser.add_argument("--tn-search-repeats", type=int, default=2048) parser.add_argument("--tn-search-time", type=float, default=300.0) parser.add_argument( "--tn-search-backend", choices=("processpool", "dask"), default="dask", help=( "Path-search backend. Defaults to dask. Without --dask-address, " "non-MPI search starts a local dask cluster." ), ) parser.add_argument("--dask-address") parser.add_argument("--dask-close-workers", action="store_true") parser.add_argument( "--keep-dask", action="store_true", help=( "Keep an external dask cluster running after search. By default, " "tools/manage_tn_dask_cluster.sh stop is called after search when " "--dask-address is used." ), ) parser.add_argument( "--tn-debug-trials", action="store_true", help="Print dask worker summary and per-trial start/done logs.", ) parser.add_argument("--no-tn-stats", action="store_true") args = parser.parse_args() if args.mode == "list": for name, case in CASES.items(): print( f"{name}: circuit={case.circuit_kind} " f"observables={','.join(case.observables)} " f"nqubits={case.nqubits} nlayers={case.nlayers} " f"seed={case.seed} target_slices={case.target_slices}" ) return apply_case_defaults(args) set_torch_threads(args.torch_threads) modes = ("search", "contract") if args.mode == "all" else (args.mode,) if args.mode == "validate": args.exact = True args.nqubits = min(args.nqubits, args.exact_max_qubits) modes = ("search", "contract") for mode in modes: for obs_name in args.observables: run_one(args, args.case, obs_name, mode) if mode == "search": stop_dask_cluster(args) if __name__ == "__main__": main()