metaformer / models /MBConv.py
diaoqishuai
first commit
4a3ad95
raw
history blame
6.49 kB
import math
from functools import partial
import torch
from torch import nn
from torch.nn import functional as F
class SwishImplementation(torch.autograd.Function):
@staticmethod
def forward(ctx, i):
result = i * torch.sigmoid(i)
ctx.save_for_backward(i)
return result
@staticmethod
def backward(ctx, grad_output):
i = ctx.saved_variables[0]
sigmoid_i = torch.sigmoid(i)
return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i)))
class MemoryEfficientSwish(nn.Module):
def forward(self, x):
return SwishImplementation.apply(x)
def drop_connect(inputs, p, training):
""" Drop connect. """
if not training: return inputs
batch_size = inputs.shape[0]
keep_prob = 1 - p
random_tensor = keep_prob
random_tensor += torch.rand([batch_size, 1, 1, 1], dtype=inputs.dtype, device=inputs.device)
binary_tensor = torch.floor(random_tensor)
output = inputs / keep_prob * binary_tensor
return output
def get_same_padding_conv2d(image_size=None):
return partial(Conv2dStaticSamePadding, image_size=image_size)
def get_width_and_height_from_size(x):
""" Obtains width and height from a int or tuple """
if isinstance(x, int): return x, x
if isinstance(x, list) or isinstance(x, tuple): return x
else: raise TypeError()
def calculate_output_image_size(input_image_size, stride):
"""
计算出 Conv2dSamePadding with a stride.
"""
if input_image_size is None: return None
image_height, image_width = get_width_and_height_from_size(input_image_size)
stride = stride if isinstance(stride, int) else stride[0]
image_height = int(math.ceil(image_height / stride))
image_width = int(math.ceil(image_width / stride))
return [image_height, image_width]
class Conv2dStaticSamePadding(nn.Conv2d):
""" 2D Convolutions like TensorFlow, for a fixed image size"""
def __init__(self, in_channels, out_channels, kernel_size, image_size=None, **kwargs):
super().__init__(in_channels, out_channels, kernel_size, **kwargs)
self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2
# Calculate padding based on image size and save it
assert image_size is not None
ih, iw = (image_size, image_size) if isinstance(image_size, int) else image_size
kh, kw = self.weight.size()[-2:]
sh, sw = self.stride
oh, ow = math.ceil(ih / sh), math.ceil(iw / sw)
pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
if pad_h > 0 or pad_w > 0:
self.static_padding = nn.ZeroPad2d((pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2))
else:
self.static_padding = Identity()
def forward(self, x):
x = self.static_padding(x)
x = F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
return x
class Identity(nn.Module):
def __init__(self, ):
super(Identity, self).__init__()
def forward(self, input):
return input
# #MBConvBlock
class MBConvBlock(nn.Module):
'''
层 ksize3*3 输入32 输出16 conv1 stride步长1
'''
def __init__(self, ksize, input_filters, output_filters, expand_ratio=1, stride=1,image_size=224,drop_connect_rate=0.):
super().__init__()
self._bn_mom = 0.1
self._bn_eps = 0.01
self._se_ratio = 0.25
self._input_filters = input_filters
self._output_filters = output_filters
self._expand_ratio = expand_ratio
self._kernel_size = ksize
self._stride = stride
self._drop_connect_rate = drop_connect_rate
inp = self._input_filters
oup = self._input_filters * self._expand_ratio
if self._expand_ratio != 1:
self._expand_conv = nn.Conv2d(in_channels=inp, out_channels=oup, kernel_size=1,bias=False)
self._bn0 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps)
# Depthwise convolution
k = self._kernel_size
s = self._stride
self._depthwise_conv = nn.Conv2d(in_channels=oup, out_channels=oup, groups=oup,
kernel_size=k, stride=s, padding=1,bias=False)
self._bn1 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps)
# Squeeze and Excitation layer, if desired
num_squeezed_channels = max(1, int(self._input_filters * self._se_ratio))
self._se_reduce = nn.Conv2d(in_channels=oup, out_channels=num_squeezed_channels, kernel_size=1)
self._se_expand = nn.Conv2d(in_channels=num_squeezed_channels, out_channels=oup, kernel_size=1)
# Output phase
final_oup = self._output_filters
self._project_conv = nn.Conv2d(in_channels=oup, out_channels=final_oup, kernel_size=1,bias=False)
self._bn2 = nn.BatchNorm2d(num_features=final_oup, momentum=self._bn_mom, eps=self._bn_eps)
self._swish = MemoryEfficientSwish()
def forward(self, inputs):
"""
:param inputs: input tensor
:return: output of block
"""
# Expansion and Depthwise Convolution
x = inputs
if self._expand_ratio != 1:
expand = self._expand_conv(inputs)
bn0 = self._bn0(expand)
x = self._swish(bn0)
depthwise = self._depthwise_conv(x)
bn1 = self._bn1(depthwise)
x = self._swish(bn1)
# Squeeze and Excitation
x_squeezed = F.adaptive_avg_pool2d(x, 1)
x_squeezed = self._se_reduce(x_squeezed)
x_squeezed = self._swish(x_squeezed)
x_squeezed = self._se_expand(x_squeezed)
x = torch.sigmoid(x_squeezed) * x
x = self._bn2(self._project_conv(x))
# Skip connection and drop connect
input_filters, output_filters = self._input_filters, self._output_filters
if self._stride == 1 and input_filters == output_filters:
if self._drop_connect_rate!=0:
x = drop_connect(x, p=self._drop_connect_rate, training=self.training)
x = x + inputs # skip connection
return x
if __name__ == '__main__':
input=torch.randn(1,3,112,112)
mbconv=MBConvBlock(ksize=3,input_filters=3,output_filters=3,expand_ratio=4,stride=1)
print(mbconv)
out=mbconv(input)
print(out.shape)