Files
qibotn/bench_profile.py
jaunatisblue dd222587b7
Some checks failed
Build wheels / build (ubuntu-latest, 3.11) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.12) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.13) (push) Has been cancelled
Tests / check (push) Has been cancelled
Tests / build (ubuntu-latest, 3.11) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.12) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.13) (push) Has been cancelled
tn脚本更新
2026-05-03 18:54:05 +08:00

460 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Benchmark: qibotn/quimb generic TN — single-process torch profiling version."""
import os
import pickle
import time
import argparse
import numpy as np
import cotengra as ctg
import qibo
from qibo import Circuit, gates
def make_circuit(circuit_type, nqubits, nlayers=1):
c = Circuit(nqubits)
if circuit_type == "qft":
from qibo.models import QFT
return QFT(nqubits)
elif circuit_type == "variational":
for layer in range(nlayers):
for q in range(nqubits):
c.add(gates.RY(q, theta=np.random.uniform(0, 2 * np.pi)))
offset = layer % 2
for q in range(offset, nqubits - 1, 2):
c.add(gates.CZ(q, q + 1))
elif circuit_type == "ghz":
c.add(gates.H(0))
for q in range(nqubits - 1):
c.add(gates.CNOT(q, q + 1))
elif circuit_type == "brickwork":
for q in range(nqubits):
c.add(gates.H(q))
for layer in range(nlayers):
offset = layer % 2
for q in range(offset, nqubits - 1, 2):
c.add(gates.CNOT(q, q + 1))
c.add(gates.RZ(q, theta=np.random.uniform(0, 2 * np.pi)))
c.add(gates.RZ(q + 1, theta=np.random.uniform(0, 2 * np.pi)))
else:
raise ValueError(f"Unknown circuit: {circuit_type}")
return c
def make_z_observable(nqubits):
"""Z on qubit 0 only — single contraction for benchmarking."""
return ["z"], [(0,)], [1.0]
def export_profiler_outputs(prof, trace_path):
"""Export Chrome trace and text table."""
prof.export_chrome_trace(trace_path)
table_path = trace_path.replace(".json", ".txt")
with open(table_path, "w") as f:
f.write(
prof.key_averages().table(
sort_by="self_cpu_time_total",
row_limit=200,
)
)
print(f" [torch profiler trace] {trace_path}")
print(f" [torch profiler table] {table_path}")
def run_quimb_tn(
circuit,
nqubits,
num_slices,
load_path=None,
save_path=None,
):
"""Mode: expval — compute <Z_0> via local_expectation."""
qibo.set_backend("qibotn", platform="quimb")
b = qibo.get_backend()
b.configure_tn_simulation(ansatz="tn")
operators, sites, coeffs = make_z_observable(nqubits)
ops = b._string_to_quimb_operator(operators[0])
qc = b._qibo_circuit_to_quimb(
circuit,
quimb_circuit_type=b.circuit_ansatz,
gate_opts={"max_bond": None, "cutoff": 1e-10},
)
if load_path:
with open(load_path, "rb") as f:
saved = pickle.load(f)
tree = saved["tree"]
t_search = 0.0
print(f" [path loaded] {load_path}")
else:
opt = ctg.HyperOptimizer(
methods=["kahypar", "random-greedy", "spinglass"],
max_repeats=16,
parallel=True,
max_time=60,
slicing_opts={"target_slices": num_slices},
progbar=True,
)
t0 = time.time()
rehearsal = qc.local_expectation(
ops,
where=sites[0],
optimize=opt,
simplify_sequence="R",
rehearse=True,
)
t_search = time.time() - t0
tree = rehearsal["tree"]
print(
f" [path search] {t_search:.3f}s "
f"flops~2^{tree.contraction_cost():.2f} "
f"size~2^{tree.contraction_width():.2f} "
f"slices={tree.multiplicity}"
)
if save_path:
with open(save_path, "wb") as f:
pickle.dump({"tree": tree}, f)
print(f" [path saved] {save_path}")
t0 = time.time()
expval = qc.local_expectation(
ops,
where=sites[0],
optimize=tree,
simplify_sequence="R",
)
t_contract = time.time() - t0
print(f" [contraction] {t_contract:.3f}s")
return float(expval.real), t_search + t_contract
def run_quimb_tn_statevector(
circuit,
nqubits,
num_slices,
load_path=None,
save_path=None,
profile=False,
profile_dir="profiles",
):
"""Mode: statevector — contract full TN to dense vector, single process."""
qibo.set_backend("qibotn", platform="quimb")
b = qibo.get_backend()
b.configure_tn_simulation(ansatz="tn")
import torch
qc = b._qibo_circuit_to_quimb(
circuit,
quimb_circuit_type=b.circuit_ansatz,
gate_opts={"max_bond": None, "cutoff": 1e-10},
)
# 让 quimb 生成 torch tensor这样 torch.profiler 能抓到 aten op。
qc.to_backend = torch.from_numpy
if load_path:
with open(load_path, "rb") as f:
saved = pickle.load(f)
tree = saved["tree"]
t_search = 0.0
print(f" [path loaded] {load_path}")
else:
opt = ctg.HyperOptimizer(
methods=["kahypar", "random-greedy", "spinglass"],
max_repeats=500,
parallel=48,
max_time=100,
minimize="size",
slicing_opts={"target_slices": num_slices},
progbar=True,
)
t0 = time.time()
rehearsal = qc.to_dense(optimize=opt, rehearse=True)
t_search = time.time() - t0
tree = rehearsal["tree"]
print(
f" [path search] {t_search:.3f}s "
f"flops~2^{tree.contraction_cost():.2f} "
f"size~2^{tree.contraction_width():.2f} "
f"slices={tree.multiplicity}"
)
if save_path:
with open(save_path, "wb") as f:
pickle.dump({"tree": tree}, f)
print(f" [path saved] {save_path}")
os.makedirs(profile_dir, exist_ok=True)
if profile:
from torch.profiler import profile as torch_profile
from torch.profiler import ProfilerActivity, record_function
trace_path = os.path.join(
profile_dir,
(
f"trace_statevector_"
f"{circuit.nqubits}q_"
f"slices{tree.multiplicity}_"
f"{int(time.time())}.json"
),
)
t0 = time.time()
with torch_profile(
activities=[ProfilerActivity.CPU],
record_shapes=True,
profile_memory=True,
with_stack=True,
) as prof:
with record_function("qibotn_to_dense_contraction"):
sv = qc.to_dense(optimize=tree).reshape(-1)
with record_function("torch_to_numpy_view_or_copy"):
if type(sv).__module__.startswith("torch"):
sv_tn = sv.detach().cpu().numpy()
else:
sv_tn = np.asarray(sv)
t_contract = time.time() - t0
export_profiler_outputs(prof, trace_path)
else:
t0 = time.time()
sv = qc.to_dense(optimize=tree).reshape(-1)
t_contract = time.time() - t0
if type(sv).__module__.startswith("torch"):
sv_tn = sv.detach().cpu().numpy()
else:
sv_tn = np.asarray(sv)
print(f" [contraction] {t_contract:.3f}s")
return sv_tn, t_search + t_contract
def run_quimb_tn_samples(circuit, nshots=1024):
"""Mode: samples — sample from circuit output distribution."""
qibo.set_backend("qibotn", platform="quimb")
b = qibo.get_backend()
b.configure_tn_simulation(ansatz="tn")
t0 = time.time()
result = b.execute_circuit(circuit, nshots=nshots)
t_total = time.time() - t0
print(f" [sampling] {t_total:.3f}s nshots={nshots}")
try:
freqs = result.frequencies()
print(f" top states: {dict(list(freqs.items())[:5])}")
except Exception:
pass
return result, t_total
def qibojit_expval(circuit, nqubits):
"""Compute <Z_0> via qibojit statevector."""
qibo.set_backend("qibojit", platform="numba")
t0 = time.time()
result = circuit()
elapsed = time.time() - t0
sv = np.array(result.state(), dtype=complex).flatten()
probs = np.abs(sv) ** 2
bits = (np.arange(len(probs)) >> (nqubits - 1)) & 1
expval = float(np.dot(probs, 1 - 2 * bits))
return expval, elapsed
def run_qibojit(circuit):
"""Compute full statevector via qibojit."""
qibo.set_backend("qibojit", platform="numba")
t0 = time.time()
result = circuit()
elapsed = time.time() - t0
sv = np.array(result.state(), dtype=complex).flatten()
return sv, elapsed
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--nqubits", type=int, default=10)
parser.add_argument(
"--circuit",
type=str,
default="qft",
choices=["qft", "variational", "ghz", "brickwork"],
)
parser.add_argument("--nlayers", type=int, default=3)
parser.add_argument("--num-slices", type=int, default=1)
parser.add_argument("--nshots", type=int, default=1024)
parser.add_argument(
"--mode",
type=str,
default="statevector",
choices=["expval", "statevector", "samples"],
help="expval: local_expectation; statevector: to_dense; samples: sampling",
)
parser.add_argument(
"--no-compare",
action="store_true",
help="Skip qibojit reference run",
)
parser.add_argument(
"--save-path",
type=str,
default=None,
help="Save contraction tree to a pickle file",
)
parser.add_argument(
"--load-path",
type=str,
default=None,
help="Load contraction tree from a pickle file and skip path search",
)
parser.add_argument(
"--profile",
action="store_true",
help="Enable torch profiler for statevector contraction stage",
)
parser.add_argument(
"--profile-dir",
type=str,
default="profiles",
help="Directory to save torch profiler traces",
)
parser.add_argument(
"--save-statevector",
action="store_true",
help="Save TN statevector to data/sv_tn_*.npy",
)
args = parser.parse_args()
print(
f"Circuit: {args.circuit}, "
f"nqubits={args.nqubits}, "
f"nlayers={args.nlayers}, "
f"mode={args.mode}, "
f"profile={args.profile}"
)
np.random.seed(42)
circuit_tn = make_circuit(args.circuit, args.nqubits, args.nlayers)
try:
if args.mode == "expval":
expval_tn, t_tn = run_quimb_tn(
circuit_tn,
args.nqubits,
args.num_slices,
load_path=args.load_path,
save_path=args.save_path,
)
print(f"\n[quimb TN] time={t_tn:.4f}s <Z_0>={expval_tn:.8f}")
elif args.mode == "statevector":
sv_tn, t_tn = run_quimb_tn_statevector(
circuit_tn,
args.nqubits,
args.num_slices,
load_path=args.load_path,
save_path=args.save_path,
profile=args.profile,
profile_dir=args.profile_dir,
)
print(
f"\n[quimb TN] time={t_tn:.4f}s "
f"statevector shape={sv_tn.shape}"
)
if args.save_statevector:
os.makedirs("data", exist_ok=True)
out_path = f"data/sv_tn_{args.circuit}{args.nqubits}.npy"
np.save(out_path, sv_tn)
print(f"[saved statevector] {out_path}")
else:
_, t_tn = run_quimb_tn_samples(
circuit_tn,
nshots=args.nshots,
)
print(f"\n[quimb TN] time={t_tn:.4f}s")
args.no_compare = True
except Exception as e:
print(f"[quimb TN] FAILED: {e}")
raise
if not args.no_compare and args.mode != "statevector":
np.random.seed(42)
circuit_ref = make_circuit(args.circuit, args.nqubits, args.nlayers)
expval_ref, t_ref = qibojit_expval(circuit_ref, args.nqubits)
print(f"[qibojit] time={t_ref:.4f}s <Z_0>={expval_ref:.8f}")
print(f"\nDiff : {abs(expval_tn - expval_ref):.2e}")
if t_tn > 0:
print(f"Speedup : {t_ref / t_tn:.2f}x")
elif not args.no_compare and args.mode == "statevector" and sv_tn is not None:
np.random.seed(42)
circuit_ref = make_circuit(args.circuit, args.nqubits, args.nlayers)
sv_ref, t_ref = run_qibojit(circuit_ref)
fid = abs(np.dot(sv_ref.conj(), sv_tn)) ** 2
l2_err = np.linalg.norm(sv_ref - sv_tn)
print(f"[qibojit] time={t_ref:.4f}s")
print(f"Fidelity : {fid:.8f} (1=perfect)")
print(f"L2 error : {l2_err:.2e}")
if t_tn > 0:
print(f"Speedup : {t_ref / t_tn:.2f}x")
if __name__ == "__main__":
main()