AlekseyCalvin commited on
Commit
ea6ea68
·
verified ·
1 Parent(s): eb0e115

Upload pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +678 -0
pipeline.py ADDED
@@ -0,0 +1,678 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import html
4
+ import inspect
5
+ import re
6
+ import urllib.parse as ul
7
+ from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast, CLIPTextModelWithProjection
8
+ from diffusers import FlowMatchEulerDiscreteScheduler, AutoPipelineForImage2Image, FluxPipeline, FluxTransformer2DModel
9
+ from diffusers import StableDiffusion3Pipeline, AutoencoderKL, DiffusionPipeline, ImagePipelineOutput
10
+ from diffusers.image_processor import VaeImageProcessor
11
+ from diffusers.loaders import FluxLoraLoaderMixin, FromSingleFileMixin, SD3LoraLoaderMixin
12
+ from diffusers.utils import (
13
+ USE_PEFT_BACKEND,
14
+ is_torch_xla_available,
15
+ logging,
16
+ BACKENDS_MAPPING,
17
+ deprecate,
18
+ replace_example_docstring,
19
+ scale_lora_layers,
20
+ unscale_lora_layers,
21
+ )
22
+ from diffusers.utils.torch_utils import randn_tensor
23
+ from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
24
+ from typing import Any, Callable, Dict, List, Optional, Union
25
+ from PIL import Image
26
+ from diffusers.pipelines.flux.pipeline_flux import calculate_shift, retrieve_timesteps, FluxTransformer2DModel
27
+ from diffusers.utils import is_torch_xla_available
28
+
29
+ if is_torch_xla_available():
30
+ import torch_xla.core.xla_model as xm
31
+
32
+ XLA_AVAILABLE = True
33
+ else:
34
+ XLA_AVAILABLE = False
35
+
36
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
37
+
38
+ # Constants for shift calculation
39
+ BASE_SEQ_LEN = 256
40
+ MAX_SEQ_LEN = 4096
41
+ BASE_SHIFT = 0.5
42
+ MAX_SHIFT = 1.2
43
+
44
+ # Helper functions
45
+ def calculate_timestep_shift(image_seq_len: int) -> float:
46
+ """Calculates the timestep shift (mu) based on the image sequence length."""
47
+ m = (MAX_SHIFT - BASE_SHIFT) / (MAX_SEQ_LEN - BASE_SEQ_LEN)
48
+ b = BASE_SHIFT - m * BASE_SEQ_LEN
49
+ mu = image_seq_len * m + b
50
+ return mu
51
+
52
+ def prepare_timesteps(
53
+ scheduler: FlowMatchEulerDiscreteScheduler,
54
+ num_inference_steps: Optional[int] = None,
55
+ device: Optional[Union[str, torch.device]] = None,
56
+ timesteps: Optional[List[int]] = None,
57
+ sigmas: Optional[List[float]] = None,
58
+ mu: Optional[float] = None,
59
+ ) -> (torch.Tensor, int):
60
+ """Prepares the timesteps for the diffusion process."""
61
+ if timesteps is not None and sigmas is not None:
62
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed.")
63
+
64
+ if timesteps is not None:
65
+ scheduler.set_timesteps(timesteps=timesteps, device=device)
66
+ elif sigmas is not None:
67
+ scheduler.set_timesteps(sigmas=sigmas, device=device)
68
+ else:
69
+ scheduler.set_timesteps(num_inference_steps, device=device, mu=mu)
70
+
71
+ timesteps = scheduler.timesteps
72
+ num_inference_steps = len(timesteps)
73
+ return timesteps, num_inference_steps
74
+
75
+ # FLUX pipeline function
76
+ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin):
77
+ def __init__(
78
+ self,
79
+ scheduler: FlowMatchEulerDiscreteScheduler,
80
+ vae: AutoencoderKL,
81
+ text_encoder: CLIPTextModel,
82
+ tokenizer: CLIPTokenizer,
83
+ text_encoder_2: T5EncoderModel,
84
+ tokenizer_2: T5TokenizerFast,
85
+ transformer: FluxTransformer2DModel,
86
+ ):
87
+ super().__init__()
88
+
89
+ self.register_modules(
90
+ vae=vae,
91
+ text_encoder=text_encoder,
92
+ text_encoder_2=text_encoder_2,
93
+ tokenizer=tokenizer,
94
+ tokenizer_2=tokenizer_2,
95
+ transformer=transformer,
96
+ scheduler=scheduler,
97
+ )
98
+ self.vae_scale_factor = (
99
+ 2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16
100
+ )
101
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
102
+ self.tokenizer_max_length = (
103
+ self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
104
+ )
105
+ self.default_sample_size = 64
106
+
107
+ def _get_t5_prompt_embeds(
108
+ self,
109
+ prompt: Union[str, List[str]] = None,
110
+ num_images_per_prompt: int = 1,
111
+ max_sequence_length: int = 512,
112
+ device: Optional[torch.device] = None,
113
+ dtype: Optional[torch.dtype] = None,
114
+ ):
115
+ device = device or self._execution_device
116
+ dtype = dtype or self.text_encoder.dtype
117
+
118
+ prompt = [prompt] if isinstance(prompt, str) else prompt
119
+ batch_size = len(prompt)
120
+
121
+ text_inputs = self.tokenizer_2(
122
+ prompt,
123
+ padding="max_length",
124
+ max_length=max_sequence_length,
125
+ truncation=True,
126
+ return_length=True,
127
+ return_overflowing_tokens=True,
128
+ return_tensors="pt",
129
+ )
130
+ text_input_ids = text_inputs.input_ids
131
+ untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
132
+
133
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
134
+ removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
135
+ logger.warning(
136
+ "The following part of your input was truncated because `max_sequence_length` is set to "
137
+ f" {max_sequence_length} tokens: {removed_text}"
138
+ )
139
+
140
+ prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
141
+
142
+ dtype = self.text_encoder_2.dtype
143
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
144
+
145
+ _, seq_len, _ = prompt_embeds.shape
146
+
147
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
148
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
149
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
150
+
151
+ return prompt_embeds
152
+
153
+ def _get_clip_prompt_embeds(
154
+ self,
155
+ prompt: Union[str, List[str]],
156
+ num_images_per_prompt: int = 1,
157
+ device: Optional[torch.device] = None,
158
+ ):
159
+ device = device or self._execution_device
160
+
161
+ prompt = [prompt] if isinstance(prompt, str) else prompt
162
+ batch_size = len(prompt)
163
+
164
+ text_inputs = self.tokenizer(
165
+ prompt,
166
+ padding="max_length",
167
+ max_length=self.tokenizer_max_length,
168
+ truncation=True,
169
+ return_overflowing_tokens=False,
170
+ return_length=False,
171
+ return_tensors="pt",
172
+ )
173
+
174
+ text_input_ids = text_inputs.input_ids
175
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
176
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
177
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
178
+ logger.warning(
179
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
180
+ f" {self.tokenizer_max_length} tokens: {removed_text}"
181
+ )
182
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
183
+
184
+ # Use pooled output of CLIPTextModel
185
+ prompt_embeds = prompt_embeds.pooler_output
186
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
187
+
188
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
189
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
190
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
191
+
192
+ return prompt_embeds
193
+
194
+ def encode_prompt(
195
+ self,
196
+ prompt: Union[str, List[str]],
197
+ prompt_2: Union[str, List[str]],
198
+ num_images_per_prompt: int = 1,
199
+ max_sequence_length: int = 512,
200
+ do_classifier_free_guidance: bool = True,
201
+ device: Optional[torch.device] = None,
202
+ negative_prompt: Optional[Union[str, List[str]]] = None,
203
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
204
+ prompt_embeds: Optional[torch.FloatTensor] = None,
205
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
206
+ negative_prompt_2_embed: Optional[torch.Tensor] = None,
207
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
208
+ negative_pooled_prompt_2_embed: Optional[torch.FloatTensor] = None,
209
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
210
+ lora_scale: Optional[float] = None,
211
+ ):
212
+ device = device or self._execution_device
213
+ if device is None:
214
+ device = self._execution_device
215
+
216
+ # set lora scale so that monkey patched LoRA
217
+ # function of text encoder can correctly access it
218
+ if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
219
+ self._lora_scale = lora_scale
220
+
221
+ # dynamically adjust the LoRA scale
222
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
223
+ scale_lora_layers(self.text_encoder, lora_scale)
224
+ if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
225
+ scale_lora_layers(self.text_encoder_2, lora_scale)
226
+
227
+ prompt = [prompt] if isinstance(prompt, str) else prompt
228
+ if prompt is not None:
229
+ batch_size = len(prompt)
230
+ else:
231
+ batch_size = prompt_embeds.shape[0]
232
+
233
+ if prompt_embeds is None:
234
+ prompt_2 = prompt_2 or prompt
235
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
236
+
237
+ pooled_prompt_embeds = self._get_clip_prompt_embeds(
238
+ prompt=prompt,
239
+ device=device,
240
+ num_images_per_prompt=num_images_per_prompt,
241
+ )
242
+ prompt_embeds = self._get_t5_prompt_embeds(
243
+ prompt=prompt_2,
244
+ num_images_per_prompt=num_images_per_prompt,
245
+ max_sequence_length=max_sequence_length,
246
+ device=device,
247
+ )
248
+
249
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
250
+ negative_prompt = negative_prompt or ""
251
+ negative_prompt_2 = negative_prompt_2 or negative_prompt
252
+
253
+ # normalize str to list
254
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
255
+ negative_prompt_2 = (
256
+ batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
257
+ )
258
+
259
+ if prompt is not None and type(prompt) is not type(negative_prompt):
260
+ raise TypeError(
261
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
262
+ f" {type(prompt)}."
263
+ )
264
+ elif batch_size != len(negative_prompt):
265
+ raise ValueError(
266
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
267
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
268
+ " the batch size of `prompt`."
269
+ )
270
+
271
+ negative_prompt_embed, negative_pooled_prompt_embed = self._get_clip_prompt_embeds(
272
+ negative_prompt,
273
+ device=device,
274
+ num_images_per_prompt=num_images_per_prompt,
275
+ )
276
+ negative_clip_prompt_embeds = torch.cat([negative_prompt_embed, negative_prompt_2_embed], dim=-1)
277
+
278
+ t5_negative_prompt_embed = self._get_t5_prompt_embeds(
279
+ prompt=negative_prompt_2,
280
+ num_images_per_prompt=num_images_per_prompt,
281
+ max_sequence_length=max_sequence_length,
282
+ device=device,
283
+ )
284
+
285
+ negative_clip_prompt_embeds = torch.nn.functional.pad(
286
+ negative_clip_prompt_embeds,
287
+ (0, t5_negative_prompt_embed.shape[-1] - negative_clip_prompt_embeds.shape[-1]),
288
+ )
289
+
290
+ negative_prompt_embeds = torch.cat([negative_clip_prompt_embeds, t5_negative_prompt_embed], dim=-2)
291
+ negative_pooled_prompt_embeds = torch.cat(
292
+ [negative_pooled_prompt_embed, negative_pooled_prompt_2_embed], dim=-1
293
+ )
294
+
295
+ if self.text_encoder is not None:
296
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
297
+ # Retrieve the original scale by scaling back the LoRA layers
298
+ unscale_lora_layers(self.text_encoder, lora_scale)
299
+
300
+ if self.text_encoder_2 is not None:
301
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
302
+ # Retrieve the original scale by scaling back the LoRA layers
303
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
304
+
305
+ dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
306
+ text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
307
+
308
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
309
+ pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
310
+
311
+ return prompt_embeds, pooled_prompt_embeds, text_ids, negative_prompt_embeds, negative_pooled_prompt_embeds
312
+
313
+ def check_inputs(
314
+ self,
315
+ prompt,
316
+ prompt_2,
317
+ height,
318
+ width,
319
+ negative_prompt=None,
320
+ negative_prompt_2=None,
321
+ prompt_embeds=None,
322
+ negative_prompt_embeds=None,
323
+ pooled_prompt_embeds=None,
324
+ negative_pooled_prompt_embeds=None,
325
+ max_sequence_length=None,
326
+ ):
327
+ if height % 8 != 0 or width % 8 != 0:
328
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
329
+
330
+ if prompt is not None and prompt_embeds is not None:
331
+ raise ValueError(
332
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
333
+ " only forward one of the two."
334
+ )
335
+ elif prompt_2 is not None and prompt_embeds is not None:
336
+ raise ValueError(
337
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
338
+ " only forward one of the two."
339
+ )
340
+ elif prompt is None and prompt_embeds is None:
341
+ raise ValueError(
342
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
343
+ )
344
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
345
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
346
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
347
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
348
+
349
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
350
+ raise ValueError(
351
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
352
+ )
353
+ if negative_prompt is not None and negative_prompt_embeds is not None:
354
+ raise ValueError(
355
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
356
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
357
+ )
358
+ elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
359
+ raise ValueError(
360
+ f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
361
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
362
+ )
363
+ if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
364
+ raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
365
+
366
+ if max_sequence_length is not None and max_sequence_length > 512:
367
+ raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
368
+
369
+ @staticmethod
370
+ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
371
+ latent_image_ids = torch.zeros(height // 2, width // 2, 3)
372
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
373
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
374
+
375
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
376
+
377
+ latent_image_ids = latent_image_ids.reshape(
378
+ latent_image_id_height * latent_image_id_width, latent_image_id_channels
379
+ )
380
+
381
+ return latent_image_ids.to(device=device, dtype=dtype)
382
+
383
+ @staticmethod
384
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
385
+ latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
386
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
387
+ latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
388
+
389
+ return latents
390
+
391
+ @staticmethod
392
+ def _unpack_latents(latents, height, width, vae_scale_factor):
393
+ batch_size, num_patches, channels = latents.shape
394
+
395
+ height = height // vae_scale_factor
396
+ width = width // vae_scale_factor
397
+
398
+ latents = latents.view(batch_size, height, width, channels // 4, 2, 2)
399
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
400
+
401
+ latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2)
402
+
403
+ return latents
404
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
405
+ def prepare_extra_step_kwargs(self, generator, eta):
406
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
407
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
408
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
409
+ # and should be between [0, 1]
410
+
411
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
412
+ extra_step_kwargs = {}
413
+ if accepts_eta:
414
+ extra_step_kwargs["eta"] = eta
415
+
416
+ # check if the scheduler accepts generator
417
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
418
+ if accepts_generator:
419
+ extra_step_kwargs["generator"] = generator
420
+ return extra_step_kwargs
421
+
422
+ def enable_vae_slicing(self):
423
+ r"""
424
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
425
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
426
+ """
427
+ self.vae.enable_slicing()
428
+
429
+ def disable_vae_slicing(self):
430
+ r"""
431
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
432
+ computing decoding in one step.
433
+ """
434
+ self.vae.disable_slicing()
435
+
436
+ def enable_vae_tiling(self):
437
+ r"""
438
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
439
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
440
+ processing larger images.
441
+ """
442
+ self.vae.enable_tiling()
443
+
444
+ def disable_vae_tiling(self):
445
+ r"""
446
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
447
+ computing decoding in one step.
448
+ """
449
+ self.vae.disable_tiling()
450
+
451
+ def prepare_latents(
452
+ self,
453
+ batch_size,
454
+ num_channels_latents,
455
+ height,
456
+ width,
457
+ dtype,
458
+ device,
459
+ generator,
460
+ latents=None,
461
+ ):
462
+ height = 2 * (int(height) // self.vae_scale_factor)
463
+ width = 2 * (int(width) // self.vae_scale_factor)
464
+
465
+ shape = (batch_size, num_channels_latents, height, width)
466
+
467
+ if latents is not None:
468
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
469
+ return latents.to(device=device, dtype=dtype), latent_image_ids
470
+
471
+ if isinstance(generator, list) and len(generator) != batch_size:
472
+ raise ValueError(
473
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
474
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
475
+ )
476
+
477
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
478
+ latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
479
+
480
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
481
+
482
+ return latents, latent_image_ids
483
+
484
+ @property
485
+ def guidance_scale(self):
486
+ return self._guidance_scale
487
+
488
+ @property
489
+ def do_classifier_free_guidance(self):
490
+ return self._guidance_scale > 1
491
+
492
+ @property
493
+ def joint_attention_kwargs(self):
494
+ return self._joint_attention_kwargs
495
+
496
+ @property
497
+ def num_timesteps(self):
498
+ return self._num_timesteps
499
+
500
+ @property
501
+ def interrupt(self):
502
+ return self._interrupt
503
+
504
+ @torch.no_grad()
505
+ @torch.inference_mode()
506
+ def __call__(
507
+ self,
508
+ prompt: Union[str, List[str]] = None,
509
+ prompt_2: Optional[Union[str, List[str]]] = None,
510
+ height: Optional[int] = None,
511
+ width: Optional[int] = None,
512
+ negative_prompt: Optional[Union[str, List[str]]] = None,
513
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
514
+ num_inference_steps: int = 8,
515
+ timesteps: List[int] = None,
516
+ eta: Optional[float] = 0.0,
517
+ guidance_scale: float = 3.5,
518
+ device: Optional[int] = None,
519
+ num_images_per_prompt: Optional[int] = 1,
520
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
521
+ latents: Optional[torch.FloatTensor] = None,
522
+ prompt_embeds: Optional[torch.FloatTensor] = None,
523
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
524
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
525
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
526
+ output_type: Optional[str] = "pil",
527
+ return_dict: bool = True,
528
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
529
+ max_sequence_length: int = 300,
530
+ **kwargs,
531
+ ):
532
+ height = height or self.default_sample_size * self.vae_scale_factor
533
+ width = width or self.default_sample_size * self.vae_scale_factor
534
+
535
+ # 1. Check inputs
536
+ self.check_inputs(
537
+ prompt,
538
+ prompt_2,
539
+ height,
540
+ width,
541
+ negative_prompt=negative_prompt,
542
+ negative_prompt_2=negative_prompt_2,
543
+ prompt_embeds=prompt_embeds,
544
+ negative_prompt_embeds=negative_prompt_embeds,
545
+ pooled_prompt_embeds=pooled_prompt_embeds,
546
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
547
+ max_sequence_length=max_sequence_length,
548
+ )
549
+
550
+ self._guidance_scale = guidance_scale
551
+ self._joint_attention_kwargs = joint_attention_kwargs
552
+ self._interrupt = False
553
+
554
+ do_classifier_free_guidance = guidance_scale > 1.0
555
+
556
+ # 2. Define call parameters
557
+ if prompt is not None and isinstance(prompt, str):
558
+ batch_size = 1
559
+ elif prompt is not None and isinstance(prompt, list):
560
+ batch_size = len(prompt)
561
+ else:
562
+ batch_size = prompt_embeds.shape[0]
563
+
564
+ device = self._execution_device
565
+
566
+ lora_scale = (
567
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
568
+ )
569
+ (
570
+ prompt_embeds,
571
+ negative_prompt_embeds,
572
+ pooled_prompt_embeds,
573
+ negative_pooled_prompt_embeds,
574
+ ) = self.encode_prompt(
575
+ prompt=prompt,
576
+ prompt_2=prompt_2,
577
+ negative_prompt=negative_prompt,
578
+ negative_prompt_2=negative_prompt_2,
579
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
580
+ prompt_embeds=prompt_embeds,
581
+ pooled_prompt_embeds=pooled_prompt_embeds,
582
+ negative_prompt_embeds=negative_prompt_embeds,
583
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
584
+ device=device,
585
+ num_images_per_prompt=num_images_per_prompt,
586
+ max_sequence_length=max_sequence_length,
587
+ lora_scale=lora_scale,
588
+ )
589
+
590
+ if self.do_classifier_free_guidance:
591
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
592
+ pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
593
+
594
+ # 4. Prepare latent variables
595
+ num_channels_latents = self.transformer.config.in_channels // 4
596
+ latents, latent_image_ids = self.prepare_latents(
597
+ batch_size * num_images_per_prompt,
598
+ num_channels_latents,
599
+ height,
600
+ width,
601
+ prompt_embeds.dtype,
602
+ negative_prompt_embeds.dtype,
603
+ device,
604
+ generator,
605
+ latents,
606
+ )
607
+
608
+ # 5. Prepare timesteps
609
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
610
+ image_seq_len = latents.shape[1]
611
+ mu = calculate_timestep_shift(image_seq_len)
612
+ timesteps, num_inference_steps = prepare_timesteps(
613
+ self.scheduler,
614
+ num_inference_steps,
615
+ device,
616
+ timesteps,
617
+ sigmas,
618
+ mu=mu,
619
+ )
620
+ self._num_timesteps = len(timesteps)
621
+
622
+ # 6. Denoising loop
623
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
624
+ for i, t in enumerate(timesteps):
625
+ if self.interrupt:
626
+ continue
627
+
628
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
629
+
630
+ timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype)
631
+
632
+ if self.transformer.config.guidance_embeds:
633
+ guidance = torch.tensor([guidance_scale], device=device)
634
+ guidance = guidance.expand(latents.shape[0])
635
+ else:
636
+ guidance = None
637
+
638
+ noise_pred = self.transformer(
639
+ hidden_states=latent_model_input,
640
+ timestep=timestep / 1000,
641
+ guidance=guidance,
642
+ pooled_projections=pooled_prompt_embeds,
643
+ encoder_hidden_states=prompt_embeds,
644
+ txt_ids=text_ids,
645
+ img_ids=latent_image_ids,
646
+ joint_attention_kwargs=self.joint_attention_kwargs,
647
+ return_dict=False,
648
+ )[0]
649
+
650
+ if self.do_classifier_free_guidance:
651
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
652
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
653
+
654
+ # compute the previous noisy sample x_t -> x_t-1
655
+ latents_dtype = latents.dtype
656
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
657
+
658
+ if latents.dtype != latents_dtype:
659
+ if torch.backends.mps.is_available():
660
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
661
+ latents = latents.to(latents_dtype)
662
+
663
+ # call the callback, if provided
664
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
665
+ progress_bar.update()
666
+
667
+ # Final image
668
+ return self._decode_latents_to_image(latents, height, width, output_type)
669
+ self.maybe_free_model_hooks()
670
+ torch.cuda.empty_cache()
671
+
672
+ def _decode_latents_to_image(self, latents, height, width, output_type, vae=None):
673
+ """Decodes the given latents into an image."""
674
+ vae = vae or self.vae
675
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
676
+ latents = (latents / vae.config.scaling_factor) + vae.config.shift_factor
677
+ image = vae.decode(latents, return_dict=False)[0]
678
+ return self.image_processor.postprocess(image, output_type=output_type)[0]