期望值计算支持;更新后端调用
Some checks failed
Build wheels / build (ubuntu-latest, 3.11) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.12) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.13) (push) Has been cancelled
Tests / check (push) Has been cancelled
Tests / build (ubuntu-latest, 3.11) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.12) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.13) (push) Has been cancelled
Some checks failed
Build wheels / build (ubuntu-latest, 3.11) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.12) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.13) (push) Has been cancelled
Tests / check (push) Has been cancelled
Tests / build (ubuntu-latest, 3.11) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.12) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.13) (push) Has been cancelled
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
"""MPI-parallel TN benchmark: path search + contraction via MPI."""
|
||||
import json
|
||||
import pickle
|
||||
import time
|
||||
import argparse
|
||||
@@ -8,9 +9,38 @@ import qibo
|
||||
from qibo import Circuit, gates
|
||||
from mpi4py import MPI
|
||||
from concurrent.futures import ProcessPoolExecutor, as_completed
|
||||
from qibotn.observables import check_observable, extract_gates_and_qubits
|
||||
|
||||
|
||||
def _run_serial_search(tn_bytes, output_inds, repeats, seed, num_slices, n_ranks, max_time=600):
|
||||
def _load_observable(observable_file=None, observable_json=None):
|
||||
if observable_file:
|
||||
with open(observable_file, "r", encoding="utf8") as f:
|
||||
return json.load(f)
|
||||
if observable_json:
|
||||
return json.loads(observable_json)
|
||||
return None
|
||||
|
||||
|
||||
def _term_to_quimb_operator(term):
|
||||
"""Convert one extracted Hamiltonian term to a quimb operator."""
|
||||
import quimb as qu
|
||||
|
||||
coeff = complex(term[0][2]) if term else 1.0
|
||||
op = None
|
||||
where = []
|
||||
|
||||
for qubit, gate_name, _ in term:
|
||||
qubit = int(qubit)
|
||||
gate_name = str(gate_name).upper()
|
||||
if gate_name == "I":
|
||||
continue
|
||||
where.append(qubit)
|
||||
op = qu.pauli(gate_name.lower()) if op is None else op & qu.pauli(gate_name.lower())
|
||||
|
||||
return complex(coeff), op, tuple(where)
|
||||
|
||||
|
||||
def _run_serial_search(tn_bytes, output_inds, repeats, seed, num_slices, n_ranks, max_time):
|
||||
import pickle, cotengra as ctg, random
|
||||
random.seed(seed)
|
||||
tn = pickle.loads(tn_bytes)
|
||||
@@ -29,10 +59,14 @@ def _run_serial_search(tn_bytes, output_inds, repeats, seed, num_slices, n_ranks
|
||||
|
||||
|
||||
def parallel_search(tn, output_inds, total_repeats, n_workers, num_slices, n_ranks,
|
||||
timeout=60):
|
||||
timeout):
|
||||
import pickle, os, signal
|
||||
from concurrent.futures import ProcessPoolExecutor, as_completed
|
||||
tn_bytes = pickle.dumps(tn)
|
||||
if n_workers <= 1:
|
||||
return _run_serial_search(
|
||||
tn_bytes, output_inds, total_repeats, 0, num_slices, n_ranks, timeout
|
||||
)[1]
|
||||
repeats_per = max(1, total_repeats // n_workers)
|
||||
best_cost, best_tree = float('inf'), None
|
||||
|
||||
@@ -152,7 +186,7 @@ def run_mpi(circuit, nqubits, num_slices, total_repeats=1024,
|
||||
psi_tn = qc.to_dense(rehearse="tn")
|
||||
local_tree = parallel_search(
|
||||
psi_tn, psi_tn.outer_inds(), rank_repeats, n_workers=48,
|
||||
num_slices=num_slices, n_ranks=size, timeout=60,
|
||||
num_slices=num_slices, n_ranks=size, timeout=600,
|
||||
)
|
||||
t_search = time.time() - t0
|
||||
local_psi = psi_tn
|
||||
@@ -181,7 +215,7 @@ def run_mpi(circuit, nqubits, num_slices, total_repeats=1024,
|
||||
|
||||
# --- contraction: all ranks work in parallel ---
|
||||
import torch
|
||||
torch.set_num_threads(max(1, 48 // size))
|
||||
torch.set_num_threads(max(1, 96 // size))
|
||||
arrays = [torch.from_numpy(np.asarray(a)).to(torch.complex128) for a in psi.arrays]
|
||||
t0 = time.time()
|
||||
sv = _contract_mpi(tree, arrays, comm, root=0)
|
||||
@@ -193,6 +227,72 @@ def run_mpi(circuit, nqubits, num_slices, total_repeats=1024,
|
||||
return None, t_search + t_contract
|
||||
|
||||
|
||||
def run_mpi_expval(
|
||||
circuit,
|
||||
nqubits,
|
||||
observable=None,
|
||||
total_repeats=1024,
|
||||
search_workers=1,
|
||||
search_timeout=300,
|
||||
):
|
||||
"""Compute a Hamiltonian expectation value directly from TN via MPI.
|
||||
MPI parallelizes over Hamiltonian terms; ProcessPool optionally helps
|
||||
path search for each term."""
|
||||
import torch
|
||||
comm = MPI.COMM_WORLD
|
||||
rank = comm.Get_rank()
|
||||
size = comm.Get_size()
|
||||
|
||||
qibo.set_backend("qibotn", platform="quimb")
|
||||
b = qibo.get_backend()
|
||||
b.configure_tn_simulation(ansatz="tn")
|
||||
|
||||
observable = check_observable(observable, nqubits)
|
||||
ham_gate_map = extract_gates_and_qubits(observable)
|
||||
|
||||
qc = b._qibo_circuit_to_quimb(circuit, quimb_circuit_type=b.circuit_ansatz,
|
||||
gate_opts={"max_bond": None, "cutoff": 1e-10})
|
||||
|
||||
my_terms = ham_gate_map[rank::size]
|
||||
torch.set_num_threads(max(1, 96 // size))
|
||||
t0 = time.time()
|
||||
|
||||
my_exp = 0.0 + 0.0j
|
||||
for term in my_terms:
|
||||
coeff, op, where = _term_to_quimb_operator(term)
|
||||
if op is None:
|
||||
my_exp += coeff
|
||||
continue
|
||||
tn = qc.local_expectation_tn(op, where=where)
|
||||
if len(tn.outer_inds()) == 0:
|
||||
val = complex(tn.contract())
|
||||
else:
|
||||
tree = parallel_search(
|
||||
tn,
|
||||
tn.outer_inds(),
|
||||
total_repeats,
|
||||
n_workers=search_workers,
|
||||
num_slices=1,
|
||||
n_ranks=size,
|
||||
timeout=search_timeout,
|
||||
)
|
||||
if tree is None:
|
||||
raise RuntimeError("Failed to find a contraction tree for expectation TN.")
|
||||
arrays = [torch.from_numpy(np.asarray(a)).to(torch.complex128) for a in tn.arrays]
|
||||
acc = sum(tree.contract_slice(arrays, i) for i in range(tree.multiplicity))
|
||||
val = complex(acc.item() if hasattr(acc, 'item') else acc)
|
||||
my_exp += coeff * val
|
||||
|
||||
t_total = time.time() - t0
|
||||
|
||||
all_results = comm.gather(my_exp, root=0)
|
||||
if rank == 0:
|
||||
total_exp = sum(all_results)
|
||||
print(f"\n[TN expval] time={t_total:.4f}s expval={total_exp.real:.12f}")
|
||||
return np.real_if_close(total_exp), t_total
|
||||
return None, t_total
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--nqubits", type=int, default=30)
|
||||
@@ -201,9 +301,14 @@ def main():
|
||||
parser.add_argument("--nlayers", type=int, default=3)
|
||||
parser.add_argument("--num-slices", type=int, default=1)
|
||||
parser.add_argument("--total-repeats", type=int, default=1024)
|
||||
parser.add_argument("--search-workers", type=int, default=1)
|
||||
parser.add_argument("--search-timeout", type=int, default=300)
|
||||
parser.add_argument("--observable-file", type=str, default=None)
|
||||
parser.add_argument("--observable-json", type=str, default=None)
|
||||
parser.add_argument("--save-path", type=str, default=None)
|
||||
parser.add_argument("--load-path", type=str, default=None)
|
||||
parser.add_argument("--no-compare", action="store_true")
|
||||
parser.add_argument("--mode", type=str, default="sv", choices=["sv", "expval"])
|
||||
args = parser.parse_args()
|
||||
|
||||
comm = MPI.COMM_WORLD
|
||||
@@ -215,6 +320,27 @@ def main():
|
||||
|
||||
np.random.seed(42)
|
||||
circuit = make_circuit(args.circuit, args.nqubits, args.nlayers)
|
||||
observable = _load_observable(args.observable_file, args.observable_json)
|
||||
|
||||
if args.mode == "expval":
|
||||
try:
|
||||
expval, t_total = run_mpi_expval(
|
||||
circuit,
|
||||
args.nqubits,
|
||||
observable=observable,
|
||||
total_repeats=args.total_repeats,
|
||||
search_workers=args.search_workers,
|
||||
search_timeout=args.search_timeout,
|
||||
)
|
||||
except Exception as e:
|
||||
if rank == 0:
|
||||
print(f"[FAILED] {e}")
|
||||
raise
|
||||
if rank == 0:
|
||||
np.save(f"data/expval_tn_{args.circuit}{args.nqubits}.npy", np.asarray(expval))
|
||||
if not args.no_compare:
|
||||
print("No built-in reference comparison for arbitrary observables.")
|
||||
return
|
||||
|
||||
try:
|
||||
sv, t_total = run_mpi(circuit, args.nqubits, args.num_slices,
|
||||
@@ -230,14 +356,20 @@ def main():
|
||||
np.save(f"data/sv_tn_{args.circuit}{args.nqubits}_mpi.npy", sv)
|
||||
|
||||
if not args.no_compare:
|
||||
from benchmark_tn import run_qibojit
|
||||
from qibotn.bak.benchmark_tn import run_qibojit
|
||||
import gc
|
||||
np.random.seed(42)
|
||||
circuit_ref = make_circuit(args.circuit, args.nqubits, args.nlayers)
|
||||
sv_ref, t_ref = run_qibojit(circuit_ref)
|
||||
fid = abs(np.dot(sv_ref.conj(), sv)) ** 2
|
||||
np.save(f"data/sv_qibojit_{args.circuit}{args.nqubits}.npy", sv_ref)
|
||||
print(f"[qibojit] time={t_ref:.4f}s")
|
||||
print(f"Fidelity : {fid:.8f}")
|
||||
print(f"L2 error : {np.linalg.norm(sv_ref - sv):.2e}")
|
||||
# free memory before loading via mmap for expval comparison
|
||||
del sv, sv_ref
|
||||
gc.collect()
|
||||
from compare_jit_tn_quimb import check_results
|
||||
ref_path = f"data/sv_qibojit_{args.circuit}{args.nqubits}.npy"
|
||||
tn_path = f"data/sv_tn_{args.circuit}{args.nqubits}_mpi.npy"
|
||||
check_results(ref_path, tn_path, args.nqubits)
|
||||
if t_total > 0:
|
||||
print(f"Speedup : {t_ref/t_total:.2f}x")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user