Spaces:
Paused
Paused
Update diffusion_webui/diffusion_models/controlnet/controlnet_inpaint/pipeline_stable_diffusion_controlnet_inpaint.py
Browse files
diffusion_webui/diffusion_models/controlnet/controlnet_inpaint/pipeline_stable_diffusion_controlnet_inpaint.py
CHANGED
@@ -12,13 +12,11 @@
|
|
12 |
# See the License for the specific language governing permissions and
|
13 |
# limitations under the License.
|
14 |
|
15 |
-
|
16 |
-
import numpy as np
|
17 |
-
import PIL.Image
|
18 |
import torch
|
19 |
-
|
|
|
20 |
|
21 |
-
|
22 |
|
23 |
EXAMPLE_DOC_STRING = """
|
24 |
Examples:
|
@@ -98,15 +96,11 @@ def prepare_mask_and_masked_image(image, mask):
|
|
98 |
"""
|
99 |
if isinstance(image, torch.Tensor):
|
100 |
if not isinstance(mask, torch.Tensor):
|
101 |
-
raise TypeError(
|
102 |
-
f"`image` is a torch.Tensor but `mask` (type: {type(mask)} is not"
|
103 |
-
)
|
104 |
|
105 |
# Batch single image
|
106 |
if image.ndim == 3:
|
107 |
-
assert (
|
108 |
-
image.shape[0] == 3
|
109 |
-
), "Image outside a batch should be of shape (3, H, W)"
|
110 |
image = image.unsqueeze(0)
|
111 |
|
112 |
# Batch and add channel dim for single mask
|
@@ -123,15 +117,9 @@ def prepare_mask_and_masked_image(image, mask):
|
|
123 |
else:
|
124 |
mask = mask.unsqueeze(1)
|
125 |
|
126 |
-
assert
|
127 |
-
|
128 |
-
|
129 |
-
assert (
|
130 |
-
image.shape[-2:] == mask.shape[-2:]
|
131 |
-
), "Image and Mask must have the same spatial dimensions"
|
132 |
-
assert (
|
133 |
-
image.shape[0] == mask.shape[0]
|
134 |
-
), "Image and Mask must have the same batch size"
|
135 |
|
136 |
# Check image is in [-1, 1]
|
137 |
if image.min() < -1 or image.max() > 1:
|
@@ -148,9 +136,7 @@ def prepare_mask_and_masked_image(image, mask):
|
|
148 |
# Image as float32
|
149 |
image = image.to(dtype=torch.float32)
|
150 |
elif isinstance(mask, torch.Tensor):
|
151 |
-
raise TypeError(
|
152 |
-
f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not"
|
153 |
-
)
|
154 |
else:
|
155 |
# preprocess image
|
156 |
if isinstance(image, (PIL.Image.Image, np.ndarray)):
|
@@ -170,9 +156,7 @@ def prepare_mask_and_masked_image(image, mask):
|
|
170 |
mask = [mask]
|
171 |
|
172 |
if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image):
|
173 |
-
mask = np.concatenate(
|
174 |
-
[np.array(m.convert("L"))[None, None, :] for m in mask], axis=0
|
175 |
-
)
|
176 |
mask = mask.astype(np.float32) / 255.0
|
177 |
elif isinstance(mask, list) and isinstance(mask[0], np.ndarray):
|
178 |
mask = np.concatenate([m[None, None, :] for m in mask], axis=0)
|
@@ -185,10 +169,7 @@ def prepare_mask_and_masked_image(image, mask):
|
|
185 |
|
186 |
return mask, masked_image
|
187 |
|
188 |
-
|
189 |
-
class StableDiffusionControlNetInpaintPipeline(
|
190 |
-
StableDiffusionControlNetPipeline
|
191 |
-
):
|
192 |
r"""
|
193 |
Pipeline for text-guided image inpainting using Stable Diffusion with ControlNet guidance.
|
194 |
|
@@ -217,28 +198,15 @@ class StableDiffusionControlNetInpaintPipeline(
|
|
217 |
feature_extractor ([`CLIPFeatureExtractor`]):
|
218 |
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
219 |
"""
|
220 |
-
|
221 |
def prepare_mask_latents(
|
222 |
-
self,
|
223 |
-
mask,
|
224 |
-
masked_image,
|
225 |
-
batch_size,
|
226 |
-
height,
|
227 |
-
width,
|
228 |
-
dtype,
|
229 |
-
device,
|
230 |
-
generator,
|
231 |
-
do_classifier_free_guidance,
|
232 |
):
|
233 |
# resize the mask to latents shape as we concatenate the mask to the latents
|
234 |
# we do that before converting to dtype to avoid breaking in case we're using cpu_offload
|
235 |
# and half precision
|
236 |
mask = torch.nn.functional.interpolate(
|
237 |
-
mask,
|
238 |
-
size=(
|
239 |
-
height // self.vae_scale_factor,
|
240 |
-
width // self.vae_scale_factor,
|
241 |
-
),
|
242 |
)
|
243 |
mask = mask.to(device=device, dtype=dtype)
|
244 |
|
@@ -247,19 +215,13 @@ class StableDiffusionControlNetInpaintPipeline(
|
|
247 |
# encode the mask image into latents space so we can concatenate it to the latents
|
248 |
if isinstance(generator, list):
|
249 |
masked_image_latents = [
|
250 |
-
self.vae.encode(masked_image[i : i + 1]).latent_dist.sample(
|
251 |
-
generator=generator[i]
|
252 |
-
)
|
253 |
for i in range(batch_size)
|
254 |
]
|
255 |
masked_image_latents = torch.cat(masked_image_latents, dim=0)
|
256 |
else:
|
257 |
-
masked_image_latents = self.vae.encode(
|
258 |
-
|
259 |
-
).latent_dist.sample(generator=generator)
|
260 |
-
masked_image_latents = (
|
261 |
-
self.vae.config.scaling_factor * masked_image_latents
|
262 |
-
)
|
263 |
|
264 |
# duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
|
265 |
if mask.shape[0] < batch_size:
|
@@ -277,35 +239,24 @@ class StableDiffusionControlNetInpaintPipeline(
|
|
277 |
f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
|
278 |
" Make sure the number of images that you pass is divisible by the total requested batch size."
|
279 |
)
|
280 |
-
masked_image_latents = masked_image_latents.repeat(
|
281 |
-
batch_size // masked_image_latents.shape[0], 1, 1, 1
|
282 |
-
)
|
283 |
|
284 |
mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
|
285 |
masked_image_latents = (
|
286 |
-
torch.cat([masked_image_latents] * 2)
|
287 |
-
if do_classifier_free_guidance
|
288 |
-
else masked_image_latents
|
289 |
)
|
290 |
|
291 |
# aligning device to prevent device errors when concating it with the latent model input
|
292 |
-
masked_image_latents = masked_image_latents.to(
|
293 |
-
device=device, dtype=dtype
|
294 |
-
)
|
295 |
return mask, masked_image_latents
|
296 |
-
|
297 |
@torch.no_grad()
|
298 |
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
299 |
def __call__(
|
300 |
self,
|
301 |
-
prompt: Union[str, List[str]] = None,
|
302 |
image: Union[torch.FloatTensor, PIL.Image.Image] = None,
|
303 |
-
control_image: Union[
|
304 |
-
torch.FloatTensor,
|
305 |
-
PIL.Image.Image,
|
306 |
-
List[torch.FloatTensor],
|
307 |
-
List[PIL.Image.Image],
|
308 |
-
] = None,
|
309 |
mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None,
|
310 |
height: Optional[int] = None,
|
311 |
width: Optional[int] = None,
|
@@ -314,17 +265,13 @@ class StableDiffusionControlNetInpaintPipeline(
|
|
314 |
negative_prompt: Optional[Union[str, List[str]]] = None,
|
315 |
num_images_per_prompt: Optional[int] = 1,
|
316 |
eta: float = 0.0,
|
317 |
-
generator: Optional[
|
318 |
-
Union[torch.Generator, List[torch.Generator]]
|
319 |
-
] = None,
|
320 |
latents: Optional[torch.FloatTensor] = None,
|
321 |
prompt_embeds: Optional[torch.FloatTensor] = None,
|
322 |
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
323 |
output_type: Optional[str] = "pil",
|
324 |
return_dict: bool = True,
|
325 |
-
callback: Optional[
|
326 |
-
Callable[[int, int, torch.FloatTensor], None]
|
327 |
-
] = None,
|
328 |
callback_steps: int = 1,
|
329 |
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
330 |
controlnet_conditioning_scale: float = 1.0,
|
@@ -346,7 +293,7 @@ class StableDiffusionControlNetInpaintPipeline(
|
|
346 |
`Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
|
347 |
repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted
|
348 |
to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L)
|
349 |
-
instead of 3, so the expected shape would be `(B, H, W, 1)`.
|
350 |
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
351 |
The height in pixels of the generated image.
|
352 |
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
@@ -415,14 +362,7 @@ class StableDiffusionControlNetInpaintPipeline(
|
|
415 |
|
416 |
# 1. Check inputs. Raise error if not correct
|
417 |
self.check_inputs(
|
418 |
-
prompt,
|
419 |
-
control_image,
|
420 |
-
height,
|
421 |
-
width,
|
422 |
-
callback_steps,
|
423 |
-
negative_prompt,
|
424 |
-
prompt_embeds,
|
425 |
-
negative_prompt_embeds,
|
426 |
)
|
427 |
|
428 |
# 2. Define call parameters
|
@@ -452,15 +392,15 @@ class StableDiffusionControlNetInpaintPipeline(
|
|
452 |
|
453 |
# 4. Prepare image
|
454 |
control_image = self.prepare_image(
|
455 |
-
|
456 |
-
|
457 |
-
|
458 |
-
|
459 |
-
|
460 |
-
|
461 |
-
|
462 |
-
|
463 |
-
|
464 |
if do_classifier_free_guidance:
|
465 |
control_image = torch.cat([control_image] * 2)
|
466 |
|
@@ -469,7 +409,7 @@ class StableDiffusionControlNetInpaintPipeline(
|
|
469 |
timesteps = self.scheduler.timesteps
|
470 |
|
471 |
# 6. Prepare latent variables
|
472 |
-
num_channels_latents = self.controlnet.in_channels
|
473 |
latents = self.prepare_latents(
|
474 |
batch_size * num_images_per_prompt,
|
475 |
num_channels_latents,
|
@@ -480,7 +420,7 @@ class StableDiffusionControlNetInpaintPipeline(
|
|
480 |
generator,
|
481 |
latents,
|
482 |
)
|
483 |
-
|
484 |
# EXTRA: prepare mask latents
|
485 |
mask, masked_image = prepare_mask_and_masked_image(image, mask_image)
|
486 |
mask, masked_image_latents = self.prepare_mask_latents(
|
@@ -499,20 +439,12 @@ class StableDiffusionControlNetInpaintPipeline(
|
|
499 |
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
500 |
|
501 |
# 8. Denoising loop
|
502 |
-
num_warmup_steps = (
|
503 |
-
len(timesteps) - num_inference_steps * self.scheduler.order
|
504 |
-
)
|
505 |
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
506 |
for i, t in enumerate(timesteps):
|
507 |
# expand the latents if we are doing classifier free guidance
|
508 |
-
latent_model_input = (
|
509 |
-
|
510 |
-
if do_classifier_free_guidance
|
511 |
-
else latents
|
512 |
-
)
|
513 |
-
latent_model_input = self.scheduler.scale_model_input(
|
514 |
-
latent_model_input, t
|
515 |
-
)
|
516 |
|
517 |
down_block_res_samples, mid_block_res_sample = self.controlnet(
|
518 |
latent_model_input,
|
@@ -529,9 +461,7 @@ class StableDiffusionControlNetInpaintPipeline(
|
|
529 |
mid_block_res_sample *= controlnet_conditioning_scale
|
530 |
|
531 |
# predict the noise residual
|
532 |
-
latent_model_input = torch.cat(
|
533 |
-
[latent_model_input, mask, masked_image_latents], dim=1
|
534 |
-
)
|
535 |
noise_pred = self.unet(
|
536 |
latent_model_input,
|
537 |
t,
|
@@ -544,30 +474,20 @@ class StableDiffusionControlNetInpaintPipeline(
|
|
544 |
# perform guidance
|
545 |
if do_classifier_free_guidance:
|
546 |
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
547 |
-
noise_pred = noise_pred_uncond + guidance_scale * (
|
548 |
-
noise_pred_text - noise_pred_uncond
|
549 |
-
)
|
550 |
|
551 |
# compute the previous noisy sample x_t -> x_t-1
|
552 |
-
latents = self.scheduler.step(
|
553 |
-
noise_pred, t, latents, **extra_step_kwargs
|
554 |
-
).prev_sample
|
555 |
|
556 |
# call the callback, if provided
|
557 |
-
if i == len(timesteps) - 1 or (
|
558 |
-
(i + 1) > num_warmup_steps
|
559 |
-
and (i + 1) % self.scheduler.order == 0
|
560 |
-
):
|
561 |
progress_bar.update()
|
562 |
if callback is not None and i % callback_steps == 0:
|
563 |
callback(i, t, latents)
|
564 |
|
565 |
# If we do sequential model offloading, let's offload unet and controlnet
|
566 |
# manually for max memory savings
|
567 |
-
if (
|
568 |
-
hasattr(self, "final_offload_hook")
|
569 |
-
and self.final_offload_hook is not None
|
570 |
-
):
|
571 |
self.unet.to("cpu")
|
572 |
self.controlnet.to("cpu")
|
573 |
torch.cuda.empty_cache()
|
@@ -580,9 +500,7 @@ class StableDiffusionControlNetInpaintPipeline(
|
|
580 |
image = self.decode_latents(latents)
|
581 |
|
582 |
# 9. Run safety checker
|
583 |
-
image, has_nsfw_concept = self.run_safety_checker(
|
584 |
-
image, device, prompt_embeds.dtype
|
585 |
-
)
|
586 |
|
587 |
# 10. Convert to PIL
|
588 |
image = self.numpy_to_pil(image)
|
@@ -591,20 +509,13 @@ class StableDiffusionControlNetInpaintPipeline(
|
|
591 |
image = self.decode_latents(latents)
|
592 |
|
593 |
# 9. Run safety checker
|
594 |
-
image, has_nsfw_concept = self.run_safety_checker(
|
595 |
-
image, device, prompt_embeds.dtype
|
596 |
-
)
|
597 |
|
598 |
# Offload last model to CPU
|
599 |
-
if (
|
600 |
-
hasattr(self, "final_offload_hook")
|
601 |
-
and self.final_offload_hook is not None
|
602 |
-
):
|
603 |
self.final_offload_hook.offload()
|
604 |
|
605 |
if not return_dict:
|
606 |
return (image, has_nsfw_concept)
|
607 |
|
608 |
-
return StableDiffusionPipelineOutput(
|
609 |
-
images=image, nsfw_content_detected=has_nsfw_concept
|
610 |
-
)
|
|
|
12 |
# See the License for the specific language governing permissions and
|
13 |
# limitations under the License.
|
14 |
|
|
|
|
|
|
|
15 |
import torch
|
16 |
+
import PIL.Image
|
17 |
+
import numpy as np
|
18 |
|
19 |
+
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import *
|
20 |
|
21 |
EXAMPLE_DOC_STRING = """
|
22 |
Examples:
|
|
|
96 |
"""
|
97 |
if isinstance(image, torch.Tensor):
|
98 |
if not isinstance(mask, torch.Tensor):
|
99 |
+
raise TypeError(f"`image` is a torch.Tensor but `mask` (type: {type(mask)} is not")
|
|
|
|
|
100 |
|
101 |
# Batch single image
|
102 |
if image.ndim == 3:
|
103 |
+
assert image.shape[0] == 3, "Image outside a batch should be of shape (3, H, W)"
|
|
|
|
|
104 |
image = image.unsqueeze(0)
|
105 |
|
106 |
# Batch and add channel dim for single mask
|
|
|
117 |
else:
|
118 |
mask = mask.unsqueeze(1)
|
119 |
|
120 |
+
assert image.ndim == 4 and mask.ndim == 4, "Image and Mask must have 4 dimensions"
|
121 |
+
assert image.shape[-2:] == mask.shape[-2:], "Image and Mask must have the same spatial dimensions"
|
122 |
+
assert image.shape[0] == mask.shape[0], "Image and Mask must have the same batch size"
|
|
|
|
|
|
|
|
|
|
|
|
|
123 |
|
124 |
# Check image is in [-1, 1]
|
125 |
if image.min() < -1 or image.max() > 1:
|
|
|
136 |
# Image as float32
|
137 |
image = image.to(dtype=torch.float32)
|
138 |
elif isinstance(mask, torch.Tensor):
|
139 |
+
raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not")
|
|
|
|
|
140 |
else:
|
141 |
# preprocess image
|
142 |
if isinstance(image, (PIL.Image.Image, np.ndarray)):
|
|
|
156 |
mask = [mask]
|
157 |
|
158 |
if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image):
|
159 |
+
mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0)
|
|
|
|
|
160 |
mask = mask.astype(np.float32) / 255.0
|
161 |
elif isinstance(mask, list) and isinstance(mask[0], np.ndarray):
|
162 |
mask = np.concatenate([m[None, None, :] for m in mask], axis=0)
|
|
|
169 |
|
170 |
return mask, masked_image
|
171 |
|
172 |
+
class StableDiffusionControlNetInpaintPipeline(StableDiffusionControlNetPipeline):
|
|
|
|
|
|
|
173 |
r"""
|
174 |
Pipeline for text-guided image inpainting using Stable Diffusion with ControlNet guidance.
|
175 |
|
|
|
198 |
feature_extractor ([`CLIPFeatureExtractor`]):
|
199 |
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
200 |
"""
|
201 |
+
|
202 |
def prepare_mask_latents(
|
203 |
+
self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
204 |
):
|
205 |
# resize the mask to latents shape as we concatenate the mask to the latents
|
206 |
# we do that before converting to dtype to avoid breaking in case we're using cpu_offload
|
207 |
# and half precision
|
208 |
mask = torch.nn.functional.interpolate(
|
209 |
+
mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor)
|
|
|
|
|
|
|
|
|
210 |
)
|
211 |
mask = mask.to(device=device, dtype=dtype)
|
212 |
|
|
|
215 |
# encode the mask image into latents space so we can concatenate it to the latents
|
216 |
if isinstance(generator, list):
|
217 |
masked_image_latents = [
|
218 |
+
self.vae.encode(masked_image[i : i + 1]).latent_dist.sample(generator=generator[i])
|
|
|
|
|
219 |
for i in range(batch_size)
|
220 |
]
|
221 |
masked_image_latents = torch.cat(masked_image_latents, dim=0)
|
222 |
else:
|
223 |
+
masked_image_latents = self.vae.encode(masked_image).latent_dist.sample(generator=generator)
|
224 |
+
masked_image_latents = self.vae.config.scaling_factor * masked_image_latents
|
|
|
|
|
|
|
|
|
225 |
|
226 |
# duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
|
227 |
if mask.shape[0] < batch_size:
|
|
|
239 |
f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
|
240 |
" Make sure the number of images that you pass is divisible by the total requested batch size."
|
241 |
)
|
242 |
+
masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1)
|
|
|
|
|
243 |
|
244 |
mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
|
245 |
masked_image_latents = (
|
246 |
+
torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents
|
|
|
|
|
247 |
)
|
248 |
|
249 |
# aligning device to prevent device errors when concating it with the latent model input
|
250 |
+
masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
|
|
|
|
|
251 |
return mask, masked_image_latents
|
252 |
+
|
253 |
@torch.no_grad()
|
254 |
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
255 |
def __call__(
|
256 |
self,
|
257 |
+
prompt: Union[str, List[str]] = None,
|
258 |
image: Union[torch.FloatTensor, PIL.Image.Image] = None,
|
259 |
+
control_image: Union[torch.FloatTensor, PIL.Image.Image, List[torch.FloatTensor], List[PIL.Image.Image]] = None,
|
|
|
|
|
|
|
|
|
|
|
260 |
mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None,
|
261 |
height: Optional[int] = None,
|
262 |
width: Optional[int] = None,
|
|
|
265 |
negative_prompt: Optional[Union[str, List[str]]] = None,
|
266 |
num_images_per_prompt: Optional[int] = 1,
|
267 |
eta: float = 0.0,
|
268 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
|
|
|
|
269 |
latents: Optional[torch.FloatTensor] = None,
|
270 |
prompt_embeds: Optional[torch.FloatTensor] = None,
|
271 |
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
272 |
output_type: Optional[str] = "pil",
|
273 |
return_dict: bool = True,
|
274 |
+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
|
|
|
|
275 |
callback_steps: int = 1,
|
276 |
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
277 |
controlnet_conditioning_scale: float = 1.0,
|
|
|
293 |
`Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
|
294 |
repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted
|
295 |
to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L)
|
296 |
+
instead of 3, so the expected shape would be `(B, H, W, 1)`.
|
297 |
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
298 |
The height in pixels of the generated image.
|
299 |
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
|
|
362 |
|
363 |
# 1. Check inputs. Raise error if not correct
|
364 |
self.check_inputs(
|
365 |
+
prompt, control_image, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
366 |
)
|
367 |
|
368 |
# 2. Define call parameters
|
|
|
392 |
|
393 |
# 4. Prepare image
|
394 |
control_image = self.prepare_image(
|
395 |
+
control_image,
|
396 |
+
width,
|
397 |
+
height,
|
398 |
+
batch_size * num_images_per_prompt,
|
399 |
+
num_images_per_prompt,
|
400 |
+
device,
|
401 |
+
self.controlnet.dtype,
|
402 |
+
)
|
403 |
+
|
404 |
if do_classifier_free_guidance:
|
405 |
control_image = torch.cat([control_image] * 2)
|
406 |
|
|
|
409 |
timesteps = self.scheduler.timesteps
|
410 |
|
411 |
# 6. Prepare latent variables
|
412 |
+
num_channels_latents = self.controlnet.config.in_channels
|
413 |
latents = self.prepare_latents(
|
414 |
batch_size * num_images_per_prompt,
|
415 |
num_channels_latents,
|
|
|
420 |
generator,
|
421 |
latents,
|
422 |
)
|
423 |
+
|
424 |
# EXTRA: prepare mask latents
|
425 |
mask, masked_image = prepare_mask_and_masked_image(image, mask_image)
|
426 |
mask, masked_image_latents = self.prepare_mask_latents(
|
|
|
439 |
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
440 |
|
441 |
# 8. Denoising loop
|
442 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
|
|
|
|
443 |
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
444 |
for i, t in enumerate(timesteps):
|
445 |
# expand the latents if we are doing classifier free guidance
|
446 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
447 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
|
|
|
|
|
|
|
|
|
|
|
|
448 |
|
449 |
down_block_res_samples, mid_block_res_sample = self.controlnet(
|
450 |
latent_model_input,
|
|
|
461 |
mid_block_res_sample *= controlnet_conditioning_scale
|
462 |
|
463 |
# predict the noise residual
|
464 |
+
latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)
|
|
|
|
|
465 |
noise_pred = self.unet(
|
466 |
latent_model_input,
|
467 |
t,
|
|
|
474 |
# perform guidance
|
475 |
if do_classifier_free_guidance:
|
476 |
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
477 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
|
|
|
|
478 |
|
479 |
# compute the previous noisy sample x_t -> x_t-1
|
480 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
|
|
|
|
481 |
|
482 |
# call the callback, if provided
|
483 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
|
|
|
|
|
|
484 |
progress_bar.update()
|
485 |
if callback is not None and i % callback_steps == 0:
|
486 |
callback(i, t, latents)
|
487 |
|
488 |
# If we do sequential model offloading, let's offload unet and controlnet
|
489 |
# manually for max memory savings
|
490 |
+
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
|
|
|
|
|
|
491 |
self.unet.to("cpu")
|
492 |
self.controlnet.to("cpu")
|
493 |
torch.cuda.empty_cache()
|
|
|
500 |
image = self.decode_latents(latents)
|
501 |
|
502 |
# 9. Run safety checker
|
503 |
+
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
|
|
|
|
|
504 |
|
505 |
# 10. Convert to PIL
|
506 |
image = self.numpy_to_pil(image)
|
|
|
509 |
image = self.decode_latents(latents)
|
510 |
|
511 |
# 9. Run safety checker
|
512 |
+
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
|
|
|
|
|
513 |
|
514 |
# Offload last model to CPU
|
515 |
+
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
|
|
|
|
|
|
516 |
self.final_offload_hook.offload()
|
517 |
|
518 |
if not return_dict:
|
519 |
return (image, has_nsfw_concept)
|
520 |
|
521 |
+
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|
|
|
|