diff --git a/benchmark_tn.py b/benchmark_tn.py new file mode 100644 index 0000000..4daa6fb --- /dev/null +++ b/benchmark_tn.py @@ -0,0 +1,113 @@ +"""Benchmark: qibotn/quimb generic TN — expectation values.""" +import time +import argparse +import numpy as np +import cotengra as ctg +import qibo +from qibo import Circuit, gates + +def make_circuit(circuit_type, nqubits, nlayers=1): + c = Circuit(nqubits) + if circuit_type == "qft": + from qibo.models import QFT + return QFT(nqubits) + elif circuit_type == "variational": + for layer in range(nlayers): + for q in range(nqubits): + c.add(gates.RY(q, theta=np.random.uniform(0, 2 * np.pi))) + offset = layer % 2 + for q in range(offset, nqubits - 1, 2): + c.add(gates.CZ(q, q + 1)) + elif circuit_type == "ghz": + c.add(gates.H(0)) + for q in range(nqubits - 1): + c.add(gates.CNOT(q, q + 1)) + else: + raise ValueError(f"Unknown circuit: {circuit_type}") + return c + + + +def make_z_observable(nqubits): + """Z on qubit 0 only — single contraction for benchmarking""" + return ["z"], [(0,)], [1.0] + + +def run_quimb_tn(circuit, nqubits): + qibo.set_backend("qibotn", platform="quimb") + b = qibo.get_backend() + b.configure_tn_simulation(ansatz="tn") # generic TN, no MPS + #if max_time is not None: + # opt = ctg.HyperOptimizer(max_repeats=128, max_time=max_time, minimize=minimize, parallel=True) + #else: + opt = ctg.HyperOptimizer( + max_repeats=16, + parallel=True, + slicing_opts={'target_size': 2**24}, # 限制单个张量最大 2^28 个元素 + progbar=True + ) + + b.contractions_optimizer = opt + operators, sites, coeffs = make_z_observable(nqubits) + t0 = time.time() + expval = b.exp_value_observable_symbolic(circuit, operators, sites, coeffs, nqubits) + elapsed = time.time() - t0 + return expval, elapsed + + +def qibojit_expval(circuit, nqubits): + """Compute sum_i via qibojit statevector.""" + qibo.set_backend("qibojit", platform="numba") + t0 = time.time() + result = circuit() + elapsed = time.time() - t0 + sv = np.array(result.state(), dtype=complex).flatten() + probs = np.abs(sv) ** 2 + expval = sum( + probs[idx] * (1 - 2 * ((idx >> (nqubits - 1 - i)) & 1)) + for i in range(nqubits) + for idx in range(len(probs)) + ) + return float(expval), elapsed + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--nqubits", type=int, default=10) + parser.add_argument("--circuit", type=str, default="qft", + choices=["qft", "variational", "ghz"]) + parser.add_argument("--nlayers", type=int, default=3) + parser.add_argument("--optimizer", type=str, default="auto-hq") + parser.add_argument("--max-time", type=float, default=None, + help="HyperOptimizer max search time (seconds); overrides --optimizer") + parser.add_argument("--minimize", type=str, default="flops", + choices=["flops", "size", "write"], + help="HyperOptimizer minimize target") + parser.add_argument("--no-compare", action="store_true", + help="Skip qibojit reference run") + args = parser.parse_args() + + print(f"Circuit: {args.circuit}, nqubits={args.nqubits}, nlayers={args.nlayers}") + print(f"TN config: optimizer={args.optimizer}, max_time={args.max_time}, minimize={args.minimize}") + + np.random.seed(42) + circuit_tn = make_circuit(args.circuit, args.nqubits, args.nlayers) + try: + expval_tn, t_tn = run_quimb_tn(circuit_tn, args.nqubits) + print(f"\n[quimb TN] time={t_tn:.4f}s ={expval_tn:.8f}") + except Exception as e: + print(f"[quimb TN] FAILED: {e}") + raise + + if not args.no_compare: + np.random.seed(42) + circuit_ref = make_circuit(args.circuit, args.nqubits, args.nlayers) + expval_ref, t_ref = qibojit_expval(circuit_ref, args.nqubits) + print(f"[qibojit] time={t_ref:.4f}s ={expval_ref:.8f}") + print(f"\nDiff : {abs(expval_tn - expval_ref):.2e}") + if t_tn > 0: + print(f"Speedup : {t_ref/t_tn:.2f}x") + + +if __name__ == "__main__": + main()