|
from typing import Dict, Iterator, List, Optional, Tuple, Union |
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
import torchvision |
|
from einops import rearrange |
|
from matplotlib import colormaps |
|
from matplotlib import pyplot as plt |
|
|
|
from ....util import default, instantiate_from_config |
|
from ..lpips.loss.lpips import LPIPS |
|
from ..lpips.model.model import weights_init |
|
from ..lpips.vqperceptual import hinge_d_loss, vanilla_d_loss |
|
|
|
|
|
class GeneralLPIPSWithDiscriminator(nn.Module): |
|
def __init__( |
|
self, |
|
disc_start: int, |
|
logvar_init: float = 0.0, |
|
disc_num_layers: int = 3, |
|
disc_in_channels: int = 3, |
|
disc_factor: float = 1.0, |
|
disc_weight: float = 1.0, |
|
perceptual_weight: float = 1.0, |
|
disc_loss: str = "hinge", |
|
scale_input_to_tgt_size: bool = False, |
|
dims: int = 2, |
|
learn_logvar: bool = False, |
|
regularization_weights: Union[None, Dict[str, float]] = None, |
|
additional_log_keys: Optional[List[str]] = None, |
|
discriminator_config: Optional[Dict] = None, |
|
): |
|
super().__init__() |
|
self.dims = dims |
|
if self.dims > 2: |
|
print( |
|
f"running with dims={dims}. This means that for perceptual loss " |
|
f"calculation, the LPIPS loss will be applied to each frame " |
|
f"independently." |
|
) |
|
self.scale_input_to_tgt_size = scale_input_to_tgt_size |
|
assert disc_loss in ["hinge", "vanilla"] |
|
self.perceptual_loss = LPIPS().eval() |
|
self.perceptual_weight = perceptual_weight |
|
|
|
self.logvar = nn.Parameter( |
|
torch.full((), logvar_init), requires_grad=learn_logvar |
|
) |
|
self.learn_logvar = learn_logvar |
|
|
|
discriminator_config = default( |
|
discriminator_config, |
|
{ |
|
"target": "sgm.modules.autoencoding.lpips.model.model.NLayerDiscriminator", |
|
"params": { |
|
"input_nc": disc_in_channels, |
|
"n_layers": disc_num_layers, |
|
"use_actnorm": False, |
|
}, |
|
}, |
|
) |
|
|
|
self.discriminator = instantiate_from_config(discriminator_config).apply( |
|
weights_init |
|
) |
|
self.discriminator_iter_start = disc_start |
|
self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss |
|
self.disc_factor = disc_factor |
|
self.discriminator_weight = disc_weight |
|
self.regularization_weights = default(regularization_weights, {}) |
|
|
|
self.forward_keys = [ |
|
"optimizer_idx", |
|
"global_step", |
|
"last_layer", |
|
"split", |
|
"regularization_log", |
|
] |
|
|
|
self.additional_log_keys = set(default(additional_log_keys, [])) |
|
self.additional_log_keys.update(set(self.regularization_weights.keys())) |
|
|
|
def get_trainable_parameters(self) -> Iterator[nn.Parameter]: |
|
return self.discriminator.parameters() |
|
|
|
def get_trainable_autoencoder_parameters(self) -> Iterator[nn.Parameter]: |
|
if self.learn_logvar: |
|
yield self.logvar |
|
yield from () |
|
|
|
@torch.no_grad() |
|
def log_images( |
|
self, inputs: torch.Tensor, reconstructions: torch.Tensor |
|
) -> Dict[str, torch.Tensor]: |
|
|
|
logits_real = self.discriminator(inputs.contiguous().detach()) |
|
if len(logits_real.shape) < 4: |
|
|
|
return dict() |
|
logits_fake = self.discriminator(reconstructions.contiguous().detach()) |
|
|
|
|
|
|
|
high = max(logits_fake.abs().max(), logits_real.abs().max()).item() |
|
cmap = colormaps["PiYG"] |
|
|
|
def to_colormap(logits: torch.Tensor) -> torch.Tensor: |
|
"""(b, 1, ...) -> (b, 3, ...)""" |
|
logits = (logits + high) / (2 * high) |
|
logits_np = cmap(logits.cpu().numpy())[..., :3] |
|
|
|
logits = torch.from_numpy(logits_np).to(logits.device) |
|
return rearrange(logits, "b 1 ... c -> b c ...") |
|
|
|
logits_real = torch.nn.functional.interpolate( |
|
logits_real, |
|
size=inputs.shape[-2:], |
|
mode="nearest", |
|
antialias=False, |
|
) |
|
logits_fake = torch.nn.functional.interpolate( |
|
logits_fake, |
|
size=reconstructions.shape[-2:], |
|
mode="nearest", |
|
antialias=False, |
|
) |
|
|
|
|
|
alpha_real = torch.abs(logits_real) / high |
|
alpha_fake = torch.abs(logits_fake) / high |
|
|
|
|
|
|
|
grid_alpha_real = torchvision.utils.make_grid(alpha_real, nrow=4) |
|
grid_alpha_fake = torchvision.utils.make_grid(alpha_fake, nrow=4) |
|
grid_alpha = 0.8 * torch.cat((grid_alpha_real, grid_alpha_fake), dim=1) |
|
|
|
|
|
|
|
|
|
logits_real = to_colormap(logits_real) |
|
logits_fake = to_colormap(logits_fake) |
|
|
|
|
|
|
|
|
|
|
|
logits_real = torchvision.utils.make_grid(logits_real, nrow=4) |
|
logits_fake = torchvision.utils.make_grid(logits_fake, nrow=4) |
|
|
|
grid_logits = torch.cat((logits_real, logits_fake), dim=1) |
|
|
|
|
|
grid_images_real = torchvision.utils.make_grid(0.5 * inputs + 0.5, nrow=4) |
|
grid_images_fake = torchvision.utils.make_grid( |
|
0.5 * reconstructions + 0.5, nrow=4 |
|
) |
|
grid_images = torch.cat((grid_images_real, grid_images_fake), dim=1) |
|
|
|
|
|
grid_blend = grid_alpha * grid_logits + (1 - grid_alpha) * grid_images |
|
|
|
|
|
dpi = 100 |
|
height = 128 / dpi |
|
width = grid_logits.shape[2] / dpi |
|
fig, ax = plt.subplots(figsize=(width, height), dpi=dpi) |
|
img = ax.imshow(np.array([[-high, high]]), cmap=cmap) |
|
plt.colorbar( |
|
img, |
|
cax=ax, |
|
orientation="horizontal", |
|
fraction=0.9, |
|
aspect=width / height, |
|
pad=0.0, |
|
) |
|
img.set_visible(False) |
|
fig.tight_layout() |
|
fig.canvas.draw() |
|
|
|
cbar_np = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) |
|
cbar_np = cbar_np.reshape(fig.canvas.get_width_height()[::-1] + (3,)) |
|
cbar = torch.from_numpy(cbar_np.copy()).to(grid_logits.dtype) / 255.0 |
|
cbar = rearrange(cbar, "h w c -> c h w").to(grid_logits.device) |
|
|
|
|
|
annotated_grid = torch.cat((grid_logits, cbar), dim=1) |
|
blended_grid = torch.cat((grid_blend, cbar), dim=1) |
|
return { |
|
"vis_logits": 2 * annotated_grid[None, ...] - 1, |
|
"vis_logits_blended": 2 * blended_grid[None, ...] - 1, |
|
} |
|
|
|
def calculate_adaptive_weight( |
|
self, nll_loss: torch.Tensor, g_loss: torch.Tensor, last_layer: torch.Tensor |
|
) -> torch.Tensor: |
|
nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] |
|
g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] |
|
|
|
d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) |
|
d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() |
|
d_weight = d_weight * self.discriminator_weight |
|
return d_weight |
|
|
|
def forward( |
|
self, |
|
inputs: torch.Tensor, |
|
reconstructions: torch.Tensor, |
|
*, |
|
regularization_log: Dict[str, torch.Tensor], |
|
optimizer_idx: int, |
|
global_step: int, |
|
last_layer: torch.Tensor, |
|
split: str = "train", |
|
weights: Union[None, float, torch.Tensor] = None, |
|
) -> Tuple[torch.Tensor, dict]: |
|
if self.scale_input_to_tgt_size: |
|
inputs = torch.nn.functional.interpolate( |
|
inputs, reconstructions.shape[2:], mode="bicubic", antialias=True |
|
) |
|
|
|
if self.dims > 2: |
|
inputs, reconstructions = map( |
|
lambda x: rearrange(x, "b c t h w -> (b t) c h w"), |
|
(inputs, reconstructions), |
|
) |
|
|
|
rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) |
|
if self.perceptual_weight > 0: |
|
p_loss = self.perceptual_loss( |
|
inputs.contiguous(), reconstructions.contiguous() |
|
) |
|
rec_loss = rec_loss + self.perceptual_weight * p_loss |
|
|
|
nll_loss, weighted_nll_loss = self.get_nll_loss(rec_loss, weights) |
|
|
|
|
|
if optimizer_idx == 0: |
|
|
|
if global_step >= self.discriminator_iter_start or not self.training: |
|
logits_fake = self.discriminator(reconstructions.contiguous()) |
|
g_loss = -torch.mean(logits_fake) |
|
if self.training: |
|
d_weight = self.calculate_adaptive_weight( |
|
nll_loss, g_loss, last_layer=last_layer |
|
) |
|
else: |
|
d_weight = torch.tensor(1.0) |
|
else: |
|
d_weight = torch.tensor(0.0) |
|
g_loss = torch.tensor(0.0, requires_grad=True) |
|
|
|
loss = weighted_nll_loss + d_weight * self.disc_factor * g_loss |
|
log = dict() |
|
for k in regularization_log: |
|
if k in self.regularization_weights: |
|
loss = loss + self.regularization_weights[k] * regularization_log[k] |
|
if k in self.additional_log_keys: |
|
log[f"{split}/{k}"] = regularization_log[k].detach().float().mean() |
|
|
|
log.update( |
|
{ |
|
f"{split}/loss/total": loss.clone().detach().mean(), |
|
f"{split}/loss/nll": nll_loss.detach().mean(), |
|
f"{split}/loss/rec": rec_loss.detach().mean(), |
|
f"{split}/loss/g": g_loss.detach().mean(), |
|
f"{split}/scalars/logvar": self.logvar.detach(), |
|
f"{split}/scalars/d_weight": d_weight.detach(), |
|
} |
|
) |
|
|
|
return loss, log |
|
elif optimizer_idx == 1: |
|
|
|
logits_real = self.discriminator(inputs.contiguous().detach()) |
|
logits_fake = self.discriminator(reconstructions.contiguous().detach()) |
|
|
|
if global_step >= self.discriminator_iter_start or not self.training: |
|
d_loss = self.disc_factor * self.disc_loss(logits_real, logits_fake) |
|
else: |
|
d_loss = torch.tensor(0.0, requires_grad=True) |
|
|
|
log = { |
|
f"{split}/loss/disc": d_loss.clone().detach().mean(), |
|
f"{split}/logits/real": logits_real.detach().mean(), |
|
f"{split}/logits/fake": logits_fake.detach().mean(), |
|
} |
|
return d_loss, log |
|
else: |
|
raise NotImplementedError(f"Unknown optimizer_idx {optimizer_idx}") |
|
|
|
def get_nll_loss( |
|
self, |
|
rec_loss: torch.Tensor, |
|
weights: Optional[Union[float, torch.Tensor]] = None, |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar |
|
weighted_nll_loss = nll_loss |
|
if weights is not None: |
|
weighted_nll_loss = weights * nll_loss |
|
weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0] |
|
nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] |
|
|
|
return nll_loss, weighted_nll_loss |
|
|