chaojiemao commited on
Commit
e667736
·
verified ·
1 Parent(s): dc3e53f

Update ace_inference.py

Browse files
Files changed (1) hide show
  1. ace_inference.py +162 -355
ace_inference.py CHANGED
@@ -79,153 +79,19 @@ def process_edit_image(images,
79
  mask_tensors.append(mask_tensor)
80
  return img_tensors, mask_tensors
81
 
82
-
83
  class TextEmbedding(nn.Module):
84
  def __init__(self, embedding_shape):
85
  super().__init__()
86
  self.pos = nn.Parameter(data=torch.zeros(embedding_shape))
87
 
88
- class RefinerInference(DiffusionInference):
89
- def init_from_cfg(self, cfg):
90
- super().init_from_cfg(cfg)
91
- self.diffusion = DIFFUSIONS.build(cfg.MODEL.DIFFUSION, logger=self.logger) \
92
- if cfg.MODEL.have('DIFFUSION') else None
93
- self.max_seq_length = cfg.MODEL.get("MAX_SEQ_LENGTH", 4096)
94
- assert self.diffusion is not None
95
- self.dynamic_load(self.cond_stage_model, 'cond_stage_model')
96
- self.dynamic_load(self.diffusion_model, 'diffusion_model')
97
- self.dynamic_load(self.first_stage_model, 'first_stage_model')
98
-
99
- @torch.no_grad()
100
- def encode_first_stage(self, x, **kwargs):
101
- _, dtype = self.get_function_info(self.first_stage_model, 'encode')
102
- with torch.autocast('cuda',
103
- enabled=dtype in ('float16', 'bfloat16'),
104
- dtype=getattr(torch, dtype)):
105
- def run_one_image(u):
106
- zu = get_model(self.first_stage_model).encode(u)
107
- if isinstance(zu, (tuple, list)):
108
- zu = zu[0]
109
- return zu
110
- z = [run_one_image(u.unsqueeze(0) if u.dim == 3 else u) for u in x]
111
- return z
112
- def upscale_resize(self, image, interpolation=T.InterpolationMode.BILINEAR):
113
- c, H, W = image.shape
114
- scale = max(1.0, math.sqrt(self.max_seq_length / ((H / 16) * (W / 16))))
115
- rH = int(H * scale) // 16 * 16 # ensure divisible by self.d
116
- rW = int(W * scale) // 16 * 16
117
- image = T.Resize((rH, rW), interpolation=interpolation, antialias=True)(image)
118
- return image
119
- @torch.no_grad()
120
- def decode_first_stage(self, z):
121
- _, dtype = self.get_function_info(self.first_stage_model, 'decode')
122
- with torch.autocast('cuda',
123
- enabled=dtype in ('float16', 'bfloat16'),
124
- dtype=getattr(torch, dtype)):
125
- return [get_model(self.first_stage_model).decode(zu) for zu in z]
126
-
127
- def noise_sample(self, num_samples, h, w, seed, device = None, dtype = torch.bfloat16):
128
- noise = torch.randn(
129
- num_samples,
130
- 16,
131
- # allow for packing
132
- 2 * math.ceil(h / 16),
133
- 2 * math.ceil(w / 16),
134
- device=device,
135
- dtype=dtype,
136
- generator=torch.Generator(device=device).manual_seed(seed),
137
- )
138
- return noise
139
- def refine(self,
140
- x_samples=None,
141
- prompt=None,
142
- reverse_scale=-1.,
143
- seed = 2024,
144
- use_dynamic_model = False,
145
- **kwargs
146
- ):
147
- print(prompt)
148
- value_input = copy.deepcopy(self.input)
149
- x_samples = [self.upscale_resize(x) for x in x_samples]
150
-
151
- noise = []
152
- for i, x in enumerate(x_samples):
153
- noise_ = self.noise_sample(1, x.shape[1],
154
- x.shape[2], seed,
155
- device = x.device)
156
- noise.append(noise_)
157
- noise, x_shapes = pack_imagelist_into_tensor(noise)
158
- if reverse_scale > 0:
159
- if use_dynamic_model: self.dynamic_load(self.first_stage_model, 'first_stage_model')
160
- x_samples = [x.unsqueeze(0) for x in x_samples]
161
- x_start = self.encode_first_stage(x_samples, **kwargs)
162
- if use_dynamic_model: self.dynamic_unload(self.first_stage_model,
163
- 'first_stage_model',
164
- skip_loaded=True)
165
- x_start, _ = pack_imagelist_into_tensor(x_start)
166
- else:
167
- x_start = None
168
- # cond stage
169
- if use_dynamic_model: self.dynamic_load(self.cond_stage_model, 'cond_stage_model')
170
- function_name, dtype = self.get_function_info(self.cond_stage_model)
171
- with torch.autocast('cuda',
172
- enabled=dtype == 'float16',
173
- dtype=getattr(torch, dtype)):
174
- ctx = getattr(get_model(self.cond_stage_model),
175
- function_name)(prompt)
176
- ctx["x_shapes"] = x_shapes
177
- if use_dynamic_model: self.dynamic_unload(self.cond_stage_model,
178
- 'cond_stage_model',
179
- skip_loaded=True)
180
-
181
-
182
- if use_dynamic_model: self.dynamic_load(self.diffusion_model, 'diffusion_model')
183
- # UNet use input n_prompt
184
- function_name, dtype = self.get_function_info(
185
- self.diffusion_model)
186
- with torch.autocast('cuda',
187
- enabled=dtype in ('float16', 'bfloat16'),
188
- dtype=getattr(torch, dtype)):
189
- solver_sample = value_input.get('sample', 'flow_euler')
190
- sample_steps = value_input.get('sample_steps', 20)
191
- guide_scale = value_input.get('guide_scale', 3.5)
192
- if guide_scale is not None:
193
- guide_scale = torch.full((noise.shape[0],), guide_scale, device=noise.device,
194
- dtype=noise.dtype)
195
- else:
196
- guide_scale = None
197
- latent = self.diffusion.sample(
198
- noise=noise,
199
- sampler=solver_sample,
200
- model=get_model(self.diffusion_model),
201
- model_kwargs={"cond": ctx, "guidance": guide_scale},
202
- steps=sample_steps,
203
- show_progress=True,
204
- guide_scale=guide_scale,
205
- return_intermediate=None,
206
- reverse_scale=reverse_scale,
207
- x=x_start,
208
- **kwargs).float()
209
- latent = unpack_tensor_into_imagelist(latent, x_shapes)
210
- if use_dynamic_model: self.dynamic_unload(self.diffusion_model,
211
- 'diffusion_model',
212
- skip_loaded=True)
213
- if use_dynamic_model: self.dynamic_load(self.first_stage_model, 'first_stage_model')
214
- x_samples = self.decode_first_stage(latent)
215
- if use_dynamic_model: self.dynamic_unload(self.first_stage_model,
216
- 'first_stage_model',
217
- skip_loaded=True)
218
- return x_samples
219
-
220
-
221
- class ACEInference(DiffusionInference):
222
  def __init__(self, logger=None):
223
  if logger is None:
224
  logger = get_logger(name='scepter')
225
  self.logger = logger
226
  self.loaded_model = {}
227
  self.loaded_model_name = [
228
- 'diffusion_model', 'first_stage_model', 'cond_stage_model'
229
  ]
230
 
231
  def init_from_cfg(self, cfg):
@@ -234,7 +100,7 @@ class ACEInference(DiffusionInference):
234
  self.use_dynamic_model = cfg.get('USE_DYNAMIC_MODEL', True)
235
  module_paras = self.load_default(cfg.get('DEFAULT_PARAS', None))
236
  assert cfg.have('MODEL')
237
-
238
  self.diffusion_model = self.infer_model(
239
  cfg.MODEL.DIFFUSION_MODEL, module_paras.get(
240
  'DIFFUSION_MODEL',
@@ -250,24 +116,23 @@ class ACEInference(DiffusionInference):
250
  'COND_STAGE_MODEL',
251
  None)) if cfg.MODEL.have('COND_STAGE_MODEL') else None
252
 
253
- self.refiner_model_cfg = cfg.get('REFINER_MODEL', None)
254
- # self.refiner_scale = cfg.get('REFINER_SCALE', 0.)
255
- # self.refiner_prompt = cfg.get('REFINER_PROMPT', "")
256
- self.ace_prompt = cfg.get("ACE_PROMPT", [])
257
- if self.refiner_model_cfg:
258
- self.refiner_module = RefinerInference(self.logger)
259
- self.refiner_module.init_from_cfg(self.refiner_model_cfg)
260
- else:
261
- self.refiner_module = None
262
 
263
  self.diffusion = DIFFUSIONS.build(cfg.MODEL.DIFFUSION,
264
  logger=self.logger)
265
-
266
-
267
  self.interpolate_func = lambda x: (F.interpolate(
268
  x.unsqueeze(0),
269
  scale_factor=1 / self.size_factor,
270
  mode='nearest-exact') if x is not None else None)
 
 
 
 
 
271
  self.text_indentifers = cfg.MODEL.get('TEXT_IDENTIFIER', [])
272
  self.use_text_pos_embeddings = cfg.MODEL.get('USE_TEXT_POS_EMBEDDINGS',
273
  False)
@@ -277,41 +142,66 @@ class ACEInference(DiffusionInference):
277
  else:
278
  self.text_position_embeddings = None
279
 
280
- self.max_seq_len = cfg.MODEL.DIFFUSION_MODEL.MAX_SEQ_LEN
281
- self.scale_factor = cfg.get('SCALE_FACTOR', 0.18215)
282
- self.size_factor = cfg.get('SIZE_FACTOR', 8)
283
- self.decoder_bias = cfg.get('DECODER_BIAS', 0)
284
- self.default_n_prompt = cfg.get('DEFAULT_N_PROMPT', '')
285
- #self.dynamic_load(self.cond_stage_model, 'cond_stage_model')
286
- #self.dynamic_load(self.diffusion_model, 'diffusion_model')
287
- #self.dynamic_load(self.first_stage_model, 'first_stage_model')
 
 
 
 
 
 
288
 
289
  @torch.no_grad()
290
  def encode_first_stage(self, x, **kwargs):
291
  _, dtype = self.get_function_info(self.first_stage_model, 'encode')
292
  with torch.autocast('cuda',
293
- enabled=(dtype != 'float32'),
294
  dtype=getattr(torch, dtype)):
295
- z = [
296
- self.scale_factor * get_model(self.first_stage_model)._encode(
297
- i.unsqueeze(0).to(getattr(torch, dtype))) for i in x
298
- ]
299
- return z
 
 
 
 
300
 
301
  @torch.no_grad()
302
  def decode_first_stage(self, z):
303
  _, dtype = self.get_function_info(self.first_stage_model, 'decode')
304
  with torch.autocast('cuda',
305
- enabled=(dtype != 'float32'),
306
  dtype=getattr(torch, dtype)):
307
- x = [
308
- get_model(self.first_stage_model)._decode(
309
- 1. / self.scale_factor * i.to(getattr(torch, dtype)))
310
- for i in z
311
- ]
312
- return x
313
 
 
 
 
 
 
 
 
 
 
 
 
 
314
 
 
 
 
 
 
 
 
 
315
 
316
  @torch.no_grad()
317
  def __call__(self,
@@ -320,48 +210,35 @@ class ACEInference(DiffusionInference):
320
  prompt='',
321
  task=None,
322
  negative_prompt='',
323
- output_height=512,
324
- output_width=512,
325
- sampler='ddim',
326
  sample_steps=20,
327
- guide_scale=4.5,
328
- guide_rescale=0.5,
329
  seed=-1,
330
  history_io=None,
331
  tar_index=0,
 
332
  **kwargs):
333
  input_image, input_mask = image, mask
334
- g = torch.Generator(device=we.device_id)
335
  seed = seed if seed >= 0 else random.randint(0, 2**32 - 1)
336
- g.manual_seed(int(seed))
337
  if input_image is not None:
338
  # assert isinstance(input_image, list) and isinstance(input_mask, list)
339
  if task is None:
340
  task = [''] * len(input_image)
341
  if not isinstance(prompt, list):
342
  prompt = [prompt] * len(input_image)
343
- if history_io is not None and len(history_io) > 0:
344
- his_image, his_maks, his_prompt, his_task = history_io[
345
- 'image'], history_io['mask'], history_io[
346
- 'prompt'], history_io['task']
347
- assert len(his_image) == len(his_maks) == len(
348
- his_prompt) == len(his_task)
349
- input_image = his_image + input_image
350
- input_mask = his_maks + input_mask
351
- task = his_task + task
352
- prompt = his_prompt + [prompt[-1]]
353
- prompt = [
354
- pp.replace('{image}', f'{{image{i}}}') if i > 0 else pp
355
- for i, pp in enumerate(prompt)
356
- ]
357
-
358
  edit_image, edit_image_mask = process_edit_image(
359
- input_image, input_mask, task, max_seq_len=self.max_seq_len)
360
-
361
- image, image_mask = edit_image[tar_index], edit_image_mask[
362
- tar_index]
 
363
  edit_image, edit_image_mask = [edit_image], [edit_image_mask]
364
-
365
  else:
366
  edit_image = edit_image_mask = [[]]
367
  image = torch.zeros(
@@ -373,177 +250,107 @@ class ACEInference(DiffusionInference):
373
  if not isinstance(prompt, list):
374
  prompt = [prompt]
375
 
376
- image, image_mask, prompt = [image], [image_mask], [prompt]
 
 
377
  assert check_list_of_list(prompt) and check_list_of_list(
378
  edit_image) and check_list_of_list(edit_image_mask)
379
- # Assign Negative Prompt
380
- if isinstance(negative_prompt, list):
381
- negative_prompt = negative_prompt[0]
382
- assert isinstance(negative_prompt, str)
383
-
384
- n_prompt = copy.deepcopy(prompt)
385
- for nn_p_id, nn_p in enumerate(n_prompt):
386
- assert isinstance(nn_p, list)
387
- n_prompt[nn_p_id][-1] = negative_prompt
388
-
389
- is_txt_image = sum([len(e_i) for e_i in edit_image]) < 1
390
  image = to_device(image)
 
 
 
 
 
 
 
391
 
392
- refiner_scale = kwargs.pop("refiner_scale", 0.0)
393
- refiner_prompt = kwargs.pop("refiner_prompt", "")
394
- use_ace = kwargs.pop("use_ace", True)
395
- # <= 0 use ace as the txt2img generator.
396
- if use_ace and (not is_txt_image or refiner_scale <= 0):
397
- ctx, null_ctx = {}, {}
398
- # Get Noise Shape
399
- self.dynamic_load(self.first_stage_model, 'first_stage_model')
400
- x = self.encode_first_stage(image)
401
- self.dynamic_unload(self.first_stage_model,
402
- 'first_stage_model',
403
- skip_loaded=True)
404
- noise = [
405
- torch.empty(*i.shape, device=we.device_id).normal_(generator=g)
406
- for i in x
407
- ]
408
- noise, x_shapes = pack_imagelist_into_tensor(noise)
409
- ctx['x_shapes'] = null_ctx['x_shapes'] = x_shapes
410
 
411
- image_mask = to_device(image_mask, strict=False)
412
- cond_mask = [self.interpolate_func(i) for i in image_mask
413
- ] if image_mask is not None else [None] * len(image)
414
- ctx['x_mask'] = null_ctx['x_mask'] = cond_mask
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
415
 
416
- # Encode Prompt
417
- self.dynamic_load(self.cond_stage_model, 'cond_stage_model')
418
- function_name, dtype = self.get_function_info(self.cond_stage_model)
419
- cont, cont_mask = getattr(get_model(self.cond_stage_model),
420
- function_name)(prompt)
421
- cont, cont_mask = self.cond_stage_embeddings(prompt, edit_image, cont,
422
- cont_mask)
423
- null_cont, null_cont_mask = getattr(get_model(self.cond_stage_model),
424
- function_name)(n_prompt)
425
- null_cont, null_cont_mask = self.cond_stage_embeddings(
426
- prompt, edit_image, null_cont, null_cont_mask)
427
- self.dynamic_unload(self.cond_stage_model,
428
- 'cond_stage_model',
429
- skip_loaded=False)
430
- ctx['crossattn'] = cont
431
- null_ctx['crossattn'] = null_cont
432
-
433
- # Encode Edit Images
434
- self.dynamic_load(self.first_stage_model, 'first_stage_model')
435
- edit_image = [to_device(i, strict=False) for i in edit_image]
436
- edit_image_mask = [to_device(i, strict=False) for i in edit_image_mask]
437
- e_img, e_mask = [], []
438
- for u, m in zip(edit_image, edit_image_mask):
439
- if u is None:
440
- continue
441
- if m is None:
442
- m = [None] * len(u)
443
- e_img.append(self.encode_first_stage(u, **kwargs))
444
- e_mask.append([self.interpolate_func(i) for i in m])
445
- self.dynamic_unload(self.first_stage_model,
446
- 'first_stage_model',
447
- skip_loaded=True)
448
- null_ctx['edit'] = ctx['edit'] = e_img
449
- null_ctx['edit_mask'] = ctx['edit_mask'] = e_mask
450
-
451
- # Diffusion Process
452
- self.dynamic_load(self.diffusion_model, 'diffusion_model')
453
- function_name, dtype = self.get_function_info(self.diffusion_model)
454
- with torch.autocast('cuda',
455
- enabled=dtype in ('float16', 'bfloat16'),
456
- dtype=getattr(torch, dtype)):
457
- latent = self.diffusion.sample(
458
- noise=noise,
459
- sampler=sampler,
460
- model=get_model(self.diffusion_model),
461
- model_kwargs=[{
462
- 'cond':
463
- ctx,
464
- 'mask':
465
- cont_mask,
466
- 'text_position_embeddings':
467
- self.text_position_embeddings.pos if hasattr(
468
- self.text_position_embeddings, 'pos') else None
469
- }, {
470
- 'cond':
471
- null_ctx,
472
- 'mask':
473
- null_cont_mask,
474
- 'text_position_embeddings':
475
- self.text_position_embeddings.pos if hasattr(
476
- self.text_position_embeddings, 'pos') else None
477
- }] if guide_scale is not None and guide_scale > 1 else {
478
- 'cond':
479
- null_ctx,
480
- 'mask':
481
- cont_mask,
482
- 'text_position_embeddings':
483
- self.text_position_embeddings.pos if hasattr(
484
- self.text_position_embeddings, 'pos') else None
485
- },
486
- steps=sample_steps,
487
- show_progress=True,
488
- seed=seed,
489
- guide_scale=guide_scale,
490
- guide_rescale=guide_rescale,
491
- return_intermediate=None,
492
- **kwargs)
493
- self.dynamic_unload(self.diffusion_model,
494
- 'diffusion_model',
495
- skip_loaded=False)
496
-
497
- # Decode to Pixel Space
498
- self.dynamic_load(self.first_stage_model, 'first_stage_model')
499
- samples = unpack_tensor_into_imagelist(latent, x_shapes)
500
- x_samples = self.decode_first_stage(samples)
501
- self.dynamic_unload(self.first_stage_model,
502
- 'first_stage_model',
503
- skip_loaded=False)
504
- x_samples = [x.squeeze(0) for x in x_samples]
505
  else:
506
- x_samples = image
507
- if self.refiner_module and refiner_scale > 0:
508
- if is_txt_image:
509
- random.shuffle(self.ace_prompt)
510
- input_refine_prompt = [self.ace_prompt[0] + refiner_prompt if p[0] == "" else p[0] for p in prompt]
511
- input_refine_scale = -1.
512
- else:
513
- input_refine_prompt = [p[0].replace("{image}", "") + " " + refiner_prompt for p in prompt]
514
- input_refine_scale = refiner_scale
515
- print(input_refine_prompt)
516
-
517
- x_samples = self.refiner_module.refine(x_samples,
518
- reverse_scale = input_refine_scale,
519
- prompt= input_refine_prompt,
520
- seed=seed,
521
- use_dynamic_model=self.use_dynamic_model)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
522
 
523
  imgs = [
524
- torch.clamp((x_i.float() + 1.0) / 2.0 + self.decoder_bias / 255,
525
  min=0.0,
526
  max=1.0).squeeze(0).permute(1, 2, 0).cpu().numpy()
527
  for x_i in x_samples
528
  ]
529
  imgs = [Image.fromarray((img * 255).astype(np.uint8)) for img in imgs]
530
  return imgs
531
-
532
- def cond_stage_embeddings(self, prompt, edit_image, cont, cont_mask):
533
- if self.use_text_pos_embeddings and not torch.sum(
534
- self.text_position_embeddings.pos) > 0:
535
- identifier_cont, _ = getattr(get_model(self.cond_stage_model),
536
- 'encode')(self.text_indentifers,
537
- return_mask=True)
538
- self.text_position_embeddings.load_state_dict(
539
- {'pos': identifier_cont[:, 0, :]})
540
-
541
- cont_, cont_mask_ = [], []
542
- for pp, edit, c, cm in zip(prompt, edit_image, cont, cont_mask):
543
- if isinstance(pp, list):
544
- cont_.append([c[-1], *c] if len(edit) > 0 else [c[-1]])
545
- cont_mask_.append([cm[-1], *cm] if len(edit) > 0 else [cm[-1]])
546
- else:
547
- raise NotImplementedError
548
-
549
- return cont_, cont_mask_
 
79
  mask_tensors.append(mask_tensor)
80
  return img_tensors, mask_tensors
81
 
 
82
  class TextEmbedding(nn.Module):
83
  def __init__(self, embedding_shape):
84
  super().__init__()
85
  self.pos = nn.Parameter(data=torch.zeros(embedding_shape))
86
 
87
+ class ACEFluxLCInference(DiffusionInference):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  def __init__(self, logger=None):
89
  if logger is None:
90
  logger = get_logger(name='scepter')
91
  self.logger = logger
92
  self.loaded_model = {}
93
  self.loaded_model_name = [
94
+ 'diffusion_model', 'first_stage_model', 'cond_stage_model', 'ref_cond_stage_model'
95
  ]
96
 
97
  def init_from_cfg(self, cfg):
 
100
  self.use_dynamic_model = cfg.get('USE_DYNAMIC_MODEL', True)
101
  module_paras = self.load_default(cfg.get('DEFAULT_PARAS', None))
102
  assert cfg.have('MODEL')
103
+ self.size_factor = cfg.get('SIZE_FACTOR', 8)
104
  self.diffusion_model = self.infer_model(
105
  cfg.MODEL.DIFFUSION_MODEL, module_paras.get(
106
  'DIFFUSION_MODEL',
 
116
  'COND_STAGE_MODEL',
117
  None)) if cfg.MODEL.have('COND_STAGE_MODEL') else None
118
 
119
+ self.ref_cond_stage_model = self.infer_model(
120
+ cfg.MODEL.REF_COND_STAGE_MODEL,
121
+ module_paras.get(
122
+ 'REF_COND_STAGE_MODEL',
123
+ None)) if cfg.MODEL.have('REF_COND_STAGE_MODEL') else None
 
 
 
 
124
 
125
  self.diffusion = DIFFUSIONS.build(cfg.MODEL.DIFFUSION,
126
  logger=self.logger)
 
 
127
  self.interpolate_func = lambda x: (F.interpolate(
128
  x.unsqueeze(0),
129
  scale_factor=1 / self.size_factor,
130
  mode='nearest-exact') if x is not None else None)
131
+
132
+ self.max_seq_length = cfg.get("MAX_SEQ_LENGTH", 4096)
133
+ self.src_max_seq_length = cfg.get("SRC_MAX_SEQ_LENGTH", 1024)
134
+ self.image_token = cfg.MODEL.get("IMAGE_TOKEN", "<img>")
135
+
136
  self.text_indentifers = cfg.MODEL.get('TEXT_IDENTIFIER', [])
137
  self.use_text_pos_embeddings = cfg.MODEL.get('USE_TEXT_POS_EMBEDDINGS',
138
  False)
 
142
  else:
143
  self.text_position_embeddings = None
144
 
145
+ if not self.use_dynamic_model:
146
+ self.dynamic_load(self.first_stage_model, 'first_stage_model')
147
+ self.dynamic_load(self.cond_stage_model, 'cond_stage_model')
148
+ if self.ref_cond_stage_model is not None: self.dynamic_load(self.ref_cond_stage_model, 'ref_cond_stage_model')
149
+ self.dynamic_load(self.diffusion_model, 'diffusion_model')
150
+
151
+ def upscale_resize(self, image, interpolation=T.InterpolationMode.BILINEAR):
152
+ c, H, W = image.shape
153
+ scale = max(1.0, math.sqrt(self.max_seq_length / ((H / 16) * (W / 16))))
154
+ rH = int(H * scale) // 16 * 16 # ensure divisible by self.d
155
+ rW = int(W * scale) // 16 * 16
156
+ image = T.Resize((rH, rW), interpolation=interpolation, antialias=True)(image)
157
+ return image
158
+
159
 
160
  @torch.no_grad()
161
  def encode_first_stage(self, x, **kwargs):
162
  _, dtype = self.get_function_info(self.first_stage_model, 'encode')
163
  with torch.autocast('cuda',
164
+ enabled=dtype in ('float16', 'bfloat16'),
165
  dtype=getattr(torch, dtype)):
166
+ def run_one_image(u):
167
+ zu = get_model(self.first_stage_model).encode(u)
168
+ if isinstance(zu, (tuple, list)):
169
+ zu = zu[0]
170
+ return zu
171
+
172
+ z = [run_one_image(u.unsqueeze(0) if u.dim() == 3 else u) for u in x]
173
+ return z
174
+
175
 
176
  @torch.no_grad()
177
  def decode_first_stage(self, z):
178
  _, dtype = self.get_function_info(self.first_stage_model, 'decode')
179
  with torch.autocast('cuda',
180
+ enabled=dtype in ('float16', 'bfloat16'),
181
  dtype=getattr(torch, dtype)):
182
+ return [get_model(self.first_stage_model).decode(zu) for zu in z]
 
 
 
 
 
183
 
184
+ def noise_sample(self, num_samples, h, w, seed, device = None, dtype = torch.bfloat16):
185
+ noise = torch.randn(
186
+ num_samples,
187
+ 16,
188
+ # allow for packing
189
+ 2 * math.ceil(h / 16),
190
+ 2 * math.ceil(w / 16),
191
+ device=device,
192
+ dtype=dtype,
193
+ generator=torch.Generator(device=device).manual_seed(seed),
194
+ )
195
+ return noise
196
 
197
+ # def preprocess_prompt(self, prompt):
198
+ # prompt_ = [[pp] if isinstance(pp, str) else pp for pp in prompt]
199
+ # for pp_id, pp in enumerate(prompt_):
200
+ # prompt_[pp_id] = [""] + pp
201
+ # for p_id, p in enumerate(prompt_[pp_id]):
202
+ # prompt_[pp_id][p_id] = self.image_token + self.text_indentifers[p_id] + " " + p
203
+ # prompt_[pp_id] = [f";".join(prompt_[pp_id])]
204
+ # return prompt_
205
 
206
  @torch.no_grad()
207
  def __call__(self,
 
210
  prompt='',
211
  task=None,
212
  negative_prompt='',
213
+ output_height=1024,
214
+ output_width=1024,
215
+ sampler='flow_euler',
216
  sample_steps=20,
217
+ guide_scale=3.5,
 
218
  seed=-1,
219
  history_io=None,
220
  tar_index=0,
221
+ align=0,
222
  **kwargs):
223
  input_image, input_mask = image, mask
 
224
  seed = seed if seed >= 0 else random.randint(0, 2**32 - 1)
 
225
  if input_image is not None:
226
  # assert isinstance(input_image, list) and isinstance(input_mask, list)
227
  if task is None:
228
  task = [''] * len(input_image)
229
  if not isinstance(prompt, list):
230
  prompt = [prompt] * len(input_image)
231
+ prompt = [
232
+ pp.replace('{image}', f'{{image{i}}}') if i > 0 else pp
233
+ for i, pp in enumerate(prompt)
234
+ ]
 
 
 
 
 
 
 
 
 
 
 
235
  edit_image, edit_image_mask = process_edit_image(
236
+ input_image, input_mask, task, max_seq_len=self.src_max_seq_length)
237
+ image, image_mask = self.upscale_resize(edit_image[tar_index]), self.upscale_resize(edit_image_mask[
238
+ tar_index])
239
+ # edit_image, edit_image_mask = [[self.upscale_resize(i) for i in edit_image]], [[self.upscale_resize(i) for i in edit_image_mask]]
240
+ # image, image_mask = edit_image[tar_index], edit_image_mask[tar_index]
241
  edit_image, edit_image_mask = [edit_image], [edit_image_mask]
 
242
  else:
243
  edit_image = edit_image_mask = [[]]
244
  image = torch.zeros(
 
250
  if not isinstance(prompt, list):
251
  prompt = [prompt]
252
 
253
+ image, image_mask, prompt = [image], [image_mask], [prompt],
254
+ align = [align for p in prompt] if isinstance(align, int) else align
255
+
256
  assert check_list_of_list(prompt) and check_list_of_list(
257
  edit_image) and check_list_of_list(edit_image_mask)
258
+ # negative prompt is not used
 
 
 
 
 
 
 
 
 
 
259
  image = to_device(image)
260
+ ctx = {}
261
+ # Get Noise Shape
262
+ self.dynamic_load(self.first_stage_model, 'first_stage_model')
263
+ x = self.encode_first_stage(image)
264
+ self.dynamic_unload(self.first_stage_model,
265
+ 'first_stage_model',
266
+ skip_loaded=not self.use_dynamic_model)
267
 
268
+ g = torch.Generator(device=we.device_id).manual_seed(seed)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
269
 
270
+ noise = [
271
+ torch.randn((1, 16, i.shape[2], i.shape[3]), device=we.device_id, dtype=torch.bfloat16).normal_(generator=g)
272
+ for i in x
273
+ ]
274
+ noise, x_shapes = pack_imagelist_into_tensor(noise)
275
+ ctx['x_shapes'] = x_shapes
276
+ ctx['align'] = align
277
+
278
+ image_mask = to_device(image_mask, strict=False)
279
+ cond_mask = [self.interpolate_func(i) for i in image_mask
280
+ ] if image_mask is not None else [None] * len(image)
281
+ ctx['x_mask'] = cond_mask
282
+ # Encode Prompt
283
+ instruction_prompt = [[pp[-1]] if "{image}" in pp[-1] else ["{image} " + pp[-1]] for pp in prompt]
284
+ self.dynamic_load(self.cond_stage_model, 'cond_stage_model')
285
+ function_name, dtype = self.get_function_info(self.cond_stage_model)
286
+ cont = getattr(get_model(self.cond_stage_model), function_name)(instruction_prompt)
287
+ cont["context"] = [ct[-1] for ct in cont["context"]]
288
+ cont["y"] = [ct[-1] for ct in cont["y"]]
289
+ self.dynamic_unload(self.cond_stage_model,
290
+ 'cond_stage_model',
291
+ skip_loaded=not self.use_dynamic_model)
292
+ ctx.update(cont)
293
 
294
+ # Encode Edit Images
295
+ self.dynamic_load(self.first_stage_model, 'first_stage_model')
296
+ edit_image = [to_device(i, strict=False) for i in edit_image]
297
+ edit_image_mask = [to_device(i, strict=False) for i in edit_image_mask]
298
+ e_img, e_mask = [], []
299
+ for u, m in zip(edit_image, edit_image_mask):
300
+ if u is None:
301
+ continue
302
+ if m is None:
303
+ m = [None] * len(u)
304
+ e_img.append(self.encode_first_stage(u, **kwargs))
305
+ e_mask.append([self.interpolate_func(i) for i in m])
306
+ self.dynamic_unload(self.first_stage_model,
307
+ 'first_stage_model',
308
+ skip_loaded=not self.use_dynamic_model)
309
+ ctx['edit_x'] = e_img
310
+ ctx['edit_mask'] = e_mask
311
+ # Encode Ref Images
312
+ if guide_scale is not None:
313
+ guide_scale = torch.full((noise.shape[0],), guide_scale, device=noise.device, dtype=noise.dtype)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
314
  else:
315
+ guide_scale = None
316
+
317
+ # Diffusion Process
318
+ self.dynamic_load(self.diffusion_model, 'diffusion_model')
319
+ function_name, dtype = self.get_function_info(self.diffusion_model)
320
+ with torch.autocast('cuda',
321
+ enabled=dtype in ('float16', 'bfloat16'),
322
+ dtype=getattr(torch, dtype)):
323
+ latent = self.diffusion.sample(
324
+ noise=noise,
325
+ sampler=sampler,
326
+ model=get_model(self.diffusion_model),
327
+ model_kwargs={
328
+ "cond": ctx, "guidance": guide_scale, "gc_seg": -1
329
+ },
330
+ steps=sample_steps,
331
+ show_progress=True,
332
+ guide_scale=guide_scale,
333
+ return_intermediate=None,
334
+ reverse_scale=-1,
335
+ **kwargs).float()
336
+ if self.use_dynamic_model: self.dynamic_unload(self.diffusion_model,
337
+ 'diffusion_model',
338
+ skip_loaded=not self.use_dynamic_model)
339
+
340
+ # Decode to Pixel Space
341
+ self.dynamic_load(self.first_stage_model, 'first_stage_model')
342
+ samples = unpack_tensor_into_imagelist(latent, x_shapes)
343
+ x_samples = self.decode_first_stage(samples)
344
+ self.dynamic_unload(self.first_stage_model,
345
+ 'first_stage_model',
346
+ skip_loaded=not self.use_dynamic_model)
347
+ x_samples = [x.squeeze(0) for x in x_samples]
348
 
349
  imgs = [
350
+ torch.clamp((x_i.float() + 1.0) / 2.0,
351
  min=0.0,
352
  max=1.0).squeeze(0).permute(1, 2, 0).cpu().numpy()
353
  for x_i in x_samples
354
  ]
355
  imgs = [Image.fromarray((img * 255).astype(np.uint8)) for img in imgs]
356
  return imgs