import math from functools import partial import torch from torch import nn from torch.nn import functional as F class SwishImplementation(torch.autograd.Function): @staticmethod def forward(ctx, i): result = i * torch.sigmoid(i) ctx.save_for_backward(i) return result @staticmethod def backward(ctx, grad_output): i = ctx.saved_variables[0] sigmoid_i = torch.sigmoid(i) return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i))) class MemoryEfficientSwish(nn.Module): def forward(self, x): return SwishImplementation.apply(x) def drop_connect(inputs, p, training): """ Drop connect. """ if not training: return inputs batch_size = inputs.shape[0] keep_prob = 1 - p random_tensor = keep_prob random_tensor += torch.rand([batch_size, 1, 1, 1], dtype=inputs.dtype, device=inputs.device) binary_tensor = torch.floor(random_tensor) output = inputs / keep_prob * binary_tensor return output def get_same_padding_conv2d(image_size=None): return partial(Conv2dStaticSamePadding, image_size=image_size) def get_width_and_height_from_size(x): """ Obtains width and height from a int or tuple """ if isinstance(x, int): return x, x if isinstance(x, list) or isinstance(x, tuple): return x else: raise TypeError() def calculate_output_image_size(input_image_size, stride): """ 计算出 Conv2dSamePadding with a stride. """ if input_image_size is None: return None image_height, image_width = get_width_and_height_from_size(input_image_size) stride = stride if isinstance(stride, int) else stride[0] image_height = int(math.ceil(image_height / stride)) image_width = int(math.ceil(image_width / stride)) return [image_height, image_width] class Conv2dStaticSamePadding(nn.Conv2d): """ 2D Convolutions like TensorFlow, for a fixed image size""" def __init__(self, in_channels, out_channels, kernel_size, image_size=None, **kwargs): super().__init__(in_channels, out_channels, kernel_size, **kwargs) self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2 # Calculate padding based on image size and save it assert image_size is not None ih, iw = (image_size, image_size) if isinstance(image_size, int) else image_size kh, kw = self.weight.size()[-2:] sh, sw = self.stride oh, ow = math.ceil(ih / sh), math.ceil(iw / sw) pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0) pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0) if pad_h > 0 or pad_w > 0: self.static_padding = nn.ZeroPad2d((pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2)) else: self.static_padding = Identity() def forward(self, x): x = self.static_padding(x) x = F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) return x class Identity(nn.Module): def __init__(self, ): super(Identity, self).__init__() def forward(self, input): return input # #MBConvBlock class MBConvBlock(nn.Module): ''' 层 ksize3*3 输入32 输出16 conv1 stride步长1 ''' def __init__(self, ksize, input_filters, output_filters, expand_ratio=1, stride=1,image_size=224,drop_connect_rate=0.): super().__init__() self._bn_mom = 0.1 self._bn_eps = 0.01 self._se_ratio = 0.25 self._input_filters = input_filters self._output_filters = output_filters self._expand_ratio = expand_ratio self._kernel_size = ksize self._stride = stride self._drop_connect_rate = drop_connect_rate inp = self._input_filters oup = self._input_filters * self._expand_ratio if self._expand_ratio != 1: self._expand_conv = nn.Conv2d(in_channels=inp, out_channels=oup, kernel_size=1,bias=False) self._bn0 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps) # Depthwise convolution k = self._kernel_size s = self._stride self._depthwise_conv = nn.Conv2d(in_channels=oup, out_channels=oup, groups=oup, kernel_size=k, stride=s, padding=1,bias=False) self._bn1 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps) # Squeeze and Excitation layer, if desired num_squeezed_channels = max(1, int(self._input_filters * self._se_ratio)) self._se_reduce = nn.Conv2d(in_channels=oup, out_channels=num_squeezed_channels, kernel_size=1) self._se_expand = nn.Conv2d(in_channels=num_squeezed_channels, out_channels=oup, kernel_size=1) # Output phase final_oup = self._output_filters self._project_conv = nn.Conv2d(in_channels=oup, out_channels=final_oup, kernel_size=1,bias=False) self._bn2 = nn.BatchNorm2d(num_features=final_oup, momentum=self._bn_mom, eps=self._bn_eps) self._swish = MemoryEfficientSwish() def forward(self, inputs): """ :param inputs: input tensor :return: output of block """ # Expansion and Depthwise Convolution x = inputs if self._expand_ratio != 1: expand = self._expand_conv(inputs) bn0 = self._bn0(expand) x = self._swish(bn0) depthwise = self._depthwise_conv(x) bn1 = self._bn1(depthwise) x = self._swish(bn1) # Squeeze and Excitation x_squeezed = F.adaptive_avg_pool2d(x, 1) x_squeezed = self._se_reduce(x_squeezed) x_squeezed = self._swish(x_squeezed) x_squeezed = self._se_expand(x_squeezed) x = torch.sigmoid(x_squeezed) * x x = self._bn2(self._project_conv(x)) # Skip connection and drop connect input_filters, output_filters = self._input_filters, self._output_filters if self._stride == 1 and input_filters == output_filters: if self._drop_connect_rate!=0: x = drop_connect(x, p=self._drop_connect_rate, training=self.training) x = x + inputs # skip connection return x if __name__ == '__main__': input=torch.randn(1,3,112,112) mbconv=MBConvBlock(ksize=3,input_filters=3,output_filters=3,expand_ratio=4,stride=1) print(mbconv) out=mbconv(input) print(out.shape)