一个更为优秀的mpi运行代码,不同测试用例修改n_qubits与电路defination
Some checks failed
Build wheels / build (ubuntu-latest, 3.11) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.12) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.13) (push) Has been cancelled
Tests / check (push) Has been cancelled
Tests / build (ubuntu-latest, 3.11) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.12) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.13) (push) Has been cancelled
Some checks failed
Build wheels / build (ubuntu-latest, 3.11) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.12) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.13) (push) Has been cancelled
Tests / check (push) Has been cancelled
Tests / build (ubuntu-latest, 3.11) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.12) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.13) (push) Has been cancelled
This commit is contained in:
103
tests/quimb_mpi3.py
Normal file
103
tests/quimb_mpi3.py
Normal file
@@ -0,0 +1,103 @@
|
||||
import time
|
||||
import numpy as np
|
||||
import quimb.tensor as qtn
|
||||
import cotengra as ctg
|
||||
from mpi4py import MPI
|
||||
|
||||
comm = MPI.COMM_WORLD
|
||||
rank = comm.Get_rank()
|
||||
size = comm.Get_size()
|
||||
|
||||
def build_qft_circuit(n_qubits):
|
||||
"""构建标准 QFT 电路"""
|
||||
circ = qtn.Circuit(n_qubits, dtype=np.complex128)
|
||||
for i in range(n_qubits):
|
||||
# 1. 施加 H 门
|
||||
circ.apply_gate('H', i)
|
||||
# 2. 施加受控相位旋转
|
||||
for j in range(i + 1, n_qubits):
|
||||
theta = np.pi / (2**(j - i))
|
||||
circ.apply_gate('CPHASE', theta, i, j)
|
||||
return circ
|
||||
|
||||
def run_mpi(n_qubits):
|
||||
if rank == 0:
|
||||
print(f"MPI size: {size} ranks")
|
||||
print(f"Circuit: QFT {n_qubits} qubits")
|
||||
|
||||
# 1. 所有 rank 独立构建 QFT 电路
|
||||
circ = build_qft_circuit(n_qubits)
|
||||
|
||||
# 物理观测:计算 <psi|psi>,结果应为 1.0
|
||||
# 注意:QFT 是幺正变换,末态模长平方必为 1
|
||||
psi = circ.psi
|
||||
net = psi.conj() & psi
|
||||
|
||||
# 2. 路径搜索优化
|
||||
t0 = time.perf_counter()
|
||||
# 每个 rank 尝试不同的种子,增加找到全局最优路径的概率
|
||||
repeats_per_rank = max(1, 256 // size)
|
||||
opt = ctg.HyperOptimizer(
|
||||
methods=['kahypar'],
|
||||
max_repeats=repeats_per_rank,
|
||||
minimize='flops',
|
||||
parallel=max(1, 96 // size),
|
||||
)
|
||||
# 搜索收缩树
|
||||
local_tree = net.contraction_tree(optimize=opt)
|
||||
|
||||
# 汇总所有 rank 找到的树,在 rank 0 选出 FLOPs 最低的那棵
|
||||
all_trees = comm.gather(local_tree, root=0)
|
||||
|
||||
if rank == 0:
|
||||
tree = min(all_trees, key=lambda t: t.contraction_cost())
|
||||
t1 = time.perf_counter()
|
||||
print(f"[rank 0] Path search: {t1 - t0:.4f} s")
|
||||
print(f"[rank 0] Best path FLOPs: {tree.contraction_cost():.2e}")
|
||||
else:
|
||||
tree = None
|
||||
|
||||
# 将最优路径广播给所有进程
|
||||
tree = comm.bcast(tree, root=0)
|
||||
|
||||
# 3. 切片处理(性能控制核心)
|
||||
if rank == 0:
|
||||
# 比赛建议:将 target_size 设为能填满单进程内存的 50%-70%
|
||||
# 或者改用 target_slices=size * 4 以确保负载绝对平衡
|
||||
sliced_tree = tree.slice(target_size=2**27)
|
||||
else:
|
||||
sliced_tree = None
|
||||
|
||||
sliced_tree = comm.bcast(sliced_tree, root=0)
|
||||
n_slices = sliced_tree.nslices
|
||||
|
||||
if rank == 0:
|
||||
print(f"Total slices: {n_slices}, each rank handles ~{n_slices // size + 1}")
|
||||
|
||||
# 获取原始张量数据
|
||||
arrays = [t.data for t in net.tensors]
|
||||
|
||||
# 4. 执行收缩计算
|
||||
t2 = time.perf_counter()
|
||||
local_result = 0.0 + 0.0j
|
||||
# 简单的静态负载均衡:每个 rank 跳步处理切片
|
||||
for i in range(rank, n_slices, size):
|
||||
local_result += sliced_tree.contract_slice(arrays, i, backend='numpy')
|
||||
t3 = time.perf_counter()
|
||||
|
||||
# 5. 结果汇总
|
||||
total = comm.reduce(local_result, op=MPI.SUM, root=0)
|
||||
|
||||
if rank == 0:
|
||||
duration = t3 - t2
|
||||
print(f"[rank 0] Contract: {duration:.4f} s")
|
||||
# 对于 <psi|psi>,QFT 的正确结果应无限接近 1.0
|
||||
print(f"Result (Norm): {total.real:.10f} + {total.imag:.10f}j")
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--n_qubits", type=int, default=20)
|
||||
# QFT 的深度由比特数自动决定,所以删除了 --depth 参数
|
||||
args = parser.parse_args()
|
||||
run_mpi(args.n_qubits)
|
||||
Reference in New Issue
Block a user