File size: 7,420 Bytes
938e515 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 |
from collections import OrderedDict
import torch.nn as nn
from .bn import ABN, ACT_LEAKY_RELU, ACT_ELU, ACT_NONE
import torch.nn.functional as functional
class ResidualBlock(nn.Module):
"""Configurable residual block
Parameters
----------
in_channels : int
Number of input channels.
channels : list of int
Number of channels in the internal feature maps. Can either have two or three elements: if three construct
a residual block with two `3 x 3` convolutions, otherwise construct a bottleneck block with `1 x 1`, then
`3 x 3` then `1 x 1` convolutions.
stride : int
Stride of the first `3 x 3` convolution
dilation : int
Dilation to apply to the `3 x 3` convolutions.
groups : int
Number of convolution groups. This is used to create ResNeXt-style blocks and is only compatible with
bottleneck blocks.
norm_act : callable
Function to create normalization / activation Module.
dropout: callable
Function to create Dropout Module.
"""
def __init__(self,
in_channels,
channels,
stride=1,
dilation=1,
groups=1,
norm_act=ABN,
dropout=None):
super(ResidualBlock, self).__init__()
# Check parameters for inconsistencies
if len(channels) != 2 and len(channels) != 3:
raise ValueError("channels must contain either two or three values")
if len(channels) == 2 and groups != 1:
raise ValueError("groups > 1 are only valid if len(channels) == 3")
is_bottleneck = len(channels) == 3
need_proj_conv = stride != 1 or in_channels != channels[-1]
if not is_bottleneck:
bn2 = norm_act(channels[1])
bn2.activation = ACT_NONE
layers = [
("conv1", nn.Conv2d(in_channels, channels[0], 3, stride=stride, padding=dilation, bias=False,
dilation=dilation)),
("bn1", norm_act(channels[0])),
("conv2", nn.Conv2d(channels[0], channels[1], 3, stride=1, padding=dilation, bias=False,
dilation=dilation)),
("bn2", bn2)
]
if dropout is not None:
layers = layers[0:2] + [("dropout", dropout())] + layers[2:]
else:
bn3 = norm_act(channels[2])
bn3.activation = ACT_NONE
layers = [
("conv1", nn.Conv2d(in_channels, channels[0], 1, stride=1, padding=0, bias=False)),
("bn1", norm_act(channels[0])),
("conv2", nn.Conv2d(channels[0], channels[1], 3, stride=stride, padding=dilation, bias=False,
groups=groups, dilation=dilation)),
("bn2", norm_act(channels[1])),
("conv3", nn.Conv2d(channels[1], channels[2], 1, stride=1, padding=0, bias=False)),
("bn3", bn3)
]
if dropout is not None:
layers = layers[0:4] + [("dropout", dropout())] + layers[4:]
self.convs = nn.Sequential(OrderedDict(layers))
if need_proj_conv:
self.proj_conv = nn.Conv2d(in_channels, channels[-1], 1, stride=stride, padding=0, bias=False)
self.proj_bn = norm_act(channels[-1])
self.proj_bn.activation = ACT_NONE
def forward(self, x):
if hasattr(self, "proj_conv"):
residual = self.proj_conv(x)
residual = self.proj_bn(residual)
else:
residual = x
x = self.convs(x) + residual
if self.convs.bn1.activation == ACT_LEAKY_RELU:
return functional.leaky_relu(x, negative_slope=self.convs.bn1.slope, inplace=True)
elif self.convs.bn1.activation == ACT_ELU:
return functional.elu(x, inplace=True)
else:
return x
class IdentityResidualBlock(nn.Module):
def __init__(self,
in_channels,
channels,
stride=1,
dilation=1,
groups=1,
norm_act=ABN,
dropout=None):
"""Configurable identity-mapping residual block
Parameters
----------
in_channels : int
Number of input channels.
channels : list of int
Number of channels in the internal feature maps. Can either have two or three elements: if three construct
a residual block with two `3 x 3` convolutions, otherwise construct a bottleneck block with `1 x 1`, then
`3 x 3` then `1 x 1` convolutions.
stride : int
Stride of the first `3 x 3` convolution
dilation : int
Dilation to apply to the `3 x 3` convolutions.
groups : int
Number of convolution groups. This is used to create ResNeXt-style blocks and is only compatible with
bottleneck blocks.
norm_act : callable
Function to create normalization / activation Module.
dropout: callable
Function to create Dropout Module.
"""
super(IdentityResidualBlock, self).__init__()
# Check parameters for inconsistencies
if len(channels) != 2 and len(channels) != 3:
raise ValueError("channels must contain either two or three values")
if len(channels) == 2 and groups != 1:
raise ValueError("groups > 1 are only valid if len(channels) == 3")
is_bottleneck = len(channels) == 3
need_proj_conv = stride != 1 or in_channels != channels[-1]
self.bn1 = norm_act(in_channels)
if not is_bottleneck:
layers = [
("conv1", nn.Conv2d(in_channels, channels[0], 3, stride=stride, padding=dilation, bias=False,
dilation=dilation)),
("bn2", norm_act(channels[0])),
("conv2", nn.Conv2d(channels[0], channels[1], 3, stride=1, padding=dilation, bias=False,
dilation=dilation))
]
if dropout is not None:
layers = layers[0:2] + [("dropout", dropout())] + layers[2:]
else:
layers = [
("conv1", nn.Conv2d(in_channels, channels[0], 1, stride=stride, padding=0, bias=False)),
("bn2", norm_act(channels[0])),
("conv2", nn.Conv2d(channels[0], channels[1], 3, stride=1, padding=dilation, bias=False,
groups=groups, dilation=dilation)),
("bn3", norm_act(channels[1])),
("conv3", nn.Conv2d(channels[1], channels[2], 1, stride=1, padding=0, bias=False))
]
if dropout is not None:
layers = layers[0:4] + [("dropout", dropout())] + layers[4:]
self.convs = nn.Sequential(OrderedDict(layers))
if need_proj_conv:
self.proj_conv = nn.Conv2d(in_channels, channels[-1], 1, stride=stride, padding=0, bias=False)
def forward(self, x):
if hasattr(self, "proj_conv"):
bn1 = self.bn1(x)
shortcut = self.proj_conv(bn1)
else:
shortcut = x.clone()
bn1 = self.bn1(x)
out = self.convs(bn1)
out.add_(shortcut)
return out
|