makeavid-sd-jax / makeavid_sd /flax_impl /flax_resnet_pseudo3d.py
lopho's picture
forgot about the nested package structure
b2f876f
raw
history blame
6.02 kB
from typing import Optional, Union, Sequence
import jax
import jax.numpy as jnp
import flax.linen as nn
import einops
class ConvPseudo3D(nn.Module):
features: int
kernel_size: Sequence[int]
strides: Union[None, int, Sequence[int]] = 1
padding: nn.linear.PaddingLike = 'SAME'
dtype: jnp.dtype = jnp.float32
def setup(self) -> None:
self.spatial_conv = nn.Conv(
features = self.features,
kernel_size = self.kernel_size,
strides = self.strides,
padding = self.padding,
dtype = self.dtype
)
self.temporal_conv = nn.Conv(
features = self.features,
kernel_size = (3,),
padding = 'SAME',
dtype = self.dtype,
bias_init = nn.initializers.zeros_init()
# TODO dirac delta (identity) initialization impl
# kernel_init = torch.nn.init.dirac_ <-> jax/lax
)
def __call__(self, x: jax.Array, convolve_across_time: bool = True) -> jax.Array:
is_video = x.ndim == 5
convolve_across_time = convolve_across_time and is_video
if is_video:
b, f, h, w, c = x.shape
x = einops.rearrange(x, 'b f h w c -> (b f) h w c')
x = self.spatial_conv(x)
if is_video:
x = einops.rearrange(x, '(b f) h w c -> b f h w c', b = b)
b, f, h, w, c = x.shape
if not convolve_across_time:
return x
if is_video:
x = einops.rearrange(x, 'b f h w c -> (b h w) f c')
x = self.temporal_conv(x)
x = einops.rearrange(x, '(b h w) f c -> b f h w c', h = h, w = w)
return x
class UpsamplePseudo3D(nn.Module):
out_channels: int
dtype: jnp.dtype = jnp.float32
def setup(self) -> None:
self.conv = ConvPseudo3D(
features = self.out_channels,
kernel_size = (3, 3),
strides = (1, 1),
padding = ((1, 1), (1, 1)),
dtype = self.dtype
)
def __call__(self, hidden_states: jax.Array) -> jax.Array:
is_video = hidden_states.ndim == 5
if is_video:
b, *_ = hidden_states.shape
hidden_states = einops.rearrange(hidden_states, 'b f h w c -> (b f) h w c')
batch, h, w, c = hidden_states.shape
hidden_states = jax.image.resize(
image = hidden_states,
shape = (batch, h * 2, w * 2, c),
method = 'nearest'
)
if is_video:
hidden_states = einops.rearrange(hidden_states, '(b f) h w c -> b f h w c', b = b)
hidden_states = self.conv(hidden_states)
return hidden_states
class DownsamplePseudo3D(nn.Module):
out_channels: int
dtype: jnp.dtype = jnp.float32
def setup(self) -> None:
self.conv = ConvPseudo3D(
features = self.out_channels,
kernel_size = (3, 3),
strides = (2, 2),
padding = ((1, 1), (1, 1)),
dtype = self.dtype
)
def __call__(self, hidden_states: jax.Array) -> jax.Array:
hidden_states = self.conv(hidden_states)
return hidden_states
class ResnetBlockPseudo3D(nn.Module):
in_channels: int
out_channels: Optional[int] = None
use_nin_shortcut: Optional[bool] = None
dtype: jnp.dtype = jnp.float32
def setup(self) -> None:
out_channels = self.in_channels if self.out_channels is None else self.out_channels
self.norm1 = nn.GroupNorm(
num_groups = 32,
epsilon = 1e-5
)
self.conv1 = ConvPseudo3D(
features = out_channels,
kernel_size = (3, 3),
strides = (1, 1),
padding = ((1, 1), (1, 1)),
dtype = self.dtype
)
self.time_emb_proj = nn.Dense(
out_channels,
dtype = self.dtype
)
self.norm2 = nn.GroupNorm(
num_groups = 32,
epsilon = 1e-5
)
self.conv2 = ConvPseudo3D(
features = out_channels,
kernel_size = (3, 3),
strides = (1, 1),
padding = ((1, 1), (1, 1)),
dtype = self.dtype
)
use_nin_shortcut = self.in_channels != out_channels if self.use_nin_shortcut is None else self.use_nin_shortcut
self.conv_shortcut = None
if use_nin_shortcut:
self.conv_shortcut = ConvPseudo3D(
features = self.out_channels,
kernel_size = (1, 1),
strides = (1, 1),
padding = 'VALID',
dtype = self.dtype
)
def __call__(self,
hidden_states: jax.Array,
temb: jax.Array
) -> jax.Array:
is_video = hidden_states.ndim == 5
residual = hidden_states
hidden_states = self.norm1(hidden_states)
hidden_states = nn.silu(hidden_states)
hidden_states = self.conv1(hidden_states)
temb = nn.silu(temb)
temb = self.time_emb_proj(temb)
temb = jnp.expand_dims(temb, 1)
temb = jnp.expand_dims(temb, 1)
if is_video:
b, f, *_ = hidden_states.shape
hidden_states = einops.rearrange(hidden_states, 'b f h w c -> (b f) h w c')
hidden_states = hidden_states + temb.repeat(f, 0)
hidden_states = einops.rearrange(hidden_states, '(b f) h w c -> b f h w c', b = b)
else:
hidden_states = hidden_states + temb
hidden_states = self.norm2(hidden_states)
hidden_states = nn.silu(hidden_states)
hidden_states = self.conv2(hidden_states)
if self.conv_shortcut is not None:
residual = self.conv_shortcut(residual)
hidden_states = hidden_states + residual
return hidden_states