diff --git a/tests/quimb_mpi3.py b/tests/quimb_mpi3.py new file mode 100644 index 0000000..ebbcc25 --- /dev/null +++ b/tests/quimb_mpi3.py @@ -0,0 +1,103 @@ +import time +import numpy as np +import quimb.tensor as qtn +import cotengra as ctg +from mpi4py import MPI + +comm = MPI.COMM_WORLD +rank = comm.Get_rank() +size = comm.Get_size() + +def build_qft_circuit(n_qubits): + """构建标准 QFT 电路""" + circ = qtn.Circuit(n_qubits, dtype=np.complex128) + for i in range(n_qubits): + # 1. 施加 H 门 + circ.apply_gate('H', i) + # 2. 施加受控相位旋转 + for j in range(i + 1, n_qubits): + theta = np.pi / (2**(j - i)) + circ.apply_gate('CPHASE', theta, i, j) + return circ + +def run_mpi(n_qubits): + if rank == 0: + print(f"MPI size: {size} ranks") + print(f"Circuit: QFT {n_qubits} qubits") + + # 1. 所有 rank 独立构建 QFT 电路 + circ = build_qft_circuit(n_qubits) + + # 物理观测:计算 ,结果应为 1.0 + # 注意:QFT 是幺正变换,末态模长平方必为 1 + psi = circ.psi + net = psi.conj() & psi + + # 2. 路径搜索优化 + t0 = time.perf_counter() + # 每个 rank 尝试不同的种子,增加找到全局最优路径的概率 + repeats_per_rank = max(1, 256 // size) + opt = ctg.HyperOptimizer( + methods=['kahypar'], + max_repeats=repeats_per_rank, + minimize='flops', + parallel=max(1, 96 // size), + ) + # 搜索收缩树 + local_tree = net.contraction_tree(optimize=opt) + + # 汇总所有 rank 找到的树,在 rank 0 选出 FLOPs 最低的那棵 + all_trees = comm.gather(local_tree, root=0) + + if rank == 0: + tree = min(all_trees, key=lambda t: t.contraction_cost()) + t1 = time.perf_counter() + print(f"[rank 0] Path search: {t1 - t0:.4f} s") + print(f"[rank 0] Best path FLOPs: {tree.contraction_cost():.2e}") + else: + tree = None + + # 将最优路径广播给所有进程 + tree = comm.bcast(tree, root=0) + + # 3. 切片处理(性能控制核心) + if rank == 0: + # 比赛建议:将 target_size 设为能填满单进程内存的 50%-70% + # 或者改用 target_slices=size * 4 以确保负载绝对平衡 + sliced_tree = tree.slice(target_size=2**27) + else: + sliced_tree = None + + sliced_tree = comm.bcast(sliced_tree, root=0) + n_slices = sliced_tree.nslices + + if rank == 0: + print(f"Total slices: {n_slices}, each rank handles ~{n_slices // size + 1}") + + # 获取原始张量数据 + arrays = [t.data for t in net.tensors] + + # 4. 执行收缩计算 + t2 = time.perf_counter() + local_result = 0.0 + 0.0j + # 简单的静态负载均衡:每个 rank 跳步处理切片 + for i in range(rank, n_slices, size): + local_result += sliced_tree.contract_slice(arrays, i, backend='numpy') + t3 = time.perf_counter() + + # 5. 结果汇总 + total = comm.reduce(local_result, op=MPI.SUM, root=0) + + if rank == 0: + duration = t3 - t2 + print(f"[rank 0] Contract: {duration:.4f} s") + # 对于 ,QFT 的正确结果应无限接近 1.0 + print(f"Result (Norm): {total.real:.10f} + {total.imag:.10f}j") + +if __name__ == "__main__": + import argparse + parser = argparse.ArgumentParser() + parser.add_argument("--n_qubits", type=int, default=20) + # QFT 的深度由比特数自动决定,所以删除了 --depth 参数 + args = parser.parse_args() + run_mpi(args.n_qubits)