pan-yl commited on
Commit
9f254e0
1 Parent(s): c7b412c

modify somefiles

Browse files
infer.py DELETED
@@ -1,364 +0,0 @@
1
- # -*- coding: utf-8 -*-
2
- # Copyright (c) Alibaba, Inc. and its affiliates.
3
- import copy
4
- import math
5
- import random
6
- import numpy as np
7
- from PIL import Image
8
-
9
- import torch
10
- import torch.nn as nn
11
- import torch.nn.functional as F
12
- import torchvision.transforms.functional as TF
13
-
14
- from scepter.modules.model.registry import DIFFUSIONS
15
- from scepter.modules.utils.distribute import we
16
- from scepter.modules.utils.logger import get_logger
17
- from scepter.modules.inference.diffusion_inference import DiffusionInference, get_model
18
-
19
- from modules.model.utils.basic_utils import (
20
- check_list_of_list,
21
- pack_imagelist_into_tensor_v2 as pack_imagelist_into_tensor,
22
- to_device,
23
- unpack_tensor_into_imagelist
24
- )
25
-
26
-
27
- def process_edit_image(images,
28
- masks,
29
- tasks,
30
- max_seq_len=1024,
31
- max_aspect_ratio=4,
32
- d=16,
33
- **kwargs):
34
-
35
- if not isinstance(images, list):
36
- images = [images]
37
- if not isinstance(masks, list):
38
- masks = [masks]
39
- if not isinstance(tasks, list):
40
- tasks = [tasks]
41
-
42
- img_tensors = []
43
- mask_tensors = []
44
- for img, mask, task in zip(images, masks, tasks):
45
- if mask is None or mask == '':
46
- mask = Image.new('L', img.size, 0)
47
- W, H = img.size
48
- if H / W > max_aspect_ratio:
49
- img = TF.center_crop(img, [int(max_aspect_ratio * W), W])
50
- mask = TF.center_crop(mask, [int(max_aspect_ratio * W), W])
51
- elif W / H > max_aspect_ratio:
52
- img = TF.center_crop(img, [H, int(max_aspect_ratio * H)])
53
- mask = TF.center_crop(mask, [H, int(max_aspect_ratio * H)])
54
-
55
- H, W = img.height, img.width
56
- scale = min(1.0, math.sqrt(max_seq_len / ((H / d) * (W / d))))
57
- rH = int(H * scale) // d * d # ensure divisible by self.d
58
- rW = int(W * scale) // d * d
59
-
60
- img = TF.resize(img, (rH, rW),
61
- interpolation=TF.InterpolationMode.BICUBIC)
62
- mask = TF.resize(mask, (rH, rW),
63
- interpolation=TF.InterpolationMode.NEAREST_EXACT)
64
-
65
- mask = np.asarray(mask)
66
- mask = np.where(mask > 128, 1, 0)
67
- mask = mask.astype(
68
- np.float32) if np.any(mask) else np.ones_like(mask).astype(
69
- np.float32)
70
-
71
- img_tensor = TF.to_tensor(img).to(we.device_id)
72
- img_tensor = TF.normalize(img_tensor,
73
- mean=[0.5, 0.5, 0.5],
74
- std=[0.5, 0.5, 0.5])
75
- mask_tensor = TF.to_tensor(mask).to(we.device_id)
76
- if task in ['inpainting', 'Try On', 'Inpainting']:
77
- mask_indicator = mask_tensor.repeat(3, 1, 1)
78
- img_tensor[mask_indicator == 1] = -1.0
79
- img_tensors.append(img_tensor)
80
- mask_tensors.append(mask_tensor)
81
- return img_tensors, mask_tensors
82
-
83
-
84
- class TextEmbedding(nn.Module):
85
- def __init__(self, embedding_shape):
86
- super().__init__()
87
- self.pos = nn.Parameter(data=torch.zeros(embedding_shape))
88
-
89
-
90
- class ACEInference(DiffusionInference):
91
- def __init__(self, logger=None):
92
- if logger is None:
93
- logger = get_logger(name='scepter')
94
- self.logger = logger
95
- self.loaded_model = {}
96
- self.loaded_model_name = [
97
- 'diffusion_model', 'first_stage_model', 'cond_stage_model'
98
- ]
99
-
100
- def init_from_cfg(self, cfg):
101
- self.name = cfg.NAME
102
- self.is_default = cfg.get('IS_DEFAULT', False)
103
- module_paras = self.load_default(cfg.get('DEFAULT_PARAS', None))
104
- assert cfg.have('MODEL')
105
-
106
- self.diffusion_model = self.infer_model(
107
- cfg.MODEL.DIFFUSION_MODEL, module_paras.get(
108
- 'DIFFUSION_MODEL',
109
- None)) if cfg.MODEL.have('DIFFUSION_MODEL') else None
110
- self.first_stage_model = self.infer_model(
111
- cfg.MODEL.FIRST_STAGE_MODEL,
112
- module_paras.get(
113
- 'FIRST_STAGE_MODEL',
114
- None)) if cfg.MODEL.have('FIRST_STAGE_MODEL') else None
115
- self.cond_stage_model = self.infer_model(
116
- cfg.MODEL.COND_STAGE_MODEL,
117
- module_paras.get(
118
- 'COND_STAGE_MODEL',
119
- None)) if cfg.MODEL.have('COND_STAGE_MODEL') else None
120
- self.diffusion = DIFFUSIONS.build(cfg.MODEL.DIFFUSION,
121
- logger=self.logger)
122
-
123
- self.interpolate_func = lambda x: (F.interpolate(
124
- x.unsqueeze(0),
125
- scale_factor=1 / self.size_factor,
126
- mode='nearest-exact') if x is not None else None)
127
- self.text_indentifers = cfg.MODEL.get('TEXT_IDENTIFIER', [])
128
- self.use_text_pos_embeddings = cfg.MODEL.get('USE_TEXT_POS_EMBEDDINGS',
129
- False)
130
- if self.use_text_pos_embeddings:
131
- self.text_position_embeddings = TextEmbedding(
132
- (10, 4096)).eval().requires_grad_(False).to(we.device_id)
133
- else:
134
- self.text_position_embeddings = None
135
-
136
- self.max_seq_len = cfg.MODEL.DIFFUSION_MODEL.MAX_SEQ_LEN
137
- self.scale_factor = cfg.get('SCALE_FACTOR', 0.18215)
138
- self.size_factor = cfg.get('SIZE_FACTOR', 8)
139
- self.decoder_bias = cfg.get('DECODER_BIAS', 0)
140
- self.default_n_prompt = cfg.get('DEFAULT_N_PROMPT', '')
141
-
142
- self.dynamic_load(self.first_stage_model, 'first_stage_model')
143
- self.dynamic_load(self.cond_stage_model, 'cond_stage_model')
144
- self.dynamic_load(self.diffusion_model, 'diffusion_model')
145
-
146
- @torch.no_grad()
147
- def encode_first_stage(self, x, **kwargs):
148
- _, dtype = self.get_function_info(self.first_stage_model, 'encode')
149
- with torch.autocast('cuda',
150
- enabled=(dtype != 'float32'),
151
- dtype=getattr(torch, dtype)):
152
- z = [
153
- self.scale_factor * get_model(self.first_stage_model)._encode(
154
- i.unsqueeze(0).to(getattr(torch, dtype))) for i in x
155
- ]
156
- return z
157
-
158
- @torch.no_grad()
159
- def decode_first_stage(self, z):
160
- _, dtype = self.get_function_info(self.first_stage_model, 'decode')
161
- with torch.autocast('cuda',
162
- enabled=(dtype != 'float32'),
163
- dtype=getattr(torch, dtype)):
164
- x = [
165
- get_model(self.first_stage_model)._decode(
166
- 1. / self.scale_factor * i.to(getattr(torch, dtype)))
167
- for i in z
168
- ]
169
- return x
170
-
171
- @torch.no_grad()
172
- def __call__(self,
173
- image=None,
174
- mask=None,
175
- prompt='',
176
- task=None,
177
- negative_prompt='',
178
- output_height=512,
179
- output_width=512,
180
- sampler='ddim',
181
- sample_steps=20,
182
- guide_scale=4.5,
183
- guide_rescale=0.5,
184
- seed=-1,
185
- history_io=None,
186
- tar_index=0,
187
- **kwargs):
188
- input_image, input_mask = image, mask
189
- g = torch.Generator(device=we.device_id)
190
- seed = seed if seed >= 0 else random.randint(0, 2**32 - 1)
191
- g.manual_seed(int(seed))
192
-
193
- if input_image is not None:
194
- assert isinstance(input_image, list) and isinstance(
195
- input_mask, list)
196
- if task is None:
197
- task = [''] * len(input_image)
198
- if not isinstance(prompt, list):
199
- prompt = [prompt] * len(input_image)
200
- if history_io is not None and len(history_io) > 0:
201
- his_image, his_maks, his_prompt, his_task = history_io[
202
- 'image'], history_io['mask'], history_io[
203
- 'prompt'], history_io['task']
204
- assert len(his_image) == len(his_maks) == len(
205
- his_prompt) == len(his_task)
206
- input_image = his_image + input_image
207
- input_mask = his_maks + input_mask
208
- task = his_task + task
209
- prompt = his_prompt + [prompt[-1]]
210
- prompt = [
211
- pp.replace('{image}', f'{{image{i}}}') if i > 0 else pp
212
- for i, pp in enumerate(prompt)
213
- ]
214
-
215
- edit_image, edit_image_mask = process_edit_image(
216
- input_image, input_mask, task, max_seq_len=self.max_seq_len)
217
-
218
- image, image_mask = edit_image[tar_index], edit_image_mask[
219
- tar_index]
220
- edit_image, edit_image_mask = [edit_image], [edit_image_mask]
221
-
222
- else:
223
- edit_image = edit_image_mask = [[]]
224
- image = torch.zeros(
225
- size=[3, int(output_height),
226
- int(output_width)])
227
- image_mask = torch.ones(
228
- size=[1, int(output_height),
229
- int(output_width)])
230
- if not isinstance(prompt, list):
231
- prompt = [prompt]
232
-
233
- image, image_mask, prompt = [image], [image_mask], [prompt]
234
- assert check_list_of_list(prompt) and check_list_of_list(
235
- edit_image) and check_list_of_list(edit_image_mask)
236
- # Assign Negative Prompt
237
- if isinstance(negative_prompt, list):
238
- negative_prompt = negative_prompt[0]
239
- assert isinstance(negative_prompt, str)
240
-
241
- n_prompt = copy.deepcopy(prompt)
242
- for nn_p_id, nn_p in enumerate(n_prompt):
243
- assert isinstance(nn_p, list)
244
- n_prompt[nn_p_id][-1] = negative_prompt
245
-
246
- ctx, null_ctx = {}, {}
247
-
248
- # Get Noise Shape
249
- image = to_device(image)
250
- x = self.encode_first_stage(image)
251
- noise = [
252
- torch.empty(*i.shape, device=we.device_id).normal_(generator=g)
253
- for i in x
254
- ]
255
- noise, x_shapes = pack_imagelist_into_tensor(noise)
256
- ctx['x_shapes'] = null_ctx['x_shapes'] = x_shapes
257
-
258
- image_mask = to_device(image_mask, strict=False)
259
- cond_mask = [self.interpolate_func(i) for i in image_mask
260
- ] if image_mask is not None else [None] * len(image)
261
- ctx['x_mask'] = null_ctx['x_mask'] = cond_mask
262
-
263
- # Encode Prompt
264
-
265
- function_name, dtype = self.get_function_info(self.cond_stage_model)
266
- cont, cont_mask = getattr(get_model(self.cond_stage_model),
267
- function_name)(prompt)
268
- cont, cont_mask = self.cond_stage_embeddings(prompt, edit_image, cont,
269
- cont_mask)
270
- null_cont, null_cont_mask = getattr(get_model(self.cond_stage_model),
271
- function_name)(n_prompt)
272
- null_cont, null_cont_mask = self.cond_stage_embeddings(
273
- prompt, edit_image, null_cont, null_cont_mask)
274
- ctx['crossattn'] = cont
275
- null_ctx['crossattn'] = null_cont
276
-
277
- # Encode Edit Images
278
- edit_image = [to_device(i, strict=False) for i in edit_image]
279
- edit_image_mask = [to_device(i, strict=False) for i in edit_image_mask]
280
- e_img, e_mask = [], []
281
- for u, m in zip(edit_image, edit_image_mask):
282
- if u is None:
283
- continue
284
- if m is None:
285
- m = [None] * len(u)
286
- e_img.append(self.encode_first_stage(u, **kwargs))
287
- e_mask.append([self.interpolate_func(i) for i in m])
288
-
289
- null_ctx['edit'] = ctx['edit'] = e_img
290
- null_ctx['edit_mask'] = ctx['edit_mask'] = e_mask
291
-
292
- # Diffusion Process
293
- function_name, dtype = self.get_function_info(self.diffusion_model)
294
- with torch.autocast('cuda',
295
- enabled=dtype in ('float16', 'bfloat16'),
296
- dtype=getattr(torch, dtype)):
297
- latent = self.diffusion.sample(
298
- noise=noise,
299
- sampler=sampler,
300
- model=get_model(self.diffusion_model),
301
- model_kwargs=[{
302
- 'cond':
303
- ctx,
304
- 'mask':
305
- cont_mask,
306
- 'text_position_embeddings':
307
- self.text_position_embeddings.pos if hasattr(
308
- self.text_position_embeddings, 'pos') else None
309
- }, {
310
- 'cond':
311
- null_ctx,
312
- 'mask':
313
- null_cont_mask,
314
- 'text_position_embeddings':
315
- self.text_position_embeddings.pos if hasattr(
316
- self.text_position_embeddings, 'pos') else None
317
- }] if guide_scale is not None and guide_scale > 1 else {
318
- 'cond':
319
- null_ctx,
320
- 'mask':
321
- cont_mask,
322
- 'text_position_embeddings':
323
- self.text_position_embeddings.pos if hasattr(
324
- self.text_position_embeddings, 'pos') else None
325
- },
326
- steps=sample_steps,
327
- show_progress=True,
328
- seed=seed,
329
- guide_scale=guide_scale,
330
- guide_rescale=guide_rescale,
331
- return_intermediate=None,
332
- **kwargs)
333
-
334
- # Decode to Pixel Space
335
- samples = unpack_tensor_into_imagelist(latent, x_shapes)
336
- x_samples = self.decode_first_stage(samples)
337
-
338
- imgs = [
339
- torch.clamp((x_i + 1.0) / 2.0 + self.decoder_bias / 255,
340
- min=0.0,
341
- max=1.0).squeeze(0).permute(1, 2, 0).cpu().numpy()
342
- for x_i in x_samples
343
- ]
344
- imgs = [Image.fromarray((img * 255).astype(np.uint8)) for img in imgs]
345
- return imgs
346
-
347
- def cond_stage_embeddings(self, prompt, edit_image, cont, cont_mask):
348
- if self.use_text_pos_embeddings and not torch.sum(
349
- self.text_position_embeddings.pos) > 0:
350
- identifier_cont, _ = getattr(get_model(self.cond_stage_model),
351
- 'encode')(self.text_indentifers,
352
- return_mask=True)
353
- self.text_position_embeddings.load_state_dict(
354
- {'pos': identifier_cont[:, 0, :]})
355
-
356
- cont_, cont_mask_ = [], []
357
- for pp, edit, c, cm in zip(prompt, edit_image, cont, cont_mask):
358
- if isinstance(pp, list):
359
- cont_.append([c[-1], *c] if len(edit) > 0 else [c[-1]])
360
- cont_mask_.append([cm[-1], *cm] if len(edit) > 0 else [cm[-1]])
361
- else:
362
- raise NotImplementedError
363
-
364
- return cont_, cont_mask_
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
modules/__init__.py DELETED
@@ -1 +0,0 @@
1
- from . import model
 
 
modules/model/__init__.py DELETED
@@ -1 +0,0 @@
1
- from . import backbone, embedder, diffusion, network
 
 
modules/model/backbone/__init__.py DELETED
@@ -1,3 +0,0 @@
1
- # -*- coding: utf-8 -*-
2
- # Copyright (c) Alibaba, Inc. and its affiliates.
3
- from .ace import DiTACE
 
 
 
 
modules/model/backbone/ace.py DELETED
@@ -1,373 +0,0 @@
1
- # -*- coding: utf-8 -*-
2
- # Copyright (c) Alibaba, Inc. and its affiliates.
3
- import re
4
- from collections import OrderedDict
5
- from functools import partial
6
-
7
- import torch
8
- import torch.nn as nn
9
- from einops import rearrange
10
- from torch.nn.utils.rnn import pad_sequence
11
- from torch.utils.checkpoint import checkpoint_sequential
12
-
13
- from scepter.modules.model.base_model import BaseModel
14
- from scepter.modules.model.registry import BACKBONES
15
- from scepter.modules.utils.config import dict_to_yaml
16
- from scepter.modules.utils.file_system import FS
17
-
18
- from .layers import (
19
- Mlp,
20
- TimestepEmbedder,
21
- PatchEmbed,
22
- DiTACEBlock,
23
- T2IFinalLayer
24
- )
25
- from .pos_embed import rope_params
26
-
27
-
28
- @BACKBONES.register_class()
29
- class DiTACE(BaseModel):
30
-
31
- para_dict = {
32
- 'PATCH_SIZE': {
33
- 'value': 2,
34
- 'description': ''
35
- },
36
- 'IN_CHANNELS': {
37
- 'value': 4,
38
- 'description': ''
39
- },
40
- 'HIDDEN_SIZE': {
41
- 'value': 1152,
42
- 'description': ''
43
- },
44
- 'DEPTH': {
45
- 'value': 28,
46
- 'description': ''
47
- },
48
- 'NUM_HEADS': {
49
- 'value': 16,
50
- 'description': ''
51
- },
52
- 'MLP_RATIO': {
53
- 'value': 4.0,
54
- 'description': ''
55
- },
56
- 'PRED_SIGMA': {
57
- 'value': True,
58
- 'description': ''
59
- },
60
- 'DROP_PATH': {
61
- 'value': 0.,
62
- 'description': ''
63
- },
64
- 'WINDOW_SIZE': {
65
- 'value': 0,
66
- 'description': ''
67
- },
68
- 'WINDOW_BLOCK_INDEXES': {
69
- 'value': None,
70
- 'description': ''
71
- },
72
- 'Y_CHANNELS': {
73
- 'value': 4096,
74
- 'description': ''
75
- },
76
- 'ATTENTION_BACKEND': {
77
- 'value': None,
78
- 'description': ''
79
- },
80
- 'QK_NORM': {
81
- 'value': True,
82
- 'description': 'Whether to use RMSNorm for query and key.',
83
- },
84
- }
85
- para_dict.update(BaseModel.para_dict)
86
-
87
- def __init__(self, cfg, logger):
88
- super().__init__(cfg, logger=logger)
89
- self.window_block_indexes = cfg.get('WINDOW_BLOCK_INDEXES', None)
90
- if self.window_block_indexes is None:
91
- self.window_block_indexes = []
92
- self.pred_sigma = cfg.get('PRED_SIGMA', True)
93
- self.in_channels = cfg.get('IN_CHANNELS', 4)
94
- self.out_channels = self.in_channels * 2 if self.pred_sigma else self.in_channels
95
- self.patch_size = cfg.get('PATCH_SIZE', 2)
96
- self.num_heads = cfg.get('NUM_HEADS', 16)
97
- self.hidden_size = cfg.get('HIDDEN_SIZE', 1152)
98
- self.y_channels = cfg.get('Y_CHANNELS', 4096)
99
- self.drop_path = cfg.get('DROP_PATH', 0.)
100
- self.depth = cfg.get('DEPTH', 28)
101
- self.mlp_ratio = cfg.get('MLP_RATIO', 4.0)
102
- self.use_grad_checkpoint = cfg.get('USE_GRAD_CHECKPOINT', False)
103
- self.attention_backend = cfg.get('ATTENTION_BACKEND', None)
104
- self.max_seq_len = cfg.get('MAX_SEQ_LEN', 1024)
105
- self.qk_norm = cfg.get('QK_NORM', False)
106
- self.ignore_keys = cfg.get('IGNORE_KEYS', [])
107
- assert (self.hidden_size % self.num_heads
108
- ) == 0 and (self.hidden_size // self.num_heads) % 2 == 0
109
- d = self.hidden_size // self.num_heads
110
- self.freqs = torch.cat(
111
- [
112
- rope_params(self.max_seq_len, d - 4 * (d // 6)), # T (~1/3)
113
- rope_params(self.max_seq_len, 2 * (d // 6)), # H (~1/3)
114
- rope_params(self.max_seq_len, 2 * (d // 6)) # W (~1/3)
115
- ],
116
- dim=1)
117
-
118
- # init embedder
119
- self.x_embedder = PatchEmbed(self.patch_size,
120
- self.in_channels + 1,
121
- self.hidden_size,
122
- bias=True,
123
- flatten=False)
124
- self.t_embedder = TimestepEmbedder(self.hidden_size)
125
- self.y_embedder = Mlp(in_features=self.y_channels,
126
- hidden_features=self.hidden_size,
127
- out_features=self.hidden_size,
128
- act_layer=lambda: nn.GELU(approximate='tanh'),
129
- drop=0)
130
- self.t_block = nn.Sequential(
131
- nn.SiLU(),
132
- nn.Linear(self.hidden_size, 6 * self.hidden_size, bias=True))
133
- # init blocks
134
- drop_path = [
135
- x.item() for x in torch.linspace(0, self.drop_path, self.depth)
136
- ]
137
- self.blocks = nn.ModuleList([
138
- DiTACEBlock(self.hidden_size,
139
- self.num_heads,
140
- mlp_ratio=self.mlp_ratio,
141
- drop_path=drop_path[i],
142
- window_size=self.window_size
143
- if i in self.window_block_indexes else 0,
144
- backend=self.attention_backend,
145
- use_condition=True,
146
- qk_norm=self.qk_norm) for i in range(self.depth)
147
- ])
148
- self.final_layer = T2IFinalLayer(self.hidden_size, self.patch_size,
149
- self.out_channels)
150
- self.initialize_weights()
151
-
152
- def load_pretrained_model(self, pretrained_model):
153
- if pretrained_model:
154
- with FS.get_from(pretrained_model, wait_finish=True) as local_path:
155
- model = torch.load(local_path, map_location='cpu')
156
- if 'state_dict' in model:
157
- model = model['state_dict']
158
- new_ckpt = OrderedDict()
159
- for k, v in model.items():
160
- if self.ignore_keys is not None:
161
- if (isinstance(self.ignore_keys, str) and re.match(self.ignore_keys, k)) or \
162
- (isinstance(self.ignore_keys, list) and k in self.ignore_keys):
163
- continue
164
- k = k.replace('.cross_attn.q_linear.', '.cross_attn.q.')
165
- k = k.replace('.cross_attn.proj.',
166
- '.cross_attn.o.').replace(
167
- '.attn.proj.', '.attn.o.')
168
- if '.cross_attn.kv_linear.' in k:
169
- k_p, v_p = torch.split(v, v.shape[0] // 2)
170
- new_ckpt[k.replace('.cross_attn.kv_linear.',
171
- '.cross_attn.k.')] = k_p
172
- new_ckpt[k.replace('.cross_attn.kv_linear.',
173
- '.cross_attn.v.')] = v_p
174
- elif '.attn.qkv.' in k:
175
- q_p, k_p, v_p = torch.split(v, v.shape[0] // 3)
176
- new_ckpt[k.replace('.attn.qkv.', '.attn.q.')] = q_p
177
- new_ckpt[k.replace('.attn.qkv.', '.attn.k.')] = k_p
178
- new_ckpt[k.replace('.attn.qkv.', '.attn.v.')] = v_p
179
- elif 'y_embedder.y_proj.' in k:
180
- new_ckpt[k.replace('y_embedder.y_proj.',
181
- 'y_embedder.')] = v
182
- elif k in ('x_embedder.proj.weight'):
183
- model_p = self.state_dict()[k]
184
- if v.shape != model_p.shape:
185
- model_p.zero_()
186
- model_p[:, :4, :, :].copy_(v)
187
- new_ckpt[k] = torch.nn.parameter.Parameter(model_p)
188
- else:
189
- new_ckpt[k] = v
190
- elif k in ('x_embedder.proj.bias'):
191
- new_ckpt[k] = v
192
- else:
193
- new_ckpt[k] = v
194
- missing, unexpected = self.load_state_dict(new_ckpt,
195
- strict=False)
196
- print(
197
- f'Restored from {pretrained_model} with {len(missing)} missing and {len(unexpected)} unexpected keys'
198
- )
199
- if len(missing) > 0:
200
- print(f'Missing Keys:\n {missing}')
201
- if len(unexpected) > 0:
202
- print(f'\nUnexpected Keys:\n {unexpected}')
203
-
204
- def forward(self,
205
- x,
206
- t=None,
207
- cond=dict(),
208
- mask=None,
209
- text_position_embeddings=None,
210
- gc_seg=-1,
211
- **kwargs):
212
- if self.freqs.device != x.device:
213
- self.freqs = self.freqs.to(x.device)
214
- if isinstance(cond, dict):
215
- context = cond.get('crossattn', None)
216
- else:
217
- context = cond
218
- if text_position_embeddings is not None:
219
- # default use the text_position_embeddings in state_dict
220
- # if state_dict doesn't including this key, use the arg: text_position_embeddings
221
- proj_position_embeddings = self.y_embedder(
222
- text_position_embeddings)
223
- else:
224
- proj_position_embeddings = None
225
-
226
- ctx_batch, txt_lens = [], []
227
- if mask is not None and isinstance(mask, list):
228
- for ctx, ctx_mask in zip(context, mask):
229
- for frame_id, one_ctx in enumerate(zip(ctx, ctx_mask)):
230
- u, m = one_ctx
231
- t_len = m.flatten().sum() # l
232
- u = u[:t_len]
233
- u = self.y_embedder(u)
234
- if frame_id == 0:
235
- u = u + proj_position_embeddings[
236
- len(ctx) -
237
- 1] if proj_position_embeddings is not None else u
238
- else:
239
- u = u + proj_position_embeddings[
240
- frame_id -
241
- 1] if proj_position_embeddings is not None else u
242
- ctx_batch.append(u)
243
- txt_lens.append(t_len)
244
- else:
245
- raise TypeError
246
- y = torch.cat(ctx_batch, dim=0)
247
- txt_lens = torch.LongTensor(txt_lens).to(x.device, non_blocking=True)
248
-
249
- batch_frames = []
250
- for u, shape, m in zip(x, cond['x_shapes'], cond['x_mask']):
251
- u = u[:, :shape[0] * shape[1]].view(-1, shape[0], shape[1])
252
- m = torch.ones_like(u[[0], :, :]) if m is None else m.squeeze(0)
253
- batch_frames.append([torch.cat([u, m], dim=0).unsqueeze(0)])
254
- if 'edit' in cond:
255
- for i, (edit, edit_mask) in enumerate(
256
- zip(cond['edit'], cond['edit_mask'])):
257
- if edit is None:
258
- continue
259
- for u, m in zip(edit, edit_mask):
260
- u = u.squeeze(0)
261
- m = torch.ones_like(
262
- u[[0], :, :]) if m is None else m.squeeze(0)
263
- batch_frames[i].append(
264
- torch.cat([u, m], dim=0).unsqueeze(0))
265
-
266
- patch_batch, shape_batch, self_x_len, cross_x_len = [], [], [], []
267
- for frames in batch_frames:
268
- patches, patch_shapes = [], []
269
- self_x_len.append(0)
270
- for frame_id, u in enumerate(frames):
271
- u = self.x_embedder(u)
272
- h, w = u.size(2), u.size(3)
273
- u = rearrange(u, '1 c h w -> (h w) c')
274
- if frame_id == 0:
275
- u = u + proj_position_embeddings[
276
- len(frames) -
277
- 1] if proj_position_embeddings is not None else u
278
- else:
279
- u = u + proj_position_embeddings[
280
- frame_id -
281
- 1] if proj_position_embeddings is not None else u
282
- patches.append(u)
283
- patch_shapes.append([h, w])
284
- cross_x_len.append(h * w) # b*s, 1
285
- self_x_len[-1] += h * w # b, 1
286
- # u = torch.cat(patches, dim=0)
287
- patch_batch.extend(patches)
288
- shape_batch.append(
289
- torch.LongTensor(patch_shapes).to(x.device, non_blocking=True))
290
- # repeat t to align with x
291
- t = torch.cat([t[i].repeat(l) for i, l in enumerate(self_x_len)])
292
- self_x_len, cross_x_len = (torch.LongTensor(self_x_len).to(
293
- x.device, non_blocking=True), torch.LongTensor(cross_x_len).to(
294
- x.device, non_blocking=True))
295
- # x = pad_sequence(tuple(patch_batch), batch_first=True) # b, s*max(cl), c
296
- x = torch.cat(patch_batch, dim=0)
297
- x_shapes = pad_sequence(tuple(shape_batch),
298
- batch_first=True) # b, max(len(frames)), 2
299
- t = self.t_embedder(t) # (N, D)
300
- t0 = self.t_block(t)
301
- # y = self.y_embedder(context)
302
-
303
- kwargs = dict(y=y,
304
- t=t0,
305
- x_shapes=x_shapes,
306
- self_x_len=self_x_len,
307
- cross_x_len=cross_x_len,
308
- freqs=self.freqs,
309
- txt_lens=txt_lens)
310
- if self.use_grad_checkpoint and gc_seg >= 0:
311
- x = checkpoint_sequential(
312
- functions=[partial(block, **kwargs) for block in self.blocks],
313
- segments=gc_seg if gc_seg > 0 else len(self.blocks),
314
- input=x,
315
- use_reentrant=False)
316
- else:
317
- for block in self.blocks:
318
- x = block(x, **kwargs)
319
- x = self.final_layer(x, t) # b*s*n, d
320
- outs, cur_length = [], 0
321
- p = self.patch_size
322
- for seq_length, shape in zip(self_x_len, shape_batch):
323
- x_i = x[cur_length:cur_length + seq_length]
324
- h, w = shape[0].tolist()
325
- u = x_i[:h * w].view(h, w, p, p, -1)
326
- u = rearrange(u, 'h w p q c -> (h p w q) c'
327
- ) # dump into sequence for following tensor ops
328
- cur_length = cur_length + seq_length
329
- outs.append(u)
330
- x = pad_sequence(tuple(outs), batch_first=True).permute(0, 2, 1)
331
- if self.pred_sigma:
332
- return x.chunk(2, dim=1)[0]
333
- else:
334
- return x
335
-
336
- def initialize_weights(self):
337
- # Initialize transformer layers:
338
- def _basic_init(module):
339
- if isinstance(module, nn.Linear):
340
- torch.nn.init.xavier_uniform_(module.weight)
341
- if module.bias is not None:
342
- nn.init.constant_(module.bias, 0)
343
-
344
- self.apply(_basic_init)
345
- # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
346
- w = self.x_embedder.proj.weight.data
347
- nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
348
- # Initialize timestep embedding MLP:
349
- nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
350
- nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
351
- nn.init.normal_(self.t_block[1].weight, std=0.02)
352
- # Initialize caption embedding MLP:
353
- if hasattr(self, 'y_embedder'):
354
- nn.init.normal_(self.y_embedder.fc1.weight, std=0.02)
355
- nn.init.normal_(self.y_embedder.fc2.weight, std=0.02)
356
- # Zero-out adaLN modulation layers
357
- for block in self.blocks:
358
- nn.init.constant_(block.cross_attn.o.weight, 0)
359
- nn.init.constant_(block.cross_attn.o.bias, 0)
360
- # Zero-out output layers:
361
- nn.init.constant_(self.final_layer.linear.weight, 0)
362
- nn.init.constant_(self.final_layer.linear.bias, 0)
363
-
364
- @property
365
- def dtype(self):
366
- return next(self.parameters()).dtype
367
-
368
- @staticmethod
369
- def get_config_template():
370
- return dict_to_yaml('BACKBONE',
371
- __class__.__name__,
372
- DiTACE.para_dict,
373
- set_name=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
modules/model/backbone/layers.py DELETED
@@ -1,386 +0,0 @@
1
- # -*- coding: utf-8 -*-
2
- # Copyright (c) Alibaba, Inc. and its affiliates.
3
- import math
4
- import warnings
5
- import torch
6
- import torch.nn as nn
7
- from .pos_embed import rope_apply_multires as rope_apply
8
-
9
- try:
10
- from flash_attn import (flash_attn_varlen_func)
11
- FLASHATTN_IS_AVAILABLE = True
12
- except ImportError as e:
13
- FLASHATTN_IS_AVAILABLE = False
14
- flash_attn_varlen_func = None
15
- warnings.warn(f'{e}')
16
-
17
- __all__ = [
18
- "drop_path",
19
- "modulate",
20
- "PatchEmbed",
21
- "DropPath",
22
- "RMSNorm",
23
- "Mlp",
24
- "TimestepEmbedder",
25
- "DiTEditBlock",
26
- "MultiHeadAttentionDiTEdit",
27
- "T2IFinalLayer",
28
- ]
29
-
30
- def drop_path(x, drop_prob: float = 0., training: bool = False):
31
- """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
32
- This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
33
- the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
34
- See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
35
- changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
36
- 'survival rate' as the argument.
37
- """
38
- if drop_prob == 0. or not training:
39
- return x
40
- keep_prob = 1 - drop_prob
41
- shape = (x.shape[0], ) + (1, ) * (
42
- x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
43
- random_tensor = keep_prob + torch.rand(
44
- shape, dtype=x.dtype, device=x.device)
45
- random_tensor.floor_() # binarize
46
- output = x.div(keep_prob) * random_tensor
47
- return output
48
-
49
-
50
- def modulate(x, shift, scale, unsqueeze=False):
51
- if unsqueeze:
52
- return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
53
- else:
54
- return x * (1 + scale) + shift
55
-
56
-
57
- class PatchEmbed(nn.Module):
58
- """ 2D Image to Patch Embedding
59
- """
60
- def __init__(
61
- self,
62
- patch_size=16,
63
- in_chans=3,
64
- embed_dim=768,
65
- norm_layer=None,
66
- flatten=True,
67
- bias=True,
68
- ):
69
- super().__init__()
70
- self.flatten = flatten
71
- self.proj = nn.Conv2d(in_chans,
72
- embed_dim,
73
- kernel_size=patch_size,
74
- stride=patch_size,
75
- bias=bias)
76
- self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
77
-
78
- def forward(self, x):
79
- x = self.proj(x)
80
- if self.flatten:
81
- x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
82
- x = self.norm(x)
83
- return x
84
-
85
-
86
- class DropPath(nn.Module):
87
- """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
88
- """
89
- def __init__(self, drop_prob=None):
90
- super(DropPath, self).__init__()
91
- self.drop_prob = drop_prob
92
-
93
- def forward(self, x):
94
- return drop_path(x, self.drop_prob, self.training)
95
-
96
-
97
- class RMSNorm(nn.Module):
98
- def __init__(self, dim, eps=1e-6):
99
- super().__init__()
100
- self.dim = dim
101
- self.eps = eps
102
- self.weight = nn.Parameter(torch.ones(dim))
103
-
104
- def forward(self, x):
105
- return self._norm(x.float()).type_as(x) * self.weight
106
-
107
- def _norm(self, x):
108
- return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
109
-
110
-
111
- class Mlp(nn.Module):
112
- """ MLP as used in Vision Transformer, MLP-Mixer and related networks
113
- """
114
- def __init__(self,
115
- in_features,
116
- hidden_features=None,
117
- out_features=None,
118
- act_layer=nn.GELU,
119
- drop=0.):
120
- super().__init__()
121
- out_features = out_features or in_features
122
- hidden_features = hidden_features or in_features
123
- self.fc1 = nn.Linear(in_features, hidden_features)
124
- self.act = act_layer()
125
- self.fc2 = nn.Linear(hidden_features, out_features)
126
- self.drop = nn.Dropout(drop)
127
-
128
- def forward(self, x):
129
- x = self.fc1(x)
130
- x = self.act(x)
131
- x = self.drop(x)
132
- x = self.fc2(x)
133
- x = self.drop(x)
134
- return x
135
-
136
-
137
- class TimestepEmbedder(nn.Module):
138
- """
139
- Embeds scalar timesteps into vector representations.
140
- """
141
- def __init__(self, hidden_size, frequency_embedding_size=256):
142
- super().__init__()
143
- self.mlp = nn.Sequential(
144
- nn.Linear(frequency_embedding_size, hidden_size, bias=True),
145
- nn.SiLU(),
146
- nn.Linear(hidden_size, hidden_size, bias=True),
147
- )
148
- self.frequency_embedding_size = frequency_embedding_size
149
-
150
- @staticmethod
151
- def timestep_embedding(t, dim, max_period=10000):
152
- """
153
- Create sinusoidal timestep embeddings.
154
- :param t: a 1-D Tensor of N indices, one per batch element.
155
- These may be fractional.
156
- :param dim: the dimension of the output.
157
- :param max_period: controls the minimum frequency of the embeddings.
158
- :return: an (N, D) Tensor of positional embeddings.
159
- """
160
- # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
161
- half = dim // 2
162
- freqs = torch.exp(
163
- -math.log(max_period) *
164
- torch.arange(start=0, end=half, dtype=torch.float32) /
165
- half).to(device=t.device)
166
- args = t[:, None].float() * freqs[None]
167
- embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
168
- if dim % 2:
169
- embedding = torch.cat(
170
- [embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
171
- return embedding
172
-
173
- def forward(self, t):
174
- t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
175
- t_emb = self.mlp(t_freq)
176
- return t_emb
177
-
178
-
179
- class DiTACEBlock(nn.Module):
180
- def __init__(self,
181
- hidden_size,
182
- num_heads,
183
- mlp_ratio=4.0,
184
- drop_path=0.,
185
- window_size=0,
186
- backend=None,
187
- use_condition=True,
188
- qk_norm=False,
189
- **block_kwargs):
190
- super().__init__()
191
- self.hidden_size = hidden_size
192
- self.use_condition = use_condition
193
- self.norm1 = nn.LayerNorm(hidden_size,
194
- elementwise_affine=False,
195
- eps=1e-6)
196
- self.attn = MultiHeadAttention(hidden_size,
197
- num_heads=num_heads,
198
- qkv_bias=True,
199
- backend=backend,
200
- qk_norm=qk_norm,
201
- **block_kwargs)
202
- if self.use_condition:
203
- self.cross_attn = MultiHeadAttention(
204
- hidden_size,
205
- context_dim=hidden_size,
206
- num_heads=num_heads,
207
- qkv_bias=True,
208
- backend=backend,
209
- qk_norm=qk_norm,
210
- **block_kwargs)
211
- self.norm2 = nn.LayerNorm(hidden_size,
212
- elementwise_affine=False,
213
- eps=1e-6)
214
- # to be compatible with lower version pytorch
215
- approx_gelu = lambda: nn.GELU(approximate='tanh')
216
- self.mlp = Mlp(in_features=hidden_size,
217
- hidden_features=int(hidden_size * mlp_ratio),
218
- act_layer=approx_gelu,
219
- drop=0)
220
- self.drop_path = DropPath(
221
- drop_path) if drop_path > 0. else nn.Identity()
222
- self.window_size = window_size
223
- self.scale_shift_table = nn.Parameter(
224
- torch.randn(6, hidden_size) / hidden_size**0.5)
225
-
226
- def forward(self, x, y, t, **kwargs):
227
- B = x.size(0)
228
- shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
229
- self.scale_shift_table[None] + t.reshape(B, 6, -1)).chunk(6, dim=1)
230
- shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
231
- shift_msa.squeeze(1), scale_msa.squeeze(1), gate_msa.squeeze(1),
232
- shift_mlp.squeeze(1), scale_mlp.squeeze(1), gate_mlp.squeeze(1))
233
- x = x + self.drop_path(gate_msa * self.attn(
234
- modulate(self.norm1(x), shift_msa, scale_msa, unsqueeze=False), **
235
- kwargs))
236
- if self.use_condition:
237
- x = x + self.cross_attn(x, context=y, **kwargs)
238
-
239
- x = x + self.drop_path(gate_mlp * self.mlp(
240
- modulate(self.norm2(x), shift_mlp, scale_mlp, unsqueeze=False)))
241
- return x
242
-
243
-
244
- class MultiHeadAttention(nn.Module):
245
- def __init__(self,
246
- dim,
247
- context_dim=None,
248
- num_heads=None,
249
- head_dim=None,
250
- attn_drop=0.0,
251
- qkv_bias=False,
252
- dropout=0.0,
253
- backend=None,
254
- qk_norm=False,
255
- eps=1e-6,
256
- **block_kwargs):
257
- super().__init__()
258
- # consider head_dim first, then num_heads
259
- num_heads = dim // head_dim if head_dim else num_heads
260
- head_dim = dim // num_heads
261
- assert num_heads * head_dim == dim
262
- context_dim = context_dim or dim
263
- self.dim = dim
264
- self.context_dim = context_dim
265
- self.num_heads = num_heads
266
- self.head_dim = head_dim
267
- self.scale = math.pow(head_dim, -0.25)
268
- # layers
269
- self.q = nn.Linear(dim, dim, bias=qkv_bias)
270
- self.k = nn.Linear(context_dim, dim, bias=qkv_bias)
271
- self.v = nn.Linear(context_dim, dim, bias=qkv_bias)
272
- self.o = nn.Linear(dim, dim)
273
- self.norm_q = RMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
274
- self.norm_k = RMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
275
-
276
- self.dropout = nn.Dropout(dropout)
277
- self.attention_op = None
278
- self.attn_drop = nn.Dropout(attn_drop)
279
- self.backend = backend
280
- assert self.backend in ('flash_attn', 'xformer_attn', 'pytorch_attn',
281
- None)
282
- if FLASHATTN_IS_AVAILABLE and self.backend in ('flash_attn', None):
283
- self.backend = 'flash_attn'
284
- self.softmax_scale = block_kwargs.get('softmax_scale', None)
285
- self.causal = block_kwargs.get('causal', False)
286
- self.window_size = block_kwargs.get('window_size', (-1, -1))
287
- self.deterministic = block_kwargs.get('deterministic', False)
288
- else:
289
- raise NotImplementedError
290
-
291
- def flash_attn(self, x, context=None, **kwargs):
292
- '''
293
- The implementation will be very slow when mask is not None,
294
- because we need rearange the x/context features according to mask.
295
- Args:
296
- x:
297
- context:
298
- mask:
299
- **kwargs:
300
- Returns: x
301
- '''
302
- dtype = kwargs.get('dtype', torch.float16)
303
-
304
- def half(x):
305
- return x if x.dtype in [torch.float16, torch.bfloat16
306
- ] else x.to(dtype)
307
-
308
- x_shapes = kwargs['x_shapes']
309
- freqs = kwargs['freqs']
310
- self_x_len = kwargs['self_x_len']
311
- cross_x_len = kwargs['cross_x_len']
312
- txt_lens = kwargs['txt_lens']
313
- n, d = self.num_heads, self.head_dim
314
-
315
- if context is None:
316
- # self-attn
317
- q = self.norm_q(self.q(x)).view(-1, n, d)
318
- k = self.norm_q(self.k(x)).view(-1, n, d)
319
- v = self.v(x).view(-1, n, d)
320
- q = rope_apply(q, self_x_len, x_shapes, freqs, pad=False)
321
- k = rope_apply(k, self_x_len, x_shapes, freqs, pad=False)
322
- q_lens = k_lens = self_x_len
323
- else:
324
- # cross-attn
325
- q = self.norm_q(self.q(x)).view(-1, n, d)
326
- k = self.norm_q(self.k(context)).view(-1, n, d)
327
- v = self.v(context).view(-1, n, d)
328
- q_lens = cross_x_len
329
- k_lens = txt_lens
330
-
331
- cu_seqlens_q = torch.cat([q_lens.new_zeros([1]),
332
- q_lens]).cumsum(0, dtype=torch.int32)
333
- cu_seqlens_k = torch.cat([k_lens.new_zeros([1]),
334
- k_lens]).cumsum(0, dtype=torch.int32)
335
- max_seqlen_q = q_lens.max()
336
- max_seqlen_k = k_lens.max()
337
-
338
- out_dtype = q.dtype
339
- q, k, v = half(q), half(k), half(v)
340
- x = flash_attn_varlen_func(q,
341
- k,
342
- v,
343
- cu_seqlens_q=cu_seqlens_q,
344
- cu_seqlens_k=cu_seqlens_k,
345
- max_seqlen_q=max_seqlen_q,
346
- max_seqlen_k=max_seqlen_k,
347
- dropout_p=self.attn_drop.p,
348
- softmax_scale=self.softmax_scale,
349
- causal=self.causal,
350
- window_size=self.window_size,
351
- deterministic=self.deterministic)
352
-
353
- x = x.type(out_dtype)
354
- x = x.reshape(-1, n * d)
355
- x = self.o(x)
356
- x = self.dropout(x)
357
- return x
358
-
359
- def forward(self, x, context=None, **kwargs):
360
- x = getattr(self, self.backend)(x, context=context, **kwargs)
361
- return x
362
-
363
-
364
- class T2IFinalLayer(nn.Module):
365
- """
366
- The final layer of PixArt.
367
- """
368
- def __init__(self, hidden_size, patch_size, out_channels):
369
- super().__init__()
370
- self.norm_final = nn.LayerNorm(hidden_size,
371
- elementwise_affine=False,
372
- eps=1e-6)
373
- self.linear = nn.Linear(hidden_size,
374
- patch_size * patch_size * out_channels,
375
- bias=True)
376
- self.scale_shift_table = nn.Parameter(
377
- torch.randn(2, hidden_size) / hidden_size**0.5)
378
- self.out_channels = out_channels
379
-
380
- def forward(self, x, t):
381
- shift, scale = (self.scale_shift_table[None] + t[:, None]).chunk(2,
382
- dim=1)
383
- shift, scale = shift.squeeze(1), scale.squeeze(1)
384
- x = modulate(self.norm_final(x), shift, scale)
385
- x = self.linear(x)
386
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
modules/model/backbone/pos_embed.py DELETED
@@ -1,85 +0,0 @@
1
- import numpy as np
2
- from einops import rearrange
3
-
4
- import torch
5
- import torch.cuda.amp as amp
6
- import torch.nn.functional as F
7
- from torch.nn.utils.rnn import pad_sequence
8
-
9
- def frame_pad(x, seq_len, shapes):
10
- max_h, max_w = np.max(shapes, 0)
11
- frames = []
12
- cur_len = 0
13
- for h, w in shapes:
14
- frame_len = h * w
15
- frames.append(
16
- F.pad(
17
- x[cur_len:cur_len + frame_len].view(h, w, -1),
18
- (0, 0, 0, max_w - w, 0, max_h - h)) # .view(max_h * max_w, -1)
19
- )
20
- cur_len += frame_len
21
- if cur_len >= seq_len:
22
- break
23
- return torch.stack(frames)
24
-
25
-
26
- def frame_unpad(x, shapes):
27
- max_h, max_w = np.max(shapes, 0)
28
- x = rearrange(x, '(b h w) n c -> b h w n c', h=max_h, w=max_w)
29
- frames = []
30
- for i, (h, w) in enumerate(shapes):
31
- if i >= len(x):
32
- break
33
- frames.append(rearrange(x[i, :h, :w], 'h w n c -> (h w) n c'))
34
- return torch.concat(frames)
35
-
36
-
37
- @amp.autocast(enabled=False)
38
- def rope_apply_multires(x, x_lens, x_shapes, freqs, pad=True):
39
- """
40
- x: [B*L, N, C].
41
- x_lens: [B].
42
- x_shapes: [B, F, 2].
43
- freqs: [M, C // 2].
44
- """
45
- n, c = x.size(1), x.size(2) // 2
46
- # split freqs
47
- freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
48
- # loop over samples
49
- output = []
50
- st = 0
51
- for i, (seq_len,
52
- shapes) in enumerate(zip(x_lens.tolist(), x_shapes.tolist())):
53
- x_i = frame_pad(x[st:st + seq_len], seq_len, shapes) # f, h, w, c
54
- f, h, w = x_i.shape[:3]
55
- pad_seq_len = f * h * w
56
- # precompute multipliers
57
- x_i = torch.view_as_complex(
58
- x_i.to(torch.float64).reshape(pad_seq_len, n, -1, 2))
59
- freqs_i = torch.cat([
60
- freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
61
- freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
62
- freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
63
- ],
64
- dim=-1).reshape(pad_seq_len, 1, -1)
65
- # apply rotary embedding
66
- x_i = torch.view_as_real(x_i * freqs_i).flatten(2).type_as(x)
67
- x_i = frame_unpad(x_i, shapes)
68
- # append to collection
69
- output.append(x_i)
70
- st += seq_len
71
- return pad_sequence(output) if pad else torch.concat(output)
72
-
73
-
74
- @amp.autocast(enabled=False)
75
- def rope_params(max_seq_len, dim, theta=10000):
76
- """
77
- Precompute the frequency tensor for complex exponentials.
78
- """
79
- assert dim % 2 == 0
80
- freqs = torch.outer(
81
- torch.arange(max_seq_len),
82
- 1.0 / torch.pow(theta,
83
- torch.arange(0, dim, 2).to(torch.float64).div(dim)))
84
- freqs = torch.polar(torch.ones_like(freqs), freqs)
85
- return freqs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
modules/model/diffusion/__init__.py DELETED
@@ -1,6 +0,0 @@
1
- # -*- coding: utf-8 -*-
2
- # Copyright (c) Alibaba, Inc. and its affiliates.
3
-
4
- from .diffusions import ACEDiffusion
5
- from .samplers import DDIMSampler
6
- from .schedules import LinearScheduler
 
 
 
 
 
 
 
modules/model/diffusion/diffusions.py DELETED
@@ -1,206 +0,0 @@
1
- # -*- coding: utf-8 -*-
2
- # Copyright (c) Alibaba, Inc. and its affiliates.
3
- import math
4
- import os
5
- from collections import OrderedDict
6
-
7
- import torch
8
- from tqdm import trange
9
-
10
- from scepter.modules.model.registry import (DIFFUSION_SAMPLERS, DIFFUSIONS,
11
- NOISE_SCHEDULERS)
12
- from scepter.modules.utils.config import Config, dict_to_yaml
13
- from scepter.modules.utils.distribute import we
14
- from scepter.modules.utils.file_system import FS
15
-
16
-
17
- @DIFFUSIONS.register_class()
18
- class ACEDiffusion(object):
19
- para_dict = {
20
- 'NOISE_SCHEDULER': {},
21
- 'SAMPLER_SCHEDULER': {},
22
- 'MIN_SNR_GAMMA': {
23
- 'value': None,
24
- 'description': 'The minimum SNR gamma value for the loss function.'
25
- },
26
- 'PREDICTION_TYPE': {
27
- 'value': 'eps',
28
- 'description':
29
- 'The type of prediction to use for the loss function.'
30
- }
31
- }
32
-
33
- def __init__(self, cfg, logger=None):
34
- super(ACEDiffusion, self).__init__()
35
- self.logger = logger
36
- self.cfg = cfg
37
- self.init_params()
38
-
39
- def init_params(self):
40
- self.min_snr_gamma = self.cfg.get('MIN_SNR_GAMMA', None)
41
- self.prediction_type = self.cfg.get('PREDICTION_TYPE', 'eps')
42
- self.noise_scheduler = NOISE_SCHEDULERS.build(self.cfg.NOISE_SCHEDULER,
43
- logger=self.logger)
44
- self.sampler_scheduler = NOISE_SCHEDULERS.build(self.cfg.get(
45
- 'SAMPLER_SCHEDULER', self.cfg.NOISE_SCHEDULER),
46
- logger=self.logger)
47
- self.num_timesteps = self.noise_scheduler.num_timesteps
48
- if self.cfg.have('WORK_DIR') and we.rank == 0:
49
- schedule_visualization = os.path.join(self.cfg.WORK_DIR,
50
- 'noise_schedule.png')
51
- with FS.put_to(schedule_visualization) as local_path:
52
- self.noise_scheduler.plot_noise_sampling_map(local_path)
53
- schedule_visualization = os.path.join(self.cfg.WORK_DIR,
54
- 'sampler_schedule.png')
55
- with FS.put_to(schedule_visualization) as local_path:
56
- self.sampler_scheduler.plot_noise_sampling_map(local_path)
57
-
58
- def sample(self,
59
- noise,
60
- model,
61
- model_kwargs={},
62
- steps=20,
63
- sampler=None,
64
- use_dynamic_cfg=False,
65
- guide_scale=None,
66
- guide_rescale=None,
67
- show_progress=False,
68
- return_intermediate=None,
69
- intermediate_callback=None,
70
- **kwargs):
71
- assert isinstance(steps, (int, torch.LongTensor))
72
- assert return_intermediate in (None, 'x0', 'xt')
73
- assert isinstance(sampler, (str, dict, Config))
74
- intermediates = []
75
-
76
- def callback_fn(x_t, t, sigma=None, alpha=None):
77
- timestamp = t
78
- t = t.repeat(len(x_t)).round().long().to(x_t.device)
79
- sigma = sigma.repeat(len(x_t), *([1] * (len(sigma.shape) - 1)))
80
- alpha = alpha.repeat(len(x_t), *([1] * (len(alpha.shape) - 1)))
81
-
82
- if guide_scale is None or guide_scale == 1.0:
83
- out = model(x=x_t, t=t, **model_kwargs)
84
- else:
85
- if use_dynamic_cfg:
86
- guidance_scale = 1 + guide_scale * (
87
- (1 - math.cos(math.pi * (
88
- (steps - timestamp.item()) / steps)**5.0)) / 2)
89
- else:
90
- guidance_scale = guide_scale
91
- y_out = model(x=x_t, t=t, **model_kwargs[0])
92
- u_out = model(x=x_t, t=t, **model_kwargs[1])
93
- out = u_out + guidance_scale * (y_out - u_out)
94
- if guide_rescale is not None and guide_rescale > 0.0:
95
- ratio = (
96
- y_out.flatten(1).std(dim=1) /
97
- (out.flatten(1).std(dim=1) + 1e-12)).view((-1, ) + (1, ) *
98
- (y_out.ndim - 1))
99
- out *= guide_rescale * ratio + (1 - guide_rescale) * 1.0
100
-
101
- if self.prediction_type == 'x0':
102
- x0 = out
103
- elif self.prediction_type == 'eps':
104
- x0 = (x_t - sigma * out) / alpha
105
- elif self.prediction_type == 'v':
106
- x0 = alpha * x_t - sigma * out
107
- else:
108
- raise NotImplementedError(
109
- f'prediction_type {self.prediction_type} not implemented')
110
-
111
- return x0
112
-
113
- sampler_ins = self.get_sampler(sampler)
114
-
115
- # this is ignored for schnell
116
- sampler_output = sampler_ins.preprare_sampler(
117
- noise,
118
- steps=steps,
119
- prediction_type=self.prediction_type,
120
- scheduler_ins=self.sampler_scheduler,
121
- callback_fn=callback_fn)
122
-
123
- for _ in trange(steps, disable=not show_progress):
124
- trange.desc = sampler_output.msg
125
- sampler_output = sampler_ins.step(sampler_output)
126
- if return_intermediate == 'x_0':
127
- intermediates.append(sampler_output.x_0)
128
- elif return_intermediate == 'x_t':
129
- intermediates.append(sampler_output.x_t)
130
- if intermediate_callback is not None:
131
- intermediate_callback(intermediates[-1])
132
- return (sampler_output.x_0, intermediates
133
- ) if return_intermediate is not None else sampler_output.x_0
134
-
135
- def loss(self,
136
- x_0,
137
- model,
138
- model_kwargs={},
139
- reduction='mean',
140
- noise=None,
141
- **kwargs):
142
- # use noise scheduler to add noise
143
- if noise is None:
144
- noise = torch.randn_like(x_0)
145
- schedule_output = self.noise_scheduler.add_noise(x_0, noise, **kwargs)
146
- x_t, t, sigma, alpha = schedule_output.x_t, schedule_output.t, schedule_output.sigma, schedule_output.alpha
147
- out = model(x=x_t, t=t, **model_kwargs)
148
-
149
- # mse loss
150
- target = {
151
- 'eps': noise,
152
- 'x0': x_0,
153
- 'v': alpha * noise - sigma * x_0
154
- }[self.prediction_type]
155
-
156
- loss = (out - target).pow(2)
157
- if reduction == 'mean':
158
- loss = loss.flatten(1).mean(dim=1)
159
-
160
- if self.min_snr_gamma is not None:
161
- alphas = self.noise_scheduler.alphas.to(x_0.device)[t]
162
- sigmas = self.noise_scheduler.sigmas.pow(2).to(x_0.device)[t]
163
- snrs = (alphas / sigmas).clamp(min=1e-20)
164
- min_snrs = snrs.clamp(max=self.min_snr_gamma)
165
- weights = min_snrs / snrs
166
- else:
167
- weights = 1
168
-
169
- loss = loss * weights
170
- return loss
171
-
172
- def get_sampler(self, sampler):
173
- if isinstance(sampler, str):
174
- if sampler not in DIFFUSION_SAMPLERS.class_map:
175
- if self.logger is not None:
176
- self.logger.info(
177
- f'{sampler} not in the defined samplers list {DIFFUSION_SAMPLERS.class_map.keys()}'
178
- )
179
- else:
180
- print(
181
- f'{sampler} not in the defined samplers list {DIFFUSION_SAMPLERS.class_map.keys()}'
182
- )
183
- return None
184
- sampler_cfg = Config(cfg_dict={'NAME': sampler}, load=False)
185
- sampler_ins = DIFFUSION_SAMPLERS.build(sampler_cfg,
186
- logger=self.logger)
187
- elif isinstance(sampler, (Config, dict, OrderedDict)):
188
- if isinstance(sampler, (dict, OrderedDict)):
189
- sampler = Config(
190
- cfg_dict={k.upper(): v
191
- for k, v in dict(sampler).items()},
192
- load=False)
193
- sampler_ins = DIFFUSION_SAMPLERS.build(sampler, logger=self.logger)
194
- else:
195
- raise NotImplementedError
196
- return sampler_ins
197
-
198
- def __repr__(self) -> str:
199
- return f'{self.__class__.__name__}' + ' ' + super().__repr__()
200
-
201
- @staticmethod
202
- def get_config_template():
203
- return dict_to_yaml('DIFFUSIONS',
204
- __class__.__name__,
205
- ACEDiffusion.para_dict,
206
- set_name=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
modules/model/diffusion/samplers.py DELETED
@@ -1,69 +0,0 @@
1
- # -*- coding: utf-8 -*-
2
- # Copyright (c) Alibaba, Inc. and its affiliates.
3
- import torch
4
-
5
- from scepter.modules.model.registry import DIFFUSION_SAMPLERS
6
- from scepter.modules.model.diffusion.samplers import BaseDiffusionSampler
7
- from scepter.modules.model.diffusion.util import _i
8
-
9
- def _i(tensor, t, x):
10
- """
11
- Index tensor using t and format the output according to x.
12
- """
13
- shape = (x.size(0), ) + (1, ) * (x.ndim - 1)
14
- if isinstance(t, torch.Tensor):
15
- t = t.to(tensor.device)
16
- return tensor[t].view(shape).to(x.device)
17
-
18
-
19
- @DIFFUSION_SAMPLERS.register_class('ddim')
20
- class DDIMSampler(BaseDiffusionSampler):
21
- def init_params(self):
22
- super().init_params()
23
- self.eta = self.cfg.get('ETA', 0.)
24
- self.discretization_type = self.cfg.get('DISCRETIZATION_TYPE',
25
- 'trailing')
26
-
27
- def preprare_sampler(self,
28
- noise,
29
- steps=20,
30
- scheduler_ins=None,
31
- prediction_type='',
32
- sigmas=None,
33
- betas=None,
34
- alphas=None,
35
- callback_fn=None,
36
- **kwargs):
37
- output = super().preprare_sampler(noise, steps, scheduler_ins,
38
- prediction_type, sigmas, betas,
39
- alphas, callback_fn, **kwargs)
40
- sigmas = output.sigmas
41
- sigmas = torch.cat([sigmas, sigmas.new_zeros([1])])
42
- sigmas_vp = (sigmas**2 / (1 + sigmas**2))**0.5
43
- sigmas_vp[sigmas == float('inf')] = 1.
44
- output.add_custom_field('sigmas_vp', sigmas_vp)
45
- return output
46
-
47
- def step(self, sampler_output):
48
- x_t = sampler_output.x_t
49
- step = sampler_output.step
50
- t = sampler_output.ts[step]
51
- sigmas_vp = sampler_output.sigmas_vp.to(x_t.device)
52
- alpha_init = _i(sampler_output.alphas_init, step, x_t[:1])
53
- sigma_init = _i(sampler_output.sigmas_init, step, x_t[:1])
54
-
55
- x = sampler_output.callback_fn(x_t, t, sigma_init, alpha_init)
56
- noise_factor = self.eta * (sigmas_vp[step + 1]**2 /
57
- sigmas_vp[step]**2 *
58
- (1 - (1 - sigmas_vp[step]**2) /
59
- (1 - sigmas_vp[step + 1]**2)))
60
- d = (x_t - (1 - sigmas_vp[step]**2)**0.5 * x) / sigmas_vp[step]
61
- x = (1 - sigmas_vp[step + 1] ** 2) ** 0.5 * x + \
62
- (sigmas_vp[step + 1] ** 2 - noise_factor ** 2) ** 0.5 * d
63
- sampler_output.x_0 = x
64
- if sigmas_vp[step + 1] > 0:
65
- x += noise_factor * torch.randn_like(x)
66
- sampler_output.x_t = x
67
- sampler_output.step += 1
68
- sampler_output.msg = f'step {step}'
69
- return sampler_output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
modules/model/diffusion/schedules.py DELETED
@@ -1,30 +0,0 @@
1
- # -*- coding: utf-8 -*-
2
- # Copyright (c) Alibaba, Inc. and its affiliates.
3
- import torch
4
-
5
- from scepter.modules.model.registry import NOISE_SCHEDULERS
6
- from scepter.modules.model.diffusion.schedules import BaseNoiseScheduler
7
-
8
-
9
- @NOISE_SCHEDULERS.register_class()
10
- class LinearScheduler(BaseNoiseScheduler):
11
- para_dict = {}
12
-
13
- def init_params(self):
14
- super().init_params()
15
- self.beta_min = self.cfg.get('BETA_MIN', 0.00085)
16
- self.beta_max = self.cfg.get('BETA_MAX', 0.012)
17
-
18
- def betas_to_sigmas(self, betas):
19
- return torch.sqrt(1 - torch.cumprod(1 - betas, dim=0))
20
-
21
- def get_schedule(self):
22
- betas = torch.linspace(self.beta_min,
23
- self.beta_max,
24
- self.num_timesteps,
25
- dtype=torch.float32)
26
- sigmas = self.betas_to_sigmas(betas)
27
- self._sigmas = sigmas
28
- self._betas = betas
29
- self._alphas = torch.sqrt(1 - sigmas**2)
30
- self._timesteps = torch.arange(len(sigmas), dtype=torch.float32)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
modules/model/embedder/__init__.py DELETED
@@ -1 +0,0 @@
1
- from .embedder import ACETextEmbedder
 
 
modules/model/embedder/embedder.py DELETED
@@ -1,184 +0,0 @@
1
- # -*- coding: utf-8 -*-
2
- # Copyright (c) Alibaba, Inc. and its affiliates.
3
- import warnings
4
- from contextlib import nullcontext
5
-
6
- import torch
7
- import torch.nn.functional as F
8
- import torch.utils.dlpack
9
- from scepter.modules.model.embedder.base_embedder import BaseEmbedder
10
- from scepter.modules.model.registry import EMBEDDERS
11
- from scepter.modules.model.tokenizer.tokenizer_component import (
12
- basic_clean, canonicalize, heavy_clean, whitespace_clean)
13
- from scepter.modules.utils.config import dict_to_yaml
14
- from scepter.modules.utils.distribute import we
15
- from scepter.modules.utils.file_system import FS
16
-
17
- try:
18
- from transformers import AutoTokenizer, T5EncoderModel
19
- except Exception as e:
20
- warnings.warn(
21
- f'Import transformers error, please deal with this problem: {e}')
22
-
23
-
24
- @EMBEDDERS.register_class()
25
- class ACETextEmbedder(BaseEmbedder):
26
- """
27
- Uses the OpenCLIP transformer encoder for text
28
- """
29
- """
30
- Uses the OpenCLIP transformer encoder for text
31
- """
32
- para_dict = {
33
- 'PRETRAINED_MODEL': {
34
- 'value':
35
- 'google/umt5-small',
36
- 'description':
37
- 'Pretrained Model for umt5, modelcard path or local path.'
38
- },
39
- 'TOKENIZER_PATH': {
40
- 'value': 'google/umt5-small',
41
- 'description':
42
- 'Tokenizer Path for umt5, modelcard path or local path.'
43
- },
44
- 'FREEZE': {
45
- 'value': True,
46
- 'description': ''
47
- },
48
- 'USE_GRAD': {
49
- 'value': False,
50
- 'description': 'Compute grad or not.'
51
- },
52
- 'CLEAN': {
53
- 'value':
54
- 'whitespace',
55
- 'description':
56
- 'Set the clean strtegy for tokenizer, used when TOKENIZER_PATH is not None.'
57
- },
58
- 'LAYER': {
59
- 'value': 'last',
60
- 'description': ''
61
- },
62
- 'LEGACY': {
63
- 'value':
64
- True,
65
- 'description':
66
- 'Whether use legacy returnd feature or not ,default True.'
67
- }
68
- }
69
-
70
- def __init__(self, cfg, logger=None):
71
- super().__init__(cfg, logger=logger)
72
- pretrained_path = cfg.get('PRETRAINED_MODEL', None)
73
- self.t5_dtype = cfg.get('T5_DTYPE', 'float32')
74
- assert pretrained_path
75
- with FS.get_dir_to_local_dir(pretrained_path,
76
- wait_finish=True) as local_path:
77
- self.model = T5EncoderModel.from_pretrained(
78
- local_path,
79
- torch_dtype=getattr(
80
- torch,
81
- 'float' if self.t5_dtype == 'float32' else self.t5_dtype))
82
- tokenizer_path = cfg.get('TOKENIZER_PATH', None)
83
- self.length = cfg.get('LENGTH', 77)
84
-
85
- self.use_grad = cfg.get('USE_GRAD', False)
86
- self.clean = cfg.get('CLEAN', 'whitespace')
87
- self.added_identifier = cfg.get('ADDED_IDENTIFIER', None)
88
- if tokenizer_path:
89
- self.tokenize_kargs = {'return_tensors': 'pt'}
90
- with FS.get_dir_to_local_dir(tokenizer_path,
91
- wait_finish=True) as local_path:
92
- if self.added_identifier is not None and isinstance(
93
- self.added_identifier, list):
94
- self.tokenizer = AutoTokenizer.from_pretrained(local_path)
95
- else:
96
- self.tokenizer = AutoTokenizer.from_pretrained(local_path)
97
- if self.length is not None:
98
- self.tokenize_kargs.update({
99
- 'padding': 'max_length',
100
- 'truncation': True,
101
- 'max_length': self.length
102
- })
103
- self.eos_token = self.tokenizer(
104
- self.tokenizer.eos_token)['input_ids'][0]
105
- else:
106
- self.tokenizer = None
107
- self.tokenize_kargs = {}
108
-
109
- self.use_grad = cfg.get('USE_GRAD', False)
110
- self.clean = cfg.get('CLEAN', 'whitespace')
111
-
112
- def freeze(self):
113
- self.model = self.model.eval()
114
- for param in self.parameters():
115
- param.requires_grad = False
116
-
117
- # encode && encode_text
118
- def forward(self, tokens, return_mask=False, use_mask=True):
119
- # tokenization
120
- embedding_context = nullcontext if self.use_grad else torch.no_grad
121
- with embedding_context():
122
- if use_mask:
123
- x = self.model(tokens.input_ids.to(we.device_id),
124
- tokens.attention_mask.to(we.device_id))
125
- else:
126
- x = self.model(tokens.input_ids.to(we.device_id))
127
- x = x.last_hidden_state
128
-
129
- if return_mask:
130
- return x.detach() + 0.0, tokens.attention_mask.to(we.device_id)
131
- else:
132
- return x.detach() + 0.0, None
133
-
134
- def _clean(self, text):
135
- if self.clean == 'whitespace':
136
- text = whitespace_clean(basic_clean(text))
137
- elif self.clean == 'lower':
138
- text = whitespace_clean(basic_clean(text)).lower()
139
- elif self.clean == 'canonicalize':
140
- text = canonicalize(basic_clean(text))
141
- elif self.clean == 'heavy':
142
- text = heavy_clean(basic_clean(text))
143
- return text
144
-
145
- def encode(self, text, return_mask=False, use_mask=True):
146
- if isinstance(text, str):
147
- text = [text]
148
- if self.clean:
149
- text = [self._clean(u) for u in text]
150
- assert self.tokenizer is not None
151
- cont, mask = [], []
152
- with torch.autocast(device_type='cuda',
153
- enabled=self.t5_dtype in ('float16', 'bfloat16'),
154
- dtype=getattr(torch, self.t5_dtype)):
155
- for tt in text:
156
- tokens = self.tokenizer([tt], **self.tokenize_kargs)
157
- one_cont, one_mask = self(tokens,
158
- return_mask=return_mask,
159
- use_mask=use_mask)
160
- cont.append(one_cont)
161
- mask.append(one_mask)
162
- if return_mask:
163
- return torch.cat(cont, dim=0), torch.cat(mask, dim=0)
164
- else:
165
- return torch.cat(cont, dim=0)
166
-
167
- def encode_list(self, text_list, return_mask=True):
168
- cont_list = []
169
- mask_list = []
170
- for pp in text_list:
171
- cont, cont_mask = self.encode(pp, return_mask=return_mask)
172
- cont_list.append(cont)
173
- mask_list.append(cont_mask)
174
- if return_mask:
175
- return cont_list, mask_list
176
- else:
177
- return cont_list
178
-
179
- @staticmethod
180
- def get_config_template():
181
- return dict_to_yaml('MODELS',
182
- __class__.__name__,
183
- ACETextEmbedder.para_dict,
184
- set_name=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
modules/model/network/__init__.py DELETED
@@ -1 +0,0 @@
1
- from .ldm_ace import LdmACE
 
 
modules/model/network/ldm_ace.py DELETED
@@ -1,353 +0,0 @@
1
- # -*- coding: utf-8 -*-
2
- # Copyright (c) Alibaba, Inc. and its affiliates.
3
- import copy
4
- import random
5
- from contextlib import nullcontext
6
-
7
- import torch
8
- import torch.nn.functional as F
9
- from torch import nn
10
-
11
- from scepter.modules.model.network.ldm import LatentDiffusion
12
- from scepter.modules.model.registry import MODELS
13
- from scepter.modules.utils.config import dict_to_yaml
14
- from scepter.modules.utils.distribute import we
15
-
16
- from ..utils.basic_utils import (
17
- check_list_of_list,
18
- pack_imagelist_into_tensor_v2 as pack_imagelist_into_tensor,
19
- to_device,
20
- unpack_tensor_into_imagelist
21
- )
22
-
23
-
24
- class TextEmbedding(nn.Module):
25
- def __init__(self, embedding_shape):
26
- super().__init__()
27
- self.pos = nn.Parameter(data=torch.zeros(embedding_shape))
28
-
29
-
30
- @MODELS.register_class()
31
- class LdmACE(LatentDiffusion):
32
- para_dict = LatentDiffusion.para_dict
33
- para_dict['DECODER_BIAS'] = {'value': 0, 'description': ''}
34
-
35
- def __init__(self, cfg, logger=None):
36
- super().__init__(cfg, logger=logger)
37
- self.interpolate_func = lambda x: (F.interpolate(
38
- x.unsqueeze(0),
39
- scale_factor=1 / self.size_factor,
40
- mode='nearest-exact') if x is not None else None)
41
-
42
- self.text_indentifers = cfg.get('TEXT_IDENTIFIER', [])
43
- self.use_text_pos_embeddings = cfg.get('USE_TEXT_POS_EMBEDDINGS',
44
- False)
45
- if self.use_text_pos_embeddings:
46
- self.text_position_embeddings = TextEmbedding(
47
- (10, 4096)).eval().requires_grad_(False)
48
- else:
49
- self.text_position_embeddings = None
50
-
51
- self.logger.info(self.model)
52
-
53
- @torch.no_grad()
54
- def encode_first_stage(self, x, **kwargs):
55
- return [
56
- self.scale_factor *
57
- self.first_stage_model._encode(i.unsqueeze(0).to(torch.float16))
58
- for i in x
59
- ]
60
-
61
- @torch.no_grad()
62
- def decode_first_stage(self, z):
63
- return [
64
- self.first_stage_model._decode(1. / self.scale_factor *
65
- i.to(torch.float16)) for i in z
66
- ]
67
-
68
- def cond_stage_embeddings(self, prompt, edit_image, cont, cont_mask):
69
- if self.use_text_pos_embeddings and not torch.sum(
70
- self.text_position_embeddings.pos) > 0:
71
- identifier_cont, identifier_cont_mask = getattr(
72
- self.cond_stage_model, 'encode')(self.text_indentifers,
73
- return_mask=True)
74
- self.text_position_embeddings.load_state_dict(
75
- {'pos': identifier_cont[:, 0, :]})
76
- cont_, cont_mask_ = [], []
77
- for pp, edit, c, cm in zip(prompt, edit_image, cont, cont_mask):
78
- if isinstance(pp, list):
79
- cont_.append([c[-1], *c] if len(edit) > 0 else [c[-1]])
80
- cont_mask_.append([cm[-1], *cm] if len(edit) > 0 else [cm[-1]])
81
- else:
82
- raise NotImplementedError
83
-
84
- return cont_, cont_mask_
85
-
86
- def limit_batch_data(self, batch_data_list, log_num):
87
- if log_num and log_num > 0:
88
- batch_data_list_limited = []
89
- for sub_data in batch_data_list:
90
- if sub_data is not None:
91
- sub_data = sub_data[:log_num]
92
- batch_data_list_limited.append(sub_data)
93
- return batch_data_list_limited
94
- else:
95
- return batch_data_list
96
-
97
- def forward_train(self,
98
- edit_image=[],
99
- edit_image_mask=[],
100
- image=None,
101
- image_mask=None,
102
- noise=None,
103
- prompt=[],
104
- **kwargs):
105
- '''
106
- Args:
107
- edit_image: list of list of edit_image
108
- edit_image_mask: list of list of edit_image_mask
109
- image: target image
110
- image_mask: target image mask
111
- noise: default is None, generate automaticly
112
- prompt: list of list of text
113
- **kwargs:
114
- Returns:
115
- '''
116
- assert check_list_of_list(prompt) and check_list_of_list(
117
- edit_image) and check_list_of_list(edit_image_mask)
118
- assert len(edit_image) == len(edit_image_mask) == len(prompt)
119
- assert self.cond_stage_model is not None
120
- gc_seg = kwargs.pop('gc_seg', [])
121
- gc_seg = int(gc_seg[0]) if len(gc_seg) > 0 else 0
122
- context = {}
123
-
124
- # process image
125
- image = to_device(image)
126
- x_start = self.encode_first_stage(image, **kwargs)
127
- x_start, x_shapes = pack_imagelist_into_tensor(x_start) # B, C, L
128
- n, _, _ = x_start.shape
129
- t = torch.randint(0, self.num_timesteps, (n, ),
130
- device=x_start.device).long()
131
- context['x_shapes'] = x_shapes
132
-
133
- # process image mask
134
- image_mask = to_device(image_mask, strict=False)
135
- context['x_mask'] = [self.interpolate_func(i) for i in image_mask
136
- ] if image_mask is not None else [None] * n
137
-
138
- # process text
139
- # with torch.autocast(device_type="cuda", enabled=True, dtype=torch.bfloat16):
140
- prompt_ = [[pp] if isinstance(pp, str) else pp for pp in prompt]
141
- try:
142
- cont, cont_mask = getattr(self.cond_stage_model,
143
- 'encode_list')(prompt_, return_mask=True)
144
- except Exception as e:
145
- print(e, prompt_)
146
- cont, cont_mask = self.cond_stage_embeddings(prompt, edit_image, cont,
147
- cont_mask)
148
- context['crossattn'] = cont
149
-
150
- # process edit image & edit image mask
151
- edit_image = [to_device(i, strict=False) for i in edit_image]
152
- edit_image_mask = [to_device(i, strict=False) for i in edit_image_mask]
153
- e_img, e_mask = [], []
154
- for u, m in zip(edit_image, edit_image_mask):
155
- if m is None:
156
- m = [None] * len(u) if u is not None else [None]
157
- e_img.append(
158
- self.encode_first_stage(u, **kwargs) if u is not None else u)
159
- e_mask.append([
160
- self.interpolate_func(i) if i is not None else None for i in m
161
- ])
162
- context['edit'], context['edit_mask'] = e_img, e_mask
163
-
164
- # process loss
165
- loss = self.diffusion.loss(
166
- x_0=x_start,
167
- t=t,
168
- noise=noise,
169
- model=self.model,
170
- model_kwargs={
171
- 'cond':
172
- context,
173
- 'mask':
174
- cont_mask,
175
- 'gc_seg':
176
- gc_seg,
177
- 'text_position_embeddings':
178
- self.text_position_embeddings.pos if hasattr(
179
- self.text_position_embeddings, 'pos') else None
180
- },
181
- **kwargs)
182
- loss = loss.mean()
183
- ret = {'loss': loss, 'probe_data': {'prompt': prompt}}
184
- return ret
185
-
186
- @torch.no_grad()
187
- def forward_test(self,
188
- edit_image=[],
189
- edit_image_mask=[],
190
- image=None,
191
- image_mask=None,
192
- prompt=[],
193
- n_prompt=[],
194
- sampler='ddim',
195
- sample_steps=20,
196
- guide_scale=4.5,
197
- guide_rescale=0.5,
198
- log_num=-1,
199
- seed=2024,
200
- **kwargs):
201
-
202
- assert check_list_of_list(prompt) and check_list_of_list(
203
- edit_image) and check_list_of_list(edit_image_mask)
204
- assert len(edit_image) == len(edit_image_mask) == len(prompt)
205
- assert self.cond_stage_model is not None
206
- # gc_seg is unused
207
- kwargs.pop('gc_seg', -1)
208
- # prepare data
209
- context, null_context = {}, {}
210
-
211
- prompt, n_prompt, image, image_mask, edit_image, edit_image_mask = self.limit_batch_data(
212
- [prompt, n_prompt, image, image_mask, edit_image, edit_image_mask],
213
- log_num)
214
- g = torch.Generator(device=we.device_id)
215
- seed = seed if seed >= 0 else random.randint(0, 2**32 - 1)
216
- g.manual_seed(seed)
217
- n_prompt = copy.deepcopy(prompt)
218
- # only modify the last prompt to be zero
219
- for nn_p_id, nn_p in enumerate(n_prompt):
220
- if isinstance(nn_p, str):
221
- n_prompt[nn_p_id] = ['']
222
- elif isinstance(nn_p, list):
223
- n_prompt[nn_p_id][-1] = ''
224
- else:
225
- raise NotImplementedError
226
- # process image
227
- image = to_device(image)
228
- x = self.encode_first_stage(image, **kwargs)
229
- noise = [
230
- torch.empty(*i.shape, device=we.device_id).normal_(generator=g)
231
- for i in x
232
- ]
233
- noise, x_shapes = pack_imagelist_into_tensor(noise)
234
- context['x_shapes'] = null_context['x_shapes'] = x_shapes
235
-
236
- # process image mask
237
- image_mask = to_device(image_mask, strict=False)
238
- cond_mask = [self.interpolate_func(i) for i in image_mask
239
- ] if image_mask is not None else [None] * len(image)
240
- context['x_mask'] = null_context['x_mask'] = cond_mask
241
- # process text
242
- # with torch.autocast(device_type="cuda", enabled=True, dtype=torch.bfloat16):
243
- prompt_ = [[pp] if isinstance(pp, str) else pp for pp in prompt]
244
- cont, cont_mask = getattr(self.cond_stage_model,
245
- 'encode_list')(prompt_, return_mask=True)
246
- cont, cont_mask = self.cond_stage_embeddings(prompt, edit_image, cont,
247
- cont_mask)
248
- null_cont, null_cont_mask = getattr(self.cond_stage_model,
249
- 'encode_list')(n_prompt,
250
- return_mask=True)
251
- null_cont, null_cont_mask = self.cond_stage_embeddings(
252
- prompt, edit_image, null_cont, null_cont_mask)
253
- context['crossattn'] = cont
254
- null_context['crossattn'] = null_cont
255
-
256
- # processe edit image & edit image mask
257
- edit_image = [to_device(i, strict=False) for i in edit_image]
258
- edit_image_mask = [to_device(i, strict=False) for i in edit_image_mask]
259
- e_img, e_mask = [], []
260
- for u, m in zip(edit_image, edit_image_mask):
261
- if u is None:
262
- continue
263
- if m is None:
264
- m = [None] * len(u)
265
- e_img.append(self.encode_first_stage(u, **kwargs))
266
- e_mask.append([self.interpolate_func(i) for i in m])
267
- null_context['edit'] = context['edit'] = e_img
268
- null_context['edit_mask'] = context['edit_mask'] = e_mask
269
-
270
- # process sample
271
- model = self.model_ema if self.use_ema and self.eval_ema else self.model
272
- embedding_context = model.no_sync if isinstance(model, torch.distributed.fsdp.FullyShardedDataParallel) \
273
- else nullcontext
274
- with embedding_context():
275
- samples = self.diffusion.sample(
276
- sampler=sampler,
277
- noise=noise,
278
- model=model,
279
- model_kwargs=[{
280
- 'cond':
281
- context,
282
- 'mask':
283
- cont_mask,
284
- 'text_position_embeddings':
285
- self.text_position_embeddings.pos if hasattr(
286
- self.text_position_embeddings, 'pos') else None
287
- }, {
288
- 'cond':
289
- null_context,
290
- 'mask':
291
- null_cont_mask,
292
- 'text_position_embeddings':
293
- self.text_position_embeddings.pos if hasattr(
294
- self.text_position_embeddings, 'pos') else None
295
- }] if guide_scale is not None and guide_scale > 1 else {
296
- 'cond':
297
- context,
298
- 'mask':
299
- cont_mask,
300
- 'text_position_embeddings':
301
- self.text_position_embeddings.pos if hasattr(
302
- self.text_position_embeddings, 'pos') else None
303
- },
304
- steps=sample_steps,
305
- guide_scale=guide_scale,
306
- guide_rescale=guide_rescale,
307
- show_progress=True,
308
- **kwargs)
309
-
310
- samples = unpack_tensor_into_imagelist(samples, x_shapes)
311
- x_samples = self.decode_first_stage(samples)
312
- outputs = list()
313
- for i in range(len(prompt)):
314
- rec_img = torch.clamp(
315
- (x_samples[i] + 1.0) / 2.0 + self.decoder_bias / 255,
316
- min=0.0,
317
- max=1.0)
318
- rec_img = rec_img.squeeze(0)
319
- edit_imgs, edit_img_masks = [], []
320
- if edit_image is not None and edit_image[i] is not None:
321
- if edit_image_mask[i] is None:
322
- edit_image_mask[i] = [None] * len(edit_image[i])
323
- for edit_img, edit_mask in zip(edit_image[i],
324
- edit_image_mask[i]):
325
- edit_img = torch.clamp((edit_img + 1.0) / 2.0,
326
- min=0.0,
327
- max=1.0)
328
- edit_imgs.append(edit_img.squeeze(0))
329
- if edit_mask is None:
330
- edit_mask = torch.ones_like(edit_img[[0], :, :])
331
- edit_img_masks.append(edit_mask)
332
- one_tup = {
333
- 'reconstruct_image': rec_img,
334
- 'instruction': prompt[i],
335
- 'edit_image': edit_imgs if len(edit_imgs) > 0 else None,
336
- 'edit_mask': edit_img_masks if len(edit_imgs) > 0 else None
337
- }
338
- if image is not None:
339
- if image_mask is None:
340
- image_mask = [None] * len(image)
341
- ori_img = torch.clamp((image[i] + 1.0) / 2.0, min=0.0, max=1.0)
342
- one_tup['target_image'] = ori_img.squeeze(0)
343
- one_tup['target_mask'] = image_mask[i] if image_mask[
344
- i] is not None else torch.ones_like(ori_img[[0], :, :])
345
- outputs.append(one_tup)
346
- return outputs
347
-
348
- @staticmethod
349
- def get_config_template():
350
- return dict_to_yaml('MODEL',
351
- __class__.__name__,
352
- LdmACE.para_dict,
353
- set_name=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
modules/model/utils/basic_utils.py DELETED
@@ -1,104 +0,0 @@
1
- # -*- coding: utf-8 -*-
2
- # Copyright (c) Alibaba, Inc. and its affiliates.
3
- from inspect import isfunction
4
-
5
- import torch
6
- from torch.nn.utils.rnn import pad_sequence
7
-
8
- from scepter.modules.utils.distribute import we
9
-
10
-
11
- def exists(x):
12
- return x is not None
13
-
14
-
15
- def default(val, d):
16
- if exists(val):
17
- return val
18
- return d() if isfunction(d) else d
19
-
20
-
21
- def disabled_train(self, mode=True):
22
- """Overwrite model.train with this function to make sure train/eval mode
23
- does not change anymore."""
24
- return self
25
-
26
-
27
- def transfer_size(para_num):
28
- if para_num > 1000 * 1000 * 1000 * 1000:
29
- bill = para_num / (1000 * 1000 * 1000 * 1000)
30
- return '{:.2f}T'.format(bill)
31
- elif para_num > 1000 * 1000 * 1000:
32
- gyte = para_num / (1000 * 1000 * 1000)
33
- return '{:.2f}B'.format(gyte)
34
- elif para_num > (1000 * 1000):
35
- meta = para_num / (1000 * 1000)
36
- return '{:.2f}M'.format(meta)
37
- elif para_num > 1000:
38
- kelo = para_num / 1000
39
- return '{:.2f}K'.format(kelo)
40
- else:
41
- return para_num
42
-
43
-
44
- def count_params(model):
45
- total_params = sum(p.numel() for p in model.parameters())
46
- return transfer_size(total_params)
47
-
48
-
49
- def expand_dims_like(x, y):
50
- while x.dim() != y.dim():
51
- x = x.unsqueeze(-1)
52
- return x
53
-
54
-
55
- def unpack_tensor_into_imagelist(image_tensor, shapes):
56
- image_list = []
57
- for img, shape in zip(image_tensor, shapes):
58
- h, w = shape[0], shape[1]
59
- image_list.append(img[:, :h * w].view(1, -1, h, w))
60
-
61
- return image_list
62
-
63
-
64
- def find_example(tensor_list, image_list):
65
- for i in tensor_list:
66
- if isinstance(i, torch.Tensor):
67
- return torch.zeros_like(i)
68
- for i in image_list:
69
- if isinstance(i, torch.Tensor):
70
- _, c, h, w = i.size()
71
- return torch.zeros_like(i.view(c, h * w).transpose(1, 0))
72
- return None
73
-
74
-
75
- def pack_imagelist_into_tensor_v2(image_list):
76
- # allow None
77
- example = None
78
- image_tensor, shapes = [], []
79
- for img in image_list:
80
- if img is None:
81
- example = find_example(image_tensor,
82
- image_list) if example is None else example
83
- image_tensor.append(example)
84
- shapes.append(None)
85
- continue
86
- _, c, h, w = img.size()
87
- image_tensor.append(img.view(c, h * w).transpose(1, 0)) # h*w, c
88
- shapes.append((h, w))
89
-
90
- image_tensor = pad_sequence(image_tensor,
91
- batch_first=True).permute(0, 2, 1) # b, c, l
92
- return image_tensor, shapes
93
-
94
-
95
- def to_device(inputs, strict=True):
96
- if inputs is None:
97
- return None
98
- if strict:
99
- assert all(isinstance(i, torch.Tensor) for i in inputs)
100
- return [i.to(we.device_id) if i is not None else None for i in inputs]
101
-
102
-
103
- def check_list_of_list(ll):
104
- return isinstance(ll, list) and all(isinstance(i, list) for i in ll)