Some checks failed
Build wheels / build (ubuntu-latest, 3.11) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.12) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.13) (push) Has been cancelled
Tests / check (push) Has been cancelled
Tests / build (ubuntu-latest, 3.11) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.12) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.13) (push) Has been cancelled
314 lines
10 KiB
Python
314 lines
10 KiB
Python
#!/usr/bin/env python
|
|
"""Contest-style multi-node Vidal/MPS expectation runner."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import math
|
|
import sys
|
|
import time
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
|
|
import numpy as np
|
|
from mpi4py import MPI
|
|
from qibo import Circuit, gates, hamiltonians
|
|
from qibo.symbols import X, Y, Z
|
|
|
|
ROOT = Path(__file__).resolve().parents[1]
|
|
SRC = ROOT / "src"
|
|
if str(SRC) not in sys.path:
|
|
sys.path.insert(0, str(SRC))
|
|
|
|
from qibotn.backends.vidal import VidalBackend # noqa: E402
|
|
from qibotn.expectation_runner import exact_for_observable # noqa: E402
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class CaseSpec:
|
|
circuit_kind: str
|
|
observables: tuple[str, ...]
|
|
nqubits: int
|
|
nlayers: int
|
|
bond: int | None
|
|
seed: int
|
|
|
|
|
|
CASES = {
|
|
"main1": CaseSpec(
|
|
circuit_kind="reversed_cnot",
|
|
observables=("ring_xz",),
|
|
nqubits=128,
|
|
nlayers=24,
|
|
bond=512,
|
|
seed=31001,
|
|
),
|
|
"main2": CaseSpec(
|
|
circuit_kind="rxx_rzz",
|
|
observables=("open_zz", "range2_xx", "mixed_local"),
|
|
nqubits=128,
|
|
nlayers=32,
|
|
bond=1024,
|
|
seed=31002,
|
|
),
|
|
"strong": CaseSpec(
|
|
circuit_kind="scramble",
|
|
observables=("ring_xz", "long_z_string", "dense3_spread"),
|
|
nqubits=256,
|
|
nlayers=48,
|
|
bond=2048,
|
|
seed=41001,
|
|
),
|
|
}
|
|
|
|
|
|
def optional_int(text):
|
|
if isinstance(text, str) and text.lower() in {"none", "null", "inf", "unlimited"}:
|
|
return None
|
|
return int(text)
|
|
|
|
|
|
def optional_float(text):
|
|
if isinstance(text, str) and text.lower() in {"none", "null", "inf", "unlimited"}:
|
|
return None
|
|
return float(text)
|
|
|
|
|
|
def format_optional(value, fmt="g"):
|
|
return "None" if value is None else format(value, fmt)
|
|
|
|
|
|
def set_torch_threads(nthreads):
|
|
try:
|
|
import torch
|
|
|
|
torch.set_num_threads(nthreads)
|
|
except Exception:
|
|
pass
|
|
|
|
|
|
def add_single_qubit_layer(circuit, nqubits, rng, include_rx=False):
|
|
for qubit in range(nqubits):
|
|
circuit.add(gates.RY(qubit, theta=rng.uniform(-math.pi, math.pi)))
|
|
circuit.add(gates.RZ(qubit, theta=rng.uniform(-math.pi, math.pi)))
|
|
if include_rx:
|
|
circuit.add(gates.RX(qubit, theta=rng.uniform(-math.pi, math.pi)))
|
|
|
|
|
|
def build_circuit(kind, nqubits, nlayers, seed):
|
|
rng = np.random.default_rng(seed)
|
|
circuit = Circuit(nqubits)
|
|
|
|
for layer in range(nlayers):
|
|
if kind == "reversed_cnot":
|
|
add_single_qubit_layer(circuit, nqubits, rng)
|
|
for qubit in range(0, nqubits - 1, 2):
|
|
gate = gates.CNOT(qubit + 1, qubit) if layer % 2 else gates.CNOT(qubit, qubit + 1)
|
|
circuit.add(gate)
|
|
for qubit in range(1, nqubits - 1, 2):
|
|
gate = gates.CNOT(qubit + 1, qubit) if layer % 2 == 0 else gates.CNOT(qubit, qubit + 1)
|
|
circuit.add(gate)
|
|
|
|
elif kind == "rxx_rzz":
|
|
add_single_qubit_layer(circuit, nqubits, rng, include_rx=True)
|
|
for qubit in range(layer % 2, nqubits - 1, 2):
|
|
circuit.add(gates.RXX(qubit, qubit + 1, theta=rng.uniform(-0.9, 0.9)))
|
|
circuit.add(gates.RZZ(qubit, qubit + 1, theta=rng.uniform(-0.9, 0.9)))
|
|
|
|
elif kind == "scramble":
|
|
add_single_qubit_layer(circuit, nqubits, rng, include_rx=True)
|
|
for qubit in range(layer % 2, nqubits - 1, 2):
|
|
circuit.add(gates.RXX(qubit, qubit + 1, theta=rng.uniform(-0.8, 0.8)))
|
|
circuit.add(gates.RZZ(qubit, qubit + 1, theta=rng.uniform(-0.8, 0.8)))
|
|
if layer % 5 == 4:
|
|
circuit.add(gates.SWAP(qubit, qubit + 1))
|
|
|
|
else:
|
|
raise ValueError(f"Unknown circuit kind {kind!r}.")
|
|
|
|
return circuit
|
|
|
|
|
|
def dense_observable(nqubits, qubits, seed, dim):
|
|
del nqubits
|
|
rng = np.random.default_rng(seed)
|
|
raw = rng.normal(size=(dim, dim)) + 1j * rng.normal(size=(dim, dim))
|
|
matrix = (raw + raw.conj().T) / 2.0
|
|
matrix = matrix / np.linalg.norm(matrix)
|
|
return {"matrix": matrix, "qubits": list(qubits)}
|
|
|
|
|
|
def observable(kind, nqubits, seed):
|
|
q1 = nqubits // 4
|
|
q2 = nqubits // 2
|
|
q3 = (3 * nqubits) // 4
|
|
last = nqubits - 1
|
|
|
|
if kind == "boundary_ZZ_q1":
|
|
return hamiltonians.SymbolicHamiltonian(form=Z(q1 - 1) * Z(q1))
|
|
if kind == "boundary_ZZ_q2":
|
|
return hamiltonians.SymbolicHamiltonian(form=Z(q2 - 1) * Z(q2))
|
|
if kind == "boundary_ZZ_q3":
|
|
return hamiltonians.SymbolicHamiltonian(form=Z(q3 - 1) * Z(q3))
|
|
if kind == "long_Z_5_sites":
|
|
return hamiltonians.SymbolicHamiltonian(form=Z(0) * Z(q1) * Z(q2) * Z(q3) * Z(last))
|
|
if kind == "mixed_XZYZX":
|
|
return hamiltonians.SymbolicHamiltonian(form=X(0) * Z(q1) * Y(q2) * Z(q3) * X(last))
|
|
if kind == "ring_xz":
|
|
form = 0
|
|
for qubit in range(nqubits):
|
|
form += 0.5 * X(qubit) * Z((qubit + 1) % nqubits)
|
|
return hamiltonians.SymbolicHamiltonian(form=form)
|
|
if kind == "open_zz":
|
|
form = 0
|
|
for qubit in range(nqubits - 1):
|
|
form += (1.0 / max(1, nqubits - 1)) * Z(qubit) * Z(qubit + 1)
|
|
return hamiltonians.SymbolicHamiltonian(form=form)
|
|
if kind == "range2_xx":
|
|
form = 0
|
|
for qubit in range(nqubits - 2):
|
|
form += (1.0 / max(1, nqubits - 2)) * X(qubit) * X(qubit + 2)
|
|
return hamiltonians.SymbolicHamiltonian(form=form)
|
|
if kind == "mixed_local":
|
|
form = 0.25 * X(0) - 0.5 * Z(last) + 0.125 * X(q1) * Z(q2) * Y(q3)
|
|
return hamiltonians.SymbolicHamiltonian(form=form)
|
|
if kind == "complex_iZ0":
|
|
return hamiltonians.SymbolicHamiltonian(form=1.0j * Z(0))
|
|
if kind == "dense2_mid":
|
|
return dense_observable(nqubits, (q2 - 1, q2), seed + 101, 4)
|
|
if kind == "dense3_spread":
|
|
return dense_observable(nqubits, (q1, q2, q3), seed + 202, 8)
|
|
raise ValueError(f"Unknown observable kind {kind!r}.")
|
|
|
|
|
|
def selected_observables(args, case):
|
|
if args.observables:
|
|
return tuple(args.observables)
|
|
if args.obs_filter:
|
|
return tuple(x.strip() for x in args.obs_filter.split(",") if x.strip())
|
|
return case.observables
|
|
|
|
|
|
def apply_case_defaults(args):
|
|
case = CASES[args.case]
|
|
if args.nqubits is None:
|
|
args.nqubits = case.nqubits
|
|
if args.nlayers is None:
|
|
args.nlayers = case.nlayers
|
|
if args.bond == "case-default":
|
|
args.bond = case.bond
|
|
if args.seed is None:
|
|
args.seed = case.seed
|
|
args.observables = selected_observables(args, case)
|
|
|
|
|
|
def run_case(args):
|
|
set_torch_threads(args.torch_threads)
|
|
comm = MPI.COMM_WORLD
|
|
rank = comm.Get_rank()
|
|
size = comm.Get_size()
|
|
|
|
case = CASES[args.case]
|
|
circuit = build_circuit(case.circuit_kind, args.nqubits, args.nlayers, args.seed)
|
|
|
|
if rank == 0:
|
|
print("=" * 88, flush=True)
|
|
print(
|
|
"backend=vidal_mps "
|
|
f"case={args.case} circuit={case.circuit_kind} ranks={size} "
|
|
f"nqubits={args.nqubits} nlayers={args.nlayers} gates={len(circuit.queue)} "
|
|
f"bond={format_optional(args.bond)} cut_ratio={format_optional(args.cut_ratio)} "
|
|
f"torch_threads={args.torch_threads} seed={args.seed} "
|
|
f"observables={','.join(args.observables)}",
|
|
flush=True,
|
|
)
|
|
print("observable exact value abs_error rel_error seconds trunc_sum trunc_max status", flush=True)
|
|
|
|
for obs_name in args.observables:
|
|
obs = observable(obs_name, args.nqubits, args.seed)
|
|
exact = None
|
|
if args.exact and rank == 0:
|
|
if args.nqubits > args.exact_max_qubits:
|
|
raise ValueError(
|
|
f"--exact is limited to {args.exact_max_qubits} qubits by default."
|
|
)
|
|
exact = exact_for_observable(circuit, obs, args.nqubits)
|
|
|
|
backend = VidalBackend()
|
|
backend.configure_tn_simulation(
|
|
max_bond_dimension=args.bond,
|
|
cut_ratio=args.cut_ratio,
|
|
tensor_module="torch",
|
|
mpi_approach="CT",
|
|
mpi_num_procs=size,
|
|
fallback=False,
|
|
)
|
|
|
|
comm.Barrier()
|
|
start = time.perf_counter()
|
|
try:
|
|
value = backend.expectation(
|
|
circuit,
|
|
obs,
|
|
preprocess=True,
|
|
compile_circuit=False,
|
|
)
|
|
status = "ok"
|
|
except Exception as exc:
|
|
value = np.nan
|
|
status = type(exc).__name__ + ":" + str(exc).split("\n", 1)[0]
|
|
seconds = time.perf_counter() - start
|
|
|
|
if rank == 0:
|
|
abs_error = float("nan") if exact is None else abs(value - exact)
|
|
rel_error = float("nan") if exact is None else abs_error / max(abs(exact), 1e-15)
|
|
exact_text = "nan" if exact is None else f"{exact:.16e}"
|
|
print(
|
|
f"{obs_name} {exact_text} {value!r} "
|
|
f"{abs_error:.6e} {rel_error:.6e} {seconds:.3f} "
|
|
f"{backend.last_truncation_error:.6e} "
|
|
f"{backend.last_max_truncation_error:.6e} {status}",
|
|
flush=True,
|
|
)
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("mode", choices=("run", "validate", "list"))
|
|
parser.add_argument("--case", choices=sorted(CASES), default="main1")
|
|
parser.add_argument("--observables", nargs="+")
|
|
parser.add_argument("--obs-filter", default="")
|
|
parser.add_argument("--nqubits", type=int)
|
|
parser.add_argument("--nlayers", type=int)
|
|
parser.add_argument("--bond", "--bonds", dest="bond", default="case-default")
|
|
parser.add_argument("--cut-ratio", type=optional_float, default=1e-12)
|
|
parser.add_argument("--seed", type=int)
|
|
parser.add_argument("--torch-threads", type=int, default=8)
|
|
parser.add_argument("--exact", action="store_true")
|
|
parser.add_argument("--exact-max-qubits", type=int, default=24)
|
|
args = parser.parse_args()
|
|
|
|
if args.mode == "list":
|
|
for name, case in CASES.items():
|
|
print(
|
|
f"{name}: circuit={case.circuit_kind} "
|
|
f"observables={','.join(case.observables)} "
|
|
f"nqubits={case.nqubits} nlayers={case.nlayers} "
|
|
f"bond={case.bond} seed={case.seed}"
|
|
)
|
|
return
|
|
|
|
apply_case_defaults(args)
|
|
if isinstance(args.bond, str):
|
|
args.bond = optional_int(args.bond)
|
|
|
|
if args.mode == "validate":
|
|
args.exact = True
|
|
args.nqubits = min(args.nqubits, args.exact_max_qubits)
|
|
|
|
run_case(args)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|