Files
qibotn/tests/quimb_mpi2.py
jaunatisblue bcad2882fa
Some checks failed
Build wheels / build (ubuntu-latest, 3.11) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.12) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.13) (push) Has been cancelled
Tests / check (push) Has been cancelled
Tests / build (ubuntu-latest, 3.11) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.12) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.13) (push) Has been cancelled
构建基于oneapi的mpi4py,quimb支持mpi多机并行,缩短路径找寻时间
2026-04-15 21:15:10 +08:00

82 lines
2.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import time
import numpy as np
import quimb.tensor as qtn
import cotengra as ctg
from mpi4py import MPI
comm = MPI.COMM_WORLD
rank = comm.Get_rank()
size = comm.Get_size()
def run_mpi(n_qubits, depth):
if rank == 0:
print(f"MPI size: {size} ranks")
print(f"Circuit: {n_qubits} qubits, depth {depth}")
# 1. 所有 rank 独立构建电路(避免广播大对象)
circ = qtn.Circuit(n_qubits, dtype=np.complex128)
for _ in range(depth):
for i in range(n_qubits):
circ.apply_gate('H', i)
for i in range(0, n_qubits - 1, 2):
circ.apply_gate('CZ', i, i + 1)
psi = circ.psi
net = psi.conj() & psi
# 2. 所有 rank 并行搜索路径rank 0 选全局最优
t0 = time.perf_counter()
repeats_per_rank = max(1, 128 // size)
opt = ctg.HyperOptimizer(
methods=['kahypar'],
max_repeats=repeats_per_rank,
minimize='flops',
parallel=max(1, 96 // size),
)
local_tree = net.contraction_tree(optimize=opt)
all_trees = comm.gather(local_tree, root=0)
if rank == 0:
tree = min(all_trees, key=lambda t: t.contraction_cost())
t1 = time.perf_counter()
print(f"[rank 0] Path search: {t1 - t0:.4f} s")
else:
tree = None
tree = comm.bcast(tree, root=0)
# 3. rank 0 切片broadcast sliced_tree
if rank == 0:
sliced_tree = tree.slice(target_size=2**27)
else:
sliced_tree = None
sliced_tree = comm.bcast(sliced_tree, root=0)
n_slices = sliced_tree.nslices
if rank == 0:
print(f"Total slices: {n_slices}, each rank handles ~{n_slices // size}")
arrays = [t.data for t in net.tensors]
# 每个 rank 处理自己负责的切片
t2 = time.perf_counter()
local_result = 0.0 + 0.0j
for i in range(rank, n_slices, size):
local_result += sliced_tree.contract_slice(arrays, i, backend='numpy')
t3 = time.perf_counter()
# 4. reduce 汇总到 rank 0
total = comm.reduce(local_result, op=MPI.SUM, root=0)
if rank == 0:
print(f"[rank 0] Contract: {t3 - t2:.4f} s")
print(f"Result: {total:.10f}")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--n_qubits", type=int, default=50)
parser.add_argument("--depth", type=int, default=20)
args = parser.parse_args()
run_mpi(args.n_qubits, args.depth)