赛前稳定版
Some checks failed
Build wheels / build (ubuntu-latest, 3.11) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.12) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.13) (push) Has been cancelled
Tests / check (push) Has been cancelled
Tests / build (ubuntu-latest, 3.11) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.12) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.13) (push) Has been cancelled
Some checks failed
Build wheels / build (ubuntu-latest, 3.11) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.12) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.13) (push) Has been cancelled
Tests / check (push) Has been cancelled
Tests / build (ubuntu-latest, 3.11) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.12) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.13) (push) Has been cancelled
This commit is contained in:
@@ -3,6 +3,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from qibotn.benchmark_cases import (
|
||||
CIRCUITS,
|
||||
@@ -19,6 +23,51 @@ from qibotn.expectation_runner import (
|
||||
)
|
||||
|
||||
|
||||
def optional_int(text):
|
||||
if isinstance(text, str) and text.lower() in {"none", "null", "inf", "unlimited"}:
|
||||
return None
|
||||
return int(text)
|
||||
|
||||
|
||||
def optional_float(text):
|
||||
if isinstance(text, str) and text.lower() in {"none", "null", "inf", "unlimited"}:
|
||||
return None
|
||||
return float(text)
|
||||
|
||||
|
||||
def format_optional(value, fmt="g"):
|
||||
return "None" if value is None else format(value, fmt)
|
||||
|
||||
|
||||
def should_stop_dask(args):
|
||||
return (
|
||||
not args.keep_dask
|
||||
and args.tn_search_backend == "dask"
|
||||
and args.dask_address is not None
|
||||
and args.tn_load_tree is None
|
||||
)
|
||||
|
||||
|
||||
def stop_dask_cluster(args, rank):
|
||||
if rank != 0 or not should_stop_dask(args):
|
||||
return
|
||||
script = Path(__file__).resolve().parent / "tools" / "manage_tn_dask_cluster.sh"
|
||||
if not script.exists():
|
||||
print(f"dask_stop_skipped reason=missing_script path={script}", flush=True)
|
||||
return
|
||||
|
||||
env = os.environ.copy()
|
||||
parsed = urlparse(args.dask_address)
|
||||
if parsed.hostname:
|
||||
env.setdefault("SCHEDULER_HOST", parsed.hostname)
|
||||
if parsed.port:
|
||||
env.setdefault("SCHEDULER_PORT", str(parsed.port))
|
||||
|
||||
print("dask_stop_after_search start", flush=True)
|
||||
subprocess.run([str(script), "stop"], cwd=str(script.parent.parent), env=env, check=False)
|
||||
print("dask_stop_after_search done", flush=True)
|
||||
|
||||
|
||||
def build_parallel_opts(args):
|
||||
slicing_opts = {}
|
||||
if args.tn_target_slices is not None:
|
||||
@@ -31,11 +80,24 @@ def build_parallel_opts(args):
|
||||
"search_workers": args.tn_search_workers or args.torch_threads,
|
||||
"max_repeats": args.tn_search_repeats,
|
||||
"max_time": args.tn_search_time,
|
||||
"print_stats": not args.no_tn_stats,
|
||||
}
|
||||
if args.tn_search_backend is not None:
|
||||
opts["search_backend"] = args.tn_search_backend
|
||||
if args.dask_address is not None:
|
||||
opts["dask_address"] = args.dask_address
|
||||
if args.tn_save_tree is not None:
|
||||
opts["save_tree_path"] = args.tn_save_tree
|
||||
if args.tn_load_tree is not None:
|
||||
opts["load_tree_path"] = args.tn_load_tree
|
||||
if args.tn_search_only:
|
||||
opts["search_only"] = True
|
||||
if args.tn_debug_trials:
|
||||
opts["debug_trials"] = True
|
||||
if args.tn_contract_implementation is not None:
|
||||
opts["contract_implementation"] = args.tn_contract_implementation
|
||||
if args.dask_close_workers:
|
||||
opts["dask_close_workers"] = True
|
||||
return opts
|
||||
|
||||
|
||||
@@ -43,10 +105,16 @@ def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--nqubits", type=int, default=40)
|
||||
parser.add_argument("--nlayers", type=int, default=30)
|
||||
parser.add_argument("--bond", "--bonds", dest="bond", type=int, default=1024)
|
||||
parser.add_argument("--cut-ratio", type=float, default=1e-12)
|
||||
parser.add_argument("--bond", "--bonds", dest="bond", type=optional_int, default=1024)
|
||||
parser.add_argument("--cut-ratio", type=optional_float, default=1e-12)
|
||||
parser.add_argument("--seed", type=int, default=42)
|
||||
parser.add_argument("--torch-threads", type=int, default=8)
|
||||
parser.add_argument("--quimb-backend", choices=("numpy", "torch"), default="torch")
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
choices=("complex128", "complex64"),
|
||||
default="complex128",
|
||||
)
|
||||
parser.add_argument("--ansatz", choices=("tn", "mps"), default=None)
|
||||
parser.add_argument("--mps", action="store_true")
|
||||
parser.add_argument("--mpi", action="store_true")
|
||||
@@ -56,19 +124,62 @@ def main():
|
||||
parser.add_argument("--observables", nargs="+", default=["ring_xz"])
|
||||
parser.add_argument("--pauli-pattern")
|
||||
parser.add_argument("--tn-target-slices", type=int)
|
||||
parser.add_argument("--tn-target-size", type=int)
|
||||
parser.add_argument("--tn-target-size", type=int,default=2**32)
|
||||
parser.add_argument("--tn-search-workers", type=int)
|
||||
parser.add_argument("--tn-search-repeats", type=int, default=128)
|
||||
parser.add_argument("--tn-search-time", type=float, default=60.0)
|
||||
parser.add_argument(
|
||||
"--no-tn-stats",
|
||||
action="store_true",
|
||||
help="Do not print per-term TN search/contraction diagnostics.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tn-search-backend",
|
||||
choices=("processpool", "dask"),
|
||||
default="dask",
|
||||
help="Path-search backend. In MPI mode, dask search runs only on rank 0 and broadcasts the tree.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dask-address",
|
||||
help="Dask scheduler address, for example tcp://host:8786. If omitted with dask search, a local cluster is created.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dask-close-workers",
|
||||
action="store_true",
|
||||
help="After dask path search, ask the scheduler to close all currently connected workers.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--keep-dask",
|
||||
action="store_true",
|
||||
help=(
|
||||
"Keep an external dask cluster running after search. By default, "
|
||||
"tools/manage_tn_dask_cluster.sh stop is called after search when "
|
||||
"--dask-address is used."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tn-save-tree",
|
||||
help="Save searched cotengra contraction tree(s) to this pickle file.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tn-load-tree",
|
||||
help="Load cotengra contraction tree(s) from this pickle file and skip path search.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tn-search-only",
|
||||
action="store_true",
|
||||
help="Only run path search and optional --tn-save-tree; skip contraction.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tn-debug-trials",
|
||||
action="store_true",
|
||||
help="Print dask worker summary and per-trial worker start/done logs.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tn-contract-implementation",
|
||||
choices=("auto", "cotengra", "autoray", "cpp"),
|
||||
help="cotengra contraction implementation for TN contraction.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
ansatz = "mps" if args.mps else (args.ansatz or "tn")
|
||||
@@ -89,6 +200,8 @@ def main():
|
||||
bond=args.bond,
|
||||
cut_ratio=args.cut_ratio,
|
||||
tensor_module="torch",
|
||||
quimb_backend=args.quimb_backend,
|
||||
dtype=args.dtype,
|
||||
torch_threads=args.torch_threads,
|
||||
parallel_opts=build_parallel_opts(args),
|
||||
)
|
||||
@@ -98,47 +211,74 @@ def main():
|
||||
print(
|
||||
f"backend=cpu ansatz={ansatz.upper()} mode={mode} "
|
||||
f"nqubits={args.nqubits} nlayers={args.nlayers} "
|
||||
f"bond={args.bond} cut_ratio={args.cut_ratio:g} seed={args.seed} "
|
||||
f"bond={format_optional(args.bond)} "
|
||||
f"cut_ratio={format_optional(args.cut_ratio)} seed={args.seed} "
|
||||
f"quimb_backend={args.quimb_backend} dtype={args.dtype} "
|
||||
f"torch_threads={args.torch_threads} "
|
||||
f"tn_search_backend={args.tn_search_backend or 'processpool'}"
|
||||
f"tn_search_backend={args.tn_search_backend}"
|
||||
)
|
||||
print("circuit observable exact value abs_error rel_error seconds")
|
||||
|
||||
for circuit_kind in circuits:
|
||||
circuit = build_circuit(circuit_kind, args.nqubits, args.nlayers, args.seed)
|
||||
named_observables = (
|
||||
[(f"pattern:{args.pauli_pattern}", {"pauli_string_pattern": args.pauli_pattern})]
|
||||
if args.pauli_pattern
|
||||
else [
|
||||
(obs_kind, terms_to_dict(observable_terms(obs_kind, args.nqubits)))
|
||||
for obs_kind in observables
|
||||
]
|
||||
)
|
||||
try:
|
||||
for circuit_kind in circuits:
|
||||
circuit = build_circuit(circuit_kind, args.nqubits, args.nlayers, args.seed)
|
||||
named_observables = (
|
||||
[(f"pattern:{args.pauli_pattern}", {"pauli_string_pattern": args.pauli_pattern})]
|
||||
if args.pauli_pattern
|
||||
else [
|
||||
(obs_kind, terms_to_dict(observable_terms(obs_kind, args.nqubits)))
|
||||
for obs_kind in observables
|
||||
]
|
||||
)
|
||||
|
||||
for obs_name, observable in named_observables:
|
||||
exact = None
|
||||
if args.exact and rank == 0:
|
||||
if args.nqubits > args.exact_max_qubits:
|
||||
raise ValueError(
|
||||
f"--exact is limited to {args.exact_max_qubits} qubits by default."
|
||||
for obs_name, observable in named_observables:
|
||||
exact = None
|
||||
if args.exact and rank == 0:
|
||||
if args.nqubits > args.exact_max_qubits:
|
||||
raise ValueError(
|
||||
f"--exact is limited to {args.exact_max_qubits} qubits by default."
|
||||
)
|
||||
exact = exact_for_observable(circuit, observable, args.nqubits)
|
||||
|
||||
result = run_cpu_expectation(circuit, observable, config)
|
||||
if args.mpi and result.rank != 0:
|
||||
continue
|
||||
|
||||
abs_error = float("nan") if exact is None else abs(result.value - exact)
|
||||
rel_error = (
|
||||
float("nan")
|
||||
if exact is None
|
||||
else abs_error / max(abs(exact), 1e-15)
|
||||
)
|
||||
exact_text = "nan" if exact is None else f"{exact:.16e}"
|
||||
print(
|
||||
f"{circuit_kind} {obs_name} {exact_text} {result.value:.16e} "
|
||||
f"{abs_error:.6e} {rel_error:.6e} {result.seconds:.3f}"
|
||||
)
|
||||
for stat in result.parallel_stats or ():
|
||||
cost = stat["path_cost"]
|
||||
search_stats = stat.get("search_stats", {})
|
||||
print(
|
||||
"tn_term_summary "
|
||||
f"term={stat.get('term_index', 0)} "
|
||||
f"search_seconds={stat.get('search_seconds', float('nan')):.3f} "
|
||||
f"contract_seconds={stat.get('contract_seconds', float('nan')):.3f} "
|
||||
f"completed_trials={search_stats.get('completed_trials', 'na')} "
|
||||
f"finite_trials={search_stats.get('finite_trials', 'na')} "
|
||||
f"failed_trials={search_stats.get('failed_trials', 'na')} "
|
||||
f"requested_trials={search_stats.get('requested_trials', 'na')} "
|
||||
f"best_score={search_stats.get('best_score', float('nan')):.6g} "
|
||||
f"slices={cost['nslices']} "
|
||||
f"log10_flops={cost['log10_flops']:.3f} "
|
||||
f"log10_write={cost['log10_write']:.3f} "
|
||||
f"log2_size={cost['log2_size']:.3f} "
|
||||
f"log10_combo={cost['log10_combo']:.3f} "
|
||||
f"peak_memory_gib={cost['peak_memory_gib']:.6g} "
|
||||
f"slicing_overhead={cost['slicing_overhead']:.6g} "
|
||||
f"rank_slices={stat.get('rank_slices', 'na')}"
|
||||
)
|
||||
exact = exact_for_observable(circuit, observable, args.nqubits)
|
||||
|
||||
result = run_cpu_expectation(circuit, observable, config)
|
||||
if args.mpi and result.rank != 0:
|
||||
continue
|
||||
|
||||
abs_error = float("nan") if exact is None else abs(result.value - exact)
|
||||
rel_error = (
|
||||
float("nan")
|
||||
if exact is None
|
||||
else abs_error / max(abs(exact), 1e-15)
|
||||
)
|
||||
exact_text = "nan" if exact is None else f"{exact:.16e}"
|
||||
print(
|
||||
f"{circuit_kind} {obs_name} {exact_text} {result.value:.16e} "
|
||||
f"{abs_error:.6e} {rel_error:.6e} {result.seconds:.3f}"
|
||||
)
|
||||
finally:
|
||||
stop_dask_cluster(args, rank)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user