"""Vidal/TEBD MPS executor for layer-parallel circuit simulation. This module is intentionally small and focused on the circuit family used by the MPS benchmarks: one-qubit gates and adjacent two-qubit gates on a 1D chain. It keeps the state in Vidal form, so gates acting on disjoint bonds can be applied in parallel without moving a global mixed-canonical center. """ from __future__ import annotations from dataclasses import dataclass import numpy as np def _backend_module(tensor_module): if tensor_module == "torch": import torch return torch if tensor_module == "numpy": return np raise ValueError(f"Unsupported tensor module {tensor_module!r}.") def _asarray(xp, value, dtype): if xp is np: return np.asarray(value, dtype=dtype) return xp.as_tensor(value, dtype=dtype) def _ones(xp, size, dtype, device=None): if xp is np: return np.ones(size, dtype=np.float64 if dtype == np.complex128 else np.float32) real_dtype = xp.float64 if dtype == xp.complex128 else xp.float32 return xp.ones(size, dtype=real_dtype, device=device) def _eye(xp, size, dtype, device=None): if xp is np: return np.eye(size, dtype=dtype) return xp.eye(size, dtype=dtype, device=device) def _conj(xp, tensor): return np.conjugate(tensor) if xp is np else tensor.conj() def _transpose(xp, tensor, axes): return np.transpose(tensor, axes) if xp is np else tensor.permute(*axes) def _vdot(xp, left, right): if xp is np: return np.vdot(left.reshape(-1), right.reshape(-1)) return xp.vdot(left.reshape(-1), right.reshape(-1)) def _to_float(x): if hasattr(x, "detach"): return float(x.detach().cpu().item()) return float(x) def _to_scalar(x): if hasattr(x, "detach"): return x.detach().cpu().item() if isinstance(x, np.ndarray): return x.item() return x def _real_if_close(x, tol=1000): value = np.real_if_close(x, tol=tol) return value.item() if isinstance(value, np.ndarray) else value def _to_numpy(tensor): if hasattr(tensor, "detach"): return tensor.detach().cpu().numpy() return np.asarray(tensor) def _tensor_update_to_numpy(update): result = { "site": int(update["site"]), "left": _to_numpy(update["left"]), "right": _to_numpy(update["right"]), "lambda": _to_numpy(update["lambda"]), } if "truncation_error" in update: result["truncation_error"] = float(update["truncation_error"]) return result def _tensor_update_from_numpy(xp, update, dtype): if xp is np: return update result = { "site": update["site"], "left": _asarray(xp, update["left"], dtype), "right": _asarray(xp, update["right"], dtype), "lambda": xp.as_tensor( update["lambda"], dtype=xp.float64 if dtype == xp.complex128 else xp.float32, ), } if "truncation_error" in update: result["truncation_error"] = float(update["truncation_error"]) return result def _svd(xp, matrix): return _svd_eigh(xp, matrix) def _svd_eigh(xp, matrix): """SVD through Hermitian eigendecomposition. This mirrors the E-style path that is fast for the benchmark matrices and avoids torch's slower general-purpose SVD for many small/medium splits. """ m_dim, n_dim = matrix.shape if m_dim <= n_dim: gram = matrix @ _conj(xp, matrix).T eigvals, eigvecs = _eigh(xp, gram) eigvals, eigvecs = _sort_eigh_desc(xp, eigvals, eigvecs) singvals = _sqrt_clamped(xp, eigvals) inv_s = _safe_inverse(xp, singvals) vh = (_conj(xp, eigvecs).T @ matrix) * inv_s.reshape(-1, 1) return eigvecs, singvals, vh gram = _conj(xp, matrix).T @ matrix eigvals, eigvecs = _eigh(xp, gram) eigvals, eigvecs = _sort_eigh_desc(xp, eigvals, eigvecs) singvals = _sqrt_clamped(xp, eigvals) inv_s = _safe_inverse(xp, singvals) umat = (matrix @ eigvecs) * inv_s.reshape(1, -1) return umat, singvals, _conj(xp, eigvecs).T def _eigh(xp, matrix): if xp is np: return np.linalg.eigh(matrix) return xp.linalg.eigh(matrix) def _sort_eigh_desc(xp, eigvals, eigvecs): if xp is np: return eigvals[::-1].copy(), eigvecs[:, ::-1].copy() return xp.flip(eigvals, dims=(0,)), xp.flip(eigvecs, dims=(1,)) def _sqrt_clamped(xp, eigvals): if xp is np: return np.sqrt(np.maximum(eigvals.real, 0.0)) return xp.sqrt(xp.clamp(eigvals.real, min=0.0)) def _safe_inverse(xp, values): if xp is np: return np.where(values > 1e-300, 1.0 / values, 0.0) return xp.where(values > 1e-300, 1.0 / values, xp.zeros_like(values)) @dataclass class VidalTEBDExecutor: nqubits: int max_bond: int | None cut_ratio: float | None = 1e-12 tensor_module: str = "torch" def __post_init__(self): self.xp = _backend_module(self.tensor_module) if self.xp is np: self.dtype = np.complex128 self.device = None else: self.dtype = self.xp.complex128 self.device = self.xp.device("cpu") self.gammas = [] for _ in range(self.nqubits): tensor = _asarray(self.xp, [[[1.0 + 0.0j], [0.0 + 0.0j]]], self.dtype) self.gammas.append(tensor) self.lambdas = [ _ones(self.xp, 1, self.dtype, self.device) for _ in range(self.nqubits + 1) ] self._accumulated_truncation_error = 0.0 self._max_truncation_error = 0.0 def run_circuit(self, circuit, compile_circuit=True): gates = circuit.queue if compile_circuit: gates = _route_non_adjacent_gates(gates, circuit.nqubits) gates = _fuse_one_site_blocks(gates) for batch in _disjoint_batches(gates): for gate in batch: self._apply_gate(gate) @property def truncation_error(self): return self._accumulated_truncation_error @property def max_truncation_error(self): return self._max_truncation_error def _apply_gate(self, gate): sites = _gate_sites(gate) matrix = _asarray(self.xp, gate.matrix(), self.dtype) if len(sites) == 1: self.apply_one_site(matrix, sites[0]) elif len(sites) == 2: if abs(sites[0] - sites[1]) != 1: raise NotImplementedError("VidalTEBDExecutor supports adjacent gates only.") self.apply_two_site(matrix, sites[0], sites[1]) else: raise NotImplementedError("Only one- and two-qubit gates are supported.") def apply_one_site(self, op, pos): # op[out_phys, in_phys] * gamma[left, in_phys, right] self.gammas[pos] = self.xp.einsum("st,atb->asb", op, self.gammas[pos]) def apply_two_site(self, op, left_pos, right_pos): item = self._build_two_site_matrix(op, left_pos, right_pos) umat, singvals, vh = _svd(self.xp, item["matrix"]) self._install_two_site_split(item, umat, singvals, vh) def _build_two_site_matrix(self, op, left_pos, right_pos): if left_pos > right_pos: left_pos, right_pos = right_pos, left_pos op = _transpose(self.xp, op.reshape(2, 2, 2, 2), (1, 0, 3, 2)).reshape( 4, 4 ) i = left_pos result = _build_theta_svd_matrix( op, self.xp, self.lambdas[i], self.lambdas[i + 1], self.lambdas[i + 2], self.gammas[i], self.gammas[i + 1], ) result["site"] = i result["lam_left"] = self.lambdas[i] result["lam_right"] = self.lambdas[i + 2] return result def _install_two_site_split(self, item, umat, singvals, vh): update = _make_two_site_update(item, umat, singvals, vh, self.max_bond, self.cut_ratio, self.xp) self._accumulated_truncation_error += update["truncation_error"] self._max_truncation_error = max( self._max_truncation_error, update["truncation_error"], ) i = update["site"] self.gammas[i] = update["left"] self.gammas[i + 1] = update["right"] self.lambdas[i + 1] = update["lambda"] def expectation_ring_xz(self): return self.expectation_pauli_sum( [ (0.5, (("X", site), ("Z", (site + 1) % self.nqubits))) for site in range(self.nqubits) ] ) def expectation_pauli_sum(self, terms): paulis = { "I": _eye(self.xp, 2, self.dtype, self.device), "X": _asarray(self.xp, [[0, 1], [1, 0]], self.dtype), "Y": _asarray(self.xp, [[0, -1j], [1j, 0]], self.dtype), "Z": _asarray(self.xp, [[1, 0], [0, -1]], self.dtype), } operator_terms = [ ( coeff, tuple((site, paulis[name.upper()]) for name, site in ops), ) for coeff, ops in terms ] return self.expectation_operator_sum(operator_terms) def expectation_operator_sum(self, terms): value = 0.0 + 0.0j norm = self.norm() for coeff, ops in terms: operators = { int(site): _asarray(self.xp, matrix, self.dtype) for site, matrix in ops } if len(ops) == 0: term_value = norm elif len(operators) == 1: site, matrix = next(iter(operators.items())) term_value = _to_scalar(self._expect_one_site(site, matrix)) elif len(operators) == 2 and abs(max(operators) - min(operators)) == 1: site0, site1 = sorted(operators) term_value = _to_scalar( self._expect_adjacent(site0, operators[site0], operators[site1]) ) else: term_value = _to_scalar(self.expect_product_operators(operators)) value += complex(coeff) * complex(term_value) return _real_if_close(value / norm) def _expect_one_site(self, site, op): theta = self.xp.einsum( "a,asb,b->asb", self.lambdas[site], self.gammas[site], self.lambdas[site + 1], ) op_theta = self.xp.einsum("us,asb->aub", op, theta) return _vdot(self.xp, theta, op_theta) def _expect_adjacent(self, site, op_left, op_right): theta = self.xp.einsum( "a,asb,b,btc,c->astc", self.lambdas[site], self.gammas[site], self.lambdas[site + 1], self.gammas[site + 1], self.lambdas[site + 2], ) op_theta = self.xp.einsum("us,vt,astc->auvc", op_left, op_right, theta) return _vdot(self.xp, theta, op_theta) def expect_product_operators(self, operators): env = _asarray(self.xp, [[1.0 + 0.0j]], self.dtype) identity = _eye(self.xp, 2, self.dtype, self.device) for site in range(self.nqubits): tensor = self.gammas[site] * self.lambdas[site + 1].reshape(1, 1, -1) op = operators.get(site, identity) env = self.xp.einsum( "xy,xsb,st,ytd->bd", env, _conj(self.xp, tensor), op, tensor ) return env.reshape(-1)[0] def norm(self): return float(np.real(_to_scalar(self.expect_product_operators({})))) def expectation_mpo(self, mpo_tensors): """Compute `` / ``. MPO tensors are expected in ``(left_bond, phys_out, phys_in, right_bond)`` order, with physical dimension 2 on every site. """ if len(mpo_tensors) != self.nqubits: raise ValueError( f"Expected {self.nqubits} MPO tensors, got {len(mpo_tensors)}." ) env = _asarray(self.xp, [[[1.0 + 0.0j]]], self.dtype) for site, raw_mpo in enumerate(mpo_tensors): mpo = _asarray(self.xp, raw_mpo, self.dtype) if mpo.ndim != 4 or mpo.shape[1:3] != (2, 2): raise ValueError( "Each MPO tensor must have shape " "(left_bond, 2, 2, right_bond)." ) tensor = self.gammas[site] * self.lambdas[site + 1].reshape(1, 1, -1) env = self.xp.einsum( "xlc,xub,lutr,ctd->brd", env, _conj(self.xp, tensor), mpo, tensor, ) return _real_if_close(_to_scalar(env.reshape(-1)[0]) / self.norm()) def _build_theta_svd_matrix(op, xp, lam_left, lam_mid, lam_right, gamma_left, gamma_right): """Merge and apply a two-site gate, returning the SVD-ready matrix.""" theta = xp.einsum( "a,asb,b,btc,c->astc", lam_left, gamma_left, lam_mid, gamma_right, lam_right, ) gate = op.reshape(2, 2, 2, 2) theta = xp.einsum("uvst,astc->auvc", gate, theta) chi_left = theta.shape[0] chi_right = theta.shape[3] return { "chi_left": chi_left, "chi_right": chi_right, "matrix": theta.reshape(chi_left * 2, 2 * chi_right), } def _choose_bond(singvals, max_bond, cut_ratio, xp): max_possible = int(singvals.shape[0]) keep = max_possible if max_bond is None else min(max_possible, int(max_bond)) if cut_ratio is not None and cut_ratio > 0 and max_possible > 0: threshold = singvals[0] * cut_ratio if xp is np: ratio_keep = int(np.count_nonzero(singvals > threshold)) else: ratio_keep = int((singvals > threshold).sum().detach().cpu().item()) keep = min(keep, max(1, ratio_keep)) return keep def _divide_left_lambda(tensor, lambdas, xp): if xp is np: safe = np.where(np.abs(lambdas) > 1e-300, lambdas, 1.0) else: safe = xp.where(xp.abs(lambdas) > 1e-300, lambdas, xp.ones_like(lambdas)) return tensor / safe.reshape(-1, 1, 1) def _divide_right_lambda(tensor, lambdas, xp): if xp is np: safe = np.where(np.abs(lambdas) > 1e-300, lambdas, 1.0) else: safe = xp.where(xp.abs(lambdas) > 1e-300, lambdas, xp.ones_like(lambdas)) return tensor / safe.reshape(1, 1, -1) def _make_two_site_update(item, umat, singvals, vh, max_bond, cut_ratio, xp): keep = _choose_bond(singvals, max_bond, cut_ratio, xp) umat = umat[:, :keep] kept = singvals[:keep] cut = singvals[keep:] vh = vh[:keep, :] discarded_weight = 0.0 if cut.shape[0] > 0: norm_kept = (kept * kept).sum() norm_cut = (cut * cut).sum() discarded_weight = float(_to_float(norm_cut)) kept = kept / xp.sqrt(norm_kept / (norm_kept + norm_cut)) new_left = umat.reshape(item["chi_left"], 2, keep) new_right = vh.reshape(keep, 2, item["chi_right"]) new_left = _divide_left_lambda(new_left, item["lam_left"], xp) new_right = _divide_right_lambda(new_right, item["lam_right"], xp) return { "site": item["site"], "left": new_left, "right": new_right, "lambda": kept, "truncation_error": discarded_weight, } def _gate_sites(gate): controls = tuple(getattr(gate, "control_qubits", ())) targets = tuple(getattr(gate, "target_qubits", ())) if controls: return controls + targets return targets # ── SWAP routing for non-adjacent two-qubit gates ────────────────────── class _SWAPGate: """Minimal SWAP gate wrapper for routing non-adjacent gates.""" name = "swap" control_qubits = () def __init__(self, left, right): self.target_qubits = (left, right) def matrix(self): return np.array( [[1, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 0, 0, 1]], dtype=complex, ) class _RoutedTwoQubitGate: """Wraps a two-qubit gate with remapped physical sites after SWAP routing.""" name = "routed_two_qubit" control_qubits = () def __init__(self, original_gate, physical_sites): self.target_qubits = tuple(physical_sites) self._matrix = original_gate.matrix() def matrix(self): return self._matrix def _route_non_adjacent_gates(gates, nqubits): """Insert SWAP networks to make all two-qubit gates adjacent. For each non-adjacent two-qubit gate, inserts SWAP gates to bring the farther qubit adjacent, applies the original gate, then inserts reverse SWAPs to restore the qubit ordering. The resulting gate sequence contains only adjacent two-qubit gates and is safe for VidalTEBDExecutor. """ routed = [] for gate in gates: sites = _gate_sites(gate) if len(sites) <= 1: routed.append(gate) continue left, right = sorted(sites) if right - left == 1: routed.append(gate) continue # Move qubit 'right' leftwards to sit at left+1 for pos in range(right - 1, left, -1): routed.append(_SWAPGate(pos, pos + 1)) # Apply the original gate in its original qubit order. For gates like # CNOT(5, 0), sorting the routed sites would swap control and target. physical_map = {left: left, right: left + 1} routed.append(_RoutedTwoQubitGate(gate, [physical_map[site] for site in sites])) # Reverse SWAPs to restore original ordering for pos in range(left + 1, right): routed.append(_SWAPGate(pos, pos + 1)) return routed def _disjoint_batches(gates): batches = [] current = [] touched = set() current_arity = None for gate in gates: sites = _gate_sites(gate) arity = len(sites) site_set = set(sites) if current and (current_arity != arity or touched & site_set): batches.append(current) current = [] touched = set() current_arity = None current.append(gate) touched |= site_set current_arity = arity if current: batches.append(current) return batches def _is_two_qubit_batch(batch): return batch and all(len(_gate_sites(gate)) == 2 for gate in batch) class _FusedOneSiteGate: name = "fused_one_site" def __init__(self, site, matrix): self.target_qubits = (site,) self.control_qubits = () self._matrix = matrix def matrix(self): return self._matrix def _fuse_one_site_blocks(gates): fused = [] block = [] def flush_block(): nonlocal block if not block: return per_site = {} for gate in block: site = _gate_sites(gate)[0] mat = gate.matrix() if site in per_site: per_site[site] = mat @ per_site[site] else: per_site[site] = mat for site in sorted(per_site): fused.append(_FusedOneSiteGate(site, per_site[site])) block = [] for gate in gates: if len(_gate_sites(gate)) == 1: block.append(gate) continue flush_block() fused.append(gate) flush_block() return fused def run_vidal_ring_xz( circuit, max_bond, cut_ratio=1e-12, tensor_module="torch", ): executor = VidalTEBDExecutor( nqubits=circuit.nqubits, max_bond=max_bond, cut_ratio=cut_ratio, tensor_module=tensor_module, ) executor.run_circuit(circuit) return executor.expectation_ring_xz()