numpy换为torch;修复torch后端计算方式不支持的问题,加速svd;发现并行化err,尝试修复;当前加速比11x,但是最新的并行方式不稳定,可能有精度问题
Some checks failed
Build wheels / build (ubuntu-latest, 3.11) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.12) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.13) (push) Has been cancelled
Tests / check (push) Has been cancelled
Tests / build (ubuntu-latest, 3.11) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.12) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.13) (push) Has been cancelled
Some checks failed
Build wheels / build (ubuntu-latest, 3.11) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.12) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.13) (push) Has been cancelled
Tests / check (push) Has been cancelled
Tests / build (ubuntu-latest, 3.11) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.12) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.13) (push) Has been cancelled
This commit is contained in:
419
src/qibotn/backends/vidal_tebd.py
Normal file
419
src/qibotn/backends/vidal_tebd.py
Normal file
@@ -0,0 +1,419 @@
|
||||
"""Vidal/TEBD MPS executor for layer-parallel circuit simulation.
|
||||
|
||||
This module is intentionally small and focused on the circuit family used by the
|
||||
MPS benchmarks: one-qubit gates and adjacent two-qubit gates on a 1D chain. It
|
||||
keeps the state in Vidal form, so gates acting on disjoint bonds can be applied
|
||||
in parallel without moving a global mixed-canonical center.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def _backend_module(tensor_module):
|
||||
if tensor_module == "torch":
|
||||
import torch
|
||||
|
||||
return torch
|
||||
if tensor_module == "numpy":
|
||||
return np
|
||||
raise ValueError(f"Unsupported tensor module {tensor_module!r}.")
|
||||
|
||||
|
||||
def _asarray(xp, value, dtype):
|
||||
if xp is np:
|
||||
return np.asarray(value, dtype=dtype)
|
||||
return xp.as_tensor(value, dtype=dtype)
|
||||
|
||||
|
||||
def _ones(xp, size, dtype, device=None):
|
||||
if xp is np:
|
||||
return np.ones(size, dtype=np.float64 if dtype == np.complex128 else np.float32)
|
||||
real_dtype = xp.float64 if dtype == xp.complex128 else xp.float32
|
||||
return xp.ones(size, dtype=real_dtype, device=device)
|
||||
|
||||
|
||||
def _eye(xp, size, dtype, device=None):
|
||||
if xp is np:
|
||||
return np.eye(size, dtype=dtype)
|
||||
return xp.eye(size, dtype=dtype, device=device)
|
||||
|
||||
|
||||
def _conj(xp, tensor):
|
||||
return np.conjugate(tensor) if xp is np else tensor.conj()
|
||||
|
||||
|
||||
def _transpose(xp, tensor, axes):
|
||||
return np.transpose(tensor, axes) if xp is np else tensor.permute(*axes)
|
||||
|
||||
|
||||
def _vdot_real(xp, left, right):
|
||||
if xp is np:
|
||||
return np.vdot(left.reshape(-1), right.reshape(-1)).real
|
||||
return xp.vdot(left.reshape(-1), right.reshape(-1)).real
|
||||
|
||||
|
||||
def _to_float(x):
|
||||
if hasattr(x, "detach"):
|
||||
return float(x.detach().cpu().item())
|
||||
return float(x)
|
||||
|
||||
|
||||
def _svd(xp, matrix):
|
||||
return _svd_eigh(xp, matrix)
|
||||
|
||||
|
||||
def _svd_eigh(xp, matrix):
|
||||
"""SVD through Hermitian eigendecomposition.
|
||||
|
||||
This mirrors the E-style path that is fast for the benchmark matrices and
|
||||
avoids torch's slower general-purpose SVD for many small/medium splits.
|
||||
"""
|
||||
|
||||
m_dim, n_dim = matrix.shape
|
||||
if m_dim <= n_dim:
|
||||
gram = matrix @ _conj(xp, matrix).T
|
||||
eigvals, eigvecs = _eigh(xp, gram)
|
||||
eigvals, eigvecs = _sort_eigh_desc(xp, eigvals, eigvecs)
|
||||
singvals = _sqrt_clamped(xp, eigvals)
|
||||
inv_s = _safe_inverse(xp, singvals)
|
||||
vh = (_conj(xp, eigvecs).T @ matrix) * inv_s.reshape(-1, 1)
|
||||
return eigvecs, singvals, vh
|
||||
|
||||
gram = _conj(xp, matrix).T @ matrix
|
||||
eigvals, eigvecs = _eigh(xp, gram)
|
||||
eigvals, eigvecs = _sort_eigh_desc(xp, eigvals, eigvecs)
|
||||
singvals = _sqrt_clamped(xp, eigvals)
|
||||
inv_s = _safe_inverse(xp, singvals)
|
||||
umat = (matrix @ eigvecs) * inv_s.reshape(1, -1)
|
||||
return umat, singvals, _conj(xp, eigvecs).T
|
||||
|
||||
|
||||
def _batched_svd_eigh(xp, matrices):
|
||||
"""Batched EVD SVD for a stack of same-shape torch matrices."""
|
||||
|
||||
m_dim, n_dim = matrices.shape[-2:]
|
||||
if m_dim <= n_dim:
|
||||
grams = matrices @ _conj(xp, matrices).transpose(-1, -2)
|
||||
eigvals, eigvecs = xp.linalg.eigh(grams)
|
||||
idx = xp.argsort(eigvals, dim=-1, descending=True)
|
||||
eigvals = xp.gather(eigvals, -1, idx)
|
||||
eigvecs = xp.gather(eigvecs, -1, idx.unsqueeze(-2).expand_as(eigvecs))
|
||||
singvals = _sqrt_clamped(xp, eigvals)
|
||||
inv_s = _safe_inverse(xp, singvals)
|
||||
vhs = (eigvecs.conj().transpose(-1, -2) @ matrices) * inv_s.unsqueeze(-1)
|
||||
return eigvecs, singvals, vhs
|
||||
|
||||
grams = _conj(xp, matrices).transpose(-1, -2) @ matrices
|
||||
eigvals, eigvecs = xp.linalg.eigh(grams)
|
||||
idx = xp.argsort(eigvals, dim=-1, descending=True)
|
||||
eigvals = xp.gather(eigvals, -1, idx)
|
||||
eigvecs = xp.gather(eigvecs, -1, idx.unsqueeze(-2).expand_as(eigvecs))
|
||||
singvals = _sqrt_clamped(xp, eigvals)
|
||||
inv_s = _safe_inverse(xp, singvals)
|
||||
umats = (matrices @ eigvecs) * inv_s.unsqueeze(-2)
|
||||
return umats, singvals, eigvecs.conj().transpose(-1, -2)
|
||||
|
||||
|
||||
def _eigh(xp, matrix):
|
||||
if xp is np:
|
||||
return np.linalg.eigh(matrix)
|
||||
return xp.linalg.eigh(matrix)
|
||||
|
||||
|
||||
def _sort_eigh_desc(xp, eigvals, eigvecs):
|
||||
if xp is np:
|
||||
idx = np.argsort(eigvals)[::-1].copy()
|
||||
return eigvals[idx], eigvecs[:, idx]
|
||||
idx = xp.argsort(eigvals, descending=True)
|
||||
return eigvals[idx], eigvecs[:, idx]
|
||||
|
||||
|
||||
def _sqrt_clamped(xp, eigvals):
|
||||
if xp is np:
|
||||
return np.sqrt(np.maximum(eigvals.real, 0.0))
|
||||
return xp.sqrt(xp.clamp(eigvals.real, min=0.0))
|
||||
|
||||
|
||||
def _safe_inverse(xp, values):
|
||||
if xp is np:
|
||||
return np.where(values > 1e-300, 1.0 / values, 0.0)
|
||||
return xp.where(values > 1e-300, 1.0 / values, xp.zeros_like(values))
|
||||
|
||||
|
||||
@dataclass
|
||||
class VidalTEBDExecutor:
|
||||
nqubits: int
|
||||
max_bond: int
|
||||
cut_ratio: float = 1e-12
|
||||
tensor_module: str = "torch"
|
||||
workers: int = 1
|
||||
use_batched: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
self.xp = _backend_module(self.tensor_module)
|
||||
if self.xp is np:
|
||||
self.dtype = np.complex128
|
||||
self.device = None
|
||||
else:
|
||||
self.dtype = self.xp.complex128
|
||||
self.device = self.xp.device("cpu")
|
||||
|
||||
self.gammas = []
|
||||
for _ in range(self.nqubits):
|
||||
tensor = _asarray(self.xp, [[[1.0 + 0.0j], [0.0 + 0.0j]]], self.dtype)
|
||||
self.gammas.append(tensor)
|
||||
self.lambdas = [
|
||||
_ones(self.xp, 1, self.dtype, self.device) for _ in range(self.nqubits + 1)
|
||||
]
|
||||
|
||||
def run_circuit(self, circuit):
|
||||
for batch in _disjoint_batches(circuit.queue):
|
||||
if (
|
||||
self.use_batched
|
||||
and _is_two_qubit_batch(batch)
|
||||
and self.workers > 1
|
||||
and len(batch) > 1
|
||||
):
|
||||
self.apply_two_site_batch(batch)
|
||||
else:
|
||||
for gate in batch:
|
||||
self._apply_gate(gate)
|
||||
|
||||
def _apply_gate(self, gate):
|
||||
sites = _gate_sites(gate)
|
||||
matrix = _asarray(self.xp, gate.matrix(), self.dtype)
|
||||
if len(sites) == 1:
|
||||
self.apply_one_site(matrix, sites[0])
|
||||
elif len(sites) == 2:
|
||||
if abs(sites[0] - sites[1]) != 1:
|
||||
raise NotImplementedError("VidalTEBDExecutor supports adjacent gates only.")
|
||||
self.apply_two_site(matrix, sites[0], sites[1])
|
||||
else:
|
||||
raise NotImplementedError("Only one- and two-qubit gates are supported.")
|
||||
|
||||
def apply_one_site(self, op, pos):
|
||||
# op[out_phys, in_phys] * gamma[left, in_phys, right]
|
||||
self.gammas[pos] = self.xp.einsum("st,atb->asb", op, self.gammas[pos])
|
||||
|
||||
def apply_two_site(self, op, left_pos, right_pos):
|
||||
item = self._build_two_site_matrix(op, left_pos, right_pos)
|
||||
umat, singvals, vh = _svd(self.xp, item["matrix"])
|
||||
self._install_two_site_split(item, umat, singvals, vh)
|
||||
|
||||
def apply_two_site_batch(self, batch):
|
||||
items = []
|
||||
for gate in batch:
|
||||
sites = _gate_sites(gate)
|
||||
op = _asarray(self.xp, gate.matrix(), self.dtype)
|
||||
items.append(self._build_two_site_matrix(op, sites[0], sites[1]))
|
||||
|
||||
if self.xp is not np:
|
||||
grouped = {}
|
||||
for idx, item in enumerate(items):
|
||||
grouped.setdefault(tuple(item["matrix"].shape), []).append(idx)
|
||||
|
||||
for indices in grouped.values():
|
||||
if len(indices) == 1:
|
||||
item = items[indices[0]]
|
||||
umat, singvals, vh = _svd(self.xp, item["matrix"])
|
||||
item["split"] = (umat, singvals, vh)
|
||||
continue
|
||||
|
||||
matrices = self.xp.stack([items[idx]["matrix"] for idx in indices])
|
||||
umats, singvals, vhs = _batched_svd_eigh(self.xp, matrices)
|
||||
for out_idx, item_idx in enumerate(indices):
|
||||
items[item_idx]["split"] = (
|
||||
umats[out_idx],
|
||||
singvals[out_idx],
|
||||
vhs[out_idx],
|
||||
)
|
||||
else:
|
||||
for item in items:
|
||||
item["split"] = _svd(self.xp, item["matrix"])
|
||||
|
||||
for item in items:
|
||||
self._install_two_site_split(item, *item["split"])
|
||||
|
||||
def _build_two_site_matrix(self, op, left_pos, right_pos):
|
||||
if left_pos > right_pos:
|
||||
left_pos, right_pos = right_pos, left_pos
|
||||
op = _transpose(self.xp, op.reshape(2, 2, 2, 2), (1, 0, 3, 2)).reshape(
|
||||
4, 4
|
||||
)
|
||||
|
||||
i = left_pos
|
||||
lam_left = self.lambdas[i]
|
||||
lam_mid = self.lambdas[i + 1]
|
||||
lam_right = self.lambdas[i + 2]
|
||||
gamma_left = self.gammas[i]
|
||||
gamma_right = self.gammas[i + 1]
|
||||
|
||||
theta = self.xp.einsum(
|
||||
"a,asb,b,btc,c->astc",
|
||||
lam_left,
|
||||
gamma_left,
|
||||
lam_mid,
|
||||
gamma_right,
|
||||
lam_right,
|
||||
)
|
||||
gate = op.reshape(2, 2, 2, 2)
|
||||
theta = self.xp.einsum("uvst,astc->auvc", gate, theta)
|
||||
|
||||
chi_left = theta.shape[0]
|
||||
chi_right = theta.shape[3]
|
||||
matrix = theta.reshape(chi_left * 2, 2 * chi_right)
|
||||
return {
|
||||
"site": i,
|
||||
"chi_left": chi_left,
|
||||
"chi_right": chi_right,
|
||||
"lam_left": lam_left,
|
||||
"lam_right": lam_right,
|
||||
"matrix": matrix,
|
||||
}
|
||||
|
||||
def _install_two_site_split(self, item, umat, singvals, vh):
|
||||
keep = self._choose_bond(singvals)
|
||||
umat = umat[:, :keep]
|
||||
kept = singvals[:keep]
|
||||
cut = singvals[keep:]
|
||||
vh = vh[:keep, :]
|
||||
|
||||
if cut.shape[0] > 0:
|
||||
norm_kept = (kept * kept).sum()
|
||||
norm_cut = (cut * cut).sum()
|
||||
kept = kept / self.xp.sqrt(norm_kept / (norm_kept + norm_cut))
|
||||
|
||||
new_left = umat.reshape(item["chi_left"], 2, keep)
|
||||
new_right = vh.reshape(keep, 2, item["chi_right"])
|
||||
new_left = self._divide_left_lambda(new_left, item["lam_left"])
|
||||
new_right = self._divide_right_lambda(new_right, item["lam_right"])
|
||||
|
||||
i = item["site"]
|
||||
self.gammas[i] = new_left
|
||||
self.gammas[i + 1] = new_right
|
||||
self.lambdas[i + 1] = kept
|
||||
|
||||
def _choose_bond(self, singvals):
|
||||
max_possible = int(singvals.shape[0])
|
||||
keep = min(max_possible, self.max_bond)
|
||||
if self.cut_ratio > 0 and max_possible > 0:
|
||||
threshold = singvals[0] * self.cut_ratio
|
||||
if self.xp is np:
|
||||
ratio_keep = int(np.count_nonzero(singvals > threshold))
|
||||
else:
|
||||
ratio_keep = int((singvals > threshold).sum().detach().cpu().item())
|
||||
keep = min(keep, max(1, ratio_keep))
|
||||
return keep
|
||||
|
||||
def _divide_left_lambda(self, tensor, lambdas):
|
||||
if self.xp is np:
|
||||
safe = np.where(np.abs(lambdas) > 1e-300, lambdas, 1.0)
|
||||
else:
|
||||
safe = self.xp.where(
|
||||
self.xp.abs(lambdas) > 1e-300,
|
||||
lambdas,
|
||||
self.xp.ones_like(lambdas),
|
||||
)
|
||||
return tensor / safe.reshape(-1, 1, 1)
|
||||
|
||||
def _divide_right_lambda(self, tensor, lambdas):
|
||||
if self.xp is np:
|
||||
safe = np.where(np.abs(lambdas) > 1e-300, lambdas, 1.0)
|
||||
else:
|
||||
safe = self.xp.where(
|
||||
self.xp.abs(lambdas) > 1e-300,
|
||||
lambdas,
|
||||
self.xp.ones_like(lambdas),
|
||||
)
|
||||
return tensor / safe.reshape(1, 1, -1)
|
||||
|
||||
def expectation_ring_xz(self):
|
||||
xmat = _asarray(self.xp, [[0, 1], [1, 0]], self.dtype)
|
||||
zmat = _asarray(self.xp, [[1, 0], [0, -1]], self.dtype)
|
||||
value = 0.0
|
||||
for site in range(self.nqubits - 1):
|
||||
value += 0.5 * _to_float(self._expect_adjacent(site, xmat, zmat))
|
||||
value += 0.5 * _to_float(
|
||||
self.expect_product_operators({self.nqubits - 1: xmat, 0: zmat})
|
||||
)
|
||||
return value / self.norm()
|
||||
|
||||
def _expect_adjacent(self, site, op_left, op_right):
|
||||
theta = self.xp.einsum(
|
||||
"a,asb,b,btc,c->astc",
|
||||
self.lambdas[site],
|
||||
self.gammas[site],
|
||||
self.lambdas[site + 1],
|
||||
self.gammas[site + 1],
|
||||
self.lambdas[site + 2],
|
||||
)
|
||||
op_theta = self.xp.einsum("us,vt,astc->auvc", op_left, op_right, theta)
|
||||
return _vdot_real(self.xp, theta, op_theta)
|
||||
|
||||
def expect_product_operators(self, operators):
|
||||
env = _asarray(self.xp, [[1.0 + 0.0j]], self.dtype)
|
||||
identity = _eye(self.xp, 2, self.dtype, self.device)
|
||||
for site in range(self.nqubits):
|
||||
tensor = self.gammas[site] * self.lambdas[site + 1].reshape(1, 1, -1)
|
||||
op = operators.get(site, identity)
|
||||
env = self.xp.einsum(
|
||||
"xy,xsb,st,ytd->bd", env, _conj(self.xp, tensor), op, tensor
|
||||
)
|
||||
return env.reshape(-1)[0].real
|
||||
|
||||
def norm(self):
|
||||
return _to_float(self.expect_product_operators({}))
|
||||
|
||||
|
||||
def _gate_sites(gate):
|
||||
controls = tuple(getattr(gate, "control_qubits", ()))
|
||||
targets = tuple(getattr(gate, "target_qubits", ()))
|
||||
if controls:
|
||||
return controls + targets
|
||||
return targets
|
||||
|
||||
|
||||
def _disjoint_batches(gates):
|
||||
batches = []
|
||||
current = []
|
||||
touched = set()
|
||||
for gate in gates:
|
||||
sites = _gate_sites(gate)
|
||||
site_set = set(sites)
|
||||
if current and touched & site_set:
|
||||
batches.append(current)
|
||||
current = []
|
||||
touched = set()
|
||||
current.append(gate)
|
||||
touched |= site_set
|
||||
if current:
|
||||
batches.append(current)
|
||||
return batches
|
||||
|
||||
|
||||
def _is_two_qubit_batch(batch):
|
||||
return batch and all(len(_gate_sites(gate)) == 2 for gate in batch)
|
||||
|
||||
|
||||
def run_vidal_ring_xz(
|
||||
circuit,
|
||||
max_bond,
|
||||
cut_ratio=1e-12,
|
||||
tensor_module="torch",
|
||||
workers=1,
|
||||
use_batched=False,
|
||||
):
|
||||
executor = VidalTEBDExecutor(
|
||||
nqubits=circuit.nqubits,
|
||||
max_bond=max_bond,
|
||||
cut_ratio=cut_ratio,
|
||||
tensor_module=tensor_module,
|
||||
workers=workers,
|
||||
use_batched=use_batched,
|
||||
)
|
||||
executor.run_circuit(circuit)
|
||||
return executor.expectation_ring_xz()
|
||||
Reference in New Issue
Block a user