import torch
import torch.nn as nn
[docs]
class ChannelSelect2d(nn.Module):
def __init__(self,
channel: int) -> None:
super().__init__()
self.channel = channel
[docs]
def forward(self,
input: torch.Tensor
) -> torch.Tensor:
return input.narrow(
dim = -3,
start = self.channel,
length = 1
).squeeze( dim = -3 )