完善mps的vidal机制,多节点并行;补充tn搜索时dask集群搜索的方式
Some checks failed
Build wheels / build (ubuntu-latest, 3.11) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.12) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.13) (push) Has been cancelled
Tests / check (push) Has been cancelled
Tests / build (ubuntu-latest, 3.11) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.12) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.13) (push) Has been cancelled
Some checks failed
Build wheels / build (ubuntu-latest, 3.11) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.12) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.13) (push) Has been cancelled
Tests / check (push) Has been cancelled
Tests / build (ubuntu-latest, 3.11) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.12) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.13) (push) Has been cancelled
This commit is contained in:
56
tools/benchmark_contract_sliced.py
Normal file
56
tools/benchmark_contract_sliced.py
Normal file
@@ -0,0 +1,56 @@
|
||||
"""MPI parallel sliced contraction using pre-sliced tree."""
|
||||
import time, pickle, os
|
||||
import numpy as np
|
||||
from mpi4py import MPI
|
||||
|
||||
NQUBITS, NLAYERS, NCORES = 25, 10, 48
|
||||
|
||||
comm = MPI.COMM_WORLD
|
||||
rank, size = comm.Get_rank(), comm.Get_size()
|
||||
|
||||
os.environ['OMP_NUM_THREADS'] = str(NCORES)
|
||||
os.environ['MKL_NUM_THREADS'] = str(NCORES)
|
||||
|
||||
import torch
|
||||
import qibo, quimb as qu
|
||||
from qibotn.observables import build_random_circuit
|
||||
|
||||
torch.set_num_threads(NCORES)
|
||||
|
||||
circuit = build_random_circuit(NQUBITS, NLAYERS)
|
||||
qibo.set_backend("qibotn", platform="quimb")
|
||||
backend = qibo.get_backend()
|
||||
backend.configure_tn_simulation(ansatz="tn")
|
||||
qc = backend._qibo_circuit_to_quimb(circuit, backend.circuit_ansatz)
|
||||
tn = qc.local_expectation(qu.pauli('x') & qu.pauli('z'), (0, 1), rehearse='tn')
|
||||
|
||||
if rank == 0:
|
||||
with open(f"data/tree_q{NQUBITS}_l{NLAYERS}_sliced.pkl", 'rb') as f:
|
||||
tree = pickle.load(f)
|
||||
else:
|
||||
tree = None
|
||||
tree = comm.bcast(tree, root=0)
|
||||
|
||||
arrays = [torch.from_numpy(np.asarray(t._data)) for t in tn.tensors]
|
||||
n_slices = tree.multiplicity
|
||||
|
||||
if rank == 0:
|
||||
print(f"Slices: {n_slices}, Ranks: {size}, "
|
||||
f"Peak: {tree.max_size() * 16 / 1e9:.2f} GB, "
|
||||
f"Threads/rank: {NCORES}, Backend: torch")
|
||||
|
||||
t0 = time.time()
|
||||
result = None
|
||||
for i in range(rank, n_slices, size):
|
||||
val = tree.contract_slice(arrays, i, backend='torch')
|
||||
val_np = val.cpu().numpy().reshape(-1)
|
||||
result = val_np if result is None else result + val_np
|
||||
|
||||
if result is None:
|
||||
result = np.zeros(1, dtype=np.complex128)
|
||||
|
||||
total = np.zeros_like(result) if rank == 0 else None
|
||||
comm.Reduce(result, total, root=0)
|
||||
|
||||
if rank == 0:
|
||||
print(f"Contract: {time.time() - t0:.4f}s Expectation: {0.5 * total[0].real:.10f}")
|
||||
Reference in New Issue
Block a user