Ricoooo's picture
'folder'
5d21dd2
raw
history blame
8.38 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
class SEAttention(nn.Module):
def __init__(self, in_channels, out_channels, reduction=8):
super(SEAttention, self).__init__()
self.se = nn.Sequential(
nn.AdaptiveAvgPool2d((1, 1)),
nn.Conv2d(in_channels=in_channels, out_channels=out_channels // reduction, kernel_size=1, bias=False),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels=out_channels // reduction, out_channels=out_channels, kernel_size=1, bias=False),
nn.Sigmoid()
)
def forward(self, x):
x = self.se(x) * x
return x
class ChannelAttention(nn.Module):
def __init__(self, in_channels, out_channels, reduction=8):
super(ChannelAttention, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
self.max_pool = nn.AdaptiveMaxPool2d((1, 1))
self.fc = nn.Sequential(nn.Conv2d(in_channels=in_channels, out_channels=out_channels // reduction, kernel_size=1, bias=False),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels=out_channels // reduction, out_channels=out_channels, kernel_size=1, bias=False))
self.sigmoid = nn.Sigmoid()
def forward(self, x):
avg_out = self.fc(self.avg_pool(x))
max_out = self.fc(self.max_pool(x))
out = avg_out + max_out
return self.sigmoid(out)
class SpatialAttention(nn.Module):
def __init__(self, kernel_size=7):
super(SpatialAttention, self).__init__()
self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=kernel_size // 2, bias=False)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
avg_out = torch.mean(x, dim=1, keepdim=True)
max_out, _ = torch.max(x, dim=1, keepdim=True)
x = torch.cat([avg_out, max_out], dim=1)
x = self.conv1(x)
return self.sigmoid(x)
class CBAMAttention(nn.Module):
def __init__(self, in_channels, out_channels, reduction=8):
super(CBAMAttention, self).__init__()
self.ca = ChannelAttention(in_channels=in_channels, out_channels=out_channels, reduction=reduction)
self.sa = SpatialAttention()
def forward(self, x):
x = self.ca(x) * x
x = self.sa(x) * x
return x
class h_sigmoid(nn.Module):
def __init__(self, inplace=True):
super(h_sigmoid, self).__init__()
self.relu = nn.ReLU6(inplace=inplace)
def forward(self, x):
return self.relu(x + 3) / 6
class h_swish(nn.Module):
def __init__(self, inplace=True):
super(h_swish, self).__init__()
self.sigmoid = h_sigmoid(inplace=inplace)
def forward(self, x):
return x * self.sigmoid(x)
class CoordAttention(nn.Module):
def __init__(self, in_channels, out_channels, reduction=8):
super(CoordAttention, self).__init__()
self.pool_w, self.pool_h = nn.AdaptiveAvgPool2d((1, None)), nn.AdaptiveAvgPool2d((None, 1))
temp_c = max(8, in_channels // reduction)
self.conv1 = nn.Conv2d(in_channels, temp_c, kernel_size=1, stride=1, padding=0)
self.bn1 = nn.InstanceNorm2d(temp_c)
self.act1 = h_swish() # nn.SiLU() # nn.Hardswish() # nn.SiLU()
self.conv2 = nn.Conv2d(temp_c, out_channels, kernel_size=1, stride=1, padding=0)
self.conv3 = nn.Conv2d(temp_c, out_channels, kernel_size=1, stride=1, padding=0)
def forward(self, x):
short = x
n, c, H, W = x.shape
x_h, x_w = self.pool_h(x), self.pool_w(x).permute(0, 1, 3, 2)
x_cat = torch.cat([x_h, x_w], dim=2)
out = self.act1(self.bn1(self.conv1(x_cat)))
x_h, x_w = torch.split(out, [H, W], dim=2)
x_w = x_w.permute(0, 1, 3, 2)
out_h = torch.sigmoid(self.conv2(x_h))
out_w = torch.sigmoid(self.conv3(x_w))
return short * out_w * out_h
class BasicBlock(nn.Module):
def __init__(self, in_channels, out_channels, reduction, stride, attention=None):
super(BasicBlock, self).__init__()
self.change = None
if (in_channels != out_channels or stride != 1):
self.change = nn.Sequential(
nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, padding=0,
stride=stride, bias=False),
nn.InstanceNorm2d(out_channels)
)
self.left = nn.Sequential(
nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1,
stride=stride, bias=False),
nn.InstanceNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, padding=1, bias=False),
nn.InstanceNorm2d(out_channels)
)
if attention == 'se':
print('SEAttention')
self.attention = SEAttention(in_channels=out_channels, out_channels=out_channels, reduction=reduction)
elif attention == 'cbam':
print('CBAMAttention')
self.attention = CBAMAttention(in_channels=out_channels, out_channels=out_channels, reduction=reduction)
elif attention == 'coord':
print('CoordAttention')
self.attention = CoordAttention(in_channels=out_channels, out_channels=out_channels, reduction=reduction)
else:
print('None Attention')
self.attention = nn.Identity()
def forward(self, x):
identity = x
x = self.left(x)
x = self.attention(x)
if self.change is not None:
identity = self.change(identity)
x += identity
x = F.relu(x)
return x
class BottleneckBlock(nn.Module):
def __init__(self, in_channels, out_channels, reduction, stride, attention=None):
super(BottleneckBlock, self).__init__()
self.change = None
if (in_channels != out_channels or stride != 1):
self.change = nn.Sequential(
nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, padding=0,
stride=stride, bias=False),
nn.InstanceNorm2d(out_channels)
)
self.left = nn.Sequential(
nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1,
stride=stride, padding=0, bias=False),
nn.InstanceNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, padding=1, bias=False),
nn.InstanceNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=1, padding=0, bias=False),
nn.InstanceNorm2d(out_channels)
)
if attention == 'se':
print('SEAttention')
self.attention = SEAttention(in_channels=out_channels, out_channels=out_channels, reduction=reduction)
elif attention == 'cbam':
print('CBAMAttention')
self.attention = CBAMAttention(in_channels=out_channels, out_channels=out_channels, reduction=reduction)
elif attention == 'coord':
print('CoordAttention')
self.attention = CoordAttention(in_channels=out_channels, out_channels=out_channels, reduction=reduction)
else:
print('None Attention')
self.attention = nn.Identity()
def forward(self, x):
identity = x
x = self.left(x)
x = self.attention(x)
if self.change is not None:
identity = self.change(identity)
x += identity
x = F.relu(x)
return x
class ResBlock(nn.Module):
def __init__(self, in_channels, out_channels, blocks=1, block_type="BottleneckBlock", reduction=8, stride=1, attention=None):
super(ResBlock, self).__init__()
layers = [eval(block_type)(in_channels, out_channels, reduction, stride, attention=attention)] if blocks != 0 else []
for _ in range(blocks - 1):
layer = eval(block_type)(out_channels, out_channels, reduction, 1, attention=attention)
layers.append(layer)
self.layers = nn.Sequential(*layers)
def forward(self, x):
return self.layers(x)