Files
qibotn/tests/quimb_mpi3.py
jaunatisblue e38fd02cf3
Some checks failed
Build wheels / build (ubuntu-latest, 3.11) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.12) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.13) (push) Has been cancelled
Tests / check (push) Has been cancelled
Tests / build (ubuntu-latest, 3.11) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.12) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.13) (push) Has been cancelled
一个更为优秀的mpi运行代码,不同测试用例修改n_qubits与电路defination
2026-04-22 18:48:03 +08:00

104 lines
3.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import time
import numpy as np
import quimb.tensor as qtn
import cotengra as ctg
from mpi4py import MPI
comm = MPI.COMM_WORLD
rank = comm.Get_rank()
size = comm.Get_size()
def build_qft_circuit(n_qubits):
"""构建标准 QFT 电路"""
circ = qtn.Circuit(n_qubits, dtype=np.complex128)
for i in range(n_qubits):
# 1. 施加 H 门
circ.apply_gate('H', i)
# 2. 施加受控相位旋转
for j in range(i + 1, n_qubits):
theta = np.pi / (2**(j - i))
circ.apply_gate('CPHASE', theta, i, j)
return circ
def run_mpi(n_qubits):
if rank == 0:
print(f"MPI size: {size} ranks")
print(f"Circuit: QFT {n_qubits} qubits")
# 1. 所有 rank 独立构建 QFT 电路
circ = build_qft_circuit(n_qubits)
# 物理观测:计算 <psi|psi>,结果应为 1.0
# 注意QFT 是幺正变换,末态模长平方必为 1
psi = circ.psi
net = psi.conj() & psi
# 2. 路径搜索优化
t0 = time.perf_counter()
# 每个 rank 尝试不同的种子,增加找到全局最优路径的概率
repeats_per_rank = max(1, 256 // size)
opt = ctg.HyperOptimizer(
methods=['kahypar'],
max_repeats=repeats_per_rank,
minimize='flops',
parallel=max(1, 96 // size),
)
# 搜索收缩树
local_tree = net.contraction_tree(optimize=opt)
# 汇总所有 rank 找到的树,在 rank 0 选出 FLOPs 最低的那棵
all_trees = comm.gather(local_tree, root=0)
if rank == 0:
tree = min(all_trees, key=lambda t: t.contraction_cost())
t1 = time.perf_counter()
print(f"[rank 0] Path search: {t1 - t0:.4f} s")
print(f"[rank 0] Best path FLOPs: {tree.contraction_cost():.2e}")
else:
tree = None
# 将最优路径广播给所有进程
tree = comm.bcast(tree, root=0)
# 3. 切片处理(性能控制核心)
if rank == 0:
# 比赛建议:将 target_size 设为能填满单进程内存的 50%-70%
# 或者改用 target_slices=size * 4 以确保负载绝对平衡
sliced_tree = tree.slice(target_size=2**27)
else:
sliced_tree = None
sliced_tree = comm.bcast(sliced_tree, root=0)
n_slices = sliced_tree.nslices
if rank == 0:
print(f"Total slices: {n_slices}, each rank handles ~{n_slices // size + 1}")
# 获取原始张量数据
arrays = [t.data for t in net.tensors]
# 4. 执行收缩计算
t2 = time.perf_counter()
local_result = 0.0 + 0.0j
# 简单的静态负载均衡:每个 rank 跳步处理切片
for i in range(rank, n_slices, size):
local_result += sliced_tree.contract_slice(arrays, i, backend='numpy')
t3 = time.perf_counter()
# 5. 结果汇总
total = comm.reduce(local_result, op=MPI.SUM, root=0)
if rank == 0:
duration = t3 - t2
print(f"[rank 0] Contract: {duration:.4f} s")
# 对于 <psi|psi>QFT 的正确结果应无限接近 1.0
print(f"Result (Norm): {total.real:.10f} + {total.imag:.10f}j")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--n_qubits", type=int, default=20)
# QFT 的深度由比特数自动决定,所以删除了 --depth 参数
args = parser.parse_args()
run_mpi(args.n_qubits)