"""Functionality relating to actually contracting.""" import functools import itertools import operator import contextlib import os from autoray import do, shape, infer_backend_multi, get_lib_fn from .utils import node_from_single DEFAULT_IMPLEMENTATION = "auto" def _torch_workspace_enabled(): value = os.environ.get("QIBOTN_TORCH_WORKSPACE", "0").strip().lower() return value in {"1", "true", "yes", "on", "enable", "enabled"} def _torch_arena_enabled(): value = os.environ.get("QIBOTN_TORCH_ARENA", "0").strip().lower() return value in {"1", "true", "yes", "on", "enable", "enabled"} def _parse_size_bytes(value, default): if value is None: return default text = str(value).strip().lower() scale = 1 for suffix, multiplier in ( ("gib", 1024**3), ("gb", 1000**3), ("gi", 1024**3), ("g", 1024**3), ("mib", 1024**2), ("mb", 1000**2), ("mi", 1024**2), ("m", 1024**2), ): if text.endswith(suffix): text = text[: -len(suffix)] scale = multiplier break return int(float(text) * scale) class _TorchArena: """Simple first-fit arena for torch CPU contraction intermediates.""" def __init__(self, size_bytes=None): self.size_bytes = _parse_size_bytes( size_bytes or os.environ.get("QIBOTN_TORCH_ARENA_BYTES"), 70 * 1024**3, ) self.buffer = None self.dtype = None self.device = None self.free = [] self.allocated = {} def _setup(self, dtype, device): import torch element_size = torch.empty((), dtype=dtype, device=device).element_size() nelements = self.size_bytes // element_size if nelements <= 0: raise MemoryError("QIBOTN_TORCH_ARENA_BYTES is too small.") self.buffer = torch.empty(nelements, dtype=dtype, device=device) self.dtype = dtype self.device = device self.free = [(0, nelements)] self.allocated = {} def alloc(self, shape, dtype, device): import torch if self.buffer is None: self._setup(dtype, device) elif (dtype != self.dtype) or (device != self.device): return torch.empty(shape, dtype=dtype, device=device) numel = functools.reduce(operator.mul, shape, 1) element_size = self.buffer.element_size() align = max(1, 64 // element_size) for i, (offset, length) in enumerate(self.free): aligned = ((offset + align - 1) // align) * align padding = aligned - offset if length - padding < numel: continue new_blocks = [] if padding: new_blocks.append((offset, padding)) tail_offset = aligned + numel tail_length = (offset + length) - tail_offset if tail_length: new_blocks.append((tail_offset, tail_length)) self.free[i : i + 1] = new_blocks self.allocated[aligned] = numel return self.buffer.narrow(0, aligned, numel).view(shape) return torch.empty(shape, dtype=dtype, device=device) def release(self, x): if self.buffer is None: return if not hasattr(x, "untyped_storage"): return try: if x.untyped_storage().data_ptr() != self.buffer.untyped_storage().data_ptr(): return offset = int(x.storage_offset()) except Exception: return length = self.allocated.pop(offset, None) if length is None: return self.free.append((offset, length)) self.free.sort() merged = [] for block_offset, block_length in self.free: if merged and merged[-1][0] + merged[-1][1] == block_offset: prev_offset, prev_length = merged[-1] merged[-1] = (prev_offset, prev_length + block_length) else: merged.append((block_offset, block_length)) self.free = merged def set_default_implementation(impl): global DEFAULT_IMPLEMENTATION DEFAULT_IMPLEMENTATION = impl def get_default_implementation(): return DEFAULT_IMPLEMENTATION @contextlib.contextmanager def default_implementation(impl): """Context manager for temporarily setting the default implementation.""" global DEFAULT_IMPLEMENTATION old_impl = DEFAULT_IMPLEMENTATION DEFAULT_IMPLEMENTATION = impl try: yield finally: DEFAULT_IMPLEMENTATION = old_impl @functools.lru_cache(2**12) def _sanitize_equation(eq): """Get the input and output indices of an equation, computing the output implicitly as the sorted sequence of every index that appears exactly once if it is not provided. """ # remove spaces eq = eq.replace(" ", "") if "..." in eq: raise NotImplementedError("Ellipsis not supported.") if "->" not in eq: lhs = eq tmp_subscripts = lhs.replace(",", "") out = "".join( # sorted sequence of indices s for s in sorted(set(tmp_subscripts)) # that appear exactly once if tmp_subscripts.count(s) == 1 ) else: lhs, out = eq.split("->") return lhs, out @functools.lru_cache(2**12) def _parse_einsum_single(eq, shape): """Cached parsing of a single term einsum equation into the necessary sequence of arguments for axes diagonals, sums, and transposes. """ lhs, out = _sanitize_equation(eq) # parse each index need_to_diag = [] need_to_sum = [] seen = set() for ix in lhs: if ix in need_to_diag: continue if ix in seen: need_to_diag.append(ix) continue seen.add(ix) if ix not in out: need_to_sum.append(ix) # first handle diagonal reductions if need_to_diag: diag_sels = [] sizes = dict(zip(lhs, shape)) while need_to_diag: ixd = need_to_diag.pop() dinds = tuple(range(sizes[ixd])) # construct advanced indexing object selector = tuple(dinds if ix == ixd else slice(None) for ix in lhs) diag_sels.append(selector) # after taking the diagonal what are new indices? ixd_contig = ixd * lhs.count(ixd) if ixd_contig in lhs: # contig axes, new axis is at same position lhs = lhs.replace(ixd_contig, ixd) else: # non-contig, new axis is at beginning lhs = ixd + lhs.replace(ixd, "") else: diag_sels = None # then sum reductions if need_to_sum: sum_axes = tuple(map(lhs.index, need_to_sum)) for ix in need_to_sum: lhs = lhs.replace(ix, "") else: sum_axes = None # then transposition if lhs == out: perm = None else: perm = tuple(lhs.index(ix) for ix in out) return diag_sels, sum_axes, perm def _parse_eq_to_pure_multiplication(a_term, shape_a, b_term, shape_b, out): """If there are no contracted indices, then we can directly transpose and insert singleton dimensions into ``a`` and ``b`` such that (broadcast) elementwise multiplication performs the einsum. No need to cache this as it is within the cached ``_parse_eq_to_batch_matmul``. """ desired_a = "" desired_b = "" new_shape_a = [] new_shape_b = [] for ix in out: if ix in a_term: desired_a += ix new_shape_a.append(shape_a[a_term.index(ix)]) else: new_shape_a.append(1) if ix in b_term: desired_b += ix new_shape_b.append(shape_b[b_term.index(ix)]) else: new_shape_b.append(1) if desired_a != a_term: eq_a = f"{a_term}->{desired_a}" else: eq_a = None if desired_b != b_term: eq_b = f"{b_term}->{desired_b}" else: eq_b = None return ( eq_a, eq_b, new_shape_a, new_shape_b, None, # new_shape_ab, not needed since not fusing None, # perm_ab, not needed as we transpose a and b first True, # pure_multiplication=True ) @functools.lru_cache(2**12) def _parse_eq_to_batch_matmul(eq, shape_a, shape_b): """Cached parsing of a two term einsum equation into the necessary sequence of arguments for contracttion via batched matrix multiplication. The steps we need to specify are: 1. Remove repeated and trivial indices from the left and right terms, and transpose them, done as a single einsum. 2. Fuse the remaining indices so we have two 3D tensors. 3. Perform the batched matrix multiplication. 4. Unfuse the output to get the desired final index order. """ lhs, out = eq.split("->") a_term, b_term = lhs.split(",") if len(a_term) != len(shape_a): raise ValueError(f"Term '{a_term}' does not match shape {shape_a}.") if len(b_term) != len(shape_b): raise ValueError(f"Term '{b_term}' does not match shape {shape_b}.") bat_inds = [] # appears on A, B, O con_inds = [] # appears on A, B, . a_keep = [] # appears on A, ., O b_keep = [] # appears on ., B, O sizes = {} singletons = set() # parse left term seen = set() for ix, d in zip(a_term, shape_a): if d == 1: # everything (including broadcasting) works nicely if simply ignore # such dimensions, but we do need to track if they appear in output # and thus should be reintroduced later singletons.add(ix) continue # set or check size if sizes.setdefault(ix, d) != d: raise ValueError( f"Index {ix} has mismatched sizes {sizes[ix]} and {d}." ) if ix in seen: continue seen.add(ix) if ix in b_term: if ix in out: bat_inds.append(ix) else: con_inds.append(ix) elif ix in out: a_keep.append(ix) # parse right term seen.clear() for ix, d in zip(b_term, shape_b): if d == 1: singletons.add(ix) continue # broadcast indices don't appear as singletons in output singletons.discard(ix) # set or check size if sizes.setdefault(ix, d) != d: raise ValueError( f"Index {ix} has mismatched sizes {sizes[ix]} and {d}." ) if ix in seen: continue seen.add(ix) if ix not in a_term: if ix in out: b_keep.append(ix) if not con_inds: # contraction is pure multiplication, prepare inputs differently return _parse_eq_to_pure_multiplication( a_term, shape_a, b_term, shape_b, out ) # only need the size one indices that appear in the output singletons = [ix for ix in out if ix in singletons] # take diagonal, remove any trivial axes and transpose left desired_a = "".join((*bat_inds, *a_keep, *con_inds)) if a_term != desired_a: if set(a_term) == set(desired_a): # only need to transpose, don't invoke einsum eq_a = tuple(a_term.index(ix) for ix in desired_a) else: eq_a = f"{a_term}->{desired_a}" else: eq_a = None # take diagonal, remove any trivial axes and transpose right desired_b = "".join((*bat_inds, *con_inds, *b_keep)) if b_term != desired_b: if set(b_term) == set(desired_b): # only need to transpose, don't invoke einsum eq_b = tuple(b_term.index(ix) for ix in desired_b) else: eq_b = f"{b_term}->{desired_b}" else: eq_b = None # then we want to reshape if bat_inds: lgroups = (bat_inds, a_keep, con_inds) rgroups = (bat_inds, con_inds, b_keep) ogroups = (bat_inds, a_keep, b_keep) else: # avoid size 1 batch dimension if no batch indices lgroups = (a_keep, con_inds) rgroups = (con_inds, b_keep) ogroups = (a_keep, b_keep) if any(len(group) != 1 for group in lgroups): # need to fuse 'kept' and contracted indices # (though could allow batch indices to be broadcast) new_shape_a = tuple( functools.reduce(operator.mul, (sizes[ix] for ix in ix_group), 1) for ix_group in lgroups ) else: new_shape_a = None if any(len(group) != 1 for group in rgroups): # need to fuse 'kept' and contracted indices # (though could allow batch indices to be broadcast) new_shape_b = tuple( functools.reduce(operator.mul, (sizes[ix] for ix in ix_group), 1) for ix_group in rgroups ) else: new_shape_b = None if any(len(group) != 1 for group in ogroups) or singletons: new_shape_ab = (1,) * len(singletons) + tuple( sizes[ix] for ix_group in ogroups for ix in ix_group ) else: new_shape_ab = None # then we want to permute the matmul produced output: out_produced = "".join((*singletons, *bat_inds, *a_keep, *b_keep)) perm_ab = tuple(out_produced.index(ix) for ix in out) if perm_ab == tuple(range(len(perm_ab))): perm_ab = None return ( eq_a, eq_b, new_shape_a, new_shape_b, new_shape_ab, perm_ab, False, # pure_multiplication=False ) def _einsum_single(eq, x, backend=None): """Einsum on a single tensor, via three steps: diagonal selection (via advanced indexing), axes summations, transposition. The logic for each is cached based on the equation and array shape, and each step is only performed if necessary. """ try: return do("einsum", eq, x, like=backend) except ImportError: pass diag_sels, sum_axes, perm = _parse_einsum_single(eq, shape(x)) if diag_sels is not None: # diagonal reduction via advanced indexing # e.g ababbac->abc for selector in diag_sels: x = x[selector] if sum_axes is not None: # trivial removal of axes via summation # e.g. abc->c x = do("sum", x, sum_axes, like=backend) if perm is not None: # transpose to desired output # e.g. abc->cba x = do("transpose", x, perm, like=backend) return x def _do_contraction_via_bmm( a, b, eq_a, eq_b, new_shape_a, new_shape_b, new_shape_ab, perm_ab, pure_multiplication, backend, ): # prepare left if eq_a is not None: if isinstance(eq_a, tuple): # only transpose a = do("transpose", a, eq_a, like=backend) else: # diagonals, sums, and tranpose a = _einsum_single(eq_a, a) if new_shape_a is not None: a = do("reshape", a, new_shape_a, like=backend) # prepare right if eq_b is not None: if isinstance(eq_b, tuple): # only transpose b = do("transpose", b, eq_b, like=backend) else: # diagonals, sums, and tranpose b = _einsum_single(eq_b, b) if new_shape_b is not None: b = do("reshape", b, new_shape_b, like=backend) if pure_multiplication: # no contracted indices return do("multiply", a, b) # do the contraction! ab = do("matmul", a, b, like=backend) # prepare the output if new_shape_ab is not None: ab = do("reshape", ab, new_shape_ab, like=backend) if perm_ab is not None: ab = do("transpose", ab, perm_ab, like=backend) return ab def _torch_workspace_pop(shape, dtype, device): try: pool = _TORCH_WORKSPACE_POOL except NameError: return None try: return pool[(tuple(shape), dtype, device)].pop() except (KeyError, IndexError): return None def _torch_workspace_push(x): try: pool = _TORCH_WORKSPACE_POOL except NameError: return if not x.is_contiguous(): return pool.setdefault((tuple(x.shape), x.dtype, x.device), []).append(x) def _torch_matmul_workspace(a, b): import torch shape = torch.broadcast_shapes(a.shape[:-2], b.shape[:-2]) + ( a.shape[-2], b.shape[-1], ) try: arena = _TORCH_ARENA except NameError: arena = None if arena is not None: out = arena.alloc(shape, a.dtype, a.device) else: out = _torch_workspace_pop(shape, a.dtype, a.device) if out is None: out = torch.empty(shape, dtype=a.dtype, device=a.device) if a.ndim == 2 and b.ndim == 2: torch.mm(a, b, out=out) elif a.ndim == 3 and b.ndim == 3: torch.bmm(a, b, out=out) else: torch.matmul(a, b, out=out) return out def _torch_multiply_workspace(a, b): import torch shape = torch.broadcast_shapes(a.shape, b.shape) try: arena = _TORCH_ARENA except NameError: arena = None if arena is not None: out = arena.alloc(shape, a.dtype, a.device) torch.mul(a, b, out=out) return out return do("multiply", a, b) def _torch_reshape_workspace(x, new_shape, backend): try: arena = _TORCH_ARENA except NameError: arena = None if arena is None: return do("reshape", x, new_shape, like=backend) if not hasattr(x, "view"): return do("reshape", x, new_shape, like=backend) try: return x.view(new_shape) except RuntimeError: pass try: out = arena.alloc(tuple(new_shape), x.dtype, x.device) out.view(tuple(x.shape)).copy_(x) return out except Exception: return do("reshape", x, new_shape, like=backend) def _do_contraction_via_bmm_torch_workspace( a, b, eq_a, eq_b, new_shape_a, new_shape_b, new_shape_ab, perm_ab, pure_multiplication, backend, ): import torch if eq_a is not None: if isinstance(eq_a, tuple): a = do("transpose", a, eq_a, like=backend) else: a = _einsum_single(eq_a, a) if new_shape_a is not None: a = _torch_reshape_workspace(a, new_shape_a, backend) if eq_b is not None: if isinstance(eq_b, tuple): b = do("transpose", b, eq_b, like=backend) else: b = _einsum_single(eq_b, b) if new_shape_b is not None: b = _torch_reshape_workspace(b, new_shape_b, backend) if pure_multiplication: return _torch_multiply_workspace(a, b) ab = _torch_matmul_workspace(a, b) if new_shape_ab is not None: ab = _torch_reshape_workspace(ab, new_shape_ab, backend) if perm_ab is not None: ab = do("transpose", ab, perm_ab, like=backend) return ab def einsum(eq, a, b=None, *, backend=None): """Perform arbitrary single and pairwise einsums using only `matmul`, `transpose`, `reshape` and `sum`. The logic for each is cached based on the equation and array shape, and each step is only performed if necessary. Parameters ---------- eq : str The einsum equation. a : array_like The first array to contract. b : array_like, optional The second array to contract. backend : str, optional The backend to use for array operations. If ``None``, dispatch automatically based on ``a`` and ``b``. Returns ------- array_like """ if b is None: return _einsum_single(eq, a, backend=backend) ( eq_a, eq_b, new_shape_a, new_shape_b, new_shape_ab, perm_ab, pure_multiplication, ) = _parse_eq_to_batch_matmul(eq, shape(a), shape(b)) do_contraction = ( _do_contraction_via_bmm_torch_workspace if backend == "torch" and (_torch_workspace_enabled() or _torch_arena_enabled()) else _do_contraction_via_bmm ) return do_contraction( a, b, eq_a, eq_b, new_shape_a, new_shape_b, new_shape_ab, perm_ab, pure_multiplication, backend, ) def gen_nice_inds(): """Generate the indices from [a-z, A-Z, reasonable unicode...].""" for i in range(26): yield chr(ord("a") + i) for i in range(26): yield chr(ord("A") + i) for i in itertools.count(192): yield chr(i) @functools.lru_cache(2**12) def _parse_tensordot_axes_to_matmul(axes, shape_a, shape_b): """Parse a tensordot specification into the necessary sequence of arguments for contracttion via matrix multiplication. This just converts ``axes`` into an ``einsum`` eq string then calls ``_parse_eq_to_batch_matmul``. """ ndim_a = len(shape_a) ndim_b = len(shape_b) if isinstance(axes, int): axes_a = tuple(range(ndim_a - axes, ndim_a)) axes_b = tuple(range(axes)) else: axes_a, axes_b = axes num_con = len(axes_a) if num_con != len(axes_b): raise ValueError( f"Axes should have the same length, got {axes_a} and {axes_b}." ) possible_inds = gen_nice_inds() inds_a = [next(possible_inds) for _ in range(ndim_a)] inds_b = [] inds_out = inds_a.copy() for axb in range(ndim_b): if axb not in axes_b: # right uncontracted index ind = next(possible_inds) inds_out.append(ind) else: # contracted index axa = axes_a[axes_b.index(axb)] # check that the shapes match if shape_a[axa] != shape_b[axb]: raise ValueError( f"Dimension mismatch between axes {axa} of {shape_a} and " f"{axb} of {shape_b}: {shape_a[axa]} != {shape_b[axb]}." ) ind = inds_a[axa] inds_out.remove(ind) inds_b.append(ind) eq = f"{''.join(inds_a)},{''.join(inds_b)}->{''.join(inds_out)}" return _parse_eq_to_batch_matmul(eq, shape_a, shape_b) def tensordot(a, b, axes=2, *, backend=None): """Perform a tensordot using only `matmul`, `transpose`, `reshape`. The logic for each is cached based on the equation and array shape, and each step is only performed if necessary. Parameters ---------- a, b : array_like The arrays to contract. axes : int or tuple of (sequence[int], sequence[int]) The number of axes to contract, or the axes to contract. If an int, the last ``axes`` axes of ``a`` and the first ``axes`` axes of ``b`` are contracted. If a tuple, the axes to contract for ``a`` and ``b`` respectively. backend : str or None, optional The backend to use for array operations. If ``None``, dispatch automatically based on ``a`` and ``b``. Returns ------- array_like """ try: # ensure hashable axes = tuple(map(int, axes[0])), tuple(map(int, axes[1])) except IndexError: axes = int(axes) ( eq_a, eq_b, new_shape_a, new_shape_b, new_shape_ab, perm_ab, pure_multiplication, ) = _parse_tensordot_axes_to_matmul(axes, shape(a), shape(b)) do_contraction = ( _do_contraction_via_bmm_torch_workspace if backend == "torch" and (_torch_workspace_enabled() or _torch_arena_enabled()) else _do_contraction_via_bmm ) return do_contraction( a, b, eq_a, eq_b, new_shape_a, new_shape_b, new_shape_ab, perm_ab, pure_multiplication, backend, ) def extract_contractions( tree, order=None, prefer_einsum=False, ): """Extract just the information needed to perform the contraction. Parameters ---------- order : str or callable, optional Supplied to :meth:`ContractionTree.traverse`. prefer_einsum : bool, optional Prefer to use ``einsum`` for pairwise contractions, even if ``tensordot`` can perform the contraction. Returns ------- contractions : tuple A tuple of tuples, each containing the information needed to perform a pairwise contraction. Each tuple contains: - ``p``: the parent node, - ``l``: the left child node, - ``r``: the right child node, - ``tdot``: whether to use ``tensordot`` or ``einsum``, - ``arg``: the argument to pass to ``tensordot`` or ``einsum`` i.e. ``axes`` or ``eq``, - ``perm``: the permutation required after the contraction, if any (only applies to tensordot). If both ``l`` and ``r`` are ``None``, the the operation is a single term simplification performed with ``einsum``. """ contractions = [] # pairwise contractions contractions.extend( (p, l, r, False, tree.get_einsum_eq(p), None) if (prefer_einsum or not tree.get_can_dot(p)) else ( p, l, r, True, tree.get_tensordot_axes(p), tree.get_tensordot_perm(p), ) for p, l, r in tree.traverse(order=order) ) if tree.preprocessing: # inplace single term simplifications # n.b. these are populated lazily when the other information is # computed above, so we do it after pre_contractions = ( (node_from_single(i), None, None, False, eq, None) for i, eq in tree.preprocessing.items() ) return (*pre_contractions, *contractions) return tuple(contractions) class Contractor: """Default cotengra network contractor. Parameters ---------- contractions : tuple[tuple] The sequence of contractions to perform. Each contraction should be a tuple containing: - ``p``: the parent node, - ``l``: the left child node, - ``r``: the right child node, - ``tdot``: whether to use ``tensordot`` or ``einsum``, - ``arg``: the argument to pass to ``tensordot`` or ``einsum`` i.e. ``axes`` or ``eq``, - ``perm``: the permutation required after the contraction, if any (only applies to tensordot). e.g. built by calling ``extract_contractions(tree)``. strip_exponent : bool, optional If ``True``, eagerly strip the exponent (in log10) from intermediate tensors to control numerical problems from leaving the range of the datatype. This method then returns the scaled 'mantissa' output array and the exponent separately. check_zero : bool, optional If ``True``, when ``strip_exponent=True``, explicitly check for zero-valued intermediates that would otherwise produce ``nan``, instead terminating early if encounteredand returning ``(0.0, 0.0)``. backend : str, optional What library to use for ``tensordot``, ``einsum`` and ``transpose``, it will be automatically inferred from the input arrays if not given. progbar : bool, optional Whether to show a progress bar. """ __slots__ = ( "contractions", "strip_exponent", "check_zero", "implementation", "backend", "progbar", "__weakref__", ) def __init__( self, contractions, strip_exponent=False, check_zero=False, implementation="auto", backend=None, progbar=False, ): self.contractions = contractions self.strip_exponent = strip_exponent self.check_zero = check_zero self.implementation = implementation self.backend = backend self.progbar = progbar def __call__(self, *arrays, **kwargs): """Contract ``arrays`` using operations listed in ``contractions``. Parameters ---------- arrays : sequence of array-like The arrays to contract. kwargs : dict Override the default settings for this contraction only. Returns ------- output : array The contracted output, it will be scaled if ``strip_exponent==True``. exponent : float The exponent of the output in base 10, returned only if ``strip_exponent==True``. """ backend = kwargs.pop("backend", self.backend) progbar = kwargs.pop("progbar", self.progbar) check_zero = kwargs.pop("check_zero", self.check_zero) strip_exponent = kwargs.pop("strip_exponent", self.strip_exponent) implementation = kwargs.pop("implementation", self.implementation) if kwargs: raise TypeError(f"Unknown keyword arguments: {kwargs}.") if backend is None: backend = infer_backend_multi(*arrays) if implementation == "auto": if (backend == "numpy") or ( backend == "torch" and all( getattr(getattr(x, "device", None), "type", "cpu") == "cpu" for x in arrays if hasattr(x, "device") ) ): # by default replace numpy's einsum/tensordot, and do the # same for torch CPU to control bmm outputs and workspace reuse implementation = "cotengra" else: implementation = "autoray" if implementation == "cotengra": _einsum, _tensordot = einsum, tensordot elif implementation == "autoray": try: _einsum = get_lib_fn(backend, "einsum") except ImportError: # fallback to cotengra (matmul) implementation _einsum = einsum try: _tensordot = get_lib_fn(backend, "tensordot") except ImportError: # fallback to cotengra (matmul) implementation _tensordot = tensordot else: # manually supplied _einsum, _tensordot = implementation using_torch_arena = (backend == "torch") and _torch_arena_enabled() if using_torch_arena: global _TORCH_ARENA _TORCH_ARENA = _TorchArena() using_torch_workspace = ( (backend == "torch") and (_einsum is einsum) and _torch_workspace_enabled() and not using_torch_arena ) if using_torch_workspace: global _TORCH_WORKSPACE_POOL _TORCH_WORKSPACE_POOL = {} # temporary storage for intermediates N = len(arrays) temps = { leaf: array for leaf, array in zip(map(node_from_single, range(N)), arrays) } exponent = 0.0 if (strip_exponent is not False) else None if progbar: import tqdm contractions = tqdm.tqdm(self.contractions, total=N - 1) else: contractions = self.contractions p_array = next(iter(temps.values())) for p, l, r, tdot, arg, perm in contractions: if (l is None) and (r is None): # single term simplification, perform inplace with einsum temps[p] = _einsum(arg, temps[p]) p_array = temps[p] continue # get input arrays for this contraction l_array = temps.pop(l) r_array = temps.pop(r) if tdot: p_array = _tensordot(l_array, r_array, arg) if perm: p_array = do("transpose", p_array, perm, like=backend) else: p_array = _einsum(arg, l_array, r_array) if exponent is not None: factor = do( "max", do("abs", p_array, like=backend), like=backend ) if check_zero and float(factor) == 0.0: if using_torch_arena: _TORCH_ARENA = None return 0.0, float("-inf") exponent = exponent + do("log10", factor, like=backend) p_array = p_array / factor # insert the new intermediate array temps[p] = p_array if using_torch_workspace: if (len(l) != 1) and hasattr(l_array, "device"): _torch_workspace_push(l_array) if (len(r) != 1) and hasattr(r_array, "device"): _torch_workspace_push(r_array) if using_torch_arena: _TORCH_ARENA.release(l_array) _TORCH_ARENA.release(r_array) if using_torch_arena: # The final output may be a view into the arena. Clone it before # dropping the arena so a scalar result doesn't keep the whole # workspace storage alive. if hasattr(p_array, "clone"): p_array = p_array.clone() _TORCH_ARENA = None if exponent is not None: return p_array, exponent return p_array class CuQuantumContractor: def __init__( self, tree, handle_slicing=False, autotune=False, **kwargs, ): if kwargs.pop("strip_exponent", None): raise ValueError( "strip_exponent=True not supported with cuQuantum" ) if tree.has_preprocessing(): raise ValueError("Preprocessing not supported with cuQuantum yet.") if kwargs.pop("progbar", None): import warnings warnings.warn("Progress bar not supported with cuQuantum yet.") if handle_slicing: self.eq = tree.get_eq() self.shapes = tree.get_shapes() else: self.eq = tree.get_eq_sliced() self.shapes = tree.get_shapes_sliced() if tree.is_complete(): kwargs.setdefault("optimize", {}) kwargs["optimize"].setdefault("path", tree.get_path()) if handle_slicing and tree.sliced_inds: kwargs["optimize"].setdefault( "slicing", [(ix, tree.size_dict[ix] - 1) for ix in tree.sliced_inds], ) self.kwargs = kwargs self.autotune = 3 if autotune is True else autotune self.handle = None self.network = None def setup(self, *arrays): import cuquantum if hasattr(cuquantum, "bindings"): # cuquantum-python >= 25.03 from cuquantum.tensornet import Network else: # for cuquantum < 25.03 from cuquantum import Network self.network = Network( self.eq, *arrays, ) self.network.contract_path(**self.kwargs) if self.autotune: self.network.autotune(iterations=self.autotune) def __call__( self, *arrays, check_zero=False, backend=None, progbar=False, ): # can't handle these yet assert not check_zero assert not progbar assert backend is None if self.network is None: self.setup(*arrays) else: self.network.reset_operands(*arrays) return self.network.contract() def __del__(self): if self.network is not None: self.network.free() def make_contractor( tree, order=None, prefer_einsum=False, strip_exponent=False, check_zero=False, implementation=None, autojit=False, progbar=False, ): """Get a reusable function which performs the contraction corresponding to ``tree``. The various options provide defaults that can also be overrode when calling the standard contractor. Parameters ---------- tree : ContractionTree The contraction tree. order : str or callable, optional Supplied to :meth:`ContractionTree.traverse`, the order in which to perform the pairwise contractions given by the tree. prefer_einsum : bool, optional Prefer to use ``einsum`` for pairwise contractions, even if ``tensordot`` can perform the contraction. strip_exponent : bool, optional If ``True``, the function will strip the exponent from the output array and return it separately. check_zero : bool, optional If ``True``, when ``strip_exponent=True``, explicitly check for zero-valued intermediates that would otherwise produce ``nan``, instead terminating early if encountered and returning ``(0.0, 0.0)``. implementation : str or tuple[callable, callable], optional What library to use to actually perform the contractions. Options are - "auto": let cotengra choose - "autoray": dispatch with autoray, using the ``tensordot`` and ``einsum`` implementation of the backend - "cotengra": use the ``tensordot`` and ``einsum`` implementation of cotengra, which is based on batch matrix multiplication. This is faster for some backends like numpy, and also enables libraries which don't yet provide ``tensordot`` and ``einsum`` to be used. - "cuquantum": use the cuquantum library to perform the whole contraction (not just individual contractions). - tuple[callable, callable]: manually supply the ``tensordot`` and ``einsum`` implementations to use. autojit : bool, optional If ``True``, use :func:`autoray.autojit` to compile the contraction function. progbar : bool, optional Whether to show progress through the contraction by default. Returns ------- fn : callable The contraction function, with signature ``fn(*arrays)``. """ if implementation is None: implementation = get_default_implementation() if implementation == "cuquantum": fn = CuQuantumContractor( tree, strip_exponent=strip_exponent, check_zero=check_zero, progbar=progbar, ) else: fn = Contractor( contractions=extract_contractions(tree, order, prefer_einsum), strip_exponent=strip_exponent, check_zero=check_zero, implementation=implementation, progbar=progbar, ) if autojit: from autoray import autojit as _autojit fn = _autojit(fn) return fn