Source code for persunraveltorch.nn.hilbert_gram

from collections.abc import Sequence
import torch

from .hilbert_kernel import HilbertKernel


__all__ = [ 'HilbertGram' ]


[docs] class HilbertGram(HilbertKernel): r"""As :class:`HilbertKernel` but for different sizes in dimension :obj:`-3`. If the first input consists of tensors of shape :math:`([\dots,] m, k, 2)` for some :math:`{m \in \N}` and the second input consists of tensors of shape :math:`([\dots,] n, k, 2)` for some :math:`{n \in \N}`, then the output has shape :math:`([\dots,] m, n)` containing all results of pairwise kernel computations as a Gram matrix. Parameters and attributes are inherited from :class:`HilbertKernel`. """
[docs] def forward(self, intervals01: Sequence[torch.Tensor], intervals02: Sequence[torch.Tensor], /) -> torch.Tensor: """Computes the Gram matrix of the corresponding Hilbert functions. Parameters ---------- intervals01 : :class:`Sequence[torch.Tensor]` Persistence intervals for the first input as a :class:`Sequence[torch.Tensor]` by degree. Each item of this sequence is a :class:`torch.Tensor` of shape :math:`([\dots,] m, k, 2)`, where :math:`k` is the number of persistence intervals in the corresponding degree. intervals02 : :class:`Sequence[torch.Tensor]` Persistence intervals for the second input as a :class:`Sequence[torch.Tensor]` by degree. Each item of this sequence is a :class:`torch.Tensor` of shape :math:`([\dots,] n, k, 2)`, where :math:`k` is the number of persistence intervals in the corresponding degree. Returns ------- :class:`torch.Tensor` The Gram matrix of the corresponding Hilbert functions. """ return super().forward( [ intervals_sd.unsqueeze(-3) for intervals_sd in intervals01 ], [ intervals_sd.unsqueeze(-4) for intervals_sd in intervals02 ] )