poipiii commited on
Commit
dabc83d
1 Parent(s): e36655e

test refeactor

Browse files
Files changed (1) hide show
  1. pipeline.py +107 -74
pipeline.py CHANGED
@@ -673,6 +673,7 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
673
  eta: float = 0.0,
674
  generator: Optional[torch.Generator] = None,
675
  latents: Optional[torch.FloatTensor] = None,
 
676
  max_embeddings_multiples: Optional[int] = 3,
677
  output_type: Optional[str] = "pil",
678
  return_dict: bool = True,
@@ -796,7 +797,7 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
796
  timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device, image is None)
797
  latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
798
 
799
- # 6. Prepare latent variables
800
  latents, init_latents_orig, noise = self.prepare_latents(
801
  image,
802
  latent_timestep,
@@ -812,14 +813,14 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
812
  # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
813
  extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
814
  # print("before denoise latents")
815
- print(latents.shape)
816
  # 8. Denoising loop
817
  for i, t in enumerate(self.progress_bar(timesteps)):
818
  # expand the latents if we are doing classifier free guidance
819
  latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
820
- print("latent_model_input 1st step")
821
  # print(latent_model_input)
822
- print(latent_model_input.shape)
823
 
824
  latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
825
 
@@ -845,86 +846,91 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
845
  callback(i, t, latents)
846
  if is_cancelled_callback is not None and is_cancelled_callback():
847
  return None
848
- print("after first step denoise latents")
849
  # print(latents)
850
- print(latents.shape)
851
- upscale_latents = torch.nn.functional.interpolate(
852
- latents, size=(int(height*resize_scale)//8, int(width*resize_scale)//8))
853
 
854
- for i, t in enumerate(self.progress_bar(timesteps)):
855
- # expand the latents if we are doing classifier free guidance
856
- latent_model_input = torch.cat(
857
- [upscale_latents] * 2) if do_classifier_free_guidance else upscale_latents
858
- print("latent_model_input 2nd step")
859
- # print(latent_model_input)
860
- print(latent_model_input.shape)
861
-
862
- print("2nd step timestep")
863
- print(t)
864
-
865
- latent_model_input = self.scheduler.scale_model_input(
866
- latent_model_input, t)
867
 
868
 
869
 
870
 
871
- print("latent_model_input after scheduler")
872
- # print(latent_model_input)
873
- print(latent_model_input.shape)
874
- # predict the noise residual
875
- upscale_noise_pred = self.unet(latent_model_input, t,
876
- encoder_hidden_states=text_embeddings).sample
877
- print("noise_pred")
878
- # print(noise_pred)
879
- print(upscale_noise_pred.shape)
880
-
881
- print("perform guidance")
882
- # perform guidance
883
- if do_classifier_free_guidance:
884
- noise_pred_uncond, noise_pred_text = upscale_noise_pred.chunk(
885
- 2)
886
- upscale_noise_pred = noise_pred_uncond + guidance_scale * \
887
- (noise_pred_text - noise_pred_uncond)
888
- print("noise_pred after guidance")
889
- print(upscale_noise_pred.shape)
890
-
891
- print("compute the previous noisy sample")
892
- # compute the previous noisy sample x_t -> x_t-1
893
- upscale_latents = self.scheduler.step(
894
- upscale_noise_pred, t, upscale_latents, **extra_step_kwargs).prev_sample
895
- print(upscale_latents.shape)
896
-
897
- print("compute mask")
898
- if mask is not None:
899
- # masking
900
- init_latents_proper = self.scheduler.add_noise(
901
- init_latents_orig, noise, torch.tensor([t]))
902
- upscale_latents = (init_latents_proper *
903
- mask) + (latents * (1 - mask))
904
-
905
- # call the callback, if provided
906
- if i % callback_steps == 0:
907
- if callback is not None:
908
- callback(i, t, upscale_latents)
909
- if is_cancelled_callback is not None and is_cancelled_callback():
910
- return None
911
- #do latent upscale here
912
-
913
- # 9. Post-processing
914
- image = self.decode_latents(upscale_latents)
 
 
 
 
 
915
 
916
 
917
- # 10. Run safety checker
918
- image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype)
919
 
920
- # 11. Convert to PIL
921
- if output_type == "pil":
922
- image = self.numpy_to_pil(image)
923
 
924
- if not return_dict:
925
- return image, has_nsfw_concept
926
 
927
- return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
928
 
929
  def text2img(
930
  self,
@@ -934,6 +940,8 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
934
  width: int = 512,
935
  num_inference_steps: int = 50,
936
  guidance_scale: float = 7.5,
 
 
937
  num_images_per_prompt: Optional[int] = 1,
938
  eta: float = 0.0,
939
  generator: Optional[torch.Generator] = None,
@@ -1002,7 +1010,8 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
1002
  list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
1003
  (nsfw) content, according to the `safety_checker`.
1004
  """
1005
- return self.__call__(
 
1006
  prompt=prompt,
1007
  negative_prompt=negative_prompt,
1008
  height=height,
@@ -1013,6 +1022,30 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
1013
  eta=eta,
1014
  generator=generator,
1015
  latents=latents,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1016
  max_embeddings_multiples=max_embeddings_multiples,
1017
  output_type=output_type,
1018
  return_dict=return_dict,
 
673
  eta: float = 0.0,
674
  generator: Optional[torch.Generator] = None,
675
  latents: Optional[torch.FloatTensor] = None,
676
+ return_latents: bool = False,
677
  max_embeddings_multiples: Optional[int] = 3,
678
  output_type: Optional[str] = "pil",
679
  return_dict: bool = True,
 
797
  timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device, image is None)
798
  latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
799
 
800
+
801
  latents, init_latents_orig, noise = self.prepare_latents(
802
  image,
803
  latent_timestep,
 
813
  # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
814
  extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
815
  # print("before denoise latents")
816
+ # print(latents.shape)
817
  # 8. Denoising loop
818
  for i, t in enumerate(self.progress_bar(timesteps)):
819
  # expand the latents if we are doing classifier free guidance
820
  latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
821
+ # print("latent_model_input 1st step")
822
  # print(latent_model_input)
823
+ # print(latent_model_input.shape)
824
 
825
  latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
826
 
 
846
  callback(i, t, latents)
847
  if is_cancelled_callback is not None and is_cancelled_callback():
848
  return None
849
+ # print("after first step denoise latents")
850
  # print(latents)
851
+ # print(latents.shape)
852
+ # upscale_latents = torch.nn.functional.interpolate(
853
+ # latents, size=(int(height*resize_scale)//8, int(width*resize_scale)//8))
854
 
855
+ # for i, t in enumerate(self.progress_bar(timesteps)):
856
+ # # expand the latents if we are doing classifier free guidance
857
+ # latent_model_input = torch.cat(
858
+ # [upscale_latents] * 2) if do_classifier_free_guidance else upscale_latents
859
+ # # print("latent_model_input 2nd step")
860
+ # # print(latent_model_input)
861
+ # # print(latent_model_input.shape)
862
+
863
+ # # print("2nd step timestep")
864
+ # # print(t)
865
+
866
+ # latent_model_input = self.scheduler.scale_model_input(
867
+ # latent_model_input, t)
868
 
869
 
870
 
871
 
872
+ # # print("latent_model_input after scheduler")
873
+ # # print(latent_model_input)
874
+ # # print(latent_model_input.shape)
875
+ # # predict the noise residual
876
+ # upscale_noise_pred = self.unet(latent_model_input, t,
877
+ # encoder_hidden_states=text_embeddings).sample
878
+ # # print("noise_pred")
879
+ # # print(noise_pred)
880
+ # # print(upscale_noise_pred.shape)
881
+
882
+ # # print("perform guidance")
883
+ # # perform guidance
884
+ # if do_classifier_free_guidance:
885
+ # noise_pred_uncond, noise_pred_text = upscale_noise_pred.chunk(
886
+ # 2)
887
+ # upscale_noise_pred = noise_pred_uncond + guidance_scale * \
888
+ # (noise_pred_text - noise_pred_uncond)
889
+ # # print("noise_pred after guidance")
890
+ # # print(upscale_noise_pred.shape)
891
+
892
+ # # print("compute the previous noisy sample")
893
+ # # compute the previous noisy sample x_t -> x_t-1
894
+ # upscale_latents = self.scheduler.step(
895
+ # noise_pred, t, upscale_latents, **extra_step_kwargs).prev_sample
896
+ # # print(upscale_latents.shape)
897
+
898
+ # # print("compute mask")
899
+ # if mask is not None:
900
+ # # masking
901
+ # init_latents_proper = self.scheduler.add_noise(
902
+ # init_latents_orig, noise, torch.tensor([t]))
903
+ # upscale_latents = (init_latents_proper *
904
+ # mask) + (latents * (1 - mask))
905
+
906
+ # # call the callback, if provided
907
+ # if i % callback_steps == 0:
908
+ # if callback is not None:
909
+ # callback(i, t, upscale_latents)
910
+ # if is_cancelled_callback is not None and is_cancelled_callback():
911
+ # return None
912
+ # #do latent upscale here
913
+
914
+ # # 9. Post-processing
915
+ # image = self.decode_latents(upscale_latents)
916
+
917
+ if return_latents:
918
+ return latents
919
+ else:
920
+ image = self.decode_latents(latents)
921
 
922
 
923
+ # 10. Run safety checker
924
+ image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype)
925
 
926
+ # 11. Convert to PIL
927
+ if output_type == "pil":
928
+ image = self.numpy_to_pil(image)
929
 
930
+ if not return_dict:
931
+ return image, has_nsfw_concept
932
 
933
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
934
 
935
  def text2img(
936
  self,
 
940
  width: int = 512,
941
  num_inference_steps: int = 50,
942
  guidance_scale: float = 7.5,
943
+ strength: float = 0.6,
944
+ resize_scale: float = 1.2,
945
  num_images_per_prompt: Optional[int] = 1,
946
  eta: float = 0.0,
947
  generator: Optional[torch.Generator] = None,
 
1010
  list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
1011
  (nsfw) content, according to the `safety_checker`.
1012
  """
1013
+
1014
+ latents = self.__call__(
1015
  prompt=prompt,
1016
  negative_prompt=negative_prompt,
1017
  height=height,
 
1022
  eta=eta,
1023
  generator=generator,
1024
  latents=latents,
1025
+ return_latents=True,
1026
+ max_embeddings_multiples=max_embeddings_multiples,
1027
+ output_type=output_type,
1028
+ return_dict=return_dict,
1029
+ callback=callback,
1030
+ is_cancelled_callback=is_cancelled_callback,
1031
+ callback_steps=callback_steps,
1032
+ )
1033
+
1034
+ latents = torch.nn.functional.interpolate(latents, size=(int(height*resize_scale)//8, int(width*resize_scale)//8))
1035
+
1036
+ return self.__call__(
1037
+ prompt=prompt,
1038
+ negative_prompt=negative_prompt,
1039
+ height=height*resize_scale,
1040
+ width=width*resize_scale,
1041
+ num_inference_steps=num_inference_steps,
1042
+ guidance_scale=guidance_scale,
1043
+ strength=strength,
1044
+ num_images_per_prompt=num_images_per_prompt,
1045
+ eta=eta,
1046
+ generator=generator,
1047
+ latents=latents,
1048
+ return_latents=False,
1049
  max_embeddings_multiples=max_embeddings_multiples,
1050
  output_type=output_type,
1051
  return_dict=return_dict,