Source code for channel_select2d

import torch
import torch.nn as nn

[docs] class ChannelSelect2d(nn.Module): def __init__(self, channel: int) -> None: super().__init__() self.channel = channel
[docs] def forward(self, input: torch.Tensor ) -> torch.Tensor: return input.narrow( dim = -3, start = self.channel, length = 1 ).squeeze( dim = -3 )