"""MPI-parallel TN benchmark: path search + contraction via MPI.""" import json import pickle import time import argparse import numpy as np import cotengra as ctg import qibo from qibo import Circuit, gates from mpi4py import MPI from concurrent.futures import ProcessPoolExecutor, as_completed from qibotn.observables import check_observable, extract_gates_and_qubits def _load_observable(observable_file=None, observable_json=None): if observable_file: with open(observable_file, "r", encoding="utf8") as f: return json.load(f) if observable_json: return json.loads(observable_json) return None def _term_to_quimb_operator(term): """Convert one extracted Hamiltonian term to a quimb operator.""" import quimb as qu coeff = complex(term[0][2]) if term else 1.0 op = None where = [] for qubit, gate_name, _ in term: qubit = int(qubit) gate_name = str(gate_name).upper() if gate_name == "I": continue where.append(qubit) op = qu.pauli(gate_name.lower()) if op is None else op & qu.pauli(gate_name.lower()) return complex(coeff), op, tuple(where) def _run_serial_search(tn_bytes, output_inds, repeats, seed, num_slices, n_ranks, max_time): import pickle, cotengra as ctg, random random.seed(seed) tn = pickle.loads(tn_bytes) opt = ctg.HyperOptimizer( methods=['kahypar', 'kahypar-agglom', 'spinglass'], max_repeats=repeats, parallel=False, minimize='combo-256', max_time=max_time, optlib="random", slicing_opts={'target_size': 2**29, 'allow_outer': True}, progbar=False, ) tree = tn.contraction_tree(optimize=opt, output_inds=output_inds) return tree.combo_cost(factor=256), tree def parallel_search(tn, output_inds, total_repeats, n_workers, num_slices, n_ranks, timeout): import pickle, os, signal from concurrent.futures import ProcessPoolExecutor, as_completed tn_bytes = pickle.dumps(tn) if n_workers <= 1: return _run_serial_search( tn_bytes, output_inds, total_repeats, 0, num_slices, n_ranks, timeout )[1] repeats_per = max(1, total_repeats // n_workers) best_cost, best_tree = float('inf'), None pool = ProcessPoolExecutor(max_workers=n_workers) futures = [ pool.submit(_run_serial_search, tn_bytes, output_inds, repeats_per, seed, num_slices, n_ranks, timeout) for seed in range(n_workers) ] try: for fut in as_completed(futures, timeout=timeout + 5): try: cost, tree = fut.result() if cost < best_cost: best_cost, best_tree = cost, tree except Exception as e: print(f" [worker failed] {e}") except TimeoutError: pass finally: for fut in futures: fut.cancel() for pid in list(pool._processes.keys()): try: os.kill(pid, signal.SIGKILL) except ProcessLookupError: pass pool.shutdown(wait=False) return best_tree def make_circuit(circuit_type, nqubits, nlayers=1): c = Circuit(nqubits) if circuit_type == "qft": from qibo.models import QFT return QFT(nqubits) elif circuit_type == "variational": for layer in range(nlayers): for q in range(nqubits): c.add(gates.RY(q, theta=np.random.uniform(0, 2 * np.pi))) offset = layer % 2 for q in range(offset, nqubits - 1, 2): c.add(gates.CZ(q, q + 1)) elif circuit_type == "ghz": c.add(gates.H(0)) for q in range(nqubits - 1): c.add(gates.CNOT(q, q + 1)) elif circuit_type == "brickwork": for q in range(nqubits): c.add(gates.H(q)) for layer in range(nlayers): offset = layer % 2 for q in range(offset, nqubits - 1, 2): c.add(gates.CNOT(q, q + 1)) c.add(gates.RZ(q, theta=np.random.uniform(0, 2 * np.pi))) c.add(gates.RZ(q + 1, theta=np.random.uniform(0, 2 * np.pi))) else: raise ValueError(f"Unknown circuit: {circuit_type}") return c def _contract_mpi(tree, arrays, comm, root=0): rank = comm.Get_rank() size = comm.Get_size() is_torch = type(arrays[0]).__module__.startswith("torch") result_np = None for i in range(rank, tree.multiplicity, size): x = tree.contract_slice(arrays, i) x_np = np.asfortranarray(x.detach().cpu().numpy() if is_torch else np.asarray(x)) result_np = x_np if result_np is None else result_np + x_np if result_np is None: result_np = np.zeros(1, dtype=np.complex128) result = np.zeros_like(result_np) if rank == root else None comm.Reduce(result_np, result, root=root) if rank == root: import torch return torch.from_numpy(np.asarray(result)) if is_torch else result return None def run_mpi(circuit, nqubits, num_slices, total_repeats=1024, load_path=None, save_path=None): """Each MPI rank runs serial path search over total_repeats/size trials, rank 0 picks the global best, then all ranks contract in parallel.""" comm = MPI.COMM_WORLD rank = comm.Get_rank() size = comm.Get_size() qibo.set_backend("qibotn", platform="quimb") b = qibo.get_backend() b.configure_tn_simulation(ansatz="tn") import torch qc = b._qibo_circuit_to_quimb(circuit, quimb_circuit_type=b.circuit_ansatz, gate_opts={"max_bond": None, "cutoff": 1e-10}) qc.to_backend = lambda x: torch.from_numpy(x).to(torch.complex128) # --- path search: each rank serial, gather best to rank 0 --- if load_path: if rank == 0: with open(load_path, "rb") as f: saved = pickle.load(f) tree, psi, t_search = saved["tree"], saved["psi"], 0.0 print(f" [path loaded] {load_path}") else: tree = psi = None t_search = 0.0 else: rank_repeats = max(1, total_repeats // size) t0 = time.time() # get TN object first (no contraction), then run parallel search psi_tn = qc.to_dense(rehearse="tn") local_tree = parallel_search( psi_tn, psi_tn.outer_inds(), rank_repeats, n_workers=48, num_slices=num_slices, n_ranks=size, timeout=600, ) t_search = time.time() - t0 local_psi = psi_tn all_results = comm.gather((local_tree.combo_cost(factor=256), local_tree, local_psi), root=0) if rank == 0: _, tree, psi = min(all_results, key=lambda x: x[0]) print(f" [path search] {t_search:.3f}s " f"flops~2^{tree.contraction_cost(log=2):.2f} " f"size~2^{tree.contraction_width():.2f} " f"slices={tree.multiplicity}") if save_path: with open(save_path, "wb") as f: pickle.dump({"tree": tree, "psi": psi}, f) print(f" [path saved] {save_path}") else: tree = psi = None if save_path: t_search = comm.bcast(t_search, root=0) return None, t_search tree = comm.bcast(tree, root=0) psi = comm.bcast(psi, root=0) t_search = comm.bcast(t_search, root=0) # --- contraction: all ranks work in parallel --- import torch torch.set_num_threads(max(1, 96 // size)) arrays = [torch.from_numpy(np.asarray(a)).to(torch.complex128) for a in psi.arrays] t0 = time.time() sv = _contract_mpi(tree, arrays, comm, root=0) t_contract = time.time() - t0 if rank == 0: print(f" [contraction] {t_contract:.3f}s") return np.array(sv).reshape(-1), t_search + t_contract return None, t_search + t_contract def run_mpi_expval( circuit, nqubits, observable=None, total_repeats=1024, search_workers=1, search_timeout=300, ): """Compute a Hamiltonian expectation value directly from TN via MPI. MPI parallelizes over Hamiltonian terms; ProcessPool optionally helps path search for each term.""" import torch comm = MPI.COMM_WORLD rank = comm.Get_rank() size = comm.Get_size() qibo.set_backend("qibotn", platform="quimb") b = qibo.get_backend() b.configure_tn_simulation(ansatz="tn") observable = check_observable(observable, nqubits) ham_gate_map = extract_gates_and_qubits(observable) qc = b._qibo_circuit_to_quimb(circuit, quimb_circuit_type=b.circuit_ansatz, gate_opts={"max_bond": None, "cutoff": 1e-10}) my_terms = ham_gate_map[rank::size] torch.set_num_threads(max(1, 96 // size)) t0 = time.time() my_exp = 0.0 + 0.0j for term in my_terms: coeff, op, where = _term_to_quimb_operator(term) if op is None: my_exp += coeff continue tn = qc.local_expectation_tn(op, where=where) if len(tn.outer_inds()) == 0: val = complex(tn.contract()) else: tree = parallel_search( tn, tn.outer_inds(), total_repeats, n_workers=search_workers, num_slices=1, n_ranks=size, timeout=search_timeout, ) if tree is None: raise RuntimeError("Failed to find a contraction tree for expectation TN.") arrays = [torch.from_numpy(np.asarray(a)).to(torch.complex128) for a in tn.arrays] acc = sum(tree.contract_slice(arrays, i) for i in range(tree.multiplicity)) val = complex(acc.item() if hasattr(acc, 'item') else acc) my_exp += coeff * val t_total = time.time() - t0 all_results = comm.gather(my_exp, root=0) if rank == 0: total_exp = sum(all_results) print(f"\n[TN expval] time={t_total:.4f}s expval={total_exp.real:.12f}") return np.real_if_close(total_exp), t_total return None, t_total def main(): parser = argparse.ArgumentParser() parser.add_argument("--nqubits", type=int, default=30) parser.add_argument("--circuit", type=str, default="qft", choices=["qft", "variational", "ghz", "brickwork"]) parser.add_argument("--nlayers", type=int, default=3) parser.add_argument("--num-slices", type=int, default=1) parser.add_argument("--total-repeats", type=int, default=1024) parser.add_argument("--search-workers", type=int, default=1) parser.add_argument("--search-timeout", type=int, default=300) parser.add_argument("--observable-file", type=str, default=None) parser.add_argument("--observable-json", type=str, default=None) parser.add_argument("--save-path", type=str, default=None) parser.add_argument("--load-path", type=str, default=None) parser.add_argument("--no-compare", action="store_true") parser.add_argument("--mode", type=str, default="sv", choices=["sv", "expval"]) args = parser.parse_args() comm = MPI.COMM_WORLD rank = comm.Get_rank() if rank == 0: print(f"Circuit: {args.circuit}, nqubits={args.nqubits}, " f"nlayers={args.nlayers}, ranks={comm.Get_size()}") np.random.seed(42) circuit = make_circuit(args.circuit, args.nqubits, args.nlayers) observable = _load_observable(args.observable_file, args.observable_json) if args.mode == "expval": try: expval, t_total = run_mpi_expval( circuit, args.nqubits, observable=observable, total_repeats=args.total_repeats, search_workers=args.search_workers, search_timeout=args.search_timeout, ) except Exception as e: if rank == 0: print(f"[FAILED] {e}") raise if rank == 0: np.save(f"data/expval_tn_{args.circuit}{args.nqubits}.npy", np.asarray(expval)) if not args.no_compare: print("No built-in reference comparison for arbitrary observables.") return try: sv, t_total = run_mpi(circuit, args.nqubits, args.num_slices, total_repeats=args.total_repeats, load_path=args.load_path, save_path=args.save_path) except Exception as e: if rank == 0: print(f"[FAILED] {e}") raise if rank == 0 and sv is not None: print(f"\n[quimb TN MPI] time={t_total:.4f}s shape={sv.shape}") np.save(f"data/sv_tn_{args.circuit}{args.nqubits}_mpi.npy", sv) if not args.no_compare: from qibotn.bak.benchmark_tn import run_qibojit import gc np.random.seed(42) circuit_ref = make_circuit(args.circuit, args.nqubits, args.nlayers) sv_ref, t_ref = run_qibojit(circuit_ref) np.save(f"data/sv_qibojit_{args.circuit}{args.nqubits}.npy", sv_ref) print(f"[qibojit] time={t_ref:.4f}s") # free memory before loading via mmap for expval comparison del sv, sv_ref gc.collect() from compare_jit_tn_quimb import check_results ref_path = f"data/sv_qibojit_{args.circuit}{args.nqubits}.npy" tn_path = f"data/sv_tn_{args.circuit}{args.nqubits}_mpi.npy" check_results(ref_path, tn_path, args.nqubits) if t_total > 0: print(f"Speedup : {t_ref/t_total:.2f}x") if __name__ == "__main__": main()