Spaces:
Runtime error
Runtime error
from typing import Optional, Dict | |
from functools import partial | |
import math | |
import torch | |
import torch.nn as nn | |
def get_activation(activation: str = "lrelu"): | |
actv_layers = { | |
"relu": nn.ReLU, | |
"lrelu": partial(nn.LeakyReLU, 0.2), | |
} | |
assert activation in actv_layers, f"activation [{activation}] not implemented" | |
return actv_layers[activation] | |
def get_normalization(normalization: str = "batch_norm"): | |
norm_layers = { | |
"instance_norm": nn.InstanceNorm2d, | |
"batch_norm": nn.BatchNorm2d, | |
"group_norm": partial(nn.GroupNorm, num_groups=8), | |
"layer_norm": partial(nn.GroupNorm, num_groups=1), | |
} | |
assert normalization in norm_layers, f"normalization [{normalization}] not implemented" | |
return norm_layers[normalization] | |
class ConvLayer(nn.Sequential): | |
def __init__( | |
self, | |
in_channels: int, | |
out_channels: int, | |
kernel_size: int = 3, | |
stride: int = 1, | |
padding: Optional[int] = 1, | |
padding_mode: str = "zeros", | |
groups: int = 1, | |
bias: bool = True, | |
transposed: bool = False, | |
normalization: Optional[str] = None, | |
activation: Optional[str] = "lrelu", | |
pre_activate: bool = False, | |
): | |
if transposed: | |
conv = partial(nn.ConvTranspose2d, output_padding=stride-1) | |
padding_mode = "zeros" | |
else: | |
conv = nn.Conv2d | |
layers = [ | |
conv( | |
in_channels, | |
out_channels, | |
kernel_size=kernel_size, | |
stride=stride, | |
padding=padding, | |
padding_mode=padding_mode, | |
groups=groups, | |
bias=bias, | |
) | |
] | |
norm_actv = [] | |
if normalization is not None: | |
norm_actv.append( | |
get_normalization(normalization)( | |
num_channels=in_channels if pre_activate else out_channels | |
) | |
) | |
if activation is not None: | |
norm_actv.append( | |
get_activation(activation)(inplace=True) | |
) | |
if pre_activate: | |
layers = norm_actv + layers | |
else: | |
layers = layers + norm_actv | |
super().__init__( | |
*layers | |
) | |
class SubspaceLayer(nn.Module): | |
def __init__( | |
self, | |
dim: int, | |
n_basis: int, | |
): | |
super().__init__() | |
self.U = nn.Parameter(torch.empty(n_basis, dim)) | |
nn.init.orthogonal_(self.U) | |
self.L = nn.Parameter(torch.FloatTensor([3 * i for i in range(n_basis, 0, -1)])) | |
self.mu = nn.Parameter(torch.zeros(dim)) | |
def forward(self, z): | |
return (self.L * z) @ self.U + self.mu | |
class EigenBlock(nn.Module): | |
def __init__( | |
self, | |
width: int, | |
height: int, | |
in_channels: int, | |
out_channels: int, | |
n_basis: int, | |
): | |
super().__init__() | |
self.projection = SubspaceLayer(dim=width*height*in_channels, n_basis=n_basis) | |
self.subspace_conv1 = ConvLayer( | |
in_channels, | |
in_channels, | |
kernel_size=1, | |
stride=1, | |
padding=0, | |
transposed=True, | |
activation=None, | |
normalization=None, | |
) | |
self.subspace_conv2 = ConvLayer( | |
in_channels, | |
out_channels, | |
kernel_size=3, | |
stride=2, | |
padding=1, | |
transposed=True, | |
activation=None, | |
normalization=None, | |
) | |
self.feature_conv1 = ConvLayer( | |
in_channels, | |
out_channels, | |
kernel_size=3, | |
stride=2, | |
transposed=True, | |
pre_activate=True, | |
) | |
self.feature_conv2 = ConvLayer( | |
out_channels, | |
out_channels, | |
kernel_size=3, | |
stride=1, | |
transposed=True, | |
pre_activate=True, | |
) | |
def forward(self, z, h): | |
phi = self.projection(z).view(h.shape) | |
h = self.feature_conv1(h + self.subspace_conv1(phi)) | |
h = self.feature_conv2(h + self.subspace_conv2(phi)) | |
return h | |
class ConditionalGenerator(nn.Module): | |
"""Conditional generator | |
It generates images from one hot label + noise sampled from N(0, 1) with explorable z injection space | |
Based on EigenGAN | |
""" | |
def __init__(self, | |
size: int, | |
y_size: int, | |
z_size: int, | |
out_channels: int = 3, | |
n_basis: int = 6, | |
noise_dim: int = 512, | |
base_channels: int = 16, | |
max_channels: int = 512, | |
y_type: str = 'one_hot'): | |
if y_type not in ['one_hot', 'multi_label', 'mixed', 'real']: | |
raise ValueError('Unsupported `y_type`') | |
super(ConditionalGenerator, self).__init__() | |
assert (size & (size - 1) == 0) and size != 0, "img size should be a power of 2" | |
self.y_type = y_type | |
self.y_size = y_size | |
self.eps_size = z_size | |
self.noise_dim = noise_dim | |
self.n_basis = n_basis | |
self.n_blocks = int(math.log(size, 2)) - 2 | |
def get_channels(i_block): | |
return min(max_channels, base_channels * (2 ** (self.n_blocks - i_block))) | |
self.y_fc = nn.Linear(self.y_size, self.y_size) | |
self.concat_fc = nn.Linear(self.y_size + self.eps_size, self.noise_dim) | |
self.fc = nn.Linear(self.noise_dim, 4 * 4 * get_channels(0)) | |
self.blocks = nn.ModuleList() | |
for i in range(self.n_blocks): | |
self.blocks.append( | |
EigenBlock( | |
width=4 * (2 ** i), | |
height=4 * (2 ** i), | |
in_channels=get_channels(i), | |
out_channels=get_channels(i + 1), | |
n_basis=self.n_basis, | |
) | |
) | |
self.out = nn.Sequential( | |
ConvLayer(base_channels, out_channels, kernel_size=7, stride=1, padding=3, pre_activate=True), | |
nn.Tanh(), | |
) | |
def forward(self, | |
y: torch.Tensor, | |
eps: Optional[torch.Tensor] = None, | |
zs: Optional[torch.Tensor] = None, | |
return_eps: bool = False): | |
bs = y.size(0) | |
if eps is None: | |
eps = self.sample_eps(bs) | |
if zs is None: | |
zs = self.sample_zs(bs) | |
y_out = self.y_fc(y) | |
concat = torch.cat((y_out, eps), dim=1) | |
concat = self.concat_fc(concat) | |
out = self.fc(concat).view(len(eps), -1, 4, 4) | |
for block, z in zip(self.blocks, zs.permute(1, 0, 2)): | |
out = block(z, out) | |
out = self.out(out) | |
if return_eps: | |
return out, concat | |
return out | |
def sample_zs(self, batch: int, truncation: float = 1.): | |
device = self.get_device() | |
zs = torch.randn(batch, self.n_blocks, self.n_basis, device=device) | |
if truncation < 1.: | |
zs = torch.zeros_like(zs) * (1 - truncation) + zs * truncation | |
return zs | |
def sample_eps(self, batch: int, truncation: float = 1.): | |
device = self.get_device() | |
eps = torch.randn(batch, self.eps_size, device=device) | |
if truncation < 1.: | |
eps = torch.zeros_like(eps) * (1 - truncation) + eps * truncation | |
return eps | |
def get_device(self): | |
return self.fc.weight.device | |
def orthogonal_regularizer(self): | |
reg = [] | |
for layer in self.modules(): | |
if isinstance(layer, SubspaceLayer): | |
UUT = layer.U @ layer.U.t() | |
reg.append( | |
((UUT - torch.eye(UUT.shape[0], device=UUT.device)) ** 2).mean() | |
) | |
return sum(reg) / len(reg) | |