tn脚本更新
Some checks failed
Build wheels / build (ubuntu-latest, 3.11) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.12) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.13) (push) Has been cancelled
Tests / check (push) Has been cancelled
Tests / build (ubuntu-latest, 3.11) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.12) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.13) (push) Has been cancelled
Some checks failed
Build wheels / build (ubuntu-latest, 3.11) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.12) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.13) (push) Has been cancelled
Tests / check (push) Has been cancelled
Tests / build (ubuntu-latest, 3.11) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.12) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.13) (push) Has been cancelled
This commit is contained in:
127
benchmark_mps.py
127
benchmark_mps.py
@@ -4,16 +4,17 @@ import argparse
|
||||
import os
|
||||
import numpy as np
|
||||
import qibo
|
||||
import quimb.tensor as qtn
|
||||
from qibo import Circuit, gates
|
||||
|
||||
DATA_DIR = os.path.join(os.path.dirname(__file__), "data")
|
||||
|
||||
|
||||
def make_circuit(circuit_type, nqubits, nlayers=1):
|
||||
def make_circuit(circuit_type, nqubits, nlayers=1, add_measurements=False):
|
||||
c = Circuit(nqubits)
|
||||
if circuit_type == "qft":
|
||||
from qibo.models import QFT
|
||||
return QFT(nqubits)
|
||||
c = QFT(nqubits)
|
||||
elif circuit_type == "variational":
|
||||
for layer in range(nlayers):
|
||||
for q in range(nqubits):
|
||||
@@ -27,6 +28,8 @@ def make_circuit(circuit_type, nqubits, nlayers=1):
|
||||
c.add(gates.CNOT(q, q + 1))
|
||||
else:
|
||||
raise ValueError(f"Unknown circuit: {circuit_type}")
|
||||
if add_measurements:
|
||||
c.add(gates.M(*range(nqubits)))
|
||||
return c
|
||||
|
||||
|
||||
@@ -39,20 +42,58 @@ def run_qibojit(circuit):
|
||||
return sv, elapsed
|
||||
|
||||
|
||||
def run_quimb_mps(circuit, max_bond, svd_cutoff, optimizer):
|
||||
def run_quimb_mps(circuit, max_bond, svd_cutoff, optimizer, nshots=None):
|
||||
qibo.set_backend("qibotn", platform="quimb")
|
||||
b = qibo.get_backend()
|
||||
b.configure_tn_simulation(ansatz="mps", max_bond_dimension=max_bond, svd_cutoff=svd_cutoff)
|
||||
b.contractions_optimizer = optimizer
|
||||
|
||||
t0 = time.time()
|
||||
result = b.execute_circuit(circuit, return_array=True)
|
||||
elapsed = time.time() - t0
|
||||
sv = result.state()
|
||||
return sv, elapsed
|
||||
if nshots:
|
||||
result = b.execute_circuit(circuit, nshots=nshots)
|
||||
elapsed = time.time() - t0
|
||||
return result.frequencies(), elapsed, 0.0
|
||||
else:
|
||||
# MPS simulation
|
||||
circ_quimb = qtn.CircuitMPS.from_openqasm2_str(
|
||||
circuit.to_qasm(),
|
||||
gate_opts={"max_bond": max_bond, "cutoff": svd_cutoff},
|
||||
)
|
||||
t_mps = time.time() - t0
|
||||
# to_dense separately
|
||||
t1 = time.time()
|
||||
#sv = circ_quimb.psi.to_dense().reshape(-1)
|
||||
sv = None
|
||||
t_dense = time.time() - t1
|
||||
return sv, t_mps, t_dense
|
||||
|
||||
|
||||
def compare(sv_ref, sv_mps):
|
||||
def run_quimb_permmps(circuit, max_bond, svd_cutoff, nshots=None):
|
||||
gates_list = [
|
||||
qtn.Gate(g.name, params=list(g.parameters), qubits=list(g.qubits))
|
||||
for g in circuit.queue
|
||||
if g.name.lower() != "measure"
|
||||
]
|
||||
t0 = time.time()
|
||||
circ = qtn.CircuitPermMPS.from_gates(
|
||||
gates_list,
|
||||
N=circuit.nqubits,
|
||||
max_bond=max_bond,
|
||||
cutoff=svd_cutoff,
|
||||
)
|
||||
if nshots:
|
||||
from collections import Counter
|
||||
result = Counter(circ.sample(nshots))
|
||||
elapsed = time.time() - t0
|
||||
return dict(result), elapsed
|
||||
else:
|
||||
mps = circ.get_psi_unordered()
|
||||
sv = mps.to_dense().reshape(-1)
|
||||
elapsed = time.time() - t0
|
||||
return sv, elapsed
|
||||
|
||||
|
||||
def compare_statevector(sv_ref, sv_mps):
|
||||
sv_ref = np.array(sv_ref, dtype=complex).flatten()
|
||||
sv_mps = np.array(sv_mps, dtype=complex).flatten()
|
||||
fidelity = abs(np.dot(sv_ref.conj(), sv_mps)) ** 2
|
||||
@@ -60,6 +101,12 @@ def compare(sv_ref, sv_mps):
|
||||
return fidelity, l2_err
|
||||
|
||||
|
||||
def compare_frequencies(freq_ref, freq_mps, nshots):
|
||||
all_keys = set(freq_ref) | set(freq_mps)
|
||||
tvd = 0.5 * sum(abs(freq_ref.get(k, 0) - freq_mps.get(k, 0)) for k in all_keys) / nshots
|
||||
return tvd
|
||||
|
||||
|
||||
def jit_cache_path(circuit_type, nqubits, nlayers):
|
||||
os.makedirs(DATA_DIR, exist_ok=True)
|
||||
return os.path.join(DATA_DIR, f"jit_{circuit_type}_n{nqubits}_l{nlayers}.npy")
|
||||
@@ -74,37 +121,65 @@ def main():
|
||||
parser.add_argument("--max-bond", type=int, default=None,
|
||||
help="Max bond dimension for MPS (None = unlimited)")
|
||||
parser.add_argument("--svd-cutoff", type=float, default=1e-6)
|
||||
parser.add_argument("--optimizer", type=str, default="auto-hq")
|
||||
parser.add_argument("--optimizer", type=str, default="eager")
|
||||
parser.add_argument("--nshots", type=int, default=None,
|
||||
help="Use sampling mode with given number of shots instead of statevector")
|
||||
parser.add_argument("--permmps", action="store_true",
|
||||
help="Use CircuitPermMPS directly instead of qibotn backend")
|
||||
parser.add_argument("--skip-jit", action="store_true",
|
||||
help="Skip qibojit run, load cached statevector if available")
|
||||
parser.add_argument("--no-compare", action="store_true",
|
||||
help="Skip correctness comparison entirely")
|
||||
args = parser.parse_args()
|
||||
|
||||
print(f"Circuit: {args.circuit}, nqubits={args.nqubits}, nlayers={args.nlayers}")
|
||||
print(f"MPS config: max_bond={args.max_bond}, svd_cutoff={args.svd_cutoff}, optimizer={args.optimizer}")
|
||||
|
||||
cache_path = jit_cache_path(args.circuit, args.nqubits, args.nlayers)
|
||||
ref = None
|
||||
t_ref = None
|
||||
|
||||
if args.skip_jit and os.path.exists(cache_path):
|
||||
sv_ref = np.load(cache_path)
|
||||
print(f"\n[qibojit] loaded from cache: {cache_path}")
|
||||
else:
|
||||
np.random.seed(42)
|
||||
circuit_ref = make_circuit(args.circuit, args.nqubits, args.nlayers)
|
||||
sv_ref, t_ref = run_qibojit(circuit_ref)
|
||||
np.save(cache_path, sv_ref)
|
||||
print(f"\n[qibojit] time={t_ref:.4f}s (saved to {cache_path})")
|
||||
if not args.no_compare:
|
||||
cache_path = jit_cache_path(args.circuit, args.nqubits, args.nlayers)
|
||||
if args.nshots:
|
||||
# frequency mode: run qibojit with same nshots
|
||||
np.random.seed(42)
|
||||
circuit_ref = make_circuit(args.circuit, args.nqubits, args.nlayers, add_measurements=True)
|
||||
qibo.set_backend("qibojit", platform="numba")
|
||||
t0 = time.time()
|
||||
result_ref = circuit_ref(nshots=args.nshots)
|
||||
t_ref = time.time() - t0
|
||||
ref = dict(result_ref.frequencies())
|
||||
print(f"\n[qibojit] time={t_ref:.4f}s")
|
||||
elif args.skip_jit and os.path.exists(cache_path):
|
||||
ref = np.load(cache_path)
|
||||
print(f"\n[qibojit] loaded from cache: {cache_path}")
|
||||
else:
|
||||
np.random.seed(42)
|
||||
circuit_ref = make_circuit(args.circuit, args.nqubits, args.nlayers)
|
||||
ref, t_ref = run_qibojit(circuit_ref)
|
||||
np.save(cache_path, ref)
|
||||
print(f"\n[qibojit] time={t_ref:.4f}s (saved to {cache_path})")
|
||||
|
||||
np.random.seed(42)
|
||||
circuit_mps = make_circuit(args.circuit, args.nqubits, args.nlayers)
|
||||
label = "quimb PermMPS" if args.permmps else "quimb MPS"
|
||||
try:
|
||||
sv_mps, t_mps = run_quimb_mps(circuit_mps, args.max_bond, args.svd_cutoff, args.optimizer)
|
||||
fidelity, l2_err = compare(sv_ref, sv_mps)
|
||||
print(f"[quimb MPS] time={t_mps:.4f}s")
|
||||
print(f"\nFidelity : {fidelity:.8f} (1=perfect)")
|
||||
print(f"L2 error : {l2_err:.2e}")
|
||||
if t_ref is not None and t_mps > 0:
|
||||
print(f"Speedup : {t_ref/t_mps:.2f}x")
|
||||
if args.permmps:
|
||||
out, t_mps = run_quimb_permmps(circuit_mps, args.max_bond, args.svd_cutoff, args.nshots)
|
||||
t_dense = 0.0
|
||||
else:
|
||||
out, t_mps, t_dense = run_quimb_mps(circuit_mps, args.max_bond, args.svd_cutoff, args.optimizer, args.nshots)
|
||||
print(f"[{label}] MPS sim={t_mps:.4f}s to_dense={t_dense:.4f}s total={t_mps+t_dense:.4f}s")
|
||||
if not args.no_compare:
|
||||
if args.nshots:
|
||||
tvd = compare_frequencies(ref, out, args.nshots)
|
||||
print(f"\nTVD : {tvd:.6f} (0=perfect)")
|
||||
else:
|
||||
fidelity, l2_err = compare_statevector(ref, out)
|
||||
print(f"\nFidelity : {fidelity:.8f} (1=perfect)")
|
||||
print(f"L2 error : {l2_err:.2e}")
|
||||
if t_ref is not None and t_mps > 0:
|
||||
print(f"Speedup : {t_ref/t_mps:.2f}x")
|
||||
except Exception as e:
|
||||
print(f"[quimb MPS] FAILED: {e}")
|
||||
raise
|
||||
|
||||
Reference in New Issue
Block a user