Source code for persunraveltorch.nn.functional.max_pool_biplane

from typing import Optional

import torch

from .._kernel_sizes import _size_2_t


__all__ = [ 'max_pool_biplane' ]


[docs] def max_pool_biplane( input: torch.Tensor, kernel_size: _size_2_t, stride: Optional[_size_2_t] = None, padding: _size_2_t = 0, dilation: _size_2_t = 1, return_indices: bool = False, ceil_mode: bool = False ) -> torch.Tensor: """Applies max pooling to a biplane. See :class:`MaxPoolBiplane` for details. """ processed = torch.nn.functional.max_pool2d( input.flatten( end_dim = -4 ), kernel_size = kernel_size, stride = stride, padding = padding, dilation = dilation, return_indices = return_indices, ceil_mode = ceil_mode ) if return_indices: return tuple( tensor.view(*input.shape[:-3], *tensor.shape[-3:]) for tensor in processed ) else: return processed.view(*input.shape[:-3], *processed.shape[-3:])