Spaces:
Running
Running
import torch | |
from torch import nn | |
# ---------------------------------------------------------------------------- | |
# Improved preconditioning proposed in the paper "Elucidating the Design | |
# Space of Diffusion-Based Generative networks" (EDM). | |
class EDMPrecond(torch.nn.Module): | |
def __init__( | |
self, | |
network, | |
label_dim=0, # Number of class labels, 0 = unconditional. | |
sigma_min=0, # Minimum supported noise level. | |
sigma_max=float("inf"), # Maximum supported noise level. | |
sigma_data=0.5, # Expected standard deviation of the training data. | |
): | |
super().__init__() | |
self.label_dim = label_dim | |
self.sigma_min = sigma_min | |
self.sigma_max = sigma_max | |
self.sigma_data = sigma_data | |
self.network = network | |
def forward(self, x, sigma, conditioning=None, **network_kwargs): | |
x = x.to(torch.float32) | |
sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1) | |
conditioning = ( | |
None | |
if self.label_dim == 0 | |
else torch.zeros([1, self.label_dim], device=x.device) | |
if conditioning is None | |
else conditioning.to(torch.float32) | |
) | |
c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2) | |
c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2).sqrt() | |
c_in = 1 / (self.sigma_data**2 + sigma**2).sqrt() | |
c_noise = sigma.log() / 4 | |
F_x = self.network( | |
(c_in * x), | |
c_noise.flatten(), | |
conditioning=conditioning, | |
**network_kwargs, | |
) | |
D_x = c_skip * x + c_out * F_x.to(torch.float32) | |
return D_x | |
def round_sigma(self, sigma): | |
return torch.as_tensor(sigma) | |
class DDPMPrecond(nn.Module): | |
def __init__(self): | |
super().__init__() | |
def forward(self, network, batch): | |
F_x = network(batch) | |
return F_x | |