import time import numpy as np import quimb.tensor as qtn import cotengra as ctg from mpi4py import MPI comm = MPI.COMM_WORLD rank = comm.Get_rank() size = comm.Get_size() def build_qft(n_qubits): circ = qtn.Circuit(n_qubits, dtype=np.complex128) for i in range(n_qubits): circ.apply_gate('H', i) for j in range(i + 1, n_qubits): circ.apply_gate('CPHASE', np.pi / 2 ** (j - i), i, j) return circ def run_mpi(n_qubits, depth=None): if rank == 0: print(f"MPI size: {size} ranks") print(f"Circuit: QFT {n_qubits} qubits") circ = build_qft(n_qubits) psi = circ.psi # 期望值网络: Z = np.array([[1, 0], [0, -1]], dtype=np.complex128) bra = psi.conj().reindex({f'k{i}': f'b{i}' for i in range(n_qubits)}) obs = qtn.Tensor(Z, inds=(f'k0', f'b0')) net = psi & obs & bra # 2. 所有 rank 并行搜索路径,rank 0 选全局最优 t0 = time.perf_counter() repeats_per_rank = max(1, 128 // size) opt = ctg.HyperOptimizer( methods=['kahypar'], #methods=['greedy'], #max_repeats=repeats_per_rank, max_repeats=repeats_per_rank, minimize='flops', parallel=max(1, 96 // size), ) local_tree = net.contraction_tree(optimize=opt) all_trees = comm.gather(local_tree, root=0) if rank == 0: tree = min(all_trees, key=lambda t: t.contraction_cost()) t1 = time.perf_counter() print(f"[rank 0] Path search: {t1 - t0:.4f} s") else: tree = None tree = comm.bcast(tree, root=0) # 3. rank 0 切片,broadcast sliced_tree if rank == 0: sliced_tree = tree.slice(target_size=2**27) else: sliced_tree = None sliced_tree = comm.bcast(sliced_tree, root=0) n_slices = sliced_tree.nslices if rank == 0: print(f"Total slices: {n_slices}, each rank handles ~{n_slices // size}") arrays = [t.data for t in net.tensors] # 每个 rank 处理自己负责的切片 t2 = time.perf_counter() local_result = 0.0 + 0.0j for i in range(rank, n_slices, size): local_result += sliced_tree.contract_slice(arrays, i, backend='numpy') t3 = time.perf_counter() # 4. reduce 汇总到 rank 0 total = comm.reduce(local_result, op=MPI.SUM, root=0) if rank == 0: print(f"[rank 0] Contract: {t3 - t2:.4f} s") print(f"Result: {total:.10f}") if __name__ == "__main__": import argparse parser = argparse.ArgumentParser() parser.add_argument("--n_qubits", type=int, default=20) parser.add_argument("--depth", type=int, default=30) args = parser.parse_args() run_mpi(args.n_qubits, args.depth)