File size: 15,610 Bytes
2a00960
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
# -*- coding: utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
import re
from collections import OrderedDict
from functools import partial

import torch
import torch.nn as nn
from einops import rearrange
from torch.nn.utils.rnn import pad_sequence
from torch.utils.checkpoint import checkpoint_sequential

from scepter.modules.model.base_model import BaseModel
from scepter.modules.model.registry import BACKBONES
from scepter.modules.utils.config import dict_to_yaml
from scepter.modules.utils.file_system import FS

from .layers import (
    Mlp,
    TimestepEmbedder,
    PatchEmbed,
    DiTACEBlock,
    T2IFinalLayer
)
from .pos_embed import rope_params


@BACKBONES.register_class()
class DiTACE(BaseModel):

    para_dict = {
        'PATCH_SIZE': {
            'value': 2,
            'description': ''
        },
        'IN_CHANNELS': {
            'value': 4,
            'description': ''
        },
        'HIDDEN_SIZE': {
            'value': 1152,
            'description': ''
        },
        'DEPTH': {
            'value': 28,
            'description': ''
        },
        'NUM_HEADS': {
            'value': 16,
            'description': ''
        },
        'MLP_RATIO': {
            'value': 4.0,
            'description': ''
        },
        'PRED_SIGMA': {
            'value': True,
            'description': ''
        },
        'DROP_PATH': {
            'value': 0.,
            'description': ''
        },
        'WINDOW_SIZE': {
            'value': 0,
            'description': ''
        },
        'WINDOW_BLOCK_INDEXES': {
            'value': None,
            'description': ''
        },
        'Y_CHANNELS': {
            'value': 4096,
            'description': ''
        },
        'ATTENTION_BACKEND': {
            'value': None,
            'description': ''
        },
        'QK_NORM': {
            'value': True,
            'description': 'Whether to use RMSNorm for query and key.',
        },
    }
    para_dict.update(BaseModel.para_dict)

    def __init__(self, cfg, logger):
        super().__init__(cfg, logger=logger)
        self.window_block_indexes = cfg.get('WINDOW_BLOCK_INDEXES', None)
        if self.window_block_indexes is None:
            self.window_block_indexes = []
        self.pred_sigma = cfg.get('PRED_SIGMA', True)
        self.in_channels = cfg.get('IN_CHANNELS', 4)
        self.out_channels = self.in_channels * 2 if self.pred_sigma else self.in_channels
        self.patch_size = cfg.get('PATCH_SIZE', 2)
        self.num_heads = cfg.get('NUM_HEADS', 16)
        self.hidden_size = cfg.get('HIDDEN_SIZE', 1152)
        self.y_channels = cfg.get('Y_CHANNELS', 4096)
        self.drop_path = cfg.get('DROP_PATH', 0.)
        self.depth = cfg.get('DEPTH', 28)
        self.mlp_ratio = cfg.get('MLP_RATIO', 4.0)
        self.use_grad_checkpoint = cfg.get('USE_GRAD_CHECKPOINT', False)
        self.attention_backend = cfg.get('ATTENTION_BACKEND', None)
        self.max_seq_len = cfg.get('MAX_SEQ_LEN', 1024)
        self.qk_norm = cfg.get('QK_NORM', False)
        self.ignore_keys = cfg.get('IGNORE_KEYS', [])
        assert (self.hidden_size % self.num_heads
                ) == 0 and (self.hidden_size // self.num_heads) % 2 == 0
        d = self.hidden_size // self.num_heads
        self.freqs = torch.cat(
            [
                rope_params(self.max_seq_len, d - 4 * (d // 6)),  # T (~1/3)
                rope_params(self.max_seq_len, 2 * (d // 6)),  # H (~1/3)
                rope_params(self.max_seq_len, 2 * (d // 6))  # W (~1/3)
            ],
            dim=1)

        # init embedder
        self.x_embedder = PatchEmbed(self.patch_size,
                                     self.in_channels + 1,
                                     self.hidden_size,
                                     bias=True,
                                     flatten=False)
        self.t_embedder = TimestepEmbedder(self.hidden_size)
        self.y_embedder = Mlp(in_features=self.y_channels,
                              hidden_features=self.hidden_size,
                              out_features=self.hidden_size,
                              act_layer=lambda: nn.GELU(approximate='tanh'),
                              drop=0)
        self.t_block = nn.Sequential(
            nn.SiLU(),
            nn.Linear(self.hidden_size, 6 * self.hidden_size, bias=True))
        # init blocks
        drop_path = [
            x.item() for x in torch.linspace(0, self.drop_path, self.depth)
        ]
        self.blocks = nn.ModuleList([
            DiTACEBlock(self.hidden_size,
                        self.num_heads,
                        mlp_ratio=self.mlp_ratio,
                        drop_path=drop_path[i],
                        window_size=self.window_size
                        if i in self.window_block_indexes else 0,
                        backend=self.attention_backend,
                        use_condition=True,
                        qk_norm=self.qk_norm) for i in range(self.depth)
        ])
        self.final_layer = T2IFinalLayer(self.hidden_size, self.patch_size,
                                         self.out_channels)
        self.initialize_weights()

    def load_pretrained_model(self, pretrained_model):
        if pretrained_model:
            with FS.get_from(pretrained_model, wait_finish=True) as local_path:
                model = torch.load(local_path, map_location='cpu')
                if 'state_dict' in model:
                    model = model['state_dict']
                new_ckpt = OrderedDict()
                for k, v in model.items():
                    if self.ignore_keys is not None:
                        if (isinstance(self.ignore_keys, str) and re.match(self.ignore_keys, k)) or \
                                (isinstance(self.ignore_keys, list) and k in self.ignore_keys):
                            continue
                    k = k.replace('.cross_attn.q_linear.', '.cross_attn.q.')
                    k = k.replace('.cross_attn.proj.',
                                  '.cross_attn.o.').replace(
                                      '.attn.proj.', '.attn.o.')
                    if '.cross_attn.kv_linear.' in k:
                        k_p, v_p = torch.split(v, v.shape[0] // 2)
                        new_ckpt[k.replace('.cross_attn.kv_linear.',
                                           '.cross_attn.k.')] = k_p
                        new_ckpt[k.replace('.cross_attn.kv_linear.',
                                           '.cross_attn.v.')] = v_p
                    elif '.attn.qkv.' in k:
                        q_p, k_p, v_p = torch.split(v, v.shape[0] // 3)
                        new_ckpt[k.replace('.attn.qkv.', '.attn.q.')] = q_p
                        new_ckpt[k.replace('.attn.qkv.', '.attn.k.')] = k_p
                        new_ckpt[k.replace('.attn.qkv.', '.attn.v.')] = v_p
                    elif 'y_embedder.y_proj.' in k:
                        new_ckpt[k.replace('y_embedder.y_proj.',
                                           'y_embedder.')] = v
                    elif k in ('x_embedder.proj.weight'):
                        model_p = self.state_dict()[k]
                        if v.shape != model_p.shape:
                            model_p.zero_()
                            model_p[:, :4, :, :].copy_(v)
                            new_ckpt[k] = torch.nn.parameter.Parameter(model_p)
                        else:
                            new_ckpt[k] = v
                    elif k in ('x_embedder.proj.bias'):
                        new_ckpt[k] = v
                    else:
                        new_ckpt[k] = v
                missing, unexpected = self.load_state_dict(new_ckpt,
                                                           strict=False)
                print(
                    f'Restored from {pretrained_model} with {len(missing)} missing and {len(unexpected)} unexpected keys'
                )
                if len(missing) > 0:
                    print(f'Missing Keys:\n {missing}')
                if len(unexpected) > 0:
                    print(f'\nUnexpected Keys:\n {unexpected}')

    def forward(self,
                x,
                t=None,
                cond=dict(),
                mask=None,
                text_position_embeddings=None,
                gc_seg=-1,
                **kwargs):
        if self.freqs.device != x.device:
            self.freqs = self.freqs.to(x.device)
        if isinstance(cond, dict):
            context = cond.get('crossattn', None)
        else:
            context = cond
        if text_position_embeddings is not None:
            # default use the text_position_embeddings in state_dict
            # if state_dict doesn't including this key, use the arg: text_position_embeddings
            proj_position_embeddings = self.y_embedder(
                text_position_embeddings)
        else:
            proj_position_embeddings = None

        ctx_batch, txt_lens = [], []
        if mask is not None and isinstance(mask, list):
            for ctx, ctx_mask in zip(context, mask):
                for frame_id, one_ctx in enumerate(zip(ctx, ctx_mask)):
                    u, m = one_ctx
                    t_len = m.flatten().sum()  # l
                    u = u[:t_len]
                    u = self.y_embedder(u)
                    if frame_id == 0:
                        u = u + proj_position_embeddings[
                            len(ctx) -
                            1] if proj_position_embeddings is not None else u
                    else:
                        u = u + proj_position_embeddings[
                            frame_id -
                            1] if proj_position_embeddings is not None else u
                    ctx_batch.append(u)
                    txt_lens.append(t_len)
        else:
            raise TypeError
        y = torch.cat(ctx_batch, dim=0)
        txt_lens = torch.LongTensor(txt_lens).to(x.device, non_blocking=True)

        batch_frames = []
        for u, shape, m in zip(x, cond['x_shapes'], cond['x_mask']):
            u = u[:, :shape[0] * shape[1]].view(-1, shape[0], shape[1])
            m = torch.ones_like(u[[0], :, :]) if m is None else m.squeeze(0)
            batch_frames.append([torch.cat([u, m], dim=0).unsqueeze(0)])
        if 'edit' in cond:
            for i, (edit, edit_mask) in enumerate(
                    zip(cond['edit'], cond['edit_mask'])):
                if edit is None:
                    continue
                for u, m in zip(edit, edit_mask):
                    u = u.squeeze(0)
                    m = torch.ones_like(
                        u[[0], :, :]) if m is None else m.squeeze(0)
                    batch_frames[i].append(
                        torch.cat([u, m], dim=0).unsqueeze(0))

        patch_batch, shape_batch, self_x_len, cross_x_len = [], [], [], []
        for frames in batch_frames:
            patches, patch_shapes = [], []
            self_x_len.append(0)
            for frame_id, u in enumerate(frames):
                u = self.x_embedder(u)
                h, w = u.size(2), u.size(3)
                u = rearrange(u, '1 c h w -> (h w) c')
                if frame_id == 0:
                    u = u + proj_position_embeddings[
                        len(frames) -
                        1] if proj_position_embeddings is not None else u
                else:
                    u = u + proj_position_embeddings[
                        frame_id -
                        1] if proj_position_embeddings is not None else u
                patches.append(u)
                patch_shapes.append([h, w])
                cross_x_len.append(h * w)  # b*s, 1
                self_x_len[-1] += h * w  # b, 1
            # u = torch.cat(patches, dim=0)
            patch_batch.extend(patches)
            shape_batch.append(
                torch.LongTensor(patch_shapes).to(x.device, non_blocking=True))
        # repeat t to align with x
        t = torch.cat([t[i].repeat(l) for i, l in enumerate(self_x_len)])
        self_x_len, cross_x_len = (torch.LongTensor(self_x_len).to(
            x.device, non_blocking=True), torch.LongTensor(cross_x_len).to(
                x.device, non_blocking=True))
        # x = pad_sequence(tuple(patch_batch), batch_first=True)  # b, s*max(cl), c
        x = torch.cat(patch_batch, dim=0)
        x_shapes = pad_sequence(tuple(shape_batch),
                                batch_first=True)  # b, max(len(frames)), 2
        t = self.t_embedder(t)  # (N, D)
        t0 = self.t_block(t)
        # y = self.y_embedder(context)

        kwargs = dict(y=y,
                      t=t0,
                      x_shapes=x_shapes,
                      self_x_len=self_x_len,
                      cross_x_len=cross_x_len,
                      freqs=self.freqs,
                      txt_lens=txt_lens)
        if self.use_grad_checkpoint and gc_seg >= 0:
            x = checkpoint_sequential(
                functions=[partial(block, **kwargs) for block in self.blocks],
                segments=gc_seg if gc_seg > 0 else len(self.blocks),
                input=x,
                use_reentrant=False)
        else:
            for block in self.blocks:
                x = block(x, **kwargs)
        x = self.final_layer(x, t)  # b*s*n, d
        outs, cur_length = [], 0
        p = self.patch_size
        for seq_length, shape in zip(self_x_len, shape_batch):
            x_i = x[cur_length:cur_length + seq_length]
            h, w = shape[0].tolist()
            u = x_i[:h * w].view(h, w, p, p, -1)
            u = rearrange(u, 'h w p q c -> (h p w q) c'
                          )  # dump into sequence for following tensor ops
            cur_length = cur_length + seq_length
            outs.append(u)
        x = pad_sequence(tuple(outs), batch_first=True).permute(0, 2, 1)
        if self.pred_sigma:
            return x.chunk(2, dim=1)[0]
        else:
            return x

    def initialize_weights(self):
        # Initialize transformer layers:
        def _basic_init(module):
            if isinstance(module, nn.Linear):
                torch.nn.init.xavier_uniform_(module.weight)
                if module.bias is not None:
                    nn.init.constant_(module.bias, 0)

        self.apply(_basic_init)
        # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
        w = self.x_embedder.proj.weight.data
        nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
        # Initialize timestep embedding MLP:
        nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
        nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
        nn.init.normal_(self.t_block[1].weight, std=0.02)
        # Initialize caption embedding MLP:
        if hasattr(self, 'y_embedder'):
            nn.init.normal_(self.y_embedder.fc1.weight, std=0.02)
            nn.init.normal_(self.y_embedder.fc2.weight, std=0.02)
        # Zero-out adaLN modulation layers
        for block in self.blocks:
            nn.init.constant_(block.cross_attn.o.weight, 0)
            nn.init.constant_(block.cross_attn.o.bias, 0)
        # Zero-out output layers:
        nn.init.constant_(self.final_layer.linear.weight, 0)
        nn.init.constant_(self.final_layer.linear.bias, 0)

    @property
    def dtype(self):
        return next(self.parameters()).dtype

    @staticmethod
    def get_config_template():
        return dict_to_yaml('BACKBONE',
                            __class__.__name__,
                            DiTACE.para_dict,
                            set_name=True)