numpy换为torch;修复torch后端计算方式不支持的问题,加速svd;发现并行化err,尝试修复;当前加速比11x,但是最新的并行方式不稳定,可能有精度问题
Some checks failed
Build wheels / build (ubuntu-latest, 3.11) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.12) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.13) (push) Has been cancelled
Tests / check (push) Has been cancelled
Tests / build (ubuntu-latest, 3.11) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.12) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.13) (push) Has been cancelled

This commit is contained in:
2026-05-10 22:33:46 +08:00
parent fea8e5abc0
commit aa122964b4
4 changed files with 638 additions and 58 deletions

View File

@@ -0,0 +1,419 @@
"""Vidal/TEBD MPS executor for layer-parallel circuit simulation.
This module is intentionally small and focused on the circuit family used by the
MPS benchmarks: one-qubit gates and adjacent two-qubit gates on a 1D chain. It
keeps the state in Vidal form, so gates acting on disjoint bonds can be applied
in parallel without moving a global mixed-canonical center.
"""
from __future__ import annotations
from dataclasses import dataclass
import numpy as np
def _backend_module(tensor_module):
if tensor_module == "torch":
import torch
return torch
if tensor_module == "numpy":
return np
raise ValueError(f"Unsupported tensor module {tensor_module!r}.")
def _asarray(xp, value, dtype):
if xp is np:
return np.asarray(value, dtype=dtype)
return xp.as_tensor(value, dtype=dtype)
def _ones(xp, size, dtype, device=None):
if xp is np:
return np.ones(size, dtype=np.float64 if dtype == np.complex128 else np.float32)
real_dtype = xp.float64 if dtype == xp.complex128 else xp.float32
return xp.ones(size, dtype=real_dtype, device=device)
def _eye(xp, size, dtype, device=None):
if xp is np:
return np.eye(size, dtype=dtype)
return xp.eye(size, dtype=dtype, device=device)
def _conj(xp, tensor):
return np.conjugate(tensor) if xp is np else tensor.conj()
def _transpose(xp, tensor, axes):
return np.transpose(tensor, axes) if xp is np else tensor.permute(*axes)
def _vdot_real(xp, left, right):
if xp is np:
return np.vdot(left.reshape(-1), right.reshape(-1)).real
return xp.vdot(left.reshape(-1), right.reshape(-1)).real
def _to_float(x):
if hasattr(x, "detach"):
return float(x.detach().cpu().item())
return float(x)
def _svd(xp, matrix):
return _svd_eigh(xp, matrix)
def _svd_eigh(xp, matrix):
"""SVD through Hermitian eigendecomposition.
This mirrors the E-style path that is fast for the benchmark matrices and
avoids torch's slower general-purpose SVD for many small/medium splits.
"""
m_dim, n_dim = matrix.shape
if m_dim <= n_dim:
gram = matrix @ _conj(xp, matrix).T
eigvals, eigvecs = _eigh(xp, gram)
eigvals, eigvecs = _sort_eigh_desc(xp, eigvals, eigvecs)
singvals = _sqrt_clamped(xp, eigvals)
inv_s = _safe_inverse(xp, singvals)
vh = (_conj(xp, eigvecs).T @ matrix) * inv_s.reshape(-1, 1)
return eigvecs, singvals, vh
gram = _conj(xp, matrix).T @ matrix
eigvals, eigvecs = _eigh(xp, gram)
eigvals, eigvecs = _sort_eigh_desc(xp, eigvals, eigvecs)
singvals = _sqrt_clamped(xp, eigvals)
inv_s = _safe_inverse(xp, singvals)
umat = (matrix @ eigvecs) * inv_s.reshape(1, -1)
return umat, singvals, _conj(xp, eigvecs).T
def _batched_svd_eigh(xp, matrices):
"""Batched EVD SVD for a stack of same-shape torch matrices."""
m_dim, n_dim = matrices.shape[-2:]
if m_dim <= n_dim:
grams = matrices @ _conj(xp, matrices).transpose(-1, -2)
eigvals, eigvecs = xp.linalg.eigh(grams)
idx = xp.argsort(eigvals, dim=-1, descending=True)
eigvals = xp.gather(eigvals, -1, idx)
eigvecs = xp.gather(eigvecs, -1, idx.unsqueeze(-2).expand_as(eigvecs))
singvals = _sqrt_clamped(xp, eigvals)
inv_s = _safe_inverse(xp, singvals)
vhs = (eigvecs.conj().transpose(-1, -2) @ matrices) * inv_s.unsqueeze(-1)
return eigvecs, singvals, vhs
grams = _conj(xp, matrices).transpose(-1, -2) @ matrices
eigvals, eigvecs = xp.linalg.eigh(grams)
idx = xp.argsort(eigvals, dim=-1, descending=True)
eigvals = xp.gather(eigvals, -1, idx)
eigvecs = xp.gather(eigvecs, -1, idx.unsqueeze(-2).expand_as(eigvecs))
singvals = _sqrt_clamped(xp, eigvals)
inv_s = _safe_inverse(xp, singvals)
umats = (matrices @ eigvecs) * inv_s.unsqueeze(-2)
return umats, singvals, eigvecs.conj().transpose(-1, -2)
def _eigh(xp, matrix):
if xp is np:
return np.linalg.eigh(matrix)
return xp.linalg.eigh(matrix)
def _sort_eigh_desc(xp, eigvals, eigvecs):
if xp is np:
idx = np.argsort(eigvals)[::-1].copy()
return eigvals[idx], eigvecs[:, idx]
idx = xp.argsort(eigvals, descending=True)
return eigvals[idx], eigvecs[:, idx]
def _sqrt_clamped(xp, eigvals):
if xp is np:
return np.sqrt(np.maximum(eigvals.real, 0.0))
return xp.sqrt(xp.clamp(eigvals.real, min=0.0))
def _safe_inverse(xp, values):
if xp is np:
return np.where(values > 1e-300, 1.0 / values, 0.0)
return xp.where(values > 1e-300, 1.0 / values, xp.zeros_like(values))
@dataclass
class VidalTEBDExecutor:
nqubits: int
max_bond: int
cut_ratio: float = 1e-12
tensor_module: str = "torch"
workers: int = 1
use_batched: bool = False
def __post_init__(self):
self.xp = _backend_module(self.tensor_module)
if self.xp is np:
self.dtype = np.complex128
self.device = None
else:
self.dtype = self.xp.complex128
self.device = self.xp.device("cpu")
self.gammas = []
for _ in range(self.nqubits):
tensor = _asarray(self.xp, [[[1.0 + 0.0j], [0.0 + 0.0j]]], self.dtype)
self.gammas.append(tensor)
self.lambdas = [
_ones(self.xp, 1, self.dtype, self.device) for _ in range(self.nqubits + 1)
]
def run_circuit(self, circuit):
for batch in _disjoint_batches(circuit.queue):
if (
self.use_batched
and _is_two_qubit_batch(batch)
and self.workers > 1
and len(batch) > 1
):
self.apply_two_site_batch(batch)
else:
for gate in batch:
self._apply_gate(gate)
def _apply_gate(self, gate):
sites = _gate_sites(gate)
matrix = _asarray(self.xp, gate.matrix(), self.dtype)
if len(sites) == 1:
self.apply_one_site(matrix, sites[0])
elif len(sites) == 2:
if abs(sites[0] - sites[1]) != 1:
raise NotImplementedError("VidalTEBDExecutor supports adjacent gates only.")
self.apply_two_site(matrix, sites[0], sites[1])
else:
raise NotImplementedError("Only one- and two-qubit gates are supported.")
def apply_one_site(self, op, pos):
# op[out_phys, in_phys] * gamma[left, in_phys, right]
self.gammas[pos] = self.xp.einsum("st,atb->asb", op, self.gammas[pos])
def apply_two_site(self, op, left_pos, right_pos):
item = self._build_two_site_matrix(op, left_pos, right_pos)
umat, singvals, vh = _svd(self.xp, item["matrix"])
self._install_two_site_split(item, umat, singvals, vh)
def apply_two_site_batch(self, batch):
items = []
for gate in batch:
sites = _gate_sites(gate)
op = _asarray(self.xp, gate.matrix(), self.dtype)
items.append(self._build_two_site_matrix(op, sites[0], sites[1]))
if self.xp is not np:
grouped = {}
for idx, item in enumerate(items):
grouped.setdefault(tuple(item["matrix"].shape), []).append(idx)
for indices in grouped.values():
if len(indices) == 1:
item = items[indices[0]]
umat, singvals, vh = _svd(self.xp, item["matrix"])
item["split"] = (umat, singvals, vh)
continue
matrices = self.xp.stack([items[idx]["matrix"] for idx in indices])
umats, singvals, vhs = _batched_svd_eigh(self.xp, matrices)
for out_idx, item_idx in enumerate(indices):
items[item_idx]["split"] = (
umats[out_idx],
singvals[out_idx],
vhs[out_idx],
)
else:
for item in items:
item["split"] = _svd(self.xp, item["matrix"])
for item in items:
self._install_two_site_split(item, *item["split"])
def _build_two_site_matrix(self, op, left_pos, right_pos):
if left_pos > right_pos:
left_pos, right_pos = right_pos, left_pos
op = _transpose(self.xp, op.reshape(2, 2, 2, 2), (1, 0, 3, 2)).reshape(
4, 4
)
i = left_pos
lam_left = self.lambdas[i]
lam_mid = self.lambdas[i + 1]
lam_right = self.lambdas[i + 2]
gamma_left = self.gammas[i]
gamma_right = self.gammas[i + 1]
theta = self.xp.einsum(
"a,asb,b,btc,c->astc",
lam_left,
gamma_left,
lam_mid,
gamma_right,
lam_right,
)
gate = op.reshape(2, 2, 2, 2)
theta = self.xp.einsum("uvst,astc->auvc", gate, theta)
chi_left = theta.shape[0]
chi_right = theta.shape[3]
matrix = theta.reshape(chi_left * 2, 2 * chi_right)
return {
"site": i,
"chi_left": chi_left,
"chi_right": chi_right,
"lam_left": lam_left,
"lam_right": lam_right,
"matrix": matrix,
}
def _install_two_site_split(self, item, umat, singvals, vh):
keep = self._choose_bond(singvals)
umat = umat[:, :keep]
kept = singvals[:keep]
cut = singvals[keep:]
vh = vh[:keep, :]
if cut.shape[0] > 0:
norm_kept = (kept * kept).sum()
norm_cut = (cut * cut).sum()
kept = kept / self.xp.sqrt(norm_kept / (norm_kept + norm_cut))
new_left = umat.reshape(item["chi_left"], 2, keep)
new_right = vh.reshape(keep, 2, item["chi_right"])
new_left = self._divide_left_lambda(new_left, item["lam_left"])
new_right = self._divide_right_lambda(new_right, item["lam_right"])
i = item["site"]
self.gammas[i] = new_left
self.gammas[i + 1] = new_right
self.lambdas[i + 1] = kept
def _choose_bond(self, singvals):
max_possible = int(singvals.shape[0])
keep = min(max_possible, self.max_bond)
if self.cut_ratio > 0 and max_possible > 0:
threshold = singvals[0] * self.cut_ratio
if self.xp is np:
ratio_keep = int(np.count_nonzero(singvals > threshold))
else:
ratio_keep = int((singvals > threshold).sum().detach().cpu().item())
keep = min(keep, max(1, ratio_keep))
return keep
def _divide_left_lambda(self, tensor, lambdas):
if self.xp is np:
safe = np.where(np.abs(lambdas) > 1e-300, lambdas, 1.0)
else:
safe = self.xp.where(
self.xp.abs(lambdas) > 1e-300,
lambdas,
self.xp.ones_like(lambdas),
)
return tensor / safe.reshape(-1, 1, 1)
def _divide_right_lambda(self, tensor, lambdas):
if self.xp is np:
safe = np.where(np.abs(lambdas) > 1e-300, lambdas, 1.0)
else:
safe = self.xp.where(
self.xp.abs(lambdas) > 1e-300,
lambdas,
self.xp.ones_like(lambdas),
)
return tensor / safe.reshape(1, 1, -1)
def expectation_ring_xz(self):
xmat = _asarray(self.xp, [[0, 1], [1, 0]], self.dtype)
zmat = _asarray(self.xp, [[1, 0], [0, -1]], self.dtype)
value = 0.0
for site in range(self.nqubits - 1):
value += 0.5 * _to_float(self._expect_adjacent(site, xmat, zmat))
value += 0.5 * _to_float(
self.expect_product_operators({self.nqubits - 1: xmat, 0: zmat})
)
return value / self.norm()
def _expect_adjacent(self, site, op_left, op_right):
theta = self.xp.einsum(
"a,asb,b,btc,c->astc",
self.lambdas[site],
self.gammas[site],
self.lambdas[site + 1],
self.gammas[site + 1],
self.lambdas[site + 2],
)
op_theta = self.xp.einsum("us,vt,astc->auvc", op_left, op_right, theta)
return _vdot_real(self.xp, theta, op_theta)
def expect_product_operators(self, operators):
env = _asarray(self.xp, [[1.0 + 0.0j]], self.dtype)
identity = _eye(self.xp, 2, self.dtype, self.device)
for site in range(self.nqubits):
tensor = self.gammas[site] * self.lambdas[site + 1].reshape(1, 1, -1)
op = operators.get(site, identity)
env = self.xp.einsum(
"xy,xsb,st,ytd->bd", env, _conj(self.xp, tensor), op, tensor
)
return env.reshape(-1)[0].real
def norm(self):
return _to_float(self.expect_product_operators({}))
def _gate_sites(gate):
controls = tuple(getattr(gate, "control_qubits", ()))
targets = tuple(getattr(gate, "target_qubits", ()))
if controls:
return controls + targets
return targets
def _disjoint_batches(gates):
batches = []
current = []
touched = set()
for gate in gates:
sites = _gate_sites(gate)
site_set = set(sites)
if current and touched & site_set:
batches.append(current)
current = []
touched = set()
current.append(gate)
touched |= site_set
if current:
batches.append(current)
return batches
def _is_two_qubit_batch(batch):
return batch and all(len(_gate_sites(gate)) == 2 for gate in batch)
def run_vidal_ring_xz(
circuit,
max_bond,
cut_ratio=1e-12,
tensor_module="torch",
workers=1,
use_batched=False,
):
executor = VidalTEBDExecutor(
nqubits=circuit.nqubits,
max_bond=max_bond,
cut_ratio=cut_ratio,
tensor_module=tensor_module,
workers=workers,
use_batched=use_batched,
)
executor.run_circuit(circuit)
return executor.expectation_ring_xz()