#!/usr/bin/env python """Benchmark qredtea/qtealeaves SVD control modes. This isolates the tensor split used by MPS updates: a rank-2 tensor is split with singular values contracted either left or right, then reconstructed to measure numerical error and timing. """ from __future__ import annotations import argparse import gc import statistics import time import torch import qmatchatea from qredtea.torchapi import QteaTorchTensor def _dtype(name: str): return { "complex64": torch.complex64, "complex128": torch.complex128, "float64": torch.float64, "float32": torch.float32, }[name] def _random_matrix(shape, dtype, seed): gen = torch.Generator(device="cpu") gen.manual_seed(seed) if dtype.is_complex: real_dtype = torch.float32 if dtype == torch.complex64 else torch.float64 real = torch.randn(shape, dtype=real_dtype, generator=gen) imag = torch.randn(shape, dtype=real_dtype, generator=gen) return torch.complex(real, imag).to(dtype) return torch.randn(shape, dtype=dtype, generator=gen) def _sync(): if torch.cuda.is_available(): torch.cuda.synchronize() def run_one(matrix, ctrl, max_bond, contract_singvals, repeats): conv = qmatchatea.QCConvergenceParameters( max_bond_dimension=max_bond, cut_ratio=0.0, svd_ctrl=ctrl, ) qtensor = QteaTorchTensor.from_elem_array(matrix, dtype=matrix.dtype, device="cpu") times = [] rel_error = None kept = None status = "ok" error = "" for i in range(repeats): gc.collect() _sync() t0 = time.perf_counter() try: left, right, singvals, _ = qtensor.split_svd( [0], [1], contract_singvals=contract_singvals, conv_params=conv, ) except Exception as exc: # noqa: BLE001 - benchmark should keep going status = "error" error = repr(exc) break _sync() times.append(time.perf_counter() - t0) if i == repeats - 1: left_matrix = left.elem.reshape(matrix.shape[0], -1) right_matrix = right.elem.reshape(-1, matrix.shape[1]) recon = left_matrix @ right_matrix rel_error = ( torch.linalg.vector_norm(matrix - recon) / torch.linalg.vector_norm(matrix) ).item() kept = int(singvals.numel()) return { "ctrl": ctrl, "contract_singvals": contract_singvals, "status": status, "median_ms": float("nan") if not times else statistics.median(times) * 1000, "min_ms": float("nan") if not times else min(times) * 1000, "rel_error": rel_error, "kept": kept, "error": error, } def main(): parser = argparse.ArgumentParser() parser.add_argument("--shapes", nargs="+", default=("256x1024", "1024x256", "512x512")) parser.add_argument("--max-bond", type=int, default=128) parser.add_argument("--dtype", choices=("complex64", "complex128", "float32", "float64"), default="complex128") parser.add_argument("--threads", type=int, default=8) parser.add_argument("--repeats", type=int, default=3) parser.add_argument( "--controls", nargs="+", default=("A", "D", "V", "R", "E", "E!", "X", "X!"), ) args = parser.parse_args() torch.set_num_threads(args.threads) dtype = _dtype(args.dtype) print( "svd_benchmark " f"dtype={args.dtype} threads={torch.get_num_threads()} " f"max_bond={args.max_bond} repeats={args.repeats}", flush=True, ) print( "columns shape contract ctrl status median_ms min_ms kept rel_error error", flush=True, ) for shape_text in args.shapes: m_text, n_text = shape_text.lower().split("x", 1) shape = (int(m_text), int(n_text)) matrix = _random_matrix(shape, dtype, seed=sum(shape)) for contract_singvals in ("L", "R"): for ctrl in args.controls: result = run_one( matrix, ctrl=ctrl, max_bond=args.max_bond, contract_singvals=contract_singvals, repeats=args.repeats, ) print( f"row shape={shape_text} " f"contract={contract_singvals} " f"ctrl={ctrl} " f"status={result['status']} " f"median_ms={result['median_ms']:.3f} " f"min_ms={result['min_ms']:.3f} " f"kept={result['kept']} " f"rel_error={result['rel_error']} " f"error={result['error']}", flush=True, ) if __name__ == "__main__": main()