Source code for gram_test

from typing import Callable
from collections.abc import Sequence
import torch


__all__ = [ 'GramTest' ]


[docs] class GramTest(torch.nn.Module): def __init__(self, *, gram: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], support_indices: torch.Tensor, len_train: int, support_vectors, device = None, dtype = None ) -> None: super().__init__() self._gram = gram self._index = support_indices self._len_train = len_train self._support_vectors = support_vectors self._device = device self._dtype = dtype
[docs] def forward(self, test: torch.Tensor, /) -> torch.Tensor: compressed_gram_test = self._gram( test, self._support_vectors ) fut_gram_test = torch.zeros( *compressed_gram_test.shape[:-1], self._len_train, device = self._device, dtype = self._dtype ) return fut_gram_test.index_copy_( dim = -1, index = self._index, source = compressed_gram_test )