Photo_Colorizer / model.py
matikosowy's picture
description and title
28801c1
import torch.nn as nn
import torch
import torch.nn.functional as F
# Dropout layer that works even in the evaluation mode
class DropoutAlways(nn.Dropout2d):
def forward(self, x):
return F.dropout2d(x, self.p, training=True)
class DownBlock(nn.Module):
def __init__(self, in_channels, out_channels, normalize=True):
super().__init__()
self.block = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 4, 2, 1, padding_mode='reflect', bias=False if normalize else True),
nn.InstanceNorm2d(out_channels, affine=True) if normalize else nn.Identity(),
# Note that nn.Identity() is just a placeholder layer that returns its input.
nn.LeakyReLU(0.2),
)
def forward(self, x):
return self.block(x)
class UpBlock(nn.Module):
def __init__(self, in_channels, out_channels, normalize=True, dropout=False, activation='relu'):
super().__init__()
self.block = nn.Sequential(
nn.ConvTranspose2d(in_channels, out_channels, 4, 2, 1, bias=False if normalize else True),
nn.InstanceNorm2d(out_channels, affine=True) if normalize else nn.Identity(),
DropoutAlways() if dropout else nn.Identity(),
nn.ReLU() if activation == 'relu' else nn.Tanh(),
)
def forward(self, x):
return self.block(x)
class Generator(nn.Module):
def __init__(self):
super().__init__()
# Encoder
self.encoder1 = DownBlock(1, 64, normalize=False) # 256x256 -> 128x128
self.encoder2 = DownBlock(64, 128) # 128x128 -> 64x64
self.encoder3 = DownBlock(128, 256) # 64x64 -> 32x32
self.encoder4 = DownBlock(256, 512) # 32x32 -> 16x16
self.encoder5 = DownBlock(512, 512) # 16x16 -> 8x8
self.encoder6 = DownBlock(512, 512) # 8x8 -> 4x4
self.encoder7 = DownBlock(512, 512) # 4x4 -> 2x2
self.encoder8 = DownBlock(512, 512, normalize=False) # 2x2 -> 1x1
# Decoder
self.decoder1 = UpBlock(512, 512, dropout=True) # 1x1 -> 2x2
self.decoder2 = UpBlock(512 * 2, 512, dropout=True) # 2x2 -> 4x4
self.decoder3 = UpBlock(512 * 2, 512, dropout=True) # 4x4 -> 8x8
self.decoder4 = UpBlock(512 * 2, 512) # 8x8 -> 16x16
self.decoder5 = UpBlock(512 * 2, 256) # 16x16 -> 32x32
self.decoder6 = UpBlock(256 * 2, 128) # 32x32 -> 64x64
self.decoder7 = UpBlock(128 * 2, 64) # 64x64 -> 128x128
self.decoder8 = UpBlock(64 * 2, 2, normalize=False, activation='tanh') # 128x128 -> 256x256
def forward(self, x):
# Encoder
ch256_down = x
ch128_down = self.encoder1(ch256_down)
ch64_down = self.encoder2(ch128_down)
ch32_down = self.encoder3(ch64_down)
ch16_down = self.encoder4(ch32_down)
ch8_down = self.encoder5(ch16_down)
ch4_down = self.encoder6(ch8_down)
ch2_down = self.encoder7(ch4_down)
ch1 = self.encoder8(ch2_down)
# Decoder
ch2_up = self.decoder1(ch1)
ch2 = torch.cat([ch2_up, ch2_down], dim=1)
ch4_up = self.decoder2(ch2)
ch4 = torch.cat([ch4_up, ch4_down], dim=1)
ch8_up = self.decoder3(ch4)
ch8 = torch.cat([ch8_up, ch8_down], dim=1)
ch16_up = self.decoder4(ch8)
ch16 = torch.cat([ch16_up, ch16_down], dim=1)
ch32_up = self.decoder5(ch16)
ch32 = torch.cat([ch32_up, ch32_down], dim=1)
ch64_up = self.decoder6(ch32)
ch64 = torch.cat([ch64_up, ch64_down], dim=1)
ch128_up = self.decoder7(ch64)
ch128 = torch.cat([ch128_up, ch128_down], dim=1)
ch256_up = self.decoder8(ch128)
return ch256_up