vitaliykinakh's picture
Initial
8d6cd57
from typing import Optional, Dict
from functools import partial
import math
import torch
import torch.nn as nn
def get_activation(activation: str = "lrelu"):
actv_layers = {
"relu": nn.ReLU,
"lrelu": partial(nn.LeakyReLU, 0.2),
}
assert activation in actv_layers, f"activation [{activation}] not implemented"
return actv_layers[activation]
def get_normalization(normalization: str = "batch_norm"):
norm_layers = {
"instance_norm": nn.InstanceNorm2d,
"batch_norm": nn.BatchNorm2d,
"group_norm": partial(nn.GroupNorm, num_groups=8),
"layer_norm": partial(nn.GroupNorm, num_groups=1),
}
assert normalization in norm_layers, f"normalization [{normalization}] not implemented"
return norm_layers[normalization]
class ConvLayer(nn.Sequential):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int = 3,
stride: int = 1,
padding: Optional[int] = 1,
padding_mode: str = "zeros",
groups: int = 1,
bias: bool = True,
transposed: bool = False,
normalization: Optional[str] = None,
activation: Optional[str] = "lrelu",
pre_activate: bool = False,
):
if transposed:
conv = partial(nn.ConvTranspose2d, output_padding=stride-1)
padding_mode = "zeros"
else:
conv = nn.Conv2d
layers = [
conv(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
padding_mode=padding_mode,
groups=groups,
bias=bias,
)
]
norm_actv = []
if normalization is not None:
norm_actv.append(
get_normalization(normalization)(
num_channels=in_channels if pre_activate else out_channels
)
)
if activation is not None:
norm_actv.append(
get_activation(activation)(inplace=True)
)
if pre_activate:
layers = norm_actv + layers
else:
layers = layers + norm_actv
super().__init__(
*layers
)
class SubspaceLayer(nn.Module):
def __init__(
self,
dim: int,
n_basis: int,
):
super().__init__()
self.U = nn.Parameter(torch.empty(n_basis, dim))
nn.init.orthogonal_(self.U)
self.L = nn.Parameter(torch.FloatTensor([3 * i for i in range(n_basis, 0, -1)]))
self.mu = nn.Parameter(torch.zeros(dim))
def forward(self, z):
return (self.L * z) @ self.U + self.mu
class EigenBlock(nn.Module):
def __init__(
self,
width: int,
height: int,
in_channels: int,
out_channels: int,
n_basis: int,
):
super().__init__()
self.projection = SubspaceLayer(dim=width*height*in_channels, n_basis=n_basis)
self.subspace_conv1 = ConvLayer(
in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0,
transposed=True,
activation=None,
normalization=None,
)
self.subspace_conv2 = ConvLayer(
in_channels,
out_channels,
kernel_size=3,
stride=2,
padding=1,
transposed=True,
activation=None,
normalization=None,
)
self.feature_conv1 = ConvLayer(
in_channels,
out_channels,
kernel_size=3,
stride=2,
transposed=True,
pre_activate=True,
)
self.feature_conv2 = ConvLayer(
out_channels,
out_channels,
kernel_size=3,
stride=1,
transposed=True,
pre_activate=True,
)
def forward(self, z, h):
phi = self.projection(z).view(h.shape)
h = self.feature_conv1(h + self.subspace_conv1(phi))
h = self.feature_conv2(h + self.subspace_conv2(phi))
return h
class ConditionalGenerator(nn.Module):
"""Conditional generator
It generates images from one hot label + noise sampled from N(0, 1) with explorable z injection space
Based on EigenGAN
"""
def __init__(self,
size: int,
y_size: int,
z_size: int,
out_channels: int = 3,
n_basis: int = 6,
noise_dim: int = 512,
base_channels: int = 16,
max_channels: int = 512,
y_type: str = 'one_hot'):
if y_type not in ['one_hot', 'multi_label', 'mixed', 'real']:
raise ValueError('Unsupported `y_type`')
super(ConditionalGenerator, self).__init__()
assert (size & (size - 1) == 0) and size != 0, "img size should be a power of 2"
self.y_type = y_type
self.y_size = y_size
self.eps_size = z_size
self.noise_dim = noise_dim
self.n_basis = n_basis
self.n_blocks = int(math.log(size, 2)) - 2
def get_channels(i_block):
return min(max_channels, base_channels * (2 ** (self.n_blocks - i_block)))
self.y_fc = nn.Linear(self.y_size, self.y_size)
self.concat_fc = nn.Linear(self.y_size + self.eps_size, self.noise_dim)
self.fc = nn.Linear(self.noise_dim, 4 * 4 * get_channels(0))
self.blocks = nn.ModuleList()
for i in range(self.n_blocks):
self.blocks.append(
EigenBlock(
width=4 * (2 ** i),
height=4 * (2 ** i),
in_channels=get_channels(i),
out_channels=get_channels(i + 1),
n_basis=self.n_basis,
)
)
self.out = nn.Sequential(
ConvLayer(base_channels, out_channels, kernel_size=7, stride=1, padding=3, pre_activate=True),
nn.Tanh(),
)
def forward(self,
y: torch.Tensor,
eps: Optional[torch.Tensor] = None,
zs: Optional[torch.Tensor] = None,
return_eps: bool = False):
bs = y.size(0)
if eps is None:
eps = self.sample_eps(bs)
if zs is None:
zs = self.sample_zs(bs)
y_out = self.y_fc(y)
concat = torch.cat((y_out, eps), dim=1)
concat = self.concat_fc(concat)
out = self.fc(concat).view(len(eps), -1, 4, 4)
for block, z in zip(self.blocks, zs.permute(1, 0, 2)):
out = block(z, out)
out = self.out(out)
if return_eps:
return out, concat
return out
def sample_zs(self, batch: int, truncation: float = 1.):
device = self.get_device()
zs = torch.randn(batch, self.n_blocks, self.n_basis, device=device)
if truncation < 1.:
zs = torch.zeros_like(zs) * (1 - truncation) + zs * truncation
return zs
def sample_eps(self, batch: int, truncation: float = 1.):
device = self.get_device()
eps = torch.randn(batch, self.eps_size, device=device)
if truncation < 1.:
eps = torch.zeros_like(eps) * (1 - truncation) + eps * truncation
return eps
def get_device(self):
return self.fc.weight.device
def orthogonal_regularizer(self):
reg = []
for layer in self.modules():
if isinstance(layer, SubspaceLayer):
UUT = layer.U @ layer.U.t()
reg.append(
((UUT - torch.eye(UUT.shape[0], device=UUT.device)) ** 2).mean()
)
return sum(reg) / len(reg)