# Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # modified by Wuvin from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np import torch from diffusers import AutoencoderKL, UNet2DConditionModel, StableDiffusionImageVariationPipeline from diffusers.schedulers import KarrasDiffusionSchedulers, DDPMScheduler from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker, StableDiffusionPipelineOutput from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel from PIL import Image from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection class StableDiffusionImage2MVCustomPipeline( StableDiffusionImageVariationPipeline ): def __init__( self, vae: AutoencoderKL, image_encoder: CLIPVisionModelWithProjection, unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, requires_safety_checker: bool = True, latents_offset=None, noisy_cond_latents=False, condition_offset=True, ): super().__init__( vae=vae, image_encoder=image_encoder, unet=unet, scheduler=scheduler, safety_checker=safety_checker, feature_extractor=feature_extractor, requires_safety_checker=requires_safety_checker ) latents_offset = tuple(latents_offset) if latents_offset is not None else None self.latents_offset = latents_offset if latents_offset is not None: self.register_to_config(latents_offset=latents_offset) if noisy_cond_latents: raise NotImplementedError("Noisy condition latents not supported Now.") self.condition_offset = condition_offset self.register_to_config(condition_offset=condition_offset) def encode_latents(self, image: Image.Image, device, dtype, height, width): images = self.image_processor.preprocess(image.convert("RGB"), height=height, width=width).to(device, dtype=dtype) # NOTE: .mode() for condition latents = self.vae.encode(images).latent_dist.mode() * self.vae.config.scaling_factor if self.latents_offset is not None and self.condition_offset: return latents - torch.tensor(self.latents_offset).to(latents.device)[None, :, None, None] else: return latents def _encode_image(self, image, device, num_images_per_prompt, do_classifier_free_guidance): dtype = next(self.image_encoder.parameters()).dtype if not isinstance(image, torch.Tensor): image = self.feature_extractor(images=image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) image_embeddings = self.image_encoder(image).image_embeds image_embeddings = image_embeddings.unsqueeze(1) # duplicate image embeddings for each generation per prompt, using mps friendly method bs_embed, seq_len, _ = image_embeddings.shape image_embeddings = image_embeddings.repeat(1, num_images_per_prompt, 1) image_embeddings = image_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) if do_classifier_free_guidance: # NOTE: the same as original code negative_prompt_embeds = torch.zeros_like(image_embeddings) # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes image_embeddings = torch.cat([negative_prompt_embeds, image_embeddings]) return image_embeddings @torch.no_grad() def __call__( self, image: Union[Image.Image, List[Image.Image], torch.FloatTensor], height: Optional[int] = 1024, width: Optional[int] = 1024, height_cond: Optional[int] = 512, width_cond: Optional[int] = 512, num_inference_steps: int = 50, guidance_scale: float = 7.5, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, ): r""" The call function to the pipeline for generation. Args: image (`Image.Image` or `List[Image.Image]` or `torch.FloatTensor`): Image or images to guide image generation. If you provide a tensor, it needs to be compatible with [`CLIPImageProcessor`](https://huggingface.co/lambdalabs/sd-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json). height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): The height in pixels of the generated image. width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. This parameter is modulated by `strength`. guidance_scale (`float`, *optional*, defaults to 7.5): A higher guidance scale value encourages the model to generate images closely linked to the text `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. latents (`torch.FloatTensor`, *optional*): Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image generation. Can be used to tweak the same generation with different prompts. If not provided, a latents tensor is generated by sampling using the supplied random `generator`. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a plain tuple. callback (`Callable`, *optional*): A function that calls every `callback_steps` steps during inference. The function is called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function is called. If not specified, the callback is called at every step. Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated images and the second element is a list of `bool`s indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. Examples: ```py from diffusers import StableDiffusionImageVariationPipeline from PIL import Image from io import BytesIO import requests pipe = StableDiffusionImageVariationPipeline.from_pretrained( "lambdalabs/sd-image-variations-diffusers", revision="v2.0" ) pipe = pipe.to("cuda") url = "https://lh3.googleusercontent.com/y-iFOHfLTwkuQSUegpwDdgKmOjRSTvPxat63dQLB25xkTs4lhIbRUFeNBWZzYf370g=s1200" response = requests.get(url) image = Image.open(BytesIO(response.content)).convert("RGB") out = pipe(image, num_images_per_prompt=3, guidance_scale=15) out["images"][0].save("result.jpg") ``` """ # 0. Default height and width to unet height = height or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor # 1. Check inputs. Raise error if not correct self.check_inputs(image, height, width, callback_steps) # 2. Define call parameters if isinstance(image, Image.Image): batch_size = 1 elif len(image) == 1: image = image[0] batch_size = 1 else: raise NotImplementedError() # elif isinstance(image, list): # batch_size = len(image) # else: # batch_size = image.shape[0] # device = self._execution_device device = "cuda" # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input image emb_image = image image_embeddings = self._encode_image(emb_image, device, num_images_per_prompt, do_classifier_free_guidance).to(device=self.unet.device, dtype=self.unet.dtype) print("DEBUG: image_embeddings", image_embeddings.dtype, image_embeddings.device) print("DEBUG: version v111") cond_latents = self.encode_latents(image, image_embeddings.device, image_embeddings.dtype, height_cond, width_cond) cond_latents = torch.cat([torch.zeros_like(cond_latents), cond_latents]) if do_classifier_free_guidance else cond_latents image_pixels = self.feature_extractor(images=emb_image, return_tensors="pt").pixel_values if do_classifier_free_guidance: image_pixels = torch.cat([torch.zeros_like(image_pixels), image_pixels], dim=0) # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps # 5. Prepare latent variables num_channels_latents = self.unet.config.out_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, image_embeddings.dtype, device, generator, latents, ) # 6. Prepare extra step kwargs. extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=image_embeddings, condition_latents=cond_latents, noisy_condition_input=False, cond_pixels_clip=image_pixels).sample # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) self.maybe_free_model_hooks() if self.latents_offset is not None: latents = latents + torch.tensor(self.latents_offset).to(latents.device)[None, :, None, None] if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image, has_nsfw_concept = self.run_safety_checker(image, device, image_embeddings.dtype) else: image = latents has_nsfw_concept = None if has_nsfw_concept is None: do_denormalize = [True] * image.shape[0] else: do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) self.maybe_free_model_hooks() if not return_dict: return (image, has_nsfw_concept) return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) if __name__ == "__main__": pass