Source code for train_test

from typing import Callable, NamedTuple
from collections.abc import Sequence
from statistics import mean

import torch

[docs] class Result(NamedTuple): loss: torch.Tensor nb_correct: float
[docs] class ProcessBatch(NamedTuple): model: torch.nn.Module loss_fn: Callable[torch.Tensor, torch.Tensor] binary: bool = False
[docs] def __call__(self, X: torch.Tensor, label: torch.Tensor ) -> Result: pred = self.model(X) return Result( loss = self.loss_fn( pred, label ), nb_correct = ( ( pred.round() if self.binary else pred.argmax(1) ) == label ).type(torch.float).sum().item() )
[docs] class Optimize(NamedTuple): optimizer: torch.optim.Optimizer
[docs] def __call__(self, result: Result ) -> torch.Tensor: result.loss.backward() self.optimizer.step() self.optimizer.zero_grad() return result
[docs] def accumulate_loss(results: Sequence[Result]) -> float: return mean( result.loss.item() for result in results )
[docs] def accumulate_score(results: Sequence[Result]) -> float: return sum( result.nb_correct for result in results )