tn脚本更新
Some checks failed
Build wheels / build (ubuntu-latest, 3.11) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.12) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.13) (push) Has been cancelled
Tests / check (push) Has been cancelled
Tests / build (ubuntu-latest, 3.11) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.12) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.13) (push) Has been cancelled
Some checks failed
Build wheels / build (ubuntu-latest, 3.11) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.12) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.13) (push) Has been cancelled
Tests / check (push) Has been cancelled
Tests / build (ubuntu-latest, 3.11) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.12) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.13) (push) Has been cancelled
This commit is contained in:
126
benchmark_qmatchatea.py
Normal file
126
benchmark_qmatchatea.py
Normal file
@@ -0,0 +1,126 @@
|
||||
"""Benchmark: qibojit (reference) vs qibotn/qmatchatea MPS."""
|
||||
import time
|
||||
import argparse
|
||||
import os
|
||||
import numpy as np
|
||||
import qibo
|
||||
from qibo import Circuit, gates
|
||||
from qibo.backends import construct_backend
|
||||
|
||||
DATA_DIR = os.path.join(os.path.dirname(__file__), "data")
|
||||
|
||||
|
||||
def make_circuit(circuit_type, nqubits, nlayers=1):
|
||||
c = Circuit(nqubits)
|
||||
if circuit_type == "qft":
|
||||
from qibo.models import QFT
|
||||
return QFT(nqubits)
|
||||
elif circuit_type == "variational":
|
||||
for layer in range(nlayers):
|
||||
for q in range(nqubits):
|
||||
c.add(gates.RY(q, theta=np.random.uniform(0, 2 * np.pi)))
|
||||
offset = layer % 2
|
||||
for q in range(offset, nqubits - 1, 2):
|
||||
c.add(gates.CZ(q, q + 1))
|
||||
elif circuit_type == "ghz":
|
||||
c.add(gates.H(0))
|
||||
for q in range(nqubits - 1):
|
||||
c.add(gates.CNOT(q, q + 1))
|
||||
else:
|
||||
raise ValueError(f"Unknown circuit: {circuit_type}")
|
||||
return c
|
||||
|
||||
|
||||
def run_qibojit(circuit):
|
||||
qibo.set_backend("qibojit", platform="numba")
|
||||
t0 = time.time()
|
||||
result = circuit()
|
||||
elapsed = time.time() - t0
|
||||
return result.state(), elapsed
|
||||
|
||||
|
||||
def run_qmatchatea(circuit, max_bond, cut_ratio):
|
||||
import qmatchatea, qtealeaves.observables
|
||||
from qibo.backends import construct_backend as _cb
|
||||
b = _cb(backend="qibotn", platform="qmatchatea")
|
||||
b.configure_tn_simulation(ansatz="MPS", max_bond_dimension=max_bond, cut_ratio=cut_ratio)
|
||||
|
||||
qk_circuit = b._qibocirc_to_qiskitcirc(circuit)
|
||||
run_qk_params = qmatchatea.preprocessing.qk_transpilation_params(False)
|
||||
observables = qtealeaves.observables.TNObservables()
|
||||
observables += qtealeaves.observables.TNState2File(name="temp", formatting="D")
|
||||
|
||||
t0 = time.time()
|
||||
results = qmatchatea.run_simulation(
|
||||
circ=qk_circuit,
|
||||
convergence_parameters=b.convergence_params,
|
||||
transpilation_parameters=run_qk_params,
|
||||
backend=b.qmatchatea_backend,
|
||||
observables=observables,
|
||||
)
|
||||
elapsed = time.time() - t0
|
||||
tn_state = results.observables.get("tn_state")
|
||||
if tn_state is None:
|
||||
results.load_state()
|
||||
tn_state = results.observables["tn_state"]
|
||||
sv_obj = tn_state.to_statevector(qiskit_order=False, max_qubit_equivalent=40)
|
||||
sv = np.array(sv_obj.elem, dtype=complex).flatten()
|
||||
return sv, elapsed
|
||||
|
||||
|
||||
def compare(sv_ref, sv_mps):
|
||||
sv_ref = np.array(sv_ref, dtype=complex).flatten()
|
||||
fidelity = abs(np.dot(sv_ref.conj(), sv_mps)) ** 2
|
||||
l2_err = np.linalg.norm(sv_ref - sv_mps)
|
||||
return fidelity, l2_err
|
||||
|
||||
|
||||
def jit_cache_path(circuit_type, nqubits, nlayers):
|
||||
os.makedirs(DATA_DIR, exist_ok=True)
|
||||
return os.path.join(DATA_DIR, f"jit_{circuit_type}_n{nqubits}_l{nlayers}.npy")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--nqubits", type=int, default=10)
|
||||
parser.add_argument("--circuit", type=str, default="ghz",
|
||||
choices=["qft", "variational", "ghz"])
|
||||
parser.add_argument("--nlayers", type=int, default=3)
|
||||
parser.add_argument("--max-bond", type=int, default=64)
|
||||
parser.add_argument("--cut-ratio", type=float, default=1e-6)
|
||||
parser.add_argument("--skip-jit", action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
print(f"Circuit: {args.circuit}, nqubits={args.nqubits}, nlayers={args.nlayers}")
|
||||
print(f"MPS config: max_bond={args.max_bond}, cut_ratio={args.cut_ratio}")
|
||||
|
||||
cache_path = jit_cache_path(args.circuit, args.nqubits, args.nlayers)
|
||||
t_ref = None
|
||||
|
||||
if args.skip_jit and os.path.exists(cache_path):
|
||||
sv_ref = np.load(cache_path)
|
||||
print(f"\n[qibojit] loaded from cache: {cache_path}")
|
||||
else:
|
||||
np.random.seed(42)
|
||||
circuit_ref = make_circuit(args.circuit, args.nqubits, args.nlayers)
|
||||
sv_ref, t_ref = run_qibojit(circuit_ref)
|
||||
np.save(cache_path, sv_ref)
|
||||
print(f"\n[qibojit] time={t_ref:.4f}s (saved to {cache_path})")
|
||||
|
||||
np.random.seed(42)
|
||||
circuit_mps = make_circuit(args.circuit, args.nqubits, args.nlayers)
|
||||
try:
|
||||
sv_mps, t_mps = run_qmatchatea(circuit_mps, args.max_bond, args.cut_ratio)
|
||||
fidelity, l2_err = compare(sv_ref, sv_mps)
|
||||
print(f"[qmatchatea] time={t_mps:.4f}s")
|
||||
print(f"\nFidelity : {fidelity:.8f} (1=perfect)")
|
||||
print(f"L2 error : {l2_err:.2e}")
|
||||
if t_ref is not None and t_mps > 0:
|
||||
print(f"Speedup : {t_ref/t_mps:.2f}x")
|
||||
except Exception as e:
|
||||
print(f"[qmatchatea] FAILED: {e}")
|
||||
raise
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user