Spaces:
Runtime error
Runtime error
File size: 5,718 Bytes
5d756f1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
import torch
import numpy as np
import tqdm
import tops
from ..layers import Module
from ..layers.sg2_layers import FullyConnectedLayer
class BaseGenerator(Module):
def __init__(self, z_channels: int):
super().__init__()
self.z_channels = z_channels
self.latent_space = "Z"
@torch.no_grad()
def get_z(
self,
x: torch.Tensor = None,
z: torch.Tensor = None,
truncation_value: float = None,
batch_size: int = None,
dtype=None, device=None) -> torch.Tensor:
"""Generates a latent variable for generator.
"""
if z is not None:
return z
if x is not None:
batch_size = x.shape[0]
dtype = x.dtype
device = x.device
if device is None:
device = tops.get_device()
if truncation_value == 0:
return torch.zeros((batch_size, self.z_channels), device=device, dtype=dtype)
z = torch.randn((batch_size, self.z_channels), device=device, dtype=dtype)
if truncation_value is None:
return z
while z.abs().max() > truncation_value:
m = z.abs() > truncation_value
z[m] = torch.rand_like(z)[m]
return z
def sample(self, truncation_value, z=None, **kwargs):
"""
Samples via interpolating to the mean (0).
"""
if truncation_value is None:
return self.forward(**kwargs)
truncation_value = max(0, truncation_value)
truncation_value = min(truncation_value, 1)
if z is None:
z = self.get_z(kwargs["condition"])
z = z * truncation_value
return self.forward(**kwargs, z=z)
class SG2StyleNet(torch.nn.Module):
def __init__(self,
z_dim, # Input latent (Z) dimensionality.
w_dim, # Intermediate latent (W) dimensionality.
num_layers=2, # Number of mapping layers.
lr_multiplier=0.01, # Learning rate multiplier for the mapping layers.
w_avg_beta=0.998, # Decay for tracking the moving average of W during training.
):
super().__init__()
self.z_dim = z_dim
self.w_dim = w_dim
self.num_layers = num_layers
self.w_avg_beta = w_avg_beta
# Construct layers.
features = [self.z_dim] + [self.w_dim] * self.num_layers
for idx, in_features, out_features in zip(range(num_layers), features[:-1], features[1:]):
layer = FullyConnectedLayer(in_features, out_features, activation='lrelu', lr_multiplier=lr_multiplier)
setattr(self, f'fc{idx}', layer)
self.register_buffer('w_avg', torch.zeros([w_dim]))
def forward(self, z, update_emas=False, **kwargs):
tops.assert_shape(z, [None, self.z_dim])
# Embed, normalize, and concatenate inputs.
x = z.to(torch.float32)
x = x * (x.square().mean(1, keepdim=True) + 1e-8).rsqrt()
# Execute layers.
for idx in range(self.num_layers):
x = getattr(self, f'fc{idx}')(x)
# Update moving average of W.
if update_emas:
self.w_avg.copy_(x.float().detach().mean(dim=0).lerp(self.w_avg, self.w_avg_beta))
return x
def extra_repr(self):
return f'z_dim={self.z_dim:d}, w_dim={self.w_dim:d}'
def update_w(self, n=int(10e3), batch_size=32):
"""
Calculate w_ema over n iterations.
Useful in cases where w_ema is calculated incorrectly during training.
"""
n = n // batch_size
for i in tqdm.trange(n, desc="Updating w"):
z = torch.randn((batch_size, self.z_dim), device=tops.get_device())
self(z, update_emas=True)
def get_truncated(self, truncation_value, condition, z=None, **kwargs):
if z is None:
z = torch.randn((condition.shape[0], self.z_dim), device=tops.get_device())
w = self(z)
truncation_value = max(0, truncation_value)
truncation_value = min(truncation_value, 1)
return self.w_avg.to(w.dtype).lerp(w, truncation_value)
def multi_modal_truncate(self, truncation_value, condition, w_indices, z=None, **kwargs):
truncation_value = max(0, truncation_value)
truncation_value = min(truncation_value, 1)
if z is None:
z = torch.randn((condition.shape[0], self.z_dim), device=tops.get_device())
w = self(z)
if w_indices is None:
w_indices = np.random.randint(0, len(self.w_centers), size=(len(w)))
w_centers = self.w_centers[w_indices].to(w.device)
w = w_centers.to(w.dtype).lerp(w, truncation_value)
return w
class BaseStyleGAN(BaseGenerator):
def __init__(self, z_channels: int, w_dim: int):
super().__init__(z_channels)
self.style_net = SG2StyleNet(z_channels, w_dim)
self.latent_space = "W"
def get_w(self, z, update_emas):
return self.style_net(z, update_emas=update_emas)
@torch.no_grad()
def sample(self, truncation_value, **kwargs):
if truncation_value is None:
return self.forward(**kwargs)
w = self.style_net.get_truncated(truncation_value, **kwargs)
return self.forward(**kwargs, w=w)
def update_w(self, *args, **kwargs):
self.style_net.update_w(*args, **kwargs)
@torch.no_grad()
def multi_modal_truncate(self, truncation_value, w_indices=None, **kwargs):
w = self.style_net.multi_modal_truncate(truncation_value, w_indices=w_indices, **kwargs)
return self.forward(**kwargs, w=w)
|