tn脚本更新
Some checks failed
Build wheels / build (ubuntu-latest, 3.11) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.12) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.13) (push) Has been cancelled
Tests / check (push) Has been cancelled
Tests / build (ubuntu-latest, 3.11) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.12) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.13) (push) Has been cancelled
Some checks failed
Build wheels / build (ubuntu-latest, 3.11) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.12) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.13) (push) Has been cancelled
Tests / check (push) Has been cancelled
Tests / build (ubuntu-latest, 3.11) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.12) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.13) (push) Has been cancelled
This commit is contained in:
307
benchmark_tn.py
307
benchmark_tn.py
@@ -1,4 +1,5 @@
|
||||
"""Benchmark: qibotn/quimb generic TN — expectation values."""
|
||||
import pickle
|
||||
import time
|
||||
import argparse
|
||||
import numpy as np
|
||||
@@ -22,6 +23,15 @@ def make_circuit(circuit_type, nqubits, nlayers=1):
|
||||
c.add(gates.H(0))
|
||||
for q in range(nqubits - 1):
|
||||
c.add(gates.CNOT(q, q + 1))
|
||||
elif circuit_type == "brickwork":
|
||||
for q in range(nqubits):
|
||||
c.add(gates.H(q))
|
||||
for layer in range(nlayers):
|
||||
offset = layer % 2
|
||||
for q in range(offset, nqubits - 1, 2):
|
||||
c.add(gates.CNOT(q, q + 1))
|
||||
c.add(gates.RZ(q, theta=np.random.uniform(0, 2 * np.pi)))
|
||||
c.add(gates.RZ(q + 1, theta=np.random.uniform(0, 2 * np.pi)))
|
||||
else:
|
||||
raise ValueError(f"Unknown circuit: {circuit_type}")
|
||||
return c
|
||||
@@ -33,80 +43,305 @@ def make_z_observable(nqubits):
|
||||
return ["z"], [(0,)], [1.0]
|
||||
|
||||
|
||||
def run_quimb_tn(circuit, nqubits):
|
||||
def run_quimb_tn(circuit, nqubits, num_slices, load_path=None, save_path=None):
|
||||
"""Mode: expval — compute <Z_0> via local_expectation (lightcone pruning)."""
|
||||
qibo.set_backend("qibotn", platform="quimb")
|
||||
b = qibo.get_backend()
|
||||
b.configure_tn_simulation(ansatz="tn") # generic TN, no MPS
|
||||
#if max_time is not None:
|
||||
# opt = ctg.HyperOptimizer(max_repeats=128, max_time=max_time, minimize=minimize, parallel=True)
|
||||
#else:
|
||||
opt = ctg.HyperOptimizer(
|
||||
max_repeats=16,
|
||||
parallel=True,
|
||||
slicing_opts={'target_size': 2**24}, # 限制单个张量最大 2^28 个元素
|
||||
progbar=True
|
||||
)
|
||||
|
||||
b.contractions_optimizer = opt
|
||||
b.configure_tn_simulation(ansatz="tn")
|
||||
|
||||
operators, sites, coeffs = make_z_observable(nqubits)
|
||||
ops = b._string_to_quimb_operator(operators[0])
|
||||
qc = b._qibo_circuit_to_quimb(circuit, quimb_circuit_type=b.circuit_ansatz,
|
||||
gate_opts={"max_bond": None, "cutoff": 1e-10})
|
||||
|
||||
if load_path:
|
||||
with open(load_path, "rb") as f:
|
||||
saved = pickle.load(f)
|
||||
tree = saved["tree"]
|
||||
t_search = 0.0
|
||||
print(f" [path loaded] {load_path}")
|
||||
else:
|
||||
opt = ctg.HyperOptimizer(
|
||||
methods=['kahypar', 'random-greedy', 'spinglass'],
|
||||
max_repeats=16,
|
||||
parallel=True,
|
||||
max_time=60,
|
||||
slicing_opts={'target_slices': num_slices},
|
||||
progbar=True,
|
||||
)
|
||||
t0 = time.time()
|
||||
rehearsal = qc.local_expectation(ops, where=sites[0], optimize=opt,
|
||||
simplify_sequence="R", rehearse=True)
|
||||
t_search = time.time() - t0
|
||||
tree = rehearsal['tree']
|
||||
print(f" [path search] {t_search:.3f}s flops~2^{tree.contraction_cost():.2f} size~2^{tree.contraction_width():.2f}")
|
||||
|
||||
if save_path:
|
||||
with open(save_path, "wb") as f:
|
||||
pickle.dump({"tree": tree}, f)
|
||||
print(f" [path saved] {save_path}")
|
||||
|
||||
t0 = time.time()
|
||||
expval = b.exp_value_observable_symbolic(circuit, operators, sites, coeffs, nqubits)
|
||||
elapsed = time.time() - t0
|
||||
return expval, elapsed
|
||||
expval = qc.local_expectation(ops, where=sites[0], optimize=tree, simplify_sequence="R")
|
||||
t_contract = time.time() - t0
|
||||
print(f" [contraction] {t_contract:.3f}s")
|
||||
|
||||
return float(expval.real), t_search + t_contract
|
||||
|
||||
|
||||
def run_quimb_tn_statevector(circuit, nqubits, num_slices, load_path=None, save_path=None):
|
||||
"""Mode: statevector — contract full TN to dense vector."""
|
||||
qibo.set_backend("qibotn", platform="quimb")
|
||||
b = qibo.get_backend()
|
||||
b.configure_tn_simulation(ansatz="tn")
|
||||
|
||||
import torch
|
||||
qc = b._qibo_circuit_to_quimb(circuit, quimb_circuit_type=b.circuit_ansatz,
|
||||
gate_opts={"max_bond": None, "cutoff": 1e-10})
|
||||
qc.to_backend = torch.from_numpy
|
||||
if load_path:
|
||||
with open(load_path, "rb") as f:
|
||||
saved = pickle.load(f)
|
||||
tree = saved["tree"]
|
||||
t_search = 0.0
|
||||
print(f" [path loaded] {load_path}")
|
||||
else:
|
||||
opt = ctg.HyperOptimizer(
|
||||
methods=['kahypar', 'random-greedy', 'spinglass'],
|
||||
max_repeats=128,
|
||||
parallel=64,
|
||||
max_time=100,
|
||||
minimize='size',
|
||||
slicing_opts={'target_slices': num_slices},
|
||||
#slicing_opts={'target_size': 2**30},
|
||||
progbar=True,
|
||||
)
|
||||
t0 = time.time()
|
||||
rehearsal = qc.to_dense(optimize=opt, rehearse=True)
|
||||
t_search = time.time() - t0
|
||||
tree = rehearsal['tree']
|
||||
print(f" [path search] {t_search:.3f}s flops~2^{tree.contraction_cost():.2f} size~2^{tree.contraction_width():.2f}")
|
||||
|
||||
if save_path:
|
||||
with open(save_path, "wb") as f:
|
||||
pickle.dump({"tree": tree}, f)
|
||||
print(f" [path saved] {save_path}")
|
||||
|
||||
t0 = time.time()
|
||||
sv = qc.to_dense(optimize=tree,implementation="cotengra").reshape(-1)
|
||||
t_contract = time.time() - t0
|
||||
print(f" [contraction] {t_contract:.3f}s")
|
||||
sv_tn = np.array(sv)
|
||||
return sv_tn, t_search + t_contract
|
||||
|
||||
|
||||
def _contract_mpi(tree, arrays, comm, root=0):
|
||||
"""Contract slices via MPI, returning result as the same array type as input.
|
||||
|
||||
Unlike ``cotengra.ContractionTree.contract_mpi``, this works with any
|
||||
array backend (numpy, torch, etc.) — it only converts to numpy at the
|
||||
MPI-reduce boundary and converts back.
|
||||
"""
|
||||
size = comm.Get_size()
|
||||
rank = comm.Get_rank()
|
||||
|
||||
result_np = None
|
||||
is_torch = type(arrays[0]).__module__.startswith("torch")
|
||||
|
||||
for i in range(rank, tree.multiplicity, size):
|
||||
x = tree.contract_slice(arrays, i)
|
||||
x_np = np.asfortranarray(x.detach().cpu().numpy() if is_torch else np.asarray(x))
|
||||
|
||||
if result_np is None:
|
||||
result_np = x_np
|
||||
else:
|
||||
result_np += x_np
|
||||
|
||||
if result_np is None:
|
||||
result_np = np.zeros(1, dtype=np.complex64)
|
||||
|
||||
if rank == root:
|
||||
result = np.zeros_like(result_np)
|
||||
else:
|
||||
result = None
|
||||
comm.Reduce(result_np, result, root=root)
|
||||
|
||||
if rank == root:
|
||||
import torch
|
||||
return torch.from_numpy(np.asarray(result)) if is_torch else result
|
||||
return None
|
||||
|
||||
|
||||
def run_quimb_tn_statevector_mpi(circuit, nqubits, num_slices, load_path=None, save_path=None):
|
||||
"""MPI-parallel statevector via custom MPI contraction (supports torch backend)."""
|
||||
from mpi4py import MPI
|
||||
comm = MPI.COMM_WORLD
|
||||
rank = comm.Get_rank()
|
||||
size = comm.Get_size()
|
||||
|
||||
qibo.set_backend("qibotn", platform="quimb")
|
||||
b = qibo.get_backend()
|
||||
b.configure_tn_simulation(ansatz="tn")
|
||||
|
||||
import torch
|
||||
qc = b._qibo_circuit_to_quimb(circuit, quimb_circuit_type=b.circuit_ansatz,
|
||||
gate_opts={"max_bond": None, "cutoff": 1e-10})
|
||||
qc.to_backend = torch.from_numpy
|
||||
|
||||
# path search on rank 0, broadcast to all
|
||||
if rank == 0:
|
||||
if load_path:
|
||||
with open(load_path, "rb") as f:
|
||||
saved = pickle.load(f)
|
||||
tree = saved["tree"]
|
||||
psi = saved["psi"]
|
||||
t_search = 0.0
|
||||
print(f" [path loaded] {load_path}")
|
||||
else:
|
||||
opt = ctg.HyperOptimizer(
|
||||
methods=['kahypar', 'random-greedy', 'spinglass'],
|
||||
max_repeats=128,
|
||||
parallel=64,
|
||||
#max_repeats=1,
|
||||
max_time=100,
|
||||
minimize='size',
|
||||
slicing_opts={'target_slices': max(num_slices, size), 'allow_outer': False},
|
||||
progbar=True,
|
||||
)
|
||||
t0 = time.time()
|
||||
rehearsal = qc.to_dense(optimize=opt, rehearse=True)
|
||||
t_search = time.time() - t0
|
||||
tree = rehearsal['tree']
|
||||
psi = rehearsal['tn']
|
||||
print(f" [path search] {t_search:.3f}s flops~2^{tree.contraction_cost():.2f} size~2^{tree.contraction_width():.2f} slices={tree.multiplicity}")
|
||||
if save_path:
|
||||
with open(save_path, "wb") as f:
|
||||
pickle.dump({"tree": tree, "psi": psi}, f)
|
||||
print(f" [path saved] {save_path}")
|
||||
else:
|
||||
tree = None
|
||||
psi = None
|
||||
t_search = 0.0
|
||||
|
||||
tree = comm.bcast(tree, root=0)
|
||||
psi = comm.bcast(psi, root=0)
|
||||
t_search = comm.bcast(t_search, root=0)
|
||||
|
||||
arrays = psi.arrays
|
||||
t0 = time.time()
|
||||
sv = _contract_mpi(tree, arrays, comm, root=0)
|
||||
t_contract = time.time() - t0
|
||||
|
||||
if rank == 0:
|
||||
print(f" [contraction] {t_contract:.3f}s")
|
||||
return np.array(sv).reshape(-1), t_search + t_contract
|
||||
return None, t_search + t_contract
|
||||
|
||||
|
||||
def run_quimb_tn_samples(circuit, nshots=1024):
|
||||
"""Mode: samples — sample from circuit output distribution."""
|
||||
qibo.set_backend("qibotn", platform="quimb")
|
||||
b = qibo.get_backend()
|
||||
b.configure_tn_simulation(ansatz="tn")
|
||||
|
||||
t0 = time.time()
|
||||
result = b.execute_circuit(circuit, nshots=nshots)
|
||||
t_total = time.time() - t0
|
||||
print(f" [sampling] {t_total:.3f}s nshots={nshots}")
|
||||
print(f" top states: {dict(list(result.frequencies().items())[:5])}")
|
||||
return result, t_total
|
||||
|
||||
|
||||
def qibojit_expval(circuit, nqubits):
|
||||
"""Compute sum_i <Z_i> via qibojit statevector."""
|
||||
"""Compute <Z_0> via qibojit statevector."""
|
||||
qibo.set_backend("qibojit", platform="numba")
|
||||
t0 = time.time()
|
||||
result = circuit()
|
||||
elapsed = time.time() - t0
|
||||
sv = np.array(result.state(), dtype=complex).flatten()
|
||||
probs = np.abs(sv) ** 2
|
||||
expval = sum(
|
||||
probs[idx] * (1 - 2 * ((idx >> (nqubits - 1 - i)) & 1))
|
||||
for i in range(nqubits)
|
||||
for idx in range(len(probs))
|
||||
)
|
||||
return float(expval), elapsed
|
||||
bits = (np.arange(len(probs)) >> (nqubits - 1)) & 1
|
||||
expval = float(np.dot(probs, 1 - 2 * bits))
|
||||
return expval, elapsed
|
||||
|
||||
|
||||
def run_qibojit(circuit):
|
||||
"""Compute full statevector via qibojit."""
|
||||
qibo.set_backend("qibojit", platform="numba")
|
||||
t0 = time.time()
|
||||
result = circuit()
|
||||
elapsed = time.time() - t0
|
||||
sv = np.array(result.state(), dtype=complex).flatten()
|
||||
return sv, elapsed
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--nqubits", type=int, default=10)
|
||||
parser.add_argument("--circuit", type=str, default="qft",
|
||||
choices=["qft", "variational", "ghz"])
|
||||
choices=["qft", "variational", "ghz", "brickwork"])
|
||||
parser.add_argument("--nlayers", type=int, default=3)
|
||||
parser.add_argument("--optimizer", type=str, default="auto-hq")
|
||||
parser.add_argument("--max-time", type=float, default=None,
|
||||
help="HyperOptimizer max search time (seconds); overrides --optimizer")
|
||||
parser.add_argument("--minimize", type=str, default="flops",
|
||||
choices=["flops", "size", "write"],
|
||||
help="HyperOptimizer minimize target")
|
||||
parser.add_argument("--num-slices", type=int, default=1)
|
||||
parser.add_argument("--nshots", type=int, default=1024)
|
||||
parser.add_argument("--mode", type=str, default="statevector",
|
||||
choices=["expval", "statevector", "samples"],
|
||||
help="expval: local_expectation; statevector: to_dense; samples: sampling")
|
||||
parser.add_argument("--mpi", action="store_true",
|
||||
help="Use MPI-parallel contraction (run with mpirun -n N)")
|
||||
parser.add_argument("--no-compare", action="store_true",
|
||||
help="Skip qibojit reference run")
|
||||
parser.add_argument("--save-path", type=str, default=None,
|
||||
help="Save contraction tree to a pickle file")
|
||||
parser.add_argument("--load-path", type=str, default=None,
|
||||
help="Load contraction tree from a pickle file (skip path search)")
|
||||
args = parser.parse_args()
|
||||
|
||||
print(f"Circuit: {args.circuit}, nqubits={args.nqubits}, nlayers={args.nlayers}")
|
||||
print(f"TN config: optimizer={args.optimizer}, max_time={args.max_time}, minimize={args.minimize}")
|
||||
print(f"Circuit: {args.circuit}, nqubits={args.nqubits}, nlayers={args.nlayers}, mode={args.mode}")
|
||||
|
||||
np.random.seed(42)
|
||||
circuit_tn = make_circuit(args.circuit, args.nqubits, args.nlayers)
|
||||
try:
|
||||
expval_tn, t_tn = run_quimb_tn(circuit_tn, args.nqubits)
|
||||
print(f"\n[quimb TN] time={t_tn:.4f}s <sum Z_i>={expval_tn:.8f}")
|
||||
if args.mode == "expval":
|
||||
expval_tn, t_tn = run_quimb_tn(circuit_tn, args.nqubits, args.num_slices,
|
||||
load_path=args.load_path, save_path=args.save_path)
|
||||
print(f"\n[quimb TN] time={t_tn:.4f}s <Z_0>={expval_tn:.8f}")
|
||||
elif args.mode == "statevector":
|
||||
if args.mpi:
|
||||
sv_tn, t_tn = run_quimb_tn_statevector_mpi(circuit_tn, args.nqubits, args.num_slices,
|
||||
load_path=args.load_path, save_path=args.save_path)
|
||||
else:
|
||||
sv_tn, t_tn = run_quimb_tn_statevector(circuit_tn, args.nqubits, args.num_slices,
|
||||
load_path=args.load_path, save_path=args.save_path)
|
||||
if sv_tn is not None:
|
||||
print(f"\n[quimb TN] time={t_tn:.4f}s statevector shape={sv_tn.shape}")
|
||||
np.save(f"data/sv_tn_{args.circuit}{args.nqubits}.npy", sv_tn)
|
||||
else:
|
||||
_, t_tn = run_quimb_tn_samples(circuit_tn, args.nqubits, args.nshots)
|
||||
print(f"\n[quimb TN] time={t_tn:.4f}s")
|
||||
args.no_compare = True # samples 模式无法和 qibojit 期望值对比
|
||||
except Exception as e:
|
||||
print(f"[quimb TN] FAILED: {e}")
|
||||
raise
|
||||
|
||||
if not args.no_compare:
|
||||
if not args.no_compare and args.mode != "statevector":
|
||||
np.random.seed(42)
|
||||
circuit_ref = make_circuit(args.circuit, args.nqubits, args.nlayers)
|
||||
expval_ref, t_ref = qibojit_expval(circuit_ref, args.nqubits)
|
||||
print(f"[qibojit] time={t_ref:.4f}s <sum Z_i>={expval_ref:.8f}")
|
||||
print(f"[qibojit] time={t_ref:.4f}s <Z_0>={expval_ref:.8f}")
|
||||
print(f"\nDiff : {abs(expval_tn - expval_ref):.2e}")
|
||||
if t_tn > 0:
|
||||
print(f"Speedup : {t_ref/t_tn:.2f}x")
|
||||
elif not args.no_compare and args.mode == "statevector" and sv_tn is not None:
|
||||
np.random.seed(42)
|
||||
circuit_ref = make_circuit(args.circuit, args.nqubits, args.nlayers)
|
||||
sv_ref, t_ref = run_qibojit(circuit_ref)
|
||||
fid = abs(np.dot(sv_ref.conj(), sv_tn)) ** 2
|
||||
l2_err = np.linalg.norm(sv_ref - sv_tn)
|
||||
print(f"[qibojit] time={t_ref:.4f}s")
|
||||
print(f"Fidelity : {fid:.8f} (1=perfect)")
|
||||
print(f"L2 error : {l2_err:.2e}")
|
||||
if t_tn > 0:
|
||||
print(f"Speedup : {t_ref/t_tn:.2f}x")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user