添加MPI并行TN benchmark及辅助脚本,移除旧benchmark
Some checks failed
Build wheels / build (ubuntu-latest, 3.11) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.12) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.13) (push) Has been cancelled
Tests / check (push) Has been cancelled
Tests / build (ubuntu-latest, 3.11) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.12) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.13) (push) Has been cancelled
Some checks failed
Build wheels / build (ubuntu-latest, 3.11) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.12) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.13) (push) Has been cancelled
Tests / check (push) Has been cancelled
Tests / build (ubuntu-latest, 3.11) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.12) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.13) (push) Has been cancelled
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
100
benchmark_tn.py
100
benchmark_tn.py
@@ -1,12 +1,39 @@
|
||||
"""Benchmark: qibotn/quimb generic TN — expectation values."""
|
||||
import multiprocessing
|
||||
multiprocessing.set_start_method("spawn", force=True)
|
||||
import pickle
|
||||
import time
|
||||
import threading
|
||||
import argparse
|
||||
import numpy as np
|
||||
import cotengra as ctg
|
||||
import qibo
|
||||
from qibo import Circuit, gates
|
||||
|
||||
class TimedTrialFn:
|
||||
def __init__(self, trial_fn, timeout=15):
|
||||
self.trial_fn = trial_fn
|
||||
self.timeout = timeout
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
result = [None]
|
||||
exc = [None]
|
||||
|
||||
def _run():
|
||||
try:
|
||||
result[0] = self.trial_fn(*args, **kwargs)
|
||||
except Exception as e:
|
||||
exc[0] = e
|
||||
|
||||
t = threading.Thread(target=_run, daemon=True)
|
||||
t.start()
|
||||
t.join(self.timeout)
|
||||
if t.is_alive():
|
||||
raise TimeoutError(f"trial exceeded {self.timeout}s")
|
||||
if exc[0] is not None:
|
||||
raise exc[0]
|
||||
return result[0]
|
||||
|
||||
def make_circuit(circuit_type, nqubits, nlayers=1):
|
||||
c = Circuit(nqubits)
|
||||
if circuit_type == "qft":
|
||||
@@ -98,7 +125,7 @@ def run_quimb_tn_statevector(circuit, nqubits, num_slices, load_path=None, save_
|
||||
import torch
|
||||
qc = b._qibo_circuit_to_quimb(circuit, quimb_circuit_type=b.circuit_ansatz,
|
||||
gate_opts={"max_bond": None, "cutoff": 1e-10})
|
||||
qc.to_backend = torch.from_numpy
|
||||
qc.to_backend = lambda x: torch.from_numpy(x).to(torch.complex64)
|
||||
if load_path:
|
||||
with open(load_path, "rb") as f:
|
||||
saved = pickle.load(f)
|
||||
@@ -107,25 +134,30 @@ def run_quimb_tn_statevector(circuit, nqubits, num_slices, load_path=None, save_
|
||||
print(f" [path loaded] {load_path}")
|
||||
else:
|
||||
opt = ctg.HyperOptimizer(
|
||||
methods=['kahypar', 'random-greedy', 'spinglass'],
|
||||
max_repeats=128,
|
||||
#methods=['kahypar', 'random-greedy', 'spinglass'],
|
||||
max_repeats=1024,
|
||||
#parallel="concurrent.futures",
|
||||
parallel=64,
|
||||
max_time=100,
|
||||
max_time=60,
|
||||
minimize='size',
|
||||
slicing_opts={'target_slices': num_slices},
|
||||
#slicing_opts={'target_size': 2**30},
|
||||
progbar=True,
|
||||
on_trial_error='ignore'
|
||||
)
|
||||
|
||||
t0 = time.time()
|
||||
rehearsal = qc.to_dense(optimize=opt, rehearse=True)
|
||||
t_search = time.time() - t0
|
||||
tree = rehearsal['tree']
|
||||
print(f" [path search] {t_search:.3f}s flops~2^{tree.contraction_cost():.2f} size~2^{tree.contraction_width():.2f}")
|
||||
#print(f" [path search] {t_search:.3f}s flops~2^{tree.contraction_cost():.2f} size~2^{tree.contraction_width():.2f}")
|
||||
|
||||
if save_path:
|
||||
with open(save_path, "wb") as f:
|
||||
pickle.dump({"tree": tree}, f)
|
||||
print(f" [path saved] {save_path}")
|
||||
print(f" [path search] {t_search:.3f}s flops~2^{tree.contraction_cost():.2f} size~2^{tree.contraction_width():.2f}")
|
||||
return None, t_search
|
||||
|
||||
t0 = time.time()
|
||||
sv = qc.to_dense(optimize=tree).reshape(-1)
|
||||
@@ -186,42 +218,48 @@ def run_quimb_tn_statevector_mpi(circuit, nqubits, num_slices, load_path=None, s
|
||||
import torch
|
||||
qc = b._qibo_circuit_to_quimb(circuit, quimb_circuit_type=b.circuit_ansatz,
|
||||
gate_opts={"max_bond": None, "cutoff": 1e-10})
|
||||
qc.to_backend = torch.from_numpy
|
||||
qc.to_backend = lambda x: torch.from_numpy(x).to(torch.complex64)
|
||||
|
||||
# path search on rank 0, broadcast to all
|
||||
if rank == 0:
|
||||
if load_path:
|
||||
if load_path:
|
||||
if rank == 0:
|
||||
with open(load_path, "rb") as f:
|
||||
saved = pickle.load(f)
|
||||
tree = saved["tree"]
|
||||
psi = saved["psi"]
|
||||
t_search = 0.0
|
||||
tree, psi, t_search = saved["tree"], saved["psi"], 0.0
|
||||
print(f" [path loaded] {load_path}")
|
||||
else:
|
||||
opt = ctg.HyperOptimizer(
|
||||
methods=['kahypar', 'random-greedy', 'spinglass'],
|
||||
max_repeats=128,
|
||||
parallel=64,
|
||||
#max_repeats=1,
|
||||
max_time=100,
|
||||
minimize='size',
|
||||
slicing_opts={'target_slices': max(num_slices, size), 'allow_outer': False},
|
||||
progbar=True,
|
||||
)
|
||||
t0 = time.time()
|
||||
rehearsal = qc.to_dense(optimize=opt, rehearse=True)
|
||||
t_search = time.time() - t0
|
||||
tree = rehearsal['tree']
|
||||
psi = rehearsal['tn']
|
||||
tree = psi = None
|
||||
t_search = 0.0
|
||||
else:
|
||||
# each rank runs serial search over its share of trials
|
||||
total_repeats = 1024
|
||||
rank_repeats = max(1, total_repeats // size)
|
||||
opt = ctg.HyperOptimizer(
|
||||
methods=['kahypar', 'random-greedy', 'spinglass'],
|
||||
max_repeats=rank_repeats,
|
||||
parallel=False,
|
||||
max_time=100,
|
||||
minimize='size',
|
||||
slicing_opts={'target_slices': max(num_slices, size), 'allow_outer': False},
|
||||
progbar=(rank == 0),
|
||||
)
|
||||
t0 = time.time()
|
||||
rehearsal = qc.to_dense(optimize=opt, rehearse=True)
|
||||
t_search = time.time() - t0
|
||||
local_tree = rehearsal['tree']
|
||||
local_psi = rehearsal['tn']
|
||||
local_size = local_tree.contraction_width()
|
||||
|
||||
# gather all trees to rank 0, pick best by contraction_width
|
||||
all_results = comm.gather((local_size, local_tree, local_psi), root=0)
|
||||
if rank == 0:
|
||||
_, tree, psi = min(all_results, key=lambda x: x[0])
|
||||
print(f" [path search] {t_search:.3f}s flops~2^{tree.contraction_cost():.2f} size~2^{tree.contraction_width():.2f} slices={tree.multiplicity}")
|
||||
if save_path:
|
||||
with open(save_path, "wb") as f:
|
||||
pickle.dump({"tree": tree, "psi": psi}, f)
|
||||
print(f" [path saved] {save_path}")
|
||||
else:
|
||||
tree = None
|
||||
psi = None
|
||||
t_search = 0.0
|
||||
else:
|
||||
tree = psi = None
|
||||
|
||||
tree = comm.bcast(tree, root=0)
|
||||
psi = comm.bcast(psi, root=0)
|
||||
|
||||
Reference in New Issue
Block a user