Some checks failed
Build wheels / build (ubuntu-latest, 3.11) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.12) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.13) (push) Has been cancelled
Tests / check (push) Has been cancelled
Tests / build (ubuntu-latest, 3.11) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.12) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.13) (push) Has been cancelled
202 lines
7.0 KiB
Python
202 lines
7.0 KiB
Python
"""MPS expectation benchmark for qmatchatea and Vidal backends."""
|
|
|
|
import argparse
|
|
import json
|
|
import logging
|
|
import os
|
|
import socket
|
|
import time
|
|
|
|
import numpy as np
|
|
|
|
from qibotn.benchmark_cases import (
|
|
build_circuit as build_benchmark_circuit,
|
|
exact_pauli_sum,
|
|
observable_terms,
|
|
terms_to_dict,
|
|
)
|
|
from qibotn.backends.qmatchatea import QMatchaTeaBackend
|
|
from qibotn.backends.vidal_tebd import run_vidal_ring_xz
|
|
|
|
|
|
def optional_int(text):
|
|
if isinstance(text, str) and text.lower() in {"none", "null", "inf", "unlimited"}:
|
|
return None
|
|
return int(text)
|
|
|
|
|
|
def optional_float(text):
|
|
if isinstance(text, str) and text.lower() in {"none", "null", "inf", "unlimited"}:
|
|
return None
|
|
return float(text)
|
|
|
|
|
|
def format_optional(value, fmt="g"):
|
|
return "None" if value is None else format(value, fmt)
|
|
|
|
|
|
def build_circuit(nqubits, nlayers, seed):
|
|
return build_benchmark_circuit("brickwall_cnot", nqubits, nlayers, seed)
|
|
|
|
|
|
def build_observable(nqubits):
|
|
return terms_to_dict(observable_terms("ring_xz", nqubits))
|
|
|
|
|
|
def exact_expectation(circuit, nqubits):
|
|
return exact_pauli_sum(circuit, observable_terms("ring_xz", nqubits), nqubits)
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--nqubits", type=int, default=40)
|
|
parser.add_argument("--nlayers", type=int, default=30)
|
|
parser.add_argument("--bond", "--bonds", dest="bond", type=optional_int, default=512)
|
|
parser.add_argument("--cut-ratio", type=optional_float, default=1e-12)
|
|
parser.add_argument("--seed", type=int, default=42)
|
|
parser.add_argument("--tensor-module", choices=("numpy", "torch"), default="torch")
|
|
parser.add_argument("--torch-threads", type=int, default=32)
|
|
parser.add_argument(
|
|
"--executor",
|
|
choices=("qmatchatea", "vidal", "vidal-mpi"),
|
|
default="qmatchatea",
|
|
)
|
|
parser.add_argument("--mpi-ct", action="store_true")
|
|
parser.add_argument("--mpi-barriers", type=int, default=-1)
|
|
parser.add_argument("--mpi-isometrization", type=int, default=-1)
|
|
parser.add_argument("--exact", action="store_true")
|
|
parser.add_argument("--exact-max-qubits", type=int, default=24)
|
|
parser.add_argument("--reference-file")
|
|
parser.add_argument(
|
|
"--mpi-rank-map",
|
|
action="store_true",
|
|
help="Print MPI rank, host, pid, and torch thread placement metadata.",
|
|
)
|
|
args = parser.parse_args()
|
|
logging.getLogger("qibo.config").setLevel(logging.ERROR)
|
|
logging.getLogger("qtealeaves").setLevel(logging.ERROR)
|
|
import torch
|
|
|
|
torch.set_num_threads(args.torch_threads)
|
|
rank = 0
|
|
size = 1
|
|
if args.mpi_ct:
|
|
from mpi4py import MPI
|
|
|
|
rank = MPI.COMM_WORLD.Get_rank()
|
|
size = MPI.COMM_WORLD.Get_size()
|
|
if args.mpi_rank_map:
|
|
rank_info = {
|
|
"rank": rank,
|
|
"size": size,
|
|
"host": socket.gethostname(),
|
|
"pid": os.getpid(),
|
|
"torch_threads": args.torch_threads,
|
|
"omp_num_threads": os.environ.get("OMP_NUM_THREADS", ""),
|
|
"mkl_num_threads": os.environ.get("MKL_NUM_THREADS", ""),
|
|
}
|
|
rank_infos = MPI.COMM_WORLD.gather(rank_info, root=0)
|
|
if rank == 0:
|
|
print("mpi_rank_map")
|
|
for item in sorted(rank_infos, key=lambda row: row["rank"]):
|
|
print(
|
|
"rank={rank} size={size} host={host} pid={pid} "
|
|
"torch_threads={torch_threads} "
|
|
"OMP_NUM_THREADS={omp_num_threads} "
|
|
"MKL_NUM_THREADS={mkl_num_threads}".format(**item)
|
|
)
|
|
|
|
circuit = build_circuit(args.nqubits, args.nlayers, args.seed)
|
|
observable = build_observable(args.nqubits)
|
|
exact = None
|
|
if args.reference_file:
|
|
with open(args.reference_file, "r", encoding="utf-8") as f:
|
|
exact = float(json.load(f)["expectation"])
|
|
elif args.exact:
|
|
if args.nqubits > args.exact_max_qubits:
|
|
raise ValueError(
|
|
f"--exact is limited to {args.exact_max_qubits} qubits by default."
|
|
)
|
|
exact = exact_expectation(circuit, args.nqubits)
|
|
|
|
if rank == 0:
|
|
if args.mpi_ct and args.executor in ("vidal", "vidal-mpi"):
|
|
mpi_label = f"VidalSegment/{size}"
|
|
else:
|
|
mpi_label = f"MPIMPS/{size}" if args.mpi_ct else "SR"
|
|
print(
|
|
f"nqubits={args.nqubits} nlayers={args.nlayers} "
|
|
f"bond={format_optional(args.bond)} "
|
|
f"cut_ratio={format_optional(args.cut_ratio)} seed={args.seed} "
|
|
f"tensor_module={args.tensor_module} svd_control=E! "
|
|
f"compile_circuit=True mpi={mpi_label} executor={args.executor}"
|
|
)
|
|
if exact is not None:
|
|
print(f"exact={exact:.16e}")
|
|
print("expval abs_error rel_error seconds")
|
|
|
|
start = time.perf_counter()
|
|
timings = None
|
|
if args.executor in ("vidal", "vidal-mpi"):
|
|
if args.executor == "vidal-mpi" and not args.mpi_ct:
|
|
raise ValueError("--executor vidal-mpi requires --mpi-ct.")
|
|
if args.mpi_ct:
|
|
from qibotn.backends.vidal_mpi_segment import run_segment_vidal_mpi_ring_xz
|
|
|
|
value, timings = run_segment_vidal_mpi_ring_xz(
|
|
circuit,
|
|
max_bond=args.bond,
|
|
cut_ratio=args.cut_ratio,
|
|
tensor_module=args.tensor_module,
|
|
comm=MPI.COMM_WORLD,
|
|
)
|
|
else:
|
|
value = run_vidal_ring_xz(
|
|
circuit,
|
|
max_bond=args.bond,
|
|
cut_ratio=args.cut_ratio,
|
|
tensor_module=args.tensor_module,
|
|
)
|
|
else:
|
|
backend = QMatchaTeaBackend()
|
|
backend.configure_tn_simulation(
|
|
ansatz="MPS",
|
|
max_bond_dimension=args.bond,
|
|
cut_ratio=args.cut_ratio,
|
|
svd_control="E!",
|
|
tensor_module=args.tensor_module,
|
|
compile_circuit=True,
|
|
track_memory=False,
|
|
mpi_approach="CT" if args.mpi_ct else "SR",
|
|
mpi_num_procs=size,
|
|
mpi_where_barriers=args.mpi_barriers if args.mpi_ct else -1,
|
|
mpi_isometrization=args.mpi_isometrization,
|
|
)
|
|
value = backend.expectation(
|
|
circuit,
|
|
observable,
|
|
preprocess=False,
|
|
compile_circuit=True,
|
|
)
|
|
max_timings = None
|
|
if timings:
|
|
max_timings = {
|
|
key: MPI.COMM_WORLD.reduce(local_value, op=MPI.MAX, root=0)
|
|
for key, local_value in timings.items()
|
|
}
|
|
if rank != 0:
|
|
return
|
|
value = float(np.real(value))
|
|
elapsed = time.perf_counter() - start
|
|
abs_error = float("nan") if exact is None else abs(value - exact)
|
|
rel_error = float("nan") if exact is None else abs_error / max(abs(exact), 1e-15)
|
|
print(f"{value:.16e} {abs_error:.6e} {rel_error:.6e} {elapsed:.3f}")
|
|
if max_timings:
|
|
print("timing_section max_seconds")
|
|
for key, max_value in max_timings.items():
|
|
print(f"{key} {max_value:.6f}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|