|
import torch |
|
import torch.nn as nn |
|
|
|
try: |
|
from inplace_abn import InPlaceABN |
|
except ImportError: |
|
InPlaceABN = None |
|
|
|
|
|
class Conv2dReLU(nn.Sequential): |
|
def __init__( |
|
self, |
|
in_channels, |
|
out_channels, |
|
kernel_size, |
|
padding=0, |
|
stride=1, |
|
use_batchnorm=True, |
|
): |
|
|
|
if use_batchnorm == "inplace" and InPlaceABN is None: |
|
raise RuntimeError( |
|
"In order to use `use_batchnorm='inplace'` inplace_abn package must be installed. " |
|
+ "To install see: https://github.com/mapillary/inplace_abn" |
|
) |
|
|
|
conv = nn.Conv2d( |
|
in_channels, |
|
out_channels, |
|
kernel_size, |
|
stride=stride, |
|
padding=padding, |
|
bias=not (use_batchnorm), |
|
) |
|
relu = nn.ReLU(inplace=True) |
|
|
|
if use_batchnorm == "inplace": |
|
bn = InPlaceABN(out_channels, activation="leaky_relu", activation_param=0.0) |
|
relu = nn.Identity() |
|
|
|
elif use_batchnorm and use_batchnorm != "inplace": |
|
bn = nn.BatchNorm2d(out_channels) |
|
|
|
else: |
|
bn = nn.Identity() |
|
|
|
super(Conv2dReLU, self).__init__(conv, bn, relu) |
|
|
|
|
|
class SCSEModule(nn.Module): |
|
def __init__(self, in_channels, reduction=16): |
|
super().__init__() |
|
self.cSE = nn.Sequential( |
|
nn.AdaptiveAvgPool2d(1), |
|
nn.Conv2d(in_channels, in_channels // reduction, 1), |
|
nn.ReLU(inplace=True), |
|
nn.Conv2d(in_channels // reduction, in_channels, 1), |
|
nn.Sigmoid(), |
|
) |
|
self.sSE = nn.Sequential(nn.Conv2d(in_channels, 1, 1), nn.Sigmoid()) |
|
|
|
def forward(self, x): |
|
return x * self.cSE(x) + x * self.sSE(x) |
|
|
|
|
|
class ArgMax(nn.Module): |
|
def __init__(self, dim=None): |
|
super().__init__() |
|
self.dim = dim |
|
|
|
def forward(self, x): |
|
return torch.argmax(x, dim=self.dim) |
|
|
|
|
|
class Clamp(nn.Module): |
|
def __init__(self, min=0, max=1): |
|
super().__init__() |
|
self.min, self.max = min, max |
|
|
|
def forward(self, x): |
|
return torch.clamp(x, self.min, self.max) |
|
|
|
|
|
class Activation(nn.Module): |
|
def __init__(self, name, **params): |
|
|
|
super().__init__() |
|
|
|
if name is None or name == "identity": |
|
self.activation = nn.Identity(**params) |
|
elif name == "sigmoid": |
|
self.activation = nn.Sigmoid() |
|
elif name == "softmax2d": |
|
self.activation = nn.Softmax(dim=1, **params) |
|
elif name == "softmax": |
|
self.activation = nn.Softmax(**params) |
|
elif name == "logsoftmax": |
|
self.activation = nn.LogSoftmax(**params) |
|
elif name == "tanh": |
|
self.activation = nn.Tanh() |
|
elif name == "argmax": |
|
self.activation = ArgMax(**params) |
|
elif name == "argmax2d": |
|
self.activation = ArgMax(dim=1, **params) |
|
elif name == "clamp": |
|
self.activation = Clamp(**params) |
|
elif callable(name): |
|
self.activation = name(**params) |
|
else: |
|
raise ValueError( |
|
f"Activation should be callable/sigmoid/softmax/logsoftmax/tanh/" |
|
f"argmax/argmax2d/clamp/None; got {name}" |
|
) |
|
|
|
def forward(self, x): |
|
return self.activation(x) |
|
|
|
|
|
class Attention(nn.Module): |
|
def __init__(self, name, **params): |
|
super().__init__() |
|
|
|
if name is None: |
|
self.attention = nn.Identity(**params) |
|
elif name == "scse": |
|
self.attention = SCSEModule(**params) |
|
else: |
|
raise ValueError("Attention {} is not implemented".format(name)) |
|
|
|
def forward(self, x): |
|
return self.attention(x) |
|
|