完善mps的vidal机制,多节点并行;补充tn搜索时dask集群搜索的方式
Some checks failed
Build wheels / build (ubuntu-latest, 3.11) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.12) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.13) (push) Has been cancelled
Tests / check (push) Has been cancelled
Tests / build (ubuntu-latest, 3.11) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.12) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.13) (push) Has been cancelled
Some checks failed
Build wheels / build (ubuntu-latest, 3.11) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.12) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.13) (push) Has been cancelled
Tests / check (push) Has been cancelled
Tests / build (ubuntu-latest, 3.11) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.12) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.13) (push) Has been cancelled
This commit is contained in:
202
tools/validate_vidal_mpi_correctness.py
Normal file
202
tools/validate_vidal_mpi_correctness.py
Normal file
@@ -0,0 +1,202 @@
|
||||
"""Correctness checks for the Vidal/TEBD MPS fast path.
|
||||
|
||||
The cases here intentionally cover more than the benchmark ring-XZ observable:
|
||||
different nearest-neighbor gate orientations and several Pauli-sum observables.
|
||||
Run serially to compare qibojit/statevector vs Vidal, or under MPI to compare
|
||||
the segmented Vidal executor.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import math
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from qibo import Circuit, gates
|
||||
|
||||
from qibotn.backends.vidal_mpi_segment import SegmentVidalMPIExecutor
|
||||
from qibotn.backends.vidal_tebd import VidalTEBDExecutor
|
||||
|
||||
|
||||
def build_circuit(kind, nqubits, nlayers, seed):
|
||||
rng = np.random.default_rng(seed)
|
||||
circuit = Circuit(nqubits)
|
||||
for layer in range(nlayers):
|
||||
for q in range(nqubits):
|
||||
circuit.add(gates.RY(q, theta=rng.uniform(-math.pi, math.pi)))
|
||||
circuit.add(gates.RZ(q, theta=rng.uniform(-math.pi, math.pi)))
|
||||
if kind == "rx_ry_cz":
|
||||
circuit.add(gates.RX(q, theta=rng.uniform(-math.pi, math.pi)))
|
||||
|
||||
if kind in ("brickwall", "reversed_cnot"):
|
||||
for q in range(0, nqubits - 1, 2):
|
||||
if kind == "reversed_cnot" and (layer % 2):
|
||||
circuit.add(gates.CNOT(q + 1, q))
|
||||
else:
|
||||
circuit.add(gates.CNOT(q, q + 1))
|
||||
for q in range(1, nqubits - 1, 2):
|
||||
if kind == "reversed_cnot" and not (layer % 2):
|
||||
circuit.add(gates.CNOT(q + 1, q))
|
||||
else:
|
||||
circuit.add(gates.CNOT(q, q + 1))
|
||||
elif kind == "rx_ry_cz":
|
||||
for q in range(layer % 2, nqubits - 1, 2):
|
||||
circuit.add(gates.CZ(q, q + 1))
|
||||
else:
|
||||
raise ValueError(f"Unknown circuit kind {kind!r}.")
|
||||
return circuit
|
||||
|
||||
|
||||
def observable_terms(kind, nqubits):
|
||||
if kind == "ring_xz":
|
||||
return [
|
||||
(0.5, (("X", site), ("Z", (site + 1) % nqubits)))
|
||||
for site in range(nqubits)
|
||||
]
|
||||
if kind == "open_zz":
|
||||
return [
|
||||
(1.0 / (nqubits - 1), (("Z", site), ("Z", site + 1)))
|
||||
for site in range(nqubits - 1)
|
||||
]
|
||||
if kind == "mixed_local":
|
||||
terms = [(0.25, (("X", 0),)), (-0.5, (("Z", nqubits - 1),))]
|
||||
terms += [
|
||||
(0.125, (("Y", site), ("Y", site + 1)))
|
||||
for site in range(0, nqubits - 1, 3)
|
||||
]
|
||||
return terms
|
||||
raise ValueError(f"Unknown observable kind {kind!r}.")
|
||||
|
||||
|
||||
def exact_pauli_sum(circuit, terms, nqubits):
|
||||
state = circuit().state(numpy=True).reshape(-1)
|
||||
indices = np.arange(state.size, dtype=np.int64)
|
||||
value = 0.0 + 0.0j
|
||||
for coeff, ops in terms:
|
||||
flipped = indices.copy()
|
||||
phase = np.ones(state.size, dtype=np.complex128)
|
||||
for name, site in ops:
|
||||
shift = nqubits - 1 - site
|
||||
bit = (indices >> shift) & 1
|
||||
name = name.upper()
|
||||
if name == "X":
|
||||
flipped ^= 1 << shift
|
||||
elif name == "Y":
|
||||
flipped ^= 1 << shift
|
||||
phase *= 1j * (1 - 2 * bit)
|
||||
elif name == "Z":
|
||||
phase *= 1 - 2 * bit
|
||||
elif name != "I":
|
||||
raise ValueError(f"Unsupported Pauli {name!r}.")
|
||||
value += coeff * np.vdot(state[flipped], phase * state)
|
||||
return float(value.real)
|
||||
|
||||
|
||||
def run_vidal(circuit, terms, nqubits, bond, tensor_module):
|
||||
executor = VidalTEBDExecutor(
|
||||
nqubits=nqubits,
|
||||
max_bond=bond,
|
||||
cut_ratio=1e-12,
|
||||
tensor_module=tensor_module,
|
||||
)
|
||||
executor.run_circuit(circuit)
|
||||
return float(executor.expectation_pauli_sum(terms))
|
||||
|
||||
|
||||
def run_segment_mpi(circuit, terms, nqubits, bond, tensor_module, comm):
|
||||
executor = SegmentVidalMPIExecutor(
|
||||
nqubits=nqubits,
|
||||
max_bond=bond,
|
||||
cut_ratio=1e-12,
|
||||
tensor_module=tensor_module,
|
||||
comm=comm,
|
||||
)
|
||||
executor.run_circuit(circuit)
|
||||
return executor.expectation_pauli_sum_root(terms)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--nqubits", type=int, default=16)
|
||||
parser.add_argument("--nlayers", type=int, default=6)
|
||||
parser.add_argument("--bond", "--bonds", dest="bond", type=int, default=512)
|
||||
parser.add_argument("--seed", type=int, default=42)
|
||||
parser.add_argument("--tensor-module", choices=("torch", "numpy"), default="torch")
|
||||
parser.add_argument("--torch-threads", type=int, default=32)
|
||||
parser.add_argument("--mpi", action="store_true")
|
||||
parser.add_argument(
|
||||
"--circuits",
|
||||
nargs="+",
|
||||
default=("brickwall", "reversed_cnot", "rx_ry_cz"),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--observables",
|
||||
nargs="+",
|
||||
default=("ring_xz", "open_zz", "mixed_local"),
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
torch.set_num_threads(args.torch_threads)
|
||||
comm = None
|
||||
rank = 0
|
||||
size = 1
|
||||
if args.mpi:
|
||||
from mpi4py import MPI
|
||||
|
||||
comm = MPI.COMM_WORLD
|
||||
rank = comm.Get_rank()
|
||||
size = comm.Get_size()
|
||||
|
||||
if rank == 0:
|
||||
mode = f"vidal-segment-mpi/{size}" if args.mpi else "vidal"
|
||||
print(
|
||||
f"mode={mode} nqubits={args.nqubits} nlayers={args.nlayers} "
|
||||
f"bond={args.bond} tensor_module={args.tensor_module}"
|
||||
)
|
||||
print("circuit observable exact value abs_error seconds")
|
||||
|
||||
for circuit_kind in args.circuits:
|
||||
circuit = build_circuit(circuit_kind, args.nqubits, args.nlayers, args.seed)
|
||||
exact = None
|
||||
if rank == 0:
|
||||
exact_values = {
|
||||
obs: exact_pauli_sum(
|
||||
circuit, observable_terms(obs, args.nqubits), args.nqubits
|
||||
)
|
||||
for obs in args.observables
|
||||
}
|
||||
else:
|
||||
exact_values = None
|
||||
if comm is not None:
|
||||
exact_values = comm.bcast(exact_values, root=0)
|
||||
|
||||
for obs_kind in args.observables:
|
||||
terms = observable_terms(obs_kind, args.nqubits)
|
||||
start = time.perf_counter()
|
||||
if args.mpi:
|
||||
value = run_segment_mpi(
|
||||
circuit,
|
||||
terms,
|
||||
args.nqubits,
|
||||
args.bond,
|
||||
args.tensor_module,
|
||||
comm,
|
||||
)
|
||||
else:
|
||||
value = run_vidal(
|
||||
circuit, terms, args.nqubits, args.bond, args.tensor_module
|
||||
)
|
||||
if rank != 0:
|
||||
continue
|
||||
elapsed = time.perf_counter() - start
|
||||
exact = exact_values[obs_kind]
|
||||
print(
|
||||
f"{circuit_kind} {obs_kind} {exact:.16e} {value:.16e} "
|
||||
f"{abs(value - exact):.6e} {elapsed:.3f}"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user