saintzeno commited on
Commit
cc3d44c
1 Parent(s): 1539d07

Upload pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +220 -0
pipeline.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from diffusers import StableDiffusionPipeline, StableDiffusionPipelineOutput
3
+ from typing import Any, Callable, Dict, List, Optional, Union
4
+
5
+
6
+ class MyPipeline(StableDiffusionPipeline):
7
+ @torch.no_grad()
8
+ def __call__(
9
+ self,
10
+ prompt: Union[str, List[str]] = None,
11
+ height: Optional[int] = None,
12
+ width: Optional[int] = None,
13
+ num_inference_steps: int = 50,
14
+ guidance_scale: float = 7.5,
15
+ negative_prompt: Optional[Union[str, List[str]]] = None,
16
+ num_images_per_prompt: Optional[int] = 1,
17
+ eta: float = 0.0,
18
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
19
+ latents: Optional[torch.FloatTensor] = None,
20
+ prompt_embeds: Optional[torch.FloatTensor] = None,
21
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
22
+ output_type: Optional[str] = "pil",
23
+ return_dict: bool = True,
24
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
25
+ callback_steps: int = 1,
26
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
27
+ ########
28
+ image_saving_kwargs: Optional[Dict[str, Any]] = None,
29
+
30
+ ):
31
+ r"""
32
+ Function invoked when calling the pipeline for generation.
33
+
34
+ Args:
35
+ prompt (`str` or `List[str]`, *optional*):
36
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
37
+ instead.
38
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
39
+ The height in pixels of the generated image.
40
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
41
+ The width in pixels of the generated image.
42
+ num_inference_steps (`int`, *optional*, defaults to 50):
43
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
44
+ expense of slower inference.
45
+ guidance_scale (`float`, *optional*, defaults to 7.5):
46
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
47
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
48
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
49
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
50
+ usually at the expense of lower image quality.
51
+ negative_prompt (`str` or `List[str]`, *optional*):
52
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
53
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
54
+ less than `1`).
55
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
56
+ The number of images to generate per prompt.
57
+ eta (`float`, *optional*, defaults to 0.0):
58
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
59
+ [`schedulers.DDIMScheduler`], will be ignored for others.
60
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
61
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
62
+ to make generation deterministic.
63
+ latents (`torch.FloatTensor`, *optional*):
64
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
65
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
66
+ tensor will ge generated by sampling using the supplied random `generator`.
67
+ prompt_embeds (`torch.FloatTensor`, *optional*):
68
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
69
+ provided, text embeddings will be generated from `prompt` input argument.
70
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
71
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
72
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
73
+ argument.
74
+ output_type (`str`, *optional*, defaults to `"pil"`):
75
+ The output format of the generate image. Choose between
76
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
77
+ return_dict (`bool`, *optional*, defaults to `True`):
78
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
79
+ plain tuple.
80
+ callback (`Callable`, *optional*):
81
+ A function that will be called every `callback_steps` steps during inference. The function will be
82
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
83
+ callback_steps (`int`, *optional*, defaults to 1):
84
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
85
+ called at every step.
86
+ cross_attention_kwargs (`dict`, *optional*):
87
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
88
+ `self.processor` in
89
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
90
+
91
+ Examples:
92
+
93
+ Returns:
94
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
95
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
96
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
97
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
98
+ (nsfw) content, according to the `safety_checker`.
99
+ """
100
+ # 0. Default height and width to unet
101
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
102
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
103
+
104
+ # 1. Check inputs. Raise error if not correct
105
+ self.check_inputs(
106
+ prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
107
+ )
108
+
109
+ # 2. Define call parameters
110
+ if prompt is not None and isinstance(prompt, str):
111
+ batch_size = 1
112
+ elif prompt is not None and isinstance(prompt, list):
113
+ batch_size = len(prompt)
114
+ else:
115
+ batch_size = prompt_embeds.shape[0]
116
+
117
+ device = self._execution_device
118
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
119
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
120
+ # corresponds to doing no classifier free guidance.
121
+ do_classifier_free_guidance = guidance_scale > 1.0
122
+
123
+ # 3. Encode input prompt
124
+ prompt_embeds = self._encode_prompt(
125
+ prompt,
126
+ device,
127
+ num_images_per_prompt,
128
+ do_classifier_free_guidance,
129
+ negative_prompt,
130
+ prompt_embeds=prompt_embeds,
131
+ negative_prompt_embeds=negative_prompt_embeds,
132
+ )
133
+
134
+ # 4. Prepare timesteps
135
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
136
+ timesteps = self.scheduler.timesteps
137
+
138
+ # 5. Prepare latent variables
139
+ num_channels_latents = self.unet.config.in_channels
140
+ latents = self.prepare_latents(
141
+ batch_size * num_images_per_prompt,
142
+ num_channels_latents,
143
+ height,
144
+ width,
145
+ prompt_embeds.dtype,
146
+ device,
147
+ generator,
148
+ latents,
149
+ )
150
+
151
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
152
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
153
+
154
+ # 7. Denoising loop
155
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
156
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
157
+ for i, t in enumerate(timesteps):
158
+ # expand the latents if we are doing classifier free guidance
159
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
160
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
161
+
162
+ # predict the noise residual
163
+ noise_pred = self.unet(
164
+ latent_model_input,
165
+ t,
166
+ encoder_hidden_states=prompt_embeds,
167
+ cross_attention_kwargs=cross_attention_kwargs,
168
+ ).sample
169
+
170
+ # perform guidance
171
+ if do_classifier_free_guidance:
172
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
173
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
174
+
175
+ # compute the previous noisy sample x_t -> x_t-1
176
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
177
+ ####################################################################################################
178
+ ####################################################################################################
179
+ if image_saving_kwargs:
180
+ if image_saving_kwargs.get('save_denoising_images'):
181
+ image = self.decode_latents(latents)
182
+ image = self.numpy_to_pil(image)
183
+ StableDiffusionPipelineOutput(images=image, nsfw_content_detected=None).images[0].save(image_saving_kwargs['save_denoising_path'] + '_t_' + str(i) + '.png')
184
+
185
+ ####################################################################################################
186
+ ####################################################################################################
187
+
188
+ # call the callback, if provided
189
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
190
+ progress_bar.update()
191
+ if callback is not None and i % callback_steps == 0:
192
+ callback(i, t, latents)
193
+
194
+ if output_type == "latent":
195
+ image = latents
196
+ has_nsfw_concept = None
197
+ elif output_type == "pil":
198
+ # 8. Post-processing
199
+ image = self.decode_latents(latents)
200
+
201
+ # 9. Run safety checker
202
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
203
+
204
+ # 10. Convert to PIL
205
+ image = self.numpy_to_pil(image)
206
+ else:
207
+ # 8. Post-processing
208
+ image = self.decode_latents(latents)
209
+
210
+ # 9. Run safety checker
211
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
212
+
213
+ # Offload last model to CPU
214
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
215
+ self.final_offload_hook.offload()
216
+
217
+ if not return_dict:
218
+ return (image, has_nsfw_concept)
219
+
220
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)