Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
# Conv Layer | |
class ConvLayer(nn.Module): | |
def __init__(self, in_channels, out_channels, kernel_size, stride, groups=1): | |
super(ConvLayer, self).__init__() | |
paddings = kernel_size // 2 | |
self.reflection_pad = nn.ReflectionPad2d(paddings) | |
self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride, groups=groups) #, padding) | |
# self.in_d = nn.InstanceNorm2d(out_channels, affine=True) | |
def forward(self, x): | |
out = self.reflection_pad(x) | |
out = self.conv2d(out) | |
return out | |
class ConvLayer_dpws(nn.Module): | |
def __init__(self, in_channels, out_channels, kernel_size, stride): | |
super(ConvLayer_dpws, self).__init__() | |
self.conv1 = ConvLayer(in_channels, in_channels, kernel_size, stride=stride, groups=in_channels) | |
self.in_1d = nn.InstanceNorm2d(in_channels, affine=True) | |
self.conv2 = ConvLayer(in_channels, out_channels, kernel_size=1, stride=1) | |
self.in_2d = nn.InstanceNorm2d(out_channels, affine=True) | |
self.relu = nn.ReLU() | |
def forward(self, x): | |
out = self.in_1d(self.conv1(x)) | |
out = self.relu(self.in_2d(self.conv2(out))) | |
return out | |
class ConvLayer_dpws_last(nn.Module): | |
def __init__(self, in_channels, out_channels, kernel_size, stride): | |
super(ConvLayer_dpws_last, self).__init__() | |
self.conv1 = ConvLayer(in_channels, in_channels, kernel_size, stride=stride, groups=in_channels) | |
self.in_1d = nn.InstanceNorm2d(in_channels, affine=True) | |
self.conv2 = ConvLayer(in_channels, out_channels, kernel_size=1, stride=1) | |
# self.in_2d = nn.InstanceNorm2d(out_channels, affine=True) | |
# self.relu = nn.ReLU() | |
def forward(self, x): | |
out = self.in_1d(self.conv1(x)) | |
# out = self.relu(self.in_2d(self.conv2(out))) | |
out = self.conv2(out) | |
return out | |
# Upsample Conv Layer | |
class UpsampleConvLayer(nn.Module): | |
def __init__(self, in_channels, out_channels, kernel_size, stride, upsample=None): | |
super(UpsampleConvLayer, self).__init__() | |
self.upsample = upsample | |
if upsample: | |
self.upsample = nn.Upsample(scale_factor=upsample, mode='nearest') | |
reflection_padding = kernel_size // 2 | |
self.reflection_pad = nn.ReflectionPad2d(reflection_padding) | |
self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride) | |
# self.in_d = nn.InstanceNorm2d(out_channels, affine=True) | |
def forward(self, x): | |
if self.upsample: | |
x = self.upsample(x) | |
out = self.reflection_pad(x) | |
out = self.conv2d(out) | |
return out | |
class UpsampleConvLayer_dpws(nn.Module): | |
def __init__(self, in_channels, out_channels, kernel_size, stride, upsample=None): | |
super(UpsampleConvLayer_dpws, self).__init__() | |
self.upsample = upsample | |
if upsample: | |
self.upsample = nn.Upsample(scale_factor=upsample, mode='nearest') | |
self.conv1 = ConvLayer(in_channels, in_channels, kernel_size, stride, groups=in_channels) | |
self.in1 = nn.InstanceNorm2d(in_channels, affine=True ) | |
self.conv2 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1) | |
self.in2 = nn.InstanceNorm2d(out_channels, affine=True ) | |
self.relu = nn.ReLU() | |
def forward(self, x): | |
if self.upsample: | |
x = self.upsample(x) | |
# out = self.reflection_pad(x) | |
# out = self.conv2d(out) | |
out = self.relu(self.in1(self.conv1(x))) | |
out = self.in2(self.conv2(out)) | |
return out | |
class DeConvLayer(nn.Module): | |
def __init__(self, in_channels, out_channels, kernel_size, stride): | |
super(DeConvLayer, self).__init__() | |
# reflection_padding = kernel_size // 2 | |
# self.reflection_pad = nn.ReflectionPad2d(reflection_padding) | |
self.deconv2d = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding=1, output_padding=1) | |
def forward(self, x): | |
# out = self.reflection_pad(x) | |
out = self.deconv2d(x) | |
return out | |
# Residual Block | |
# adapted from pytorch tutorial | |
# https://github.com/yunjey/pytorch-tutorial/blob/master/tutorials/02- | |
# intermediate/deep_residual_network/main.py | |
class ResidualBlock(nn.Module): | |
def __init__(self, channels): | |
super(ResidualBlock, self).__init__() | |
self.conv1 = ConvLayer(channels, channels, kernel_size=3, stride=1) | |
self.in1 = nn.InstanceNorm2d(channels, affine=True) | |
self.relu = nn.ReLU() | |
self.conv2 = ConvLayer(channels, channels, kernel_size=3, stride=1) | |
self.in2 = nn.InstanceNorm2d(channels, affine=True) | |
def forward(self, x): | |
residual = x | |
out = self.relu(self.in1(self.conv1(x))) | |
# out = self.relu(self.in2(self.conv2(out))) | |
out = self.in2(self.conv2(out)) | |
# out = self.relu(self.conv2(out)) | |
# out = self.conv2(out) | |
out = out + residual | |
# out = self.relu(out) | |
return out | |
class ResidualBlock_depthwise(nn.Module): | |
def __init__(self, channels): | |
super(ResidualBlock_depthwise, self).__init__() | |
# ########################## deptwise ########################################### | |
self.conv1 = ConvLayer(channels, channels, kernel_size=3, stride=1, groups=channels) | |
self.in1 = nn.InstanceNorm2d(channels, affine=True ) | |
self.conv2 = nn.Conv2d(channels, channels, kernel_size=1, stride=1) | |
self.in2 = nn.InstanceNorm2d(channels, affine=True ) | |
self.conv3 = ConvLayer(channels, channels, kernel_size=3, stride=1, groups=channels) | |
self.in3 = nn.InstanceNorm2d(channels, affine=True ) | |
self.conv4 = nn.Conv2d(channels, channels, kernel_size=1, stride=1) | |
self.in4 = nn.InstanceNorm2d(channels, affine=True ) | |
self.relu = nn.ReLU() | |
self.prelu = nn.PReLU() | |
def forward(self, x): | |
# ############### DEPTWISE ################### | |
# residual = x | |
# out = self.relu(self.in1(self.conv1(x))) | |
# out = self.relu(self.in2(self.conv2(out))) | |
# out = self.relu(self.in3(self.conv3(out))) | |
# out = self.relu(self.in4(self.conv4(out))) | |
# out = out + residual | |
# # ################## v1 #################### | |
# residual = x | |
# out = self.in1(self.conv1(x)) | |
# out = self.relu(self.in2(self.conv2(out))) | |
# out = self.in3(self.conv3(out)) | |
# out = self.in4(self.conv4(out)) | |
# out = out + residual | |
# out = self.relu(out) | |
# ################## v2 #################### √ | |
residual = x | |
out = self.in1(self.conv1(x)) | |
out = self.relu(self.in2(self.conv2(out))) | |
out = self.in3(self.conv3(out)) | |
out = self.in4(self.conv4(out)) | |
out = out + residual | |
# ################## v3 #################### | |
# residual = x | |
# out = self.conv1(x) | |
# out = self.relu(self.in2(self.conv2(out))) | |
# out = self.conv3(out) | |
# out = self.in4(self.conv4(out)) | |
# out = out + residual | |
# ################## v4 #################### | |
# residual = x | |
# out = self.in1(self.conv1(x)) | |
# out = self.relu(self.in2(self.conv2(out))) | |
# out = self.in3(self.conv3(out)) | |
# out = self.relu(self.in4(self.conv4(out))) | |
# out = out + residual | |
return out | |
# Image Transform Network | |
class ImageTransformNet(nn.Module): | |
def __init__(self): | |
super(ImageTransformNet, self).__init__() | |
# nonlineraity | |
self.relu = nn.ReLU() | |
self.tanh = nn.Tanh() | |
# encoding layers | |
self.conv1 = ConvLayer(3, 32, kernel_size=9, stride=1) | |
self.in1_e = nn.InstanceNorm2d(32, affine=True ) | |
self.conv2 = ConvLayer(32, 64, kernel_size=3, stride=2) | |
self.in2_e = nn.InstanceNorm2d(64, affine=True ) | |
self.conv3 = ConvLayer(64, 128, kernel_size=3, stride=2) | |
self.in3_e = nn.InstanceNorm2d(128, affine=True ) | |
# residual layers | |
self.res1 = ResidualBlock(128) | |
self.res2 = ResidualBlock(128) | |
self.res3 = ResidualBlock(128) | |
self.res4 = ResidualBlock(128) | |
self.res5 = ResidualBlock(128) | |
# self.res6 = ResidualBlock(128) | |
# decoding layers | |
# TODO: | |
# self.deconv3 = DeConvLayer(128, 64, kernel_size=3, stride=2) | |
self.deconv3 = UpsampleConvLayer(128, 64, kernel_size=3, stride=1, upsample=2) | |
self.in3_d = nn.InstanceNorm2d(64, affine=True ) | |
# self.deconv2 = DeConvLayer(64, 32, kernel_size=3, stride=2) | |
self.deconv2 = UpsampleConvLayer(64, 32, kernel_size=3, stride=1, upsample=2) | |
self.in2_d = nn.InstanceNorm2d(32, affine=True ) | |
self.deconv1 = ConvLayer(32, 3, kernel_size=9, stride=1) | |
self.in1_d = nn.InstanceNorm2d(3, affine=True ) | |
def forward(self, x): | |
# encode | |
y = self.relu(self.in1_e(self.conv1(x))) | |
y = self.relu(self.in2_e(self.conv2(y))) | |
y = self.relu(self.in3_e(self.conv3(y))) | |
# y = self.relu(self.conv1(x)) | |
# y = self.relu(self.conv2(y)) | |
# y = self.relu(self.conv3(y)) | |
y_downsample = y | |
# residual layers | |
y = self.res1(y) | |
y = self.res2(y) | |
y = self.res3(y) | |
y = self.res4(y) | |
y = self.res5(y) | |
y_upsample = y | |
# decode | |
y = self.relu(self.in3_d(self.deconv3(y))) | |
y = self.relu(self.in2_d(self.deconv2(y))) | |
# y = self.relu(self.deconv3(y)) | |
# y = self.relu(self.deconv2(y)) | |
# y = self.tanh(self.in1_d(self.deconv1(y))) | |
y = self.deconv1(y) | |
# return y, y_downsample, y_upsample | |
return y | |
ALAPHA_1 = 0.25 | |
ALAPHA_2 = 0.25 | |
# ALAPHA_1 = 0.5 | |
# ALAPHA_2 = 0.5 | |
class ImageTransformNet_dpws(nn.Module): | |
def __init__(self): | |
super(ImageTransformNet_dpws, self).__init__() | |
# nonlineraity | |
self.relu = nn.ReLU() | |
self.tanh = nn.Tanh() | |
# encoding layers | |
self.conv1 = ConvLayer_dpws(3, int(32*ALAPHA_1), kernel_size=9, stride=1) | |
# self.in1_e = nn.InstanceNorm2d(int(32*ALAPHA_1), affine=True ) | |
self.conv2 = ConvLayer_dpws(int(32*ALAPHA_1), int(64*ALAPHA_1), kernel_size=3, stride= 2) | |
self.conv3 = ConvLayer_dpws(int(64*ALAPHA_1), int(128*ALAPHA_2), kernel_size=3, stride= 2) | |
# residual layers | |
self.res1 = ResidualBlock_depthwise(int(128*ALAPHA_2)) | |
self.res2 = ResidualBlock_depthwise(int(128*ALAPHA_2)) | |
self.res3 = ResidualBlock_depthwise(int(128*ALAPHA_2)) | |
self.res4 = ResidualBlock_depthwise(int(128*ALAPHA_2)) | |
self.res5 = ResidualBlock_depthwise(int(128*ALAPHA_2)) | |
# self.res6 = ResidualBlock_depthwise(128) | |
# decoding layers | |
# TODO: | |
# self.deconv3 = DeConvLayer(128, 64, kernel_size=3, stride=2) | |
self.deconv3 = UpsampleConvLayer_dpws(int(128*ALAPHA_2), int(64*ALAPHA_1), kernel_size=3, stride=1, upsample=2) | |
# self.in3_d = nn.InstanceNorm2d(int(64*ALAPHA_1), affine=True ) | |
# self.deconv2 = DeConvLayer(64, 32, kernel_size=3, stride=2) | |
self.deconv2 = UpsampleConvLayer_dpws(int(64*ALAPHA_1), int(32*ALAPHA_1), kernel_size=3, stride=1, upsample=2) | |
# self.in2_d = nn.InstanceNorm2d(32, affine=True ) | |
self.deconv1 = ConvLayer_dpws_last(int(32*ALAPHA_1), 3, kernel_size=9, stride=1) | |
# self.deconv1 = ConvLayer_dpws_last(int(32*ALAPHA_1), 3, kernel_size=9, stride=1) | |
# self.in1_d = nn.InstanceNorm2d(3, affine=True ) | |
def forward(self, x): | |
# encode | |
# y = self.relu(self.in1_e(self.conv1(x))) | |
y = self.conv1(x) | |
y = self.conv2(y) | |
y = self.conv3(y) | |
y_downsample = y | |
# y = self.relu(self.in2_e(self.conv2(y))) | |
# y = self.relu(self.in3_e(self.conv3(y))) | |
# residual layers | |
y = self.res1(y) | |
y = self.res2(y) | |
y = self.res3(y) | |
y = self.res4(y) | |
y = self.res5(y) | |
# y = self.res6(y) | |
y_upsample = y | |
# decode | |
y = self.deconv3(y) | |
y = self.deconv2(y) | |
y = self.deconv1(y) | |
# return y, y_downsample, y_upsample | |
return y | |
class distiller_1(nn.Module): | |
def __init__(self): | |
super(distiller_1, self).__init__() | |
self.conv = nn.Conv2d(128, int(128*ALAPHA_2), kernel_size=1, stride=1) | |
def forward(self, x): | |
# encode | |
y = self.conv(x) | |
return y | |
class distiller_2(nn.Module): | |
def __init__(self): | |
super(distiller_2, self).__init__() | |
self.conv = nn.Conv2d(128, int(128*ALAPHA_2), kernel_size=1, stride=1) | |
def forward(self, x): | |
# encode | |
y = self.conv(x) | |
return y | |