Source code for persunraveltorch.nn.reshear

from typing import Tuple
from enum import StrEnum, auto

import torch
import torch.nn.functional as F


__all__ = [ 'ReshearMode', 'Reshear' ]


[docs] class ReshearMode(StrEnum): """Initialization parameter and attribute to :class:`Reshear`.""" REPLICATE = auto() """Replicate the pixels of the left and right most columns horizontally.""" ZERO = auto() """Fill the surrounding pixels with zeros."""
[docs] class Reshear(torch.nn.Module): """Reshears a biplane. The biplanes rendered by :class:`BiplaneFromIntervals` and :class:`BiplaneFromTriangles` are sheared in comparison to the strip indexing the unravelled relative homology lattice to save on memory and computation. This module undoes the sheering inherents to these two classes and can be used for visualization purposes. Parameters ---------- shape : :class:`torch.Size` The shape of the tensors to be processed. Only the last two components matter, the others are ignored. mode : :class:`ReshearMode`, optional See :class:`ReshearMode` for details. The corresponding attribute :attr:`mode` can be changed after initialization. The default is :obj:`ReshearMode.ZERO`. """ def __init__(self, *, shape: torch.Size, mode: ReshearMode = ReshearMode.ZERO, device = None, dtype = None ) -> None: super().__init__() self.mode = mode """:class:`ReshearMode`: See :class:`ReshearMode` for details.""" rows, columns = shape[-2:] indices_bottom = torch.arange( start = 1, end = columns + rows, dtype = torch.long, device = device )[None, :] offsets = torch.arange( start = 1 - rows, end = 1, dtype = torch.long, device = device )[:, None] self._indices = torch.maximum( torch.tensor(0, dtype=torch.long, device=device), torch.minimum( torch.tensor(columns+1, dtype=torch.long, device=device), offsets + indices_bottom ) ) def _pad(self, input: torch.Tensor, /) -> torch.Tensor: match self.mode: case ReshearMode.REPLICATE: return torch.cat( dim = -1, tensors = [ input.narrow( dim = -1, start = 0, length = 1 ), input, input.narrow( dim = -1, start = -1, length = 1 ) ] ) case ReshearMode.ZERO: return F.pad( input, pad = (1, 1, 0, 0), mode = 'constant', value = 0 )
[docs] def forward(self, input: torch.Tensor, /) -> torch.Tensor: """Reshears the input. Parameters ---------- input : :class:`torch.Tensor` The tensor to be resheared. Returns ------- :class:`torch.Tensor` The resheared tensor. """ return self._pad( input ).take_along_dim( dim = -1, indices = self._indices.view( (1,) * (len(input.shape) - 2) + self._indices.shape ) )
[docs] @classmethod def create_n_apply(cls, input: torch.Tensor, *, mode: ReshearMode = ReshearMode.ZERO, device = None, dtype = None ) -> Tuple[torch.Tensor, None]: """Creates a matching instance and applies it. This class method creates an instance with the :obj:`shape` parameter matching the :obj:`input` and applies it to :obj:`input`. Both the processed input tensor as well as the instance are returned. Parameters ---------- input : :class:`torch.Tensor` The tensor to be resheared. mode : :class:`ReshearMode`, optional See :class:`ReshearMode` for details. The corresponding attribute :attr:`mode` can be changed after initialization. The default is :obj:`ReshearMode.ZERO`. Returns ------- :class:`Tuple[torch.Tensor, Reshear]` Pair with the first component containing the sheared tensor and the second component the instance that was used to create it. """ instance = cls( shape = input.shape, mode = mode, device = device, dtype = dtype ) return ( instance(input), instance )