from typing import Callable
from collections.abc import Sequence
import torch
__all__ = [ 'GramTest' ]
[docs]
class GramTest(torch.nn.Module):
def __init__(self,
*,
gram: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
support_indices: torch.Tensor,
len_train: int,
support_vectors,
device = None,
dtype = None
) -> None:
super().__init__()
self._gram = gram
self._index = support_indices
self._len_train = len_train
self._support_vectors = support_vectors
self._device = device
self._dtype = dtype
[docs]
def forward(self,
test: torch.Tensor,
/) -> torch.Tensor:
compressed_gram_test = self._gram( test, self._support_vectors )
fut_gram_test = torch.zeros(
*compressed_gram_test.shape[:-1],
self._len_train,
device = self._device,
dtype = self._dtype
)
return fut_gram_test.index_copy_(
dim = -1,
index = self._index,
source = compressed_gram_test
)