Spaces:
Runtime error
Runtime error
import jax | |
import jax.numpy as jnp | |
import flax.linen as nn | |
def get_sinusoidal_embeddings( | |
timesteps: jax.Array, | |
embedding_dim: int, | |
freq_shift: float = 1, | |
min_timescale: float = 1, | |
max_timescale: float = 1.0e4, | |
flip_sin_to_cos: bool = False, | |
scale: float = 1.0, | |
dtype: jnp.dtype = jnp.float32 | |
) -> jax.Array: | |
assert timesteps.ndim == 1, "Timesteps should be a 1d-array" | |
assert embedding_dim % 2 == 0, f"Embedding dimension {embedding_dim} should be even" | |
num_timescales = float(embedding_dim // 2) | |
log_timescale_increment = jnp.log(max_timescale / min_timescale) / (num_timescales - freq_shift) | |
inv_timescales = min_timescale * jnp.exp(jnp.arange(num_timescales, dtype = dtype) * -log_timescale_increment) | |
emb = jnp.expand_dims(timesteps, 1) * jnp.expand_dims(inv_timescales, 0) | |
# scale embeddings | |
scaled_time = scale * emb | |
if flip_sin_to_cos: | |
signal = jnp.concatenate([jnp.cos(scaled_time), jnp.sin(scaled_time)], axis = 1) | |
else: | |
signal = jnp.concatenate([jnp.sin(scaled_time), jnp.cos(scaled_time)], axis = 1) | |
signal = jnp.reshape(signal, [jnp.shape(timesteps)[0], embedding_dim]) | |
return signal | |
class TimestepEmbedding(nn.Module): | |
time_embed_dim: int = 32 | |
dtype: jnp.dtype = jnp.float32 | |
def __call__(self, temb: jax.Array) -> jax.Array: | |
temb = nn.Dense(self.time_embed_dim, dtype = self.dtype, name = "linear_1")(temb) | |
temb = nn.silu(temb) | |
temb = nn.Dense(self.time_embed_dim, dtype = self.dtype, name = "linear_2")(temb) | |
return temb | |
class Timesteps(nn.Module): | |
dim: int = 32 | |
flip_sin_to_cos: bool = False | |
freq_shift: float = 1 | |
dtype: jnp.dtype = jnp.float32 | |
def __call__(self, timesteps: jax.Array) -> jax.Array: | |
return get_sinusoidal_embeddings( | |
timesteps = timesteps, | |
embedding_dim = self.dim, | |
flip_sin_to_cos = self.flip_sin_to_cos, | |
freq_shift = self.freq_shift, | |
dtype = self.dtype | |
) | |