修复时间剪枝功能
Some checks failed
Build wheels / build (ubuntu-latest, 3.11) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.12) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.13) (push) Has been cancelled
Tests / check (push) Has been cancelled
Tests / build (ubuntu-latest, 3.11) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.12) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.13) (push) Has been cancelled
Some checks failed
Build wheels / build (ubuntu-latest, 3.11) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.12) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.13) (push) Has been cancelled
Tests / check (push) Has been cancelled
Tests / build (ubuntu-latest, 3.11) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.12) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.13) (push) Has been cancelled
This commit is contained in:
@@ -10,8 +10,7 @@ from mpi4py import MPI
|
|||||||
from concurrent.futures import ProcessPoolExecutor, as_completed
|
from concurrent.futures import ProcessPoolExecutor, as_completed
|
||||||
|
|
||||||
|
|
||||||
def _run_serial_search(tn_bytes, output_inds, repeats, seed, num_slices, n_ranks):
|
def _run_serial_search(tn_bytes, output_inds, repeats, seed, num_slices, n_ranks, max_time=600):
|
||||||
"""Run one serial HyperOptimizer in a subprocess, return (width, tree)."""
|
|
||||||
import pickle, cotengra as ctg, random
|
import pickle, cotengra as ctg, random
|
||||||
random.seed(seed)
|
random.seed(seed)
|
||||||
tn = pickle.loads(tn_bytes)
|
tn = pickle.loads(tn_bytes)
|
||||||
@@ -19,48 +18,49 @@ def _run_serial_search(tn_bytes, output_inds, repeats, seed, num_slices, n_ranks
|
|||||||
methods=['kahypar', 'kahypar-agglom', 'spinglass'],
|
methods=['kahypar', 'kahypar-agglom', 'spinglass'],
|
||||||
max_repeats=repeats,
|
max_repeats=repeats,
|
||||||
parallel=False,
|
parallel=False,
|
||||||
minimize='flops',
|
minimize='combo-256',
|
||||||
max_time=600,
|
max_time=max_time,
|
||||||
optlib="random",
|
optlib="random",
|
||||||
slicing_opts={'target_size': 2**30, 'allow_outer': False},
|
slicing_opts={'target_size': 2**29, 'allow_outer': True},
|
||||||
progbar=False,
|
progbar=False,
|
||||||
)
|
)
|
||||||
tree = tn.contraction_tree(optimize=opt, output_inds=output_inds)
|
tree = tn.contraction_tree(optimize=opt, output_inds=output_inds)
|
||||||
return tree.contraction_width(), tree
|
return tree.combo_cost(factor=256), tree
|
||||||
|
|
||||||
|
|
||||||
def parallel_search(tn, output_inds, total_repeats, n_workers, num_slices, n_ranks,
|
def parallel_search(tn, output_inds, total_repeats, n_workers, num_slices, n_ranks,
|
||||||
timeout=None):
|
timeout=60):
|
||||||
"""Launch n_workers subprocesses each running serial search, return best tree."""
|
|
||||||
import pickle, os, signal
|
import pickle, os, signal
|
||||||
from concurrent.futures import ProcessPoolExecutor, as_completed
|
from concurrent.futures import ProcessPoolExecutor, as_completed
|
||||||
tn_bytes = pickle.dumps(tn)
|
tn_bytes = pickle.dumps(tn)
|
||||||
repeats_per = max(1, total_repeats // n_workers)
|
repeats_per = max(1, total_repeats // n_workers)
|
||||||
best_width, best_tree = float('inf'), None
|
best_cost, best_tree = float('inf'), None
|
||||||
|
|
||||||
with ProcessPoolExecutor(max_workers=n_workers) as pool:
|
pool = ProcessPoolExecutor(max_workers=n_workers)
|
||||||
futures = {
|
futures = [
|
||||||
pool.submit(_run_serial_search, tn_bytes, output_inds,
|
pool.submit(_run_serial_search, tn_bytes, output_inds,
|
||||||
repeats_per, seed, num_slices, n_ranks): seed
|
repeats_per, seed, num_slices, n_ranks, timeout)
|
||||||
for seed in range(n_workers)
|
for seed in range(n_workers)
|
||||||
}
|
]
|
||||||
pids = {f: p.pid for f, p in zip(futures, pool._processes.values())}
|
try:
|
||||||
try:
|
for fut in as_completed(futures, timeout=timeout + 5):
|
||||||
for fut in as_completed(futures, timeout=timeout):
|
try:
|
||||||
try:
|
cost, tree = fut.result()
|
||||||
width, tree = fut.result()
|
if cost < best_cost:
|
||||||
if width < best_width:
|
best_cost, best_tree = cost, tree
|
||||||
best_width, best_tree = width, tree
|
except Exception as e:
|
||||||
except Exception as e:
|
print(f" [worker failed] {e}")
|
||||||
print(f" [worker failed] {e}")
|
except TimeoutError:
|
||||||
except TimeoutError:
|
pass
|
||||||
pass
|
finally:
|
||||||
for fut, pid in pids.items():
|
for fut in futures:
|
||||||
if not fut.done():
|
fut.cancel()
|
||||||
try:
|
for pid in list(pool._processes.keys()):
|
||||||
os.kill(pid, signal.SIGKILL)
|
try:
|
||||||
except ProcessLookupError:
|
os.kill(pid, signal.SIGKILL)
|
||||||
pass
|
except ProcessLookupError:
|
||||||
|
pass
|
||||||
|
pool.shutdown(wait=False)
|
||||||
|
|
||||||
return best_tree
|
return best_tree
|
||||||
|
|
||||||
@@ -107,7 +107,7 @@ def _contract_mpi(tree, arrays, comm, root=0):
|
|||||||
result_np = x_np if result_np is None else result_np + x_np
|
result_np = x_np if result_np is None else result_np + x_np
|
||||||
|
|
||||||
if result_np is None:
|
if result_np is None:
|
||||||
result_np = np.zeros(1, dtype=np.complex64)
|
result_np = np.zeros(1, dtype=np.complex128)
|
||||||
|
|
||||||
result = np.zeros_like(result_np) if rank == root else None
|
result = np.zeros_like(result_np) if rank == root else None
|
||||||
comm.Reduce(result_np, result, root=root)
|
comm.Reduce(result_np, result, root=root)
|
||||||
@@ -133,7 +133,7 @@ def run_mpi(circuit, nqubits, num_slices, total_repeats=1024,
|
|||||||
import torch
|
import torch
|
||||||
qc = b._qibo_circuit_to_quimb(circuit, quimb_circuit_type=b.circuit_ansatz,
|
qc = b._qibo_circuit_to_quimb(circuit, quimb_circuit_type=b.circuit_ansatz,
|
||||||
gate_opts={"max_bond": None, "cutoff": 1e-10})
|
gate_opts={"max_bond": None, "cutoff": 1e-10})
|
||||||
qc.to_backend = lambda x: torch.from_numpy(x).to(torch.complex64)
|
qc.to_backend = lambda x: torch.from_numpy(x).to(torch.complex128)
|
||||||
|
|
||||||
# --- path search: each rank serial, gather best to rank 0 ---
|
# --- path search: each rank serial, gather best to rank 0 ---
|
||||||
if load_path:
|
if load_path:
|
||||||
@@ -152,16 +152,16 @@ def run_mpi(circuit, nqubits, num_slices, total_repeats=1024,
|
|||||||
psi_tn = qc.to_dense(rehearse="tn")
|
psi_tn = qc.to_dense(rehearse="tn")
|
||||||
local_tree = parallel_search(
|
local_tree = parallel_search(
|
||||||
psi_tn, psi_tn.outer_inds(), rank_repeats, n_workers=48,
|
psi_tn, psi_tn.outer_inds(), rank_repeats, n_workers=48,
|
||||||
num_slices=num_slices, n_ranks=size, timeout=630,
|
num_slices=num_slices, n_ranks=size, timeout=60,
|
||||||
)
|
)
|
||||||
t_search = time.time() - t0
|
t_search = time.time() - t0
|
||||||
local_psi = psi_tn
|
local_psi = psi_tn
|
||||||
|
|
||||||
all_results = comm.gather((local_tree.contraction_width(), local_tree, local_psi), root=0)
|
all_results = comm.gather((local_tree.combo_cost(factor=256), local_tree, local_psi), root=0)
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
_, tree, psi = min(all_results, key=lambda x: x[0])
|
_, tree, psi = min(all_results, key=lambda x: x[0])
|
||||||
print(f" [path search] {t_search:.3f}s "
|
print(f" [path search] {t_search:.3f}s "
|
||||||
f"flops~2^{tree.contraction_cost():.2f} "
|
f"flops~2^{tree.contraction_cost(log=2):.2f} "
|
||||||
f"size~2^{tree.contraction_width():.2f} "
|
f"size~2^{tree.contraction_width():.2f} "
|
||||||
f"slices={tree.multiplicity}")
|
f"slices={tree.multiplicity}")
|
||||||
if save_path:
|
if save_path:
|
||||||
@@ -182,7 +182,7 @@ def run_mpi(circuit, nqubits, num_slices, total_repeats=1024,
|
|||||||
# --- contraction: all ranks work in parallel ---
|
# --- contraction: all ranks work in parallel ---
|
||||||
import torch
|
import torch
|
||||||
torch.set_num_threads(max(1, 48 // size))
|
torch.set_num_threads(max(1, 48 // size))
|
||||||
arrays = [torch.from_numpy(np.asarray(a)).to(torch.complex64) for a in psi.arrays]
|
arrays = [torch.from_numpy(np.asarray(a)).to(torch.complex128) for a in psi.arrays]
|
||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
sv = _contract_mpi(tree, arrays, comm, root=0)
|
sv = _contract_mpi(tree, arrays, comm, root=0)
|
||||||
t_contract = time.time() - t0
|
t_contract = time.time() - t0
|
||||||
|
|||||||
Reference in New Issue
Block a user