from typing import Dict, Iterator, List, Optional, Tuple, Union import numpy as np import torch import torch.nn as nn import torchvision from einops import rearrange from matplotlib import colormaps from matplotlib import pyplot as plt from ....util import default, instantiate_from_config from ..lpips.loss.lpips import LPIPS from ..lpips.model.model import weights_init from ..lpips.vqperceptual import hinge_d_loss, vanilla_d_loss class GeneralLPIPSWithDiscriminator(nn.Module): def __init__( self, disc_start: int, logvar_init: float = 0.0, disc_num_layers: int = 3, disc_in_channels: int = 3, disc_factor: float = 1.0, disc_weight: float = 1.0, perceptual_weight: float = 1.0, disc_loss: str = "hinge", scale_input_to_tgt_size: bool = False, dims: int = 2, learn_logvar: bool = False, regularization_weights: Union[None, Dict[str, float]] = None, additional_log_keys: Optional[List[str]] = None, discriminator_config: Optional[Dict] = None, ): super().__init__() self.dims = dims if self.dims > 2: print( f"running with dims={dims}. This means that for perceptual loss " f"calculation, the LPIPS loss will be applied to each frame " f"independently." ) self.scale_input_to_tgt_size = scale_input_to_tgt_size assert disc_loss in ["hinge", "vanilla"] self.perceptual_loss = LPIPS().eval() self.perceptual_weight = perceptual_weight # output log variance self.logvar = nn.Parameter( torch.full((), logvar_init), requires_grad=learn_logvar ) self.learn_logvar = learn_logvar discriminator_config = default( discriminator_config, { "target": "sgm.modules.autoencoding.lpips.model.model.NLayerDiscriminator", "params": { "input_nc": disc_in_channels, "n_layers": disc_num_layers, "use_actnorm": False, }, }, ) self.discriminator = instantiate_from_config(discriminator_config).apply( weights_init ) self.discriminator_iter_start = disc_start self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss self.disc_factor = disc_factor self.discriminator_weight = disc_weight self.regularization_weights = default(regularization_weights, {}) self.forward_keys = [ "optimizer_idx", "global_step", "last_layer", "split", "regularization_log", ] self.additional_log_keys = set(default(additional_log_keys, [])) self.additional_log_keys.update(set(self.regularization_weights.keys())) def get_trainable_parameters(self) -> Iterator[nn.Parameter]: return self.discriminator.parameters() def get_trainable_autoencoder_parameters(self) -> Iterator[nn.Parameter]: if self.learn_logvar: yield self.logvar yield from () @torch.no_grad() def log_images( self, inputs: torch.Tensor, reconstructions: torch.Tensor ) -> Dict[str, torch.Tensor]: # calc logits of real/fake logits_real = self.discriminator(inputs.contiguous().detach()) if len(logits_real.shape) < 4: # Non patch-discriminator return dict() logits_fake = self.discriminator(reconstructions.contiguous().detach()) # -> (b, 1, h, w) # parameters for colormapping high = max(logits_fake.abs().max(), logits_real.abs().max()).item() cmap = colormaps["PiYG"] # diverging colormap def to_colormap(logits: torch.Tensor) -> torch.Tensor: """(b, 1, ...) -> (b, 3, ...)""" logits = (logits + high) / (2 * high) logits_np = cmap(logits.cpu().numpy())[..., :3] # truncate alpha channel # -> (b, 1, ..., 3) logits = torch.from_numpy(logits_np).to(logits.device) return rearrange(logits, "b 1 ... c -> b c ...") logits_real = torch.nn.functional.interpolate( logits_real, size=inputs.shape[-2:], mode="nearest", antialias=False, ) logits_fake = torch.nn.functional.interpolate( logits_fake, size=reconstructions.shape[-2:], mode="nearest", antialias=False, ) # alpha value of logits for overlay alpha_real = torch.abs(logits_real) / high alpha_fake = torch.abs(logits_fake) / high # -> (b, 1, h, w) in range [0, 0.5] # alpha value of lines don't really matter, since the values are the same # for both images and logits anyway grid_alpha_real = torchvision.utils.make_grid(alpha_real, nrow=4) grid_alpha_fake = torchvision.utils.make_grid(alpha_fake, nrow=4) grid_alpha = 0.8 * torch.cat((grid_alpha_real, grid_alpha_fake), dim=1) # -> (1, h, w) # blend logits and images together # prepare logits for plotting logits_real = to_colormap(logits_real) logits_fake = to_colormap(logits_fake) # resize logits # -> (b, 3, h, w) # make some grids # add all logits to one plot logits_real = torchvision.utils.make_grid(logits_real, nrow=4) logits_fake = torchvision.utils.make_grid(logits_fake, nrow=4) # I just love how torchvision calls the number of columns `nrow` grid_logits = torch.cat((logits_real, logits_fake), dim=1) # -> (3, h, w) grid_images_real = torchvision.utils.make_grid(0.5 * inputs + 0.5, nrow=4) grid_images_fake = torchvision.utils.make_grid( 0.5 * reconstructions + 0.5, nrow=4 ) grid_images = torch.cat((grid_images_real, grid_images_fake), dim=1) # -> (3, h, w) in range [0, 1] grid_blend = grid_alpha * grid_logits + (1 - grid_alpha) * grid_images # Create labeled colorbar dpi = 100 height = 128 / dpi width = grid_logits.shape[2] / dpi fig, ax = plt.subplots(figsize=(width, height), dpi=dpi) img = ax.imshow(np.array([[-high, high]]), cmap=cmap) plt.colorbar( img, cax=ax, orientation="horizontal", fraction=0.9, aspect=width / height, pad=0.0, ) img.set_visible(False) fig.tight_layout() fig.canvas.draw() # manually convert figure to numpy cbar_np = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) cbar_np = cbar_np.reshape(fig.canvas.get_width_height()[::-1] + (3,)) cbar = torch.from_numpy(cbar_np.copy()).to(grid_logits.dtype) / 255.0 cbar = rearrange(cbar, "h w c -> c h w").to(grid_logits.device) # Add colorbar to plot annotated_grid = torch.cat((grid_logits, cbar), dim=1) blended_grid = torch.cat((grid_blend, cbar), dim=1) return { "vis_logits": 2 * annotated_grid[None, ...] - 1, "vis_logits_blended": 2 * blended_grid[None, ...] - 1, } def calculate_adaptive_weight( self, nll_loss: torch.Tensor, g_loss: torch.Tensor, last_layer: torch.Tensor ) -> torch.Tensor: nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() d_weight = d_weight * self.discriminator_weight return d_weight def forward( self, inputs: torch.Tensor, reconstructions: torch.Tensor, *, # added because I changed the order here regularization_log: Dict[str, torch.Tensor], optimizer_idx: int, global_step: int, last_layer: torch.Tensor, split: str = "train", weights: Union[None, float, torch.Tensor] = None, ) -> Tuple[torch.Tensor, dict]: if self.scale_input_to_tgt_size: inputs = torch.nn.functional.interpolate( inputs, reconstructions.shape[2:], mode="bicubic", antialias=True ) if self.dims > 2: inputs, reconstructions = map( lambda x: rearrange(x, "b c t h w -> (b t) c h w"), (inputs, reconstructions), ) rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) if self.perceptual_weight > 0: p_loss = self.perceptual_loss( inputs.contiguous(), reconstructions.contiguous() ) rec_loss = rec_loss + self.perceptual_weight * p_loss nll_loss, weighted_nll_loss = self.get_nll_loss(rec_loss, weights) # now the GAN part if optimizer_idx == 0: # generator update if global_step >= self.discriminator_iter_start or not self.training: logits_fake = self.discriminator(reconstructions.contiguous()) g_loss = -torch.mean(logits_fake) if self.training: d_weight = self.calculate_adaptive_weight( nll_loss, g_loss, last_layer=last_layer ) else: d_weight = torch.tensor(1.0) else: d_weight = torch.tensor(0.0) g_loss = torch.tensor(0.0, requires_grad=True) loss = weighted_nll_loss + d_weight * self.disc_factor * g_loss log = dict() for k in regularization_log: if k in self.regularization_weights: loss = loss + self.regularization_weights[k] * regularization_log[k] if k in self.additional_log_keys: log[f"{split}/{k}"] = regularization_log[k].detach().float().mean() log.update( { f"{split}/loss/total": loss.clone().detach().mean(), f"{split}/loss/nll": nll_loss.detach().mean(), f"{split}/loss/rec": rec_loss.detach().mean(), f"{split}/loss/g": g_loss.detach().mean(), f"{split}/scalars/logvar": self.logvar.detach(), f"{split}/scalars/d_weight": d_weight.detach(), } ) return loss, log elif optimizer_idx == 1: # second pass for discriminator update logits_real = self.discriminator(inputs.contiguous().detach()) logits_fake = self.discriminator(reconstructions.contiguous().detach()) if global_step >= self.discriminator_iter_start or not self.training: d_loss = self.disc_factor * self.disc_loss(logits_real, logits_fake) else: d_loss = torch.tensor(0.0, requires_grad=True) log = { f"{split}/loss/disc": d_loss.clone().detach().mean(), f"{split}/logits/real": logits_real.detach().mean(), f"{split}/logits/fake": logits_fake.detach().mean(), } return d_loss, log else: raise NotImplementedError(f"Unknown optimizer_idx {optimizer_idx}") def get_nll_loss( self, rec_loss: torch.Tensor, weights: Optional[Union[float, torch.Tensor]] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar weighted_nll_loss = nll_loss if weights is not None: weighted_nll_loss = weights * nll_loss weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0] nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] return nll_loss, weighted_nll_loss