AlekseyCalvin commited on
Commit
9a7ca1b
·
verified ·
1 Parent(s): 397b504

Upload pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +757 -0
pipeline.py ADDED
@@ -0,0 +1,757 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import html
4
+ import inspect
5
+ import re
6
+ import urllib.parse as ul
7
+ from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast, CLIPTextModelWithProjection
8
+ from diffusers import FlowMatchEulerDiscreteScheduler, AutoPipelineForImage2Image, FluxPipeline, FluxTransformer2DModel
9
+ from diffusers import StableDiffusion3Pipeline, AutoencoderKL, DiffusionPipeline, ImagePipelineOutput
10
+ from diffusers.image_processor import VaeImageProcessor
11
+ from diffusers.loaders import FluxLoraLoaderMixin, FromSingleFileMixin, SD3LoraLoaderMixin
12
+ from diffusers.utils import (
13
+ USE_PEFT_BACKEND,
14
+ is_torch_xla_available,
15
+ logging,
16
+ BACKENDS_MAPPING,
17
+ is_bs4_available,
18
+ is_ftfy_available,
19
+ deprecate,
20
+ replace_example_docstring,
21
+ scale_lora_layers,
22
+ unscale_lora_layers,
23
+ )
24
+ from diffusers.utils.torch_utils import randn_tensor
25
+ from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
26
+ from typing import Any, Callable, Dict, List, Optional, Union
27
+ from PIL import Image
28
+ from diffusers.pipelines.flux.pipeline_flux import calculate_shift, retrieve_timesteps, FluxTransformer2DModel
29
+
30
+ from diffusers.utils import is_torch_xla_available
31
+
32
+ if is_bs4_available():
33
+ from bs4 import BeautifulSoup
34
+
35
+ if is_ftfy_available():
36
+ import ftfy
37
+
38
+ if is_torch_xla_available():
39
+ import torch_xla.core.xla_model as xm
40
+
41
+ XLA_AVAILABLE = True
42
+ else:
43
+ XLA_AVAILABLE = False
44
+
45
+
46
+ # Constants for shift calculation
47
+ BASE_SEQ_LEN = 256
48
+ MAX_SEQ_LEN = 4096
49
+ BASE_SHIFT = 0.5
50
+ MAX_SHIFT = 1.2
51
+
52
+ # Helper functions
53
+ def calculate_timestep_shift(image_seq_len: int) -> float:
54
+ """Calculates the timestep shift (mu) based on the image sequence length."""
55
+ m = (MAX_SHIFT - BASE_SHIFT) / (MAX_SEQ_LEN - BASE_SEQ_LEN)
56
+ b = BASE_SHIFT - m * BASE_SEQ_LEN
57
+ mu = image_seq_len * m + b
58
+ return mu
59
+
60
+ def prepare_timesteps(
61
+ scheduler: FlowMatchEulerDiscreteScheduler,
62
+ num_inference_steps: Optional[int] = None,
63
+ device: Optional[Union[str, torch.device]] = None,
64
+ timesteps: Optional[List[int]] = None,
65
+ sigmas: Optional[List[float]] = None,
66
+ mu: Optional[float] = None,
67
+ ) -> (torch.Tensor, int):
68
+ """Prepares the timesteps for the diffusion process."""
69
+ if timesteps is not None and sigmas is not None:
70
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed.")
71
+
72
+ if timesteps is not None:
73
+ scheduler.set_timesteps(timesteps=timesteps, device=device)
74
+ elif sigmas is not None:
75
+ scheduler.set_timesteps(sigmas=sigmas, device=device)
76
+ else:
77
+ scheduler.set_timesteps(num_inference_steps, device=device, mu=mu)
78
+
79
+ timesteps = scheduler.timesteps
80
+ num_inference_steps = len(timesteps)
81
+ return timesteps, num_inference_steps
82
+
83
+ # FLUX pipeline function
84
+ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin):
85
+ def __init__(
86
+ self,
87
+ scheduler: FlowMatchEulerDiscreteScheduler,
88
+ vae: AutoencoderKL,
89
+ text_encoder: CLIPTextModel,
90
+ tokenizer: CLIPTokenizer,
91
+ text_encoder_2: T5EncoderModel,
92
+ tokenizer_2: T5TokenizerFast,
93
+ transformer: FluxTransformer2DModel,
94
+ ):
95
+ super().__init__()
96
+
97
+ self.register_modules(
98
+ vae=vae,
99
+ text_encoder=text_encoder,
100
+ text_encoder_2=text_encoder_2,
101
+ tokenizer=tokenizer,
102
+ tokenizer_2=tokenizer_2,
103
+ transformer=transformer,
104
+ scheduler=scheduler,
105
+ )
106
+ self.vae_scale_factor = (
107
+ 2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16
108
+ )
109
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
110
+ self.tokenizer_max_length = (
111
+ self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
112
+ )
113
+ self.default_sample_size = 64
114
+
115
+ def _get_t5_prompt_embeds(
116
+ self,
117
+ prompt: Union[str, List[str]] = None,
118
+ num_images_per_prompt: int = 1,
119
+ max_sequence_length: int = 512,
120
+ device: Optional[torch.device] = None,
121
+ dtype: Optional[torch.dtype] = None,
122
+ ):
123
+ device = device or self._execution_device
124
+ dtype = dtype or self.text_encoder.dtype
125
+
126
+ prompt = [prompt] if isinstance(prompt, str) else prompt
127
+ batch_size = len(prompt)
128
+
129
+ text_inputs = self.tokenizer_2(
130
+ prompt,
131
+ padding="max_length",
132
+ max_length=max_sequence_length,
133
+ truncation=True,
134
+ return_length=True,
135
+ return_overflowing_tokens=True,
136
+ return_tensors="pt",
137
+ )
138
+ text_input_ids = text_inputs.input_ids
139
+ untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
140
+
141
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
142
+ removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
143
+ logger.warning(
144
+ "The following part of your input was truncated because `max_sequence_length` is set to "
145
+ f" {max_sequence_length} tokens: {removed_text}"
146
+ )
147
+
148
+ prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
149
+
150
+ dtype = self.text_encoder_2.dtype
151
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
152
+
153
+ _, seq_len, _ = prompt_embeds.shape
154
+
155
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
156
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
157
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
158
+
159
+ return prompt_embeds
160
+
161
+ def _get_clip_prompt_embeds(
162
+ self,
163
+ prompt: Union[str, List[str]],
164
+ num_images_per_prompt: int = 1,
165
+ device: Optional[torch.device] = None,
166
+ clip_skip: Optional[int] = None,
167
+ ):
168
+ device = device or self._execution_device
169
+
170
+ prompt = [prompt] if isinstance(prompt, str) else prompt
171
+ batch_size = len(prompt)
172
+
173
+ text_inputs = tokenizer(
174
+ prompt,
175
+ padding="max_length",
176
+ max_length=self.tokenizer_max_length,
177
+ truncation=True,
178
+ return_overflowing_tokens=False,
179
+ return_length=False,
180
+ return_tensors="pt",
181
+ )
182
+
183
+ text_input_ids = text_inputs.input_ids
184
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
185
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
186
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
187
+ logger.warning(
188
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
189
+ f" {self.tokenizer_max_length} tokens: {removed_text}"
190
+ )
191
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
192
+
193
+ if clip_skip is None:
194
+ prompt_embeds = prompt_embeds.hidden_states[-2]
195
+ else:
196
+ prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]
197
+
198
+ # Use pooled output of CLIPTextModel
199
+ prompt_embeds = prompt_embeds.pooler_output
200
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
201
+
202
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
203
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
204
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
205
+
206
+ return prompt_embeds
207
+
208
+ def encode_prompt(
209
+ self,
210
+ prompt: Union[str, List[str]],
211
+ prompt_2: Union[str, List[str]],
212
+ device: Optional[torch.device] = None,
213
+ num_images_per_prompt: int = 1,
214
+ do_classifier_free_guidance: bool = True,
215
+ negative_prompt: Optional[Union[str, List[str]]] = None,
216
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
217
+ prompt_embeds: Optional[torch.FloatTensor] = None,
218
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
219
+ negative_prompt_2_embed: Optional[torch.Tensor] = None,
220
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
221
+ negative_pooled_prompt_2_embed: Optional[torch.FloatTensor] = None,
222
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
223
+ clip_skip: Optional[int] = None,
224
+ max_sequence_length: int = 512,
225
+ lora_scale: Optional[float] = None,
226
+ ):
227
+ r"""
228
+
229
+ Args:
230
+ prompt (`str` or `List[str]`, *optional*):
231
+ prompt_2 (`str` or `List[str]`, *optional*):
232
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
233
+ used in all text-encoders
234
+ device: (`torch.device`):
235
+ torch device
236
+ num_images_per_prompt (`int`):
237
+ number of images that should be generated per prompt
238
+ prompt_embeds (`torch.FloatTensor`, *optional*):
239
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
240
+ provided, text embeddings will be generated from `prompt` input argument.
241
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
242
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
243
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
244
+ lora_scale (`float`, *optional*):
245
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
246
+ """
247
+ device = device or self._execution_device
248
+
249
+ if device is None:
250
+ device = self._execution_device
251
+
252
+ # set lora scale so that monkey patched LoRA
253
+ # function of text encoder can correctly access it
254
+ if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
255
+ self._lora_scale = lora_scale
256
+
257
+ # dynamically adjust the LoRA scale
258
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
259
+ scale_lora_layers(self.text_encoder, lora_scale)
260
+ if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
261
+ scale_lora_layers(self.text_encoder_2, lora_scale)
262
+
263
+ prompt = [prompt] if isinstance(prompt, str) else prompt
264
+
265
+ if prompt is not None and isinstance(prompt, str):
266
+ batch_size = 1
267
+ elif prompt is not None and isinstance(prompt, list):
268
+ batch_size = len(prompt)
269
+ else:
270
+ batch_size = prompt_embeds.shape[0]
271
+
272
+ if prompt_embeds is None:
273
+ prompt_2 = prompt_2 or prompt
274
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
275
+
276
+ # We only use the pooled prompt output from the CLIPTextModel
277
+ pooled_prompt_embeds = self._get_clip_prompt_embeds(
278
+ prompt=prompt,
279
+ device=device,
280
+ num_images_per_prompt=num_images_per_prompt,
281
+ clip_skip=clip_skip,
282
+ )
283
+ prompt_embeds = self._get_t5_prompt_embeds(
284
+ prompt=prompt_2,
285
+ num_images_per_prompt=num_images_per_prompt,
286
+ max_sequence_length=max_sequence_length,
287
+ device=device,
288
+ )
289
+
290
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
291
+ negative_prompt = negative_prompt or ""
292
+ negative_prompt_2 = negative_prompt_2 or negative_prompt
293
+
294
+ # normalize str to list
295
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
296
+ negative_prompt_2 = (
297
+ batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
298
+ )
299
+
300
+ if prompt is not None and type(prompt) is not type(negative_prompt):
301
+ raise TypeError(
302
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
303
+ f" {type(prompt)}."
304
+ )
305
+ elif batch_size != len(negative_prompt):
306
+ raise ValueError(
307
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
308
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
309
+ " the batch size of `prompt`."
310
+ )
311
+
312
+ negative_prompt_embed, negative_pooled_prompt_embed = self._get_clip_prompt_embeds(
313
+ negative_prompt,
314
+ device=device,
315
+ num_images_per_prompt=num_images_per_prompt,
316
+ clip_skip=None,
317
+ )
318
+ negative_clip_prompt_embeds = torch.cat([negative_prompt_embed, negative_prompt_2_embed], dim=-1)
319
+
320
+ t5_negative_prompt_embed = self._get_t5_prompt_embeds(
321
+ prompt=negative_prompt_2,
322
+ num_images_per_prompt=num_images_per_prompt,
323
+ max_sequence_length=max_sequence_length,
324
+ device=device,
325
+ )
326
+
327
+ negative_clip_prompt_embeds = torch.nn.functional.pad(
328
+ negative_clip_prompt_embeds,
329
+ (0, t5_negative_prompt_embed.shape[-1] - negative_clip_prompt_embeds.shape[-1]),
330
+ )
331
+
332
+ negative_prompt_embeds = torch.cat([negative_clip_prompt_embeds, t5_negative_prompt_embed], dim=-2)
333
+ negative_pooled_prompt_embeds = torch.cat(
334
+ [negative_pooled_prompt_embed, negative_pooled_prompt_2_embed], dim=-1
335
+ )
336
+
337
+ if self.text_encoder is not None:
338
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
339
+ # Retrieve the original scale by scaling back the LoRA layers
340
+ unscale_lora_layers(self.text_encoder, lora_scale)
341
+
342
+ dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
343
+ text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
344
+
345
+ return prompt_embeds, pooled_prompt_embeds, text_ids, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
346
+
347
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
348
+ def prepare_extra_step_kwargs(self, generator, eta):
349
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
350
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
351
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
352
+ # and should be between [0, 1]
353
+
354
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
355
+ extra_step_kwargs = {}
356
+ if accepts_eta:
357
+ extra_step_kwargs["eta"] = eta
358
+
359
+ # check if the scheduler accepts generator
360
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
361
+ if accepts_generator:
362
+ extra_step_kwargs["generator"] = generator
363
+ return extra_step_kwargs
364
+
365
+ def check_inputs(
366
+ self,
367
+ prompt,
368
+ prompt_2,
369
+ height,
370
+ width,
371
+ negative_prompt=None,
372
+ negative_prompt_2=None,
373
+ prompt_embeds=None,
374
+ negative_prompt_embeds=None,
375
+ pooled_prompt_embeds=None,
376
+ negative_pooled_prompt_embeds=None,
377
+ callback_on_step_end_tensor_inputs=None,
378
+ max_sequence_length=None,
379
+ ):
380
+ if height % 8 != 0 or width % 8 != 0:
381
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
382
+
383
+ if callback_on_step_end_tensor_inputs is not None and not all(
384
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
385
+ ):
386
+ raise ValueError(
387
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
388
+ )
389
+
390
+ if prompt is not None and prompt_embeds is not None:
391
+ raise ValueError(
392
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
393
+ " only forward one of the two."
394
+ )
395
+ elif prompt_2 is not None and prompt_embeds is not None:
396
+ raise ValueError(
397
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
398
+ " only forward one of the two."
399
+ )
400
+ elif prompt is None and prompt_embeds is None:
401
+ raise ValueError(
402
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
403
+ )
404
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
405
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
406
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
407
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
408
+
409
+ if negative_prompt is not None and negative_prompt_embeds is not None:
410
+ raise ValueError(
411
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
412
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
413
+ )
414
+ elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
415
+ raise ValueError(
416
+ f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
417
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
418
+ )
419
+
420
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
421
+ raise ValueError(
422
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
423
+ )
424
+ if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
425
+ raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
426
+
427
+ if max_sequence_length is not None and max_sequence_length > 512:
428
+ raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
429
+
430
+ return prompt_embeds, negative_prompt_embeds
431
+
432
+ @staticmethod
433
+ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
434
+ latent_image_ids = torch.zeros(height // 2, width // 2, 3)
435
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
436
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
437
+
438
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
439
+
440
+ latent_image_ids = latent_image_ids.reshape(
441
+ latent_image_id_height * latent_image_id_width, latent_image_id_channels
442
+ )
443
+
444
+ return latent_image_ids.to(device=device, dtype=dtype)
445
+
446
+ @staticmethod
447
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
448
+ latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
449
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
450
+ latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
451
+
452
+ return latents
453
+
454
+ @staticmethod
455
+ def _unpack_latents(latents, height, width, vae_scale_factor):
456
+ batch_size, num_patches, channels = latents.shape
457
+
458
+ height = height // vae_scale_factor
459
+ width = width // vae_scale_factor
460
+
461
+ latents = latents.view(batch_size, height, width, channels // 4, 2, 2)
462
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
463
+
464
+ latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2)
465
+
466
+ return latents
467
+
468
+ def enable_vae_slicing(self):
469
+ r"""
470
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
471
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
472
+ """
473
+ self.vae.enable_slicing()
474
+
475
+ def disable_vae_slicing(self):
476
+ r"""
477
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
478
+ computing decoding in one step.
479
+ """
480
+ self.vae.disable_slicing()
481
+
482
+ def enable_vae_tiling(self):
483
+ r"""
484
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
485
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
486
+ processing larger images.
487
+ """
488
+ self.vae.enable_tiling()
489
+
490
+ def disable_vae_tiling(self):
491
+ r"""
492
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
493
+ computing decoding in one step.
494
+ """
495
+ self.vae.disable_tiling()
496
+
497
+ def prepare_latents(
498
+ self,
499
+ batch_size,
500
+ num_channels_latents,
501
+ height,
502
+ width,
503
+ dtype,
504
+ device,
505
+ generator,
506
+ latents=None,
507
+ ):
508
+ height = 2 * (int(height) // self.vae_scale_factor)
509
+ width = 2 * (int(width) // self.vae_scale_factor)
510
+
511
+ shape = (batch_size, num_channels_latents, height, width)
512
+
513
+ if latents is not None:
514
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
515
+ return latents.to(device=device, dtype=dtype), latent_image_ids
516
+
517
+ if isinstance(generator, list) and len(generator) != batch_size:
518
+ raise ValueError(
519
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
520
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
521
+ )
522
+
523
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
524
+ latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
525
+
526
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
527
+
528
+ return latents, latent_image_ids
529
+
530
+ @property
531
+ def guidance_scale(self):
532
+ return self._guidance_scale
533
+
534
+ @property
535
+ def do_classifier_free_guidance(self):
536
+ return self._guidance_scale > 1
537
+
538
+ @property
539
+ def joint_attention_kwargs(self):
540
+ return self._joint_attention_kwargs
541
+
542
+ @property
543
+ def num_timesteps(self):
544
+ return self._num_timesteps
545
+
546
+ @property
547
+ def interrupt(self):
548
+ return self._interrupt
549
+
550
+ @torch.no_grad()
551
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
552
+ def __call__(
553
+ self,
554
+ prompt: Union[str, List[str]] = None,
555
+ prompt_2: Optional[Union[str, List[str]]] = None,
556
+ height: Optional[int] = None,
557
+ width: Optional[int] = None,
558
+ negative_prompt: Optional[Union[str, List[str]]] = None,
559
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
560
+ num_inference_steps: int = 8,
561
+ timesteps: List[int] = None,
562
+ eta: float = 0.0,
563
+ guidance_scale: float = 3.5,
564
+ lora_scale: Optional[torch.FloatTensor] = None,
565
+ num_images_per_prompt: Optional[int] = 1,
566
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
567
+ latents: Optional[torch.FloatTensor] = None,
568
+ prompt_embeds: Optional[torch.FloatTensor] = None,
569
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
570
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
571
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
572
+ output_type: Optional[str] = "pil",
573
+ return_dict: bool = True,
574
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
575
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
576
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
577
+ clip_skip: Optional[int] = None,
578
+ max_sequence_length: int = 300,
579
+ ):
580
+ height = height or self.default_sample_size * self.vae_scale_factor
581
+ width = width or self.default_sample_size * self.vae_scale_factor
582
+
583
+ # 1. Check inputs
584
+ self.check_inputs(
585
+ prompt,
586
+ prompt_2,
587
+ height,
588
+ width,
589
+ negative_prompt=negative_prompt,
590
+ negative_prompt_2=negative_prompt_2,
591
+ prompt_embeds=prompt_embeds,
592
+ negative_prompt_embeds=negative_prompt_embeds,
593
+ pooled_prompt_embeds=pooled_prompt_embeds,
594
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
595
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
596
+ max_sequence_length=max_sequence_length,
597
+ )
598
+
599
+ self._guidance_scale = guidance_scale
600
+ self._clip_skip = clip_skip
601
+ self._joint_attention_kwargs = joint_attention_kwargs
602
+ self._interrupt = False
603
+
604
+ # 2. Define call parameters
605
+ if prompt is not None and isinstance(prompt, str):
606
+ batch_size = 1
607
+ elif prompt is not None and isinstance(prompt, list):
608
+ batch_size = len(prompt)
609
+ else:
610
+ batch_size = prompt_embeds.shape[0]
611
+
612
+ device = self._execution_device
613
+
614
+ do_classifier_free_guidance = guidance_scale > 1.0
615
+
616
+ lora_scale = (
617
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
618
+ )
619
+
620
+ # 3. Encode prompt
621
+ lora_scale = (
622
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
623
+ )
624
+ (
625
+ prompt_embeds,
626
+ negative_prompt_embeds,
627
+ pooled_prompt_embeds,
628
+ negative_pooled_prompt_embeds,
629
+ text_ids,
630
+ ) = self.encode_prompt(
631
+ prompt=prompt,
632
+ prompt_2=prompt_2,
633
+ negative_prompt=negative_prompt,
634
+ negative_prompt_2=negative_prompt_2,
635
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
636
+ prompt_embeds=prompt_embeds,
637
+ negative_prompt_embeds=negative_prompt_embeds,
638
+ pooled_prompt_embeds=pooled_prompt_embeds,
639
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
640
+ device=device,
641
+ clip_skip=self.clip_skip,
642
+ num_images_per_prompt=num_images_per_prompt,
643
+ max_sequence_length=max_sequence_length,
644
+ lora_scale=lora_scale,
645
+ )
646
+
647
+ if self.do_classifier_free_guidance:
648
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
649
+ pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
650
+
651
+ # 4. Prepare latent variables
652
+ num_channels_latents = self.transformer.config.in_channels // 4
653
+ latents, latent_image_ids = self.prepare_latents(
654
+ batch_size * num_images_per_prompt,
655
+ num_channels_latents,
656
+ height,
657
+ width,
658
+ prompt_embeds.dtype,
659
+ negative_prompt_embeds.dtype,
660
+ device,
661
+ generator,
662
+ latents,
663
+ )
664
+
665
+ # 5. Prepare timesteps
666
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
667
+ image_seq_len = latents.shape[1]
668
+ mu = calculate_timestep_shift(image_seq_len)
669
+ timesteps, num_inference_steps = prepare_timesteps(
670
+ self.scheduler,
671
+ num_inference_steps,
672
+ device,
673
+ timesteps,
674
+ sigmas,
675
+ mu=mu,
676
+ )
677
+ self._num_timesteps = len(timesteps)
678
+
679
+ # Handle guidance
680
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float16).expand(latents.shape[0]) if self.transformer.config.guidance_embeds else None
681
+
682
+ # 6. Denoising loop
683
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
684
+ for i, t in enumerate(timesteps):
685
+ if self.interrupt:
686
+ continue
687
+
688
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
689
+
690
+ timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype)
691
+
692
+ noise_pred = self.transformer(
693
+ hidden_states=latent_model_input,
694
+ timestep=timestep / 1000,
695
+ guidance=guidance,
696
+ pooled_projections=pooled_prompt_embeds,
697
+ encoder_hidden_states=prompt_embeds,
698
+ txt_ids=text_ids,
699
+ img_ids=latent_image_ids,
700
+ joint_attention_kwargs=self.joint_attention_kwargs,
701
+ return_dict=False,
702
+ )[0]
703
+
704
+ noise_pred_uncond = self.transformer(
705
+ hidden_states=latents,
706
+ timestep=timestep / 1000,
707
+ guidance=guidance,
708
+ pooled_projections=negative_pooled_prompt_embeds,
709
+ encoder_hidden_states=negative_prompt_embeds,
710
+ img_ids=latent_image_ids,
711
+ joint_attention_kwargs=self.joint_attention_kwargs,
712
+ return_dict=False,
713
+ )[0]
714
+
715
+ if self.do_classifier_free_guidance:
716
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
717
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
718
+
719
+ latents_dtype = latents.dtype
720
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
721
+ # Yield intermediate result
722
+ torch.cuda.empty_cache()
723
+
724
+ if latents.dtype != latents_dtype:
725
+ if torch.backends.mps.is_available():
726
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
727
+ latents = latents.to(latents_dtype)
728
+
729
+ if callback_on_step_end is not None:
730
+ callback_kwargs = {}
731
+ for k in callback_on_step_end_tensor_inputs:
732
+ callback_kwargs[k] = locals()[k]
733
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
734
+
735
+ latents = callback_outputs.pop("latents", latents)
736
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
737
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
738
+ negative_pooled_prompt_embeds = callback_outputs.pop(
739
+ "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
740
+ )
741
+
742
+ # call the callback, if provided
743
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
744
+ progress_bar.update()
745
+
746
+ # Final image
747
+ return self._decode_latents_to_image(latents, height, width, output_type)
748
+ self.maybe_free_model_hooks()
749
+ torch.cuda.empty_cache()
750
+
751
+ def _decode_latents_to_image(self, latents, height, width, output_type, vae=None):
752
+ """Decodes the given latents into an image."""
753
+ vae = vae or self.vae
754
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
755
+ latents = (latents / vae.config.scaling_factor) + vae.config.shift_factor
756
+ image = vae.decode(latents, return_dict=False)[0]
757
+ return self.image_processor.postprocess(image, output_type=output_type)[0]