jozee commited on
Commit
46a9673
·
verified ·
1 Parent(s): 1f2d1fd

Create pipeline_fill_sd_xl.py

Browse files
Files changed (1) hide show
  1. pipeline_fill_sd_xl.py +559 -0
pipeline_fill_sd_xl.py ADDED
@@ -0,0 +1,559 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import List, Optional, Union
16
+
17
+ import cv2
18
+ import PIL.Image
19
+ import torch
20
+ import torch.nn.functional as F
21
+ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
22
+ from diffusers.models import AutoencoderKL, UNet2DConditionModel
23
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
24
+ from diffusers.schedulers import KarrasDiffusionSchedulers
25
+ from diffusers.utils.torch_utils import randn_tensor
26
+ from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
27
+
28
+ from controlnet_union import ControlNetModel_Union
29
+
30
+
31
+ def latents_to_rgb(latents):
32
+ weights = ((60, -60, 25, -70), (60, -5, 15, -50), (60, 10, -5, -35))
33
+
34
+ weights_tensor = torch.t(
35
+ torch.tensor(weights, dtype=latents.dtype).to(latents.device)
36
+ )
37
+ biases_tensor = torch.tensor((150, 140, 130), dtype=latents.dtype).to(
38
+ latents.device
39
+ )
40
+ rgb_tensor = torch.einsum(
41
+ "...lxy,lr -> ...rxy", latents, weights_tensor
42
+ ) + biases_tensor.unsqueeze(-1).unsqueeze(-1)
43
+ image_array = rgb_tensor.clamp(0, 255)[0].byte().cpu().numpy()
44
+ image_array = image_array.transpose(1, 2, 0) # Change the order of dimensions
45
+
46
+ denoised_image = cv2.fastNlMeansDenoisingColored(image_array, None, 10, 10, 7, 21)
47
+ blurred_image = cv2.GaussianBlur(denoised_image, (5, 5), 0)
48
+ final_image = PIL.Image.fromarray(blurred_image)
49
+
50
+ width, height = final_image.size
51
+ final_image = final_image.resize(
52
+ (width * 8, height * 8), PIL.Image.Resampling.LANCZOS
53
+ )
54
+
55
+ return final_image
56
+
57
+
58
+ def retrieve_timesteps(
59
+ scheduler,
60
+ num_inference_steps: Optional[int] = None,
61
+ device: Optional[Union[str, torch.device]] = None,
62
+ **kwargs,
63
+ ):
64
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
65
+ timesteps = scheduler.timesteps
66
+
67
+ return timesteps, num_inference_steps
68
+
69
+
70
+ class StableDiffusionXLFillPipeline(DiffusionPipeline, StableDiffusionMixin):
71
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae"
72
+ _optional_components = [
73
+ "tokenizer",
74
+ "tokenizer_2",
75
+ "text_encoder",
76
+ "text_encoder_2",
77
+ ]
78
+
79
+ def __init__(
80
+ self,
81
+ vae: AutoencoderKL,
82
+ text_encoder: CLIPTextModel,
83
+ text_encoder_2: CLIPTextModelWithProjection,
84
+ tokenizer: CLIPTokenizer,
85
+ tokenizer_2: CLIPTokenizer,
86
+ unet: UNet2DConditionModel,
87
+ controlnet: ControlNetModel_Union,
88
+ scheduler: KarrasDiffusionSchedulers,
89
+ force_zeros_for_empty_prompt: bool = True,
90
+ ):
91
+ super().__init__()
92
+
93
+ self.register_modules(
94
+ vae=vae,
95
+ text_encoder=text_encoder,
96
+ text_encoder_2=text_encoder_2,
97
+ tokenizer=tokenizer,
98
+ tokenizer_2=tokenizer_2,
99
+ unet=unet,
100
+ controlnet=controlnet,
101
+ scheduler=scheduler,
102
+ )
103
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
104
+ self.image_processor = VaeImageProcessor(
105
+ vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True
106
+ )
107
+ self.control_image_processor = VaeImageProcessor(
108
+ vae_scale_factor=self.vae_scale_factor,
109
+ do_convert_rgb=True,
110
+ do_normalize=False,
111
+ )
112
+
113
+ self.register_to_config(
114
+ force_zeros_for_empty_prompt=force_zeros_for_empty_prompt
115
+ )
116
+
117
+ def encode_prompt(
118
+ self,
119
+ prompt: str,
120
+ device: Optional[torch.device] = None,
121
+ do_classifier_free_guidance: bool = True,
122
+ ):
123
+ device = device or self._execution_device
124
+ prompt = [prompt] if isinstance(prompt, str) else prompt
125
+
126
+ if prompt is not None:
127
+ batch_size = len(prompt)
128
+
129
+ # Define tokenizers and text encoders
130
+ tokenizers = (
131
+ [self.tokenizer, self.tokenizer_2]
132
+ if self.tokenizer is not None
133
+ else [self.tokenizer_2]
134
+ )
135
+ text_encoders = (
136
+ [self.text_encoder, self.text_encoder_2]
137
+ if self.text_encoder is not None
138
+ else [self.text_encoder_2]
139
+ )
140
+
141
+ prompt_2 = prompt
142
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
143
+
144
+ # textual inversion: process multi-vector tokens if necessary
145
+ prompt_embeds_list = []
146
+ prompts = [prompt, prompt_2]
147
+ for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
148
+ text_inputs = tokenizer(
149
+ prompt,
150
+ padding="max_length",
151
+ max_length=tokenizer.model_max_length,
152
+ truncation=True,
153
+ return_tensors="pt",
154
+ )
155
+
156
+ text_input_ids = text_inputs.input_ids
157
+
158
+ prompt_embeds = text_encoder(
159
+ text_input_ids.to(device), output_hidden_states=True
160
+ )
161
+
162
+ # We are only ALWAYS interested in the pooled output of the final text encoder
163
+ pooled_prompt_embeds = prompt_embeds[0]
164
+ prompt_embeds = prompt_embeds.hidden_states[-2]
165
+ prompt_embeds_list.append(prompt_embeds)
166
+
167
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
168
+
169
+ # get unconditional embeddings for classifier free guidance
170
+ zero_out_negative_prompt = True
171
+ negative_prompt_embeds = None
172
+ negative_pooled_prompt_embeds = None
173
+
174
+ if do_classifier_free_guidance and zero_out_negative_prompt:
175
+ negative_prompt_embeds = torch.zeros_like(prompt_embeds)
176
+ negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
177
+ elif do_classifier_free_guidance and negative_prompt_embeds is None:
178
+ negative_prompt = ""
179
+ negative_prompt_2 = negative_prompt
180
+
181
+ # normalize str to list
182
+ negative_prompt = (
183
+ batch_size * [negative_prompt]
184
+ if isinstance(negative_prompt, str)
185
+ else negative_prompt
186
+ )
187
+ negative_prompt_2 = (
188
+ batch_size * [negative_prompt_2]
189
+ if isinstance(negative_prompt_2, str)
190
+ else negative_prompt_2
191
+ )
192
+
193
+ uncond_tokens: List[str]
194
+ if prompt is not None and type(prompt) is not type(negative_prompt):
195
+ raise TypeError(
196
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
197
+ f" {type(prompt)}."
198
+ )
199
+ elif batch_size != len(negative_prompt):
200
+ raise ValueError(
201
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
202
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
203
+ " the batch size of `prompt`."
204
+ )
205
+ else:
206
+ uncond_tokens = [negative_prompt, negative_prompt_2]
207
+
208
+ negative_prompt_embeds_list = []
209
+ for negative_prompt, tokenizer, text_encoder in zip(
210
+ uncond_tokens, tokenizers, text_encoders
211
+ ):
212
+ max_length = prompt_embeds.shape[1]
213
+ uncond_input = tokenizer(
214
+ negative_prompt,
215
+ padding="max_length",
216
+ max_length=max_length,
217
+ truncation=True,
218
+ return_tensors="pt",
219
+ )
220
+
221
+ negative_prompt_embeds = text_encoder(
222
+ uncond_input.input_ids.to(device),
223
+ output_hidden_states=True,
224
+ )
225
+ # We are only ALWAYS interested in the pooled output of the final text encoder
226
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
227
+ negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
228
+
229
+ negative_prompt_embeds_list.append(negative_prompt_embeds)
230
+
231
+ negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
232
+
233
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
234
+
235
+ bs_embed, seq_len, _ = prompt_embeds.shape
236
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
237
+ prompt_embeds = prompt_embeds.repeat(1, 1, 1)
238
+ prompt_embeds = prompt_embeds.view(bs_embed * 1, seq_len, -1)
239
+
240
+ if do_classifier_free_guidance:
241
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
242
+ seq_len = negative_prompt_embeds.shape[1]
243
+
244
+ if self.text_encoder_2 is not None:
245
+ negative_prompt_embeds = negative_prompt_embeds.to(
246
+ dtype=self.text_encoder_2.dtype, device=device
247
+ )
248
+ else:
249
+ negative_prompt_embeds = negative_prompt_embeds.to(
250
+ dtype=self.unet.dtype, device=device
251
+ )
252
+
253
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, 1, 1)
254
+ negative_prompt_embeds = negative_prompt_embeds.view(
255
+ batch_size * 1, seq_len, -1
256
+ )
257
+
258
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, 1).view(bs_embed * 1, -1)
259
+ if do_classifier_free_guidance:
260
+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(
261
+ 1, 1
262
+ ).view(bs_embed * 1, -1)
263
+
264
+ return (
265
+ prompt_embeds,
266
+ negative_prompt_embeds,
267
+ pooled_prompt_embeds,
268
+ negative_pooled_prompt_embeds,
269
+ )
270
+
271
+ def check_inputs(
272
+ self,
273
+ prompt_embeds,
274
+ negative_prompt_embeds,
275
+ pooled_prompt_embeds,
276
+ negative_pooled_prompt_embeds,
277
+ image,
278
+ controlnet_conditioning_scale=1.0,
279
+ ):
280
+ if prompt_embeds is None:
281
+ raise ValueError(
282
+ "Provide `prompt_embeds`. Cannot leave `prompt_embeds` undefined."
283
+ )
284
+
285
+ if negative_prompt_embeds is None:
286
+ raise ValueError(
287
+ "Provide `negative_prompt_embeds`. Cannot leave `negative_prompt_embeds` undefined."
288
+ )
289
+
290
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
291
+ raise ValueError(
292
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
293
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
294
+ f" {negative_prompt_embeds.shape}."
295
+ )
296
+
297
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
298
+ raise ValueError(
299
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
300
+ )
301
+
302
+ if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
303
+ raise ValueError(
304
+ "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
305
+ )
306
+
307
+ # Check `image`
308
+ is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance(
309
+ self.controlnet, torch._dynamo.eval_frame.OptimizedModule
310
+ )
311
+ if (
312
+ isinstance(self.controlnet, ControlNetModel_Union)
313
+ or is_compiled
314
+ and isinstance(self.controlnet._orig_mod, ControlNetModel_Union)
315
+ ):
316
+ if not isinstance(image, PIL.Image.Image):
317
+ raise TypeError(
318
+ f"image must be passed and has to be a PIL image, but is {type(image)}"
319
+ )
320
+
321
+ else:
322
+ assert False
323
+
324
+ # Check `controlnet_conditioning_scale`
325
+ if (
326
+ isinstance(self.controlnet, ControlNetModel_Union)
327
+ or is_compiled
328
+ and isinstance(self.controlnet._orig_mod, ControlNetModel_Union)
329
+ ):
330
+ if not isinstance(controlnet_conditioning_scale, float):
331
+ raise TypeError(
332
+ "For single controlnet: `controlnet_conditioning_scale` must be type `float`."
333
+ )
334
+ else:
335
+ assert False
336
+
337
+ def prepare_image(self, image, device, dtype, do_classifier_free_guidance=False):
338
+ image = self.control_image_processor.preprocess(image).to(dtype=torch.float32)
339
+
340
+ image_batch_size = image.shape[0]
341
+
342
+ image = image.repeat_interleave(image_batch_size, dim=0)
343
+ image = image.to(device=device, dtype=dtype)
344
+
345
+ if do_classifier_free_guidance:
346
+ image = torch.cat([image] * 2)
347
+
348
+ return image
349
+
350
+ def prepare_latents(
351
+ self, batch_size, num_channels_latents, height, width, dtype, device
352
+ ):
353
+ shape = (
354
+ batch_size,
355
+ num_channels_latents,
356
+ int(height) // self.vae_scale_factor,
357
+ int(width) // self.vae_scale_factor,
358
+ )
359
+
360
+ latents = randn_tensor(shape, device=device, dtype=dtype)
361
+
362
+ # scale the initial noise by the standard deviation required by the scheduler
363
+ latents = latents * self.scheduler.init_noise_sigma
364
+ return latents
365
+
366
+ @property
367
+ def guidance_scale(self):
368
+ return self._guidance_scale
369
+
370
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
371
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
372
+ # corresponds to doing no classifier free guidance.
373
+ @property
374
+ def do_classifier_free_guidance(self):
375
+ return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
376
+
377
+ @property
378
+ def num_timesteps(self):
379
+ return self._num_timesteps
380
+
381
+ @torch.no_grad()
382
+ def __call__(
383
+ self,
384
+ prompt_embeds: torch.Tensor,
385
+ negative_prompt_embeds: torch.Tensor,
386
+ pooled_prompt_embeds: torch.Tensor,
387
+ negative_pooled_prompt_embeds: torch.Tensor,
388
+ image: PipelineImageInput = None,
389
+ num_inference_steps: int = 8,
390
+ guidance_scale: float = 1.5,
391
+ controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
392
+ ):
393
+ # 1. Check inputs. Raise error if not correct
394
+ self.check_inputs(
395
+ prompt_embeds,
396
+ negative_prompt_embeds,
397
+ pooled_prompt_embeds,
398
+ negative_pooled_prompt_embeds,
399
+ image,
400
+ controlnet_conditioning_scale,
401
+ )
402
+
403
+ self._guidance_scale = guidance_scale
404
+
405
+ # 2. Define call parameters
406
+ batch_size = 1
407
+ device = self._execution_device
408
+
409
+ # 4. Prepare image
410
+ if isinstance(self.controlnet, ControlNetModel_Union):
411
+ image = self.prepare_image(
412
+ image=image,
413
+ device=device,
414
+ dtype=self.controlnet.dtype,
415
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
416
+ )
417
+ height, width = image.shape[-2:]
418
+ else:
419
+ assert False
420
+
421
+ # 5. Prepare timesteps
422
+ timesteps, num_inference_steps = retrieve_timesteps(
423
+ self.scheduler, num_inference_steps, device
424
+ )
425
+ self._num_timesteps = len(timesteps)
426
+
427
+ # 6. Prepare latent variables
428
+ num_channels_latents = self.unet.config.in_channels
429
+ latents = self.prepare_latents(
430
+ batch_size,
431
+ num_channels_latents,
432
+ height,
433
+ width,
434
+ prompt_embeds.dtype,
435
+ device,
436
+ )
437
+
438
+ # 7 Prepare added time ids & embeddings
439
+ add_text_embeds = pooled_prompt_embeds
440
+
441
+ add_time_ids = negative_add_time_ids = torch.tensor(
442
+ image.shape[-2:] + torch.Size([0, 0]) + image.shape[-2:]
443
+ ).unsqueeze(0)
444
+
445
+ if self.do_classifier_free_guidance:
446
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
447
+ add_text_embeds = torch.cat(
448
+ [negative_pooled_prompt_embeds, add_text_embeds], dim=0
449
+ )
450
+ add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
451
+
452
+ prompt_embeds = prompt_embeds.to(device)
453
+ add_text_embeds = add_text_embeds.to(device)
454
+ add_time_ids = add_time_ids.to(device).repeat(batch_size, 1)
455
+
456
+ controlnet_image_list = [0, 0, 0, 0, 0, 0, image, 0]
457
+ union_control_type = (
458
+ torch.Tensor([0, 0, 0, 0, 0, 0, 1, 0])
459
+ .to(device, dtype=prompt_embeds.dtype)
460
+ .repeat(batch_size * 2, 1)
461
+ )
462
+
463
+ added_cond_kwargs = {
464
+ "text_embeds": add_text_embeds,
465
+ "time_ids": add_time_ids,
466
+ "control_type": union_control_type,
467
+ }
468
+
469
+ controlnet_prompt_embeds = prompt_embeds
470
+ controlnet_added_cond_kwargs = added_cond_kwargs
471
+
472
+ # 8. Denoising loop
473
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
474
+
475
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
476
+ for i, t in enumerate(timesteps):
477
+ # expand the latents if we are doing classifier free guidance
478
+ latent_model_input = (
479
+ torch.cat([latents] * 2)
480
+ if self.do_classifier_free_guidance
481
+ else latents
482
+ )
483
+ latent_model_input = self.scheduler.scale_model_input(
484
+ latent_model_input, t
485
+ )
486
+
487
+ # controlnet(s) inference
488
+ control_model_input = latent_model_input
489
+
490
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
491
+ control_model_input,
492
+ t,
493
+ encoder_hidden_states=controlnet_prompt_embeds,
494
+ controlnet_cond_list=controlnet_image_list,
495
+ conditioning_scale=controlnet_conditioning_scale,
496
+ guess_mode=False,
497
+ added_cond_kwargs=controlnet_added_cond_kwargs,
498
+ return_dict=False,
499
+ )
500
+
501
+ # predict the noise residual
502
+ noise_pred = self.unet(
503
+ latent_model_input,
504
+ t,
505
+ encoder_hidden_states=prompt_embeds,
506
+ timestep_cond=None,
507
+ cross_attention_kwargs={},
508
+ down_block_additional_residuals=down_block_res_samples,
509
+ mid_block_additional_residual=mid_block_res_sample,
510
+ added_cond_kwargs=added_cond_kwargs,
511
+ return_dict=False,
512
+ )[0]
513
+
514
+ # perform guidance
515
+ if self.do_classifier_free_guidance:
516
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
517
+ noise_pred = noise_pred_uncond + guidance_scale * (
518
+ noise_pred_text - noise_pred_uncond
519
+ )
520
+
521
+ # compute the previous noisy sample x_t -> x_t-1
522
+ latents = self.scheduler.step(
523
+ noise_pred, t, latents, return_dict=False
524
+ )[0]
525
+
526
+ if i == 2:
527
+ prompt_embeds = prompt_embeds[-1:]
528
+ add_text_embeds = add_text_embeds[-1:]
529
+ add_time_ids = add_time_ids[-1:]
530
+ union_control_type = union_control_type[-1:]
531
+
532
+ added_cond_kwargs = {
533
+ "text_embeds": add_text_embeds,
534
+ "time_ids": add_time_ids,
535
+ "control_type": union_control_type,
536
+ }
537
+
538
+ controlnet_prompt_embeds = prompt_embeds
539
+ controlnet_added_cond_kwargs = added_cond_kwargs
540
+
541
+ image = image[-1:]
542
+ controlnet_image_list = [0, 0, 0, 0, 0, 0, image, 0]
543
+
544
+ self._guidance_scale = 0.0
545
+
546
+ if i == len(timesteps) - 1 or (
547
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
548
+ ):
549
+ progress_bar.update()
550
+ yield latents_to_rgb(latents)
551
+
552
+ latents = latents / self.vae.config.scaling_factor
553
+ image = self.vae.decode(latents, return_dict=False)[0]
554
+ image = self.image_processor.postprocess(image)[0]
555
+
556
+ # Offload all models
557
+ self.maybe_free_model_hooks()
558
+
559
+ yield image