Spaces:
Running
Running
File size: 1,285 Bytes
c4c7cee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 |
import torch
import torch.nn as nn
import numpy as np
class PositionalEmbedding(nn.Module):
"""
Taken from https://github.com/NVlabs/edm
"""
def __init__(self, num_channels, max_positions=10000, endpoint=False):
super().__init__()
self.num_channels = num_channels
self.max_positions = max_positions
self.endpoint = endpoint
freqs = torch.arange(start=0, end=self.num_channels // 2, dtype=torch.float32)
freqs = 2 * freqs / self.num_channels
freqs = (1 / self.max_positions) ** freqs
self.register_buffer("freqs", freqs)
def forward(self, x):
x = torch.outer(x, self.freqs)
out = torch.cat([x.cos(), x.sin()], dim=1)
return out.to(x.dtype)
# ----------------------------------------------------------------------------
# Timestep embedding used in the NCSN++ architecture.
class FourierEmbedding(nn.Module):
"""
Taken from https://github.com/NVlabs/edm
"""
def __init__(self, num_channels, scale=16):
super().__init__()
self.register_buffer("freqs", torch.randn(num_channels // 2) * scale)
def forward(self, x):
x = x.ger((2 * np.pi * self.freqs).to(x.dtype))
x = torch.cat([x.cos(), x.sin()], dim=1)
return x
|