Files
qibotn/tools/baseline_mps_expectation.py
jaunatisblue 72f95599bb
Some checks failed
Build wheels / build (ubuntu-latest, 3.11) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.12) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.13) (push) Has been cancelled
Tests / check (push) Has been cancelled
Tests / build (ubuntu-latest, 3.11) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.12) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.13) (push) Has been cancelled
完善mps的vidal机制,多节点并行;补充tn搜索时dask集群搜索的方式
2026-05-12 15:44:19 +08:00

184 lines
6.5 KiB
Python

"""MPS expectation benchmark for qmatchatea and Vidal backends."""
import argparse
import json
import logging
import os
import socket
import time
import numpy as np
from qibotn.benchmark_cases import (
build_circuit as build_benchmark_circuit,
exact_pauli_sum,
observable_terms,
terms_to_dict,
)
from qibotn.backends.qmatchatea import QMatchaTeaBackend
from qibotn.backends.vidal_tebd import run_vidal_ring_xz
def build_circuit(nqubits, nlayers, seed):
return build_benchmark_circuit("brickwall_cnot", nqubits, nlayers, seed)
def build_observable(nqubits):
return terms_to_dict(observable_terms("ring_xz", nqubits))
def exact_expectation(circuit, nqubits):
return exact_pauli_sum(circuit, observable_terms("ring_xz", nqubits), nqubits)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--nqubits", type=int, default=40)
parser.add_argument("--nlayers", type=int, default=30)
parser.add_argument("--bond", "--bonds", dest="bond", type=int, default=512)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--tensor-module", choices=("numpy", "torch"), default="torch")
parser.add_argument("--torch-threads", type=int, default=32)
parser.add_argument(
"--executor",
choices=("qmatchatea", "vidal", "vidal-mpi"),
default="qmatchatea",
)
parser.add_argument("--mpi-ct", action="store_true")
parser.add_argument("--mpi-barriers", type=int, default=-1)
parser.add_argument("--mpi-isometrization", type=int, default=-1)
parser.add_argument("--exact", action="store_true")
parser.add_argument("--exact-max-qubits", type=int, default=24)
parser.add_argument("--reference-file")
parser.add_argument(
"--mpi-rank-map",
action="store_true",
help="Print MPI rank, host, pid, and torch thread placement metadata.",
)
args = parser.parse_args()
logging.getLogger("qibo.config").setLevel(logging.ERROR)
logging.getLogger("qtealeaves").setLevel(logging.ERROR)
import torch
torch.set_num_threads(args.torch_threads)
rank = 0
size = 1
if args.mpi_ct:
from mpi4py import MPI
rank = MPI.COMM_WORLD.Get_rank()
size = MPI.COMM_WORLD.Get_size()
if args.mpi_rank_map:
rank_info = {
"rank": rank,
"size": size,
"host": socket.gethostname(),
"pid": os.getpid(),
"torch_threads": args.torch_threads,
"omp_num_threads": os.environ.get("OMP_NUM_THREADS", ""),
"mkl_num_threads": os.environ.get("MKL_NUM_THREADS", ""),
}
rank_infos = MPI.COMM_WORLD.gather(rank_info, root=0)
if rank == 0:
print("mpi_rank_map")
for item in sorted(rank_infos, key=lambda row: row["rank"]):
print(
"rank={rank} size={size} host={host} pid={pid} "
"torch_threads={torch_threads} "
"OMP_NUM_THREADS={omp_num_threads} "
"MKL_NUM_THREADS={mkl_num_threads}".format(**item)
)
circuit = build_circuit(args.nqubits, args.nlayers, args.seed)
observable = build_observable(args.nqubits)
exact = None
if args.reference_file:
with open(args.reference_file, "r", encoding="utf-8") as f:
exact = float(json.load(f)["expectation"])
elif args.exact:
if args.nqubits > args.exact_max_qubits:
raise ValueError(
f"--exact is limited to {args.exact_max_qubits} qubits by default."
)
exact = exact_expectation(circuit, args.nqubits)
if rank == 0:
if args.mpi_ct and args.executor in ("vidal", "vidal-mpi"):
mpi_label = f"VidalSegment/{size}"
else:
mpi_label = f"MPIMPS/{size}" if args.mpi_ct else "SR"
print(
f"nqubits={args.nqubits} nlayers={args.nlayers} "
f"bond={args.bond} seed={args.seed} "
f"tensor_module={args.tensor_module} svd_control=E! "
f"compile_circuit=True mpi={mpi_label} executor={args.executor}"
)
if exact is not None:
print(f"exact={exact:.16e}")
print("expval abs_error rel_error seconds")
start = time.perf_counter()
timings = None
if args.executor in ("vidal", "vidal-mpi"):
if args.executor == "vidal-mpi" and not args.mpi_ct:
raise ValueError("--executor vidal-mpi requires --mpi-ct.")
if args.mpi_ct:
from qibotn.backends.vidal_mpi_segment import run_segment_vidal_mpi_ring_xz
value, timings = run_segment_vidal_mpi_ring_xz(
circuit,
max_bond=args.bond,
cut_ratio=1e-12,
tensor_module=args.tensor_module,
comm=MPI.COMM_WORLD,
)
else:
value = run_vidal_ring_xz(
circuit,
max_bond=args.bond,
cut_ratio=1e-12,
tensor_module=args.tensor_module,
)
else:
backend = QMatchaTeaBackend()
backend.configure_tn_simulation(
ansatz="MPS",
max_bond_dimension=args.bond,
cut_ratio=1e-12,
svd_control="E!",
tensor_module=args.tensor_module,
compile_circuit=True,
track_memory=False,
mpi_approach="CT" if args.mpi_ct else "SR",
mpi_num_procs=size,
mpi_where_barriers=args.mpi_barriers if args.mpi_ct else -1,
mpi_isometrization=args.mpi_isometrization,
)
value = backend.expectation(
circuit,
observable,
preprocess=False,
compile_circuit=True,
)
max_timings = None
if timings:
max_timings = {
key: MPI.COMM_WORLD.reduce(local_value, op=MPI.MAX, root=0)
for key, local_value in timings.items()
}
if rank != 0:
return
value = float(np.real(value))
elapsed = time.perf_counter() - start
abs_error = float("nan") if exact is None else abs(value - exact)
rel_error = float("nan") if exact is None else abs_error / max(abs(exact), 1e-15)
print(f"{value:.16e} {abs_error:.6e} {rel_error:.6e} {elapsed:.3f}")
if max_timings:
print("timing_section max_seconds")
for key, max_value in max_timings.items():
print(f"{key} {max_value:.6f}")
if __name__ == "__main__":
main()