Plonk / models /preconditioning.py
nicolas-dufour's picture
squash: merge all unpushed commits
c4c7cee
import torch
from torch import nn
# ----------------------------------------------------------------------------
# Improved preconditioning proposed in the paper "Elucidating the Design
# Space of Diffusion-Based Generative networks" (EDM).
class EDMPrecond(torch.nn.Module):
def __init__(
self,
network,
label_dim=0, # Number of class labels, 0 = unconditional.
sigma_min=0, # Minimum supported noise level.
sigma_max=float("inf"), # Maximum supported noise level.
sigma_data=0.5, # Expected standard deviation of the training data.
):
super().__init__()
self.label_dim = label_dim
self.sigma_min = sigma_min
self.sigma_max = sigma_max
self.sigma_data = sigma_data
self.network = network
def forward(self, x, sigma, conditioning=None, **network_kwargs):
x = x.to(torch.float32)
sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1)
conditioning = (
None
if self.label_dim == 0
else torch.zeros([1, self.label_dim], device=x.device)
if conditioning is None
else conditioning.to(torch.float32)
)
c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2)
c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2).sqrt()
c_in = 1 / (self.sigma_data**2 + sigma**2).sqrt()
c_noise = sigma.log() / 4
F_x = self.network(
(c_in * x),
c_noise.flatten(),
conditioning=conditioning,
**network_kwargs,
)
D_x = c_skip * x + c_out * F_x.to(torch.float32)
return D_x
def round_sigma(self, sigma):
return torch.as_tensor(sigma)
class DDPMPrecond(nn.Module):
def __init__(self):
super().__init__()
def forward(self, network, batch):
F_x = network(batch)
return F_x