File size: 10,821 Bytes
149cc2d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252

from typing import Tuple, Union

import jax
import jax.numpy as jnp
import flax.linen as nn
from flax.core.frozen_dict import FrozenDict

from diffusers.configuration_utils import ConfigMixin, flax_register_to_config
from diffusers.models.modeling_flax_utils import FlaxModelMixin
from diffusers.utils import BaseOutput

from .flax_unet_pseudo3d_blocks import (
        CrossAttnDownBlockPseudo3D,
        CrossAttnUpBlockPseudo3D,
        DownBlockPseudo3D,
        UpBlockPseudo3D,
        UNetMidBlockPseudo3DCrossAttn
)
#from flax_embeddings import (
#        TimestepEmbedding,
#        Timesteps
#)
from diffusers.models.embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps
from .flax_resnet_pseudo3d import ConvPseudo3D


class UNetPseudo3DConditionOutput(BaseOutput):
    sample: jax.Array


@flax_register_to_config
class UNetPseudo3DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
    sample_size: Union[int, Tuple[int, int]] = (64, 64)
    in_channels: int = 4
    out_channels: int = 4
    down_block_types: Tuple[str] = (
            "CrossAttnDownBlockPseudo3D",
            "CrossAttnDownBlockPseudo3D",
            "CrossAttnDownBlockPseudo3D",
            "DownBlockPseudo3D"
    )
    up_block_types: Tuple[str] = (
            "UpBlockPseudo3D",
            "CrossAttnUpBlockPseudo3D",
            "CrossAttnUpBlockPseudo3D",
            "CrossAttnUpBlockPseudo3D"
    )
    block_out_channels: Tuple[int] = (
            320,
            640,
            1280,
            1280
    )
    layers_per_block: int = 2
    attention_head_dim: Union[int, Tuple[int]] = 8
    cross_attention_dim: int = 768
    flip_sin_to_cos: bool = True
    freq_shift: int = 0
    use_memory_efficient_attention: bool = False
    dtype: jnp.dtype = jnp.float32
    param_dtype: str = 'float32'

    def init_weights(self, rng: jax.random.KeyArray) -> FrozenDict:
        if self.param_dtype == 'bfloat16':
            param_dtype = jnp.bfloat16
        elif self.param_dtype == 'float16':
            param_dtype = jnp.float16
        elif self.param_dtype == 'float32':
            param_dtype = jnp.float32
        else:
            raise ValueError(f'unknown parameter type: {self.param_dtype}')
        sample_size = self.sample_size
        if isinstance(sample_size, int):
            sample_size = (sample_size, sample_size)
        sample_shape = (1, self.in_channels, 1, *sample_size)
        sample = jnp.zeros(sample_shape, dtype = param_dtype)
        timesteps = jnp.ones((1, ), dtype = jnp.int32)
        encoder_hidden_states = jnp.zeros((1, 1, self.cross_attention_dim), dtype = param_dtype)
        params_rng, dropout_rng = jax.random.split(rng)
        rngs = { "params": params_rng, "dropout": dropout_rng }
        return self.init(rngs, sample, timesteps, encoder_hidden_states)["params"]

    def setup(self) -> None:
        if isinstance(self.attention_head_dim, int):
            attention_head_dim = (self.attention_head_dim, ) * len(self.down_block_types)
        else:
            attention_head_dim = self.attention_head_dim
        time_embed_dim = self.block_out_channels[0] * 4
        self.conv_in = ConvPseudo3D(
                features = self.block_out_channels[0],
                kernel_size = (3, 3),
                strides = (1, 1),
                padding = ((1, 1), (1, 1)),
                dtype = self.dtype
        )
        self.time_proj = FlaxTimesteps(
                dim = self.block_out_channels[0],
                flip_sin_to_cos = self.flip_sin_to_cos,
                freq_shift = self.freq_shift
        )
        self.time_embedding = FlaxTimestepEmbedding(
                time_embed_dim = time_embed_dim,
                dtype = self.dtype
        )
        down_blocks = []
        output_channels = self.block_out_channels[0]
        for i, down_block_type in enumerate(self.down_block_types):
            input_channels = output_channels
            output_channels = self.block_out_channels[i]
            is_final_block = i == len(self.block_out_channels) - 1
            # allows loading 3d models with old layer type names in their configs
            # eg. 2D instead of Pseudo3D, like lxj's timelapse model
            if down_block_type in ['CrossAttnDownBlockPseudo3D', 'CrossAttnDownBlock2D']:
                down_block = CrossAttnDownBlockPseudo3D(
                        in_channels = input_channels,
                        out_channels = output_channels,
                        num_layers = self.layers_per_block,
                        attn_num_head_channels = attention_head_dim[i],
                        add_downsample = not is_final_block,
                        use_memory_efficient_attention = self.use_memory_efficient_attention,
                        dtype = self.dtype
                )
            elif down_block_type in ['DownBlockPseudo3D', 'DownBlock2D']:
                down_block = DownBlockPseudo3D(
                        in_channels = input_channels,
                        out_channels = output_channels,
                        num_layers = self.layers_per_block,
                        add_downsample = not is_final_block,
                        dtype = self.dtype
                )
            else:
                raise NotImplementedError(f'Unimplemented down block type: {down_block_type}')
            down_blocks.append(down_block)
        self.down_blocks = down_blocks
        self.mid_block = UNetMidBlockPseudo3DCrossAttn(
                in_channels = self.block_out_channels[-1],
                attn_num_head_channels = attention_head_dim[-1],
                use_memory_efficient_attention = self.use_memory_efficient_attention,
                dtype = self.dtype
        )
        up_blocks = []
        reversed_block_out_channels = list(reversed(self.block_out_channels))
        reversed_attention_head_dim = list(reversed(attention_head_dim))
        output_channels = reversed_block_out_channels[0]
        for i, up_block_type in enumerate(self.up_block_types):
            prev_output_channels = output_channels
            output_channels = reversed_block_out_channels[i]
            input_channels = reversed_block_out_channels[min(i + 1, len(self.block_out_channels) - 1)]
            is_final_block = i == len(self.block_out_channels) - 1
            if up_block_type in ['CrossAttnUpBlockPseudo3D', 'CrossAttnUpBlock2D']:
                up_block = CrossAttnUpBlockPseudo3D(
                        in_channels = input_channels,
                        out_channels = output_channels,
                        prev_output_channels = prev_output_channels,
                        num_layers = self.layers_per_block + 1,
                        attn_num_head_channels = reversed_attention_head_dim[i],
                        add_upsample = not is_final_block,
                        use_memory_efficient_attention = self.use_memory_efficient_attention,
                        dtype = self.dtype
                )
            elif up_block_type in ['UpBlockPseudo3D', 'UpBlock2D']:
                up_block = UpBlockPseudo3D(
                        in_channels = input_channels,
                        out_channels = output_channels,
                        prev_output_channels = prev_output_channels,
                        num_layers = self.layers_per_block + 1,
                        add_upsample = not is_final_block,
                        dtype = self.dtype
                )
            else:
                raise NotImplementedError(f'Unimplemented up block type: {up_block_type}')
            up_blocks.append(up_block)
        self.up_blocks = up_blocks
        self.conv_norm_out = nn.GroupNorm(
                num_groups = 32,
                epsilon = 1e-5
        )
        self.conv_out = ConvPseudo3D(
                features = self.out_channels,
                kernel_size = (3, 3),
                strides = (1, 1),
                padding = ((1, 1), (1, 1)),
                dtype = self.dtype
        )

    def __call__(self,
            sample: jax.Array,
            timesteps: jax.Array,
            encoder_hidden_states: jax.Array,
            return_dict: bool = True
    ) -> Union[UNetPseudo3DConditionOutput, Tuple[jax.Array]]:
        if timesteps.dtype != jnp.float32:
            timesteps = timesteps.astype(dtype = jnp.float32)
        if len(timesteps.shape) == 0:
            timesteps = jnp.expand_dims(timesteps, 0)
        # b,c,f,h,w -> b,f,h,w,c
        sample = sample.transpose((0, 2, 3, 4, 1))

        t_emb = self.time_proj(timesteps)
        t_emb = self.time_embedding(t_emb)
        sample = self.conv_in(sample)
        down_block_res_samples = (sample, )
        for down_block in self.down_blocks:
            if isinstance(down_block, CrossAttnDownBlockPseudo3D):
                sample, res_samples = down_block(
                        hidden_states = sample,
                        temb = t_emb,
                        encoder_hidden_states = encoder_hidden_states
                )
            elif isinstance(down_block, DownBlockPseudo3D):
                sample, res_samples = down_block(
                        hidden_states = sample,
                        temb = t_emb
                )
            else:
                raise NotImplementedError(f'Unimplemented down block type: {down_block.__class__.__name__}')
            down_block_res_samples += res_samples
        sample = self.mid_block(
                hidden_states = sample,
                temb = t_emb,
                encoder_hidden_states = encoder_hidden_states
        )
        for up_block in self.up_blocks:
            res_samples = down_block_res_samples[-(self.layers_per_block + 1):]
            down_block_res_samples = down_block_res_samples[:-(self.layers_per_block + 1)]
            if isinstance(up_block, CrossAttnUpBlockPseudo3D):
                sample = up_block(
                        hidden_states = sample,
                        temb = t_emb,
                        encoder_hidden_states = encoder_hidden_states,
                        res_hidden_states_tuple = res_samples
                )
            elif isinstance(up_block, UpBlockPseudo3D):
                sample = up_block(
                        hidden_states = sample,
                        temb = t_emb,
                        res_hidden_states_tuple = res_samples
                )
            else:
                raise NotImplementedError(f'Unimplemented up block type: {up_block.__class__.__name__}')
        sample = self.conv_norm_out(sample)
        sample = nn.silu(sample)
        sample = self.conv_out(sample)

        # b,f,h,w,c -> b,c,f,h,w
        sample = sample.transpose((0, 4, 1, 2, 3))
        if not return_dict:
            return (sample, )
        return UNetPseudo3DConditionOutput(sample = sample)