Files
jaunatisblue 28080dff1d
Some checks failed
Build wheels / build (ubuntu-latest, 3.11) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.12) (push) Has been cancelled
Build wheels / build (ubuntu-latest, 3.13) (push) Has been cancelled
Tests / check (push) Has been cancelled
Tests / build (ubuntu-latest, 3.11) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.12) (push) Has been cancelled
Tests / build (ubuntu-latest, 3.13) (push) Has been cancelled
简化代码;加入.venv下内容
2026-05-18 02:47:40 +08:00

1289 lines
38 KiB
Python

"""Functionality relating to actually contracting."""
import functools
import itertools
import operator
import contextlib
import os
from autoray import do, shape, infer_backend_multi, get_lib_fn
from .utils import node_from_single
DEFAULT_IMPLEMENTATION = "auto"
def _torch_workspace_enabled():
value = os.environ.get("QIBOTN_TORCH_WORKSPACE", "0").strip().lower()
return value in {"1", "true", "yes", "on", "enable", "enabled"}
def _torch_arena_enabled():
value = os.environ.get("QIBOTN_TORCH_ARENA", "0").strip().lower()
return value in {"1", "true", "yes", "on", "enable", "enabled"}
def _parse_size_bytes(value, default):
if value is None:
return default
text = str(value).strip().lower()
scale = 1
for suffix, multiplier in (
("gib", 1024**3),
("gb", 1000**3),
("gi", 1024**3),
("g", 1024**3),
("mib", 1024**2),
("mb", 1000**2),
("mi", 1024**2),
("m", 1024**2),
):
if text.endswith(suffix):
text = text[: -len(suffix)]
scale = multiplier
break
return int(float(text) * scale)
class _TorchArena:
"""Simple first-fit arena for torch CPU contraction intermediates."""
def __init__(self, size_bytes=None):
self.size_bytes = _parse_size_bytes(
size_bytes or os.environ.get("QIBOTN_TORCH_ARENA_BYTES"),
70 * 1024**3,
)
self.buffer = None
self.dtype = None
self.device = None
self.free = []
self.allocated = {}
def _setup(self, dtype, device):
import torch
element_size = torch.empty((), dtype=dtype, device=device).element_size()
nelements = self.size_bytes // element_size
if nelements <= 0:
raise MemoryError("QIBOTN_TORCH_ARENA_BYTES is too small.")
self.buffer = torch.empty(nelements, dtype=dtype, device=device)
self.dtype = dtype
self.device = device
self.free = [(0, nelements)]
self.allocated = {}
def alloc(self, shape, dtype, device):
import torch
if self.buffer is None:
self._setup(dtype, device)
elif (dtype != self.dtype) or (device != self.device):
return torch.empty(shape, dtype=dtype, device=device)
numel = functools.reduce(operator.mul, shape, 1)
element_size = self.buffer.element_size()
align = max(1, 64 // element_size)
for i, (offset, length) in enumerate(self.free):
aligned = ((offset + align - 1) // align) * align
padding = aligned - offset
if length - padding < numel:
continue
new_blocks = []
if padding:
new_blocks.append((offset, padding))
tail_offset = aligned + numel
tail_length = (offset + length) - tail_offset
if tail_length:
new_blocks.append((tail_offset, tail_length))
self.free[i : i + 1] = new_blocks
self.allocated[aligned] = numel
return self.buffer.narrow(0, aligned, numel).view(shape)
return torch.empty(shape, dtype=dtype, device=device)
def release(self, x):
if self.buffer is None:
return
if not hasattr(x, "untyped_storage"):
return
try:
if x.untyped_storage().data_ptr() != self.buffer.untyped_storage().data_ptr():
return
offset = int(x.storage_offset())
except Exception:
return
length = self.allocated.pop(offset, None)
if length is None:
return
self.free.append((offset, length))
self.free.sort()
merged = []
for block_offset, block_length in self.free:
if merged and merged[-1][0] + merged[-1][1] == block_offset:
prev_offset, prev_length = merged[-1]
merged[-1] = (prev_offset, prev_length + block_length)
else:
merged.append((block_offset, block_length))
self.free = merged
def set_default_implementation(impl):
global DEFAULT_IMPLEMENTATION
DEFAULT_IMPLEMENTATION = impl
def get_default_implementation():
return DEFAULT_IMPLEMENTATION
@contextlib.contextmanager
def default_implementation(impl):
"""Context manager for temporarily setting the default implementation."""
global DEFAULT_IMPLEMENTATION
old_impl = DEFAULT_IMPLEMENTATION
DEFAULT_IMPLEMENTATION = impl
try:
yield
finally:
DEFAULT_IMPLEMENTATION = old_impl
@functools.lru_cache(2**12)
def _sanitize_equation(eq):
"""Get the input and output indices of an equation, computing the output
implicitly as the sorted sequence of every index that appears exactly once
if it is not provided.
"""
# remove spaces
eq = eq.replace(" ", "")
if "..." in eq:
raise NotImplementedError("Ellipsis not supported.")
if "->" not in eq:
lhs = eq
tmp_subscripts = lhs.replace(",", "")
out = "".join(
# sorted sequence of indices
s
for s in sorted(set(tmp_subscripts))
# that appear exactly once
if tmp_subscripts.count(s) == 1
)
else:
lhs, out = eq.split("->")
return lhs, out
@functools.lru_cache(2**12)
def _parse_einsum_single(eq, shape):
"""Cached parsing of a single term einsum equation into the necessary
sequence of arguments for axes diagonals, sums, and transposes.
"""
lhs, out = _sanitize_equation(eq)
# parse each index
need_to_diag = []
need_to_sum = []
seen = set()
for ix in lhs:
if ix in need_to_diag:
continue
if ix in seen:
need_to_diag.append(ix)
continue
seen.add(ix)
if ix not in out:
need_to_sum.append(ix)
# first handle diagonal reductions
if need_to_diag:
diag_sels = []
sizes = dict(zip(lhs, shape))
while need_to_diag:
ixd = need_to_diag.pop()
dinds = tuple(range(sizes[ixd]))
# construct advanced indexing object
selector = tuple(dinds if ix == ixd else slice(None) for ix in lhs)
diag_sels.append(selector)
# after taking the diagonal what are new indices?
ixd_contig = ixd * lhs.count(ixd)
if ixd_contig in lhs:
# contig axes, new axis is at same position
lhs = lhs.replace(ixd_contig, ixd)
else:
# non-contig, new axis is at beginning
lhs = ixd + lhs.replace(ixd, "")
else:
diag_sels = None
# then sum reductions
if need_to_sum:
sum_axes = tuple(map(lhs.index, need_to_sum))
for ix in need_to_sum:
lhs = lhs.replace(ix, "")
else:
sum_axes = None
# then transposition
if lhs == out:
perm = None
else:
perm = tuple(lhs.index(ix) for ix in out)
return diag_sels, sum_axes, perm
def _parse_eq_to_pure_multiplication(a_term, shape_a, b_term, shape_b, out):
"""If there are no contracted indices, then we can directly transpose and
insert singleton dimensions into ``a`` and ``b`` such that (broadcast)
elementwise multiplication performs the einsum.
No need to cache this as it is within the cached
``_parse_eq_to_batch_matmul``.
"""
desired_a = ""
desired_b = ""
new_shape_a = []
new_shape_b = []
for ix in out:
if ix in a_term:
desired_a += ix
new_shape_a.append(shape_a[a_term.index(ix)])
else:
new_shape_a.append(1)
if ix in b_term:
desired_b += ix
new_shape_b.append(shape_b[b_term.index(ix)])
else:
new_shape_b.append(1)
if desired_a != a_term:
eq_a = f"{a_term}->{desired_a}"
else:
eq_a = None
if desired_b != b_term:
eq_b = f"{b_term}->{desired_b}"
else:
eq_b = None
return (
eq_a,
eq_b,
new_shape_a,
new_shape_b,
None, # new_shape_ab, not needed since not fusing
None, # perm_ab, not needed as we transpose a and b first
True, # pure_multiplication=True
)
@functools.lru_cache(2**12)
def _parse_eq_to_batch_matmul(eq, shape_a, shape_b):
"""Cached parsing of a two term einsum equation into the necessary
sequence of arguments for contracttion via batched matrix multiplication.
The steps we need to specify are:
1. Remove repeated and trivial indices from the left and right terms,
and transpose them, done as a single einsum.
2. Fuse the remaining indices so we have two 3D tensors.
3. Perform the batched matrix multiplication.
4. Unfuse the output to get the desired final index order.
"""
lhs, out = eq.split("->")
a_term, b_term = lhs.split(",")
if len(a_term) != len(shape_a):
raise ValueError(f"Term '{a_term}' does not match shape {shape_a}.")
if len(b_term) != len(shape_b):
raise ValueError(f"Term '{b_term}' does not match shape {shape_b}.")
bat_inds = [] # appears on A, B, O
con_inds = [] # appears on A, B, .
a_keep = [] # appears on A, ., O
b_keep = [] # appears on ., B, O
sizes = {}
singletons = set()
# parse left term
seen = set()
for ix, d in zip(a_term, shape_a):
if d == 1:
# everything (including broadcasting) works nicely if simply ignore
# such dimensions, but we do need to track if they appear in output
# and thus should be reintroduced later
singletons.add(ix)
continue
# set or check size
if sizes.setdefault(ix, d) != d:
raise ValueError(
f"Index {ix} has mismatched sizes {sizes[ix]} and {d}."
)
if ix in seen:
continue
seen.add(ix)
if ix in b_term:
if ix in out:
bat_inds.append(ix)
else:
con_inds.append(ix)
elif ix in out:
a_keep.append(ix)
# parse right term
seen.clear()
for ix, d in zip(b_term, shape_b):
if d == 1:
singletons.add(ix)
continue
# broadcast indices don't appear as singletons in output
singletons.discard(ix)
# set or check size
if sizes.setdefault(ix, d) != d:
raise ValueError(
f"Index {ix} has mismatched sizes {sizes[ix]} and {d}."
)
if ix in seen:
continue
seen.add(ix)
if ix not in a_term:
if ix in out:
b_keep.append(ix)
if not con_inds:
# contraction is pure multiplication, prepare inputs differently
return _parse_eq_to_pure_multiplication(
a_term, shape_a, b_term, shape_b, out
)
# only need the size one indices that appear in the output
singletons = [ix for ix in out if ix in singletons]
# take diagonal, remove any trivial axes and transpose left
desired_a = "".join((*bat_inds, *a_keep, *con_inds))
if a_term != desired_a:
if set(a_term) == set(desired_a):
# only need to transpose, don't invoke einsum
eq_a = tuple(a_term.index(ix) for ix in desired_a)
else:
eq_a = f"{a_term}->{desired_a}"
else:
eq_a = None
# take diagonal, remove any trivial axes and transpose right
desired_b = "".join((*bat_inds, *con_inds, *b_keep))
if b_term != desired_b:
if set(b_term) == set(desired_b):
# only need to transpose, don't invoke einsum
eq_b = tuple(b_term.index(ix) for ix in desired_b)
else:
eq_b = f"{b_term}->{desired_b}"
else:
eq_b = None
# then we want to reshape
if bat_inds:
lgroups = (bat_inds, a_keep, con_inds)
rgroups = (bat_inds, con_inds, b_keep)
ogroups = (bat_inds, a_keep, b_keep)
else:
# avoid size 1 batch dimension if no batch indices
lgroups = (a_keep, con_inds)
rgroups = (con_inds, b_keep)
ogroups = (a_keep, b_keep)
if any(len(group) != 1 for group in lgroups):
# need to fuse 'kept' and contracted indices
# (though could allow batch indices to be broadcast)
new_shape_a = tuple(
functools.reduce(operator.mul, (sizes[ix] for ix in ix_group), 1)
for ix_group in lgroups
)
else:
new_shape_a = None
if any(len(group) != 1 for group in rgroups):
# need to fuse 'kept' and contracted indices
# (though could allow batch indices to be broadcast)
new_shape_b = tuple(
functools.reduce(operator.mul, (sizes[ix] for ix in ix_group), 1)
for ix_group in rgroups
)
else:
new_shape_b = None
if any(len(group) != 1 for group in ogroups) or singletons:
new_shape_ab = (1,) * len(singletons) + tuple(
sizes[ix] for ix_group in ogroups for ix in ix_group
)
else:
new_shape_ab = None
# then we want to permute the matmul produced output:
out_produced = "".join((*singletons, *bat_inds, *a_keep, *b_keep))
perm_ab = tuple(out_produced.index(ix) for ix in out)
if perm_ab == tuple(range(len(perm_ab))):
perm_ab = None
return (
eq_a,
eq_b,
new_shape_a,
new_shape_b,
new_shape_ab,
perm_ab,
False, # pure_multiplication=False
)
def _einsum_single(eq, x, backend=None):
"""Einsum on a single tensor, via three steps: diagonal selection
(via advanced indexing), axes summations, transposition. The logic for each
is cached based on the equation and array shape, and each step is only
performed if necessary.
"""
try:
return do("einsum", eq, x, like=backend)
except ImportError:
pass
diag_sels, sum_axes, perm = _parse_einsum_single(eq, shape(x))
if diag_sels is not None:
# diagonal reduction via advanced indexing
# e.g ababbac->abc
for selector in diag_sels:
x = x[selector]
if sum_axes is not None:
# trivial removal of axes via summation
# e.g. abc->c
x = do("sum", x, sum_axes, like=backend)
if perm is not None:
# transpose to desired output
# e.g. abc->cba
x = do("transpose", x, perm, like=backend)
return x
def _do_contraction_via_bmm(
a,
b,
eq_a,
eq_b,
new_shape_a,
new_shape_b,
new_shape_ab,
perm_ab,
pure_multiplication,
backend,
):
# prepare left
if eq_a is not None:
if isinstance(eq_a, tuple):
# only transpose
a = do("transpose", a, eq_a, like=backend)
else:
# diagonals, sums, and tranpose
a = _einsum_single(eq_a, a)
if new_shape_a is not None:
a = do("reshape", a, new_shape_a, like=backend)
# prepare right
if eq_b is not None:
if isinstance(eq_b, tuple):
# only transpose
b = do("transpose", b, eq_b, like=backend)
else:
# diagonals, sums, and tranpose
b = _einsum_single(eq_b, b)
if new_shape_b is not None:
b = do("reshape", b, new_shape_b, like=backend)
if pure_multiplication:
# no contracted indices
return do("multiply", a, b)
# do the contraction!
ab = do("matmul", a, b, like=backend)
# prepare the output
if new_shape_ab is not None:
ab = do("reshape", ab, new_shape_ab, like=backend)
if perm_ab is not None:
ab = do("transpose", ab, perm_ab, like=backend)
return ab
def _torch_workspace_pop(shape, dtype, device):
try:
pool = _TORCH_WORKSPACE_POOL
except NameError:
return None
try:
return pool[(tuple(shape), dtype, device)].pop()
except (KeyError, IndexError):
return None
def _torch_workspace_push(x):
try:
pool = _TORCH_WORKSPACE_POOL
except NameError:
return
if not x.is_contiguous():
return
pool.setdefault((tuple(x.shape), x.dtype, x.device), []).append(x)
def _torch_matmul_workspace(a, b):
import torch
shape = torch.broadcast_shapes(a.shape[:-2], b.shape[:-2]) + (
a.shape[-2],
b.shape[-1],
)
try:
arena = _TORCH_ARENA
except NameError:
arena = None
if arena is not None:
out = arena.alloc(shape, a.dtype, a.device)
else:
out = _torch_workspace_pop(shape, a.dtype, a.device)
if out is None:
out = torch.empty(shape, dtype=a.dtype, device=a.device)
if a.ndim == 2 and b.ndim == 2:
torch.mm(a, b, out=out)
elif a.ndim == 3 and b.ndim == 3:
torch.bmm(a, b, out=out)
else:
torch.matmul(a, b, out=out)
return out
def _torch_multiply_workspace(a, b):
import torch
shape = torch.broadcast_shapes(a.shape, b.shape)
try:
arena = _TORCH_ARENA
except NameError:
arena = None
if arena is not None:
out = arena.alloc(shape, a.dtype, a.device)
torch.mul(a, b, out=out)
return out
return do("multiply", a, b)
def _torch_reshape_workspace(x, new_shape, backend):
try:
arena = _TORCH_ARENA
except NameError:
arena = None
if arena is None:
return do("reshape", x, new_shape, like=backend)
if not hasattr(x, "view"):
return do("reshape", x, new_shape, like=backend)
try:
return x.view(new_shape)
except RuntimeError:
pass
try:
out = arena.alloc(tuple(new_shape), x.dtype, x.device)
out.view(tuple(x.shape)).copy_(x)
return out
except Exception:
return do("reshape", x, new_shape, like=backend)
def _do_contraction_via_bmm_torch_workspace(
a,
b,
eq_a,
eq_b,
new_shape_a,
new_shape_b,
new_shape_ab,
perm_ab,
pure_multiplication,
backend,
):
import torch
if eq_a is not None:
if isinstance(eq_a, tuple):
a = do("transpose", a, eq_a, like=backend)
else:
a = _einsum_single(eq_a, a)
if new_shape_a is not None:
a = _torch_reshape_workspace(a, new_shape_a, backend)
if eq_b is not None:
if isinstance(eq_b, tuple):
b = do("transpose", b, eq_b, like=backend)
else:
b = _einsum_single(eq_b, b)
if new_shape_b is not None:
b = _torch_reshape_workspace(b, new_shape_b, backend)
if pure_multiplication:
return _torch_multiply_workspace(a, b)
ab = _torch_matmul_workspace(a, b)
if new_shape_ab is not None:
ab = _torch_reshape_workspace(ab, new_shape_ab, backend)
if perm_ab is not None:
ab = do("transpose", ab, perm_ab, like=backend)
return ab
def einsum(eq, a, b=None, *, backend=None):
"""Perform arbitrary single and pairwise einsums using only `matmul`,
`transpose`, `reshape` and `sum`. The logic for each is cached based on
the equation and array shape, and each step is only performed if necessary.
Parameters
----------
eq : str
The einsum equation.
a : array_like
The first array to contract.
b : array_like, optional
The second array to contract.
backend : str, optional
The backend to use for array operations. If ``None``, dispatch
automatically based on ``a`` and ``b``.
Returns
-------
array_like
"""
if b is None:
return _einsum_single(eq, a, backend=backend)
(
eq_a,
eq_b,
new_shape_a,
new_shape_b,
new_shape_ab,
perm_ab,
pure_multiplication,
) = _parse_eq_to_batch_matmul(eq, shape(a), shape(b))
do_contraction = (
_do_contraction_via_bmm_torch_workspace
if backend == "torch"
and (_torch_workspace_enabled() or _torch_arena_enabled())
else _do_contraction_via_bmm
)
return do_contraction(
a,
b,
eq_a,
eq_b,
new_shape_a,
new_shape_b,
new_shape_ab,
perm_ab,
pure_multiplication,
backend,
)
def gen_nice_inds():
"""Generate the indices from [a-z, A-Z, reasonable unicode...]."""
for i in range(26):
yield chr(ord("a") + i)
for i in range(26):
yield chr(ord("A") + i)
for i in itertools.count(192):
yield chr(i)
@functools.lru_cache(2**12)
def _parse_tensordot_axes_to_matmul(axes, shape_a, shape_b):
"""Parse a tensordot specification into the necessary sequence of arguments
for contracttion via matrix multiplication. This just converts ``axes``
into an ``einsum`` eq string then calls ``_parse_eq_to_batch_matmul``.
"""
ndim_a = len(shape_a)
ndim_b = len(shape_b)
if isinstance(axes, int):
axes_a = tuple(range(ndim_a - axes, ndim_a))
axes_b = tuple(range(axes))
else:
axes_a, axes_b = axes
num_con = len(axes_a)
if num_con != len(axes_b):
raise ValueError(
f"Axes should have the same length, got {axes_a} and {axes_b}."
)
possible_inds = gen_nice_inds()
inds_a = [next(possible_inds) for _ in range(ndim_a)]
inds_b = []
inds_out = inds_a.copy()
for axb in range(ndim_b):
if axb not in axes_b:
# right uncontracted index
ind = next(possible_inds)
inds_out.append(ind)
else:
# contracted index
axa = axes_a[axes_b.index(axb)]
# check that the shapes match
if shape_a[axa] != shape_b[axb]:
raise ValueError(
f"Dimension mismatch between axes {axa} of {shape_a} and "
f"{axb} of {shape_b}: {shape_a[axa]} != {shape_b[axb]}."
)
ind = inds_a[axa]
inds_out.remove(ind)
inds_b.append(ind)
eq = f"{''.join(inds_a)},{''.join(inds_b)}->{''.join(inds_out)}"
return _parse_eq_to_batch_matmul(eq, shape_a, shape_b)
def tensordot(a, b, axes=2, *, backend=None):
"""Perform a tensordot using only `matmul`, `transpose`, `reshape`. The
logic for each is cached based on the equation and array shape, and each
step is only performed if necessary.
Parameters
----------
a, b : array_like
The arrays to contract.
axes : int or tuple of (sequence[int], sequence[int])
The number of axes to contract, or the axes to contract. If an int,
the last ``axes`` axes of ``a`` and the first ``axes`` axes of ``b``
are contracted. If a tuple, the axes to contract for ``a`` and ``b``
respectively.
backend : str or None, optional
The backend to use for array operations. If ``None``, dispatch
automatically based on ``a`` and ``b``.
Returns
-------
array_like
"""
try:
# ensure hashable
axes = tuple(map(int, axes[0])), tuple(map(int, axes[1]))
except IndexError:
axes = int(axes)
(
eq_a,
eq_b,
new_shape_a,
new_shape_b,
new_shape_ab,
perm_ab,
pure_multiplication,
) = _parse_tensordot_axes_to_matmul(axes, shape(a), shape(b))
do_contraction = (
_do_contraction_via_bmm_torch_workspace
if backend == "torch"
and (_torch_workspace_enabled() or _torch_arena_enabled())
else _do_contraction_via_bmm
)
return do_contraction(
a,
b,
eq_a,
eq_b,
new_shape_a,
new_shape_b,
new_shape_ab,
perm_ab,
pure_multiplication,
backend,
)
def extract_contractions(
tree,
order=None,
prefer_einsum=False,
):
"""Extract just the information needed to perform the contraction.
Parameters
----------
order : str or callable, optional
Supplied to :meth:`ContractionTree.traverse`.
prefer_einsum : bool, optional
Prefer to use ``einsum`` for pairwise contractions, even if
``tensordot`` can perform the contraction.
Returns
-------
contractions : tuple
A tuple of tuples, each containing the information needed to
perform a pairwise contraction. Each tuple contains:
- ``p``: the parent node,
- ``l``: the left child node,
- ``r``: the right child node,
- ``tdot``: whether to use ``tensordot`` or ``einsum``,
- ``arg``: the argument to pass to ``tensordot`` or ``einsum``
i.e. ``axes`` or ``eq``,
- ``perm``: the permutation required after the contraction, if
any (only applies to tensordot).
If both ``l`` and ``r`` are ``None``, the the operation is a single
term simplification performed with ``einsum``.
"""
contractions = []
# pairwise contractions
contractions.extend(
(p, l, r, False, tree.get_einsum_eq(p), None)
if (prefer_einsum or not tree.get_can_dot(p))
else (
p,
l,
r,
True,
tree.get_tensordot_axes(p),
tree.get_tensordot_perm(p),
)
for p, l, r in tree.traverse(order=order)
)
if tree.preprocessing:
# inplace single term simplifications
# n.b. these are populated lazily when the other information is
# computed above, so we do it after
pre_contractions = (
(node_from_single(i), None, None, False, eq, None)
for i, eq in tree.preprocessing.items()
)
return (*pre_contractions, *contractions)
return tuple(contractions)
class Contractor:
"""Default cotengra network contractor.
Parameters
----------
contractions : tuple[tuple]
The sequence of contractions to perform. Each contraction should be a
tuple containing:
- ``p``: the parent node,
- ``l``: the left child node,
- ``r``: the right child node,
- ``tdot``: whether to use ``tensordot`` or ``einsum``,
- ``arg``: the argument to pass to ``tensordot`` or ``einsum``
i.e. ``axes`` or ``eq``,
- ``perm``: the permutation required after the contraction, if
any (only applies to tensordot).
e.g. built by calling ``extract_contractions(tree)``.
strip_exponent : bool, optional
If ``True``, eagerly strip the exponent (in log10) from
intermediate tensors to control numerical problems from leaving the
range of the datatype. This method then returns the scaled
'mantissa' output array and the exponent separately.
check_zero : bool, optional
If ``True``, when ``strip_exponent=True``, explicitly check for
zero-valued intermediates that would otherwise produce ``nan``,
instead terminating early if encounteredand returning
``(0.0, 0.0)``.
backend : str, optional
What library to use for ``tensordot``, ``einsum`` and
``transpose``, it will be automatically inferred from the input
arrays if not given.
progbar : bool, optional
Whether to show a progress bar.
"""
__slots__ = (
"contractions",
"strip_exponent",
"check_zero",
"implementation",
"backend",
"progbar",
"__weakref__",
)
def __init__(
self,
contractions,
strip_exponent=False,
check_zero=False,
implementation="auto",
backend=None,
progbar=False,
):
self.contractions = contractions
self.strip_exponent = strip_exponent
self.check_zero = check_zero
self.implementation = implementation
self.backend = backend
self.progbar = progbar
def __call__(self, *arrays, **kwargs):
"""Contract ``arrays`` using operations listed in ``contractions``.
Parameters
----------
arrays : sequence of array-like
The arrays to contract.
kwargs : dict
Override the default settings for this contraction only.
Returns
-------
output : array
The contracted output, it will be scaled if ``strip_exponent==True``.
exponent : float
The exponent of the output in base 10, returned only if
``strip_exponent==True``.
"""
backend = kwargs.pop("backend", self.backend)
progbar = kwargs.pop("progbar", self.progbar)
check_zero = kwargs.pop("check_zero", self.check_zero)
strip_exponent = kwargs.pop("strip_exponent", self.strip_exponent)
implementation = kwargs.pop("implementation", self.implementation)
if kwargs:
raise TypeError(f"Unknown keyword arguments: {kwargs}.")
if backend is None:
backend = infer_backend_multi(*arrays)
if implementation == "auto":
if (backend == "numpy") or (
backend == "torch"
and all(
getattr(getattr(x, "device", None), "type", "cpu") == "cpu"
for x in arrays
if hasattr(x, "device")
)
):
# by default replace numpy's einsum/tensordot, and do the
# same for torch CPU to control bmm outputs and workspace reuse
implementation = "cotengra"
else:
implementation = "autoray"
if implementation == "cotengra":
_einsum, _tensordot = einsum, tensordot
elif implementation == "autoray":
try:
_einsum = get_lib_fn(backend, "einsum")
except ImportError:
# fallback to cotengra (matmul) implementation
_einsum = einsum
try:
_tensordot = get_lib_fn(backend, "tensordot")
except ImportError:
# fallback to cotengra (matmul) implementation
_tensordot = tensordot
else:
# manually supplied
_einsum, _tensordot = implementation
using_torch_arena = (backend == "torch") and _torch_arena_enabled()
if using_torch_arena:
global _TORCH_ARENA
_TORCH_ARENA = _TorchArena()
using_torch_workspace = (
(backend == "torch")
and (_einsum is einsum)
and _torch_workspace_enabled()
and not using_torch_arena
)
if using_torch_workspace:
global _TORCH_WORKSPACE_POOL
_TORCH_WORKSPACE_POOL = {}
# temporary storage for intermediates
N = len(arrays)
temps = {
leaf: array
for leaf, array in zip(map(node_from_single, range(N)), arrays)
}
exponent = 0.0 if (strip_exponent is not False) else None
if progbar:
import tqdm
contractions = tqdm.tqdm(self.contractions, total=N - 1)
else:
contractions = self.contractions
p_array = next(iter(temps.values()))
for p, l, r, tdot, arg, perm in contractions:
if (l is None) and (r is None):
# single term simplification, perform inplace with einsum
temps[p] = _einsum(arg, temps[p])
p_array = temps[p]
continue
# get input arrays for this contraction
l_array = temps.pop(l)
r_array = temps.pop(r)
if tdot:
p_array = _tensordot(l_array, r_array, arg)
if perm:
p_array = do("transpose", p_array, perm, like=backend)
else:
p_array = _einsum(arg, l_array, r_array)
if exponent is not None:
factor = do(
"max", do("abs", p_array, like=backend), like=backend
)
if check_zero and float(factor) == 0.0:
if using_torch_arena:
_TORCH_ARENA = None
return 0.0, float("-inf")
exponent = exponent + do("log10", factor, like=backend)
p_array = p_array / factor
# insert the new intermediate array
temps[p] = p_array
if using_torch_workspace:
if (len(l) != 1) and hasattr(l_array, "device"):
_torch_workspace_push(l_array)
if (len(r) != 1) and hasattr(r_array, "device"):
_torch_workspace_push(r_array)
if using_torch_arena:
_TORCH_ARENA.release(l_array)
_TORCH_ARENA.release(r_array)
if using_torch_arena:
# The final output may be a view into the arena. Clone it before
# dropping the arena so a scalar result doesn't keep the whole
# workspace storage alive.
if hasattr(p_array, "clone"):
p_array = p_array.clone()
_TORCH_ARENA = None
if exponent is not None:
return p_array, exponent
return p_array
class CuQuantumContractor:
def __init__(
self,
tree,
handle_slicing=False,
autotune=False,
**kwargs,
):
if kwargs.pop("strip_exponent", None):
raise ValueError(
"strip_exponent=True not supported with cuQuantum"
)
if tree.has_preprocessing():
raise ValueError("Preprocessing not supported with cuQuantum yet.")
if kwargs.pop("progbar", None):
import warnings
warnings.warn("Progress bar not supported with cuQuantum yet.")
if handle_slicing:
self.eq = tree.get_eq()
self.shapes = tree.get_shapes()
else:
self.eq = tree.get_eq_sliced()
self.shapes = tree.get_shapes_sliced()
if tree.is_complete():
kwargs.setdefault("optimize", {})
kwargs["optimize"].setdefault("path", tree.get_path())
if handle_slicing and tree.sliced_inds:
kwargs["optimize"].setdefault(
"slicing",
[(ix, tree.size_dict[ix] - 1) for ix in tree.sliced_inds],
)
self.kwargs = kwargs
self.autotune = 3 if autotune is True else autotune
self.handle = None
self.network = None
def setup(self, *arrays):
import cuquantum
if hasattr(cuquantum, "bindings"):
# cuquantum-python >= 25.03
from cuquantum.tensornet import Network
else:
# for cuquantum < 25.03
from cuquantum import Network
self.network = Network(
self.eq,
*arrays,
)
self.network.contract_path(**self.kwargs)
if self.autotune:
self.network.autotune(iterations=self.autotune)
def __call__(
self,
*arrays,
check_zero=False,
backend=None,
progbar=False,
):
# can't handle these yet
assert not check_zero
assert not progbar
assert backend is None
if self.network is None:
self.setup(*arrays)
else:
self.network.reset_operands(*arrays)
return self.network.contract()
def __del__(self):
if self.network is not None:
self.network.free()
def make_contractor(
tree,
order=None,
prefer_einsum=False,
strip_exponent=False,
check_zero=False,
implementation=None,
autojit=False,
progbar=False,
):
"""Get a reusable function which performs the contraction corresponding
to ``tree``. The various options provide defaults that can also be overrode
when calling the standard contractor.
Parameters
----------
tree : ContractionTree
The contraction tree.
order : str or callable, optional
Supplied to :meth:`ContractionTree.traverse`, the order in which
to perform the pairwise contractions given by the tree.
prefer_einsum : bool, optional
Prefer to use ``einsum`` for pairwise contractions, even if
``tensordot`` can perform the contraction.
strip_exponent : bool, optional
If ``True``, the function will strip the exponent from the output
array and return it separately.
check_zero : bool, optional
If ``True``, when ``strip_exponent=True``, explicitly check for
zero-valued intermediates that would otherwise produce ``nan``,
instead terminating early if encountered and returning
``(0.0, 0.0)``.
implementation : str or tuple[callable, callable], optional
What library to use to actually perform the contractions. Options are
- "auto": let cotengra choose
- "autoray": dispatch with autoray, using the ``tensordot`` and
``einsum`` implementation of the backend
- "cotengra": use the ``tensordot`` and ``einsum`` implementation of
cotengra, which is based on batch matrix multiplication. This is
faster for some backends like numpy, and also enables libraries
which don't yet provide ``tensordot`` and ``einsum`` to be used.
- "cuquantum": use the cuquantum library to perform the whole
contraction (not just individual contractions).
- tuple[callable, callable]: manually supply the ``tensordot`` and
``einsum`` implementations to use.
autojit : bool, optional
If ``True``, use :func:`autoray.autojit` to compile the contraction
function.
progbar : bool, optional
Whether to show progress through the contraction by default.
Returns
-------
fn : callable
The contraction function, with signature ``fn(*arrays)``.
"""
if implementation is None:
implementation = get_default_implementation()
if implementation == "cuquantum":
fn = CuQuantumContractor(
tree,
strip_exponent=strip_exponent,
check_zero=check_zero,
progbar=progbar,
)
else:
fn = Contractor(
contractions=extract_contractions(tree, order, prefer_einsum),
strip_exponent=strip_exponent,
check_zero=check_zero,
implementation=implementation,
progbar=progbar,
)
if autojit:
from autoray import autojit as _autojit
fn = _autojit(fn)
return fn