poipiii
commited on
Commit
•
dabc83d
1
Parent(s):
e36655e
test refeactor
Browse files- pipeline.py +107 -74
pipeline.py
CHANGED
@@ -673,6 +673,7 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
|
|
673 |
eta: float = 0.0,
|
674 |
generator: Optional[torch.Generator] = None,
|
675 |
latents: Optional[torch.FloatTensor] = None,
|
|
|
676 |
max_embeddings_multiples: Optional[int] = 3,
|
677 |
output_type: Optional[str] = "pil",
|
678 |
return_dict: bool = True,
|
@@ -796,7 +797,7 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
|
|
796 |
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device, image is None)
|
797 |
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
|
798 |
|
799 |
-
|
800 |
latents, init_latents_orig, noise = self.prepare_latents(
|
801 |
image,
|
802 |
latent_timestep,
|
@@ -812,14 +813,14 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
|
|
812 |
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
813 |
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
814 |
# print("before denoise latents")
|
815 |
-
print(latents.shape)
|
816 |
# 8. Denoising loop
|
817 |
for i, t in enumerate(self.progress_bar(timesteps)):
|
818 |
# expand the latents if we are doing classifier free guidance
|
819 |
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
820 |
-
print("latent_model_input 1st step")
|
821 |
# print(latent_model_input)
|
822 |
-
print(latent_model_input.shape)
|
823 |
|
824 |
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
825 |
|
@@ -845,86 +846,91 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
|
|
845 |
callback(i, t, latents)
|
846 |
if is_cancelled_callback is not None and is_cancelled_callback():
|
847 |
return None
|
848 |
-
print("after first step denoise latents")
|
849 |
# print(latents)
|
850 |
-
print(latents.shape)
|
851 |
-
upscale_latents = torch.nn.functional.interpolate(
|
852 |
-
|
853 |
|
854 |
-
for i, t in enumerate(self.progress_bar(timesteps)):
|
855 |
-
|
856 |
-
|
857 |
-
|
858 |
-
|
859 |
-
|
860 |
-
|
861 |
-
|
862 |
-
|
863 |
-
|
864 |
-
|
865 |
-
|
866 |
-
|
867 |
|
868 |
|
869 |
|
870 |
|
871 |
-
|
872 |
-
|
873 |
-
|
874 |
-
|
875 |
-
|
876 |
-
|
877 |
-
|
878 |
-
|
879 |
-
|
880 |
-
|
881 |
-
|
882 |
-
|
883 |
-
|
884 |
-
|
885 |
-
|
886 |
-
|
887 |
-
|
888 |
-
|
889 |
-
|
890 |
-
|
891 |
-
|
892 |
-
|
893 |
-
|
894 |
-
|
895 |
-
|
896 |
-
|
897 |
-
|
898 |
-
|
899 |
-
|
900 |
-
|
901 |
-
|
902 |
-
|
903 |
-
|
904 |
-
|
905 |
-
|
906 |
-
|
907 |
-
|
908 |
-
|
909 |
-
|
910 |
-
|
911 |
-
#do latent upscale here
|
912 |
-
|
913 |
-
# 9. Post-processing
|
914 |
-
image = self.decode_latents(upscale_latents)
|
|
|
|
|
|
|
|
|
|
|
915 |
|
916 |
|
917 |
-
|
918 |
-
|
919 |
|
920 |
-
|
921 |
-
|
922 |
-
|
923 |
|
924 |
-
|
925 |
-
|
926 |
|
927 |
-
|
928 |
|
929 |
def text2img(
|
930 |
self,
|
@@ -934,6 +940,8 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
|
|
934 |
width: int = 512,
|
935 |
num_inference_steps: int = 50,
|
936 |
guidance_scale: float = 7.5,
|
|
|
|
|
937 |
num_images_per_prompt: Optional[int] = 1,
|
938 |
eta: float = 0.0,
|
939 |
generator: Optional[torch.Generator] = None,
|
@@ -1002,7 +1010,8 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
|
|
1002 |
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
1003 |
(nsfw) content, according to the `safety_checker`.
|
1004 |
"""
|
1005 |
-
|
|
|
1006 |
prompt=prompt,
|
1007 |
negative_prompt=negative_prompt,
|
1008 |
height=height,
|
@@ -1013,6 +1022,30 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
|
|
1013 |
eta=eta,
|
1014 |
generator=generator,
|
1015 |
latents=latents,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1016 |
max_embeddings_multiples=max_embeddings_multiples,
|
1017 |
output_type=output_type,
|
1018 |
return_dict=return_dict,
|
|
|
673 |
eta: float = 0.0,
|
674 |
generator: Optional[torch.Generator] = None,
|
675 |
latents: Optional[torch.FloatTensor] = None,
|
676 |
+
return_latents: bool = False,
|
677 |
max_embeddings_multiples: Optional[int] = 3,
|
678 |
output_type: Optional[str] = "pil",
|
679 |
return_dict: bool = True,
|
|
|
797 |
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device, image is None)
|
798 |
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
|
799 |
|
800 |
+
|
801 |
latents, init_latents_orig, noise = self.prepare_latents(
|
802 |
image,
|
803 |
latent_timestep,
|
|
|
813 |
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
814 |
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
815 |
# print("before denoise latents")
|
816 |
+
# print(latents.shape)
|
817 |
# 8. Denoising loop
|
818 |
for i, t in enumerate(self.progress_bar(timesteps)):
|
819 |
# expand the latents if we are doing classifier free guidance
|
820 |
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
821 |
+
# print("latent_model_input 1st step")
|
822 |
# print(latent_model_input)
|
823 |
+
# print(latent_model_input.shape)
|
824 |
|
825 |
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
826 |
|
|
|
846 |
callback(i, t, latents)
|
847 |
if is_cancelled_callback is not None and is_cancelled_callback():
|
848 |
return None
|
849 |
+
# print("after first step denoise latents")
|
850 |
# print(latents)
|
851 |
+
# print(latents.shape)
|
852 |
+
# upscale_latents = torch.nn.functional.interpolate(
|
853 |
+
# latents, size=(int(height*resize_scale)//8, int(width*resize_scale)//8))
|
854 |
|
855 |
+
# for i, t in enumerate(self.progress_bar(timesteps)):
|
856 |
+
# # expand the latents if we are doing classifier free guidance
|
857 |
+
# latent_model_input = torch.cat(
|
858 |
+
# [upscale_latents] * 2) if do_classifier_free_guidance else upscale_latents
|
859 |
+
# # print("latent_model_input 2nd step")
|
860 |
+
# # print(latent_model_input)
|
861 |
+
# # print(latent_model_input.shape)
|
862 |
+
|
863 |
+
# # print("2nd step timestep")
|
864 |
+
# # print(t)
|
865 |
+
|
866 |
+
# latent_model_input = self.scheduler.scale_model_input(
|
867 |
+
# latent_model_input, t)
|
868 |
|
869 |
|
870 |
|
871 |
|
872 |
+
# # print("latent_model_input after scheduler")
|
873 |
+
# # print(latent_model_input)
|
874 |
+
# # print(latent_model_input.shape)
|
875 |
+
# # predict the noise residual
|
876 |
+
# upscale_noise_pred = self.unet(latent_model_input, t,
|
877 |
+
# encoder_hidden_states=text_embeddings).sample
|
878 |
+
# # print("noise_pred")
|
879 |
+
# # print(noise_pred)
|
880 |
+
# # print(upscale_noise_pred.shape)
|
881 |
+
|
882 |
+
# # print("perform guidance")
|
883 |
+
# # perform guidance
|
884 |
+
# if do_classifier_free_guidance:
|
885 |
+
# noise_pred_uncond, noise_pred_text = upscale_noise_pred.chunk(
|
886 |
+
# 2)
|
887 |
+
# upscale_noise_pred = noise_pred_uncond + guidance_scale * \
|
888 |
+
# (noise_pred_text - noise_pred_uncond)
|
889 |
+
# # print("noise_pred after guidance")
|
890 |
+
# # print(upscale_noise_pred.shape)
|
891 |
+
|
892 |
+
# # print("compute the previous noisy sample")
|
893 |
+
# # compute the previous noisy sample x_t -> x_t-1
|
894 |
+
# upscale_latents = self.scheduler.step(
|
895 |
+
# noise_pred, t, upscale_latents, **extra_step_kwargs).prev_sample
|
896 |
+
# # print(upscale_latents.shape)
|
897 |
+
|
898 |
+
# # print("compute mask")
|
899 |
+
# if mask is not None:
|
900 |
+
# # masking
|
901 |
+
# init_latents_proper = self.scheduler.add_noise(
|
902 |
+
# init_latents_orig, noise, torch.tensor([t]))
|
903 |
+
# upscale_latents = (init_latents_proper *
|
904 |
+
# mask) + (latents * (1 - mask))
|
905 |
+
|
906 |
+
# # call the callback, if provided
|
907 |
+
# if i % callback_steps == 0:
|
908 |
+
# if callback is not None:
|
909 |
+
# callback(i, t, upscale_latents)
|
910 |
+
# if is_cancelled_callback is not None and is_cancelled_callback():
|
911 |
+
# return None
|
912 |
+
# #do latent upscale here
|
913 |
+
|
914 |
+
# # 9. Post-processing
|
915 |
+
# image = self.decode_latents(upscale_latents)
|
916 |
+
|
917 |
+
if return_latents:
|
918 |
+
return latents
|
919 |
+
else:
|
920 |
+
image = self.decode_latents(latents)
|
921 |
|
922 |
|
923 |
+
# 10. Run safety checker
|
924 |
+
image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype)
|
925 |
|
926 |
+
# 11. Convert to PIL
|
927 |
+
if output_type == "pil":
|
928 |
+
image = self.numpy_to_pil(image)
|
929 |
|
930 |
+
if not return_dict:
|
931 |
+
return image, has_nsfw_concept
|
932 |
|
933 |
+
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|
934 |
|
935 |
def text2img(
|
936 |
self,
|
|
|
940 |
width: int = 512,
|
941 |
num_inference_steps: int = 50,
|
942 |
guidance_scale: float = 7.5,
|
943 |
+
strength: float = 0.6,
|
944 |
+
resize_scale: float = 1.2,
|
945 |
num_images_per_prompt: Optional[int] = 1,
|
946 |
eta: float = 0.0,
|
947 |
generator: Optional[torch.Generator] = None,
|
|
|
1010 |
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
1011 |
(nsfw) content, according to the `safety_checker`.
|
1012 |
"""
|
1013 |
+
|
1014 |
+
latents = self.__call__(
|
1015 |
prompt=prompt,
|
1016 |
negative_prompt=negative_prompt,
|
1017 |
height=height,
|
|
|
1022 |
eta=eta,
|
1023 |
generator=generator,
|
1024 |
latents=latents,
|
1025 |
+
return_latents=True,
|
1026 |
+
max_embeddings_multiples=max_embeddings_multiples,
|
1027 |
+
output_type=output_type,
|
1028 |
+
return_dict=return_dict,
|
1029 |
+
callback=callback,
|
1030 |
+
is_cancelled_callback=is_cancelled_callback,
|
1031 |
+
callback_steps=callback_steps,
|
1032 |
+
)
|
1033 |
+
|
1034 |
+
latents = torch.nn.functional.interpolate(latents, size=(int(height*resize_scale)//8, int(width*resize_scale)//8))
|
1035 |
+
|
1036 |
+
return self.__call__(
|
1037 |
+
prompt=prompt,
|
1038 |
+
negative_prompt=negative_prompt,
|
1039 |
+
height=height*resize_scale,
|
1040 |
+
width=width*resize_scale,
|
1041 |
+
num_inference_steps=num_inference_steps,
|
1042 |
+
guidance_scale=guidance_scale,
|
1043 |
+
strength=strength,
|
1044 |
+
num_images_per_prompt=num_images_per_prompt,
|
1045 |
+
eta=eta,
|
1046 |
+
generator=generator,
|
1047 |
+
latents=latents,
|
1048 |
+
return_latents=False,
|
1049 |
max_embeddings_multiples=max_embeddings_multiples,
|
1050 |
output_type=output_type,
|
1051 |
return_dict=return_dict,
|