from typing import Optional, Tuple
from math import pi
import torch
from torch.nn.functional import relu
__all__ = [ 'RasterTriangle' ]
def _scalar_product(u: torch.Tensor,
v: torch.Tensor,
/) -> torch.Tensor:
return ( u.unsqueeze(-2) @ v.unsqueeze(-1) ).squeeze( (-2, -1) )
def _split_endpoints(intervals: torch.Tensor,
/) -> torch.Tensor:
a_unsqueezed, b_unsqueezed = intervals.split( 1, dim = -1 )
return (
a_unsqueezed.squeeze( dim = -1 ),
b_unsqueezed.squeeze( dim = -1 )
)
[docs]
class RasterTriangle(torch.nn.Module):
"""Rasters a triangle corresponding to a single patch of a biplane.
This module takes persistence intervals within a finite range
:obj:`range_intervals` of two consecutive degrees as input
and rasters the corresponding single triangular patch
to an unsheared biplane.
Parameters
----------
pixel_columns : :class:`int`
The number of pixels used to raster each row or scanline
of the strip supporting the relative homology lattice.
range_intervals : :class:`Tuple[float, float]`, optional
The finite range containing all persistence intervals,
defaults to :obj:`(0.0, pi/2.0)`.
max_overhead : :class:`Optional[int]`, optional
If this is set,
then the input will be processed in batches
as small as necessary
to limit the number of bytes allocated as overhead
to at most :obj:`max_overhead`.
So if you're not already processing your data
in sufficiently small batches,
setting this parameter is recommended.
However,
if the number of bytes required to process a single sample
already exceeds :obj:`max_overhead`,
the input is still processed sample by sample.
The corresponding attribute :attr:`max_overhead`
can be changed after initialization.
The default is :obj:`None`.
"""
def __init__(self,
pixel_columns: int,
*,
range_intervals: Tuple[float, float] = (0.0, pi/2.0),
max_overhead: Optional[int] = None,
device = None,
dtype = None
) -> None:
super().__init__()
self.max_overhead = max_overhead
""":class:`Optional[int]`:
Limits overhead as described for :class:`RasterTriangle`."""
min, max = range_intervals
strip_width = max - min
side_len = strip_width / pixel_columns
self._strip_width = torch.tensor(
strip_width,
dtype = dtype,
device = device
)
self._side_len = torch.tensor(
side_len,
dtype = dtype,
device = device
)
self._pixel_area = side_len * side_len
self._px = torch.linspace(
start = min,
end = max - side_len,
steps = pixel_columns,
dtype = dtype,
device = device
)[:, None]
self._py = torch.linspace(
start = max - side_len,
end = min,
steps = pixel_columns,
dtype = dtype,
device = device
)[:, None]
@property
def pixel_area(self) -> float:
""":class:`float`: The area covered by each pixel."""
return self._pixel_area
def _outer_scalar_product(self,
u0: torch.Tensor,
v0: torch.Tensor,
/) -> torch.Tensor:
if self.max_overhead is None:
return _scalar_product( u0.unsqueeze(-2), v0.unsqueeze(-3) )
u = u0.flatten( end_dim = -3 )
v = v0.flatten( end_dim = -3 )
split_size = max(
1,
self.max_overhead // (
2 * u.shape[1] * v.shape[1] * v.shape[2] * u.element_size()
)
)
us = torch.split( u[:, :, None, :], split_size )
vs = torch.split( v[:, None, :, :], split_size )
return torch.cat(
[ _scalar_product(*pair) for pair in zip(us, vs) ]
).view(*u0.shape[:-1], v0.shape[-2])
def _hilbert_function_integrals(self,
x0: torch.Tensor,
y0: torch.Tensor,
/) -> torch.Tensor:
x = x0.unsqueeze(dim=-2)
y = y0.unsqueeze(dim=-2)
hor_seg = relu(
torch.minimum(
self._side_len,
torch.minimum(
self._px + self._side_len - x,
y - self._px
)
)
)
vert_seg = relu(
torch.minimum(
self._side_len,
torch.minimum(
self._py + self._side_len - y,
x + self._strip_width - self._py
)
)
)
return self._outer_scalar_product( vert_seg, hor_seg )
def _rank_invariant_integrals(self,
x0: torch.Tensor,
y0: torch.Tensor,
/) -> torch.Tensor:
x = x0.unsqueeze(dim=-2)
y = y0.unsqueeze(dim=-2)
hor_seg = relu(
torch.minimum(
self._side_len,
self._px + self._side_len - x
)
)
vert_seg = relu(
torch.minimum(
self._side_len,
y - self._py
)
)
return self._outer_scalar_product( vert_seg, hor_seg )
[docs]
def forward(self,
intervals00: torch.Tensor,
intervals01: torch.Tensor,
/) -> torch.Tensor:
"""Triangle created from persistence intervals
as described for :class:`RasterTriangle`.
Parameters
----------
intervals00 : :class:`torch.Tensor`
Persistence intervals for the lower
of the two consecutive degrees.
intervals01 : :class:`torch.Tensor`
Persistence intervals for the higher
of the two consecutive degrees.
Returns
-------
:class:`torch.Tensor`
A triangle with two channels
describing a single patch
to an unsheared corresponding biplane.
"""
a0, b0 = _split_endpoints( intervals00 )
a1, b1 = _split_endpoints( intervals01 )
return torch.stack(
dim = -3,
tensors = [
self._hilbert_function_integrals(
torch.cat(
( a0, b1 - self._strip_width ),
dim = -1
),
torch.cat(
( b0, a1 ),
dim = -1
)
),
self._rank_invariant_integrals(
a0,
b0
)
]
) / self._pixel_area