|
import torch |
|
import torch.nn as nn |
|
|
|
from typing import Any, Tuple, Union |
|
|
|
from utils import ( |
|
ImageType, |
|
crop_image_part, |
|
) |
|
|
|
from layers import ( |
|
SpectralConv2d, |
|
InitLayer, |
|
SLEBlock, |
|
UpsampleBlockT1, |
|
UpsampleBlockT2, |
|
DownsampleBlockT1, |
|
DownsampleBlockT2, |
|
Decoder, |
|
) |
|
|
|
from huggan.pytorch.huggan_mixin import HugGANModelHubMixin |
|
|
|
|
|
class Generator(nn.Module, HugGANModelHubMixin): |
|
|
|
def __init__(self, in_channels: int, |
|
out_channels: int): |
|
super().__init__() |
|
|
|
self._channels = { |
|
4: 1024, |
|
8: 512, |
|
16: 256, |
|
32: 128, |
|
64: 128, |
|
128: 64, |
|
256: 32, |
|
512: 16, |
|
1024: 8, |
|
} |
|
|
|
self._init = InitLayer( |
|
in_channels=in_channels, |
|
out_channels=self._channels[4], |
|
) |
|
|
|
self._upsample_8 = UpsampleBlockT2(in_channels=self._channels[4], out_channels=self._channels[8] ) |
|
self._upsample_16 = UpsampleBlockT1(in_channels=self._channels[8], out_channels=self._channels[16] ) |
|
self._upsample_32 = UpsampleBlockT2(in_channels=self._channels[16], out_channels=self._channels[32] ) |
|
self._upsample_64 = UpsampleBlockT1(in_channels=self._channels[32], out_channels=self._channels[64] ) |
|
self._upsample_128 = UpsampleBlockT2(in_channels=self._channels[64], out_channels=self._channels[128] ) |
|
self._upsample_256 = UpsampleBlockT1(in_channels=self._channels[128], out_channels=self._channels[256] ) |
|
self._upsample_512 = UpsampleBlockT2(in_channels=self._channels[256], out_channels=self._channels[512] ) |
|
self._upsample_1024 = UpsampleBlockT1(in_channels=self._channels[512], out_channels=self._channels[1024]) |
|
|
|
self._sle_64 = SLEBlock(in_channels=self._channels[4], out_channels=self._channels[64] ) |
|
self._sle_128 = SLEBlock(in_channels=self._channels[8], out_channels=self._channels[128]) |
|
self._sle_256 = SLEBlock(in_channels=self._channels[16], out_channels=self._channels[256]) |
|
self._sle_512 = SLEBlock(in_channels=self._channels[32], out_channels=self._channels[512]) |
|
|
|
self._out_128 = nn.Sequential( |
|
SpectralConv2d( |
|
in_channels=self._channels[128], |
|
out_channels=out_channels, |
|
kernel_size=1, |
|
stride=1, |
|
padding='same', |
|
bias=False, |
|
), |
|
nn.Tanh(), |
|
) |
|
|
|
self._out_1024 = nn.Sequential( |
|
SpectralConv2d( |
|
in_channels=self._channels[1024], |
|
out_channels=out_channels, |
|
kernel_size=3, |
|
stride=1, |
|
padding='same', |
|
bias=False, |
|
), |
|
nn.Tanh(), |
|
) |
|
|
|
def forward(self, input: torch.Tensor) -> \ |
|
Tuple[torch.Tensor, torch.Tensor]: |
|
size_4 = self._init(input) |
|
size_8 = self._upsample_8(size_4) |
|
size_16 = self._upsample_16(size_8) |
|
size_32 = self._upsample_32(size_16) |
|
|
|
size_64 = self._sle_64 (size_4, self._upsample_64 (size_32) ) |
|
size_128 = self._sle_128(size_8, self._upsample_128(size_64) ) |
|
size_256 = self._sle_256(size_16, self._upsample_256(size_128)) |
|
size_512 = self._sle_512(size_32, self._upsample_512(size_256)) |
|
|
|
size_1024 = self._upsample_1024(size_512) |
|
|
|
out_128 = self._out_128 (size_128) |
|
out_1024 = self._out_1024(size_1024) |
|
return out_1024, out_128 |
|
|
|
|
|
class Discriminrator(nn.Module, HugGANModelHubMixin): |
|
|
|
def __init__(self, in_channels: int): |
|
super().__init__() |
|
|
|
self._channels = { |
|
4: 1024, |
|
8: 512, |
|
16: 256, |
|
32: 128, |
|
64: 128, |
|
128: 64, |
|
256: 32, |
|
512: 16, |
|
1024: 8, |
|
} |
|
|
|
self._init = nn.Sequential( |
|
SpectralConv2d( |
|
in_channels=in_channels, |
|
out_channels=self._channels[1024], |
|
kernel_size=4, |
|
stride=2, |
|
padding=1, |
|
bias=False, |
|
), |
|
nn.LeakyReLU(negative_slope=0.2), |
|
SpectralConv2d( |
|
in_channels=self._channels[1024], |
|
out_channels=self._channels[512], |
|
kernel_size=4, |
|
stride=2, |
|
padding=1, |
|
bias=False, |
|
), |
|
nn.BatchNorm2d(num_features=self._channels[512]), |
|
nn.LeakyReLU(negative_slope=0.2), |
|
) |
|
|
|
self._downsample_256 = DownsampleBlockT2(in_channels=self._channels[512], out_channels=self._channels[256]) |
|
self._downsample_128 = DownsampleBlockT2(in_channels=self._channels[256], out_channels=self._channels[128]) |
|
self._downsample_64 = DownsampleBlockT2(in_channels=self._channels[128], out_channels=self._channels[64] ) |
|
self._downsample_32 = DownsampleBlockT2(in_channels=self._channels[64], out_channels=self._channels[32] ) |
|
self._downsample_16 = DownsampleBlockT2(in_channels=self._channels[32], out_channels=self._channels[16] ) |
|
|
|
self._sle_64 = SLEBlock(in_channels=self._channels[512], out_channels=self._channels[64]) |
|
self._sle_32 = SLEBlock(in_channels=self._channels[256], out_channels=self._channels[32]) |
|
self._sle_16 = SLEBlock(in_channels=self._channels[128], out_channels=self._channels[16]) |
|
|
|
self._small_track = nn.Sequential( |
|
SpectralConv2d( |
|
in_channels=in_channels, |
|
out_channels=self._channels[256], |
|
kernel_size=4, |
|
stride=2, |
|
padding=1, |
|
bias=False, |
|
), |
|
nn.LeakyReLU(negative_slope=0.2), |
|
DownsampleBlockT1(in_channels=self._channels[256], out_channels=self._channels[128]), |
|
DownsampleBlockT1(in_channels=self._channels[128], out_channels=self._channels[64] ), |
|
DownsampleBlockT1(in_channels=self._channels[64], out_channels=self._channels[32] ), |
|
) |
|
|
|
self._features_large = nn.Sequential( |
|
SpectralConv2d( |
|
in_channels=self._channels[16] , |
|
out_channels=self._channels[8], |
|
kernel_size=1, |
|
stride=1, |
|
padding=0, |
|
bias=False, |
|
), |
|
nn.BatchNorm2d(num_features=self._channels[8]), |
|
nn.LeakyReLU(negative_slope=0.2), |
|
SpectralConv2d( |
|
in_channels=self._channels[8], |
|
out_channels=1, |
|
kernel_size=4, |
|
stride=1, |
|
padding=0, |
|
bias=False, |
|
) |
|
) |
|
|
|
self._features_small = nn.Sequential( |
|
SpectralConv2d( |
|
in_channels=self._channels[32], |
|
out_channels=1, |
|
kernel_size=4, |
|
stride=1, |
|
padding=0, |
|
bias=False, |
|
), |
|
) |
|
|
|
self._decoder_large = Decoder(in_channels=self._channels[16], out_channels=3) |
|
self._decoder_small = Decoder(in_channels=self._channels[32], out_channels=3) |
|
self._decoder_piece = Decoder(in_channels=self._channels[32], out_channels=3) |
|
|
|
def forward(self, images_1024: torch.Tensor, |
|
images_128: torch.Tensor, |
|
image_type: ImageType) -> \ |
|
Union[ |
|
torch.Tensor, |
|
Tuple[torch.Tensor, Tuple[Any, Any, Any]] |
|
]: |
|
|
|
|
|
down_512 = self._init(images_1024) |
|
down_256 = self._downsample_256(down_512) |
|
down_128 = self._downsample_128(down_256) |
|
|
|
down_64 = self._downsample_64(down_128) |
|
down_64 = self._sle_64(down_512, down_64) |
|
|
|
down_32 = self._downsample_32(down_64) |
|
down_32 = self._sle_32(down_256, down_32) |
|
|
|
down_16 = self._downsample_16(down_32) |
|
down_16 = self._sle_16(down_128, down_16) |
|
|
|
|
|
|
|
down_small = self._small_track(images_128) |
|
|
|
|
|
|
|
features_large = self._features_large(down_16).view(-1) |
|
features_small = self._features_small(down_small).view(-1) |
|
features = torch.cat([features_large, features_small], dim=0) |
|
|
|
|
|
|
|
if image_type != ImageType.FAKE: |
|
dec_large = self._decoder_large(down_16) |
|
dec_small = self._decoder_small(down_small) |
|
dec_piece = self._decoder_piece(crop_image_part(down_32, image_type)) |
|
return features, (dec_large, dec_small, dec_piece) |
|
|
|
return features |
|
|