Spaces:
Runtime error
Runtime error
from typing import Optional, Union, Sequence | |
import jax | |
import jax.numpy as jnp | |
import flax.linen as nn | |
import einops | |
class ConvPseudo3D(nn.Module): | |
features: int | |
kernel_size: Sequence[int] | |
strides: Union[None, int, Sequence[int]] = 1 | |
padding: nn.linear.PaddingLike = 'SAME' | |
dtype: jnp.dtype = jnp.float32 | |
def setup(self) -> None: | |
self.spatial_conv = nn.Conv( | |
features = self.features, | |
kernel_size = self.kernel_size, | |
strides = self.strides, | |
padding = self.padding, | |
dtype = self.dtype | |
) | |
self.temporal_conv = nn.Conv( | |
features = self.features, | |
kernel_size = (3,), | |
padding = 'SAME', | |
dtype = self.dtype, | |
bias_init = nn.initializers.zeros_init() | |
# TODO dirac delta (identity) initialization impl | |
# kernel_init = torch.nn.init.dirac_ <-> jax/lax | |
) | |
def __call__(self, x: jax.Array, convolve_across_time: bool = True) -> jax.Array: | |
is_video = x.ndim == 5 | |
convolve_across_time = convolve_across_time and is_video | |
if is_video: | |
b, f, h, w, c = x.shape | |
x = einops.rearrange(x, 'b f h w c -> (b f) h w c') | |
x = self.spatial_conv(x) | |
if is_video: | |
x = einops.rearrange(x, '(b f) h w c -> b f h w c', b = b) | |
b, f, h, w, c = x.shape | |
if not convolve_across_time: | |
return x | |
if is_video: | |
x = einops.rearrange(x, 'b f h w c -> (b h w) f c') | |
x = self.temporal_conv(x) | |
x = einops.rearrange(x, '(b h w) f c -> b f h w c', h = h, w = w) | |
return x | |
class UpsamplePseudo3D(nn.Module): | |
out_channels: int | |
dtype: jnp.dtype = jnp.float32 | |
def setup(self) -> None: | |
self.conv = ConvPseudo3D( | |
features = self.out_channels, | |
kernel_size = (3, 3), | |
strides = (1, 1), | |
padding = ((1, 1), (1, 1)), | |
dtype = self.dtype | |
) | |
def __call__(self, hidden_states: jax.Array) -> jax.Array: | |
is_video = hidden_states.ndim == 5 | |
if is_video: | |
b, *_ = hidden_states.shape | |
hidden_states = einops.rearrange(hidden_states, 'b f h w c -> (b f) h w c') | |
batch, h, w, c = hidden_states.shape | |
hidden_states = jax.image.resize( | |
image = hidden_states, | |
shape = (batch, h * 2, w * 2, c), | |
method = 'nearest' | |
) | |
if is_video: | |
hidden_states = einops.rearrange(hidden_states, '(b f) h w c -> b f h w c', b = b) | |
hidden_states = self.conv(hidden_states) | |
return hidden_states | |
class DownsamplePseudo3D(nn.Module): | |
out_channels: int | |
dtype: jnp.dtype = jnp.float32 | |
def setup(self) -> None: | |
self.conv = ConvPseudo3D( | |
features = self.out_channels, | |
kernel_size = (3, 3), | |
strides = (2, 2), | |
padding = ((1, 1), (1, 1)), | |
dtype = self.dtype | |
) | |
def __call__(self, hidden_states: jax.Array) -> jax.Array: | |
hidden_states = self.conv(hidden_states) | |
return hidden_states | |
class ResnetBlockPseudo3D(nn.Module): | |
in_channels: int | |
out_channels: Optional[int] = None | |
use_nin_shortcut: Optional[bool] = None | |
dtype: jnp.dtype = jnp.float32 | |
def setup(self) -> None: | |
out_channels = self.in_channels if self.out_channels is None else self.out_channels | |
self.norm1 = nn.GroupNorm( | |
num_groups = 32, | |
epsilon = 1e-5 | |
) | |
self.conv1 = ConvPseudo3D( | |
features = out_channels, | |
kernel_size = (3, 3), | |
strides = (1, 1), | |
padding = ((1, 1), (1, 1)), | |
dtype = self.dtype | |
) | |
self.time_emb_proj = nn.Dense( | |
out_channels, | |
dtype = self.dtype | |
) | |
self.norm2 = nn.GroupNorm( | |
num_groups = 32, | |
epsilon = 1e-5 | |
) | |
self.conv2 = ConvPseudo3D( | |
features = out_channels, | |
kernel_size = (3, 3), | |
strides = (1, 1), | |
padding = ((1, 1), (1, 1)), | |
dtype = self.dtype | |
) | |
use_nin_shortcut = self.in_channels != out_channels if self.use_nin_shortcut is None else self.use_nin_shortcut | |
self.conv_shortcut = None | |
if use_nin_shortcut: | |
self.conv_shortcut = ConvPseudo3D( | |
features = self.out_channels, | |
kernel_size = (1, 1), | |
strides = (1, 1), | |
padding = 'VALID', | |
dtype = self.dtype | |
) | |
def __call__(self, | |
hidden_states: jax.Array, | |
temb: jax.Array | |
) -> jax.Array: | |
is_video = hidden_states.ndim == 5 | |
residual = hidden_states | |
hidden_states = self.norm1(hidden_states) | |
hidden_states = nn.silu(hidden_states) | |
hidden_states = self.conv1(hidden_states) | |
temb = nn.silu(temb) | |
temb = self.time_emb_proj(temb) | |
temb = jnp.expand_dims(temb, 1) | |
temb = jnp.expand_dims(temb, 1) | |
if is_video: | |
b, f, *_ = hidden_states.shape | |
hidden_states = einops.rearrange(hidden_states, 'b f h w c -> (b f) h w c') | |
hidden_states = hidden_states + temb.repeat(f, 0) | |
hidden_states = einops.rearrange(hidden_states, '(b f) h w c -> b f h w c', b = b) | |
else: | |
hidden_states = hidden_states + temb | |
hidden_states = self.norm2(hidden_states) | |
hidden_states = nn.silu(hidden_states) | |
hidden_states = self.conv2(hidden_states) | |
if self.conv_shortcut is not None: | |
residual = self.conv_shortcut(residual) | |
hidden_states = hidden_states + residual | |
return hidden_states | |