LASA / models /modules /unet.py
HaolinLiu's picture
first commit of codes and update readme.md
cc9780d
raw
history blame
11 kB
'''
Codes are from:
https://github.com/jaxony/unet-pytorch/blob/master/model.py
'''
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from collections import OrderedDict
from torch.nn import init
import numpy as np
def conv3x3(in_channels, out_channels, stride=1,
padding=1, bias=True, groups=1):
return nn.Conv2d(
in_channels,
out_channels,
kernel_size=3,
stride=stride,
padding=padding,
bias=bias,
groups=groups)
def upconv2x2(in_channels, out_channels, mode='transpose'):
if mode == 'transpose':
return nn.ConvTranspose2d(
in_channels,
out_channels,
kernel_size=2,
stride=2)
else:
# out_channels is always going to be the same
# as in_channels
return nn.Sequential(
nn.Upsample(mode='bilinear', scale_factor=2),
conv1x1(in_channels, out_channels))
def conv1x1(in_channels, out_channels, groups=1):
return nn.Conv2d(
in_channels,
out_channels,
kernel_size=1,
groups=groups,
stride=1)
class RollOut_Conv(nn.Module):
def __init__(self,in_channels,out_channels):
super(RollOut_Conv,self).__init__()
#pass
self.in_channels=in_channels
self.out_channels=out_channels
self.conv = conv3x3(self.in_channels*3, self.out_channels)
def forward(self,row_features):
H,W=row_features.shape[2],row_features.shape[3]
H_per=H//3
xz_feature,xy_feature,yz_feature=torch.split(row_features,dim=2,split_size_or_sections=H_per)
xy_row_pool=torch.mean(xy_feature,dim=2,keepdim=True).expand(-1,-1,H_per,-1)
yz_col_pool=torch.mean(yz_feature,dim=3,keepdim=True).expand(-1,-1,-1,W)
cat_xz_feat=torch.cat([xz_feature,xy_row_pool,yz_col_pool],dim=1)
xz_row_pool=torch.mean(xz_feature,dim=2,keepdim=True).expand(-1,-1,H_per,-1)
zy_feature=yz_feature.transpose(2,3) #switch z y axis, for reduced confusion
zy_col_pool=torch.mean(zy_feature,dim=3,keepdim=True).expand(-1,-1,-1,W)
cat_xy_feat=torch.cat([xy_feature,xz_row_pool,zy_col_pool],dim=1)
xz_col_pool=torch.mean(xz_feature,dim=3,keepdim=True).expand(-1,-1,-1,W)
yx_feature=xy_feature.transpose(2,3)
yx_row_pool=torch.mean(yx_feature,dim=2,keepdim=True).expand(-1,-1,H_per,-1)
cat_yz_feat=torch.cat([yz_feature,yx_row_pool,xz_col_pool],dim=1)
fuse_row_feat=torch.cat([cat_xz_feat,cat_xy_feat,cat_yz_feat],dim=2) #concat at row dimension
x = self.conv(fuse_row_feat)
return x
class DownConv(nn.Module):
"""
A helper Module that performs 2 convolutions and 1 MaxPool.
A ReLU activation follows each convolution.
"""
def __init__(self, in_channels, out_channels, pooling=True):
super(DownConv, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.pooling = pooling
self.conv1 = conv3x3(self.in_channels, self.out_channels)
self.Rollout_conv=RollOut_Conv(self.out_channels,self.out_channels)
self.conv2 = conv3x3(self.out_channels, self.out_channels)
if self.pooling:
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.relu(self.Rollout_conv(x))
x = F.relu(self.conv2(x))
before_pool = x
if self.pooling:
x = self.pool(x)
return x, before_pool
class UpConv(nn.Module):
"""
A helper Module that performs 2 convolutions and 1 UpConvolution.
A ReLU activation follows each convolution.
"""
def __init__(self, in_channels, out_channels,
merge_mode='concat', up_mode='transpose'):
super(UpConv, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.merge_mode = merge_mode
self.up_mode = up_mode
self.upconv = upconv2x2(self.in_channels, self.out_channels,
mode=self.up_mode)
if self.merge_mode == 'concat':
self.conv1 = conv3x3(
2 * self.out_channels, self.out_channels)
else:
# num of input channels to conv2 is same
self.conv1 = conv3x3(self.out_channels, self.out_channels)
self.Rollout_conv = RollOut_Conv(self.out_channels, self.out_channels)
self.conv2 = conv3x3(self.out_channels, self.out_channels)
def forward(self, from_down, from_up):
""" Forward pass
Arguments:
from_down: tensor from the encoder pathway
from_up: upconv'd tensor from the decoder pathway
"""
from_up = self.upconv(from_up)
if self.merge_mode == 'concat':
x = torch.cat((from_up, from_down), 1)
else:
x = from_up + from_down
x = F.relu(self.conv1(x))
x = F.relu(self.Rollout_conv(x))
x = F.relu(self.conv2(x))
return x
class UNet(nn.Module):
""" `UNet` class is based on https://arxiv.org/abs/1505.04597
The U-Net is a convolutional encoder-decoder neural network.
Contextual spatial information (from the decoding,
expansive pathway) about an input tensor is merged with
information representing the localization of details
(from the encoding, compressive pathway).
Modifications to the original paper:
(1) padding is used in 3x3 convolutions to prevent loss
of border pixels
(2) merging outputs does not require cropping due to (1)
(3) residual connections can be used by specifying
UNet(merge_mode='add')
(4) if non-parametric upsampling is used in the decoder
pathway (specified by upmode='upsample'), then an
additional 1x1 2d convolution occurs after upsampling
to reduce channel dimensionality by a factor of 2.
This channel halving happens with the convolution in
the tranpose convolution (specified by upmode='transpose')
"""
def __init__(self, num_classes, in_channels=3, depth=5,
start_filts=64, up_mode='transpose',
merge_mode='concat', **kwargs):
"""
Arguments:
in_channels: int, number of channels in the input tensor.
Default is 3 for RGB images.
depth: int, number of MaxPools in the U-Net.
start_filts: int, number of convolutional filters for the
first conv.
up_mode: string, type of upconvolution. Choices: 'transpose'
for transpose convolution or 'upsample' for nearest neighbour
upsampling.
"""
super(UNet, self).__init__()
if up_mode in ('transpose', 'upsample'):
self.up_mode = up_mode
else:
raise ValueError("\"{}\" is not a valid mode for "
"upsampling. Only \"transpose\" and "
"\"upsample\" are allowed.".format(up_mode))
if merge_mode in ('concat', 'add'):
self.merge_mode = merge_mode
else:
raise ValueError("\"{}\" is not a valid mode for"
"merging up and down paths. "
"Only \"concat\" and "
"\"add\" are allowed.".format(up_mode))
# NOTE: up_mode 'upsample' is incompatible with merge_mode 'add'
if self.up_mode == 'upsample' and self.merge_mode == 'add':
raise ValueError("up_mode \"upsample\" is incompatible "
"with merge_mode \"add\" at the moment "
"because it doesn't make sense to use "
"nearest neighbour to reduce "
"depth channels (by half).")
self.num_classes = num_classes
self.in_channels = in_channels
self.start_filts = start_filts
self.depth = depth
self.down_convs = []
self.up_convs = []
# create the encoder pathway and add to a list
for i in range(depth):
ins = self.in_channels if i == 0 else outs
outs = self.start_filts * (2 ** i)
pooling = True if i < depth - 1 else False
down_conv = DownConv(ins, outs, pooling=pooling)
self.down_convs.append(down_conv)
# create the decoder pathway and add to a list
# - careful! decoding only requires depth-1 blocks
for i in range(depth - 1):
ins = outs
outs = ins // 2
up_conv = UpConv(ins, outs, up_mode=up_mode,
merge_mode=merge_mode)
self.up_convs.append(up_conv)
# add the list of modules to current module
self.down_convs = nn.ModuleList(self.down_convs)
self.up_convs = nn.ModuleList(self.up_convs)
self.conv_final = conv1x1(outs, self.num_classes)
self.reset_params()
@staticmethod
def weight_init(m):
if isinstance(m, nn.Conv2d):
init.xavier_normal_(m.weight)
init.constant_(m.bias, 0)
def reset_params(self):
for i, m in enumerate(self.modules()):
self.weight_init(m)
def forward(self, feature_plane):
#cat_feature=torch.cat([feature_plane['xz'],feature_plane['xy'],feature_plane,feature_plane['yz']],dim=2) #concat at row dimension
x=feature_plane
encoder_outs = []
# encoder pathway, save outputs for merging
for i, module in enumerate(self.down_convs):
x, before_pool = module(x)
encoder_outs.append(before_pool)
for i, module in enumerate(self.up_convs):
before_pool = encoder_outs[-(i + 2)]
x = module(before_pool, x)
# No softmax is used. This means you need to use
# nn.CrossEntropyLoss is your training script,
# as this module includes a softmax already.
x = self.conv_final(x)
return x
if __name__ == "__main__":
# """
# testing
# """
# model = UNet(1, depth=5, merge_mode='concat', in_channels=1, start_filts=32)
# print(model)
# print(sum(p.numel() for p in model.parameters()))
#
# reso = 176
# x = np.zeros((1, 1, reso, reso))
# x[:, :, int(reso / 2 - 1), int(reso / 2 - 1)] = np.nan
# x = torch.FloatTensor(x)
#
# out = model(x)
# print('%f' % (torch.sum(torch.isnan(out)).detach().cpu().numpy() / (reso * reso)))
#
# # loss = torch.sum(out)
# # loss.backward()
#roll_out_conv=RollOut_Conv(in_channels=32,out_channels=32).cuda().float()
model=UNet(32, depth=5, merge_mode='concat', in_channels=32, start_filts=32).cuda().float()
row_feature=torch.randn((10,32,128*3,128)).cuda().float()
output=model(row_feature)
#output_feature=roll_out_conv(row_feature)
#print(output_feature.shape)