ghlee94's picture
Init
2a13495
import torch
import torch.nn as nn
try:
from inplace_abn import InPlaceABN
except ImportError:
InPlaceABN = None
class Conv2dReLU(nn.Sequential):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
padding=0,
stride=1,
use_batchnorm=True,
):
if use_batchnorm == "inplace" and InPlaceABN is None:
raise RuntimeError(
"In order to use `use_batchnorm='inplace'` inplace_abn package must be installed. "
+ "To install see: https://github.com/mapillary/inplace_abn"
)
conv = nn.Conv2d(
in_channels,
out_channels,
kernel_size,
stride=stride,
padding=padding,
bias=not (use_batchnorm),
)
relu = nn.ReLU(inplace=True)
if use_batchnorm == "inplace":
bn = InPlaceABN(out_channels, activation="leaky_relu", activation_param=0.0)
relu = nn.Identity()
elif use_batchnorm and use_batchnorm != "inplace":
bn = nn.BatchNorm2d(out_channels)
else:
bn = nn.Identity()
super(Conv2dReLU, self).__init__(conv, bn, relu)
class SCSEModule(nn.Module):
def __init__(self, in_channels, reduction=16):
super().__init__()
self.cSE = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(in_channels, in_channels // reduction, 1),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels // reduction, in_channels, 1),
nn.Sigmoid(),
)
self.sSE = nn.Sequential(nn.Conv2d(in_channels, 1, 1), nn.Sigmoid())
def forward(self, x):
return x * self.cSE(x) + x * self.sSE(x)
class ArgMax(nn.Module):
def __init__(self, dim=None):
super().__init__()
self.dim = dim
def forward(self, x):
return torch.argmax(x, dim=self.dim)
class Clamp(nn.Module):
def __init__(self, min=0, max=1):
super().__init__()
self.min, self.max = min, max
def forward(self, x):
return torch.clamp(x, self.min, self.max)
class Activation(nn.Module):
def __init__(self, name, **params):
super().__init__()
if name is None or name == "identity":
self.activation = nn.Identity(**params)
elif name == "sigmoid":
self.activation = nn.Sigmoid()
elif name == "softmax2d":
self.activation = nn.Softmax(dim=1, **params)
elif name == "softmax":
self.activation = nn.Softmax(**params)
elif name == "logsoftmax":
self.activation = nn.LogSoftmax(**params)
elif name == "tanh":
self.activation = nn.Tanh()
elif name == "argmax":
self.activation = ArgMax(**params)
elif name == "argmax2d":
self.activation = ArgMax(dim=1, **params)
elif name == "clamp":
self.activation = Clamp(**params)
elif callable(name):
self.activation = name(**params)
else:
raise ValueError(
f"Activation should be callable/sigmoid/softmax/logsoftmax/tanh/"
f"argmax/argmax2d/clamp/None; got {name}"
)
def forward(self, x):
return self.activation(x)
class Attention(nn.Module):
def __init__(self, name, **params):
super().__init__()
if name is None:
self.attention = nn.Identity(**params)
elif name == "scse":
self.attention = SCSEModule(**params)
else:
raise ValueError("Attention {} is not implemented".format(name))
def forward(self, x):
return self.attention(x)