import torch
import torch.nn as nn
[docs]
class StandardScaler(nn.Module):
def __init__(self,
std: torch.Tensor,
mean: torch.Tensor
) -> None:
super().__init__()
self.std = std
self.mean = mean
[docs]
def forward(self,
input: torch.Tensor,
/) -> torch.Tensor:
return ( input - self.mean ) / self.std