File size: 1,167 Bytes
f7a5cb1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import torch.nn as nn


# ----------------------------------------------------------------------------
# Improved preconditioning proposed in the paper "Elucidating the Design
# Space of Diffusion-Based Generative Models" (EDM).


class RnEDMPrecond(nn.Module):
    def __init__(self, sigma_data: float = 0.5, module: nn.Module = None, **kwargs):
        super().__init__()
        self.sigma_data = sigma_data

        self.model = module
        self.num_rawfeats = module.num_rawfeats
        self.num_feats = module.num_feats
        self.num_cams = module.num_cams

    def forward(self, x, sigma, y=None, mask=None):
        """
        x: [batch_size, num_feats, max_frames], denoted x_t in the paper
        sigma: [batch_size] (int)
        """
        sigma = sigma.reshape(-1, 1, 1)
        c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2)
        c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2).sqrt()
        c_in = 1 / (self.sigma_data**2 + sigma**2).sqrt()
        c_noise = sigma.log() / 4

        F_x = self.model(c_in * x, c_noise.flatten(), y=y, mask=mask)
        D_x = c_skip * x + c_out * F_x

        return D_x