Spaces:
Running
on
Zero
Running
on
Zero
from typing import Dict, Iterator, List, Optional, Tuple, Union | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torchvision | |
from einops import rearrange | |
from matplotlib import colormaps | |
from matplotlib import pyplot as plt | |
from ....util import default, instantiate_from_config | |
from ..lpips.loss.lpips import LPIPS | |
from ..lpips.model.model import weights_init | |
from ..lpips.vqperceptual import hinge_d_loss, vanilla_d_loss | |
class GeneralLPIPSWithDiscriminator(nn.Module): | |
def __init__( | |
self, | |
disc_start: int, | |
logvar_init: float = 0.0, | |
disc_num_layers: int = 3, | |
disc_in_channels: int = 3, | |
disc_factor: float = 1.0, | |
disc_weight: float = 1.0, | |
perceptual_weight: float = 1.0, | |
disc_loss: str = "hinge", | |
scale_input_to_tgt_size: bool = False, | |
dims: int = 2, | |
learn_logvar: bool = False, | |
regularization_weights: Union[None, Dict[str, float]] = None, | |
additional_log_keys: Optional[List[str]] = None, | |
discriminator_config: Optional[Dict] = None, | |
): | |
super().__init__() | |
self.dims = dims | |
if self.dims > 2: | |
print( | |
f"running with dims={dims}. This means that for perceptual loss " | |
f"calculation, the LPIPS loss will be applied to each frame " | |
f"independently." | |
) | |
self.scale_input_to_tgt_size = scale_input_to_tgt_size | |
assert disc_loss in ["hinge", "vanilla"] | |
self.perceptual_loss = LPIPS().eval() | |
self.perceptual_weight = perceptual_weight | |
# output log variance | |
self.logvar = nn.Parameter( | |
torch.full((), logvar_init), requires_grad=learn_logvar | |
) | |
self.learn_logvar = learn_logvar | |
discriminator_config = default( | |
discriminator_config, | |
{ | |
"target": "sgm.modules.autoencoding.lpips.model.model.NLayerDiscriminator", | |
"params": { | |
"input_nc": disc_in_channels, | |
"n_layers": disc_num_layers, | |
"use_actnorm": False, | |
}, | |
}, | |
) | |
self.discriminator = instantiate_from_config(discriminator_config).apply( | |
weights_init | |
) | |
self.discriminator_iter_start = disc_start | |
self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss | |
self.disc_factor = disc_factor | |
self.discriminator_weight = disc_weight | |
self.regularization_weights = default(regularization_weights, {}) | |
self.forward_keys = [ | |
"optimizer_idx", | |
"global_step", | |
"last_layer", | |
"split", | |
"regularization_log", | |
] | |
self.additional_log_keys = set(default(additional_log_keys, [])) | |
self.additional_log_keys.update(set(self.regularization_weights.keys())) | |
def get_trainable_parameters(self) -> Iterator[nn.Parameter]: | |
return self.discriminator.parameters() | |
def get_trainable_autoencoder_parameters(self) -> Iterator[nn.Parameter]: | |
if self.learn_logvar: | |
yield self.logvar | |
yield from () | |
def log_images( | |
self, inputs: torch.Tensor, reconstructions: torch.Tensor | |
) -> Dict[str, torch.Tensor]: | |
# calc logits of real/fake | |
logits_real = self.discriminator(inputs.contiguous().detach()) | |
if len(logits_real.shape) < 4: | |
# Non patch-discriminator | |
return dict() | |
logits_fake = self.discriminator(reconstructions.contiguous().detach()) | |
# -> (b, 1, h, w) | |
# parameters for colormapping | |
high = max(logits_fake.abs().max(), logits_real.abs().max()).item() | |
cmap = colormaps["PiYG"] # diverging colormap | |
def to_colormap(logits: torch.Tensor) -> torch.Tensor: | |
"""(b, 1, ...) -> (b, 3, ...)""" | |
logits = (logits + high) / (2 * high) | |
logits_np = cmap(logits.cpu().numpy())[..., :3] # truncate alpha channel | |
# -> (b, 1, ..., 3) | |
logits = torch.from_numpy(logits_np).to(logits.device) | |
return rearrange(logits, "b 1 ... c -> b c ...") | |
logits_real = torch.nn.functional.interpolate( | |
logits_real, | |
size=inputs.shape[-2:], | |
mode="nearest", | |
antialias=False, | |
) | |
logits_fake = torch.nn.functional.interpolate( | |
logits_fake, | |
size=reconstructions.shape[-2:], | |
mode="nearest", | |
antialias=False, | |
) | |
# alpha value of logits for overlay | |
alpha_real = torch.abs(logits_real) / high | |
alpha_fake = torch.abs(logits_fake) / high | |
# -> (b, 1, h, w) in range [0, 0.5] | |
# alpha value of lines don't really matter, since the values are the same | |
# for both images and logits anyway | |
grid_alpha_real = torchvision.utils.make_grid(alpha_real, nrow=4) | |
grid_alpha_fake = torchvision.utils.make_grid(alpha_fake, nrow=4) | |
grid_alpha = 0.8 * torch.cat((grid_alpha_real, grid_alpha_fake), dim=1) | |
# -> (1, h, w) | |
# blend logits and images together | |
# prepare logits for plotting | |
logits_real = to_colormap(logits_real) | |
logits_fake = to_colormap(logits_fake) | |
# resize logits | |
# -> (b, 3, h, w) | |
# make some grids | |
# add all logits to one plot | |
logits_real = torchvision.utils.make_grid(logits_real, nrow=4) | |
logits_fake = torchvision.utils.make_grid(logits_fake, nrow=4) | |
# I just love how torchvision calls the number of columns `nrow` | |
grid_logits = torch.cat((logits_real, logits_fake), dim=1) | |
# -> (3, h, w) | |
grid_images_real = torchvision.utils.make_grid(0.5 * inputs + 0.5, nrow=4) | |
grid_images_fake = torchvision.utils.make_grid( | |
0.5 * reconstructions + 0.5, nrow=4 | |
) | |
grid_images = torch.cat((grid_images_real, grid_images_fake), dim=1) | |
# -> (3, h, w) in range [0, 1] | |
grid_blend = grid_alpha * grid_logits + (1 - grid_alpha) * grid_images | |
# Create labeled colorbar | |
dpi = 100 | |
height = 128 / dpi | |
width = grid_logits.shape[2] / dpi | |
fig, ax = plt.subplots(figsize=(width, height), dpi=dpi) | |
img = ax.imshow(np.array([[-high, high]]), cmap=cmap) | |
plt.colorbar( | |
img, | |
cax=ax, | |
orientation="horizontal", | |
fraction=0.9, | |
aspect=width / height, | |
pad=0.0, | |
) | |
img.set_visible(False) | |
fig.tight_layout() | |
fig.canvas.draw() | |
# manually convert figure to numpy | |
cbar_np = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) | |
cbar_np = cbar_np.reshape(fig.canvas.get_width_height()[::-1] + (3,)) | |
cbar = torch.from_numpy(cbar_np.copy()).to(grid_logits.dtype) / 255.0 | |
cbar = rearrange(cbar, "h w c -> c h w").to(grid_logits.device) | |
# Add colorbar to plot | |
annotated_grid = torch.cat((grid_logits, cbar), dim=1) | |
blended_grid = torch.cat((grid_blend, cbar), dim=1) | |
return { | |
"vis_logits": 2 * annotated_grid[None, ...] - 1, | |
"vis_logits_blended": 2 * blended_grid[None, ...] - 1, | |
} | |
def calculate_adaptive_weight( | |
self, nll_loss: torch.Tensor, g_loss: torch.Tensor, last_layer: torch.Tensor | |
) -> torch.Tensor: | |
nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] | |
g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] | |
d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) | |
d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() | |
d_weight = d_weight * self.discriminator_weight | |
return d_weight | |
def forward( | |
self, | |
inputs: torch.Tensor, | |
reconstructions: torch.Tensor, | |
*, # added because I changed the order here | |
regularization_log: Dict[str, torch.Tensor], | |
optimizer_idx: int, | |
global_step: int, | |
last_layer: torch.Tensor, | |
split: str = "train", | |
weights: Union[None, float, torch.Tensor] = None, | |
) -> Tuple[torch.Tensor, dict]: | |
if self.scale_input_to_tgt_size: | |
inputs = torch.nn.functional.interpolate( | |
inputs, reconstructions.shape[2:], mode="bicubic", antialias=True | |
) | |
if self.dims > 2: | |
inputs, reconstructions = map( | |
lambda x: rearrange(x, "b c t h w -> (b t) c h w"), | |
(inputs, reconstructions), | |
) | |
rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) | |
if self.perceptual_weight > 0: | |
p_loss = self.perceptual_loss( | |
inputs.contiguous(), reconstructions.contiguous() | |
) | |
rec_loss = rec_loss + self.perceptual_weight * p_loss | |
nll_loss, weighted_nll_loss = self.get_nll_loss(rec_loss, weights) | |
# now the GAN part | |
if optimizer_idx == 0: | |
# generator update | |
if global_step >= self.discriminator_iter_start or not self.training: | |
logits_fake = self.discriminator(reconstructions.contiguous()) | |
g_loss = -torch.mean(logits_fake) | |
if self.training: | |
d_weight = self.calculate_adaptive_weight( | |
nll_loss, g_loss, last_layer=last_layer | |
) | |
else: | |
d_weight = torch.tensor(1.0) | |
else: | |
d_weight = torch.tensor(0.0) | |
g_loss = torch.tensor(0.0, requires_grad=True) | |
loss = weighted_nll_loss + d_weight * self.disc_factor * g_loss | |
log = dict() | |
for k in regularization_log: | |
if k in self.regularization_weights: | |
loss = loss + self.regularization_weights[k] * regularization_log[k] | |
if k in self.additional_log_keys: | |
log[f"{split}/{k}"] = regularization_log[k].detach().float().mean() | |
log.update( | |
{ | |
f"{split}/loss/total": loss.clone().detach().mean(), | |
f"{split}/loss/nll": nll_loss.detach().mean(), | |
f"{split}/loss/rec": rec_loss.detach().mean(), | |
f"{split}/loss/g": g_loss.detach().mean(), | |
f"{split}/scalars/logvar": self.logvar.detach(), | |
f"{split}/scalars/d_weight": d_weight.detach(), | |
} | |
) | |
return loss, log | |
elif optimizer_idx == 1: | |
# second pass for discriminator update | |
logits_real = self.discriminator(inputs.contiguous().detach()) | |
logits_fake = self.discriminator(reconstructions.contiguous().detach()) | |
if global_step >= self.discriminator_iter_start or not self.training: | |
d_loss = self.disc_factor * self.disc_loss(logits_real, logits_fake) | |
else: | |
d_loss = torch.tensor(0.0, requires_grad=True) | |
log = { | |
f"{split}/loss/disc": d_loss.clone().detach().mean(), | |
f"{split}/logits/real": logits_real.detach().mean(), | |
f"{split}/logits/fake": logits_fake.detach().mean(), | |
} | |
return d_loss, log | |
else: | |
raise NotImplementedError(f"Unknown optimizer_idx {optimizer_idx}") | |
def get_nll_loss( | |
self, | |
rec_loss: torch.Tensor, | |
weights: Optional[Union[float, torch.Tensor]] = None, | |
) -> Tuple[torch.Tensor, torch.Tensor]: | |
nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar | |
weighted_nll_loss = nll_loss | |
if weights is not None: | |
weighted_nll_loss = weights * nll_loss | |
weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0] | |
nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] | |
return nll_loss, weighted_nll_loss | |