""" MPI + ThreadPoolExecutor 混合并行张量网络收缩。 每个 MPI rank 负责一部分 slice(stride 分配), rank 内用 ThreadPoolExecutor 并行执行各 slice(每线程一个 slice)。 用法: mpirun -n python mpi_v.py --qasm circuit.qasm --target-slices 16 --threads 8 """ import os import time import argparse import numpy as np from concurrent.futures import ThreadPoolExecutor, as_completed from mpi4py import MPI comm = MPI.COMM_WORLD rank = comm.Get_rank() size = comm.Get_size() import quimb.tensor as qtn import cotengra as ctg def _contract_slice(sliced_tree, arrays, idx): return sliced_tree.contract_slice(arrays, idx, backend="numpy") def run(qasm_path, target_slices, n_threads, max_repeats): # ── 构建张量网络(rank 0,broadcast arrays)── if rank == 0: with open(qasm_path) as f: qasm_str = f.read() # 不用 full_simplify,保持 outer_inds 完整 psi = qtn.Circuit.from_openqasm2_str(qasm_str).psi n_qubits = len([i for i in psi.outer_inds() if i.startswith("k")]) output_inds = [f"k{i}" for i in range(n_qubits)] arrays = [t.data for t in psi.tensors] else: psi = None n_qubits = None arrays = None output_inds = None n_qubits = comm.bcast(n_qubits, root=0) arrays = comm.bcast(arrays, root=0) output_inds = comm.bcast(output_inds, root=0) # ── 路径搜索(rank 0)+ broadcast ── t0 = time.perf_counter() if rank == 0: opt = ctg.HyperOptimizer( methods=["kahypar", "greedy"], max_repeats=max_repeats, minimize="flops", parallel=min(96, os.cpu_count()), ) tree = psi.contraction_tree(optimize=opt, output_inds=output_inds) n = target_slices sliced_tree = None while n >= 1: try: sliced_tree = tree.slice(target_size=n, allow_outer=False) break except RuntimeError: n //= 2 if sliced_tree is None: sliced_tree = tree.slice(target_slices=1, allow_outer=True) print(f"[rank 0] path search: {time.perf_counter()-t0:.2f}s slices: {sliced_tree.nslices}", flush=True) else: sliced_tree = None sliced_tree = comm.bcast(sliced_tree, root=0) n_slices = sliced_tree.nslices # ── 分布式收缩(MPI stride + ThreadPoolExecutor)── my_indices = list(range(rank, n_slices, size)) local_result = np.zeros(2**n_qubits, dtype=np.complex128) comm.Barrier() t1 = time.perf_counter() with ThreadPoolExecutor(max_workers=n_threads) as pool: for batch_start in range(0, len(my_indices), n_threads): batch = my_indices[batch_start:batch_start + n_threads] futures = {pool.submit(_contract_slice, sliced_tree, arrays, i): i for i in batch} for fut in as_completed(futures): local_result += np.array(fut.result()).flatten() t2 = time.perf_counter() if rank == 0: print(f"[rank 0] contract: {t2-t1:.2f}s", flush=True) # ── MPI reduce ── total = comm.reduce(local_result, op=MPI.SUM, root=0) if rank == 0: print(f"result norm: {np.linalg.norm(total):.10f}", flush=True) print(f"total time: {t2-t0:.2f}s", flush=True) return total return None def main(): parser = argparse.ArgumentParser() parser.add_argument("--qasm", required=True, help="QASM 文件路径") parser.add_argument("--target-slices", type=int, default=None, help="目标切片数量(优先于 target-size)") parser.add_argument("--target-size", type=int, default=28, help="切片目标大小指数(2^N),默认 28") parser.add_argument("--threads", type=int, default=max(1, os.cpu_count() // size), help="每个 rank 的线程数,默认 cpu_count/size") parser.add_argument("--max-repeats", type=int, default=256, help="cotengra 路径搜索重复次数") args = parser.parse_args() target = args.target_slices if args.target_slices else 2**args.target_size mode = "slices" if args.target_slices else f"size=2^{args.target_size}" if rank == 0: print(f"ranks={size} threads/rank={args.threads} target_{mode}", flush=True) run(args.qasm, target, args.threads, args.max_repeats) if __name__ == "__main__": main()