File size: 1,285 Bytes
c4c7cee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import torch
import torch.nn as nn
import numpy as np


class PositionalEmbedding(nn.Module):
    """
    Taken from https://github.com/NVlabs/edm
    """

    def __init__(self, num_channels, max_positions=10000, endpoint=False):
        super().__init__()
        self.num_channels = num_channels
        self.max_positions = max_positions
        self.endpoint = endpoint
        freqs = torch.arange(start=0, end=self.num_channels // 2, dtype=torch.float32)
        freqs = 2 * freqs / self.num_channels
        freqs = (1 / self.max_positions) ** freqs
        self.register_buffer("freqs", freqs)

    def forward(self, x):
        x = torch.outer(x, self.freqs)
        out = torch.cat([x.cos(), x.sin()], dim=1)
        return out.to(x.dtype)


# ----------------------------------------------------------------------------
# Timestep embedding used in the NCSN++ architecture.
class FourierEmbedding(nn.Module):
    """
    Taken from https://github.com/NVlabs/edm
    """

    def __init__(self, num_channels, scale=16):
        super().__init__()
        self.register_buffer("freqs", torch.randn(num_channels // 2) * scale)

    def forward(self, x):
        x = x.ger((2 * np.pi * self.freqs).to(x.dtype))
        x = torch.cat([x.cos(), x.sin()], dim=1)
        return x