Spaces:
Runtime error
Runtime error
# Copyright 2020 Google LLC | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# https://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import importlib | |
def class_for_name(module_name, class_name): | |
# load the module, will raise ImportError if module cannot be loaded | |
m = importlib.import_module(module_name) | |
return getattr(m, class_name) | |
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): | |
"""3x3 convolution with padding""" | |
return nn.Conv2d( | |
in_planes, | |
out_planes, | |
kernel_size=3, | |
stride=stride, | |
padding=dilation, | |
groups=groups, | |
bias=False, | |
dilation=dilation, | |
padding_mode="reflect", | |
) | |
def conv1x1(in_planes, out_planes, stride=1): | |
"""1x1 convolution""" | |
return nn.Conv2d( | |
in_planes, | |
out_planes, | |
kernel_size=1, | |
stride=stride, | |
bias=False, | |
padding_mode="reflect", | |
) | |
class BasicBlock(nn.Module): | |
expansion = 1 | |
def __init__( | |
self, | |
inplanes, | |
planes, | |
stride=1, | |
downsample=None, | |
groups=1, | |
base_width=64, | |
dilation=1, | |
norm_layer=None, | |
): | |
super(BasicBlock, self).__init__() | |
if norm_layer is None: | |
norm_layer = nn.BatchNorm2d | |
# norm_layer = nn.InstanceNorm2d | |
if groups != 1 or base_width != 64: | |
raise ValueError("BasicBlock only supports groups=1 and base_width=64") | |
if dilation > 1: | |
raise NotImplementedError("Dilation > 1 not supported in BasicBlock") | |
# Both self.conv1 and self.downsample layers downsample the input when stride != 1 | |
self.conv1 = conv3x3(inplanes, planes, stride) | |
self.bn1 = norm_layer(planes, track_running_stats=False, affine=True) | |
self.relu = nn.ReLU(inplace=True) | |
self.conv2 = conv3x3(planes, planes) | |
self.bn2 = norm_layer(planes, track_running_stats=False, affine=True) | |
self.downsample = downsample | |
self.stride = stride | |
def forward(self, x): | |
identity = x | |
out = self.conv1(x) | |
out = self.bn1(out) | |
out = self.relu(out) | |
out = self.conv2(out) | |
out = self.bn2(out) | |
if self.downsample is not None: | |
identity = self.downsample(x) | |
out += identity | |
out = self.relu(out) | |
return out | |
class Bottleneck(nn.Module): | |
# Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) | |
# while original implementation places the stride at the first 1x1 convolution(self.conv1) | |
# according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. | |
# This variant is also known as ResNet V1.5 and improves accuracy according to | |
# https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. | |
expansion = 4 | |
def __init__( | |
self, | |
inplanes, | |
planes, | |
stride=1, | |
downsample=None, | |
groups=1, | |
base_width=64, | |
dilation=1, | |
norm_layer=None, | |
): | |
super(Bottleneck, self).__init__() | |
if norm_layer is None: | |
norm_layer = nn.BatchNorm2d | |
# norm_layer = nn.InstanceNorm2d | |
width = int(planes * (base_width / 64.0)) * groups | |
# Both self.conv2 and self.downsample layers downsample the input when stride != 1 | |
self.conv1 = conv1x1(inplanes, width) | |
self.bn1 = norm_layer(width, track_running_stats=False, affine=True) | |
self.conv2 = conv3x3(width, width, stride, groups, dilation) | |
self.bn2 = norm_layer(width, track_running_stats=False, affine=True) | |
self.conv3 = conv1x1(width, planes * self.expansion) | |
self.bn3 = norm_layer( | |
planes * self.expansion, track_running_stats=False, affine=True | |
) | |
self.relu = nn.ReLU(inplace=True) | |
self.downsample = downsample | |
self.stride = stride | |
def forward(self, x): | |
identity = x | |
out = self.conv1(x) | |
out = self.bn1(out) | |
out = self.relu(out) | |
out = self.conv2(out) | |
out = self.bn2(out) | |
out = self.relu(out) | |
out = self.conv3(out) | |
out = self.bn3(out) | |
if self.downsample is not None: | |
identity = self.downsample(x) | |
out += identity | |
out = self.relu(out) | |
return out | |
class conv(nn.Module): | |
def __init__(self, num_in_layers, num_out_layers, kernel_size, stride): | |
super(conv, self).__init__() | |
self.kernel_size = kernel_size | |
self.conv = nn.Conv2d( | |
num_in_layers, | |
num_out_layers, | |
kernel_size=kernel_size, | |
stride=stride, | |
padding=(self.kernel_size - 1) // 2, | |
padding_mode="reflect", | |
) | |
# self.bn = nn.InstanceNorm2d( | |
# num_out_layers, track_running_stats=False, affine=True | |
# ) | |
self.bn = nn.BatchNorm2d(num_out_layers, track_running_stats=False, affine=True) | |
# self.bn = nn.LayerNorm(num_out_layers) | |
def forward(self, x): | |
return F.elu(self.bn(self.conv(x)), inplace=True) | |
class upconv(nn.Module): | |
def __init__(self, num_in_layers, num_out_layers, kernel_size, scale): | |
super(upconv, self).__init__() | |
self.scale = scale | |
self.conv = conv(num_in_layers, num_out_layers, kernel_size, 1) | |
def forward(self, x): | |
x = nn.functional.interpolate( | |
x, scale_factor=self.scale, align_corners=True, mode="bilinear" | |
) | |
return self.conv(x) | |
class ResUNet(nn.Module): | |
def __init__( | |
self, | |
encoder="resnet34", | |
coarse_out_ch=32, | |
fine_out_ch=32, | |
norm_layer=None, | |
coarse_only=False, | |
): | |
super(ResUNet, self).__init__() | |
assert encoder in [ | |
"resnet18", | |
"resnet34", | |
"resnet50", | |
"resnet101", | |
"resnet152", | |
], "Incorrect encoder type" | |
if encoder in ["resnet18", "resnet34"]: | |
filters = [64, 128, 256, 512] | |
else: | |
filters = [256, 512, 1024, 2048] | |
self.coarse_only = coarse_only | |
if self.coarse_only: | |
fine_out_ch = 0 | |
self.coarse_out_ch = coarse_out_ch | |
self.fine_out_ch = fine_out_ch | |
out_ch = coarse_out_ch + fine_out_ch | |
# original | |
layers = [3, 4, 6, 3] | |
if norm_layer is None: | |
norm_layer = nn.BatchNorm2d | |
# norm_layer = nn.InstanceNorm2d | |
self._norm_layer = norm_layer | |
self.dilation = 1 | |
block = BasicBlock | |
replace_stride_with_dilation = [False, False, False] | |
self.inplanes = 64 | |
self.groups = 1 | |
self.base_width = 64 | |
self.conv1 = nn.Conv2d( | |
3, | |
self.inplanes, | |
kernel_size=7, | |
stride=2, | |
padding=3, | |
bias=False, | |
padding_mode="reflect", | |
) | |
self.bn1 = norm_layer(self.inplanes, track_running_stats=False, affine=True) | |
self.relu = nn.ReLU(inplace=True) | |
self.layer1 = self._make_layer(block, 64, layers[0], stride=2) | |
self.layer2 = self._make_layer( | |
block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0] | |
) | |
self.layer3 = self._make_layer( | |
block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1] | |
) | |
# decoder | |
self.upconv3 = upconv(filters[2], 128, 3, 2) | |
self.iconv3 = conv(filters[1] + 128, 128, 3, 1) | |
self.upconv2 = upconv(128, 64, 3, 2) | |
self.iconv2 = conv(filters[0] + 64, out_ch, 3, 1) | |
# fine-level conv | |
self.out_conv = nn.Conv2d(out_ch, out_ch, 1, 1) | |
def _make_layer(self, block, planes, blocks, stride=1, dilate=False): | |
norm_layer = self._norm_layer | |
downsample = None | |
previous_dilation = self.dilation | |
if dilate: | |
self.dilation *= stride | |
stride = 1 | |
if stride != 1 or self.inplanes != planes * block.expansion: | |
downsample = nn.Sequential( | |
conv1x1(self.inplanes, planes * block.expansion, stride), | |
norm_layer( | |
planes * block.expansion, track_running_stats=False, affine=True | |
), | |
) | |
layers = [] | |
layers.append( | |
block( | |
self.inplanes, | |
planes, | |
stride, | |
downsample, | |
self.groups, | |
self.base_width, | |
previous_dilation, | |
norm_layer, | |
) | |
) | |
self.inplanes = planes * block.expansion | |
for _ in range(1, blocks): | |
layers.append( | |
block( | |
self.inplanes, | |
planes, | |
groups=self.groups, | |
base_width=self.base_width, | |
dilation=self.dilation, | |
norm_layer=norm_layer, | |
) | |
) | |
return nn.Sequential(*layers) | |
def skipconnect(self, x1, x2): | |
diffY = x2.size()[2] - x1.size()[2] | |
diffX = x2.size()[3] - x1.size()[3] | |
x1 = F.pad(x1, (diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2)) | |
# for padding issues, see | |
# https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a | |
# https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd | |
x = torch.cat([x2, x1], dim=1) | |
return x | |
def forward(self, x): | |
x = self.relu(self.bn1(self.conv1(x))) | |
x1 = self.layer1(x) | |
x2 = self.layer2(x1) | |
x3 = self.layer3(x2) | |
x = self.upconv3(x3) | |
x = self.skipconnect(x2, x) | |
x = self.iconv3(x) | |
x = self.upconv2(x) | |
x = self.skipconnect(x1, x) | |
x = self.iconv2(x) | |
x_out = self.out_conv(x) | |
return x_out | |
# if self.coarse_only: | |
# x_coarse = x_out | |
# x_fine = None | |
# else: | |
# x_coarse = x_out[:, : self.coarse_out_ch, :] | |
# x_fine = x_out[:, -self.fine_out_ch :, :] | |
# return x_coarse, x_fine | |