minchul's picture
Upload directory
984e623 verified
from collections import namedtuple
from torch.nn import Dropout
from torch.nn import MaxPool2d
from torch.nn import Sequential
import torch
import torch.nn as nn
from torch.nn import Conv2d, Linear
from torch.nn import BatchNorm1d, BatchNorm2d
from torch.nn import ReLU, Sigmoid
from torch.nn import Module
from torch.nn import PReLU
from fvcore.nn import flop_count
import numpy as np
def initialize_weights(modules):
for m in modules:
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight,
mode='fan_out',
nonlinearity='relu')
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
nn.init.kaiming_normal_(m.weight,
mode='fan_out',
nonlinearity='relu')
if m.bias is not None:
m.bias.data.zero_()
class Flatten(Module):
def forward(self, input):
return input.view(input.size(0), -1)
class LinearBlock(Module):
def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1):
super(LinearBlock, self).__init__()
self.conv = Conv2d(in_c, out_c, kernel, stride, padding, groups=groups, bias=False)
self.bn = BatchNorm2d(out_c)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
return x
class SEModule(Module):
def __init__(self, channels, reduction):
super(SEModule, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc1 = Conv2d(channels, channels // reduction,
kernel_size=1, padding=0, bias=False)
nn.init.xavier_uniform_(self.fc1.weight.data)
self.relu = ReLU(inplace=True)
self.fc2 = Conv2d(channels // reduction, channels,
kernel_size=1, padding=0, bias=False)
self.sigmoid = Sigmoid()
def forward(self, x):
module_input = x
x = self.avg_pool(x)
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.sigmoid(x)
return module_input * x
class BasicBlockIR(Module):
def __init__(self, in_channel, depth, stride):
super(BasicBlockIR, self).__init__()
if in_channel == depth:
self.shortcut_layer = MaxPool2d(1, stride)
else:
self.shortcut_layer = Sequential(
Conv2d(in_channel, depth, (1, 1), stride, bias=False),
BatchNorm2d(depth))
self.res_layer = Sequential(
BatchNorm2d(in_channel),
Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False),
BatchNorm2d(depth),
PReLU(depth),
Conv2d(depth, depth, (3, 3), stride, 1, bias=False),
BatchNorm2d(depth))
def forward(self, x):
shortcut = self.shortcut_layer(x)
res = self.res_layer(x)
return res + shortcut
class BottleneckIR(Module):
def __init__(self, in_channel, depth, stride):
super(BottleneckIR, self).__init__()
reduction_channel = depth // 4
if in_channel == depth:
self.shortcut_layer = MaxPool2d(1, stride)
else:
self.shortcut_layer = Sequential(
Conv2d(in_channel, depth, (1, 1), stride, bias=False),
BatchNorm2d(depth))
self.res_layer = Sequential(
BatchNorm2d(in_channel),
Conv2d(in_channel, reduction_channel, (1, 1), (1, 1), 0, bias=False),
BatchNorm2d(reduction_channel),
PReLU(reduction_channel),
Conv2d(reduction_channel, reduction_channel, (3, 3), (1, 1), 1, bias=False),
BatchNorm2d(reduction_channel),
PReLU(reduction_channel),
Conv2d(reduction_channel, depth, (1, 1), stride, 0, bias=False),
BatchNorm2d(depth))
def forward(self, x):
shortcut = self.shortcut_layer(x)
res = self.res_layer(x)
return res + shortcut
class BasicBlockIRSE(BasicBlockIR):
def __init__(self, in_channel, depth, stride):
super(BasicBlockIRSE, self).__init__(in_channel, depth, stride)
self.res_layer.add_module("se_block", SEModule(depth, 16))
class BottleneckIRSE(BottleneckIR):
def __init__(self, in_channel, depth, stride):
super(BottleneckIRSE, self).__init__(in_channel, depth, stride)
self.res_layer.add_module("se_block", SEModule(depth, 16))
class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])):
pass
def get_block(in_channel, depth, num_units, stride=2):
return [Bottleneck(in_channel, depth, stride)] + \
[Bottleneck(depth, depth, 1) for i in range(num_units - 1)]
def get_blocks(num_layers):
if num_layers == 18:
blocks = [
get_block(in_channel=64, depth=64, num_units=2),
get_block(in_channel=64, depth=128, num_units=2),
get_block(in_channel=128, depth=256, num_units=2),
get_block(in_channel=256, depth=512, num_units=2)
]
elif num_layers == 34:
blocks = [
get_block(in_channel=64, depth=64, num_units=3),
get_block(in_channel=64, depth=128, num_units=4),
get_block(in_channel=128, depth=256, num_units=6),
get_block(in_channel=256, depth=512, num_units=3)
]
elif num_layers == 50:
blocks = [
get_block(in_channel=64, depth=64, num_units=3),
get_block(in_channel=64, depth=128, num_units=4),
get_block(in_channel=128, depth=256, num_units=14),
get_block(in_channel=256, depth=512, num_units=3)
]
elif num_layers == 100:
blocks = [
get_block(in_channel=64, depth=64, num_units=3),
get_block(in_channel=64, depth=128, num_units=13),
get_block(in_channel=128, depth=256, num_units=30),
get_block(in_channel=256, depth=512, num_units=3)
]
elif num_layers == 152:
blocks = [
get_block(in_channel=64, depth=256, num_units=3),
get_block(in_channel=256, depth=512, num_units=8),
get_block(in_channel=512, depth=1024, num_units=36),
get_block(in_channel=1024, depth=2048, num_units=3)
]
elif num_layers == 200:
blocks = [
get_block(in_channel=64, depth=256, num_units=3),
get_block(in_channel=256, depth=512, num_units=24),
get_block(in_channel=512, depth=1024, num_units=36),
get_block(in_channel=1024, depth=2048, num_units=3)
]
return blocks
class Backbone(Module):
def __init__(self, input_size, num_layers, mode='ir', flip=False, output_dim=512):
super(Backbone, self).__init__()
assert input_size[0] in [112, 224], \
"input_size should be [112, 112] or [224, 224]"
assert num_layers in [18, 34, 50, 100, 152, 200], \
"num_layers should be 18, 34, 50, 100 or 152"
assert mode in ['ir', 'ir_se'], \
"mode should be ir or ir_se"
self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False),
BatchNorm2d(64), PReLU(64))
blocks = get_blocks(num_layers)
if num_layers <= 100:
if mode == 'ir':
unit_module = BasicBlockIR
elif mode == 'ir_se':
unit_module = BasicBlockIRSE
output_channel = 512
else:
if mode == 'ir':
unit_module = BottleneckIR
elif mode == 'ir_se':
unit_module = BottleneckIRSE
output_channel = 2048
if input_size[0] == 112:
self.output_layer = Sequential(BatchNorm2d(output_channel),
Dropout(0.4), Flatten(),
Linear(output_channel * 7 * 7, output_dim),
BatchNorm1d(output_dim, affine=False))
else:
self.output_layer = Sequential(
BatchNorm2d(output_channel), Dropout(0.4), Flatten(),
Linear(output_channel * 14 * 14, output_dim),
BatchNorm1d(output_dim, affine=False))
modules = []
for block in blocks:
for bottleneck in block:
modules.append(
unit_module(bottleneck.in_channel, bottleneck.depth,
bottleneck.stride))
self.body = Sequential(*modules)
initialize_weights(self.modules())
self.flip = flip
def forward(self, x):
if self.flip:
x = x.flip(1) # color channel flip
x = self.input_layer(x)
for idx, module in enumerate(self.body):
x = module(x)
x = self.output_layer(x)
return x
def IR_18(input_size, output_dim=512):
model = Backbone(input_size, 18, 'ir', output_dim=output_dim)
return model
def IR_34(input_size, output_dim=512):
model = Backbone(input_size, 34, 'ir', output_dim=output_dim)
return model
def IR_50(input_size, output_dim=512):
model = Backbone(input_size, 50, 'ir', output_dim=output_dim)
return model
def IR_101(input_size, output_dim=512):
model = Backbone(input_size, 100, 'ir', output_dim=output_dim)
return model
def IR_101_FLIP(input_size, output_dim=512):
model = Backbone(input_size, 100, 'ir', flip=True, output_dim=output_dim)
return model
def IR_152(input_size, output_dim=512):
model = Backbone(input_size, 152, 'ir', output_dim=output_dim)
return model
def IR_200(input_size, output_dim=512):
model = Backbone(input_size, 200, 'ir', output_dim=output_dim)
return model
def IR_SE_50(input_size, output_dim=512):
model = Backbone(input_size, 50, 'ir_se', output_dim=output_dim)
return model
def IR_SE_101(input_size, output_dim=512):
model = Backbone(input_size, 100, 'ir_se', output_dim=output_dim)
return model
def IR_SE_152(input_size, output_dim=512):
model = Backbone(input_size, 152, 'ir_se', output_dim=output_dim)
return model
def IR_SE_200(input_size, output_dim=512):
model = Backbone(input_size, 200, 'ir_se', output_dim=output_dim)
return model
if __name__ == '__main__':
inputs_shape = (1, 3, 112, 112)
model = IR_50(input_size=(112,112))
model.eval()
res = flop_count(model, inputs=torch.randn(inputs_shape), supported_ops={})
fvcore_flop = np.array(list(res[0].values())).sum()
print('FLOPs: ', fvcore_flop / 1e9, 'G')
print('Num Params: ', sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6, 'M')