File size: 2,122 Bytes
149cc2d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63

import jax
import jax.numpy as jnp
import flax.linen as nn


def get_sinusoidal_embeddings(
        timesteps: jax.Array,
        embedding_dim: int,
        freq_shift: float = 1,
        min_timescale: float = 1,
        max_timescale: float = 1.0e4,
        flip_sin_to_cos: bool = False,
        scale: float = 1.0,
        dtype: jnp.dtype = jnp.float32
) -> jax.Array:
    assert timesteps.ndim == 1, "Timesteps should be a 1d-array"
    assert embedding_dim % 2 == 0, f"Embedding dimension {embedding_dim} should be even"
    num_timescales = float(embedding_dim // 2)
    log_timescale_increment = jnp.log(max_timescale / min_timescale) / (num_timescales - freq_shift)
    inv_timescales = min_timescale * jnp.exp(jnp.arange(num_timescales, dtype = dtype) * -log_timescale_increment)
    emb = jnp.expand_dims(timesteps, 1) * jnp.expand_dims(inv_timescales, 0)

    # scale embeddings
    scaled_time = scale * emb

    if flip_sin_to_cos:
        signal = jnp.concatenate([jnp.cos(scaled_time), jnp.sin(scaled_time)], axis = 1)
    else:
        signal = jnp.concatenate([jnp.sin(scaled_time), jnp.cos(scaled_time)], axis = 1)
    signal = jnp.reshape(signal, [jnp.shape(timesteps)[0], embedding_dim])
    return signal


class TimestepEmbedding(nn.Module):
    time_embed_dim: int = 32
    dtype: jnp.dtype = jnp.float32

    @nn.compact
    def __call__(self, temb: jax.Array) -> jax.Array:
        temb = nn.Dense(self.time_embed_dim, dtype = self.dtype, name = "linear_1")(temb)
        temb = nn.silu(temb)
        temb = nn.Dense(self.time_embed_dim, dtype = self.dtype, name = "linear_2")(temb)
        return temb


class Timesteps(nn.Module):
    dim: int = 32
    flip_sin_to_cos: bool = False
    freq_shift: float = 1
    dtype: jnp.dtype = jnp.float32

    @nn.compact
    def __call__(self, timesteps: jax.Array) -> jax.Array:
        return get_sinusoidal_embeddings(
                timesteps = timesteps,
                embedding_dim = self.dim,
                flip_sin_to_cos = self.flip_sin_to_cos,
                freq_shift = self.freq_shift,
                dtype = self.dtype
        )