From 4b7fc931baafdcd5ae35561beda4182553e5aa71 Mon Sep 17 00:00:00 2001 From: jaunatisblue Date: Fri, 17 Apr 2026 23:22:50 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E8=BF=90=E8=A1=8C=E8=84=9A?= =?UTF-8?q?=E6=9C=AC?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/quimb_mpi2.py | 33 +++++++++++++++++++++------------ 1 file changed, 21 insertions(+), 12 deletions(-) diff --git a/tests/quimb_mpi2.py b/tests/quimb_mpi2.py index 40c36d4..6229558 100644 --- a/tests/quimb_mpi2.py +++ b/tests/quimb_mpi2.py @@ -8,26 +8,35 @@ comm = MPI.COMM_WORLD rank = comm.Get_rank() size = comm.Get_size() -def run_mpi(n_qubits, depth): +def build_qft(n_qubits): + circ = qtn.Circuit(n_qubits, dtype=np.complex128) + for i in range(n_qubits): + circ.apply_gate('H', i) + for j in range(i + 1, n_qubits): + circ.apply_gate('CPHASE', np.pi / 2 ** (j - i), i, j) + return circ + +def run_mpi(n_qubits, depth=None): if rank == 0: print(f"MPI size: {size} ranks") - print(f"Circuit: {n_qubits} qubits, depth {depth}") + print(f"Circuit: QFT {n_qubits} qubits") - # 1. 所有 rank 独立构建电路(避免广播大对象) - circ = qtn.Circuit(n_qubits, dtype=np.complex128) - for _ in range(depth): - for i in range(n_qubits): - circ.apply_gate('H', i) - for i in range(0, n_qubits - 1, 2): - circ.apply_gate('CZ', i, i + 1) + circ = build_qft(n_qubits) psi = circ.psi - net = psi.conj() & psi + + # 期望值网络: + Z = np.array([[1, 0], [0, -1]], dtype=np.complex128) + bra = psi.conj().reindex({f'k{i}': f'b{i}' for i in range(n_qubits)}) + obs = qtn.Tensor(Z, inds=(f'k0', f'b0')) + net = psi & obs & bra # 2. 所有 rank 并行搜索路径,rank 0 选全局最优 t0 = time.perf_counter() repeats_per_rank = max(1, 128 // size) opt = ctg.HyperOptimizer( methods=['kahypar'], + #methods=['greedy'], + #max_repeats=repeats_per_rank, max_repeats=repeats_per_rank, minimize='flops', parallel=max(1, 96 // size), @@ -75,7 +84,7 @@ def run_mpi(n_qubits, depth): if __name__ == "__main__": import argparse parser = argparse.ArgumentParser() - parser.add_argument("--n_qubits", type=int, default=50) - parser.add_argument("--depth", type=int, default=20) + parser.add_argument("--n_qubits", type=int, default=20) + parser.add_argument("--depth", type=int, default=30) args = parser.parse_args() run_mpi(args.n_qubits, args.depth)