Some checks failed
Build wheels / build (ubuntu-latest, 3.11) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.12) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.13) (push) Has been cancelled
Tests / check (push) Has been cancelled
Tests / build (ubuntu-latest, 3.11) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.12) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.13) (push) Has been cancelled
460 lines
13 KiB
Python
460 lines
13 KiB
Python
"""Benchmark: qibotn/quimb generic TN — single-process torch profiling version."""
|
||
|
||
import os
|
||
import pickle
|
||
import time
|
||
import argparse
|
||
import numpy as np
|
||
import cotengra as ctg
|
||
import qibo
|
||
from qibo import Circuit, gates
|
||
|
||
|
||
def make_circuit(circuit_type, nqubits, nlayers=1):
|
||
c = Circuit(nqubits)
|
||
|
||
if circuit_type == "qft":
|
||
from qibo.models import QFT
|
||
return QFT(nqubits)
|
||
|
||
elif circuit_type == "variational":
|
||
for layer in range(nlayers):
|
||
for q in range(nqubits):
|
||
c.add(gates.RY(q, theta=np.random.uniform(0, 2 * np.pi)))
|
||
|
||
offset = layer % 2
|
||
for q in range(offset, nqubits - 1, 2):
|
||
c.add(gates.CZ(q, q + 1))
|
||
|
||
elif circuit_type == "ghz":
|
||
c.add(gates.H(0))
|
||
for q in range(nqubits - 1):
|
||
c.add(gates.CNOT(q, q + 1))
|
||
|
||
elif circuit_type == "brickwork":
|
||
for q in range(nqubits):
|
||
c.add(gates.H(q))
|
||
|
||
for layer in range(nlayers):
|
||
offset = layer % 2
|
||
for q in range(offset, nqubits - 1, 2):
|
||
c.add(gates.CNOT(q, q + 1))
|
||
c.add(gates.RZ(q, theta=np.random.uniform(0, 2 * np.pi)))
|
||
c.add(gates.RZ(q + 1, theta=np.random.uniform(0, 2 * np.pi)))
|
||
|
||
else:
|
||
raise ValueError(f"Unknown circuit: {circuit_type}")
|
||
|
||
return c
|
||
|
||
|
||
def make_z_observable(nqubits):
|
||
"""Z on qubit 0 only — single contraction for benchmarking."""
|
||
return ["z"], [(0,)], [1.0]
|
||
|
||
|
||
def export_profiler_outputs(prof, trace_path):
|
||
"""Export Chrome trace and text table."""
|
||
prof.export_chrome_trace(trace_path)
|
||
|
||
table_path = trace_path.replace(".json", ".txt")
|
||
with open(table_path, "w") as f:
|
||
f.write(
|
||
prof.key_averages().table(
|
||
sort_by="self_cpu_time_total",
|
||
row_limit=200,
|
||
)
|
||
)
|
||
|
||
print(f" [torch profiler trace] {trace_path}")
|
||
print(f" [torch profiler table] {table_path}")
|
||
|
||
|
||
def run_quimb_tn(
|
||
circuit,
|
||
nqubits,
|
||
num_slices,
|
||
load_path=None,
|
||
save_path=None,
|
||
):
|
||
"""Mode: expval — compute <Z_0> via local_expectation."""
|
||
qibo.set_backend("qibotn", platform="quimb")
|
||
b = qibo.get_backend()
|
||
b.configure_tn_simulation(ansatz="tn")
|
||
|
||
operators, sites, coeffs = make_z_observable(nqubits)
|
||
ops = b._string_to_quimb_operator(operators[0])
|
||
|
||
qc = b._qibo_circuit_to_quimb(
|
||
circuit,
|
||
quimb_circuit_type=b.circuit_ansatz,
|
||
gate_opts={"max_bond": None, "cutoff": 1e-10},
|
||
)
|
||
|
||
if load_path:
|
||
with open(load_path, "rb") as f:
|
||
saved = pickle.load(f)
|
||
tree = saved["tree"]
|
||
t_search = 0.0
|
||
print(f" [path loaded] {load_path}")
|
||
else:
|
||
opt = ctg.HyperOptimizer(
|
||
methods=["kahypar", "random-greedy", "spinglass"],
|
||
max_repeats=16,
|
||
parallel=True,
|
||
max_time=60,
|
||
slicing_opts={"target_slices": num_slices},
|
||
progbar=True,
|
||
)
|
||
|
||
t0 = time.time()
|
||
rehearsal = qc.local_expectation(
|
||
ops,
|
||
where=sites[0],
|
||
optimize=opt,
|
||
simplify_sequence="R",
|
||
rehearse=True,
|
||
)
|
||
t_search = time.time() - t0
|
||
tree = rehearsal["tree"]
|
||
|
||
print(
|
||
f" [path search] {t_search:.3f}s "
|
||
f"flops~2^{tree.contraction_cost():.2f} "
|
||
f"size~2^{tree.contraction_width():.2f} "
|
||
f"slices={tree.multiplicity}"
|
||
)
|
||
|
||
if save_path:
|
||
with open(save_path, "wb") as f:
|
||
pickle.dump({"tree": tree}, f)
|
||
print(f" [path saved] {save_path}")
|
||
|
||
t0 = time.time()
|
||
expval = qc.local_expectation(
|
||
ops,
|
||
where=sites[0],
|
||
optimize=tree,
|
||
simplify_sequence="R",
|
||
)
|
||
t_contract = time.time() - t0
|
||
|
||
print(f" [contraction] {t_contract:.3f}s")
|
||
|
||
return float(expval.real), t_search + t_contract
|
||
|
||
|
||
def run_quimb_tn_statevector(
|
||
circuit,
|
||
nqubits,
|
||
num_slices,
|
||
load_path=None,
|
||
save_path=None,
|
||
profile=False,
|
||
profile_dir="profiles",
|
||
):
|
||
"""Mode: statevector — contract full TN to dense vector, single process."""
|
||
qibo.set_backend("qibotn", platform="quimb")
|
||
b = qibo.get_backend()
|
||
b.configure_tn_simulation(ansatz="tn")
|
||
|
||
import torch
|
||
|
||
qc = b._qibo_circuit_to_quimb(
|
||
circuit,
|
||
quimb_circuit_type=b.circuit_ansatz,
|
||
gate_opts={"max_bond": None, "cutoff": 1e-10},
|
||
)
|
||
|
||
# 让 quimb 生成 torch tensor,这样 torch.profiler 能抓到 aten op。
|
||
qc.to_backend = torch.from_numpy
|
||
|
||
if load_path:
|
||
with open(load_path, "rb") as f:
|
||
saved = pickle.load(f)
|
||
tree = saved["tree"]
|
||
t_search = 0.0
|
||
print(f" [path loaded] {load_path}")
|
||
else:
|
||
opt = ctg.HyperOptimizer(
|
||
methods=["kahypar", "random-greedy", "spinglass"],
|
||
max_repeats=500,
|
||
parallel=48,
|
||
max_time=100,
|
||
minimize="size",
|
||
slicing_opts={"target_slices": num_slices},
|
||
progbar=True,
|
||
)
|
||
|
||
t0 = time.time()
|
||
rehearsal = qc.to_dense(optimize=opt, rehearse=True)
|
||
t_search = time.time() - t0
|
||
tree = rehearsal["tree"]
|
||
|
||
print(
|
||
f" [path search] {t_search:.3f}s "
|
||
f"flops~2^{tree.contraction_cost():.2f} "
|
||
f"size~2^{tree.contraction_width():.2f} "
|
||
f"slices={tree.multiplicity}"
|
||
)
|
||
|
||
if save_path:
|
||
with open(save_path, "wb") as f:
|
||
pickle.dump({"tree": tree}, f)
|
||
print(f" [path saved] {save_path}")
|
||
|
||
os.makedirs(profile_dir, exist_ok=True)
|
||
|
||
if profile:
|
||
from torch.profiler import profile as torch_profile
|
||
from torch.profiler import ProfilerActivity, record_function
|
||
|
||
trace_path = os.path.join(
|
||
profile_dir,
|
||
(
|
||
f"trace_statevector_"
|
||
f"{circuit.nqubits}q_"
|
||
f"slices{tree.multiplicity}_"
|
||
f"{int(time.time())}.json"
|
||
),
|
||
)
|
||
|
||
t0 = time.time()
|
||
|
||
with torch_profile(
|
||
activities=[ProfilerActivity.CPU],
|
||
record_shapes=True,
|
||
profile_memory=True,
|
||
with_stack=True,
|
||
) as prof:
|
||
with record_function("qibotn_to_dense_contraction"):
|
||
sv = qc.to_dense(optimize=tree).reshape(-1)
|
||
|
||
with record_function("torch_to_numpy_view_or_copy"):
|
||
if type(sv).__module__.startswith("torch"):
|
||
sv_tn = sv.detach().cpu().numpy()
|
||
else:
|
||
sv_tn = np.asarray(sv)
|
||
|
||
t_contract = time.time() - t0
|
||
|
||
export_profiler_outputs(prof, trace_path)
|
||
|
||
else:
|
||
t0 = time.time()
|
||
sv = qc.to_dense(optimize=tree).reshape(-1)
|
||
t_contract = time.time() - t0
|
||
|
||
if type(sv).__module__.startswith("torch"):
|
||
sv_tn = sv.detach().cpu().numpy()
|
||
else:
|
||
sv_tn = np.asarray(sv)
|
||
|
||
print(f" [contraction] {t_contract:.3f}s")
|
||
|
||
return sv_tn, t_search + t_contract
|
||
|
||
|
||
def run_quimb_tn_samples(circuit, nshots=1024):
|
||
"""Mode: samples — sample from circuit output distribution."""
|
||
qibo.set_backend("qibotn", platform="quimb")
|
||
b = qibo.get_backend()
|
||
b.configure_tn_simulation(ansatz="tn")
|
||
|
||
t0 = time.time()
|
||
result = b.execute_circuit(circuit, nshots=nshots)
|
||
t_total = time.time() - t0
|
||
|
||
print(f" [sampling] {t_total:.3f}s nshots={nshots}")
|
||
|
||
try:
|
||
freqs = result.frequencies()
|
||
print(f" top states: {dict(list(freqs.items())[:5])}")
|
||
except Exception:
|
||
pass
|
||
|
||
return result, t_total
|
||
|
||
|
||
def qibojit_expval(circuit, nqubits):
|
||
"""Compute <Z_0> via qibojit statevector."""
|
||
qibo.set_backend("qibojit", platform="numba")
|
||
|
||
t0 = time.time()
|
||
result = circuit()
|
||
elapsed = time.time() - t0
|
||
|
||
sv = np.array(result.state(), dtype=complex).flatten()
|
||
probs = np.abs(sv) ** 2
|
||
|
||
bits = (np.arange(len(probs)) >> (nqubits - 1)) & 1
|
||
expval = float(np.dot(probs, 1 - 2 * bits))
|
||
|
||
return expval, elapsed
|
||
|
||
|
||
def run_qibojit(circuit):
|
||
"""Compute full statevector via qibojit."""
|
||
qibo.set_backend("qibojit", platform="numba")
|
||
|
||
t0 = time.time()
|
||
result = circuit()
|
||
elapsed = time.time() - t0
|
||
|
||
sv = np.array(result.state(), dtype=complex).flatten()
|
||
|
||
return sv, elapsed
|
||
|
||
|
||
def main():
|
||
parser = argparse.ArgumentParser()
|
||
|
||
parser.add_argument("--nqubits", type=int, default=10)
|
||
parser.add_argument(
|
||
"--circuit",
|
||
type=str,
|
||
default="qft",
|
||
choices=["qft", "variational", "ghz", "brickwork"],
|
||
)
|
||
parser.add_argument("--nlayers", type=int, default=3)
|
||
parser.add_argument("--num-slices", type=int, default=1)
|
||
parser.add_argument("--nshots", type=int, default=1024)
|
||
|
||
parser.add_argument(
|
||
"--mode",
|
||
type=str,
|
||
default="statevector",
|
||
choices=["expval", "statevector", "samples"],
|
||
help="expval: local_expectation; statevector: to_dense; samples: sampling",
|
||
)
|
||
|
||
parser.add_argument(
|
||
"--no-compare",
|
||
action="store_true",
|
||
help="Skip qibojit reference run",
|
||
)
|
||
|
||
parser.add_argument(
|
||
"--save-path",
|
||
type=str,
|
||
default=None,
|
||
help="Save contraction tree to a pickle file",
|
||
)
|
||
|
||
parser.add_argument(
|
||
"--load-path",
|
||
type=str,
|
||
default=None,
|
||
help="Load contraction tree from a pickle file and skip path search",
|
||
)
|
||
|
||
parser.add_argument(
|
||
"--profile",
|
||
action="store_true",
|
||
help="Enable torch profiler for statevector contraction stage",
|
||
)
|
||
|
||
parser.add_argument(
|
||
"--profile-dir",
|
||
type=str,
|
||
default="profiles",
|
||
help="Directory to save torch profiler traces",
|
||
)
|
||
|
||
parser.add_argument(
|
||
"--save-statevector",
|
||
action="store_true",
|
||
help="Save TN statevector to data/sv_tn_*.npy",
|
||
)
|
||
|
||
args = parser.parse_args()
|
||
|
||
print(
|
||
f"Circuit: {args.circuit}, "
|
||
f"nqubits={args.nqubits}, "
|
||
f"nlayers={args.nlayers}, "
|
||
f"mode={args.mode}, "
|
||
f"profile={args.profile}"
|
||
)
|
||
|
||
np.random.seed(42)
|
||
circuit_tn = make_circuit(args.circuit, args.nqubits, args.nlayers)
|
||
|
||
try:
|
||
if args.mode == "expval":
|
||
expval_tn, t_tn = run_quimb_tn(
|
||
circuit_tn,
|
||
args.nqubits,
|
||
args.num_slices,
|
||
load_path=args.load_path,
|
||
save_path=args.save_path,
|
||
)
|
||
|
||
print(f"\n[quimb TN] time={t_tn:.4f}s <Z_0>={expval_tn:.8f}")
|
||
|
||
elif args.mode == "statevector":
|
||
sv_tn, t_tn = run_quimb_tn_statevector(
|
||
circuit_tn,
|
||
args.nqubits,
|
||
args.num_slices,
|
||
load_path=args.load_path,
|
||
save_path=args.save_path,
|
||
profile=args.profile,
|
||
profile_dir=args.profile_dir,
|
||
)
|
||
|
||
print(
|
||
f"\n[quimb TN] time={t_tn:.4f}s "
|
||
f"statevector shape={sv_tn.shape}"
|
||
)
|
||
|
||
if args.save_statevector:
|
||
os.makedirs("data", exist_ok=True)
|
||
out_path = f"data/sv_tn_{args.circuit}{args.nqubits}.npy"
|
||
np.save(out_path, sv_tn)
|
||
print(f"[saved statevector] {out_path}")
|
||
|
||
else:
|
||
_, t_tn = run_quimb_tn_samples(
|
||
circuit_tn,
|
||
nshots=args.nshots,
|
||
)
|
||
|
||
print(f"\n[quimb TN] time={t_tn:.4f}s")
|
||
args.no_compare = True
|
||
|
||
except Exception as e:
|
||
print(f"[quimb TN] FAILED: {e}")
|
||
raise
|
||
|
||
if not args.no_compare and args.mode != "statevector":
|
||
np.random.seed(42)
|
||
circuit_ref = make_circuit(args.circuit, args.nqubits, args.nlayers)
|
||
|
||
expval_ref, t_ref = qibojit_expval(circuit_ref, args.nqubits)
|
||
|
||
print(f"[qibojit] time={t_ref:.4f}s <Z_0>={expval_ref:.8f}")
|
||
print(f"\nDiff : {abs(expval_tn - expval_ref):.2e}")
|
||
|
||
if t_tn > 0:
|
||
print(f"Speedup : {t_ref / t_tn:.2f}x")
|
||
|
||
elif not args.no_compare and args.mode == "statevector" and sv_tn is not None:
|
||
np.random.seed(42)
|
||
circuit_ref = make_circuit(args.circuit, args.nqubits, args.nlayers)
|
||
|
||
sv_ref, t_ref = run_qibojit(circuit_ref)
|
||
|
||
fid = abs(np.dot(sv_ref.conj(), sv_tn)) ** 2
|
||
l2_err = np.linalg.norm(sv_ref - sv_tn)
|
||
|
||
print(f"[qibojit] time={t_ref:.4f}s")
|
||
print(f"Fidelity : {fid:.8f} (1=perfect)")
|
||
print(f"L2 error : {l2_err:.2e}")
|
||
|
||
if t_tn > 0:
|
||
print(f"Speedup : {t_ref / t_tn:.2f}x")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main() |