|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
|
|
class mfm(nn.Module):
|
|
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, type=1):
|
|
super(mfm, self).__init__()
|
|
self.out_channels = out_channels
|
|
if type == 1:
|
|
self.filter = nn.Conv2d(in_channels, 2 * out_channels, kernel_size=kernel_size, stride=stride,
|
|
padding=padding)
|
|
else:
|
|
self.filter = nn.Linear(in_channels, 2 * out_channels)
|
|
|
|
def forward(self, x):
|
|
x = self.filter(x)
|
|
out = torch.split(x, self.out_channels, 1)
|
|
return torch.max(out[0], out[1])
|
|
|
|
|
|
class group(nn.Module):
|
|
def __init__(self, in_channels, out_channels, kernel_size, stride, padding):
|
|
super(group, self).__init__()
|
|
self.conv_a = mfm(in_channels, in_channels, 1, 1, 0)
|
|
self.conv = mfm(in_channels, out_channels, kernel_size, stride, padding)
|
|
|
|
def forward(self, x):
|
|
x = self.conv_a(x)
|
|
x = self.conv(x)
|
|
return x
|
|
|
|
|
|
class resblock(nn.Module):
|
|
def __init__(self, in_channels, out_channels):
|
|
super(resblock, self).__init__()
|
|
self.conv1 = mfm(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
|
self.conv2 = mfm(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
|
|
|
def forward(self, x):
|
|
res = x
|
|
out = self.conv1(x)
|
|
out = self.conv2(out)
|
|
out = out + res
|
|
return out
|
|
|
|
|
|
class network_29layers_v2(nn.Module):
|
|
def __init__(self, block, layers, is_train=False, num_classes=80013):
|
|
super(network_29layers_v2, self).__init__()
|
|
self.is_train = is_train
|
|
self.conv1 = mfm(1, 48, 5, 1, 2)
|
|
self.block1 = self._make_layer(block, layers[0], 48, 48)
|
|
self.group1 = group(48, 96, 3, 1, 1)
|
|
self.block2 = self._make_layer(block, layers[1], 96, 96)
|
|
self.group2 = group(96, 192, 3, 1, 1)
|
|
self.block3 = self._make_layer(block, layers[2], 192, 192)
|
|
self.group3 = group(192, 128, 3, 1, 1)
|
|
self.block4 = self._make_layer(block, layers[3], 128, 128)
|
|
self.group4 = group(128, 128, 3, 1, 1)
|
|
self.fc = nn.Linear(8 * 8 * 128, 256)
|
|
|
|
if self.is_train:
|
|
self.fc2_ = nn.Linear(256, num_classes, bias=False)
|
|
|
|
def _make_layer(self, block, num_blocks, in_channels, out_channels):
|
|
layers = []
|
|
for i in range(0, num_blocks):
|
|
layers.append(block(in_channels, out_channels))
|
|
return nn.Sequential(*layers)
|
|
|
|
def forward(self, x):
|
|
x = self.conv1(x)
|
|
x = F.max_pool2d(x, 2) + F.avg_pool2d(x, 2)
|
|
|
|
x = self.block1(x)
|
|
x = self.group1(x)
|
|
x = F.max_pool2d(x, 2) + F.avg_pool2d(x, 2)
|
|
|
|
x = self.block2(x)
|
|
x = self.group2(x)
|
|
x = F.max_pool2d(x, 2) + F.avg_pool2d(x, 2)
|
|
|
|
x = self.block3(x)
|
|
x = self.group3(x)
|
|
x = self.block4(x)
|
|
x = self.group4(x)
|
|
x = F.max_pool2d(x, 2) + F.avg_pool2d(x, 2)
|
|
|
|
x = x.view(x.size(0), -1)
|
|
fc = self.fc(x)
|
|
|
|
if self.is_train:
|
|
x = F.dropout(fc, training=self.training)
|
|
out = self.fc2_(x)
|
|
return out, fc
|
|
else:
|
|
return fc
|
|
|