决赛现场脚本
Some checks failed
Build wheels / build (ubuntu-latest, 3.11) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.12) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.13) (push) Has been cancelled
Tests / check (push) Has been cancelled
Tests / build (ubuntu-latest, 3.11) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.12) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.13) (push) Has been cancelled
Some checks failed
Build wheels / build (ubuntu-latest, 3.11) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.12) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.13) (push) Has been cancelled
Tests / check (push) Has been cancelled
Tests / build (ubuntu-latest, 3.11) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.12) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.13) (push) Has been cancelled
This commit is contained in:
157
tools/benchmark_qredtea_svd_controls.py
Normal file
157
tools/benchmark_qredtea_svd_controls.py
Normal file
@@ -0,0 +1,157 @@
|
||||
#!/usr/bin/env python
|
||||
"""Benchmark qredtea/qtealeaves SVD control modes.
|
||||
|
||||
This isolates the tensor split used by MPS updates: a rank-2 tensor is split
|
||||
with singular values contracted either left or right, then reconstructed to
|
||||
measure numerical error and timing.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import gc
|
||||
import statistics
|
||||
import time
|
||||
|
||||
import torch
|
||||
|
||||
import qmatchatea
|
||||
from qredtea.torchapi import QteaTorchTensor
|
||||
|
||||
|
||||
def _dtype(name: str):
|
||||
return {
|
||||
"complex64": torch.complex64,
|
||||
"complex128": torch.complex128,
|
||||
"float64": torch.float64,
|
||||
"float32": torch.float32,
|
||||
}[name]
|
||||
|
||||
|
||||
def _random_matrix(shape, dtype, seed):
|
||||
gen = torch.Generator(device="cpu")
|
||||
gen.manual_seed(seed)
|
||||
if dtype.is_complex:
|
||||
real_dtype = torch.float32 if dtype == torch.complex64 else torch.float64
|
||||
real = torch.randn(shape, dtype=real_dtype, generator=gen)
|
||||
imag = torch.randn(shape, dtype=real_dtype, generator=gen)
|
||||
return torch.complex(real, imag).to(dtype)
|
||||
return torch.randn(shape, dtype=dtype, generator=gen)
|
||||
|
||||
|
||||
def _sync():
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
|
||||
|
||||
def run_one(matrix, ctrl, max_bond, contract_singvals, repeats):
|
||||
conv = qmatchatea.QCConvergenceParameters(
|
||||
max_bond_dimension=max_bond,
|
||||
cut_ratio=0.0,
|
||||
svd_ctrl=ctrl,
|
||||
)
|
||||
qtensor = QteaTorchTensor.from_elem_array(matrix, dtype=matrix.dtype, device="cpu")
|
||||
|
||||
times = []
|
||||
rel_error = None
|
||||
kept = None
|
||||
status = "ok"
|
||||
error = ""
|
||||
|
||||
for i in range(repeats):
|
||||
gc.collect()
|
||||
_sync()
|
||||
t0 = time.perf_counter()
|
||||
try:
|
||||
left, right, singvals, _ = qtensor.split_svd(
|
||||
[0],
|
||||
[1],
|
||||
contract_singvals=contract_singvals,
|
||||
conv_params=conv,
|
||||
)
|
||||
except Exception as exc: # noqa: BLE001 - benchmark should keep going
|
||||
status = "error"
|
||||
error = repr(exc)
|
||||
break
|
||||
_sync()
|
||||
times.append(time.perf_counter() - t0)
|
||||
|
||||
if i == repeats - 1:
|
||||
left_matrix = left.elem.reshape(matrix.shape[0], -1)
|
||||
right_matrix = right.elem.reshape(-1, matrix.shape[1])
|
||||
recon = left_matrix @ right_matrix
|
||||
rel_error = (
|
||||
torch.linalg.vector_norm(matrix - recon)
|
||||
/ torch.linalg.vector_norm(matrix)
|
||||
).item()
|
||||
kept = int(singvals.numel())
|
||||
|
||||
return {
|
||||
"ctrl": ctrl,
|
||||
"contract_singvals": contract_singvals,
|
||||
"status": status,
|
||||
"median_ms": float("nan") if not times else statistics.median(times) * 1000,
|
||||
"min_ms": float("nan") if not times else min(times) * 1000,
|
||||
"rel_error": rel_error,
|
||||
"kept": kept,
|
||||
"error": error,
|
||||
}
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--shapes", nargs="+", default=("256x1024", "1024x256", "512x512"))
|
||||
parser.add_argument("--max-bond", type=int, default=128)
|
||||
parser.add_argument("--dtype", choices=("complex64", "complex128", "float32", "float64"), default="complex128")
|
||||
parser.add_argument("--threads", type=int, default=8)
|
||||
parser.add_argument("--repeats", type=int, default=3)
|
||||
parser.add_argument(
|
||||
"--controls",
|
||||
nargs="+",
|
||||
default=("A", "D", "V", "R", "E", "E!", "X", "X!"),
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
torch.set_num_threads(args.threads)
|
||||
dtype = _dtype(args.dtype)
|
||||
|
||||
print(
|
||||
"svd_benchmark "
|
||||
f"dtype={args.dtype} threads={torch.get_num_threads()} "
|
||||
f"max_bond={args.max_bond} repeats={args.repeats}",
|
||||
flush=True,
|
||||
)
|
||||
print(
|
||||
"columns shape contract ctrl status median_ms min_ms kept rel_error error",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
for shape_text in args.shapes:
|
||||
m_text, n_text = shape_text.lower().split("x", 1)
|
||||
shape = (int(m_text), int(n_text))
|
||||
matrix = _random_matrix(shape, dtype, seed=sum(shape))
|
||||
for contract_singvals in ("L", "R"):
|
||||
for ctrl in args.controls:
|
||||
result = run_one(
|
||||
matrix,
|
||||
ctrl=ctrl,
|
||||
max_bond=args.max_bond,
|
||||
contract_singvals=contract_singvals,
|
||||
repeats=args.repeats,
|
||||
)
|
||||
print(
|
||||
f"row shape={shape_text} "
|
||||
f"contract={contract_singvals} "
|
||||
f"ctrl={ctrl} "
|
||||
f"status={result['status']} "
|
||||
f"median_ms={result['median_ms']:.3f} "
|
||||
f"min_ms={result['min_ms']:.3f} "
|
||||
f"kept={result['kept']} "
|
||||
f"rel_error={result['rel_error']} "
|
||||
f"error={result['error']}",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user