修改运行脚本
Some checks failed
Build wheels / build (ubuntu-latest, 3.11) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.12) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.13) (push) Has been cancelled
Tests / check (push) Has been cancelled
Tests / build (ubuntu-latest, 3.11) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.12) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.13) (push) Has been cancelled

This commit is contained in:
2026-04-17 23:22:50 +08:00
parent bcad2882fa
commit 4b7fc931ba

View File

@@ -8,26 +8,35 @@ comm = MPI.COMM_WORLD
rank = comm.Get_rank()
size = comm.Get_size()
def run_mpi(n_qubits, depth):
if rank == 0:
print(f"MPI size: {size} ranks")
print(f"Circuit: {n_qubits} qubits, depth {depth}")
# 1. 所有 rank 独立构建电路(避免广播大对象)
def build_qft(n_qubits):
circ = qtn.Circuit(n_qubits, dtype=np.complex128)
for _ in range(depth):
for i in range(n_qubits):
circ.apply_gate('H', i)
for i in range(0, n_qubits - 1, 2):
circ.apply_gate('CZ', i, i + 1)
for j in range(i + 1, n_qubits):
circ.apply_gate('CPHASE', np.pi / 2 ** (j - i), i, j)
return circ
def run_mpi(n_qubits, depth=None):
if rank == 0:
print(f"MPI size: {size} ranks")
print(f"Circuit: QFT {n_qubits} qubits")
circ = build_qft(n_qubits)
psi = circ.psi
net = psi.conj() & psi
# 期望值网络:<psi|Z_0|psi>
Z = np.array([[1, 0], [0, -1]], dtype=np.complex128)
bra = psi.conj().reindex({f'k{i}': f'b{i}' for i in range(n_qubits)})
obs = qtn.Tensor(Z, inds=(f'k0', f'b0'))
net = psi & obs & bra
# 2. 所有 rank 并行搜索路径rank 0 选全局最优
t0 = time.perf_counter()
repeats_per_rank = max(1, 128 // size)
opt = ctg.HyperOptimizer(
methods=['kahypar'],
#methods=['greedy'],
#max_repeats=repeats_per_rank,
max_repeats=repeats_per_rank,
minimize='flops',
parallel=max(1, 96 // size),
@@ -75,7 +84,7 @@ def run_mpi(n_qubits, depth):
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--n_qubits", type=int, default=50)
parser.add_argument("--depth", type=int, default=20)
parser.add_argument("--n_qubits", type=int, default=20)
parser.add_argument("--depth", type=int, default=30)
args = parser.parse_args()
run_mpi(args.n_qubits, args.depth)