Some checks failed
Build wheels / build (ubuntu-latest, 3.11) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.12) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.13) (push) Has been cancelled
Tests / check (push) Has been cancelled
Tests / build (ubuntu-latest, 3.11) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.12) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.13) (push) Has been cancelled
146 lines
5.2 KiB
Python
146 lines
5.2 KiB
Python
"""CLI for CPU TN/MPS expectation benchmarks."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
|
|
from qibotn.benchmark_cases import (
|
|
CIRCUITS,
|
|
OBSERVABLES,
|
|
build_circuit,
|
|
observable_terms,
|
|
parse_names,
|
|
terms_to_dict,
|
|
)
|
|
from qibotn.expectation_runner import (
|
|
ExpectationConfig,
|
|
exact_for_observable,
|
|
run_cpu_expectation,
|
|
)
|
|
|
|
|
|
def build_parallel_opts(args):
|
|
slicing_opts = {}
|
|
if args.tn_target_slices is not None:
|
|
slicing_opts["target_slices"] = args.tn_target_slices
|
|
if args.tn_target_size is not None:
|
|
slicing_opts["target_size"] = args.tn_target_size
|
|
|
|
opts = {
|
|
"slicing_opts": slicing_opts or None,
|
|
"search_workers": args.tn_search_workers or args.torch_threads,
|
|
"max_repeats": args.tn_search_repeats,
|
|
"max_time": args.tn_search_time,
|
|
}
|
|
if args.tn_search_backend is not None:
|
|
opts["search_backend"] = args.tn_search_backend
|
|
if args.dask_address is not None:
|
|
opts["dask_address"] = args.dask_address
|
|
return opts
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--nqubits", type=int, default=40)
|
|
parser.add_argument("--nlayers", type=int, default=30)
|
|
parser.add_argument("--bond", "--bonds", dest="bond", type=int, default=1024)
|
|
parser.add_argument("--cut-ratio", type=float, default=1e-12)
|
|
parser.add_argument("--seed", type=int, default=42)
|
|
parser.add_argument("--torch-threads", type=int, default=8)
|
|
parser.add_argument("--ansatz", choices=("tn", "mps"), default=None)
|
|
parser.add_argument("--mps", action="store_true")
|
|
parser.add_argument("--mpi", action="store_true")
|
|
parser.add_argument("--exact", action="store_true")
|
|
parser.add_argument("--exact-max-qubits", type=int, default=24)
|
|
parser.add_argument("--circuits", nargs="+", default=["brickwall_cnot"])
|
|
parser.add_argument("--observables", nargs="+", default=["ring_xz"])
|
|
parser.add_argument("--pauli-pattern")
|
|
parser.add_argument("--tn-target-slices", type=int)
|
|
parser.add_argument("--tn-target-size", type=int)
|
|
parser.add_argument("--tn-search-workers", type=int)
|
|
parser.add_argument("--tn-search-repeats", type=int, default=128)
|
|
parser.add_argument("--tn-search-time", type=float, default=60.0)
|
|
parser.add_argument(
|
|
"--tn-search-backend",
|
|
choices=("processpool", "dask"),
|
|
help="Path-search backend. In MPI mode, dask search runs only on rank 0 and broadcasts the tree.",
|
|
)
|
|
parser.add_argument(
|
|
"--dask-address",
|
|
help="Dask scheduler address, for example tcp://host:8786. If omitted with dask search, a local cluster is created.",
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
ansatz = "mps" if args.mps else (args.ansatz or "tn")
|
|
circuits = parse_names(args.circuits, CIRCUITS, "circuits")
|
|
observables = [] if args.pauli_pattern else parse_names(
|
|
args.observables, OBSERVABLES, "observables"
|
|
)
|
|
|
|
rank = 0
|
|
if args.mpi:
|
|
from mpi4py import MPI
|
|
|
|
rank = MPI.COMM_WORLD.Get_rank()
|
|
|
|
config = ExpectationConfig(
|
|
ansatz=ansatz,
|
|
mpi=args.mpi,
|
|
bond=args.bond,
|
|
cut_ratio=args.cut_ratio,
|
|
tensor_module="torch",
|
|
torch_threads=args.torch_threads,
|
|
parallel_opts=build_parallel_opts(args),
|
|
)
|
|
|
|
if rank == 0:
|
|
mode = "MPI" if args.mpi else "serial"
|
|
print(
|
|
f"backend=cpu ansatz={ansatz.upper()} mode={mode} "
|
|
f"nqubits={args.nqubits} nlayers={args.nlayers} "
|
|
f"bond={args.bond} cut_ratio={args.cut_ratio:g} seed={args.seed} "
|
|
f"torch_threads={args.torch_threads} "
|
|
f"tn_search_backend={args.tn_search_backend or 'processpool'}"
|
|
)
|
|
print("circuit observable exact value abs_error rel_error seconds")
|
|
|
|
for circuit_kind in circuits:
|
|
circuit = build_circuit(circuit_kind, args.nqubits, args.nlayers, args.seed)
|
|
named_observables = (
|
|
[(f"pattern:{args.pauli_pattern}", {"pauli_string_pattern": args.pauli_pattern})]
|
|
if args.pauli_pattern
|
|
else [
|
|
(obs_kind, terms_to_dict(observable_terms(obs_kind, args.nqubits)))
|
|
for obs_kind in observables
|
|
]
|
|
)
|
|
|
|
for obs_name, observable in named_observables:
|
|
exact = None
|
|
if args.exact and rank == 0:
|
|
if args.nqubits > args.exact_max_qubits:
|
|
raise ValueError(
|
|
f"--exact is limited to {args.exact_max_qubits} qubits by default."
|
|
)
|
|
exact = exact_for_observable(circuit, observable, args.nqubits)
|
|
|
|
result = run_cpu_expectation(circuit, observable, config)
|
|
if args.mpi and result.rank != 0:
|
|
continue
|
|
|
|
abs_error = float("nan") if exact is None else abs(result.value - exact)
|
|
rel_error = (
|
|
float("nan")
|
|
if exact is None
|
|
else abs_error / max(abs(exact), 1e-15)
|
|
)
|
|
exact_text = "nan" if exact is None else f"{exact:.16e}"
|
|
print(
|
|
f"{circuit_kind} {obs_name} {exact_text} {result.value:.16e} "
|
|
f"{abs_error:.6e} {rel_error:.6e} {result.seconds:.3f}"
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|