|
""" |
|
Author: Soubhik Sanyal |
|
Copyright (c) 2019, Soubhik Sanyal |
|
All rights reserved. |
|
Loads different resnet models |
|
""" |
|
''' |
|
file: Resnet.py |
|
date: 2018_05_02 |
|
author: zhangxiong(1025679612@qq.com) |
|
mark: copied from pytorch source code |
|
''' |
|
|
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import torch |
|
from torch.nn.parameter import Parameter |
|
import torch.optim as optim |
|
import numpy as np |
|
import math |
|
import torchvision |
|
|
|
|
|
class ResNet(nn.Module): |
|
|
|
def __init__(self, block, layers, num_classes=1000): |
|
self.inplanes = 64 |
|
super(ResNet, self).__init__() |
|
self.conv1 = nn.Conv2d(3, |
|
64, |
|
kernel_size=7, |
|
stride=2, |
|
padding=3, |
|
bias=False) |
|
self.bn1 = nn.BatchNorm2d(64) |
|
self.relu = nn.ReLU(inplace=True) |
|
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) |
|
self.layer1 = self._make_layer(block, 64, layers[0]) |
|
self.layer2 = self._make_layer(block, 128, layers[1], stride=2) |
|
self.layer3 = self._make_layer(block, 256, layers[2], stride=2) |
|
self.layer4 = self._make_layer(block, 512, layers[3], stride=2) |
|
self.avgpool = nn.AvgPool2d(7, stride=1) |
|
|
|
|
|
for m in self.modules(): |
|
if isinstance(m, nn.Conv2d): |
|
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels |
|
m.weight.data.normal_(0, math.sqrt(2. / n)) |
|
elif isinstance(m, nn.BatchNorm2d): |
|
m.weight.data.fill_(1) |
|
m.bias.data.zero_() |
|
|
|
def _make_layer(self, block, planes, blocks, stride=1): |
|
downsample = None |
|
if stride != 1 or self.inplanes != planes * block.expansion: |
|
downsample = nn.Sequential( |
|
nn.Conv2d(self.inplanes, |
|
planes * block.expansion, |
|
kernel_size=1, |
|
stride=stride, |
|
bias=False), |
|
nn.BatchNorm2d(planes * block.expansion), |
|
) |
|
|
|
layers = [] |
|
layers.append(block(self.inplanes, planes, stride, downsample)) |
|
self.inplanes = planes * block.expansion |
|
for i in range(1, blocks): |
|
layers.append(block(self.inplanes, planes)) |
|
|
|
return nn.Sequential(*layers) |
|
|
|
def forward(self, x): |
|
x = self.conv1(x) |
|
x = self.bn1(x) |
|
x = self.relu(x) |
|
x = self.maxpool(x) |
|
|
|
x = self.layer1(x) |
|
x = self.layer2(x) |
|
x = self.layer3(x) |
|
x1 = self.layer4(x) |
|
|
|
x2 = self.avgpool(x1) |
|
x2 = x2.view(x2.size(0), -1) |
|
|
|
|
|
|
|
return x2 |
|
|
|
|
|
class Bottleneck(nn.Module): |
|
expansion = 4 |
|
|
|
def __init__(self, inplanes, planes, stride=1, downsample=None): |
|
super(Bottleneck, self).__init__() |
|
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) |
|
self.bn1 = nn.BatchNorm2d(planes) |
|
self.conv2 = nn.Conv2d(planes, |
|
planes, |
|
kernel_size=3, |
|
stride=stride, |
|
padding=1, |
|
bias=False) |
|
self.bn2 = nn.BatchNorm2d(planes) |
|
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) |
|
self.bn3 = nn.BatchNorm2d(planes * 4) |
|
self.relu = nn.ReLU(inplace=True) |
|
self.downsample = downsample |
|
self.stride = stride |
|
|
|
def forward(self, x): |
|
residual = x |
|
|
|
out = self.conv1(x) |
|
out = self.bn1(out) |
|
out = self.relu(out) |
|
|
|
out = self.conv2(out) |
|
out = self.bn2(out) |
|
out = self.relu(out) |
|
|
|
out = self.conv3(out) |
|
out = self.bn3(out) |
|
|
|
if self.downsample is not None: |
|
residual = self.downsample(x) |
|
|
|
out += residual |
|
out = self.relu(out) |
|
|
|
return out |
|
|
|
|
|
def conv3x3(in_planes, out_planes, stride=1): |
|
"""3x3 convolution with padding""" |
|
return nn.Conv2d(in_planes, |
|
out_planes, |
|
kernel_size=3, |
|
stride=stride, |
|
padding=1, |
|
bias=False) |
|
|
|
|
|
class BasicBlock(nn.Module): |
|
expansion = 1 |
|
|
|
def __init__(self, inplanes, planes, stride=1, downsample=None): |
|
super(BasicBlock, self).__init__() |
|
self.conv1 = conv3x3(inplanes, planes, stride) |
|
self.bn1 = nn.BatchNorm2d(planes) |
|
self.relu = nn.ReLU(inplace=True) |
|
self.conv2 = conv3x3(planes, planes) |
|
self.bn2 = nn.BatchNorm2d(planes) |
|
self.downsample = downsample |
|
self.stride = stride |
|
|
|
def forward(self, x): |
|
residual = x |
|
|
|
out = self.conv1(x) |
|
out = self.bn1(out) |
|
out = self.relu(out) |
|
|
|
out = self.conv2(out) |
|
out = self.bn2(out) |
|
|
|
if self.downsample is not None: |
|
residual = self.downsample(x) |
|
|
|
out += residual |
|
out = self.relu(out) |
|
|
|
return out |
|
|
|
|
|
def copy_parameter_from_resnet(model, resnet_dict): |
|
cur_state_dict = model.state_dict() |
|
|
|
for name, param in list(resnet_dict.items())[0:None]: |
|
if name not in cur_state_dict: |
|
|
|
continue |
|
if isinstance(param, Parameter): |
|
param = param.data |
|
try: |
|
cur_state_dict[name].copy_(param) |
|
except: |
|
|
|
continue |
|
|
|
|
|
|
|
def load_ResNet50Model(): |
|
model = ResNet(Bottleneck, [3, 4, 6, 3]) |
|
copy_parameter_from_resnet( |
|
model, |
|
torchvision.models.resnet50(pretrained=True).state_dict()) |
|
return model |
|
|
|
|
|
def load_ResNet101Model(): |
|
model = ResNet(Bottleneck, [3, 4, 23, 3]) |
|
copy_parameter_from_resnet( |
|
model, |
|
torchvision.models.resnet101(pretrained=True).state_dict()) |
|
return model |
|
|
|
|
|
def load_ResNet152Model(): |
|
model = ResNet(Bottleneck, [3, 8, 36, 3]) |
|
copy_parameter_from_resnet( |
|
model, |
|
torchvision.models.resnet152(pretrained=True).state_dict()) |
|
return model |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DoubleConv(nn.Module): |
|
"""(convolution => [BN] => ReLU) * 2""" |
|
|
|
def __init__(self, in_channels, out_channels): |
|
super().__init__() |
|
self.double_conv = nn.Sequential( |
|
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), |
|
nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True), |
|
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), |
|
nn.BatchNorm2d(out_channels), nn.ReLU(inplace=True)) |
|
|
|
def forward(self, x): |
|
return self.double_conv(x) |
|
|
|
|
|
class Down(nn.Module): |
|
"""Downscaling with maxpool then double conv""" |
|
|
|
def __init__(self, in_channels, out_channels): |
|
super().__init__() |
|
self.maxpool_conv = nn.Sequential( |
|
nn.MaxPool2d(2), DoubleConv(in_channels, out_channels)) |
|
|
|
def forward(self, x): |
|
return self.maxpool_conv(x) |
|
|
|
|
|
class Up(nn.Module): |
|
"""Upscaling then double conv""" |
|
|
|
def __init__(self, in_channels, out_channels, bilinear=True): |
|
super().__init__() |
|
|
|
|
|
if bilinear: |
|
self.up = nn.Upsample(scale_factor=2, |
|
mode='bilinear', |
|
align_corners=True) |
|
else: |
|
self.up = nn.ConvTranspose2d(in_channels // 2, |
|
in_channels // 2, |
|
kernel_size=2, |
|
stride=2) |
|
|
|
self.conv = DoubleConv(in_channels, out_channels) |
|
|
|
def forward(self, x1, x2): |
|
x1 = self.up(x1) |
|
|
|
diffY = x2.size()[2] - x1.size()[2] |
|
diffX = x2.size()[3] - x1.size()[3] |
|
|
|
x1 = F.pad( |
|
x1, |
|
[diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2]) |
|
|
|
|
|
|
|
x = torch.cat([x2, x1], dim=1) |
|
return self.conv(x) |
|
|
|
|
|
class OutConv(nn.Module): |
|
|
|
def __init__(self, in_channels, out_channels): |
|
super(OutConv, self).__init__() |
|
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) |
|
|
|
def forward(self, x): |
|
return self.conv(x) |
|
|
|
|
|
class UNet(nn.Module): |
|
|
|
def __init__(self, n_channels, n_classes, bilinear=True): |
|
super(UNet, self).__init__() |
|
self.n_channels = n_channels |
|
self.n_classes = n_classes |
|
self.bilinear = bilinear |
|
|
|
self.inc = DoubleConv(n_channels, 64) |
|
self.down1 = Down(64, 128) |
|
self.down2 = Down(128, 256) |
|
self.down3 = Down(256, 512) |
|
self.down4 = Down(512, 512) |
|
self.up1 = Up(1024, 256, bilinear) |
|
self.up2 = Up(512, 128, bilinear) |
|
self.up3 = Up(256, 64, bilinear) |
|
self.up4 = Up(128, 64, bilinear) |
|
self.outc = OutConv(64, n_classes) |
|
|
|
def forward(self, x): |
|
x1 = self.inc(x) |
|
x2 = self.down1(x1) |
|
x3 = self.down2(x2) |
|
x4 = self.down3(x3) |
|
x5 = self.down4(x4) |
|
x = self.up1(x5, x4) |
|
x = self.up2(x, x3) |
|
x = self.up3(x, x2) |
|
x = self.up4(x, x1) |
|
x = F.normalize(x) |
|
return x |
|
|