Source code for persunraveltorch.nn.unravel

from typing import Tuple
from collections.abc import Iterable, Sequence
from math import pi

import itertools as it

import torch
import torch.nn as nn


__all__ = [ 'Unravel' ]


def _bool_to_sgn(val: bool) -> float:
    if val:
        return -1.0
    else:
        return  1.0


[docs] class Unravel(nn.Module): """Computes the associated unravelled persistence diagram. This module computes the unravelled persistence diagram associated to persistence intervals within a finite range :obj:`range_intervals`. Parameters ---------- range_intervals : :class:`Tuple[float, float]`, optional The finite range containing all persistence intervals, defaults to :obj:`(0.0, pi/2.0)`. neg_x : :class:`bool`, optional Whether the x-axis is negated with respect to Cartesian coordintes, defaults to :obj:`False` neg_y : :class:`bool`, optional Whether the y-axis is negated with respect to Cartesian coordintes, defaults to :obj:`False` """ def __init__(self, *, range_intervals: Tuple[float, float] = (0.0, pi/2.0), neg_x: bool = False, neg_y: bool = False, device = None, dtype = None ) -> None: super().__init__() self._identity = torch.tensor( [[_bool_to_sgn(neg_x), 0.0], [ 0.0, _bool_to_sgn(neg_y)] ], device = device, dtype = dtype ) self._reflection = torch.tensor( [[ 0.0, _bool_to_sgn(neg_y)], [_bool_to_sgn(neg_x), 0.0] ], device = device, dtype = dtype ) min, max = range_intervals diff = min - max self._offsets = torch.tensor( [[_bool_to_sgn(neg_x) * diff, 0.0], [ 0.0, _bool_to_sgn(neg_y) * diff] ], device = device, dtype = dtype ) self._e1 = torch.tensor( [1.0, 0.0], device = device, dtype = dtype ) self._e2 = torch.tensor( [0.0, 1.0], device = device, dtype = dtype ) self._zero = torch.tensor( [0.0, 0.0], device = device, dtype = dtype )
[docs] def forward(self, intervals: Iterable[torch.Tensor], /) -> Sequence[torch.Tensor]: """Computes the associated unravelled persistence diagram. Parameters ---------- intervals : :class:`Iterable[torch.Tensor]` Persistence intervals by degree. Returns ------- :class:`Sequence[torch.Tensor]` The associated unravelled persistence diagram. """ reflection_powers = it.cycle( (self._identity, self._reflection) ) offset_multipliers = it.accumulate( initial = self._zero, iterable = it.cycle( (self._e1, self._e2) ) ) return [ intervals_d @ reflection_power + offset_multiplier @ self._offsets for intervals_d, reflection_power, offset_multiplier in zip(intervals, reflection_powers, offset_multipliers) ]