from typing import Callable, NamedTuple
from collections.abc import Sequence
from statistics import mean
import torch
[docs]
class Result(NamedTuple):
loss: torch.Tensor
nb_correct: float
[docs]
class ProcessBatch(NamedTuple):
model: torch.nn.Module
loss_fn: Callable[torch.Tensor, torch.Tensor]
binary: bool = False
[docs]
def __call__(self,
X: torch.Tensor,
label: torch.Tensor
) -> Result:
pred = self.model(X)
return Result(
loss = self.loss_fn( pred, label ),
nb_correct = (
( pred.round() if self.binary else pred.argmax(1) ) == label
).type(torch.float).sum().item()
)
[docs]
class Optimize(NamedTuple):
optimizer: torch.optim.Optimizer
[docs]
def __call__(self,
result: Result
) -> torch.Tensor:
result.loss.backward()
self.optimizer.step()
self.optimizer.zero_grad()
return result
[docs]
def accumulate_loss(results: Sequence[Result]) -> float:
return mean( result.loss.item() for result in results )
[docs]
def accumulate_score(results: Sequence[Result]) -> float:
return sum( result.nb_correct for result in results )