Source code for persunraveltorch.nn.biplane_from_triangles

from typing import Optional
from collections.abc import Iterable

import itertools as it

import torch

from ._glue_to_fragment import GlueToFragment


__all__ = [ 'BiplaneFromTriangles' ]


[docs] class BiplaneFromTriangles(torch.nn.Module): """Asembles a biplane from triangles. This module takes triangles as created by :class:`RasterTriangle` as an input and assembles a biplane from these triangles. Parameters ---------- pixel_columns : int The number of pixels used to raster each row or scanline of the strip assembled from triangles. max_overhead : Optional[int] If this is set, the input will be processed in batches as small as necessary to limit the number of bytes allocated as overhead to at most :obj:`max_overhead`. So if you're not already processing your data in sufficiently small batches, setting this parameter is recommended. However, if the number of bytes required to process a single sample already exceeds :obj:`max_overhead`, the input is still processed sample by sample. padding : int The amount of horizontal padding being added to both, the left and the right side. """ def __init__(self, *, pixel_columns: int, padding: int, max_overhead: Optional[int] = None, device = None, dtype = None ) -> None: super().__init__() self._pixel_columns = pixel_columns self._padding = padding self._glue_to_fragment = GlueToFragment( pixel_columns = pixel_columns, padding = padding, max_overhead = max_overhead ) self._zero_triangle = torch.zeros( pixel_columns, pixel_columns ) self._zero_fragment = torch.zeros( pixel_columns, pixel_columns + 2 * padding, dtype = dtype, device = device ) self._pad_with_scanline = torch.nn.ZeroPad2d( (0, 0, 1, 1) )
[docs] def forward(self, triangles: Iterable[torch.Tensor], /) -> torch.Tensor: """Asemble a biplane from triangles. Takes triangles as created by :class:`RasterTriangle` as an input and assembles a biplane from these triangles. Parameters ---------- triangles : Iterable[torch.Tensor] Triangles as for example created by :class:`RasterTriangle`. Returns ------- torch.Tensor A biplane assembled from `triangles`. """ triangles_iterator = iter( triangles ) triangle0 = next( triangles_iterator ) zero_triangle = self._zero_triangle.view( (1,) * (len(triangle0.shape) - 2) + self._zero_triangle.shape ).expand( triangle0.shape ) plane00_fragments: list[torch.Tensor] = [] plane01_fragments: list[torch.Tensor] = [] for pair_or_singleton in it.batched( n = 2, iterable = ( self._glue_to_fragment(*triangle_pair) for triangle_pair in it.pairwise( it.chain( [zero_triangle, triangle0], triangles_iterator, [zero_triangle] ) ) ) ): plane01_fragments.append( pair_or_singleton[0] ) plane00_fragments.append( ( pair_or_singleton + (self._zero_fragment,) )[1] ) plane00 = torch.cat( plane00_fragments, dim = -2 ) plane01 = torch.cat( plane01_fragments, dim = -2 ) return self._pad_with_scanline( torch.stack( (plane00, plane01), dim = -3 ) )