"""MPI-parallel TN benchmark: path search + contraction via MPI.""" import pickle import time import argparse import numpy as np import cotengra as ctg import qibo from qibo import Circuit, gates from mpi4py import MPI from concurrent.futures import ProcessPoolExecutor, as_completed def _run_serial_search(tn_bytes, output_inds, repeats, seed, num_slices, n_ranks): """Run one serial HyperOptimizer in a subprocess, return (width, tree).""" import pickle, cotengra as ctg, random random.seed(seed) tn = pickle.loads(tn_bytes) opt = ctg.HyperOptimizer( methods=['kahypar', 'kahypar-agglom', 'spinglass'], max_repeats=repeats, parallel=False, minimize='flops', max_time=600, optlib="random", slicing_opts={'target_size': 2**30, 'allow_outer': False}, progbar=False, ) tree = tn.contraction_tree(optimize=opt, output_inds=output_inds) return tree.contraction_width(), tree def parallel_search(tn, output_inds, total_repeats, n_workers, num_slices, n_ranks, timeout=None): """Launch n_workers subprocesses each running serial search, return best tree.""" import pickle, os, signal from concurrent.futures import ProcessPoolExecutor, as_completed tn_bytes = pickle.dumps(tn) repeats_per = max(1, total_repeats // n_workers) best_width, best_tree = float('inf'), None with ProcessPoolExecutor(max_workers=n_workers) as pool: futures = { pool.submit(_run_serial_search, tn_bytes, output_inds, repeats_per, seed, num_slices, n_ranks): seed for seed in range(n_workers) } pids = {f: p.pid for f, p in zip(futures, pool._processes.values())} try: for fut in as_completed(futures, timeout=timeout): try: width, tree = fut.result() if width < best_width: best_width, best_tree = width, tree except Exception as e: print(f" [worker failed] {e}") except TimeoutError: pass for fut, pid in pids.items(): if not fut.done(): try: os.kill(pid, signal.SIGKILL) except ProcessLookupError: pass return best_tree def make_circuit(circuit_type, nqubits, nlayers=1): c = Circuit(nqubits) if circuit_type == "qft": from qibo.models import QFT return QFT(nqubits) elif circuit_type == "variational": for layer in range(nlayers): for q in range(nqubits): c.add(gates.RY(q, theta=np.random.uniform(0, 2 * np.pi))) offset = layer % 2 for q in range(offset, nqubits - 1, 2): c.add(gates.CZ(q, q + 1)) elif circuit_type == "ghz": c.add(gates.H(0)) for q in range(nqubits - 1): c.add(gates.CNOT(q, q + 1)) elif circuit_type == "brickwork": for q in range(nqubits): c.add(gates.H(q)) for layer in range(nlayers): offset = layer % 2 for q in range(offset, nqubits - 1, 2): c.add(gates.CNOT(q, q + 1)) c.add(gates.RZ(q, theta=np.random.uniform(0, 2 * np.pi))) c.add(gates.RZ(q + 1, theta=np.random.uniform(0, 2 * np.pi))) else: raise ValueError(f"Unknown circuit: {circuit_type}") return c def _contract_mpi(tree, arrays, comm, root=0): rank = comm.Get_rank() size = comm.Get_size() is_torch = type(arrays[0]).__module__.startswith("torch") result_np = None for i in range(rank, tree.multiplicity, size): x = tree.contract_slice(arrays, i) x_np = np.asfortranarray(x.detach().cpu().numpy() if is_torch else np.asarray(x)) result_np = x_np if result_np is None else result_np + x_np if result_np is None: result_np = np.zeros(1, dtype=np.complex64) result = np.zeros_like(result_np) if rank == root else None comm.Reduce(result_np, result, root=root) if rank == root: import torch return torch.from_numpy(np.asarray(result)) if is_torch else result return None def run_mpi(circuit, nqubits, num_slices, total_repeats=1024, load_path=None, save_path=None): """Each MPI rank runs serial path search over total_repeats/size trials, rank 0 picks the global best, then all ranks contract in parallel.""" comm = MPI.COMM_WORLD rank = comm.Get_rank() size = comm.Get_size() qibo.set_backend("qibotn", platform="quimb") b = qibo.get_backend() b.configure_tn_simulation(ansatz="tn") import torch qc = b._qibo_circuit_to_quimb(circuit, quimb_circuit_type=b.circuit_ansatz, gate_opts={"max_bond": None, "cutoff": 1e-10}) qc.to_backend = lambda x: torch.from_numpy(x).to(torch.complex64) # --- path search: each rank serial, gather best to rank 0 --- if load_path: if rank == 0: with open(load_path, "rb") as f: saved = pickle.load(f) tree, psi, t_search = saved["tree"], saved["psi"], 0.0 print(f" [path loaded] {load_path}") else: tree = psi = None t_search = 0.0 else: rank_repeats = max(1, total_repeats // size) t0 = time.time() # get TN object first (no contraction), then run parallel search psi_tn = qc.to_dense(rehearse="tn") local_tree = parallel_search( psi_tn, psi_tn.outer_inds(), rank_repeats, n_workers=48, num_slices=num_slices, n_ranks=size, timeout=630, ) t_search = time.time() - t0 local_psi = psi_tn all_results = comm.gather((local_tree.contraction_width(), local_tree, local_psi), root=0) if rank == 0: _, tree, psi = min(all_results, key=lambda x: x[0]) print(f" [path search] {t_search:.3f}s " f"flops~2^{tree.contraction_cost():.2f} " f"size~2^{tree.contraction_width():.2f} " f"slices={tree.multiplicity}") if save_path: with open(save_path, "wb") as f: pickle.dump({"tree": tree, "psi": psi}, f) print(f" [path saved] {save_path}") else: tree = psi = None if save_path: t_search = comm.bcast(t_search, root=0) return None, t_search tree = comm.bcast(tree, root=0) psi = comm.bcast(psi, root=0) t_search = comm.bcast(t_search, root=0) # --- contraction: all ranks work in parallel --- import torch torch.set_num_threads(max(1, 48 // size)) arrays = [torch.from_numpy(np.asarray(a)).to(torch.complex64) for a in psi.arrays] t0 = time.time() sv = _contract_mpi(tree, arrays, comm, root=0) t_contract = time.time() - t0 if rank == 0: print(f" [contraction] {t_contract:.3f}s") return np.array(sv).reshape(-1), t_search + t_contract return None, t_search + t_contract def main(): parser = argparse.ArgumentParser() parser.add_argument("--nqubits", type=int, default=30) parser.add_argument("--circuit", type=str, default="qft", choices=["qft", "variational", "ghz", "brickwork"]) parser.add_argument("--nlayers", type=int, default=3) parser.add_argument("--num-slices", type=int, default=1) parser.add_argument("--total-repeats", type=int, default=1024) parser.add_argument("--save-path", type=str, default=None) parser.add_argument("--load-path", type=str, default=None) parser.add_argument("--no-compare", action="store_true") args = parser.parse_args() comm = MPI.COMM_WORLD rank = comm.Get_rank() if rank == 0: print(f"Circuit: {args.circuit}, nqubits={args.nqubits}, " f"nlayers={args.nlayers}, ranks={comm.Get_size()}") np.random.seed(42) circuit = make_circuit(args.circuit, args.nqubits, args.nlayers) try: sv, t_total = run_mpi(circuit, args.nqubits, args.num_slices, total_repeats=args.total_repeats, load_path=args.load_path, save_path=args.save_path) except Exception as e: if rank == 0: print(f"[FAILED] {e}") raise if rank == 0 and sv is not None: print(f"\n[quimb TN MPI] time={t_total:.4f}s shape={sv.shape}") np.save(f"data/sv_tn_{args.circuit}{args.nqubits}_mpi.npy", sv) if not args.no_compare: from benchmark_tn import run_qibojit np.random.seed(42) circuit_ref = make_circuit(args.circuit, args.nqubits, args.nlayers) sv_ref, t_ref = run_qibojit(circuit_ref) fid = abs(np.dot(sv_ref.conj(), sv)) ** 2 print(f"[qibojit] time={t_ref:.4f}s") print(f"Fidelity : {fid:.8f}") print(f"L2 error : {np.linalg.norm(sv_ref - sv):.2e}") if t_total > 0: print(f"Speedup : {t_ref/t_total:.2f}x") if __name__ == "__main__": main()