chaojiemao commited on
Commit
9391532
·
verified ·
1 Parent(s): 06b3202

Update ace_inference.py

Browse files
Files changed (1) hide show
  1. ace_inference.py +13 -13
ace_inference.py CHANGED
@@ -282,9 +282,9 @@ class ACEInference(DiffusionInference):
282
  self.size_factor = cfg.get('SIZE_FACTOR', 8)
283
  self.decoder_bias = cfg.get('DECODER_BIAS', 0)
284
  self.default_n_prompt = cfg.get('DEFAULT_N_PROMPT', '')
285
- self.dynamic_load(self.cond_stage_model, 'cond_stage_model')
286
- self.dynamic_load(self.diffusion_model, 'diffusion_model')
287
- self.dynamic_load(self.first_stage_model, 'first_stage_model')
288
 
289
  @torch.no_grad()
290
  def encode_first_stage(self, x, **kwargs):
@@ -396,9 +396,9 @@ class ACEInference(DiffusionInference):
396
  if use_ace and (not is_txt_image or refiner_scale <= 0):
397
  ctx, null_ctx = {}, {}
398
  # Get Noise Shape
399
- if use_dynamic_model: self.dynamic_load(self.first_stage_model, 'first_stage_model')
400
  x = self.encode_first_stage(image)
401
- if self.use_dynamic_model: self.dynamic_unload(self.first_stage_model,
402
  'first_stage_model',
403
  skip_loaded=True)
404
  noise = [
@@ -414,7 +414,7 @@ class ACEInference(DiffusionInference):
414
  ctx['x_mask'] = null_ctx['x_mask'] = cond_mask
415
 
416
  # Encode Prompt
417
- if use_dynamic_model: self.dynamic_load(self.cond_stage_model, 'cond_stage_model')
418
  function_name, dtype = self.get_function_info(self.cond_stage_model)
419
  cont, cont_mask = getattr(get_model(self.cond_stage_model),
420
  function_name)(prompt)
@@ -424,14 +424,14 @@ class ACEInference(DiffusionInference):
424
  function_name)(n_prompt)
425
  null_cont, null_cont_mask = self.cond_stage_embeddings(
426
  prompt, edit_image, null_cont, null_cont_mask)
427
- if self.use_dynamic_model: self.dynamic_unload(self.cond_stage_model,
428
  'cond_stage_model',
429
  skip_loaded=False)
430
  ctx['crossattn'] = cont
431
  null_ctx['crossattn'] = null_cont
432
 
433
  # Encode Edit Images
434
- if use_dynamic_model: self.dynamic_load(self.first_stage_model, 'first_stage_model')
435
  edit_image = [to_device(i, strict=False) for i in edit_image]
436
  edit_image_mask = [to_device(i, strict=False) for i in edit_image_mask]
437
  e_img, e_mask = [], []
@@ -442,14 +442,14 @@ class ACEInference(DiffusionInference):
442
  m = [None] * len(u)
443
  e_img.append(self.encode_first_stage(u, **kwargs))
444
  e_mask.append([self.interpolate_func(i) for i in m])
445
- if self.use_dynamic_model: self.dynamic_unload(self.first_stage_model,
446
  'first_stage_model',
447
  skip_loaded=True)
448
  null_ctx['edit'] = ctx['edit'] = e_img
449
  null_ctx['edit_mask'] = ctx['edit_mask'] = e_mask
450
 
451
  # Diffusion Process
452
- if use_dynamic_model: self.dynamic_load(self.diffusion_model, 'diffusion_model')
453
  function_name, dtype = self.get_function_info(self.diffusion_model)
454
  with torch.autocast('cuda',
455
  enabled=dtype in ('float16', 'bfloat16'),
@@ -490,15 +490,15 @@ class ACEInference(DiffusionInference):
490
  guide_rescale=guide_rescale,
491
  return_intermediate=None,
492
  **kwargs)
493
- if self.use_dynamic_model: self.dynamic_unload(self.diffusion_model,
494
  'diffusion_model',
495
  skip_loaded=False)
496
 
497
  # Decode to Pixel Space
498
- if use_dynamic_model: self.dynamic_load(self.first_stage_model, 'first_stage_model')
499
  samples = unpack_tensor_into_imagelist(latent, x_shapes)
500
  x_samples = self.decode_first_stage(samples)
501
- if self.use_dynamic_model: self.dynamic_unload(self.first_stage_model,
502
  'first_stage_model',
503
  skip_loaded=False)
504
  x_samples = [x.squeeze(0) for x in x_samples]
 
282
  self.size_factor = cfg.get('SIZE_FACTOR', 8)
283
  self.decoder_bias = cfg.get('DECODER_BIAS', 0)
284
  self.default_n_prompt = cfg.get('DEFAULT_N_PROMPT', '')
285
+ #self.dynamic_load(self.cond_stage_model, 'cond_stage_model')
286
+ #self.dynamic_load(self.diffusion_model, 'diffusion_model')
287
+ #self.dynamic_load(self.first_stage_model, 'first_stage_model')
288
 
289
  @torch.no_grad()
290
  def encode_first_stage(self, x, **kwargs):
 
396
  if use_ace and (not is_txt_image or refiner_scale <= 0):
397
  ctx, null_ctx = {}, {}
398
  # Get Noise Shape
399
+ self.dynamic_load(self.first_stage_model, 'first_stage_model')
400
  x = self.encode_first_stage(image)
401
+ self.dynamic_unload(self.first_stage_model,
402
  'first_stage_model',
403
  skip_loaded=True)
404
  noise = [
 
414
  ctx['x_mask'] = null_ctx['x_mask'] = cond_mask
415
 
416
  # Encode Prompt
417
+ self.dynamic_load(self.cond_stage_model, 'cond_stage_model')
418
  function_name, dtype = self.get_function_info(self.cond_stage_model)
419
  cont, cont_mask = getattr(get_model(self.cond_stage_model),
420
  function_name)(prompt)
 
424
  function_name)(n_prompt)
425
  null_cont, null_cont_mask = self.cond_stage_embeddings(
426
  prompt, edit_image, null_cont, null_cont_mask)
427
+ self.dynamic_unload(self.cond_stage_model,
428
  'cond_stage_model',
429
  skip_loaded=False)
430
  ctx['crossattn'] = cont
431
  null_ctx['crossattn'] = null_cont
432
 
433
  # Encode Edit Images
434
+ self.dynamic_load(self.first_stage_model, 'first_stage_model')
435
  edit_image = [to_device(i, strict=False) for i in edit_image]
436
  edit_image_mask = [to_device(i, strict=False) for i in edit_image_mask]
437
  e_img, e_mask = [], []
 
442
  m = [None] * len(u)
443
  e_img.append(self.encode_first_stage(u, **kwargs))
444
  e_mask.append([self.interpolate_func(i) for i in m])
445
+ self.dynamic_unload(self.first_stage_model,
446
  'first_stage_model',
447
  skip_loaded=True)
448
  null_ctx['edit'] = ctx['edit'] = e_img
449
  null_ctx['edit_mask'] = ctx['edit_mask'] = e_mask
450
 
451
  # Diffusion Process
452
+ self.dynamic_load(self.diffusion_model, 'diffusion_model')
453
  function_name, dtype = self.get_function_info(self.diffusion_model)
454
  with torch.autocast('cuda',
455
  enabled=dtype in ('float16', 'bfloat16'),
 
490
  guide_rescale=guide_rescale,
491
  return_intermediate=None,
492
  **kwargs)
493
+ self.dynamic_unload(self.diffusion_model,
494
  'diffusion_model',
495
  skip_loaded=False)
496
 
497
  # Decode to Pixel Space
498
+ self.dynamic_load(self.first_stage_model, 'first_stage_model')
499
  samples = unpack_tensor_into_imagelist(latent, x_shapes)
500
  x_samples = self.decode_first_stage(samples)
501
+ self.dynamic_unload(self.first_stage_model,
502
  'first_stage_model',
503
  skip_loaded=False)
504
  x_samples = [x.squeeze(0) for x in x_samples]