''' |
Codes are from: |
https://github.com/jaxony/unet-pytorch/blob/master/model.py |
''' |
import torch |
import torch.nn as nn |
import torch.nn.functional as F |
from torch.autograd import Variable |
from collections import OrderedDict |
from torch.nn import init |
import numpy as np |
def conv3x3(in_channels, out_channels, stride=1, |
padding=1, bias=True, groups=1): |
return nn.Conv2d( |
in_channels, |
out_channels, |
kernel_size=3, |
stride=stride, |
padding=padding, |
bias=bias, |
groups=groups) |
def upconv2x2(in_channels, out_channels, mode='transpose'): |
if mode == 'transpose': |
return nn.ConvTranspose2d( |
in_channels, |
out_channels, |
kernel_size=2, |
stride=2) |
else: |
return nn.Sequential( |
nn.Upsample(mode='bilinear', scale_factor=2), |
conv1x1(in_channels, out_channels)) |
def conv1x1(in_channels, out_channels, groups=1): |
return nn.Conv2d( |
in_channels, |
out_channels, |
kernel_size=1, |
groups=groups, |
stride=1) |
class RollOut_Conv(nn.Module): |
def __init__(self,in_channels,out_channels): |
super(RollOut_Conv,self).__init__() |
self.in_channels=in_channels |
self.out_channels=out_channels |
self.conv = conv3x3(self.in_channels*3, self.out_channels) |
def forward(self,row_features): |
H,W=row_features.shape[2],row_features.shape[3] |
H_per=H//3 |
xz_feature,xy_feature,yz_feature=torch.split(row_features,dim=2,split_size_or_sections=H_per) |
xy_row_pool=torch.mean(xy_feature,dim=2,keepdim=True).expand(-1,-1,H_per,-1) |
yz_col_pool=torch.mean(yz_feature,dim=3,keepdim=True).expand(-1,-1,-1,W) |
cat_xz_feat=torch.cat([xz_feature,xy_row_pool,yz_col_pool],dim=1) |
xz_row_pool=torch.mean(xz_feature,dim=2,keepdim=True).expand(-1,-1,H_per,-1) |
zy_feature=yz_feature.transpose(2,3) |
zy_col_pool=torch.mean(zy_feature,dim=3,keepdim=True).expand(-1,-1,-1,W) |
cat_xy_feat=torch.cat([xy_feature,xz_row_pool,zy_col_pool],dim=1) |
xz_col_pool=torch.mean(xz_feature,dim=3,keepdim=True).expand(-1,-1,-1,W) |
yx_feature=xy_feature.transpose(2,3) |
yx_row_pool=torch.mean(yx_feature,dim=2,keepdim=True).expand(-1,-1,H_per,-1) |
cat_yz_feat=torch.cat([yz_feature,yx_row_pool,xz_col_pool],dim=1) |
fuse_row_feat=torch.cat([cat_xz_feat,cat_xy_feat,cat_yz_feat],dim=2) |
x = self.conv(fuse_row_feat) |
return x |
class DownConv(nn.Module): |
""" |
A helper Module that performs 2 convolutions and 1 MaxPool. |
A ReLU activation follows each convolution. |
""" |
def __init__(self, in_channels, out_channels, pooling=True): |
super(DownConv, self).__init__() |
self.in_channels = in_channels |
self.out_channels = out_channels |
self.pooling = pooling |
self.conv1 = conv3x3(self.in_channels, self.out_channels) |
self.Rollout_conv=RollOut_Conv(self.out_channels,self.out_channels) |
self.conv2 = conv3x3(self.out_channels, self.out_channels) |
if self.pooling: |
self.pool = nn.MaxPool2d(kernel_size=2, stride=2) |
def forward(self, x): |
x = F.relu(self.conv1(x)) |
x = F.relu(self.Rollout_conv(x)) |
x = F.relu(self.conv2(x)) |
before_pool = x |
if self.pooling: |
x = self.pool(x) |
return x, before_pool |
class UpConv(nn.Module): |
""" |
A helper Module that performs 2 convolutions and 1 UpConvolution. |
A ReLU activation follows each convolution. |
""" |
def __init__(self, in_channels, out_channels, |
merge_mode='concat', up_mode='transpose'): |
super(UpConv, self).__init__() |
self.in_channels = in_channels |
self.out_channels = out_channels |
self.merge_mode = merge_mode |
self.up_mode = up_mode |
self.upconv = upconv2x2(self.in_channels, self.out_channels, |
mode=self.up_mode) |
if self.merge_mode == 'concat': |
self.conv1 = conv3x3( |
2 * self.out_channels, self.out_channels) |
else: |
self.conv1 = conv3x3(self.out_channels, self.out_channels) |
self.Rollout_conv = RollOut_Conv(self.out_channels, self.out_channels) |
self.conv2 = conv3x3(self.out_channels, self.out_channels) |
def forward(self, from_down, from_up): |
""" Forward pass |
Arguments: |
from_down: tensor from the encoder pathway |
from_up: upconv'd tensor from the decoder pathway |
""" |
from_up = self.upconv(from_up) |
if self.merge_mode == 'concat': |
x = torch.cat((from_up, from_down), 1) |
else: |
x = from_up + from_down |
x = F.relu(self.conv1(x)) |
x = F.relu(self.Rollout_conv(x)) |
x = F.relu(self.conv2(x)) |
return x |
class UNet(nn.Module): |
""" `UNet` class is based on https://arxiv.org/abs/1505.04597 |
The U-Net is a convolutional encoder-decoder neural network. |
Contextual spatial information (from the decoding, |
expansive pathway) about an input tensor is merged with |
information representing the localization of details |
(from the encoding, compressive pathway). |
Modifications to the original paper: |
(1) padding is used in 3x3 convolutions to prevent loss |
of border pixels |
(2) merging outputs does not require cropping due to (1) |
(3) residual connections can be used by specifying |
UNet(merge_mode='add') |
(4) if non-parametric upsampling is used in the decoder |
pathway (specified by upmode='upsample'), then an |
additional 1x1 2d convolution occurs after upsampling |
to reduce channel dimensionality by a factor of 2. |
This channel halving happens with the convolution in |
the tranpose convolution (specified by upmode='transpose') |
""" |
def __init__(self, num_classes, in_channels=3, depth=5, |
start_filts=64, up_mode='transpose', |
merge_mode='concat', **kwargs): |
""" |
Arguments: |
in_channels: int, number of channels in the input tensor. |
Default is 3 for RGB images. |
depth: int, number of MaxPools in the U-Net. |
start_filts: int, number of convolutional filters for the |
first conv. |
up_mode: string, type of upconvolution. Choices: 'transpose' |
for transpose convolution or 'upsample' for nearest neighbour |
upsampling. |
""" |
super(UNet, self).__init__() |
if up_mode in ('transpose', 'upsample'): |
self.up_mode = up_mode |
else: |
raise ValueError("\"{}\" is not a valid mode for " |
"upsampling. Only \"transpose\" and " |
"\"upsample\" are allowed.".format(up_mode)) |
if merge_mode in ('concat', 'add'): |
self.merge_mode = merge_mode |
else: |
raise ValueError("\"{}\" is not a valid mode for" |
"merging up and down paths. " |
"Only \"concat\" and " |
"\"add\" are allowed.".format(up_mode)) |
if self.up_mode == 'upsample' and self.merge_mode == 'add': |
raise ValueError("up_mode \"upsample\" is incompatible " |
"with merge_mode \"add\" at the moment " |
"because it doesn't make sense to use " |
"nearest neighbour to reduce " |
"depth channels (by half).") |
self.num_classes = num_classes |
self.in_channels = in_channels |
self.start_filts = start_filts |
self.depth = depth |
self.down_convs = [] |
self.up_convs = [] |
for i in range(depth): |
ins = self.in_channels if i == 0 else outs |
outs = self.start_filts * (2 ** i) |
pooling = True if i < depth - 1 else False |
down_conv = DownConv(ins, outs, pooling=pooling) |
self.down_convs.append(down_conv) |
for i in range(depth - 1): |
ins = outs |
outs = ins // 2 |
up_conv = UpConv(ins, outs, up_mode=up_mode, |
merge_mode=merge_mode) |
self.up_convs.append(up_conv) |
self.down_convs = nn.ModuleList(self.down_convs) |
self.up_convs = nn.ModuleList(self.up_convs) |
self.conv_final = conv1x1(outs, self.num_classes) |
self.reset_params() |
@staticmethod |
def weight_init(m): |
if isinstance(m, nn.Conv2d): |
init.xavier_normal_(m.weight) |
init.constant_(m.bias, 0) |
def reset_params(self): |
for i, m in enumerate(self.modules()): |
self.weight_init(m) |
def forward(self, feature_plane): |
x=feature_plane |
encoder_outs = [] |
for i, module in enumerate(self.down_convs): |
x, before_pool = module(x) |
encoder_outs.append(before_pool) |
for i, module in enumerate(self.up_convs): |
before_pool = encoder_outs[-(i + 2)] |
x = module(before_pool, x) |
x = self.conv_final(x) |
return x |
if __name__ == "__main__": |
model=UNet(32, depth=5, merge_mode='concat', in_channels=32, start_filts=32).cuda().float() |
row_feature=torch.randn((10,32,128*3,128)).cuda().float() |
output=model(row_feature) |