|
''' |
|
Codes are from: |
|
https://github.com/jaxony/unet-pytorch/blob/master/model.py |
|
''' |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torch.autograd import Variable |
|
from collections import OrderedDict |
|
from torch.nn import init |
|
import numpy as np |
|
|
|
|
|
def conv3x3(in_channels, out_channels, stride=1, |
|
padding=1, bias=True, groups=1): |
|
return nn.Conv2d( |
|
in_channels, |
|
out_channels, |
|
kernel_size=3, |
|
stride=stride, |
|
padding=padding, |
|
bias=bias, |
|
groups=groups) |
|
|
|
|
|
def upconv2x2(in_channels, out_channels, mode='transpose'): |
|
if mode == 'transpose': |
|
return nn.ConvTranspose2d( |
|
in_channels, |
|
out_channels, |
|
kernel_size=2, |
|
stride=2) |
|
else: |
|
|
|
|
|
return nn.Sequential( |
|
nn.Upsample(mode='bilinear', scale_factor=2), |
|
conv1x1(in_channels, out_channels)) |
|
|
|
|
|
def conv1x1(in_channels, out_channels, groups=1): |
|
return nn.Conv2d( |
|
in_channels, |
|
out_channels, |
|
kernel_size=1, |
|
groups=groups, |
|
stride=1) |
|
|
|
class RollOut_Conv(nn.Module): |
|
def __init__(self,in_channels,out_channels): |
|
super(RollOut_Conv,self).__init__() |
|
|
|
self.in_channels=in_channels |
|
self.out_channels=out_channels |
|
self.conv = conv3x3(self.in_channels*3, self.out_channels) |
|
|
|
def forward(self,row_features): |
|
H,W=row_features.shape[2],row_features.shape[3] |
|
H_per=H//3 |
|
xz_feature,xy_feature,yz_feature=torch.split(row_features,dim=2,split_size_or_sections=H_per) |
|
xy_row_pool=torch.mean(xy_feature,dim=2,keepdim=True).expand(-1,-1,H_per,-1) |
|
yz_col_pool=torch.mean(yz_feature,dim=3,keepdim=True).expand(-1,-1,-1,W) |
|
cat_xz_feat=torch.cat([xz_feature,xy_row_pool,yz_col_pool],dim=1) |
|
|
|
xz_row_pool=torch.mean(xz_feature,dim=2,keepdim=True).expand(-1,-1,H_per,-1) |
|
zy_feature=yz_feature.transpose(2,3) |
|
zy_col_pool=torch.mean(zy_feature,dim=3,keepdim=True).expand(-1,-1,-1,W) |
|
cat_xy_feat=torch.cat([xy_feature,xz_row_pool,zy_col_pool],dim=1) |
|
|
|
xz_col_pool=torch.mean(xz_feature,dim=3,keepdim=True).expand(-1,-1,-1,W) |
|
yx_feature=xy_feature.transpose(2,3) |
|
yx_row_pool=torch.mean(yx_feature,dim=2,keepdim=True).expand(-1,-1,H_per,-1) |
|
cat_yz_feat=torch.cat([yz_feature,yx_row_pool,xz_col_pool],dim=1) |
|
|
|
fuse_row_feat=torch.cat([cat_xz_feat,cat_xy_feat,cat_yz_feat],dim=2) |
|
|
|
x = self.conv(fuse_row_feat) |
|
|
|
return x |
|
|
|
|
|
class DownConv(nn.Module): |
|
""" |
|
A helper Module that performs 2 convolutions and 1 MaxPool. |
|
A ReLU activation follows each convolution. |
|
""" |
|
|
|
def __init__(self, in_channels, out_channels, pooling=True): |
|
super(DownConv, self).__init__() |
|
|
|
self.in_channels = in_channels |
|
self.out_channels = out_channels |
|
self.pooling = pooling |
|
|
|
self.conv1 = conv3x3(self.in_channels, self.out_channels) |
|
self.Rollout_conv=RollOut_Conv(self.out_channels,self.out_channels) |
|
self.conv2 = conv3x3(self.out_channels, self.out_channels) |
|
|
|
if self.pooling: |
|
self.pool = nn.MaxPool2d(kernel_size=2, stride=2) |
|
|
|
def forward(self, x): |
|
x = F.relu(self.conv1(x)) |
|
x = F.relu(self.Rollout_conv(x)) |
|
x = F.relu(self.conv2(x)) |
|
before_pool = x |
|
if self.pooling: |
|
x = self.pool(x) |
|
return x, before_pool |
|
|
|
|
|
class UpConv(nn.Module): |
|
""" |
|
A helper Module that performs 2 convolutions and 1 UpConvolution. |
|
A ReLU activation follows each convolution. |
|
""" |
|
|
|
def __init__(self, in_channels, out_channels, |
|
merge_mode='concat', up_mode='transpose'): |
|
super(UpConv, self).__init__() |
|
|
|
self.in_channels = in_channels |
|
self.out_channels = out_channels |
|
self.merge_mode = merge_mode |
|
self.up_mode = up_mode |
|
|
|
self.upconv = upconv2x2(self.in_channels, self.out_channels, |
|
mode=self.up_mode) |
|
|
|
if self.merge_mode == 'concat': |
|
self.conv1 = conv3x3( |
|
2 * self.out_channels, self.out_channels) |
|
else: |
|
|
|
self.conv1 = conv3x3(self.out_channels, self.out_channels) |
|
self.Rollout_conv = RollOut_Conv(self.out_channels, self.out_channels) |
|
self.conv2 = conv3x3(self.out_channels, self.out_channels) |
|
|
|
def forward(self, from_down, from_up): |
|
""" Forward pass |
|
Arguments: |
|
from_down: tensor from the encoder pathway |
|
from_up: upconv'd tensor from the decoder pathway |
|
""" |
|
from_up = self.upconv(from_up) |
|
if self.merge_mode == 'concat': |
|
x = torch.cat((from_up, from_down), 1) |
|
else: |
|
x = from_up + from_down |
|
x = F.relu(self.conv1(x)) |
|
x = F.relu(self.Rollout_conv(x)) |
|
x = F.relu(self.conv2(x)) |
|
return x |
|
|
|
|
|
class UNet(nn.Module): |
|
""" `UNet` class is based on https://arxiv.org/abs/1505.04597 |
|
|
|
The U-Net is a convolutional encoder-decoder neural network. |
|
Contextual spatial information (from the decoding, |
|
expansive pathway) about an input tensor is merged with |
|
information representing the localization of details |
|
(from the encoding, compressive pathway). |
|
|
|
Modifications to the original paper: |
|
(1) padding is used in 3x3 convolutions to prevent loss |
|
of border pixels |
|
(2) merging outputs does not require cropping due to (1) |
|
(3) residual connections can be used by specifying |
|
UNet(merge_mode='add') |
|
(4) if non-parametric upsampling is used in the decoder |
|
pathway (specified by upmode='upsample'), then an |
|
additional 1x1 2d convolution occurs after upsampling |
|
to reduce channel dimensionality by a factor of 2. |
|
This channel halving happens with the convolution in |
|
the tranpose convolution (specified by upmode='transpose') |
|
""" |
|
|
|
def __init__(self, num_classes, in_channels=3, depth=5, |
|
start_filts=64, up_mode='transpose', |
|
merge_mode='concat', **kwargs): |
|
""" |
|
Arguments: |
|
in_channels: int, number of channels in the input tensor. |
|
Default is 3 for RGB images. |
|
depth: int, number of MaxPools in the U-Net. |
|
start_filts: int, number of convolutional filters for the |
|
first conv. |
|
up_mode: string, type of upconvolution. Choices: 'transpose' |
|
for transpose convolution or 'upsample' for nearest neighbour |
|
upsampling. |
|
""" |
|
super(UNet, self).__init__() |
|
|
|
if up_mode in ('transpose', 'upsample'): |
|
self.up_mode = up_mode |
|
else: |
|
raise ValueError("\"{}\" is not a valid mode for " |
|
"upsampling. Only \"transpose\" and " |
|
"\"upsample\" are allowed.".format(up_mode)) |
|
|
|
if merge_mode in ('concat', 'add'): |
|
self.merge_mode = merge_mode |
|
else: |
|
raise ValueError("\"{}\" is not a valid mode for" |
|
"merging up and down paths. " |
|
"Only \"concat\" and " |
|
"\"add\" are allowed.".format(up_mode)) |
|
|
|
|
|
if self.up_mode == 'upsample' and self.merge_mode == 'add': |
|
raise ValueError("up_mode \"upsample\" is incompatible " |
|
"with merge_mode \"add\" at the moment " |
|
"because it doesn't make sense to use " |
|
"nearest neighbour to reduce " |
|
"depth channels (by half).") |
|
|
|
self.num_classes = num_classes |
|
self.in_channels = in_channels |
|
self.start_filts = start_filts |
|
self.depth = depth |
|
|
|
self.down_convs = [] |
|
self.up_convs = [] |
|
|
|
|
|
for i in range(depth): |
|
ins = self.in_channels if i == 0 else outs |
|
outs = self.start_filts * (2 ** i) |
|
pooling = True if i < depth - 1 else False |
|
|
|
down_conv = DownConv(ins, outs, pooling=pooling) |
|
self.down_convs.append(down_conv) |
|
|
|
|
|
|
|
for i in range(depth - 1): |
|
ins = outs |
|
outs = ins // 2 |
|
up_conv = UpConv(ins, outs, up_mode=up_mode, |
|
merge_mode=merge_mode) |
|
self.up_convs.append(up_conv) |
|
|
|
|
|
self.down_convs = nn.ModuleList(self.down_convs) |
|
self.up_convs = nn.ModuleList(self.up_convs) |
|
self.conv_final = conv1x1(outs, self.num_classes) |
|
|
|
self.reset_params() |
|
|
|
@staticmethod |
|
def weight_init(m): |
|
if isinstance(m, nn.Conv2d): |
|
init.xavier_normal_(m.weight) |
|
init.constant_(m.bias, 0) |
|
|
|
def reset_params(self): |
|
for i, m in enumerate(self.modules()): |
|
self.weight_init(m) |
|
|
|
def forward(self, feature_plane): |
|
|
|
x=feature_plane |
|
encoder_outs = [] |
|
|
|
for i, module in enumerate(self.down_convs): |
|
x, before_pool = module(x) |
|
encoder_outs.append(before_pool) |
|
for i, module in enumerate(self.up_convs): |
|
before_pool = encoder_outs[-(i + 2)] |
|
x = module(before_pool, x) |
|
|
|
|
|
|
|
|
|
x = self.conv_final(x) |
|
return x |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model=UNet(32, depth=5, merge_mode='concat', in_channels=32, start_filts=32).cuda().float() |
|
row_feature=torch.randn((10,32,128*3,128)).cuda().float() |
|
output=model(row_feature) |
|
|
|
|