Some checks failed
Build wheels / build (ubuntu-latest, 3.11) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.12) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.13) (push) Has been cancelled
Tests / check (push) Has been cancelled
Tests / build (ubuntu-latest, 3.11) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.12) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.13) (push) Has been cancelled
91 lines
2.6 KiB
Python
91 lines
2.6 KiB
Python
import time
|
||
import numpy as np
|
||
import quimb.tensor as qtn
|
||
import cotengra as ctg
|
||
from mpi4py import MPI
|
||
|
||
comm = MPI.COMM_WORLD
|
||
rank = comm.Get_rank()
|
||
size = comm.Get_size()
|
||
|
||
def build_qft(n_qubits):
|
||
circ = qtn.Circuit(n_qubits, dtype=np.complex128)
|
||
for i in range(n_qubits):
|
||
circ.apply_gate('H', i)
|
||
for j in range(i + 1, n_qubits):
|
||
circ.apply_gate('CPHASE', np.pi / 2 ** (j - i), i, j)
|
||
return circ
|
||
|
||
def run_mpi(n_qubits, depth=None):
|
||
if rank == 0:
|
||
print(f"MPI size: {size} ranks")
|
||
print(f"Circuit: QFT {n_qubits} qubits")
|
||
|
||
circ = build_qft(n_qubits)
|
||
psi = circ.psi
|
||
|
||
# 期望值网络:<psi|Z_0|psi>
|
||
Z = np.array([[1, 0], [0, -1]], dtype=np.complex128)
|
||
bra = psi.conj().reindex({f'k{i}': f'b{i}' for i in range(n_qubits)})
|
||
obs = qtn.Tensor(Z, inds=(f'k0', f'b0'))
|
||
net = psi & obs & bra
|
||
|
||
# 2. 所有 rank 并行搜索路径,rank 0 选全局最优
|
||
t0 = time.perf_counter()
|
||
repeats_per_rank = max(1, 128 // size)
|
||
opt = ctg.HyperOptimizer(
|
||
methods=['kahypar'],
|
||
#methods=['greedy'],
|
||
#max_repeats=repeats_per_rank,
|
||
max_repeats=repeats_per_rank,
|
||
minimize='flops',
|
||
parallel=max(1, 96 // size),
|
||
)
|
||
local_tree = net.contraction_tree(optimize=opt)
|
||
|
||
all_trees = comm.gather(local_tree, root=0)
|
||
|
||
if rank == 0:
|
||
tree = min(all_trees, key=lambda t: t.contraction_cost())
|
||
t1 = time.perf_counter()
|
||
print(f"[rank 0] Path search: {t1 - t0:.4f} s")
|
||
else:
|
||
tree = None
|
||
|
||
tree = comm.bcast(tree, root=0)
|
||
|
||
# 3. rank 0 切片,broadcast sliced_tree
|
||
if rank == 0:
|
||
sliced_tree = tree.slice(target_size=2**27)
|
||
else:
|
||
sliced_tree = None
|
||
sliced_tree = comm.bcast(sliced_tree, root=0)
|
||
n_slices = sliced_tree.nslices
|
||
|
||
if rank == 0:
|
||
print(f"Total slices: {n_slices}, each rank handles ~{n_slices // size}")
|
||
|
||
arrays = [t.data for t in net.tensors]
|
||
|
||
# 每个 rank 处理自己负责的切片
|
||
t2 = time.perf_counter()
|
||
local_result = 0.0 + 0.0j
|
||
for i in range(rank, n_slices, size):
|
||
local_result += sliced_tree.contract_slice(arrays, i, backend='numpy')
|
||
t3 = time.perf_counter()
|
||
|
||
# 4. reduce 汇总到 rank 0
|
||
total = comm.reduce(local_result, op=MPI.SUM, root=0)
|
||
|
||
if rank == 0:
|
||
print(f"[rank 0] Contract: {t3 - t2:.4f} s")
|
||
print(f"Result: {total:.10f}")
|
||
|
||
if __name__ == "__main__":
|
||
import argparse
|
||
parser = argparse.ArgumentParser()
|
||
parser.add_argument("--n_qubits", type=int, default=20)
|
||
parser.add_argument("--depth", type=int, default=30)
|
||
args = parser.parse_args()
|
||
run_mpi(args.n_qubits, args.depth)
|