完善mps的vidal机制,多节点并行;补充tn搜索时dask集群搜索的方式
Some checks failed
Build wheels / build (ubuntu-latest, 3.11) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.12) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.13) (push) Has been cancelled
Tests / check (push) Has been cancelled
Tests / build (ubuntu-latest, 3.11) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.12) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.13) (push) Has been cancelled

This commit is contained in:
2026-05-12 15:44:19 +08:00
parent aa122964b4
commit 72f95599bb
32 changed files with 3529 additions and 320 deletions

View File

@@ -62,6 +62,41 @@ def _to_float(x):
return float(x)
def _to_numpy(tensor):
if hasattr(tensor, "detach"):
return tensor.detach().cpu().numpy()
return np.asarray(tensor)
def _tensor_update_to_numpy(update):
result = {
"site": int(update["site"]),
"left": _to_numpy(update["left"]),
"right": _to_numpy(update["right"]),
"lambda": _to_numpy(update["lambda"]),
}
if "truncation_error" in update:
result["truncation_error"] = float(update["truncation_error"])
return result
def _tensor_update_from_numpy(xp, update, dtype):
if xp is np:
return update
result = {
"site": update["site"],
"left": _asarray(xp, update["left"], dtype),
"right": _asarray(xp, update["right"], dtype),
"lambda": xp.as_tensor(
update["lambda"],
dtype=xp.float64 if dtype == xp.complex128 else xp.float32,
),
}
if "truncation_error" in update:
result["truncation_error"] = float(update["truncation_error"])
return result
def _svd(xp, matrix):
return _svd_eigh(xp, matrix)
@@ -92,32 +127,6 @@ def _svd_eigh(xp, matrix):
return umat, singvals, _conj(xp, eigvecs).T
def _batched_svd_eigh(xp, matrices):
"""Batched EVD SVD for a stack of same-shape torch matrices."""
m_dim, n_dim = matrices.shape[-2:]
if m_dim <= n_dim:
grams = matrices @ _conj(xp, matrices).transpose(-1, -2)
eigvals, eigvecs = xp.linalg.eigh(grams)
idx = xp.argsort(eigvals, dim=-1, descending=True)
eigvals = xp.gather(eigvals, -1, idx)
eigvecs = xp.gather(eigvecs, -1, idx.unsqueeze(-2).expand_as(eigvecs))
singvals = _sqrt_clamped(xp, eigvals)
inv_s = _safe_inverse(xp, singvals)
vhs = (eigvecs.conj().transpose(-1, -2) @ matrices) * inv_s.unsqueeze(-1)
return eigvecs, singvals, vhs
grams = _conj(xp, matrices).transpose(-1, -2) @ matrices
eigvals, eigvecs = xp.linalg.eigh(grams)
idx = xp.argsort(eigvals, dim=-1, descending=True)
eigvals = xp.gather(eigvals, -1, idx)
eigvecs = xp.gather(eigvecs, -1, idx.unsqueeze(-2).expand_as(eigvecs))
singvals = _sqrt_clamped(xp, eigvals)
inv_s = _safe_inverse(xp, singvals)
umats = (matrices @ eigvecs) * inv_s.unsqueeze(-2)
return umats, singvals, eigvecs.conj().transpose(-1, -2)
def _eigh(xp, matrix):
if xp is np:
return np.linalg.eigh(matrix)
@@ -126,10 +135,8 @@ def _eigh(xp, matrix):
def _sort_eigh_desc(xp, eigvals, eigvecs):
if xp is np:
idx = np.argsort(eigvals)[::-1].copy()
return eigvals[idx], eigvecs[:, idx]
idx = xp.argsort(eigvals, descending=True)
return eigvals[idx], eigvecs[:, idx]
return eigvals[::-1].copy(), eigvecs[:, ::-1].copy()
return xp.flip(eigvals, dims=(0,)), xp.flip(eigvecs, dims=(1,))
def _sqrt_clamped(xp, eigvals):
@@ -150,8 +157,6 @@ class VidalTEBDExecutor:
max_bond: int
cut_ratio: float = 1e-12
tensor_module: str = "torch"
workers: int = 1
use_batched: bool = False
def __post_init__(self):
self.xp = _backend_module(self.tensor_module)
@@ -169,19 +174,20 @@ class VidalTEBDExecutor:
self.lambdas = [
_ones(self.xp, 1, self.dtype, self.device) for _ in range(self.nqubits + 1)
]
self._accumulated_truncation_error = 0.0
def run_circuit(self, circuit):
for batch in _disjoint_batches(circuit.queue):
if (
self.use_batched
and _is_two_qubit_batch(batch)
and self.workers > 1
and len(batch) > 1
):
self.apply_two_site_batch(batch)
else:
for gate in batch:
self._apply_gate(gate)
def run_circuit(self, circuit, compile_circuit=True):
gates = circuit.queue
if compile_circuit:
gates = _route_non_adjacent_gates(gates, circuit.nqubits)
gates = _fuse_one_site_blocks(gates)
for batch in _disjoint_batches(gates):
for gate in batch:
self._apply_gate(gate)
@property
def truncation_error(self):
return self._accumulated_truncation_error
def _apply_gate(self, gate):
sites = _gate_sites(gate)
@@ -204,40 +210,6 @@ class VidalTEBDExecutor:
umat, singvals, vh = _svd(self.xp, item["matrix"])
self._install_two_site_split(item, umat, singvals, vh)
def apply_two_site_batch(self, batch):
items = []
for gate in batch:
sites = _gate_sites(gate)
op = _asarray(self.xp, gate.matrix(), self.dtype)
items.append(self._build_two_site_matrix(op, sites[0], sites[1]))
if self.xp is not np:
grouped = {}
for idx, item in enumerate(items):
grouped.setdefault(tuple(item["matrix"].shape), []).append(idx)
for indices in grouped.values():
if len(indices) == 1:
item = items[indices[0]]
umat, singvals, vh = _svd(self.xp, item["matrix"])
item["split"] = (umat, singvals, vh)
continue
matrices = self.xp.stack([items[idx]["matrix"] for idx in indices])
umats, singvals, vhs = _batched_svd_eigh(self.xp, matrices)
for out_idx, item_idx in enumerate(indices):
items[item_idx]["split"] = (
umats[out_idx],
singvals[out_idx],
vhs[out_idx],
)
else:
for item in items:
item["split"] = _svd(self.xp, item["matrix"])
for item in items:
self._install_two_site_split(item, *item["split"])
def _build_two_site_matrix(self, op, left_pos, right_pos):
if left_pos > right_pos:
left_pos, right_pos = right_pos, left_pos
@@ -246,101 +218,69 @@ class VidalTEBDExecutor:
)
i = left_pos
lam_left = self.lambdas[i]
lam_mid = self.lambdas[i + 1]
lam_right = self.lambdas[i + 2]
gamma_left = self.gammas[i]
gamma_right = self.gammas[i + 1]
theta = self.xp.einsum(
"a,asb,b,btc,c->astc",
lam_left,
gamma_left,
lam_mid,
gamma_right,
lam_right,
result = _build_theta_svd_matrix(
op, self.xp,
self.lambdas[i], self.lambdas[i + 1], self.lambdas[i + 2],
self.gammas[i], self.gammas[i + 1],
)
gate = op.reshape(2, 2, 2, 2)
theta = self.xp.einsum("uvst,astc->auvc", gate, theta)
chi_left = theta.shape[0]
chi_right = theta.shape[3]
matrix = theta.reshape(chi_left * 2, 2 * chi_right)
return {
"site": i,
"chi_left": chi_left,
"chi_right": chi_right,
"lam_left": lam_left,
"lam_right": lam_right,
"matrix": matrix,
}
result["site"] = i
result["lam_left"] = self.lambdas[i]
result["lam_right"] = self.lambdas[i + 2]
return result
def _install_two_site_split(self, item, umat, singvals, vh):
keep = self._choose_bond(singvals)
umat = umat[:, :keep]
kept = singvals[:keep]
cut = singvals[keep:]
vh = vh[:keep, :]
if cut.shape[0] > 0:
norm_kept = (kept * kept).sum()
norm_cut = (cut * cut).sum()
kept = kept / self.xp.sqrt(norm_kept / (norm_kept + norm_cut))
new_left = umat.reshape(item["chi_left"], 2, keep)
new_right = vh.reshape(keep, 2, item["chi_right"])
new_left = self._divide_left_lambda(new_left, item["lam_left"])
new_right = self._divide_right_lambda(new_right, item["lam_right"])
i = item["site"]
self.gammas[i] = new_left
self.gammas[i + 1] = new_right
self.lambdas[i + 1] = kept
def _choose_bond(self, singvals):
max_possible = int(singvals.shape[0])
keep = min(max_possible, self.max_bond)
if self.cut_ratio > 0 and max_possible > 0:
threshold = singvals[0] * self.cut_ratio
if self.xp is np:
ratio_keep = int(np.count_nonzero(singvals > threshold))
else:
ratio_keep = int((singvals > threshold).sum().detach().cpu().item())
keep = min(keep, max(1, ratio_keep))
return keep
def _divide_left_lambda(self, tensor, lambdas):
if self.xp is np:
safe = np.where(np.abs(lambdas) > 1e-300, lambdas, 1.0)
else:
safe = self.xp.where(
self.xp.abs(lambdas) > 1e-300,
lambdas,
self.xp.ones_like(lambdas),
)
return tensor / safe.reshape(-1, 1, 1)
def _divide_right_lambda(self, tensor, lambdas):
if self.xp is np:
safe = np.where(np.abs(lambdas) > 1e-300, lambdas, 1.0)
else:
safe = self.xp.where(
self.xp.abs(lambdas) > 1e-300,
lambdas,
self.xp.ones_like(lambdas),
)
return tensor / safe.reshape(1, 1, -1)
update = _make_two_site_update(item, umat, singvals, vh,
self.max_bond, self.cut_ratio, self.xp)
self._accumulated_truncation_error += update["truncation_error"]
i = update["site"]
self.gammas[i] = update["left"]
self.gammas[i + 1] = update["right"]
self.lambdas[i + 1] = update["lambda"]
def expectation_ring_xz(self):
xmat = _asarray(self.xp, [[0, 1], [1, 0]], self.dtype)
zmat = _asarray(self.xp, [[1, 0], [0, -1]], self.dtype)
value = 0.0
for site in range(self.nqubits - 1):
value += 0.5 * _to_float(self._expect_adjacent(site, xmat, zmat))
value += 0.5 * _to_float(
self.expect_product_operators({self.nqubits - 1: xmat, 0: zmat})
return self.expectation_pauli_sum(
[
(0.5, (("X", site), ("Z", (site + 1) % self.nqubits)))
for site in range(self.nqubits)
]
)
return value / self.norm()
def expectation_pauli_sum(self, terms):
paulis = {
"I": _eye(self.xp, 2, self.dtype, self.device),
"X": _asarray(self.xp, [[0, 1], [1, 0]], self.dtype),
"Y": _asarray(self.xp, [[0, -1j], [1j, 0]], self.dtype),
"Z": _asarray(self.xp, [[1, 0], [0, -1]], self.dtype),
}
value = 0.0
norm = self.norm()
for coeff, ops in terms:
ops = tuple((name.upper(), site) for name, site in ops)
if len(ops) == 0:
term_value = norm
elif len(ops) == 1:
name, site = ops[0]
term_value = _to_float(self._expect_one_site(site, paulis[name]))
elif len(ops) == 2 and abs(ops[0][1] - ops[1][1]) == 1:
(name0, site0), (name1, site1) = sorted(ops, key=lambda item: item[1])
term_value = _to_float(
self._expect_adjacent(site0, paulis[name0], paulis[name1])
)
else:
operators = {site: paulis[name] for name, site in ops}
term_value = _to_float(self.expect_product_operators(operators))
value += float(np.real(coeff)) * term_value
return value / norm
def _expect_one_site(self, site, op):
theta = self.xp.einsum(
"a,asb,b->asb",
self.lambdas[site],
self.gammas[site],
self.lambdas[site + 1],
)
op_theta = self.xp.einsum("us,asb->aub", op, theta)
return _vdot_real(self.xp, theta, op_theta)
def _expect_adjacent(self, site, op_left, op_right):
theta = self.xp.einsum(
@@ -369,6 +309,79 @@ class VidalTEBDExecutor:
return _to_float(self.expect_product_operators({}))
def _build_theta_svd_matrix(op, xp, lam_left, lam_mid, lam_right, gamma_left, gamma_right):
"""Merge and apply a two-site gate, returning the SVD-ready matrix."""
theta = xp.einsum(
"a,asb,b,btc,c->astc",
lam_left, gamma_left, lam_mid, gamma_right, lam_right,
)
gate = op.reshape(2, 2, 2, 2)
theta = xp.einsum("uvst,astc->auvc", gate, theta)
chi_left = theta.shape[0]
chi_right = theta.shape[3]
return {
"chi_left": chi_left,
"chi_right": chi_right,
"matrix": theta.reshape(chi_left * 2, 2 * chi_right),
}
def _choose_bond(singvals, max_bond, cut_ratio, xp):
max_possible = int(singvals.shape[0])
keep = min(max_possible, max_bond)
if cut_ratio > 0 and max_possible > 0:
threshold = singvals[0] * cut_ratio
if xp is np:
ratio_keep = int(np.count_nonzero(singvals > threshold))
else:
ratio_keep = int((singvals > threshold).sum().detach().cpu().item())
keep = min(keep, max(1, ratio_keep))
return keep
def _divide_left_lambda(tensor, lambdas, xp):
if xp is np:
safe = np.where(np.abs(lambdas) > 1e-300, lambdas, 1.0)
else:
safe = xp.where(xp.abs(lambdas) > 1e-300, lambdas, xp.ones_like(lambdas))
return tensor / safe.reshape(-1, 1, 1)
def _divide_right_lambda(tensor, lambdas, xp):
if xp is np:
safe = np.where(np.abs(lambdas) > 1e-300, lambdas, 1.0)
else:
safe = xp.where(xp.abs(lambdas) > 1e-300, lambdas, xp.ones_like(lambdas))
return tensor / safe.reshape(1, 1, -1)
def _make_two_site_update(item, umat, singvals, vh, max_bond, cut_ratio, xp):
keep = _choose_bond(singvals, max_bond, cut_ratio, xp)
umat = umat[:, :keep]
kept = singvals[:keep]
cut = singvals[keep:]
vh = vh[:keep, :]
discarded_weight = 0.0
if cut.shape[0] > 0:
norm_kept = (kept * kept).sum()
norm_cut = (cut * cut).sum()
discarded_weight = float(_to_float(norm_cut))
kept = kept / xp.sqrt(norm_kept / (norm_kept + norm_cut))
new_left = umat.reshape(item["chi_left"], 2, keep)
new_right = vh.reshape(keep, 2, item["chi_right"])
new_left = _divide_left_lambda(new_left, item["lam_left"], xp)
new_right = _divide_right_lambda(new_right, item["lam_right"], xp)
return {
"site": item["site"],
"left": new_left,
"right": new_right,
"lambda": kept,
"truncation_error": discarded_weight,
}
def _gate_sites(gate):
controls = tuple(getattr(gate, "control_qubits", ()))
targets = tuple(getattr(gate, "target_qubits", ()))
@@ -377,19 +390,90 @@ def _gate_sites(gate):
return targets
# ── SWAP routing for non-adjacent two-qubit gates ──────────────────────
class _SWAPGate:
"""Minimal SWAP gate wrapper for routing non-adjacent gates."""
name = "swap"
control_qubits = ()
def __init__(self, left, right):
self.target_qubits = (left, right)
def matrix(self):
return np.array(
[[1, 0, 0, 0],
[0, 0, 1, 0],
[0, 1, 0, 0],
[0, 0, 0, 1]],
dtype=complex,
)
class _RoutedTwoQubitGate:
"""Wraps a two-qubit gate with remapped physical sites after SWAP routing."""
name = "routed_two_qubit"
control_qubits = ()
def __init__(self, original_gate, left_phys, right_phys):
self.target_qubits = (left_phys, right_phys)
self._matrix = original_gate.matrix()
def matrix(self):
return self._matrix
def _route_non_adjacent_gates(gates, nqubits):
"""Insert SWAP networks to make all two-qubit gates adjacent.
For each non-adjacent two-qubit gate, inserts SWAP gates to bring the
farther qubit adjacent, applies the original gate, then inserts reverse
SWAPs to restore the qubit ordering. The resulting gate sequence
contains only adjacent two-qubit gates and is safe for VidalTEBDExecutor.
"""
routed = []
for gate in gates:
sites = _gate_sites(gate)
if len(sites) <= 1:
routed.append(gate)
continue
left, right = sorted(sites)
if right - left == 1:
routed.append(gate)
continue
# Move qubit 'right' leftwards to sit at left+1
for pos in range(right - 1, left, -1):
routed.append(_SWAPGate(pos, pos + 1))
# Apply the original gate at the adjacent physical positions
routed.append(_RoutedTwoQubitGate(gate, left, left + 1))
# Reverse SWAPs to restore original ordering
for pos in range(left + 1, right):
routed.append(_SWAPGate(pos, pos + 1))
return routed
def _disjoint_batches(gates):
batches = []
current = []
touched = set()
current_arity = None
for gate in gates:
sites = _gate_sites(gate)
arity = len(sites)
site_set = set(sites)
if current and touched & site_set:
if current and (current_arity != arity or touched & site_set):
batches.append(current)
current = []
touched = set()
current_arity = None
current.append(gate)
touched |= site_set
current_arity = arity
if current:
batches.append(current)
return batches
@@ -399,21 +483,59 @@ def _is_two_qubit_batch(batch):
return batch and all(len(_gate_sites(gate)) == 2 for gate in batch)
class _FusedOneSiteGate:
name = "fused_one_site"
def __init__(self, site, matrix):
self.target_qubits = (site,)
self.control_qubits = ()
self._matrix = matrix
def matrix(self):
return self._matrix
def _fuse_one_site_blocks(gates):
fused = []
block = []
def flush_block():
nonlocal block
if not block:
return
per_site = {}
for gate in block:
site = _gate_sites(gate)[0]
mat = gate.matrix()
if site in per_site:
per_site[site] = mat @ per_site[site]
else:
per_site[site] = mat
for site in sorted(per_site):
fused.append(_FusedOneSiteGate(site, per_site[site]))
block = []
for gate in gates:
if len(_gate_sites(gate)) == 1:
block.append(gate)
continue
flush_block()
fused.append(gate)
flush_block()
return fused
def run_vidal_ring_xz(
circuit,
max_bond,
cut_ratio=1e-12,
tensor_module="torch",
workers=1,
use_batched=False,
):
executor = VidalTEBDExecutor(
nqubits=circuit.nqubits,
max_bond=max_bond,
cut_ratio=cut_ratio,
tensor_module=tensor_module,
workers=workers,
use_batched=use_batched,
)
executor.run_circuit(circuit)
return executor.expectation_ring_xz()