import torch import torch.nn as nn import torch.nn.functional as F class mfm(nn.Module): def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, type=1): super(mfm, self).__init__() self.out_channels = out_channels if type == 1: self.filter = nn.Conv2d(in_channels, 2 * out_channels, kernel_size=kernel_size, stride=stride, padding=padding) else: self.filter = nn.Linear(in_channels, 2 * out_channels) def forward(self, x): x = self.filter(x) out = torch.split(x, self.out_channels, 1) return torch.max(out[0], out[1]) class group(nn.Module): def __init__(self, in_channels, out_channels, kernel_size, stride, padding): super(group, self).__init__() self.conv_a = mfm(in_channels, in_channels, 1, 1, 0) self.conv = mfm(in_channels, out_channels, kernel_size, stride, padding) def forward(self, x): x = self.conv_a(x) x = self.conv(x) return x class resblock(nn.Module): def __init__(self, in_channels, out_channels): super(resblock, self).__init__() self.conv1 = mfm(in_channels, out_channels, kernel_size=3, stride=1, padding=1) self.conv2 = mfm(out_channels, out_channels, kernel_size=3, stride=1, padding=1) def forward(self, x): res = x out = self.conv1(x) out = self.conv2(out) out = out + res return out class network_29layers_v2(nn.Module): def __init__(self, block, layers, is_train=False, num_classes=80013): super(network_29layers_v2, self).__init__() self.is_train = is_train self.conv1 = mfm(1, 48, 5, 1, 2) self.block1 = self._make_layer(block, layers[0], 48, 48) self.group1 = group(48, 96, 3, 1, 1) self.block2 = self._make_layer(block, layers[1], 96, 96) self.group2 = group(96, 192, 3, 1, 1) self.block3 = self._make_layer(block, layers[2], 192, 192) self.group3 = group(192, 128, 3, 1, 1) self.block4 = self._make_layer(block, layers[3], 128, 128) self.group4 = group(128, 128, 3, 1, 1) self.fc = nn.Linear(8 * 8 * 128, 256) if self.is_train: self.fc2_ = nn.Linear(256, num_classes, bias=False) def _make_layer(self, block, num_blocks, in_channels, out_channels): layers = [] for i in range(0, num_blocks): layers.append(block(in_channels, out_channels)) return nn.Sequential(*layers) def forward(self, x): x = self.conv1(x) x = F.max_pool2d(x, 2) + F.avg_pool2d(x, 2) x = self.block1(x) x = self.group1(x) x = F.max_pool2d(x, 2) + F.avg_pool2d(x, 2) x = self.block2(x) x = self.group2(x) x = F.max_pool2d(x, 2) + F.avg_pool2d(x, 2) x = self.block3(x) x = self.group3(x) x = self.block4(x) x = self.group4(x) x = F.max_pool2d(x, 2) + F.avg_pool2d(x, 2) x = x.view(x.size(0), -1) fc = self.fc(x) if self.is_train: x = F.dropout(fc, training=self.training) out = self.fc2_(x) return out, fc else: return fc