from typing import Callable, Optional, Tuple
from collections.abc import Sequence
import itertools as it
from math import pi
import torch
from torch.nn.functional import relu
__all__ = [ 'HilbertKernel' ]
def _make_inner(f: Callable[[torch.nn.Module,
torch.Tensor,
torch.Tensor
],
torch.Tensor]
) -> Callable[[torch.nn.Module,
torch.Tensor,
torch.Tensor
],
torch.Tensor]:
return lambda self, a, b: f(
self,
a.unsqueeze(-3),
b.unsqueeze(-2)
).squeeze(-1).sum( dim = (-2, -1) )
[docs]
class HilbertKernel(torch.nn.Module):
"""Kernel induced by the embedding as Hilbert functions.
This is the kernel that is induced by the feature map
that sends graded persistence intervals
to the Hilbert function of the corresponding
unravelled relative homology lattice
using the inner product of square-integrable functions.
Parameters
----------
range_intervals : :class:`Tuple[float, float]`, optional
The finite range containing all persistence intervals,
defaults to :obj:`(0.0, pi/2.0)`.
partial_info : :class:`bool`, optional
If this is set to :obj:`True`,
the inner product of the corresponding truncated Hilbert functions
is computed.
This is sensible,
whenever persistence intervals are only known
up to a certain degree.
The corresponding attribute :attr:`partial_info`
can be changed after initialization.
The default is :obj:`False`.
"""
def __init__(self,
*,
range_intervals: Tuple[float, float] = (0.0, pi/2.0),
partial_info: bool = False,
device = None,
dtype = None
) -> None:
super().__init__()
min, max = range_intervals
self.partial_info = partial_info
""":class:`bool`:
Whether Hilbert functions should be truncated
before computing the inner product;
also see the description of the corresponding parameter.
"""
self._strip_width = torch.tensor(
max - min,
dtype = dtype,
device = device
)
self._min = torch.tensor(
min,
dtype = dtype,
device = device
)
self._zero = torch.tensor(
0,
dtype = dtype,
device = device
)
@_make_inner
def _area_same_degrees(self,
a: torch.Tensor,
b: torch.Tensor
) -> torch.Tensor:
join = torch.maximum( a, b )
meet = torch.minimum( a, b )
return (
relu(
meet.narrow( dim = -1, start = 1, length = 1 ) -
join.narrow( dim = -1, start = 0, length = 1 )
) *
relu(
self._strip_width +
meet.narrow( dim = -1, start = 0, length = 1 ) -
join.narrow( dim = -1, start = 1, length = 1 )
)
)
@_make_inner
def _area_consecutive_degrees(self,
a: torch.Tensor,
b: torch.Tensor
) -> torch.Tensor:
return (
relu(
torch.minimum(
b.narrow( dim = -1, start = 0, length = 1 ),
a.narrow( dim = -1, start = 1, length = 1 )
) -
a.narrow( dim = -1, start = 0, length = 1 )
) *
relu(
b.narrow( dim = -1, start = 1, length = 1 ) -
torch.maximum(
a.narrow( dim = -1, start = 1, length = 1 ),
b.narrow( dim = -1, start = 0, length = 1 )
)
)
)
@_make_inner
def _area_last_degree_partial_info(self,
a: torch.Tensor,
b: torch.Tensor
) -> torch.Tensor:
join = torch.maximum( a, b )
meet = torch.minimum( a, b )
return (
relu(
meet.narrow( dim = -1, start = 1, length = 1 ) -
join.narrow( dim = -1, start = 0, length = 1 )
) *
relu(
meet.narrow( dim = -1, start = 0, length = 1 ) -
self._min
)
)
def _peel_off_residual(self,
intervals_pairs: Sequence[Tuple[torch.Tensor,
torch.Tensor]
]
) -> Tuple[Sequence[Tuple[torch.Tensor,
torch.Tensor]],
torch.Tensor]:
if self.partial_info:
return (
intervals_pairs[:-1],
self._area_last_degree_partial_info(
*intervals_pairs[-1]
)
)
else:
return (
intervals_pairs,
self._zero
)
[docs]
def forward(self,
intervals01: Sequence[torch.Tensor],
intervals02: Sequence[torch.Tensor],
/) -> torch.Tensor:
"""Computes the inner product of the corresponding Hilbert functions.
Parameters
----------
intervals01 : :class:`Sequence[torch.Tensor]`
Persistence intervals
for the first input
as a :class:`Sequence[torch.Tensor]` by degree.
So each item of this sequence
is a :class:`torch.Tensor` of shape
:math:`([\dots,] k, 2)`,
where :math:`k` is the number of persistence intervals
in the corresponding degree.
intervals02 : :class:`Sequence[torch.Tensor]`
Persistence intervals
for the second input
as a :class:`Sequence[torch.Tensor]` by degree
analogous to the first parameter :obj:`intervals01`.
Returns
-------
:class:`torch.Tensor`
The inner product of the corresponding Hilbert functions.
"""
intervals_pairs, residual = self._peel_off_residual(
zip( intervals01, intervals02 )
)
return (
sum( self._area_same_degrees(*intervals_pair) for
intervals_pair in intervals_pairs
) +
sum( self._area_consecutive_degrees(*intervals_pair) for
intervals_pair in zip(intervals01, intervals02[1:])
) +
sum( self._area_consecutive_degrees(*intervals_pair) for
intervals_pair in zip(intervals02, intervals01[1:])
) +
residual
)