import torch import torch.nn as nn try: from inplace_abn import InPlaceABN except ImportError: InPlaceABN = None class Conv2dReLU(nn.Sequential): def __init__( self, in_channels, out_channels, kernel_size, padding=0, stride=1, use_batchnorm=True, ): if use_batchnorm == "inplace" and InPlaceABN is None: raise RuntimeError( "In order to use `use_batchnorm='inplace'` inplace_abn package must be installed. " + "To install see: https://github.com/mapillary/inplace_abn" ) conv = nn.Conv2d( in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=not (use_batchnorm), ) relu = nn.ReLU(inplace=True) if use_batchnorm == "inplace": bn = InPlaceABN(out_channels, activation="leaky_relu", activation_param=0.0) relu = nn.Identity() elif use_batchnorm and use_batchnorm != "inplace": bn = nn.BatchNorm2d(out_channels) else: bn = nn.Identity() super(Conv2dReLU, self).__init__(conv, bn, relu) class SCSEModule(nn.Module): def __init__(self, in_channels, reduction=16): super().__init__() self.cSE = nn.Sequential( nn.AdaptiveAvgPool2d(1), nn.Conv2d(in_channels, in_channels // reduction, 1), nn.ReLU(inplace=True), nn.Conv2d(in_channels // reduction, in_channels, 1), nn.Sigmoid(), ) self.sSE = nn.Sequential(nn.Conv2d(in_channels, 1, 1), nn.Sigmoid()) def forward(self, x): return x * self.cSE(x) + x * self.sSE(x) class ArgMax(nn.Module): def __init__(self, dim=None): super().__init__() self.dim = dim def forward(self, x): return torch.argmax(x, dim=self.dim) class Clamp(nn.Module): def __init__(self, min=0, max=1): super().__init__() self.min, self.max = min, max def forward(self, x): return torch.clamp(x, self.min, self.max) class Activation(nn.Module): def __init__(self, name, **params): super().__init__() if name is None or name == "identity": self.activation = nn.Identity(**params) elif name == "sigmoid": self.activation = nn.Sigmoid() elif name == "softmax2d": self.activation = nn.Softmax(dim=1, **params) elif name == "softmax": self.activation = nn.Softmax(**params) elif name == "logsoftmax": self.activation = nn.LogSoftmax(**params) elif name == "tanh": self.activation = nn.Tanh() elif name == "argmax": self.activation = ArgMax(**params) elif name == "argmax2d": self.activation = ArgMax(dim=1, **params) elif name == "clamp": self.activation = Clamp(**params) elif callable(name): self.activation = name(**params) else: raise ValueError( f"Activation should be callable/sigmoid/softmax/logsoftmax/tanh/" f"argmax/argmax2d/clamp/None; got {name}" ) def forward(self, x): return self.activation(x) class Attention(nn.Module): def __init__(self, name, **params): super().__init__() if name is None: self.attention = nn.Identity(**params) elif name == "scse": self.attention = SCSEModule(**params) else: raise ValueError("Attention {} is not implemented".format(name)) def forward(self, x): return self.attention(x)