File size: 6,017 Bytes
149cc2d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176

from typing import Optional, Union, Sequence

import jax
import jax.numpy as jnp
import flax.linen as nn

import einops


class ConvPseudo3D(nn.Module):
    features: int
    kernel_size: Sequence[int]
    strides: Union[None, int, Sequence[int]] = 1
    padding: nn.linear.PaddingLike = 'SAME'
    dtype: jnp.dtype = jnp.float32

    def setup(self) -> None:
        self.spatial_conv = nn.Conv(
                features = self.features,
                kernel_size = self.kernel_size,
                strides = self.strides,
                padding = self.padding,
                dtype = self.dtype
        )
        self.temporal_conv = nn.Conv(
                features = self.features,
                kernel_size = (3,),
                padding = 'SAME',
                dtype = self.dtype,
                bias_init = nn.initializers.zeros_init()
                # TODO dirac delta (identity) initialization impl
                # kernel_init = torch.nn.init.dirac_ <-> jax/lax
        )

    def __call__(self, x: jax.Array, convolve_across_time: bool = True) -> jax.Array:
        is_video = x.ndim == 5
        convolve_across_time = convolve_across_time and is_video
        if is_video:
            b, f, h, w, c = x.shape
            x = einops.rearrange(x, 'b f h w c -> (b f) h w c')
        x = self.spatial_conv(x)
        if is_video:
            x = einops.rearrange(x, '(b f) h w c -> b f h w c', b = b)
            b, f, h, w, c = x.shape
        if not convolve_across_time:
            return x
        if is_video:
            x = einops.rearrange(x, 'b f h w c -> (b h w) f c')
            x = self.temporal_conv(x)
            x = einops.rearrange(x, '(b h w) f c -> b f h w c', h = h, w = w)
        return x


class UpsamplePseudo3D(nn.Module):
    out_channels: int
    dtype: jnp.dtype = jnp.float32

    def setup(self) -> None:
        self.conv = ConvPseudo3D(
                features = self.out_channels,
                kernel_size = (3, 3),
                strides = (1, 1),
                padding = ((1, 1), (1, 1)),
                dtype = self.dtype
        )

    def __call__(self, hidden_states: jax.Array) -> jax.Array:
        is_video = hidden_states.ndim == 5
        if is_video:
            b, *_ = hidden_states.shape
            hidden_states = einops.rearrange(hidden_states, 'b f h w c -> (b f) h w c')
        batch, h, w, c = hidden_states.shape
        hidden_states = jax.image.resize(
                image = hidden_states,
                shape = (batch, h * 2, w * 2, c),
                method = 'nearest'
        )
        if is_video:
            hidden_states = einops.rearrange(hidden_states, '(b f) h w c -> b f h w c', b = b)
        hidden_states = self.conv(hidden_states)
        return hidden_states


class DownsamplePseudo3D(nn.Module):
    out_channels: int
    dtype: jnp.dtype = jnp.float32

    def setup(self) -> None:
        self.conv = ConvPseudo3D(
                features = self.out_channels,
                kernel_size = (3, 3),
                strides = (2, 2),
                padding = ((1, 1), (1, 1)),
                dtype = self.dtype
        )

    def __call__(self, hidden_states: jax.Array) -> jax.Array:
        hidden_states = self.conv(hidden_states)
        return hidden_states


class ResnetBlockPseudo3D(nn.Module):
    in_channels: int
    out_channels: Optional[int] = None
    use_nin_shortcut: Optional[bool] = None
    dtype: jnp.dtype = jnp.float32

    def setup(self) -> None:
        out_channels = self.in_channels if self.out_channels is None else self.out_channels
        self.norm1 = nn.GroupNorm(
                num_groups = 32,
                epsilon = 1e-5
        )
        self.conv1 = ConvPseudo3D(
                features = out_channels,
                kernel_size = (3, 3),
                strides = (1, 1),
                padding = ((1, 1), (1, 1)),
                dtype = self.dtype
        )
        self.time_emb_proj = nn.Dense(
                out_channels,
                dtype = self.dtype
        )
        self.norm2 = nn.GroupNorm(
                num_groups = 32,
                epsilon = 1e-5
        )
        self.conv2 = ConvPseudo3D(
                features = out_channels,
                kernel_size = (3, 3),
                strides = (1, 1),
                padding = ((1, 1), (1, 1)),
                dtype = self.dtype
        )
        use_nin_shortcut = self.in_channels != out_channels if self.use_nin_shortcut is None else self.use_nin_shortcut
        self.conv_shortcut = None
        if use_nin_shortcut:
            self.conv_shortcut = ConvPseudo3D(
                    features = self.out_channels,
                    kernel_size = (1, 1),
                    strides = (1, 1),
                    padding = 'VALID',
                    dtype = self.dtype
            )

    def __call__(self,
            hidden_states: jax.Array,
            temb: jax.Array
    ) -> jax.Array:
        is_video = hidden_states.ndim == 5
        residual = hidden_states
        hidden_states = self.norm1(hidden_states)
        hidden_states = nn.silu(hidden_states)
        hidden_states = self.conv1(hidden_states)
        temb = nn.silu(temb)
        temb = self.time_emb_proj(temb)
        temb = jnp.expand_dims(temb, 1)
        temb = jnp.expand_dims(temb, 1)
        if is_video:
            b, f, *_ = hidden_states.shape
            hidden_states = einops.rearrange(hidden_states, 'b f h w c -> (b f) h w c')
            hidden_states = hidden_states + temb.repeat(f, 0)
            hidden_states = einops.rearrange(hidden_states, '(b f) h w c -> b f h w c', b = b)
        else:
            hidden_states = hidden_states + temb
        hidden_states = self.norm2(hidden_states)
        hidden_states = nn.silu(hidden_states)
        hidden_states = self.conv2(hidden_states)
        if self.conv_shortcut is not None:
            residual = self.conv_shortcut(residual)
        hidden_states = hidden_states + residual
        return hidden_states