shgao commited on
Commit
713a7f5
β€’
1 Parent(s): 8e53b73

update fix

Browse files
utils/stable_diffusion_controlnet_inpaint.py CHANGED
@@ -1046,7 +1046,7 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline, LoraLoaderMixi
1046
  do_classifier_free_guidance,
1047
  )
1048
  if self.unet.config.in_channels==4:
1049
- init_masked_image_latents, _ = self.prepare_masked_image_latents(
1050
  image,
1051
  batch_size * num_images_per_prompt,
1052
  height,
@@ -1055,8 +1055,10 @@ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline, LoraLoaderMixi
1055
  device,
1056
  generator,
1057
  do_classifier_free_guidance,
1058
- ).chunk(2)
1059
- print(type(mask_image), mask_image.shape)
 
 
1060
  _, _, w, h = mask_image.shape
1061
  mask_image = torch.nn.functional.interpolate(mask_image, ((w // 8, h // 8)), mode='nearest')
1062
  mask_image = mask_image.to(latents.device).type_as(latents)
 
1046
  do_classifier_free_guidance,
1047
  )
1048
  if self.unet.config.in_channels==4:
1049
+ init_masked_image_latents = self.prepare_masked_image_latents(
1050
  image,
1051
  batch_size * num_images_per_prompt,
1052
  height,
 
1055
  device,
1056
  generator,
1057
  do_classifier_free_guidance,
1058
+ )
1059
+ if do_classifier_free_guidance:
1060
+ init_masked_image_latents, _ = init_masked_image_latents.chunk(2)
1061
+ # print(type(mask_image), mask_image.shape)
1062
  _, _, w, h = mask_image.shape
1063
  mask_image = torch.nn.functional.interpolate(mask_image, ((w // 8, h // 8)), mode='nearest')
1064
  mask_image = mask_image.to(latents.device).type_as(latents)