Spaces:
Runtime error
Runtime error
File size: 6,017 Bytes
149cc2d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 |
from typing import Optional, Union, Sequence
import jax
import jax.numpy as jnp
import flax.linen as nn
import einops
class ConvPseudo3D(nn.Module):
features: int
kernel_size: Sequence[int]
strides: Union[None, int, Sequence[int]] = 1
padding: nn.linear.PaddingLike = 'SAME'
dtype: jnp.dtype = jnp.float32
def setup(self) -> None:
self.spatial_conv = nn.Conv(
features = self.features,
kernel_size = self.kernel_size,
strides = self.strides,
padding = self.padding,
dtype = self.dtype
)
self.temporal_conv = nn.Conv(
features = self.features,
kernel_size = (3,),
padding = 'SAME',
dtype = self.dtype,
bias_init = nn.initializers.zeros_init()
# TODO dirac delta (identity) initialization impl
# kernel_init = torch.nn.init.dirac_ <-> jax/lax
)
def __call__(self, x: jax.Array, convolve_across_time: bool = True) -> jax.Array:
is_video = x.ndim == 5
convolve_across_time = convolve_across_time and is_video
if is_video:
b, f, h, w, c = x.shape
x = einops.rearrange(x, 'b f h w c -> (b f) h w c')
x = self.spatial_conv(x)
if is_video:
x = einops.rearrange(x, '(b f) h w c -> b f h w c', b = b)
b, f, h, w, c = x.shape
if not convolve_across_time:
return x
if is_video:
x = einops.rearrange(x, 'b f h w c -> (b h w) f c')
x = self.temporal_conv(x)
x = einops.rearrange(x, '(b h w) f c -> b f h w c', h = h, w = w)
return x
class UpsamplePseudo3D(nn.Module):
out_channels: int
dtype: jnp.dtype = jnp.float32
def setup(self) -> None:
self.conv = ConvPseudo3D(
features = self.out_channels,
kernel_size = (3, 3),
strides = (1, 1),
padding = ((1, 1), (1, 1)),
dtype = self.dtype
)
def __call__(self, hidden_states: jax.Array) -> jax.Array:
is_video = hidden_states.ndim == 5
if is_video:
b, *_ = hidden_states.shape
hidden_states = einops.rearrange(hidden_states, 'b f h w c -> (b f) h w c')
batch, h, w, c = hidden_states.shape
hidden_states = jax.image.resize(
image = hidden_states,
shape = (batch, h * 2, w * 2, c),
method = 'nearest'
)
if is_video:
hidden_states = einops.rearrange(hidden_states, '(b f) h w c -> b f h w c', b = b)
hidden_states = self.conv(hidden_states)
return hidden_states
class DownsamplePseudo3D(nn.Module):
out_channels: int
dtype: jnp.dtype = jnp.float32
def setup(self) -> None:
self.conv = ConvPseudo3D(
features = self.out_channels,
kernel_size = (3, 3),
strides = (2, 2),
padding = ((1, 1), (1, 1)),
dtype = self.dtype
)
def __call__(self, hidden_states: jax.Array) -> jax.Array:
hidden_states = self.conv(hidden_states)
return hidden_states
class ResnetBlockPseudo3D(nn.Module):
in_channels: int
out_channels: Optional[int] = None
use_nin_shortcut: Optional[bool] = None
dtype: jnp.dtype = jnp.float32
def setup(self) -> None:
out_channels = self.in_channels if self.out_channels is None else self.out_channels
self.norm1 = nn.GroupNorm(
num_groups = 32,
epsilon = 1e-5
)
self.conv1 = ConvPseudo3D(
features = out_channels,
kernel_size = (3, 3),
strides = (1, 1),
padding = ((1, 1), (1, 1)),
dtype = self.dtype
)
self.time_emb_proj = nn.Dense(
out_channels,
dtype = self.dtype
)
self.norm2 = nn.GroupNorm(
num_groups = 32,
epsilon = 1e-5
)
self.conv2 = ConvPseudo3D(
features = out_channels,
kernel_size = (3, 3),
strides = (1, 1),
padding = ((1, 1), (1, 1)),
dtype = self.dtype
)
use_nin_shortcut = self.in_channels != out_channels if self.use_nin_shortcut is None else self.use_nin_shortcut
self.conv_shortcut = None
if use_nin_shortcut:
self.conv_shortcut = ConvPseudo3D(
features = self.out_channels,
kernel_size = (1, 1),
strides = (1, 1),
padding = 'VALID',
dtype = self.dtype
)
def __call__(self,
hidden_states: jax.Array,
temb: jax.Array
) -> jax.Array:
is_video = hidden_states.ndim == 5
residual = hidden_states
hidden_states = self.norm1(hidden_states)
hidden_states = nn.silu(hidden_states)
hidden_states = self.conv1(hidden_states)
temb = nn.silu(temb)
temb = self.time_emb_proj(temb)
temb = jnp.expand_dims(temb, 1)
temb = jnp.expand_dims(temb, 1)
if is_video:
b, f, *_ = hidden_states.shape
hidden_states = einops.rearrange(hidden_states, 'b f h w c -> (b f) h w c')
hidden_states = hidden_states + temb.repeat(f, 0)
hidden_states = einops.rearrange(hidden_states, '(b f) h w c -> b f h w c', b = b)
else:
hidden_states = hidden_states + temb
hidden_states = self.norm2(hidden_states)
hidden_states = nn.silu(hidden_states)
hidden_states = self.conv2(hidden_states)
if self.conv_shortcut is not None:
residual = self.conv_shortcut(residual)
hidden_states = hidden_states + residual
return hidden_states
|