Spaces:
Running
on
Zero
Running
on
Zero
modify somefiles
Browse files- infer.py +0 -364
- modules/__init__.py +0 -1
- modules/model/__init__.py +0 -1
- modules/model/backbone/__init__.py +0 -3
- modules/model/backbone/ace.py +0 -373
- modules/model/backbone/layers.py +0 -386
- modules/model/backbone/pos_embed.py +0 -85
- modules/model/diffusion/__init__.py +0 -6
- modules/model/diffusion/diffusions.py +0 -206
- modules/model/diffusion/samplers.py +0 -69
- modules/model/diffusion/schedules.py +0 -30
- modules/model/embedder/__init__.py +0 -1
- modules/model/embedder/embedder.py +0 -184
- modules/model/network/__init__.py +0 -1
- modules/model/network/ldm_ace.py +0 -353
- modules/model/utils/basic_utils.py +0 -104
infer.py
DELETED
@@ -1,364 +0,0 @@
|
|
1 |
-
# -*- coding: utf-8 -*-
|
2 |
-
# Copyright (c) Alibaba, Inc. and its affiliates.
|
3 |
-
import copy
|
4 |
-
import math
|
5 |
-
import random
|
6 |
-
import numpy as np
|
7 |
-
from PIL import Image
|
8 |
-
|
9 |
-
import torch
|
10 |
-
import torch.nn as nn
|
11 |
-
import torch.nn.functional as F
|
12 |
-
import torchvision.transforms.functional as TF
|
13 |
-
|
14 |
-
from scepter.modules.model.registry import DIFFUSIONS
|
15 |
-
from scepter.modules.utils.distribute import we
|
16 |
-
from scepter.modules.utils.logger import get_logger
|
17 |
-
from scepter.modules.inference.diffusion_inference import DiffusionInference, get_model
|
18 |
-
|
19 |
-
from modules.model.utils.basic_utils import (
|
20 |
-
check_list_of_list,
|
21 |
-
pack_imagelist_into_tensor_v2 as pack_imagelist_into_tensor,
|
22 |
-
to_device,
|
23 |
-
unpack_tensor_into_imagelist
|
24 |
-
)
|
25 |
-
|
26 |
-
|
27 |
-
def process_edit_image(images,
|
28 |
-
masks,
|
29 |
-
tasks,
|
30 |
-
max_seq_len=1024,
|
31 |
-
max_aspect_ratio=4,
|
32 |
-
d=16,
|
33 |
-
**kwargs):
|
34 |
-
|
35 |
-
if not isinstance(images, list):
|
36 |
-
images = [images]
|
37 |
-
if not isinstance(masks, list):
|
38 |
-
masks = [masks]
|
39 |
-
if not isinstance(tasks, list):
|
40 |
-
tasks = [tasks]
|
41 |
-
|
42 |
-
img_tensors = []
|
43 |
-
mask_tensors = []
|
44 |
-
for img, mask, task in zip(images, masks, tasks):
|
45 |
-
if mask is None or mask == '':
|
46 |
-
mask = Image.new('L', img.size, 0)
|
47 |
-
W, H = img.size
|
48 |
-
if H / W > max_aspect_ratio:
|
49 |
-
img = TF.center_crop(img, [int(max_aspect_ratio * W), W])
|
50 |
-
mask = TF.center_crop(mask, [int(max_aspect_ratio * W), W])
|
51 |
-
elif W / H > max_aspect_ratio:
|
52 |
-
img = TF.center_crop(img, [H, int(max_aspect_ratio * H)])
|
53 |
-
mask = TF.center_crop(mask, [H, int(max_aspect_ratio * H)])
|
54 |
-
|
55 |
-
H, W = img.height, img.width
|
56 |
-
scale = min(1.0, math.sqrt(max_seq_len / ((H / d) * (W / d))))
|
57 |
-
rH = int(H * scale) // d * d # ensure divisible by self.d
|
58 |
-
rW = int(W * scale) // d * d
|
59 |
-
|
60 |
-
img = TF.resize(img, (rH, rW),
|
61 |
-
interpolation=TF.InterpolationMode.BICUBIC)
|
62 |
-
mask = TF.resize(mask, (rH, rW),
|
63 |
-
interpolation=TF.InterpolationMode.NEAREST_EXACT)
|
64 |
-
|
65 |
-
mask = np.asarray(mask)
|
66 |
-
mask = np.where(mask > 128, 1, 0)
|
67 |
-
mask = mask.astype(
|
68 |
-
np.float32) if np.any(mask) else np.ones_like(mask).astype(
|
69 |
-
np.float32)
|
70 |
-
|
71 |
-
img_tensor = TF.to_tensor(img).to(we.device_id)
|
72 |
-
img_tensor = TF.normalize(img_tensor,
|
73 |
-
mean=[0.5, 0.5, 0.5],
|
74 |
-
std=[0.5, 0.5, 0.5])
|
75 |
-
mask_tensor = TF.to_tensor(mask).to(we.device_id)
|
76 |
-
if task in ['inpainting', 'Try On', 'Inpainting']:
|
77 |
-
mask_indicator = mask_tensor.repeat(3, 1, 1)
|
78 |
-
img_tensor[mask_indicator == 1] = -1.0
|
79 |
-
img_tensors.append(img_tensor)
|
80 |
-
mask_tensors.append(mask_tensor)
|
81 |
-
return img_tensors, mask_tensors
|
82 |
-
|
83 |
-
|
84 |
-
class TextEmbedding(nn.Module):
|
85 |
-
def __init__(self, embedding_shape):
|
86 |
-
super().__init__()
|
87 |
-
self.pos = nn.Parameter(data=torch.zeros(embedding_shape))
|
88 |
-
|
89 |
-
|
90 |
-
class ACEInference(DiffusionInference):
|
91 |
-
def __init__(self, logger=None):
|
92 |
-
if logger is None:
|
93 |
-
logger = get_logger(name='scepter')
|
94 |
-
self.logger = logger
|
95 |
-
self.loaded_model = {}
|
96 |
-
self.loaded_model_name = [
|
97 |
-
'diffusion_model', 'first_stage_model', 'cond_stage_model'
|
98 |
-
]
|
99 |
-
|
100 |
-
def init_from_cfg(self, cfg):
|
101 |
-
self.name = cfg.NAME
|
102 |
-
self.is_default = cfg.get('IS_DEFAULT', False)
|
103 |
-
module_paras = self.load_default(cfg.get('DEFAULT_PARAS', None))
|
104 |
-
assert cfg.have('MODEL')
|
105 |
-
|
106 |
-
self.diffusion_model = self.infer_model(
|
107 |
-
cfg.MODEL.DIFFUSION_MODEL, module_paras.get(
|
108 |
-
'DIFFUSION_MODEL',
|
109 |
-
None)) if cfg.MODEL.have('DIFFUSION_MODEL') else None
|
110 |
-
self.first_stage_model = self.infer_model(
|
111 |
-
cfg.MODEL.FIRST_STAGE_MODEL,
|
112 |
-
module_paras.get(
|
113 |
-
'FIRST_STAGE_MODEL',
|
114 |
-
None)) if cfg.MODEL.have('FIRST_STAGE_MODEL') else None
|
115 |
-
self.cond_stage_model = self.infer_model(
|
116 |
-
cfg.MODEL.COND_STAGE_MODEL,
|
117 |
-
module_paras.get(
|
118 |
-
'COND_STAGE_MODEL',
|
119 |
-
None)) if cfg.MODEL.have('COND_STAGE_MODEL') else None
|
120 |
-
self.diffusion = DIFFUSIONS.build(cfg.MODEL.DIFFUSION,
|
121 |
-
logger=self.logger)
|
122 |
-
|
123 |
-
self.interpolate_func = lambda x: (F.interpolate(
|
124 |
-
x.unsqueeze(0),
|
125 |
-
scale_factor=1 / self.size_factor,
|
126 |
-
mode='nearest-exact') if x is not None else None)
|
127 |
-
self.text_indentifers = cfg.MODEL.get('TEXT_IDENTIFIER', [])
|
128 |
-
self.use_text_pos_embeddings = cfg.MODEL.get('USE_TEXT_POS_EMBEDDINGS',
|
129 |
-
False)
|
130 |
-
if self.use_text_pos_embeddings:
|
131 |
-
self.text_position_embeddings = TextEmbedding(
|
132 |
-
(10, 4096)).eval().requires_grad_(False).to(we.device_id)
|
133 |
-
else:
|
134 |
-
self.text_position_embeddings = None
|
135 |
-
|
136 |
-
self.max_seq_len = cfg.MODEL.DIFFUSION_MODEL.MAX_SEQ_LEN
|
137 |
-
self.scale_factor = cfg.get('SCALE_FACTOR', 0.18215)
|
138 |
-
self.size_factor = cfg.get('SIZE_FACTOR', 8)
|
139 |
-
self.decoder_bias = cfg.get('DECODER_BIAS', 0)
|
140 |
-
self.default_n_prompt = cfg.get('DEFAULT_N_PROMPT', '')
|
141 |
-
|
142 |
-
self.dynamic_load(self.first_stage_model, 'first_stage_model')
|
143 |
-
self.dynamic_load(self.cond_stage_model, 'cond_stage_model')
|
144 |
-
self.dynamic_load(self.diffusion_model, 'diffusion_model')
|
145 |
-
|
146 |
-
@torch.no_grad()
|
147 |
-
def encode_first_stage(self, x, **kwargs):
|
148 |
-
_, dtype = self.get_function_info(self.first_stage_model, 'encode')
|
149 |
-
with torch.autocast('cuda',
|
150 |
-
enabled=(dtype != 'float32'),
|
151 |
-
dtype=getattr(torch, dtype)):
|
152 |
-
z = [
|
153 |
-
self.scale_factor * get_model(self.first_stage_model)._encode(
|
154 |
-
i.unsqueeze(0).to(getattr(torch, dtype))) for i in x
|
155 |
-
]
|
156 |
-
return z
|
157 |
-
|
158 |
-
@torch.no_grad()
|
159 |
-
def decode_first_stage(self, z):
|
160 |
-
_, dtype = self.get_function_info(self.first_stage_model, 'decode')
|
161 |
-
with torch.autocast('cuda',
|
162 |
-
enabled=(dtype != 'float32'),
|
163 |
-
dtype=getattr(torch, dtype)):
|
164 |
-
x = [
|
165 |
-
get_model(self.first_stage_model)._decode(
|
166 |
-
1. / self.scale_factor * i.to(getattr(torch, dtype)))
|
167 |
-
for i in z
|
168 |
-
]
|
169 |
-
return x
|
170 |
-
|
171 |
-
@torch.no_grad()
|
172 |
-
def __call__(self,
|
173 |
-
image=None,
|
174 |
-
mask=None,
|
175 |
-
prompt='',
|
176 |
-
task=None,
|
177 |
-
negative_prompt='',
|
178 |
-
output_height=512,
|
179 |
-
output_width=512,
|
180 |
-
sampler='ddim',
|
181 |
-
sample_steps=20,
|
182 |
-
guide_scale=4.5,
|
183 |
-
guide_rescale=0.5,
|
184 |
-
seed=-1,
|
185 |
-
history_io=None,
|
186 |
-
tar_index=0,
|
187 |
-
**kwargs):
|
188 |
-
input_image, input_mask = image, mask
|
189 |
-
g = torch.Generator(device=we.device_id)
|
190 |
-
seed = seed if seed >= 0 else random.randint(0, 2**32 - 1)
|
191 |
-
g.manual_seed(int(seed))
|
192 |
-
|
193 |
-
if input_image is not None:
|
194 |
-
assert isinstance(input_image, list) and isinstance(
|
195 |
-
input_mask, list)
|
196 |
-
if task is None:
|
197 |
-
task = [''] * len(input_image)
|
198 |
-
if not isinstance(prompt, list):
|
199 |
-
prompt = [prompt] * len(input_image)
|
200 |
-
if history_io is not None and len(history_io) > 0:
|
201 |
-
his_image, his_maks, his_prompt, his_task = history_io[
|
202 |
-
'image'], history_io['mask'], history_io[
|
203 |
-
'prompt'], history_io['task']
|
204 |
-
assert len(his_image) == len(his_maks) == len(
|
205 |
-
his_prompt) == len(his_task)
|
206 |
-
input_image = his_image + input_image
|
207 |
-
input_mask = his_maks + input_mask
|
208 |
-
task = his_task + task
|
209 |
-
prompt = his_prompt + [prompt[-1]]
|
210 |
-
prompt = [
|
211 |
-
pp.replace('{image}', f'{{image{i}}}') if i > 0 else pp
|
212 |
-
for i, pp in enumerate(prompt)
|
213 |
-
]
|
214 |
-
|
215 |
-
edit_image, edit_image_mask = process_edit_image(
|
216 |
-
input_image, input_mask, task, max_seq_len=self.max_seq_len)
|
217 |
-
|
218 |
-
image, image_mask = edit_image[tar_index], edit_image_mask[
|
219 |
-
tar_index]
|
220 |
-
edit_image, edit_image_mask = [edit_image], [edit_image_mask]
|
221 |
-
|
222 |
-
else:
|
223 |
-
edit_image = edit_image_mask = [[]]
|
224 |
-
image = torch.zeros(
|
225 |
-
size=[3, int(output_height),
|
226 |
-
int(output_width)])
|
227 |
-
image_mask = torch.ones(
|
228 |
-
size=[1, int(output_height),
|
229 |
-
int(output_width)])
|
230 |
-
if not isinstance(prompt, list):
|
231 |
-
prompt = [prompt]
|
232 |
-
|
233 |
-
image, image_mask, prompt = [image], [image_mask], [prompt]
|
234 |
-
assert check_list_of_list(prompt) and check_list_of_list(
|
235 |
-
edit_image) and check_list_of_list(edit_image_mask)
|
236 |
-
# Assign Negative Prompt
|
237 |
-
if isinstance(negative_prompt, list):
|
238 |
-
negative_prompt = negative_prompt[0]
|
239 |
-
assert isinstance(negative_prompt, str)
|
240 |
-
|
241 |
-
n_prompt = copy.deepcopy(prompt)
|
242 |
-
for nn_p_id, nn_p in enumerate(n_prompt):
|
243 |
-
assert isinstance(nn_p, list)
|
244 |
-
n_prompt[nn_p_id][-1] = negative_prompt
|
245 |
-
|
246 |
-
ctx, null_ctx = {}, {}
|
247 |
-
|
248 |
-
# Get Noise Shape
|
249 |
-
image = to_device(image)
|
250 |
-
x = self.encode_first_stage(image)
|
251 |
-
noise = [
|
252 |
-
torch.empty(*i.shape, device=we.device_id).normal_(generator=g)
|
253 |
-
for i in x
|
254 |
-
]
|
255 |
-
noise, x_shapes = pack_imagelist_into_tensor(noise)
|
256 |
-
ctx['x_shapes'] = null_ctx['x_shapes'] = x_shapes
|
257 |
-
|
258 |
-
image_mask = to_device(image_mask, strict=False)
|
259 |
-
cond_mask = [self.interpolate_func(i) for i in image_mask
|
260 |
-
] if image_mask is not None else [None] * len(image)
|
261 |
-
ctx['x_mask'] = null_ctx['x_mask'] = cond_mask
|
262 |
-
|
263 |
-
# Encode Prompt
|
264 |
-
|
265 |
-
function_name, dtype = self.get_function_info(self.cond_stage_model)
|
266 |
-
cont, cont_mask = getattr(get_model(self.cond_stage_model),
|
267 |
-
function_name)(prompt)
|
268 |
-
cont, cont_mask = self.cond_stage_embeddings(prompt, edit_image, cont,
|
269 |
-
cont_mask)
|
270 |
-
null_cont, null_cont_mask = getattr(get_model(self.cond_stage_model),
|
271 |
-
function_name)(n_prompt)
|
272 |
-
null_cont, null_cont_mask = self.cond_stage_embeddings(
|
273 |
-
prompt, edit_image, null_cont, null_cont_mask)
|
274 |
-
ctx['crossattn'] = cont
|
275 |
-
null_ctx['crossattn'] = null_cont
|
276 |
-
|
277 |
-
# Encode Edit Images
|
278 |
-
edit_image = [to_device(i, strict=False) for i in edit_image]
|
279 |
-
edit_image_mask = [to_device(i, strict=False) for i in edit_image_mask]
|
280 |
-
e_img, e_mask = [], []
|
281 |
-
for u, m in zip(edit_image, edit_image_mask):
|
282 |
-
if u is None:
|
283 |
-
continue
|
284 |
-
if m is None:
|
285 |
-
m = [None] * len(u)
|
286 |
-
e_img.append(self.encode_first_stage(u, **kwargs))
|
287 |
-
e_mask.append([self.interpolate_func(i) for i in m])
|
288 |
-
|
289 |
-
null_ctx['edit'] = ctx['edit'] = e_img
|
290 |
-
null_ctx['edit_mask'] = ctx['edit_mask'] = e_mask
|
291 |
-
|
292 |
-
# Diffusion Process
|
293 |
-
function_name, dtype = self.get_function_info(self.diffusion_model)
|
294 |
-
with torch.autocast('cuda',
|
295 |
-
enabled=dtype in ('float16', 'bfloat16'),
|
296 |
-
dtype=getattr(torch, dtype)):
|
297 |
-
latent = self.diffusion.sample(
|
298 |
-
noise=noise,
|
299 |
-
sampler=sampler,
|
300 |
-
model=get_model(self.diffusion_model),
|
301 |
-
model_kwargs=[{
|
302 |
-
'cond':
|
303 |
-
ctx,
|
304 |
-
'mask':
|
305 |
-
cont_mask,
|
306 |
-
'text_position_embeddings':
|
307 |
-
self.text_position_embeddings.pos if hasattr(
|
308 |
-
self.text_position_embeddings, 'pos') else None
|
309 |
-
}, {
|
310 |
-
'cond':
|
311 |
-
null_ctx,
|
312 |
-
'mask':
|
313 |
-
null_cont_mask,
|
314 |
-
'text_position_embeddings':
|
315 |
-
self.text_position_embeddings.pos if hasattr(
|
316 |
-
self.text_position_embeddings, 'pos') else None
|
317 |
-
}] if guide_scale is not None and guide_scale > 1 else {
|
318 |
-
'cond':
|
319 |
-
null_ctx,
|
320 |
-
'mask':
|
321 |
-
cont_mask,
|
322 |
-
'text_position_embeddings':
|
323 |
-
self.text_position_embeddings.pos if hasattr(
|
324 |
-
self.text_position_embeddings, 'pos') else None
|
325 |
-
},
|
326 |
-
steps=sample_steps,
|
327 |
-
show_progress=True,
|
328 |
-
seed=seed,
|
329 |
-
guide_scale=guide_scale,
|
330 |
-
guide_rescale=guide_rescale,
|
331 |
-
return_intermediate=None,
|
332 |
-
**kwargs)
|
333 |
-
|
334 |
-
# Decode to Pixel Space
|
335 |
-
samples = unpack_tensor_into_imagelist(latent, x_shapes)
|
336 |
-
x_samples = self.decode_first_stage(samples)
|
337 |
-
|
338 |
-
imgs = [
|
339 |
-
torch.clamp((x_i + 1.0) / 2.0 + self.decoder_bias / 255,
|
340 |
-
min=0.0,
|
341 |
-
max=1.0).squeeze(0).permute(1, 2, 0).cpu().numpy()
|
342 |
-
for x_i in x_samples
|
343 |
-
]
|
344 |
-
imgs = [Image.fromarray((img * 255).astype(np.uint8)) for img in imgs]
|
345 |
-
return imgs
|
346 |
-
|
347 |
-
def cond_stage_embeddings(self, prompt, edit_image, cont, cont_mask):
|
348 |
-
if self.use_text_pos_embeddings and not torch.sum(
|
349 |
-
self.text_position_embeddings.pos) > 0:
|
350 |
-
identifier_cont, _ = getattr(get_model(self.cond_stage_model),
|
351 |
-
'encode')(self.text_indentifers,
|
352 |
-
return_mask=True)
|
353 |
-
self.text_position_embeddings.load_state_dict(
|
354 |
-
{'pos': identifier_cont[:, 0, :]})
|
355 |
-
|
356 |
-
cont_, cont_mask_ = [], []
|
357 |
-
for pp, edit, c, cm in zip(prompt, edit_image, cont, cont_mask):
|
358 |
-
if isinstance(pp, list):
|
359 |
-
cont_.append([c[-1], *c] if len(edit) > 0 else [c[-1]])
|
360 |
-
cont_mask_.append([cm[-1], *cm] if len(edit) > 0 else [cm[-1]])
|
361 |
-
else:
|
362 |
-
raise NotImplementedError
|
363 |
-
|
364 |
-
return cont_, cont_mask_
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
modules/__init__.py
DELETED
@@ -1 +0,0 @@
|
|
1 |
-
from . import model
|
|
|
|
modules/model/__init__.py
DELETED
@@ -1 +0,0 @@
|
|
1 |
-
from . import backbone, embedder, diffusion, network
|
|
|
|
modules/model/backbone/__init__.py
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
# -*- coding: utf-8 -*-
|
2 |
-
# Copyright (c) Alibaba, Inc. and its affiliates.
|
3 |
-
from .ace import DiTACE
|
|
|
|
|
|
|
|
modules/model/backbone/ace.py
DELETED
@@ -1,373 +0,0 @@
|
|
1 |
-
# -*- coding: utf-8 -*-
|
2 |
-
# Copyright (c) Alibaba, Inc. and its affiliates.
|
3 |
-
import re
|
4 |
-
from collections import OrderedDict
|
5 |
-
from functools import partial
|
6 |
-
|
7 |
-
import torch
|
8 |
-
import torch.nn as nn
|
9 |
-
from einops import rearrange
|
10 |
-
from torch.nn.utils.rnn import pad_sequence
|
11 |
-
from torch.utils.checkpoint import checkpoint_sequential
|
12 |
-
|
13 |
-
from scepter.modules.model.base_model import BaseModel
|
14 |
-
from scepter.modules.model.registry import BACKBONES
|
15 |
-
from scepter.modules.utils.config import dict_to_yaml
|
16 |
-
from scepter.modules.utils.file_system import FS
|
17 |
-
|
18 |
-
from .layers import (
|
19 |
-
Mlp,
|
20 |
-
TimestepEmbedder,
|
21 |
-
PatchEmbed,
|
22 |
-
DiTACEBlock,
|
23 |
-
T2IFinalLayer
|
24 |
-
)
|
25 |
-
from .pos_embed import rope_params
|
26 |
-
|
27 |
-
|
28 |
-
@BACKBONES.register_class()
|
29 |
-
class DiTACE(BaseModel):
|
30 |
-
|
31 |
-
para_dict = {
|
32 |
-
'PATCH_SIZE': {
|
33 |
-
'value': 2,
|
34 |
-
'description': ''
|
35 |
-
},
|
36 |
-
'IN_CHANNELS': {
|
37 |
-
'value': 4,
|
38 |
-
'description': ''
|
39 |
-
},
|
40 |
-
'HIDDEN_SIZE': {
|
41 |
-
'value': 1152,
|
42 |
-
'description': ''
|
43 |
-
},
|
44 |
-
'DEPTH': {
|
45 |
-
'value': 28,
|
46 |
-
'description': ''
|
47 |
-
},
|
48 |
-
'NUM_HEADS': {
|
49 |
-
'value': 16,
|
50 |
-
'description': ''
|
51 |
-
},
|
52 |
-
'MLP_RATIO': {
|
53 |
-
'value': 4.0,
|
54 |
-
'description': ''
|
55 |
-
},
|
56 |
-
'PRED_SIGMA': {
|
57 |
-
'value': True,
|
58 |
-
'description': ''
|
59 |
-
},
|
60 |
-
'DROP_PATH': {
|
61 |
-
'value': 0.,
|
62 |
-
'description': ''
|
63 |
-
},
|
64 |
-
'WINDOW_SIZE': {
|
65 |
-
'value': 0,
|
66 |
-
'description': ''
|
67 |
-
},
|
68 |
-
'WINDOW_BLOCK_INDEXES': {
|
69 |
-
'value': None,
|
70 |
-
'description': ''
|
71 |
-
},
|
72 |
-
'Y_CHANNELS': {
|
73 |
-
'value': 4096,
|
74 |
-
'description': ''
|
75 |
-
},
|
76 |
-
'ATTENTION_BACKEND': {
|
77 |
-
'value': None,
|
78 |
-
'description': ''
|
79 |
-
},
|
80 |
-
'QK_NORM': {
|
81 |
-
'value': True,
|
82 |
-
'description': 'Whether to use RMSNorm for query and key.',
|
83 |
-
},
|
84 |
-
}
|
85 |
-
para_dict.update(BaseModel.para_dict)
|
86 |
-
|
87 |
-
def __init__(self, cfg, logger):
|
88 |
-
super().__init__(cfg, logger=logger)
|
89 |
-
self.window_block_indexes = cfg.get('WINDOW_BLOCK_INDEXES', None)
|
90 |
-
if self.window_block_indexes is None:
|
91 |
-
self.window_block_indexes = []
|
92 |
-
self.pred_sigma = cfg.get('PRED_SIGMA', True)
|
93 |
-
self.in_channels = cfg.get('IN_CHANNELS', 4)
|
94 |
-
self.out_channels = self.in_channels * 2 if self.pred_sigma else self.in_channels
|
95 |
-
self.patch_size = cfg.get('PATCH_SIZE', 2)
|
96 |
-
self.num_heads = cfg.get('NUM_HEADS', 16)
|
97 |
-
self.hidden_size = cfg.get('HIDDEN_SIZE', 1152)
|
98 |
-
self.y_channels = cfg.get('Y_CHANNELS', 4096)
|
99 |
-
self.drop_path = cfg.get('DROP_PATH', 0.)
|
100 |
-
self.depth = cfg.get('DEPTH', 28)
|
101 |
-
self.mlp_ratio = cfg.get('MLP_RATIO', 4.0)
|
102 |
-
self.use_grad_checkpoint = cfg.get('USE_GRAD_CHECKPOINT', False)
|
103 |
-
self.attention_backend = cfg.get('ATTENTION_BACKEND', None)
|
104 |
-
self.max_seq_len = cfg.get('MAX_SEQ_LEN', 1024)
|
105 |
-
self.qk_norm = cfg.get('QK_NORM', False)
|
106 |
-
self.ignore_keys = cfg.get('IGNORE_KEYS', [])
|
107 |
-
assert (self.hidden_size % self.num_heads
|
108 |
-
) == 0 and (self.hidden_size // self.num_heads) % 2 == 0
|
109 |
-
d = self.hidden_size // self.num_heads
|
110 |
-
self.freqs = torch.cat(
|
111 |
-
[
|
112 |
-
rope_params(self.max_seq_len, d - 4 * (d // 6)), # T (~1/3)
|
113 |
-
rope_params(self.max_seq_len, 2 * (d // 6)), # H (~1/3)
|
114 |
-
rope_params(self.max_seq_len, 2 * (d // 6)) # W (~1/3)
|
115 |
-
],
|
116 |
-
dim=1)
|
117 |
-
|
118 |
-
# init embedder
|
119 |
-
self.x_embedder = PatchEmbed(self.patch_size,
|
120 |
-
self.in_channels + 1,
|
121 |
-
self.hidden_size,
|
122 |
-
bias=True,
|
123 |
-
flatten=False)
|
124 |
-
self.t_embedder = TimestepEmbedder(self.hidden_size)
|
125 |
-
self.y_embedder = Mlp(in_features=self.y_channels,
|
126 |
-
hidden_features=self.hidden_size,
|
127 |
-
out_features=self.hidden_size,
|
128 |
-
act_layer=lambda: nn.GELU(approximate='tanh'),
|
129 |
-
drop=0)
|
130 |
-
self.t_block = nn.Sequential(
|
131 |
-
nn.SiLU(),
|
132 |
-
nn.Linear(self.hidden_size, 6 * self.hidden_size, bias=True))
|
133 |
-
# init blocks
|
134 |
-
drop_path = [
|
135 |
-
x.item() for x in torch.linspace(0, self.drop_path, self.depth)
|
136 |
-
]
|
137 |
-
self.blocks = nn.ModuleList([
|
138 |
-
DiTACEBlock(self.hidden_size,
|
139 |
-
self.num_heads,
|
140 |
-
mlp_ratio=self.mlp_ratio,
|
141 |
-
drop_path=drop_path[i],
|
142 |
-
window_size=self.window_size
|
143 |
-
if i in self.window_block_indexes else 0,
|
144 |
-
backend=self.attention_backend,
|
145 |
-
use_condition=True,
|
146 |
-
qk_norm=self.qk_norm) for i in range(self.depth)
|
147 |
-
])
|
148 |
-
self.final_layer = T2IFinalLayer(self.hidden_size, self.patch_size,
|
149 |
-
self.out_channels)
|
150 |
-
self.initialize_weights()
|
151 |
-
|
152 |
-
def load_pretrained_model(self, pretrained_model):
|
153 |
-
if pretrained_model:
|
154 |
-
with FS.get_from(pretrained_model, wait_finish=True) as local_path:
|
155 |
-
model = torch.load(local_path, map_location='cpu')
|
156 |
-
if 'state_dict' in model:
|
157 |
-
model = model['state_dict']
|
158 |
-
new_ckpt = OrderedDict()
|
159 |
-
for k, v in model.items():
|
160 |
-
if self.ignore_keys is not None:
|
161 |
-
if (isinstance(self.ignore_keys, str) and re.match(self.ignore_keys, k)) or \
|
162 |
-
(isinstance(self.ignore_keys, list) and k in self.ignore_keys):
|
163 |
-
continue
|
164 |
-
k = k.replace('.cross_attn.q_linear.', '.cross_attn.q.')
|
165 |
-
k = k.replace('.cross_attn.proj.',
|
166 |
-
'.cross_attn.o.').replace(
|
167 |
-
'.attn.proj.', '.attn.o.')
|
168 |
-
if '.cross_attn.kv_linear.' in k:
|
169 |
-
k_p, v_p = torch.split(v, v.shape[0] // 2)
|
170 |
-
new_ckpt[k.replace('.cross_attn.kv_linear.',
|
171 |
-
'.cross_attn.k.')] = k_p
|
172 |
-
new_ckpt[k.replace('.cross_attn.kv_linear.',
|
173 |
-
'.cross_attn.v.')] = v_p
|
174 |
-
elif '.attn.qkv.' in k:
|
175 |
-
q_p, k_p, v_p = torch.split(v, v.shape[0] // 3)
|
176 |
-
new_ckpt[k.replace('.attn.qkv.', '.attn.q.')] = q_p
|
177 |
-
new_ckpt[k.replace('.attn.qkv.', '.attn.k.')] = k_p
|
178 |
-
new_ckpt[k.replace('.attn.qkv.', '.attn.v.')] = v_p
|
179 |
-
elif 'y_embedder.y_proj.' in k:
|
180 |
-
new_ckpt[k.replace('y_embedder.y_proj.',
|
181 |
-
'y_embedder.')] = v
|
182 |
-
elif k in ('x_embedder.proj.weight'):
|
183 |
-
model_p = self.state_dict()[k]
|
184 |
-
if v.shape != model_p.shape:
|
185 |
-
model_p.zero_()
|
186 |
-
model_p[:, :4, :, :].copy_(v)
|
187 |
-
new_ckpt[k] = torch.nn.parameter.Parameter(model_p)
|
188 |
-
else:
|
189 |
-
new_ckpt[k] = v
|
190 |
-
elif k in ('x_embedder.proj.bias'):
|
191 |
-
new_ckpt[k] = v
|
192 |
-
else:
|
193 |
-
new_ckpt[k] = v
|
194 |
-
missing, unexpected = self.load_state_dict(new_ckpt,
|
195 |
-
strict=False)
|
196 |
-
print(
|
197 |
-
f'Restored from {pretrained_model} with {len(missing)} missing and {len(unexpected)} unexpected keys'
|
198 |
-
)
|
199 |
-
if len(missing) > 0:
|
200 |
-
print(f'Missing Keys:\n {missing}')
|
201 |
-
if len(unexpected) > 0:
|
202 |
-
print(f'\nUnexpected Keys:\n {unexpected}')
|
203 |
-
|
204 |
-
def forward(self,
|
205 |
-
x,
|
206 |
-
t=None,
|
207 |
-
cond=dict(),
|
208 |
-
mask=None,
|
209 |
-
text_position_embeddings=None,
|
210 |
-
gc_seg=-1,
|
211 |
-
**kwargs):
|
212 |
-
if self.freqs.device != x.device:
|
213 |
-
self.freqs = self.freqs.to(x.device)
|
214 |
-
if isinstance(cond, dict):
|
215 |
-
context = cond.get('crossattn', None)
|
216 |
-
else:
|
217 |
-
context = cond
|
218 |
-
if text_position_embeddings is not None:
|
219 |
-
# default use the text_position_embeddings in state_dict
|
220 |
-
# if state_dict doesn't including this key, use the arg: text_position_embeddings
|
221 |
-
proj_position_embeddings = self.y_embedder(
|
222 |
-
text_position_embeddings)
|
223 |
-
else:
|
224 |
-
proj_position_embeddings = None
|
225 |
-
|
226 |
-
ctx_batch, txt_lens = [], []
|
227 |
-
if mask is not None and isinstance(mask, list):
|
228 |
-
for ctx, ctx_mask in zip(context, mask):
|
229 |
-
for frame_id, one_ctx in enumerate(zip(ctx, ctx_mask)):
|
230 |
-
u, m = one_ctx
|
231 |
-
t_len = m.flatten().sum() # l
|
232 |
-
u = u[:t_len]
|
233 |
-
u = self.y_embedder(u)
|
234 |
-
if frame_id == 0:
|
235 |
-
u = u + proj_position_embeddings[
|
236 |
-
len(ctx) -
|
237 |
-
1] if proj_position_embeddings is not None else u
|
238 |
-
else:
|
239 |
-
u = u + proj_position_embeddings[
|
240 |
-
frame_id -
|
241 |
-
1] if proj_position_embeddings is not None else u
|
242 |
-
ctx_batch.append(u)
|
243 |
-
txt_lens.append(t_len)
|
244 |
-
else:
|
245 |
-
raise TypeError
|
246 |
-
y = torch.cat(ctx_batch, dim=0)
|
247 |
-
txt_lens = torch.LongTensor(txt_lens).to(x.device, non_blocking=True)
|
248 |
-
|
249 |
-
batch_frames = []
|
250 |
-
for u, shape, m in zip(x, cond['x_shapes'], cond['x_mask']):
|
251 |
-
u = u[:, :shape[0] * shape[1]].view(-1, shape[0], shape[1])
|
252 |
-
m = torch.ones_like(u[[0], :, :]) if m is None else m.squeeze(0)
|
253 |
-
batch_frames.append([torch.cat([u, m], dim=0).unsqueeze(0)])
|
254 |
-
if 'edit' in cond:
|
255 |
-
for i, (edit, edit_mask) in enumerate(
|
256 |
-
zip(cond['edit'], cond['edit_mask'])):
|
257 |
-
if edit is None:
|
258 |
-
continue
|
259 |
-
for u, m in zip(edit, edit_mask):
|
260 |
-
u = u.squeeze(0)
|
261 |
-
m = torch.ones_like(
|
262 |
-
u[[0], :, :]) if m is None else m.squeeze(0)
|
263 |
-
batch_frames[i].append(
|
264 |
-
torch.cat([u, m], dim=0).unsqueeze(0))
|
265 |
-
|
266 |
-
patch_batch, shape_batch, self_x_len, cross_x_len = [], [], [], []
|
267 |
-
for frames in batch_frames:
|
268 |
-
patches, patch_shapes = [], []
|
269 |
-
self_x_len.append(0)
|
270 |
-
for frame_id, u in enumerate(frames):
|
271 |
-
u = self.x_embedder(u)
|
272 |
-
h, w = u.size(2), u.size(3)
|
273 |
-
u = rearrange(u, '1 c h w -> (h w) c')
|
274 |
-
if frame_id == 0:
|
275 |
-
u = u + proj_position_embeddings[
|
276 |
-
len(frames) -
|
277 |
-
1] if proj_position_embeddings is not None else u
|
278 |
-
else:
|
279 |
-
u = u + proj_position_embeddings[
|
280 |
-
frame_id -
|
281 |
-
1] if proj_position_embeddings is not None else u
|
282 |
-
patches.append(u)
|
283 |
-
patch_shapes.append([h, w])
|
284 |
-
cross_x_len.append(h * w) # b*s, 1
|
285 |
-
self_x_len[-1] += h * w # b, 1
|
286 |
-
# u = torch.cat(patches, dim=0)
|
287 |
-
patch_batch.extend(patches)
|
288 |
-
shape_batch.append(
|
289 |
-
torch.LongTensor(patch_shapes).to(x.device, non_blocking=True))
|
290 |
-
# repeat t to align with x
|
291 |
-
t = torch.cat([t[i].repeat(l) for i, l in enumerate(self_x_len)])
|
292 |
-
self_x_len, cross_x_len = (torch.LongTensor(self_x_len).to(
|
293 |
-
x.device, non_blocking=True), torch.LongTensor(cross_x_len).to(
|
294 |
-
x.device, non_blocking=True))
|
295 |
-
# x = pad_sequence(tuple(patch_batch), batch_first=True) # b, s*max(cl), c
|
296 |
-
x = torch.cat(patch_batch, dim=0)
|
297 |
-
x_shapes = pad_sequence(tuple(shape_batch),
|
298 |
-
batch_first=True) # b, max(len(frames)), 2
|
299 |
-
t = self.t_embedder(t) # (N, D)
|
300 |
-
t0 = self.t_block(t)
|
301 |
-
# y = self.y_embedder(context)
|
302 |
-
|
303 |
-
kwargs = dict(y=y,
|
304 |
-
t=t0,
|
305 |
-
x_shapes=x_shapes,
|
306 |
-
self_x_len=self_x_len,
|
307 |
-
cross_x_len=cross_x_len,
|
308 |
-
freqs=self.freqs,
|
309 |
-
txt_lens=txt_lens)
|
310 |
-
if self.use_grad_checkpoint and gc_seg >= 0:
|
311 |
-
x = checkpoint_sequential(
|
312 |
-
functions=[partial(block, **kwargs) for block in self.blocks],
|
313 |
-
segments=gc_seg if gc_seg > 0 else len(self.blocks),
|
314 |
-
input=x,
|
315 |
-
use_reentrant=False)
|
316 |
-
else:
|
317 |
-
for block in self.blocks:
|
318 |
-
x = block(x, **kwargs)
|
319 |
-
x = self.final_layer(x, t) # b*s*n, d
|
320 |
-
outs, cur_length = [], 0
|
321 |
-
p = self.patch_size
|
322 |
-
for seq_length, shape in zip(self_x_len, shape_batch):
|
323 |
-
x_i = x[cur_length:cur_length + seq_length]
|
324 |
-
h, w = shape[0].tolist()
|
325 |
-
u = x_i[:h * w].view(h, w, p, p, -1)
|
326 |
-
u = rearrange(u, 'h w p q c -> (h p w q) c'
|
327 |
-
) # dump into sequence for following tensor ops
|
328 |
-
cur_length = cur_length + seq_length
|
329 |
-
outs.append(u)
|
330 |
-
x = pad_sequence(tuple(outs), batch_first=True).permute(0, 2, 1)
|
331 |
-
if self.pred_sigma:
|
332 |
-
return x.chunk(2, dim=1)[0]
|
333 |
-
else:
|
334 |
-
return x
|
335 |
-
|
336 |
-
def initialize_weights(self):
|
337 |
-
# Initialize transformer layers:
|
338 |
-
def _basic_init(module):
|
339 |
-
if isinstance(module, nn.Linear):
|
340 |
-
torch.nn.init.xavier_uniform_(module.weight)
|
341 |
-
if module.bias is not None:
|
342 |
-
nn.init.constant_(module.bias, 0)
|
343 |
-
|
344 |
-
self.apply(_basic_init)
|
345 |
-
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
|
346 |
-
w = self.x_embedder.proj.weight.data
|
347 |
-
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
348 |
-
# Initialize timestep embedding MLP:
|
349 |
-
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
|
350 |
-
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
|
351 |
-
nn.init.normal_(self.t_block[1].weight, std=0.02)
|
352 |
-
# Initialize caption embedding MLP:
|
353 |
-
if hasattr(self, 'y_embedder'):
|
354 |
-
nn.init.normal_(self.y_embedder.fc1.weight, std=0.02)
|
355 |
-
nn.init.normal_(self.y_embedder.fc2.weight, std=0.02)
|
356 |
-
# Zero-out adaLN modulation layers
|
357 |
-
for block in self.blocks:
|
358 |
-
nn.init.constant_(block.cross_attn.o.weight, 0)
|
359 |
-
nn.init.constant_(block.cross_attn.o.bias, 0)
|
360 |
-
# Zero-out output layers:
|
361 |
-
nn.init.constant_(self.final_layer.linear.weight, 0)
|
362 |
-
nn.init.constant_(self.final_layer.linear.bias, 0)
|
363 |
-
|
364 |
-
@property
|
365 |
-
def dtype(self):
|
366 |
-
return next(self.parameters()).dtype
|
367 |
-
|
368 |
-
@staticmethod
|
369 |
-
def get_config_template():
|
370 |
-
return dict_to_yaml('BACKBONE',
|
371 |
-
__class__.__name__,
|
372 |
-
DiTACE.para_dict,
|
373 |
-
set_name=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
modules/model/backbone/layers.py
DELETED
@@ -1,386 +0,0 @@
|
|
1 |
-
# -*- coding: utf-8 -*-
|
2 |
-
# Copyright (c) Alibaba, Inc. and its affiliates.
|
3 |
-
import math
|
4 |
-
import warnings
|
5 |
-
import torch
|
6 |
-
import torch.nn as nn
|
7 |
-
from .pos_embed import rope_apply_multires as rope_apply
|
8 |
-
|
9 |
-
try:
|
10 |
-
from flash_attn import (flash_attn_varlen_func)
|
11 |
-
FLASHATTN_IS_AVAILABLE = True
|
12 |
-
except ImportError as e:
|
13 |
-
FLASHATTN_IS_AVAILABLE = False
|
14 |
-
flash_attn_varlen_func = None
|
15 |
-
warnings.warn(f'{e}')
|
16 |
-
|
17 |
-
__all__ = [
|
18 |
-
"drop_path",
|
19 |
-
"modulate",
|
20 |
-
"PatchEmbed",
|
21 |
-
"DropPath",
|
22 |
-
"RMSNorm",
|
23 |
-
"Mlp",
|
24 |
-
"TimestepEmbedder",
|
25 |
-
"DiTEditBlock",
|
26 |
-
"MultiHeadAttentionDiTEdit",
|
27 |
-
"T2IFinalLayer",
|
28 |
-
]
|
29 |
-
|
30 |
-
def drop_path(x, drop_prob: float = 0., training: bool = False):
|
31 |
-
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
32 |
-
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
|
33 |
-
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
34 |
-
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
|
35 |
-
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
|
36 |
-
'survival rate' as the argument.
|
37 |
-
"""
|
38 |
-
if drop_prob == 0. or not training:
|
39 |
-
return x
|
40 |
-
keep_prob = 1 - drop_prob
|
41 |
-
shape = (x.shape[0], ) + (1, ) * (
|
42 |
-
x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
43 |
-
random_tensor = keep_prob + torch.rand(
|
44 |
-
shape, dtype=x.dtype, device=x.device)
|
45 |
-
random_tensor.floor_() # binarize
|
46 |
-
output = x.div(keep_prob) * random_tensor
|
47 |
-
return output
|
48 |
-
|
49 |
-
|
50 |
-
def modulate(x, shift, scale, unsqueeze=False):
|
51 |
-
if unsqueeze:
|
52 |
-
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
53 |
-
else:
|
54 |
-
return x * (1 + scale) + shift
|
55 |
-
|
56 |
-
|
57 |
-
class PatchEmbed(nn.Module):
|
58 |
-
""" 2D Image to Patch Embedding
|
59 |
-
"""
|
60 |
-
def __init__(
|
61 |
-
self,
|
62 |
-
patch_size=16,
|
63 |
-
in_chans=3,
|
64 |
-
embed_dim=768,
|
65 |
-
norm_layer=None,
|
66 |
-
flatten=True,
|
67 |
-
bias=True,
|
68 |
-
):
|
69 |
-
super().__init__()
|
70 |
-
self.flatten = flatten
|
71 |
-
self.proj = nn.Conv2d(in_chans,
|
72 |
-
embed_dim,
|
73 |
-
kernel_size=patch_size,
|
74 |
-
stride=patch_size,
|
75 |
-
bias=bias)
|
76 |
-
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
77 |
-
|
78 |
-
def forward(self, x):
|
79 |
-
x = self.proj(x)
|
80 |
-
if self.flatten:
|
81 |
-
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
|
82 |
-
x = self.norm(x)
|
83 |
-
return x
|
84 |
-
|
85 |
-
|
86 |
-
class DropPath(nn.Module):
|
87 |
-
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
88 |
-
"""
|
89 |
-
def __init__(self, drop_prob=None):
|
90 |
-
super(DropPath, self).__init__()
|
91 |
-
self.drop_prob = drop_prob
|
92 |
-
|
93 |
-
def forward(self, x):
|
94 |
-
return drop_path(x, self.drop_prob, self.training)
|
95 |
-
|
96 |
-
|
97 |
-
class RMSNorm(nn.Module):
|
98 |
-
def __init__(self, dim, eps=1e-6):
|
99 |
-
super().__init__()
|
100 |
-
self.dim = dim
|
101 |
-
self.eps = eps
|
102 |
-
self.weight = nn.Parameter(torch.ones(dim))
|
103 |
-
|
104 |
-
def forward(self, x):
|
105 |
-
return self._norm(x.float()).type_as(x) * self.weight
|
106 |
-
|
107 |
-
def _norm(self, x):
|
108 |
-
return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
|
109 |
-
|
110 |
-
|
111 |
-
class Mlp(nn.Module):
|
112 |
-
""" MLP as used in Vision Transformer, MLP-Mixer and related networks
|
113 |
-
"""
|
114 |
-
def __init__(self,
|
115 |
-
in_features,
|
116 |
-
hidden_features=None,
|
117 |
-
out_features=None,
|
118 |
-
act_layer=nn.GELU,
|
119 |
-
drop=0.):
|
120 |
-
super().__init__()
|
121 |
-
out_features = out_features or in_features
|
122 |
-
hidden_features = hidden_features or in_features
|
123 |
-
self.fc1 = nn.Linear(in_features, hidden_features)
|
124 |
-
self.act = act_layer()
|
125 |
-
self.fc2 = nn.Linear(hidden_features, out_features)
|
126 |
-
self.drop = nn.Dropout(drop)
|
127 |
-
|
128 |
-
def forward(self, x):
|
129 |
-
x = self.fc1(x)
|
130 |
-
x = self.act(x)
|
131 |
-
x = self.drop(x)
|
132 |
-
x = self.fc2(x)
|
133 |
-
x = self.drop(x)
|
134 |
-
return x
|
135 |
-
|
136 |
-
|
137 |
-
class TimestepEmbedder(nn.Module):
|
138 |
-
"""
|
139 |
-
Embeds scalar timesteps into vector representations.
|
140 |
-
"""
|
141 |
-
def __init__(self, hidden_size, frequency_embedding_size=256):
|
142 |
-
super().__init__()
|
143 |
-
self.mlp = nn.Sequential(
|
144 |
-
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
|
145 |
-
nn.SiLU(),
|
146 |
-
nn.Linear(hidden_size, hidden_size, bias=True),
|
147 |
-
)
|
148 |
-
self.frequency_embedding_size = frequency_embedding_size
|
149 |
-
|
150 |
-
@staticmethod
|
151 |
-
def timestep_embedding(t, dim, max_period=10000):
|
152 |
-
"""
|
153 |
-
Create sinusoidal timestep embeddings.
|
154 |
-
:param t: a 1-D Tensor of N indices, one per batch element.
|
155 |
-
These may be fractional.
|
156 |
-
:param dim: the dimension of the output.
|
157 |
-
:param max_period: controls the minimum frequency of the embeddings.
|
158 |
-
:return: an (N, D) Tensor of positional embeddings.
|
159 |
-
"""
|
160 |
-
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
|
161 |
-
half = dim // 2
|
162 |
-
freqs = torch.exp(
|
163 |
-
-math.log(max_period) *
|
164 |
-
torch.arange(start=0, end=half, dtype=torch.float32) /
|
165 |
-
half).to(device=t.device)
|
166 |
-
args = t[:, None].float() * freqs[None]
|
167 |
-
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
168 |
-
if dim % 2:
|
169 |
-
embedding = torch.cat(
|
170 |
-
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
171 |
-
return embedding
|
172 |
-
|
173 |
-
def forward(self, t):
|
174 |
-
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
|
175 |
-
t_emb = self.mlp(t_freq)
|
176 |
-
return t_emb
|
177 |
-
|
178 |
-
|
179 |
-
class DiTACEBlock(nn.Module):
|
180 |
-
def __init__(self,
|
181 |
-
hidden_size,
|
182 |
-
num_heads,
|
183 |
-
mlp_ratio=4.0,
|
184 |
-
drop_path=0.,
|
185 |
-
window_size=0,
|
186 |
-
backend=None,
|
187 |
-
use_condition=True,
|
188 |
-
qk_norm=False,
|
189 |
-
**block_kwargs):
|
190 |
-
super().__init__()
|
191 |
-
self.hidden_size = hidden_size
|
192 |
-
self.use_condition = use_condition
|
193 |
-
self.norm1 = nn.LayerNorm(hidden_size,
|
194 |
-
elementwise_affine=False,
|
195 |
-
eps=1e-6)
|
196 |
-
self.attn = MultiHeadAttention(hidden_size,
|
197 |
-
num_heads=num_heads,
|
198 |
-
qkv_bias=True,
|
199 |
-
backend=backend,
|
200 |
-
qk_norm=qk_norm,
|
201 |
-
**block_kwargs)
|
202 |
-
if self.use_condition:
|
203 |
-
self.cross_attn = MultiHeadAttention(
|
204 |
-
hidden_size,
|
205 |
-
context_dim=hidden_size,
|
206 |
-
num_heads=num_heads,
|
207 |
-
qkv_bias=True,
|
208 |
-
backend=backend,
|
209 |
-
qk_norm=qk_norm,
|
210 |
-
**block_kwargs)
|
211 |
-
self.norm2 = nn.LayerNorm(hidden_size,
|
212 |
-
elementwise_affine=False,
|
213 |
-
eps=1e-6)
|
214 |
-
# to be compatible with lower version pytorch
|
215 |
-
approx_gelu = lambda: nn.GELU(approximate='tanh')
|
216 |
-
self.mlp = Mlp(in_features=hidden_size,
|
217 |
-
hidden_features=int(hidden_size * mlp_ratio),
|
218 |
-
act_layer=approx_gelu,
|
219 |
-
drop=0)
|
220 |
-
self.drop_path = DropPath(
|
221 |
-
drop_path) if drop_path > 0. else nn.Identity()
|
222 |
-
self.window_size = window_size
|
223 |
-
self.scale_shift_table = nn.Parameter(
|
224 |
-
torch.randn(6, hidden_size) / hidden_size**0.5)
|
225 |
-
|
226 |
-
def forward(self, x, y, t, **kwargs):
|
227 |
-
B = x.size(0)
|
228 |
-
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
|
229 |
-
self.scale_shift_table[None] + t.reshape(B, 6, -1)).chunk(6, dim=1)
|
230 |
-
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
|
231 |
-
shift_msa.squeeze(1), scale_msa.squeeze(1), gate_msa.squeeze(1),
|
232 |
-
shift_mlp.squeeze(1), scale_mlp.squeeze(1), gate_mlp.squeeze(1))
|
233 |
-
x = x + self.drop_path(gate_msa * self.attn(
|
234 |
-
modulate(self.norm1(x), shift_msa, scale_msa, unsqueeze=False), **
|
235 |
-
kwargs))
|
236 |
-
if self.use_condition:
|
237 |
-
x = x + self.cross_attn(x, context=y, **kwargs)
|
238 |
-
|
239 |
-
x = x + self.drop_path(gate_mlp * self.mlp(
|
240 |
-
modulate(self.norm2(x), shift_mlp, scale_mlp, unsqueeze=False)))
|
241 |
-
return x
|
242 |
-
|
243 |
-
|
244 |
-
class MultiHeadAttention(nn.Module):
|
245 |
-
def __init__(self,
|
246 |
-
dim,
|
247 |
-
context_dim=None,
|
248 |
-
num_heads=None,
|
249 |
-
head_dim=None,
|
250 |
-
attn_drop=0.0,
|
251 |
-
qkv_bias=False,
|
252 |
-
dropout=0.0,
|
253 |
-
backend=None,
|
254 |
-
qk_norm=False,
|
255 |
-
eps=1e-6,
|
256 |
-
**block_kwargs):
|
257 |
-
super().__init__()
|
258 |
-
# consider head_dim first, then num_heads
|
259 |
-
num_heads = dim // head_dim if head_dim else num_heads
|
260 |
-
head_dim = dim // num_heads
|
261 |
-
assert num_heads * head_dim == dim
|
262 |
-
context_dim = context_dim or dim
|
263 |
-
self.dim = dim
|
264 |
-
self.context_dim = context_dim
|
265 |
-
self.num_heads = num_heads
|
266 |
-
self.head_dim = head_dim
|
267 |
-
self.scale = math.pow(head_dim, -0.25)
|
268 |
-
# layers
|
269 |
-
self.q = nn.Linear(dim, dim, bias=qkv_bias)
|
270 |
-
self.k = nn.Linear(context_dim, dim, bias=qkv_bias)
|
271 |
-
self.v = nn.Linear(context_dim, dim, bias=qkv_bias)
|
272 |
-
self.o = nn.Linear(dim, dim)
|
273 |
-
self.norm_q = RMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
|
274 |
-
self.norm_k = RMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
|
275 |
-
|
276 |
-
self.dropout = nn.Dropout(dropout)
|
277 |
-
self.attention_op = None
|
278 |
-
self.attn_drop = nn.Dropout(attn_drop)
|
279 |
-
self.backend = backend
|
280 |
-
assert self.backend in ('flash_attn', 'xformer_attn', 'pytorch_attn',
|
281 |
-
None)
|
282 |
-
if FLASHATTN_IS_AVAILABLE and self.backend in ('flash_attn', None):
|
283 |
-
self.backend = 'flash_attn'
|
284 |
-
self.softmax_scale = block_kwargs.get('softmax_scale', None)
|
285 |
-
self.causal = block_kwargs.get('causal', False)
|
286 |
-
self.window_size = block_kwargs.get('window_size', (-1, -1))
|
287 |
-
self.deterministic = block_kwargs.get('deterministic', False)
|
288 |
-
else:
|
289 |
-
raise NotImplementedError
|
290 |
-
|
291 |
-
def flash_attn(self, x, context=None, **kwargs):
|
292 |
-
'''
|
293 |
-
The implementation will be very slow when mask is not None,
|
294 |
-
because we need rearange the x/context features according to mask.
|
295 |
-
Args:
|
296 |
-
x:
|
297 |
-
context:
|
298 |
-
mask:
|
299 |
-
**kwargs:
|
300 |
-
Returns: x
|
301 |
-
'''
|
302 |
-
dtype = kwargs.get('dtype', torch.float16)
|
303 |
-
|
304 |
-
def half(x):
|
305 |
-
return x if x.dtype in [torch.float16, torch.bfloat16
|
306 |
-
] else x.to(dtype)
|
307 |
-
|
308 |
-
x_shapes = kwargs['x_shapes']
|
309 |
-
freqs = kwargs['freqs']
|
310 |
-
self_x_len = kwargs['self_x_len']
|
311 |
-
cross_x_len = kwargs['cross_x_len']
|
312 |
-
txt_lens = kwargs['txt_lens']
|
313 |
-
n, d = self.num_heads, self.head_dim
|
314 |
-
|
315 |
-
if context is None:
|
316 |
-
# self-attn
|
317 |
-
q = self.norm_q(self.q(x)).view(-1, n, d)
|
318 |
-
k = self.norm_q(self.k(x)).view(-1, n, d)
|
319 |
-
v = self.v(x).view(-1, n, d)
|
320 |
-
q = rope_apply(q, self_x_len, x_shapes, freqs, pad=False)
|
321 |
-
k = rope_apply(k, self_x_len, x_shapes, freqs, pad=False)
|
322 |
-
q_lens = k_lens = self_x_len
|
323 |
-
else:
|
324 |
-
# cross-attn
|
325 |
-
q = self.norm_q(self.q(x)).view(-1, n, d)
|
326 |
-
k = self.norm_q(self.k(context)).view(-1, n, d)
|
327 |
-
v = self.v(context).view(-1, n, d)
|
328 |
-
q_lens = cross_x_len
|
329 |
-
k_lens = txt_lens
|
330 |
-
|
331 |
-
cu_seqlens_q = torch.cat([q_lens.new_zeros([1]),
|
332 |
-
q_lens]).cumsum(0, dtype=torch.int32)
|
333 |
-
cu_seqlens_k = torch.cat([k_lens.new_zeros([1]),
|
334 |
-
k_lens]).cumsum(0, dtype=torch.int32)
|
335 |
-
max_seqlen_q = q_lens.max()
|
336 |
-
max_seqlen_k = k_lens.max()
|
337 |
-
|
338 |
-
out_dtype = q.dtype
|
339 |
-
q, k, v = half(q), half(k), half(v)
|
340 |
-
x = flash_attn_varlen_func(q,
|
341 |
-
k,
|
342 |
-
v,
|
343 |
-
cu_seqlens_q=cu_seqlens_q,
|
344 |
-
cu_seqlens_k=cu_seqlens_k,
|
345 |
-
max_seqlen_q=max_seqlen_q,
|
346 |
-
max_seqlen_k=max_seqlen_k,
|
347 |
-
dropout_p=self.attn_drop.p,
|
348 |
-
softmax_scale=self.softmax_scale,
|
349 |
-
causal=self.causal,
|
350 |
-
window_size=self.window_size,
|
351 |
-
deterministic=self.deterministic)
|
352 |
-
|
353 |
-
x = x.type(out_dtype)
|
354 |
-
x = x.reshape(-1, n * d)
|
355 |
-
x = self.o(x)
|
356 |
-
x = self.dropout(x)
|
357 |
-
return x
|
358 |
-
|
359 |
-
def forward(self, x, context=None, **kwargs):
|
360 |
-
x = getattr(self, self.backend)(x, context=context, **kwargs)
|
361 |
-
return x
|
362 |
-
|
363 |
-
|
364 |
-
class T2IFinalLayer(nn.Module):
|
365 |
-
"""
|
366 |
-
The final layer of PixArt.
|
367 |
-
"""
|
368 |
-
def __init__(self, hidden_size, patch_size, out_channels):
|
369 |
-
super().__init__()
|
370 |
-
self.norm_final = nn.LayerNorm(hidden_size,
|
371 |
-
elementwise_affine=False,
|
372 |
-
eps=1e-6)
|
373 |
-
self.linear = nn.Linear(hidden_size,
|
374 |
-
patch_size * patch_size * out_channels,
|
375 |
-
bias=True)
|
376 |
-
self.scale_shift_table = nn.Parameter(
|
377 |
-
torch.randn(2, hidden_size) / hidden_size**0.5)
|
378 |
-
self.out_channels = out_channels
|
379 |
-
|
380 |
-
def forward(self, x, t):
|
381 |
-
shift, scale = (self.scale_shift_table[None] + t[:, None]).chunk(2,
|
382 |
-
dim=1)
|
383 |
-
shift, scale = shift.squeeze(1), scale.squeeze(1)
|
384 |
-
x = modulate(self.norm_final(x), shift, scale)
|
385 |
-
x = self.linear(x)
|
386 |
-
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
modules/model/backbone/pos_embed.py
DELETED
@@ -1,85 +0,0 @@
|
|
1 |
-
import numpy as np
|
2 |
-
from einops import rearrange
|
3 |
-
|
4 |
-
import torch
|
5 |
-
import torch.cuda.amp as amp
|
6 |
-
import torch.nn.functional as F
|
7 |
-
from torch.nn.utils.rnn import pad_sequence
|
8 |
-
|
9 |
-
def frame_pad(x, seq_len, shapes):
|
10 |
-
max_h, max_w = np.max(shapes, 0)
|
11 |
-
frames = []
|
12 |
-
cur_len = 0
|
13 |
-
for h, w in shapes:
|
14 |
-
frame_len = h * w
|
15 |
-
frames.append(
|
16 |
-
F.pad(
|
17 |
-
x[cur_len:cur_len + frame_len].view(h, w, -1),
|
18 |
-
(0, 0, 0, max_w - w, 0, max_h - h)) # .view(max_h * max_w, -1)
|
19 |
-
)
|
20 |
-
cur_len += frame_len
|
21 |
-
if cur_len >= seq_len:
|
22 |
-
break
|
23 |
-
return torch.stack(frames)
|
24 |
-
|
25 |
-
|
26 |
-
def frame_unpad(x, shapes):
|
27 |
-
max_h, max_w = np.max(shapes, 0)
|
28 |
-
x = rearrange(x, '(b h w) n c -> b h w n c', h=max_h, w=max_w)
|
29 |
-
frames = []
|
30 |
-
for i, (h, w) in enumerate(shapes):
|
31 |
-
if i >= len(x):
|
32 |
-
break
|
33 |
-
frames.append(rearrange(x[i, :h, :w], 'h w n c -> (h w) n c'))
|
34 |
-
return torch.concat(frames)
|
35 |
-
|
36 |
-
|
37 |
-
@amp.autocast(enabled=False)
|
38 |
-
def rope_apply_multires(x, x_lens, x_shapes, freqs, pad=True):
|
39 |
-
"""
|
40 |
-
x: [B*L, N, C].
|
41 |
-
x_lens: [B].
|
42 |
-
x_shapes: [B, F, 2].
|
43 |
-
freqs: [M, C // 2].
|
44 |
-
"""
|
45 |
-
n, c = x.size(1), x.size(2) // 2
|
46 |
-
# split freqs
|
47 |
-
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
|
48 |
-
# loop over samples
|
49 |
-
output = []
|
50 |
-
st = 0
|
51 |
-
for i, (seq_len,
|
52 |
-
shapes) in enumerate(zip(x_lens.tolist(), x_shapes.tolist())):
|
53 |
-
x_i = frame_pad(x[st:st + seq_len], seq_len, shapes) # f, h, w, c
|
54 |
-
f, h, w = x_i.shape[:3]
|
55 |
-
pad_seq_len = f * h * w
|
56 |
-
# precompute multipliers
|
57 |
-
x_i = torch.view_as_complex(
|
58 |
-
x_i.to(torch.float64).reshape(pad_seq_len, n, -1, 2))
|
59 |
-
freqs_i = torch.cat([
|
60 |
-
freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
|
61 |
-
freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
|
62 |
-
freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
|
63 |
-
],
|
64 |
-
dim=-1).reshape(pad_seq_len, 1, -1)
|
65 |
-
# apply rotary embedding
|
66 |
-
x_i = torch.view_as_real(x_i * freqs_i).flatten(2).type_as(x)
|
67 |
-
x_i = frame_unpad(x_i, shapes)
|
68 |
-
# append to collection
|
69 |
-
output.append(x_i)
|
70 |
-
st += seq_len
|
71 |
-
return pad_sequence(output) if pad else torch.concat(output)
|
72 |
-
|
73 |
-
|
74 |
-
@amp.autocast(enabled=False)
|
75 |
-
def rope_params(max_seq_len, dim, theta=10000):
|
76 |
-
"""
|
77 |
-
Precompute the frequency tensor for complex exponentials.
|
78 |
-
"""
|
79 |
-
assert dim % 2 == 0
|
80 |
-
freqs = torch.outer(
|
81 |
-
torch.arange(max_seq_len),
|
82 |
-
1.0 / torch.pow(theta,
|
83 |
-
torch.arange(0, dim, 2).to(torch.float64).div(dim)))
|
84 |
-
freqs = torch.polar(torch.ones_like(freqs), freqs)
|
85 |
-
return freqs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
modules/model/diffusion/__init__.py
DELETED
@@ -1,6 +0,0 @@
|
|
1 |
-
# -*- coding: utf-8 -*-
|
2 |
-
# Copyright (c) Alibaba, Inc. and its affiliates.
|
3 |
-
|
4 |
-
from .diffusions import ACEDiffusion
|
5 |
-
from .samplers import DDIMSampler
|
6 |
-
from .schedules import LinearScheduler
|
|
|
|
|
|
|
|
|
|
|
|
|
|
modules/model/diffusion/diffusions.py
DELETED
@@ -1,206 +0,0 @@
|
|
1 |
-
# -*- coding: utf-8 -*-
|
2 |
-
# Copyright (c) Alibaba, Inc. and its affiliates.
|
3 |
-
import math
|
4 |
-
import os
|
5 |
-
from collections import OrderedDict
|
6 |
-
|
7 |
-
import torch
|
8 |
-
from tqdm import trange
|
9 |
-
|
10 |
-
from scepter.modules.model.registry import (DIFFUSION_SAMPLERS, DIFFUSIONS,
|
11 |
-
NOISE_SCHEDULERS)
|
12 |
-
from scepter.modules.utils.config import Config, dict_to_yaml
|
13 |
-
from scepter.modules.utils.distribute import we
|
14 |
-
from scepter.modules.utils.file_system import FS
|
15 |
-
|
16 |
-
|
17 |
-
@DIFFUSIONS.register_class()
|
18 |
-
class ACEDiffusion(object):
|
19 |
-
para_dict = {
|
20 |
-
'NOISE_SCHEDULER': {},
|
21 |
-
'SAMPLER_SCHEDULER': {},
|
22 |
-
'MIN_SNR_GAMMA': {
|
23 |
-
'value': None,
|
24 |
-
'description': 'The minimum SNR gamma value for the loss function.'
|
25 |
-
},
|
26 |
-
'PREDICTION_TYPE': {
|
27 |
-
'value': 'eps',
|
28 |
-
'description':
|
29 |
-
'The type of prediction to use for the loss function.'
|
30 |
-
}
|
31 |
-
}
|
32 |
-
|
33 |
-
def __init__(self, cfg, logger=None):
|
34 |
-
super(ACEDiffusion, self).__init__()
|
35 |
-
self.logger = logger
|
36 |
-
self.cfg = cfg
|
37 |
-
self.init_params()
|
38 |
-
|
39 |
-
def init_params(self):
|
40 |
-
self.min_snr_gamma = self.cfg.get('MIN_SNR_GAMMA', None)
|
41 |
-
self.prediction_type = self.cfg.get('PREDICTION_TYPE', 'eps')
|
42 |
-
self.noise_scheduler = NOISE_SCHEDULERS.build(self.cfg.NOISE_SCHEDULER,
|
43 |
-
logger=self.logger)
|
44 |
-
self.sampler_scheduler = NOISE_SCHEDULERS.build(self.cfg.get(
|
45 |
-
'SAMPLER_SCHEDULER', self.cfg.NOISE_SCHEDULER),
|
46 |
-
logger=self.logger)
|
47 |
-
self.num_timesteps = self.noise_scheduler.num_timesteps
|
48 |
-
if self.cfg.have('WORK_DIR') and we.rank == 0:
|
49 |
-
schedule_visualization = os.path.join(self.cfg.WORK_DIR,
|
50 |
-
'noise_schedule.png')
|
51 |
-
with FS.put_to(schedule_visualization) as local_path:
|
52 |
-
self.noise_scheduler.plot_noise_sampling_map(local_path)
|
53 |
-
schedule_visualization = os.path.join(self.cfg.WORK_DIR,
|
54 |
-
'sampler_schedule.png')
|
55 |
-
with FS.put_to(schedule_visualization) as local_path:
|
56 |
-
self.sampler_scheduler.plot_noise_sampling_map(local_path)
|
57 |
-
|
58 |
-
def sample(self,
|
59 |
-
noise,
|
60 |
-
model,
|
61 |
-
model_kwargs={},
|
62 |
-
steps=20,
|
63 |
-
sampler=None,
|
64 |
-
use_dynamic_cfg=False,
|
65 |
-
guide_scale=None,
|
66 |
-
guide_rescale=None,
|
67 |
-
show_progress=False,
|
68 |
-
return_intermediate=None,
|
69 |
-
intermediate_callback=None,
|
70 |
-
**kwargs):
|
71 |
-
assert isinstance(steps, (int, torch.LongTensor))
|
72 |
-
assert return_intermediate in (None, 'x0', 'xt')
|
73 |
-
assert isinstance(sampler, (str, dict, Config))
|
74 |
-
intermediates = []
|
75 |
-
|
76 |
-
def callback_fn(x_t, t, sigma=None, alpha=None):
|
77 |
-
timestamp = t
|
78 |
-
t = t.repeat(len(x_t)).round().long().to(x_t.device)
|
79 |
-
sigma = sigma.repeat(len(x_t), *([1] * (len(sigma.shape) - 1)))
|
80 |
-
alpha = alpha.repeat(len(x_t), *([1] * (len(alpha.shape) - 1)))
|
81 |
-
|
82 |
-
if guide_scale is None or guide_scale == 1.0:
|
83 |
-
out = model(x=x_t, t=t, **model_kwargs)
|
84 |
-
else:
|
85 |
-
if use_dynamic_cfg:
|
86 |
-
guidance_scale = 1 + guide_scale * (
|
87 |
-
(1 - math.cos(math.pi * (
|
88 |
-
(steps - timestamp.item()) / steps)**5.0)) / 2)
|
89 |
-
else:
|
90 |
-
guidance_scale = guide_scale
|
91 |
-
y_out = model(x=x_t, t=t, **model_kwargs[0])
|
92 |
-
u_out = model(x=x_t, t=t, **model_kwargs[1])
|
93 |
-
out = u_out + guidance_scale * (y_out - u_out)
|
94 |
-
if guide_rescale is not None and guide_rescale > 0.0:
|
95 |
-
ratio = (
|
96 |
-
y_out.flatten(1).std(dim=1) /
|
97 |
-
(out.flatten(1).std(dim=1) + 1e-12)).view((-1, ) + (1, ) *
|
98 |
-
(y_out.ndim - 1))
|
99 |
-
out *= guide_rescale * ratio + (1 - guide_rescale) * 1.0
|
100 |
-
|
101 |
-
if self.prediction_type == 'x0':
|
102 |
-
x0 = out
|
103 |
-
elif self.prediction_type == 'eps':
|
104 |
-
x0 = (x_t - sigma * out) / alpha
|
105 |
-
elif self.prediction_type == 'v':
|
106 |
-
x0 = alpha * x_t - sigma * out
|
107 |
-
else:
|
108 |
-
raise NotImplementedError(
|
109 |
-
f'prediction_type {self.prediction_type} not implemented')
|
110 |
-
|
111 |
-
return x0
|
112 |
-
|
113 |
-
sampler_ins = self.get_sampler(sampler)
|
114 |
-
|
115 |
-
# this is ignored for schnell
|
116 |
-
sampler_output = sampler_ins.preprare_sampler(
|
117 |
-
noise,
|
118 |
-
steps=steps,
|
119 |
-
prediction_type=self.prediction_type,
|
120 |
-
scheduler_ins=self.sampler_scheduler,
|
121 |
-
callback_fn=callback_fn)
|
122 |
-
|
123 |
-
for _ in trange(steps, disable=not show_progress):
|
124 |
-
trange.desc = sampler_output.msg
|
125 |
-
sampler_output = sampler_ins.step(sampler_output)
|
126 |
-
if return_intermediate == 'x_0':
|
127 |
-
intermediates.append(sampler_output.x_0)
|
128 |
-
elif return_intermediate == 'x_t':
|
129 |
-
intermediates.append(sampler_output.x_t)
|
130 |
-
if intermediate_callback is not None:
|
131 |
-
intermediate_callback(intermediates[-1])
|
132 |
-
return (sampler_output.x_0, intermediates
|
133 |
-
) if return_intermediate is not None else sampler_output.x_0
|
134 |
-
|
135 |
-
def loss(self,
|
136 |
-
x_0,
|
137 |
-
model,
|
138 |
-
model_kwargs={},
|
139 |
-
reduction='mean',
|
140 |
-
noise=None,
|
141 |
-
**kwargs):
|
142 |
-
# use noise scheduler to add noise
|
143 |
-
if noise is None:
|
144 |
-
noise = torch.randn_like(x_0)
|
145 |
-
schedule_output = self.noise_scheduler.add_noise(x_0, noise, **kwargs)
|
146 |
-
x_t, t, sigma, alpha = schedule_output.x_t, schedule_output.t, schedule_output.sigma, schedule_output.alpha
|
147 |
-
out = model(x=x_t, t=t, **model_kwargs)
|
148 |
-
|
149 |
-
# mse loss
|
150 |
-
target = {
|
151 |
-
'eps': noise,
|
152 |
-
'x0': x_0,
|
153 |
-
'v': alpha * noise - sigma * x_0
|
154 |
-
}[self.prediction_type]
|
155 |
-
|
156 |
-
loss = (out - target).pow(2)
|
157 |
-
if reduction == 'mean':
|
158 |
-
loss = loss.flatten(1).mean(dim=1)
|
159 |
-
|
160 |
-
if self.min_snr_gamma is not None:
|
161 |
-
alphas = self.noise_scheduler.alphas.to(x_0.device)[t]
|
162 |
-
sigmas = self.noise_scheduler.sigmas.pow(2).to(x_0.device)[t]
|
163 |
-
snrs = (alphas / sigmas).clamp(min=1e-20)
|
164 |
-
min_snrs = snrs.clamp(max=self.min_snr_gamma)
|
165 |
-
weights = min_snrs / snrs
|
166 |
-
else:
|
167 |
-
weights = 1
|
168 |
-
|
169 |
-
loss = loss * weights
|
170 |
-
return loss
|
171 |
-
|
172 |
-
def get_sampler(self, sampler):
|
173 |
-
if isinstance(sampler, str):
|
174 |
-
if sampler not in DIFFUSION_SAMPLERS.class_map:
|
175 |
-
if self.logger is not None:
|
176 |
-
self.logger.info(
|
177 |
-
f'{sampler} not in the defined samplers list {DIFFUSION_SAMPLERS.class_map.keys()}'
|
178 |
-
)
|
179 |
-
else:
|
180 |
-
print(
|
181 |
-
f'{sampler} not in the defined samplers list {DIFFUSION_SAMPLERS.class_map.keys()}'
|
182 |
-
)
|
183 |
-
return None
|
184 |
-
sampler_cfg = Config(cfg_dict={'NAME': sampler}, load=False)
|
185 |
-
sampler_ins = DIFFUSION_SAMPLERS.build(sampler_cfg,
|
186 |
-
logger=self.logger)
|
187 |
-
elif isinstance(sampler, (Config, dict, OrderedDict)):
|
188 |
-
if isinstance(sampler, (dict, OrderedDict)):
|
189 |
-
sampler = Config(
|
190 |
-
cfg_dict={k.upper(): v
|
191 |
-
for k, v in dict(sampler).items()},
|
192 |
-
load=False)
|
193 |
-
sampler_ins = DIFFUSION_SAMPLERS.build(sampler, logger=self.logger)
|
194 |
-
else:
|
195 |
-
raise NotImplementedError
|
196 |
-
return sampler_ins
|
197 |
-
|
198 |
-
def __repr__(self) -> str:
|
199 |
-
return f'{self.__class__.__name__}' + ' ' + super().__repr__()
|
200 |
-
|
201 |
-
@staticmethod
|
202 |
-
def get_config_template():
|
203 |
-
return dict_to_yaml('DIFFUSIONS',
|
204 |
-
__class__.__name__,
|
205 |
-
ACEDiffusion.para_dict,
|
206 |
-
set_name=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
modules/model/diffusion/samplers.py
DELETED
@@ -1,69 +0,0 @@
|
|
1 |
-
# -*- coding: utf-8 -*-
|
2 |
-
# Copyright (c) Alibaba, Inc. and its affiliates.
|
3 |
-
import torch
|
4 |
-
|
5 |
-
from scepter.modules.model.registry import DIFFUSION_SAMPLERS
|
6 |
-
from scepter.modules.model.diffusion.samplers import BaseDiffusionSampler
|
7 |
-
from scepter.modules.model.diffusion.util import _i
|
8 |
-
|
9 |
-
def _i(tensor, t, x):
|
10 |
-
"""
|
11 |
-
Index tensor using t and format the output according to x.
|
12 |
-
"""
|
13 |
-
shape = (x.size(0), ) + (1, ) * (x.ndim - 1)
|
14 |
-
if isinstance(t, torch.Tensor):
|
15 |
-
t = t.to(tensor.device)
|
16 |
-
return tensor[t].view(shape).to(x.device)
|
17 |
-
|
18 |
-
|
19 |
-
@DIFFUSION_SAMPLERS.register_class('ddim')
|
20 |
-
class DDIMSampler(BaseDiffusionSampler):
|
21 |
-
def init_params(self):
|
22 |
-
super().init_params()
|
23 |
-
self.eta = self.cfg.get('ETA', 0.)
|
24 |
-
self.discretization_type = self.cfg.get('DISCRETIZATION_TYPE',
|
25 |
-
'trailing')
|
26 |
-
|
27 |
-
def preprare_sampler(self,
|
28 |
-
noise,
|
29 |
-
steps=20,
|
30 |
-
scheduler_ins=None,
|
31 |
-
prediction_type='',
|
32 |
-
sigmas=None,
|
33 |
-
betas=None,
|
34 |
-
alphas=None,
|
35 |
-
callback_fn=None,
|
36 |
-
**kwargs):
|
37 |
-
output = super().preprare_sampler(noise, steps, scheduler_ins,
|
38 |
-
prediction_type, sigmas, betas,
|
39 |
-
alphas, callback_fn, **kwargs)
|
40 |
-
sigmas = output.sigmas
|
41 |
-
sigmas = torch.cat([sigmas, sigmas.new_zeros([1])])
|
42 |
-
sigmas_vp = (sigmas**2 / (1 + sigmas**2))**0.5
|
43 |
-
sigmas_vp[sigmas == float('inf')] = 1.
|
44 |
-
output.add_custom_field('sigmas_vp', sigmas_vp)
|
45 |
-
return output
|
46 |
-
|
47 |
-
def step(self, sampler_output):
|
48 |
-
x_t = sampler_output.x_t
|
49 |
-
step = sampler_output.step
|
50 |
-
t = sampler_output.ts[step]
|
51 |
-
sigmas_vp = sampler_output.sigmas_vp.to(x_t.device)
|
52 |
-
alpha_init = _i(sampler_output.alphas_init, step, x_t[:1])
|
53 |
-
sigma_init = _i(sampler_output.sigmas_init, step, x_t[:1])
|
54 |
-
|
55 |
-
x = sampler_output.callback_fn(x_t, t, sigma_init, alpha_init)
|
56 |
-
noise_factor = self.eta * (sigmas_vp[step + 1]**2 /
|
57 |
-
sigmas_vp[step]**2 *
|
58 |
-
(1 - (1 - sigmas_vp[step]**2) /
|
59 |
-
(1 - sigmas_vp[step + 1]**2)))
|
60 |
-
d = (x_t - (1 - sigmas_vp[step]**2)**0.5 * x) / sigmas_vp[step]
|
61 |
-
x = (1 - sigmas_vp[step + 1] ** 2) ** 0.5 * x + \
|
62 |
-
(sigmas_vp[step + 1] ** 2 - noise_factor ** 2) ** 0.5 * d
|
63 |
-
sampler_output.x_0 = x
|
64 |
-
if sigmas_vp[step + 1] > 0:
|
65 |
-
x += noise_factor * torch.randn_like(x)
|
66 |
-
sampler_output.x_t = x
|
67 |
-
sampler_output.step += 1
|
68 |
-
sampler_output.msg = f'step {step}'
|
69 |
-
return sampler_output
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
modules/model/diffusion/schedules.py
DELETED
@@ -1,30 +0,0 @@
|
|
1 |
-
# -*- coding: utf-8 -*-
|
2 |
-
# Copyright (c) Alibaba, Inc. and its affiliates.
|
3 |
-
import torch
|
4 |
-
|
5 |
-
from scepter.modules.model.registry import NOISE_SCHEDULERS
|
6 |
-
from scepter.modules.model.diffusion.schedules import BaseNoiseScheduler
|
7 |
-
|
8 |
-
|
9 |
-
@NOISE_SCHEDULERS.register_class()
|
10 |
-
class LinearScheduler(BaseNoiseScheduler):
|
11 |
-
para_dict = {}
|
12 |
-
|
13 |
-
def init_params(self):
|
14 |
-
super().init_params()
|
15 |
-
self.beta_min = self.cfg.get('BETA_MIN', 0.00085)
|
16 |
-
self.beta_max = self.cfg.get('BETA_MAX', 0.012)
|
17 |
-
|
18 |
-
def betas_to_sigmas(self, betas):
|
19 |
-
return torch.sqrt(1 - torch.cumprod(1 - betas, dim=0))
|
20 |
-
|
21 |
-
def get_schedule(self):
|
22 |
-
betas = torch.linspace(self.beta_min,
|
23 |
-
self.beta_max,
|
24 |
-
self.num_timesteps,
|
25 |
-
dtype=torch.float32)
|
26 |
-
sigmas = self.betas_to_sigmas(betas)
|
27 |
-
self._sigmas = sigmas
|
28 |
-
self._betas = betas
|
29 |
-
self._alphas = torch.sqrt(1 - sigmas**2)
|
30 |
-
self._timesteps = torch.arange(len(sigmas), dtype=torch.float32)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
modules/model/embedder/__init__.py
DELETED
@@ -1 +0,0 @@
|
|
1 |
-
from .embedder import ACETextEmbedder
|
|
|
|
modules/model/embedder/embedder.py
DELETED
@@ -1,184 +0,0 @@
|
|
1 |
-
# -*- coding: utf-8 -*-
|
2 |
-
# Copyright (c) Alibaba, Inc. and its affiliates.
|
3 |
-
import warnings
|
4 |
-
from contextlib import nullcontext
|
5 |
-
|
6 |
-
import torch
|
7 |
-
import torch.nn.functional as F
|
8 |
-
import torch.utils.dlpack
|
9 |
-
from scepter.modules.model.embedder.base_embedder import BaseEmbedder
|
10 |
-
from scepter.modules.model.registry import EMBEDDERS
|
11 |
-
from scepter.modules.model.tokenizer.tokenizer_component import (
|
12 |
-
basic_clean, canonicalize, heavy_clean, whitespace_clean)
|
13 |
-
from scepter.modules.utils.config import dict_to_yaml
|
14 |
-
from scepter.modules.utils.distribute import we
|
15 |
-
from scepter.modules.utils.file_system import FS
|
16 |
-
|
17 |
-
try:
|
18 |
-
from transformers import AutoTokenizer, T5EncoderModel
|
19 |
-
except Exception as e:
|
20 |
-
warnings.warn(
|
21 |
-
f'Import transformers error, please deal with this problem: {e}')
|
22 |
-
|
23 |
-
|
24 |
-
@EMBEDDERS.register_class()
|
25 |
-
class ACETextEmbedder(BaseEmbedder):
|
26 |
-
"""
|
27 |
-
Uses the OpenCLIP transformer encoder for text
|
28 |
-
"""
|
29 |
-
"""
|
30 |
-
Uses the OpenCLIP transformer encoder for text
|
31 |
-
"""
|
32 |
-
para_dict = {
|
33 |
-
'PRETRAINED_MODEL': {
|
34 |
-
'value':
|
35 |
-
'google/umt5-small',
|
36 |
-
'description':
|
37 |
-
'Pretrained Model for umt5, modelcard path or local path.'
|
38 |
-
},
|
39 |
-
'TOKENIZER_PATH': {
|
40 |
-
'value': 'google/umt5-small',
|
41 |
-
'description':
|
42 |
-
'Tokenizer Path for umt5, modelcard path or local path.'
|
43 |
-
},
|
44 |
-
'FREEZE': {
|
45 |
-
'value': True,
|
46 |
-
'description': ''
|
47 |
-
},
|
48 |
-
'USE_GRAD': {
|
49 |
-
'value': False,
|
50 |
-
'description': 'Compute grad or not.'
|
51 |
-
},
|
52 |
-
'CLEAN': {
|
53 |
-
'value':
|
54 |
-
'whitespace',
|
55 |
-
'description':
|
56 |
-
'Set the clean strtegy for tokenizer, used when TOKENIZER_PATH is not None.'
|
57 |
-
},
|
58 |
-
'LAYER': {
|
59 |
-
'value': 'last',
|
60 |
-
'description': ''
|
61 |
-
},
|
62 |
-
'LEGACY': {
|
63 |
-
'value':
|
64 |
-
True,
|
65 |
-
'description':
|
66 |
-
'Whether use legacy returnd feature or not ,default True.'
|
67 |
-
}
|
68 |
-
}
|
69 |
-
|
70 |
-
def __init__(self, cfg, logger=None):
|
71 |
-
super().__init__(cfg, logger=logger)
|
72 |
-
pretrained_path = cfg.get('PRETRAINED_MODEL', None)
|
73 |
-
self.t5_dtype = cfg.get('T5_DTYPE', 'float32')
|
74 |
-
assert pretrained_path
|
75 |
-
with FS.get_dir_to_local_dir(pretrained_path,
|
76 |
-
wait_finish=True) as local_path:
|
77 |
-
self.model = T5EncoderModel.from_pretrained(
|
78 |
-
local_path,
|
79 |
-
torch_dtype=getattr(
|
80 |
-
torch,
|
81 |
-
'float' if self.t5_dtype == 'float32' else self.t5_dtype))
|
82 |
-
tokenizer_path = cfg.get('TOKENIZER_PATH', None)
|
83 |
-
self.length = cfg.get('LENGTH', 77)
|
84 |
-
|
85 |
-
self.use_grad = cfg.get('USE_GRAD', False)
|
86 |
-
self.clean = cfg.get('CLEAN', 'whitespace')
|
87 |
-
self.added_identifier = cfg.get('ADDED_IDENTIFIER', None)
|
88 |
-
if tokenizer_path:
|
89 |
-
self.tokenize_kargs = {'return_tensors': 'pt'}
|
90 |
-
with FS.get_dir_to_local_dir(tokenizer_path,
|
91 |
-
wait_finish=True) as local_path:
|
92 |
-
if self.added_identifier is not None and isinstance(
|
93 |
-
self.added_identifier, list):
|
94 |
-
self.tokenizer = AutoTokenizer.from_pretrained(local_path)
|
95 |
-
else:
|
96 |
-
self.tokenizer = AutoTokenizer.from_pretrained(local_path)
|
97 |
-
if self.length is not None:
|
98 |
-
self.tokenize_kargs.update({
|
99 |
-
'padding': 'max_length',
|
100 |
-
'truncation': True,
|
101 |
-
'max_length': self.length
|
102 |
-
})
|
103 |
-
self.eos_token = self.tokenizer(
|
104 |
-
self.tokenizer.eos_token)['input_ids'][0]
|
105 |
-
else:
|
106 |
-
self.tokenizer = None
|
107 |
-
self.tokenize_kargs = {}
|
108 |
-
|
109 |
-
self.use_grad = cfg.get('USE_GRAD', False)
|
110 |
-
self.clean = cfg.get('CLEAN', 'whitespace')
|
111 |
-
|
112 |
-
def freeze(self):
|
113 |
-
self.model = self.model.eval()
|
114 |
-
for param in self.parameters():
|
115 |
-
param.requires_grad = False
|
116 |
-
|
117 |
-
# encode && encode_text
|
118 |
-
def forward(self, tokens, return_mask=False, use_mask=True):
|
119 |
-
# tokenization
|
120 |
-
embedding_context = nullcontext if self.use_grad else torch.no_grad
|
121 |
-
with embedding_context():
|
122 |
-
if use_mask:
|
123 |
-
x = self.model(tokens.input_ids.to(we.device_id),
|
124 |
-
tokens.attention_mask.to(we.device_id))
|
125 |
-
else:
|
126 |
-
x = self.model(tokens.input_ids.to(we.device_id))
|
127 |
-
x = x.last_hidden_state
|
128 |
-
|
129 |
-
if return_mask:
|
130 |
-
return x.detach() + 0.0, tokens.attention_mask.to(we.device_id)
|
131 |
-
else:
|
132 |
-
return x.detach() + 0.0, None
|
133 |
-
|
134 |
-
def _clean(self, text):
|
135 |
-
if self.clean == 'whitespace':
|
136 |
-
text = whitespace_clean(basic_clean(text))
|
137 |
-
elif self.clean == 'lower':
|
138 |
-
text = whitespace_clean(basic_clean(text)).lower()
|
139 |
-
elif self.clean == 'canonicalize':
|
140 |
-
text = canonicalize(basic_clean(text))
|
141 |
-
elif self.clean == 'heavy':
|
142 |
-
text = heavy_clean(basic_clean(text))
|
143 |
-
return text
|
144 |
-
|
145 |
-
def encode(self, text, return_mask=False, use_mask=True):
|
146 |
-
if isinstance(text, str):
|
147 |
-
text = [text]
|
148 |
-
if self.clean:
|
149 |
-
text = [self._clean(u) for u in text]
|
150 |
-
assert self.tokenizer is not None
|
151 |
-
cont, mask = [], []
|
152 |
-
with torch.autocast(device_type='cuda',
|
153 |
-
enabled=self.t5_dtype in ('float16', 'bfloat16'),
|
154 |
-
dtype=getattr(torch, self.t5_dtype)):
|
155 |
-
for tt in text:
|
156 |
-
tokens = self.tokenizer([tt], **self.tokenize_kargs)
|
157 |
-
one_cont, one_mask = self(tokens,
|
158 |
-
return_mask=return_mask,
|
159 |
-
use_mask=use_mask)
|
160 |
-
cont.append(one_cont)
|
161 |
-
mask.append(one_mask)
|
162 |
-
if return_mask:
|
163 |
-
return torch.cat(cont, dim=0), torch.cat(mask, dim=0)
|
164 |
-
else:
|
165 |
-
return torch.cat(cont, dim=0)
|
166 |
-
|
167 |
-
def encode_list(self, text_list, return_mask=True):
|
168 |
-
cont_list = []
|
169 |
-
mask_list = []
|
170 |
-
for pp in text_list:
|
171 |
-
cont, cont_mask = self.encode(pp, return_mask=return_mask)
|
172 |
-
cont_list.append(cont)
|
173 |
-
mask_list.append(cont_mask)
|
174 |
-
if return_mask:
|
175 |
-
return cont_list, mask_list
|
176 |
-
else:
|
177 |
-
return cont_list
|
178 |
-
|
179 |
-
@staticmethod
|
180 |
-
def get_config_template():
|
181 |
-
return dict_to_yaml('MODELS',
|
182 |
-
__class__.__name__,
|
183 |
-
ACETextEmbedder.para_dict,
|
184 |
-
set_name=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
modules/model/network/__init__.py
DELETED
@@ -1 +0,0 @@
|
|
1 |
-
from .ldm_ace import LdmACE
|
|
|
|
modules/model/network/ldm_ace.py
DELETED
@@ -1,353 +0,0 @@
|
|
1 |
-
# -*- coding: utf-8 -*-
|
2 |
-
# Copyright (c) Alibaba, Inc. and its affiliates.
|
3 |
-
import copy
|
4 |
-
import random
|
5 |
-
from contextlib import nullcontext
|
6 |
-
|
7 |
-
import torch
|
8 |
-
import torch.nn.functional as F
|
9 |
-
from torch import nn
|
10 |
-
|
11 |
-
from scepter.modules.model.network.ldm import LatentDiffusion
|
12 |
-
from scepter.modules.model.registry import MODELS
|
13 |
-
from scepter.modules.utils.config import dict_to_yaml
|
14 |
-
from scepter.modules.utils.distribute import we
|
15 |
-
|
16 |
-
from ..utils.basic_utils import (
|
17 |
-
check_list_of_list,
|
18 |
-
pack_imagelist_into_tensor_v2 as pack_imagelist_into_tensor,
|
19 |
-
to_device,
|
20 |
-
unpack_tensor_into_imagelist
|
21 |
-
)
|
22 |
-
|
23 |
-
|
24 |
-
class TextEmbedding(nn.Module):
|
25 |
-
def __init__(self, embedding_shape):
|
26 |
-
super().__init__()
|
27 |
-
self.pos = nn.Parameter(data=torch.zeros(embedding_shape))
|
28 |
-
|
29 |
-
|
30 |
-
@MODELS.register_class()
|
31 |
-
class LdmACE(LatentDiffusion):
|
32 |
-
para_dict = LatentDiffusion.para_dict
|
33 |
-
para_dict['DECODER_BIAS'] = {'value': 0, 'description': ''}
|
34 |
-
|
35 |
-
def __init__(self, cfg, logger=None):
|
36 |
-
super().__init__(cfg, logger=logger)
|
37 |
-
self.interpolate_func = lambda x: (F.interpolate(
|
38 |
-
x.unsqueeze(0),
|
39 |
-
scale_factor=1 / self.size_factor,
|
40 |
-
mode='nearest-exact') if x is not None else None)
|
41 |
-
|
42 |
-
self.text_indentifers = cfg.get('TEXT_IDENTIFIER', [])
|
43 |
-
self.use_text_pos_embeddings = cfg.get('USE_TEXT_POS_EMBEDDINGS',
|
44 |
-
False)
|
45 |
-
if self.use_text_pos_embeddings:
|
46 |
-
self.text_position_embeddings = TextEmbedding(
|
47 |
-
(10, 4096)).eval().requires_grad_(False)
|
48 |
-
else:
|
49 |
-
self.text_position_embeddings = None
|
50 |
-
|
51 |
-
self.logger.info(self.model)
|
52 |
-
|
53 |
-
@torch.no_grad()
|
54 |
-
def encode_first_stage(self, x, **kwargs):
|
55 |
-
return [
|
56 |
-
self.scale_factor *
|
57 |
-
self.first_stage_model._encode(i.unsqueeze(0).to(torch.float16))
|
58 |
-
for i in x
|
59 |
-
]
|
60 |
-
|
61 |
-
@torch.no_grad()
|
62 |
-
def decode_first_stage(self, z):
|
63 |
-
return [
|
64 |
-
self.first_stage_model._decode(1. / self.scale_factor *
|
65 |
-
i.to(torch.float16)) for i in z
|
66 |
-
]
|
67 |
-
|
68 |
-
def cond_stage_embeddings(self, prompt, edit_image, cont, cont_mask):
|
69 |
-
if self.use_text_pos_embeddings and not torch.sum(
|
70 |
-
self.text_position_embeddings.pos) > 0:
|
71 |
-
identifier_cont, identifier_cont_mask = getattr(
|
72 |
-
self.cond_stage_model, 'encode')(self.text_indentifers,
|
73 |
-
return_mask=True)
|
74 |
-
self.text_position_embeddings.load_state_dict(
|
75 |
-
{'pos': identifier_cont[:, 0, :]})
|
76 |
-
cont_, cont_mask_ = [], []
|
77 |
-
for pp, edit, c, cm in zip(prompt, edit_image, cont, cont_mask):
|
78 |
-
if isinstance(pp, list):
|
79 |
-
cont_.append([c[-1], *c] if len(edit) > 0 else [c[-1]])
|
80 |
-
cont_mask_.append([cm[-1], *cm] if len(edit) > 0 else [cm[-1]])
|
81 |
-
else:
|
82 |
-
raise NotImplementedError
|
83 |
-
|
84 |
-
return cont_, cont_mask_
|
85 |
-
|
86 |
-
def limit_batch_data(self, batch_data_list, log_num):
|
87 |
-
if log_num and log_num > 0:
|
88 |
-
batch_data_list_limited = []
|
89 |
-
for sub_data in batch_data_list:
|
90 |
-
if sub_data is not None:
|
91 |
-
sub_data = sub_data[:log_num]
|
92 |
-
batch_data_list_limited.append(sub_data)
|
93 |
-
return batch_data_list_limited
|
94 |
-
else:
|
95 |
-
return batch_data_list
|
96 |
-
|
97 |
-
def forward_train(self,
|
98 |
-
edit_image=[],
|
99 |
-
edit_image_mask=[],
|
100 |
-
image=None,
|
101 |
-
image_mask=None,
|
102 |
-
noise=None,
|
103 |
-
prompt=[],
|
104 |
-
**kwargs):
|
105 |
-
'''
|
106 |
-
Args:
|
107 |
-
edit_image: list of list of edit_image
|
108 |
-
edit_image_mask: list of list of edit_image_mask
|
109 |
-
image: target image
|
110 |
-
image_mask: target image mask
|
111 |
-
noise: default is None, generate automaticly
|
112 |
-
prompt: list of list of text
|
113 |
-
**kwargs:
|
114 |
-
Returns:
|
115 |
-
'''
|
116 |
-
assert check_list_of_list(prompt) and check_list_of_list(
|
117 |
-
edit_image) and check_list_of_list(edit_image_mask)
|
118 |
-
assert len(edit_image) == len(edit_image_mask) == len(prompt)
|
119 |
-
assert self.cond_stage_model is not None
|
120 |
-
gc_seg = kwargs.pop('gc_seg', [])
|
121 |
-
gc_seg = int(gc_seg[0]) if len(gc_seg) > 0 else 0
|
122 |
-
context = {}
|
123 |
-
|
124 |
-
# process image
|
125 |
-
image = to_device(image)
|
126 |
-
x_start = self.encode_first_stage(image, **kwargs)
|
127 |
-
x_start, x_shapes = pack_imagelist_into_tensor(x_start) # B, C, L
|
128 |
-
n, _, _ = x_start.shape
|
129 |
-
t = torch.randint(0, self.num_timesteps, (n, ),
|
130 |
-
device=x_start.device).long()
|
131 |
-
context['x_shapes'] = x_shapes
|
132 |
-
|
133 |
-
# process image mask
|
134 |
-
image_mask = to_device(image_mask, strict=False)
|
135 |
-
context['x_mask'] = [self.interpolate_func(i) for i in image_mask
|
136 |
-
] if image_mask is not None else [None] * n
|
137 |
-
|
138 |
-
# process text
|
139 |
-
# with torch.autocast(device_type="cuda", enabled=True, dtype=torch.bfloat16):
|
140 |
-
prompt_ = [[pp] if isinstance(pp, str) else pp for pp in prompt]
|
141 |
-
try:
|
142 |
-
cont, cont_mask = getattr(self.cond_stage_model,
|
143 |
-
'encode_list')(prompt_, return_mask=True)
|
144 |
-
except Exception as e:
|
145 |
-
print(e, prompt_)
|
146 |
-
cont, cont_mask = self.cond_stage_embeddings(prompt, edit_image, cont,
|
147 |
-
cont_mask)
|
148 |
-
context['crossattn'] = cont
|
149 |
-
|
150 |
-
# process edit image & edit image mask
|
151 |
-
edit_image = [to_device(i, strict=False) for i in edit_image]
|
152 |
-
edit_image_mask = [to_device(i, strict=False) for i in edit_image_mask]
|
153 |
-
e_img, e_mask = [], []
|
154 |
-
for u, m in zip(edit_image, edit_image_mask):
|
155 |
-
if m is None:
|
156 |
-
m = [None] * len(u) if u is not None else [None]
|
157 |
-
e_img.append(
|
158 |
-
self.encode_first_stage(u, **kwargs) if u is not None else u)
|
159 |
-
e_mask.append([
|
160 |
-
self.interpolate_func(i) if i is not None else None for i in m
|
161 |
-
])
|
162 |
-
context['edit'], context['edit_mask'] = e_img, e_mask
|
163 |
-
|
164 |
-
# process loss
|
165 |
-
loss = self.diffusion.loss(
|
166 |
-
x_0=x_start,
|
167 |
-
t=t,
|
168 |
-
noise=noise,
|
169 |
-
model=self.model,
|
170 |
-
model_kwargs={
|
171 |
-
'cond':
|
172 |
-
context,
|
173 |
-
'mask':
|
174 |
-
cont_mask,
|
175 |
-
'gc_seg':
|
176 |
-
gc_seg,
|
177 |
-
'text_position_embeddings':
|
178 |
-
self.text_position_embeddings.pos if hasattr(
|
179 |
-
self.text_position_embeddings, 'pos') else None
|
180 |
-
},
|
181 |
-
**kwargs)
|
182 |
-
loss = loss.mean()
|
183 |
-
ret = {'loss': loss, 'probe_data': {'prompt': prompt}}
|
184 |
-
return ret
|
185 |
-
|
186 |
-
@torch.no_grad()
|
187 |
-
def forward_test(self,
|
188 |
-
edit_image=[],
|
189 |
-
edit_image_mask=[],
|
190 |
-
image=None,
|
191 |
-
image_mask=None,
|
192 |
-
prompt=[],
|
193 |
-
n_prompt=[],
|
194 |
-
sampler='ddim',
|
195 |
-
sample_steps=20,
|
196 |
-
guide_scale=4.5,
|
197 |
-
guide_rescale=0.5,
|
198 |
-
log_num=-1,
|
199 |
-
seed=2024,
|
200 |
-
**kwargs):
|
201 |
-
|
202 |
-
assert check_list_of_list(prompt) and check_list_of_list(
|
203 |
-
edit_image) and check_list_of_list(edit_image_mask)
|
204 |
-
assert len(edit_image) == len(edit_image_mask) == len(prompt)
|
205 |
-
assert self.cond_stage_model is not None
|
206 |
-
# gc_seg is unused
|
207 |
-
kwargs.pop('gc_seg', -1)
|
208 |
-
# prepare data
|
209 |
-
context, null_context = {}, {}
|
210 |
-
|
211 |
-
prompt, n_prompt, image, image_mask, edit_image, edit_image_mask = self.limit_batch_data(
|
212 |
-
[prompt, n_prompt, image, image_mask, edit_image, edit_image_mask],
|
213 |
-
log_num)
|
214 |
-
g = torch.Generator(device=we.device_id)
|
215 |
-
seed = seed if seed >= 0 else random.randint(0, 2**32 - 1)
|
216 |
-
g.manual_seed(seed)
|
217 |
-
n_prompt = copy.deepcopy(prompt)
|
218 |
-
# only modify the last prompt to be zero
|
219 |
-
for nn_p_id, nn_p in enumerate(n_prompt):
|
220 |
-
if isinstance(nn_p, str):
|
221 |
-
n_prompt[nn_p_id] = ['']
|
222 |
-
elif isinstance(nn_p, list):
|
223 |
-
n_prompt[nn_p_id][-1] = ''
|
224 |
-
else:
|
225 |
-
raise NotImplementedError
|
226 |
-
# process image
|
227 |
-
image = to_device(image)
|
228 |
-
x = self.encode_first_stage(image, **kwargs)
|
229 |
-
noise = [
|
230 |
-
torch.empty(*i.shape, device=we.device_id).normal_(generator=g)
|
231 |
-
for i in x
|
232 |
-
]
|
233 |
-
noise, x_shapes = pack_imagelist_into_tensor(noise)
|
234 |
-
context['x_shapes'] = null_context['x_shapes'] = x_shapes
|
235 |
-
|
236 |
-
# process image mask
|
237 |
-
image_mask = to_device(image_mask, strict=False)
|
238 |
-
cond_mask = [self.interpolate_func(i) for i in image_mask
|
239 |
-
] if image_mask is not None else [None] * len(image)
|
240 |
-
context['x_mask'] = null_context['x_mask'] = cond_mask
|
241 |
-
# process text
|
242 |
-
# with torch.autocast(device_type="cuda", enabled=True, dtype=torch.bfloat16):
|
243 |
-
prompt_ = [[pp] if isinstance(pp, str) else pp for pp in prompt]
|
244 |
-
cont, cont_mask = getattr(self.cond_stage_model,
|
245 |
-
'encode_list')(prompt_, return_mask=True)
|
246 |
-
cont, cont_mask = self.cond_stage_embeddings(prompt, edit_image, cont,
|
247 |
-
cont_mask)
|
248 |
-
null_cont, null_cont_mask = getattr(self.cond_stage_model,
|
249 |
-
'encode_list')(n_prompt,
|
250 |
-
return_mask=True)
|
251 |
-
null_cont, null_cont_mask = self.cond_stage_embeddings(
|
252 |
-
prompt, edit_image, null_cont, null_cont_mask)
|
253 |
-
context['crossattn'] = cont
|
254 |
-
null_context['crossattn'] = null_cont
|
255 |
-
|
256 |
-
# processe edit image & edit image mask
|
257 |
-
edit_image = [to_device(i, strict=False) for i in edit_image]
|
258 |
-
edit_image_mask = [to_device(i, strict=False) for i in edit_image_mask]
|
259 |
-
e_img, e_mask = [], []
|
260 |
-
for u, m in zip(edit_image, edit_image_mask):
|
261 |
-
if u is None:
|
262 |
-
continue
|
263 |
-
if m is None:
|
264 |
-
m = [None] * len(u)
|
265 |
-
e_img.append(self.encode_first_stage(u, **kwargs))
|
266 |
-
e_mask.append([self.interpolate_func(i) for i in m])
|
267 |
-
null_context['edit'] = context['edit'] = e_img
|
268 |
-
null_context['edit_mask'] = context['edit_mask'] = e_mask
|
269 |
-
|
270 |
-
# process sample
|
271 |
-
model = self.model_ema if self.use_ema and self.eval_ema else self.model
|
272 |
-
embedding_context = model.no_sync if isinstance(model, torch.distributed.fsdp.FullyShardedDataParallel) \
|
273 |
-
else nullcontext
|
274 |
-
with embedding_context():
|
275 |
-
samples = self.diffusion.sample(
|
276 |
-
sampler=sampler,
|
277 |
-
noise=noise,
|
278 |
-
model=model,
|
279 |
-
model_kwargs=[{
|
280 |
-
'cond':
|
281 |
-
context,
|
282 |
-
'mask':
|
283 |
-
cont_mask,
|
284 |
-
'text_position_embeddings':
|
285 |
-
self.text_position_embeddings.pos if hasattr(
|
286 |
-
self.text_position_embeddings, 'pos') else None
|
287 |
-
}, {
|
288 |
-
'cond':
|
289 |
-
null_context,
|
290 |
-
'mask':
|
291 |
-
null_cont_mask,
|
292 |
-
'text_position_embeddings':
|
293 |
-
self.text_position_embeddings.pos if hasattr(
|
294 |
-
self.text_position_embeddings, 'pos') else None
|
295 |
-
}] if guide_scale is not None and guide_scale > 1 else {
|
296 |
-
'cond':
|
297 |
-
context,
|
298 |
-
'mask':
|
299 |
-
cont_mask,
|
300 |
-
'text_position_embeddings':
|
301 |
-
self.text_position_embeddings.pos if hasattr(
|
302 |
-
self.text_position_embeddings, 'pos') else None
|
303 |
-
},
|
304 |
-
steps=sample_steps,
|
305 |
-
guide_scale=guide_scale,
|
306 |
-
guide_rescale=guide_rescale,
|
307 |
-
show_progress=True,
|
308 |
-
**kwargs)
|
309 |
-
|
310 |
-
samples = unpack_tensor_into_imagelist(samples, x_shapes)
|
311 |
-
x_samples = self.decode_first_stage(samples)
|
312 |
-
outputs = list()
|
313 |
-
for i in range(len(prompt)):
|
314 |
-
rec_img = torch.clamp(
|
315 |
-
(x_samples[i] + 1.0) / 2.0 + self.decoder_bias / 255,
|
316 |
-
min=0.0,
|
317 |
-
max=1.0)
|
318 |
-
rec_img = rec_img.squeeze(0)
|
319 |
-
edit_imgs, edit_img_masks = [], []
|
320 |
-
if edit_image is not None and edit_image[i] is not None:
|
321 |
-
if edit_image_mask[i] is None:
|
322 |
-
edit_image_mask[i] = [None] * len(edit_image[i])
|
323 |
-
for edit_img, edit_mask in zip(edit_image[i],
|
324 |
-
edit_image_mask[i]):
|
325 |
-
edit_img = torch.clamp((edit_img + 1.0) / 2.0,
|
326 |
-
min=0.0,
|
327 |
-
max=1.0)
|
328 |
-
edit_imgs.append(edit_img.squeeze(0))
|
329 |
-
if edit_mask is None:
|
330 |
-
edit_mask = torch.ones_like(edit_img[[0], :, :])
|
331 |
-
edit_img_masks.append(edit_mask)
|
332 |
-
one_tup = {
|
333 |
-
'reconstruct_image': rec_img,
|
334 |
-
'instruction': prompt[i],
|
335 |
-
'edit_image': edit_imgs if len(edit_imgs) > 0 else None,
|
336 |
-
'edit_mask': edit_img_masks if len(edit_imgs) > 0 else None
|
337 |
-
}
|
338 |
-
if image is not None:
|
339 |
-
if image_mask is None:
|
340 |
-
image_mask = [None] * len(image)
|
341 |
-
ori_img = torch.clamp((image[i] + 1.0) / 2.0, min=0.0, max=1.0)
|
342 |
-
one_tup['target_image'] = ori_img.squeeze(0)
|
343 |
-
one_tup['target_mask'] = image_mask[i] if image_mask[
|
344 |
-
i] is not None else torch.ones_like(ori_img[[0], :, :])
|
345 |
-
outputs.append(one_tup)
|
346 |
-
return outputs
|
347 |
-
|
348 |
-
@staticmethod
|
349 |
-
def get_config_template():
|
350 |
-
return dict_to_yaml('MODEL',
|
351 |
-
__class__.__name__,
|
352 |
-
LdmACE.para_dict,
|
353 |
-
set_name=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
modules/model/utils/basic_utils.py
DELETED
@@ -1,104 +0,0 @@
|
|
1 |
-
# -*- coding: utf-8 -*-
|
2 |
-
# Copyright (c) Alibaba, Inc. and its affiliates.
|
3 |
-
from inspect import isfunction
|
4 |
-
|
5 |
-
import torch
|
6 |
-
from torch.nn.utils.rnn import pad_sequence
|
7 |
-
|
8 |
-
from scepter.modules.utils.distribute import we
|
9 |
-
|
10 |
-
|
11 |
-
def exists(x):
|
12 |
-
return x is not None
|
13 |
-
|
14 |
-
|
15 |
-
def default(val, d):
|
16 |
-
if exists(val):
|
17 |
-
return val
|
18 |
-
return d() if isfunction(d) else d
|
19 |
-
|
20 |
-
|
21 |
-
def disabled_train(self, mode=True):
|
22 |
-
"""Overwrite model.train with this function to make sure train/eval mode
|
23 |
-
does not change anymore."""
|
24 |
-
return self
|
25 |
-
|
26 |
-
|
27 |
-
def transfer_size(para_num):
|
28 |
-
if para_num > 1000 * 1000 * 1000 * 1000:
|
29 |
-
bill = para_num / (1000 * 1000 * 1000 * 1000)
|
30 |
-
return '{:.2f}T'.format(bill)
|
31 |
-
elif para_num > 1000 * 1000 * 1000:
|
32 |
-
gyte = para_num / (1000 * 1000 * 1000)
|
33 |
-
return '{:.2f}B'.format(gyte)
|
34 |
-
elif para_num > (1000 * 1000):
|
35 |
-
meta = para_num / (1000 * 1000)
|
36 |
-
return '{:.2f}M'.format(meta)
|
37 |
-
elif para_num > 1000:
|
38 |
-
kelo = para_num / 1000
|
39 |
-
return '{:.2f}K'.format(kelo)
|
40 |
-
else:
|
41 |
-
return para_num
|
42 |
-
|
43 |
-
|
44 |
-
def count_params(model):
|
45 |
-
total_params = sum(p.numel() for p in model.parameters())
|
46 |
-
return transfer_size(total_params)
|
47 |
-
|
48 |
-
|
49 |
-
def expand_dims_like(x, y):
|
50 |
-
while x.dim() != y.dim():
|
51 |
-
x = x.unsqueeze(-1)
|
52 |
-
return x
|
53 |
-
|
54 |
-
|
55 |
-
def unpack_tensor_into_imagelist(image_tensor, shapes):
|
56 |
-
image_list = []
|
57 |
-
for img, shape in zip(image_tensor, shapes):
|
58 |
-
h, w = shape[0], shape[1]
|
59 |
-
image_list.append(img[:, :h * w].view(1, -1, h, w))
|
60 |
-
|
61 |
-
return image_list
|
62 |
-
|
63 |
-
|
64 |
-
def find_example(tensor_list, image_list):
|
65 |
-
for i in tensor_list:
|
66 |
-
if isinstance(i, torch.Tensor):
|
67 |
-
return torch.zeros_like(i)
|
68 |
-
for i in image_list:
|
69 |
-
if isinstance(i, torch.Tensor):
|
70 |
-
_, c, h, w = i.size()
|
71 |
-
return torch.zeros_like(i.view(c, h * w).transpose(1, 0))
|
72 |
-
return None
|
73 |
-
|
74 |
-
|
75 |
-
def pack_imagelist_into_tensor_v2(image_list):
|
76 |
-
# allow None
|
77 |
-
example = None
|
78 |
-
image_tensor, shapes = [], []
|
79 |
-
for img in image_list:
|
80 |
-
if img is None:
|
81 |
-
example = find_example(image_tensor,
|
82 |
-
image_list) if example is None else example
|
83 |
-
image_tensor.append(example)
|
84 |
-
shapes.append(None)
|
85 |
-
continue
|
86 |
-
_, c, h, w = img.size()
|
87 |
-
image_tensor.append(img.view(c, h * w).transpose(1, 0)) # h*w, c
|
88 |
-
shapes.append((h, w))
|
89 |
-
|
90 |
-
image_tensor = pad_sequence(image_tensor,
|
91 |
-
batch_first=True).permute(0, 2, 1) # b, c, l
|
92 |
-
return image_tensor, shapes
|
93 |
-
|
94 |
-
|
95 |
-
def to_device(inputs, strict=True):
|
96 |
-
if inputs is None:
|
97 |
-
return None
|
98 |
-
if strict:
|
99 |
-
assert all(isinstance(i, torch.Tensor) for i in inputs)
|
100 |
-
return [i.to(we.device_id) if i is not None else None for i in inputs]
|
101 |
-
|
102 |
-
|
103 |
-
def check_list_of_list(ll):
|
104 |
-
return isinstance(ll, list) and all(isinstance(i, list) for i in ll)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|