Files
qibotn/tests/quimb_mpi2.py
jaunatisblue 4b7fc931ba
Some checks failed
Build wheels / build (ubuntu-latest, 3.11) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.12) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.13) (push) Has been cancelled
Tests / check (push) Has been cancelled
Tests / build (ubuntu-latest, 3.11) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.12) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.13) (push) Has been cancelled
修改运行脚本
2026-04-17 23:22:50 +08:00

91 lines
2.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import time
import numpy as np
import quimb.tensor as qtn
import cotengra as ctg
from mpi4py import MPI
comm = MPI.COMM_WORLD
rank = comm.Get_rank()
size = comm.Get_size()
def build_qft(n_qubits):
circ = qtn.Circuit(n_qubits, dtype=np.complex128)
for i in range(n_qubits):
circ.apply_gate('H', i)
for j in range(i + 1, n_qubits):
circ.apply_gate('CPHASE', np.pi / 2 ** (j - i), i, j)
return circ
def run_mpi(n_qubits, depth=None):
if rank == 0:
print(f"MPI size: {size} ranks")
print(f"Circuit: QFT {n_qubits} qubits")
circ = build_qft(n_qubits)
psi = circ.psi
# 期望值网络:<psi|Z_0|psi>
Z = np.array([[1, 0], [0, -1]], dtype=np.complex128)
bra = psi.conj().reindex({f'k{i}': f'b{i}' for i in range(n_qubits)})
obs = qtn.Tensor(Z, inds=(f'k0', f'b0'))
net = psi & obs & bra
# 2. 所有 rank 并行搜索路径rank 0 选全局最优
t0 = time.perf_counter()
repeats_per_rank = max(1, 128 // size)
opt = ctg.HyperOptimizer(
methods=['kahypar'],
#methods=['greedy'],
#max_repeats=repeats_per_rank,
max_repeats=repeats_per_rank,
minimize='flops',
parallel=max(1, 96 // size),
)
local_tree = net.contraction_tree(optimize=opt)
all_trees = comm.gather(local_tree, root=0)
if rank == 0:
tree = min(all_trees, key=lambda t: t.contraction_cost())
t1 = time.perf_counter()
print(f"[rank 0] Path search: {t1 - t0:.4f} s")
else:
tree = None
tree = comm.bcast(tree, root=0)
# 3. rank 0 切片broadcast sliced_tree
if rank == 0:
sliced_tree = tree.slice(target_size=2**27)
else:
sliced_tree = None
sliced_tree = comm.bcast(sliced_tree, root=0)
n_slices = sliced_tree.nslices
if rank == 0:
print(f"Total slices: {n_slices}, each rank handles ~{n_slices // size}")
arrays = [t.data for t in net.tensors]
# 每个 rank 处理自己负责的切片
t2 = time.perf_counter()
local_result = 0.0 + 0.0j
for i in range(rank, n_slices, size):
local_result += sliced_tree.contract_slice(arrays, i, backend='numpy')
t3 = time.perf_counter()
# 4. reduce 汇总到 rank 0
total = comm.reduce(local_result, op=MPI.SUM, root=0)
if rank == 0:
print(f"[rank 0] Contract: {t3 - t2:.4f} s")
print(f"Result: {total:.10f}")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--n_qubits", type=int, default=20)
parser.add_argument("--depth", type=int, default=30)
args = parser.parse_args()
run_mpi(args.n_qubits, args.depth)