LITnet / network.py
JungminChung's picture
first commit
3617b5f
raw
history blame
12.9 kB
import torch
import torch.nn as nn
# Conv Layer
class ConvLayer(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride, groups=1):
super(ConvLayer, self).__init__()
paddings = kernel_size // 2
self.reflection_pad = nn.ReflectionPad2d(paddings)
self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride, groups=groups) #, padding)
# self.in_d = nn.InstanceNorm2d(out_channels, affine=True)
def forward(self, x):
out = self.reflection_pad(x)
out = self.conv2d(out)
return out
class ConvLayer_dpws(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride):
super(ConvLayer_dpws, self).__init__()
self.conv1 = ConvLayer(in_channels, in_channels, kernel_size, stride=stride, groups=in_channels)
self.in_1d = nn.InstanceNorm2d(in_channels, affine=True)
self.conv2 = ConvLayer(in_channels, out_channels, kernel_size=1, stride=1)
self.in_2d = nn.InstanceNorm2d(out_channels, affine=True)
self.relu = nn.ReLU()
def forward(self, x):
out = self.in_1d(self.conv1(x))
out = self.relu(self.in_2d(self.conv2(out)))
return out
class ConvLayer_dpws_last(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride):
super(ConvLayer_dpws_last, self).__init__()
self.conv1 = ConvLayer(in_channels, in_channels, kernel_size, stride=stride, groups=in_channels)
self.in_1d = nn.InstanceNorm2d(in_channels, affine=True)
self.conv2 = ConvLayer(in_channels, out_channels, kernel_size=1, stride=1)
# self.in_2d = nn.InstanceNorm2d(out_channels, affine=True)
# self.relu = nn.ReLU()
def forward(self, x):
out = self.in_1d(self.conv1(x))
# out = self.relu(self.in_2d(self.conv2(out)))
out = self.conv2(out)
return out
# Upsample Conv Layer
class UpsampleConvLayer(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride, upsample=None):
super(UpsampleConvLayer, self).__init__()
self.upsample = upsample
if upsample:
self.upsample = nn.Upsample(scale_factor=upsample, mode='nearest')
reflection_padding = kernel_size // 2
self.reflection_pad = nn.ReflectionPad2d(reflection_padding)
self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride)
# self.in_d = nn.InstanceNorm2d(out_channels, affine=True)
def forward(self, x):
if self.upsample:
x = self.upsample(x)
out = self.reflection_pad(x)
out = self.conv2d(out)
return out
class UpsampleConvLayer_dpws(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride, upsample=None):
super(UpsampleConvLayer_dpws, self).__init__()
self.upsample = upsample
if upsample:
self.upsample = nn.Upsample(scale_factor=upsample, mode='nearest')
self.conv1 = ConvLayer(in_channels, in_channels, kernel_size, stride, groups=in_channels)
self.in1 = nn.InstanceNorm2d(in_channels, affine=True )
self.conv2 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1)
self.in2 = nn.InstanceNorm2d(out_channels, affine=True )
self.relu = nn.ReLU()
def forward(self, x):
if self.upsample:
x = self.upsample(x)
# out = self.reflection_pad(x)
# out = self.conv2d(out)
out = self.relu(self.in1(self.conv1(x)))
out = self.in2(self.conv2(out))
return out
class DeConvLayer(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride):
super(DeConvLayer, self).__init__()
# reflection_padding = kernel_size // 2
# self.reflection_pad = nn.ReflectionPad2d(reflection_padding)
self.deconv2d = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding=1, output_padding=1)
def forward(self, x):
# out = self.reflection_pad(x)
out = self.deconv2d(x)
return out
# Residual Block
# adapted from pytorch tutorial
# https://github.com/yunjey/pytorch-tutorial/blob/master/tutorials/02-
# intermediate/deep_residual_network/main.py
class ResidualBlock(nn.Module):
def __init__(self, channels):
super(ResidualBlock, self).__init__()
self.conv1 = ConvLayer(channels, channels, kernel_size=3, stride=1)
self.in1 = nn.InstanceNorm2d(channels, affine=True)
self.relu = nn.ReLU()
self.conv2 = ConvLayer(channels, channels, kernel_size=3, stride=1)
self.in2 = nn.InstanceNorm2d(channels, affine=True)
def forward(self, x):
residual = x
out = self.relu(self.in1(self.conv1(x)))
# out = self.relu(self.in2(self.conv2(out)))
out = self.in2(self.conv2(out))
# out = self.relu(self.conv2(out))
# out = self.conv2(out)
out = out + residual
# out = self.relu(out)
return out
class ResidualBlock_depthwise(nn.Module):
def __init__(self, channels):
super(ResidualBlock_depthwise, self).__init__()
# ########################## deptwise ###########################################
self.conv1 = ConvLayer(channels, channels, kernel_size=3, stride=1, groups=channels)
self.in1 = nn.InstanceNorm2d(channels, affine=True )
self.conv2 = nn.Conv2d(channels, channels, kernel_size=1, stride=1)
self.in2 = nn.InstanceNorm2d(channels, affine=True )
self.conv3 = ConvLayer(channels, channels, kernel_size=3, stride=1, groups=channels)
self.in3 = nn.InstanceNorm2d(channels, affine=True )
self.conv4 = nn.Conv2d(channels, channels, kernel_size=1, stride=1)
self.in4 = nn.InstanceNorm2d(channels, affine=True )
self.relu = nn.ReLU()
self.prelu = nn.PReLU()
def forward(self, x):
# ############### DEPTWISE ###################
# residual = x
# out = self.relu(self.in1(self.conv1(x)))
# out = self.relu(self.in2(self.conv2(out)))
# out = self.relu(self.in3(self.conv3(out)))
# out = self.relu(self.in4(self.conv4(out)))
# out = out + residual
# # ################## v1 ####################
# residual = x
# out = self.in1(self.conv1(x))
# out = self.relu(self.in2(self.conv2(out)))
# out = self.in3(self.conv3(out))
# out = self.in4(self.conv4(out))
# out = out + residual
# out = self.relu(out)
# ################## v2 #################### √
residual = x
out = self.in1(self.conv1(x))
out = self.relu(self.in2(self.conv2(out)))
out = self.in3(self.conv3(out))
out = self.in4(self.conv4(out))
out = out + residual
# ################## v3 ####################
# residual = x
# out = self.conv1(x)
# out = self.relu(self.in2(self.conv2(out)))
# out = self.conv3(out)
# out = self.in4(self.conv4(out))
# out = out + residual
# ################## v4 ####################
# residual = x
# out = self.in1(self.conv1(x))
# out = self.relu(self.in2(self.conv2(out)))
# out = self.in3(self.conv3(out))
# out = self.relu(self.in4(self.conv4(out)))
# out = out + residual
return out
# Image Transform Network
class ImageTransformNet(nn.Module):
def __init__(self):
super(ImageTransformNet, self).__init__()
# nonlineraity
self.relu = nn.ReLU()
self.tanh = nn.Tanh()
# encoding layers
self.conv1 = ConvLayer(3, 32, kernel_size=9, stride=1)
self.in1_e = nn.InstanceNorm2d(32, affine=True )
self.conv2 = ConvLayer(32, 64, kernel_size=3, stride=2)
self.in2_e = nn.InstanceNorm2d(64, affine=True )
self.conv3 = ConvLayer(64, 128, kernel_size=3, stride=2)
self.in3_e = nn.InstanceNorm2d(128, affine=True )
# residual layers
self.res1 = ResidualBlock(128)
self.res2 = ResidualBlock(128)
self.res3 = ResidualBlock(128)
self.res4 = ResidualBlock(128)
self.res5 = ResidualBlock(128)
# self.res6 = ResidualBlock(128)
# decoding layers
# TODO:
# self.deconv3 = DeConvLayer(128, 64, kernel_size=3, stride=2)
self.deconv3 = UpsampleConvLayer(128, 64, kernel_size=3, stride=1, upsample=2)
self.in3_d = nn.InstanceNorm2d(64, affine=True )
# self.deconv2 = DeConvLayer(64, 32, kernel_size=3, stride=2)
self.deconv2 = UpsampleConvLayer(64, 32, kernel_size=3, stride=1, upsample=2)
self.in2_d = nn.InstanceNorm2d(32, affine=True )
self.deconv1 = ConvLayer(32, 3, kernel_size=9, stride=1)
self.in1_d = nn.InstanceNorm2d(3, affine=True )
def forward(self, x):
# encode
y = self.relu(self.in1_e(self.conv1(x)))
y = self.relu(self.in2_e(self.conv2(y)))
y = self.relu(self.in3_e(self.conv3(y)))
# y = self.relu(self.conv1(x))
# y = self.relu(self.conv2(y))
# y = self.relu(self.conv3(y))
y_downsample = y
# residual layers
y = self.res1(y)
y = self.res2(y)
y = self.res3(y)
y = self.res4(y)
y = self.res5(y)
y_upsample = y
# decode
y = self.relu(self.in3_d(self.deconv3(y)))
y = self.relu(self.in2_d(self.deconv2(y)))
# y = self.relu(self.deconv3(y))
# y = self.relu(self.deconv2(y))
# y = self.tanh(self.in1_d(self.deconv1(y)))
y = self.deconv1(y)
# return y, y_downsample, y_upsample
return y
ALAPHA_1 = 0.25
ALAPHA_2 = 0.25
# ALAPHA_1 = 0.5
# ALAPHA_2 = 0.5
class ImageTransformNet_dpws(nn.Module):
def __init__(self):
super(ImageTransformNet_dpws, self).__init__()
# nonlineraity
self.relu = nn.ReLU()
self.tanh = nn.Tanh()
# encoding layers
self.conv1 = ConvLayer_dpws(3, int(32*ALAPHA_1), kernel_size=9, stride=1)
# self.in1_e = nn.InstanceNorm2d(int(32*ALAPHA_1), affine=True )
self.conv2 = ConvLayer_dpws(int(32*ALAPHA_1), int(64*ALAPHA_1), kernel_size=3, stride= 2)
self.conv3 = ConvLayer_dpws(int(64*ALAPHA_1), int(128*ALAPHA_2), kernel_size=3, stride= 2)
# residual layers
self.res1 = ResidualBlock_depthwise(int(128*ALAPHA_2))
self.res2 = ResidualBlock_depthwise(int(128*ALAPHA_2))
self.res3 = ResidualBlock_depthwise(int(128*ALAPHA_2))
self.res4 = ResidualBlock_depthwise(int(128*ALAPHA_2))
self.res5 = ResidualBlock_depthwise(int(128*ALAPHA_2))
# self.res6 = ResidualBlock_depthwise(128)
# decoding layers
# TODO:
# self.deconv3 = DeConvLayer(128, 64, kernel_size=3, stride=2)
self.deconv3 = UpsampleConvLayer_dpws(int(128*ALAPHA_2), int(64*ALAPHA_1), kernel_size=3, stride=1, upsample=2)
# self.in3_d = nn.InstanceNorm2d(int(64*ALAPHA_1), affine=True )
# self.deconv2 = DeConvLayer(64, 32, kernel_size=3, stride=2)
self.deconv2 = UpsampleConvLayer_dpws(int(64*ALAPHA_1), int(32*ALAPHA_1), kernel_size=3, stride=1, upsample=2)
# self.in2_d = nn.InstanceNorm2d(32, affine=True )
self.deconv1 = ConvLayer_dpws_last(int(32*ALAPHA_1), 3, kernel_size=9, stride=1)
# self.deconv1 = ConvLayer_dpws_last(int(32*ALAPHA_1), 3, kernel_size=9, stride=1)
# self.in1_d = nn.InstanceNorm2d(3, affine=True )
def forward(self, x):
# encode
# y = self.relu(self.in1_e(self.conv1(x)))
y = self.conv1(x)
y = self.conv2(y)
y = self.conv3(y)
y_downsample = y
# y = self.relu(self.in2_e(self.conv2(y)))
# y = self.relu(self.in3_e(self.conv3(y)))
# residual layers
y = self.res1(y)
y = self.res2(y)
y = self.res3(y)
y = self.res4(y)
y = self.res5(y)
# y = self.res6(y)
y_upsample = y
# decode
y = self.deconv3(y)
y = self.deconv2(y)
y = self.deconv1(y)
# return y, y_downsample, y_upsample
return y
class distiller_1(nn.Module):
def __init__(self):
super(distiller_1, self).__init__()
self.conv = nn.Conv2d(128, int(128*ALAPHA_2), kernel_size=1, stride=1)
def forward(self, x):
# encode
y = self.conv(x)
return y
class distiller_2(nn.Module):
def __init__(self):
super(distiller_2, self).__init__()
self.conv = nn.Conv2d(128, int(128*ALAPHA_2), kernel_size=1, stride=1)
def forward(self, x):
# encode
y = self.conv(x)
return y