|
|
|
import imp |
|
import numpy as np |
|
import cv2 |
|
import torch |
|
import random |
|
from PIL import Image, ImageDraw, ImageFont |
|
import copy |
|
from typing import Optional, Union, Tuple, List, Callable, Dict, Any |
|
from tqdm.notebook import tqdm |
|
from diffusers.utils import BaseOutput, logging |
|
from diffusers.models.embeddings import TimestepEmbedding, Timesteps |
|
from diffusers.models.unet_2d_blocks import ( |
|
CrossAttnDownBlock2D, |
|
CrossAttnUpBlock2D, |
|
DownBlock2D, |
|
UNetMidBlock2DCrossAttn, |
|
UpBlock2D, |
|
get_down_block, |
|
get_up_block, |
|
) |
|
from diffusers.models.unet_2d_condition import UNet2DConditionOutput, logger |
|
from copy import deepcopy |
|
import json |
|
|
|
import inspect |
|
import os |
|
import warnings |
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union |
|
|
|
import numpy as np |
|
import PIL.Image |
|
import torch |
|
import torch.nn.functional as F |
|
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer |
|
|
|
from diffusers.image_processor import VaeImageProcessor |
|
from diffusers.loaders import LoraLoaderMixin, TextualInversionLoaderMixin |
|
from diffusers.models import AutoencoderKL, ControlNetModel, UNet2DConditionModel |
|
from diffusers.schedulers import KarrasDiffusionSchedulers |
|
from diffusers.utils.torch_utils import is_compiled_module |
|
|
|
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput |
|
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker |
|
from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel |
|
from tqdm import tqdm |
|
from controlnet_aux import HEDdetector, OpenposeDetector |
|
import time |
|
|
|
def seed_everything(seed): |
|
torch.manual_seed(seed) |
|
torch.cuda.manual_seed(seed) |
|
random.seed(seed) |
|
np.random.seed(seed) |
|
|
|
def get_promptls(prompt_path): |
|
with open(prompt_path) as f: |
|
prompt_ls = json.load(f) |
|
prompt_ls = [prompt['caption'].replace('/','_') for prompt in prompt_ls] |
|
return prompt_ls |
|
|
|
def load_512(image_path, left=0, right=0, top=0, bottom=0): |
|
|
|
if type(image_path) is str: |
|
image = np.array(Image.open(image_path)) |
|
if image.ndim>3: |
|
image = image[:,:,:3] |
|
elif image.ndim == 2: |
|
image = image.reshape(image.shape[0], image.shape[1],1).astype('uint8') |
|
else: |
|
image = image_path |
|
h, w, c = image.shape |
|
left = min(left, w-1) |
|
right = min(right, w - left - 1) |
|
top = min(top, h - left - 1) |
|
bottom = min(bottom, h - top - 1) |
|
image = image[top:h-bottom, left:w-right] |
|
h, w, c = image.shape |
|
if h < w: |
|
offset = (w - h) // 2 |
|
image = image[:, offset:offset + h] |
|
elif w < h: |
|
offset = (h - w) // 2 |
|
image = image[offset:offset + w] |
|
image = np.array(Image.fromarray(image).resize((512, 512))) |
|
return image |
|
|
|
def get_canny(image_path): |
|
image = load_512( |
|
image_path |
|
) |
|
image = np.array(image) |
|
|
|
|
|
image = cv2.Canny(image, 100, 200) |
|
image = image[:, :, None] |
|
image = np.concatenate([image, image, image], axis=2) |
|
canny_image = Image.fromarray(image) |
|
return canny_image |
|
|
|
|
|
def get_scribble(image_path, hed): |
|
image = load_512( |
|
image_path |
|
) |
|
image = hed(image, scribble=True) |
|
|
|
return image |
|
|
|
def get_cocoimages(prompt_path): |
|
data_ls = [] |
|
with open(prompt_path) as f: |
|
prompt_ls = json.load(f) |
|
img_path = 'COCO2017-val/val2017' |
|
for prompt in tqdm(prompt_ls): |
|
caption = prompt['caption'].replace('/','_') |
|
image_id = str(prompt['image_id']) |
|
image_id = (12-len(image_id))*'0' + image_id+'.jpg' |
|
image_path = os.path.join(img_path, image_id) |
|
try: |
|
image = get_canny(image_path) |
|
except: |
|
continue |
|
curr_data = {'image':image, 'prompt':caption} |
|
data_ls.append(curr_data) |
|
return data_ls |
|
|
|
def get_cocoimages2(prompt_path): |
|
"""scribble condition |
|
""" |
|
data_ls = [] |
|
with open(prompt_path) as f: |
|
prompt_ls = json.load(f) |
|
img_path = 'COCO2017-val/val2017' |
|
hed = HEDdetector.from_pretrained('ControlNet/detector_weights/annotator', filename='network-bsds500.pth') |
|
for prompt in tqdm(prompt_ls): |
|
caption = prompt['caption'].replace('/','_') |
|
image_id = str(prompt['image_id']) |
|
image_id = (12-len(image_id))*'0' + image_id+'.jpg' |
|
image_path = os.path.join(img_path, image_id) |
|
try: |
|
image = get_scribble(image_path,hed) |
|
except: |
|
continue |
|
curr_data = {'image':image, 'prompt':caption} |
|
data_ls.append(curr_data) |
|
return data_ls |
|
|
|
def warpped_feature(sample, step): |
|
""" |
|
sample: batch_size*dim*h*w, uncond: 0 - batch_size//2, cond: batch_size//2 - batch_size |
|
step: timestep span |
|
""" |
|
bs, dim, h, w = sample.shape |
|
uncond_fea, cond_fea = sample.chunk(2) |
|
uncond_fea = uncond_fea.repeat(step,1,1,1) |
|
cond_fea = cond_fea.repeat(step,1,1,1) |
|
return torch.cat([uncond_fea, cond_fea]) |
|
|
|
def warpped_skip_feature(block_samples, step): |
|
down_block_res_samples = [] |
|
for sample in block_samples: |
|
sample_expand = warpped_feature(sample, step) |
|
down_block_res_samples.append(sample_expand) |
|
return tuple(down_block_res_samples) |
|
|
|
def warpped_text_emb(text_emb, step): |
|
""" |
|
text_emb: batch_size*77*768, uncond: 0 - batch_size//2, cond: batch_size//2 - batch_size |
|
step: timestep span |
|
""" |
|
bs, token_len, dim = text_emb.shape |
|
uncond_fea, cond_fea = text_emb.chunk(2) |
|
uncond_fea = uncond_fea.repeat(step,1,1) |
|
cond_fea = cond_fea.repeat(step,1,1) |
|
return torch.cat([uncond_fea, cond_fea]) |
|
|
|
def warpped_timestep(timesteps, bs): |
|
""" |
|
timestpes: list, such as [981, 961, 941] |
|
""" |
|
semi_bs = bs//2 |
|
ts = [] |
|
for timestep in timesteps: |
|
timestep = timestep[None] |
|
texp = timestep.expand(semi_bs) |
|
ts.append(texp) |
|
timesteps = torch.cat(ts) |
|
return timesteps.repeat(2,1).reshape(-1) |
|
|
|
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): |
|
""" |
|
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and |
|
Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 |
|
""" |
|
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) |
|
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) |
|
|
|
noise_pred_rescaled = noise_cfg * (std_text / std_cfg) |
|
|
|
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg |
|
return noise_cfg |
|
|
|
def register_normal_pipeline(pipe): |
|
def new_call(self): |
|
@torch.no_grad() |
|
def call( |
|
prompt: Union[str, List[str]] = None, |
|
height: Optional[int] = None, |
|
width: Optional[int] = None, |
|
num_inference_steps: int = 50, |
|
guidance_scale: float = 7.5, |
|
negative_prompt: Optional[Union[str, List[str]]] = None, |
|
num_images_per_prompt: Optional[int] = 1, |
|
eta: float = 0.0, |
|
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, |
|
latents: Optional[torch.FloatTensor] = None, |
|
prompt_embeds: Optional[torch.FloatTensor] = None, |
|
negative_prompt_embeds: Optional[torch.FloatTensor] = None, |
|
output_type: Optional[str] = "pil", |
|
return_dict: bool = True, |
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None, |
|
guidance_rescale: float = 0.0, |
|
clip_skip: Optional[int] = None, |
|
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, |
|
callback_on_step_end_tensor_inputs: List[str] = ["latents"], |
|
**kwargs, |
|
): |
|
|
|
callback = kwargs.pop("callback", None) |
|
callback_steps = kwargs.pop("callback_steps", None) |
|
|
|
|
|
|
|
height = height or self.unet.config.sample_size * self.vae_scale_factor |
|
width = width or self.unet.config.sample_size * self.vae_scale_factor |
|
|
|
|
|
|
|
self.check_inputs( |
|
prompt, |
|
height, |
|
width, |
|
callback_steps, |
|
negative_prompt, |
|
prompt_embeds, |
|
negative_prompt_embeds, |
|
callback_on_step_end_tensor_inputs, |
|
) |
|
|
|
self._guidance_scale = guidance_scale |
|
self._guidance_rescale = guidance_rescale |
|
self._clip_skip = clip_skip |
|
self._cross_attention_kwargs = cross_attention_kwargs |
|
|
|
|
|
if prompt is not None and isinstance(prompt, str): |
|
batch_size = 1 |
|
elif prompt is not None and isinstance(prompt, list): |
|
batch_size = len(prompt) |
|
else: |
|
batch_size = prompt_embeds.shape[0] |
|
|
|
device = self._execution_device |
|
|
|
|
|
lora_scale = ( |
|
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None |
|
) |
|
|
|
prompt_embeds, negative_prompt_embeds = self.encode_prompt( |
|
prompt, |
|
device, |
|
num_images_per_prompt, |
|
self.do_classifier_free_guidance, |
|
negative_prompt, |
|
prompt_embeds=prompt_embeds, |
|
negative_prompt_embeds=negative_prompt_embeds, |
|
lora_scale=lora_scale, |
|
clip_skip=self.clip_skip, |
|
) |
|
|
|
|
|
|
|
if self.do_classifier_free_guidance: |
|
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) |
|
|
|
|
|
self.scheduler.set_timesteps(num_inference_steps, device=device) |
|
timesteps = self.scheduler.timesteps |
|
|
|
|
|
num_channels_latents = self.unet.config.in_channels |
|
latents = self.prepare_latents( |
|
batch_size * num_images_per_prompt, |
|
num_channels_latents, |
|
height, |
|
width, |
|
prompt_embeds.dtype, |
|
device, |
|
generator, |
|
latents, |
|
) |
|
|
|
|
|
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) |
|
|
|
|
|
timestep_cond = None |
|
if self.unet.config.time_cond_proj_dim is not None: |
|
guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) |
|
timestep_cond = self.get_guidance_scale_embedding( |
|
guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim |
|
).to(device=device, dtype=latents.dtype) |
|
|
|
|
|
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order |
|
self._num_timesteps = len(timesteps) |
|
init_latents = latents.detach().clone() |
|
with self.progress_bar(total=num_inference_steps) as progress_bar: |
|
for i, t in enumerate(timesteps): |
|
if t/1000 < 0.5: |
|
latents = latents + 0.003*init_latents |
|
setattr(self.unet, 'order', i) |
|
|
|
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents |
|
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) |
|
|
|
|
|
noise_pred = self.unet( |
|
latent_model_input, |
|
t, |
|
encoder_hidden_states=prompt_embeds, |
|
timestep_cond=timestep_cond, |
|
cross_attention_kwargs=self.cross_attention_kwargs, |
|
return_dict=False, |
|
)[0] |
|
|
|
|
|
if self.do_classifier_free_guidance: |
|
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) |
|
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) |
|
|
|
if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: |
|
|
|
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) |
|
|
|
|
|
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] |
|
|
|
if callback_on_step_end is not None: |
|
callback_kwargs = {} |
|
for k in callback_on_step_end_tensor_inputs: |
|
callback_kwargs[k] = locals()[k] |
|
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) |
|
|
|
latents = callback_outputs.pop("latents", latents) |
|
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) |
|
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) |
|
|
|
|
|
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): |
|
progress_bar.update() |
|
if callback is not None and i % callback_steps == 0: |
|
step_idx = i // getattr(self.scheduler, "order", 1) |
|
callback(step_idx, t, latents) |
|
|
|
if not output_type == "latent": |
|
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ |
|
0 |
|
] |
|
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) |
|
else: |
|
image = latents |
|
has_nsfw_concept = None |
|
|
|
if has_nsfw_concept is None: |
|
do_denormalize = [True] * image.shape[0] |
|
else: |
|
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] |
|
|
|
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) |
|
|
|
|
|
self.maybe_free_model_hooks() |
|
|
|
if not return_dict: |
|
return (image, has_nsfw_concept) |
|
|
|
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) |
|
return call |
|
pipe.call = new_call(pipe) |
|
|
|
|
|
def register_parallel_pipeline(pipe): |
|
def new_call(self): |
|
@torch.no_grad() |
|
def call( |
|
prompt: Union[str, List[str]] = None, |
|
height: Optional[int] = None, |
|
width: Optional[int] = None, |
|
num_inference_steps: int = 50, |
|
guidance_scale: float = 7.5, |
|
negative_prompt: Optional[Union[str, List[str]]] = None, |
|
num_images_per_prompt: Optional[int] = 1, |
|
eta: float = 0.0, |
|
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, |
|
latents: Optional[torch.FloatTensor] = None, |
|
prompt_embeds: Optional[torch.FloatTensor] = None, |
|
negative_prompt_embeds: Optional[torch.FloatTensor] = None, |
|
output_type: Optional[str] = "pil", |
|
return_dict: bool = True, |
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None, |
|
guidance_rescale: float = 0.0, |
|
clip_skip: Optional[int] = None, |
|
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, |
|
callback_on_step_end_tensor_inputs: List[str] = ["latents"], |
|
**kwargs, |
|
): |
|
|
|
callback = kwargs.pop("callback", None) |
|
callback_steps = kwargs.pop("callback_steps", None) |
|
|
|
|
|
|
|
height = height or self.unet.config.sample_size * self.vae_scale_factor |
|
width = width or self.unet.config.sample_size * self.vae_scale_factor |
|
|
|
|
|
|
|
self.check_inputs( |
|
prompt, |
|
height, |
|
width, |
|
callback_steps, |
|
negative_prompt, |
|
prompt_embeds, |
|
negative_prompt_embeds, |
|
callback_on_step_end_tensor_inputs, |
|
) |
|
|
|
self._guidance_scale = guidance_scale |
|
self._guidance_rescale = guidance_rescale |
|
self._clip_skip = clip_skip |
|
self._cross_attention_kwargs = cross_attention_kwargs |
|
|
|
|
|
if prompt is not None and isinstance(prompt, str): |
|
batch_size = 1 |
|
elif prompt is not None and isinstance(prompt, list): |
|
batch_size = len(prompt) |
|
else: |
|
batch_size = prompt_embeds.shape[0] |
|
|
|
device = self._execution_device |
|
|
|
|
|
lora_scale = ( |
|
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None |
|
) |
|
|
|
prompt_embeds, negative_prompt_embeds = self.encode_prompt( |
|
prompt, |
|
device, |
|
num_images_per_prompt, |
|
self.do_classifier_free_guidance, |
|
negative_prompt, |
|
prompt_embeds=prompt_embeds, |
|
negative_prompt_embeds=negative_prompt_embeds, |
|
lora_scale=lora_scale, |
|
clip_skip=self.clip_skip, |
|
) |
|
|
|
|
|
|
|
if self.do_classifier_free_guidance: |
|
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) |
|
|
|
|
|
self.scheduler.set_timesteps(num_inference_steps, device=device) |
|
timesteps = self.scheduler.timesteps |
|
|
|
|
|
num_channels_latents = self.unet.config.in_channels |
|
latents = self.prepare_latents( |
|
batch_size * num_images_per_prompt, |
|
num_channels_latents, |
|
height, |
|
width, |
|
prompt_embeds.dtype, |
|
device, |
|
generator, |
|
latents, |
|
) |
|
|
|
|
|
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) |
|
|
|
|
|
timestep_cond = None |
|
if self.unet.config.time_cond_proj_dim is not None: |
|
guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) |
|
timestep_cond = self.get_guidance_scale_embedding( |
|
guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim |
|
).to(device=device, dtype=latents.dtype) |
|
|
|
|
|
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order |
|
self._num_timesteps = len(timesteps) |
|
init_latents = latents.detach().clone() |
|
|
|
all_steps = len(self.scheduler.timesteps) |
|
curr_span = 1 |
|
curr_step = 0 |
|
|
|
|
|
idx = 1 |
|
keytime = [0,1,2,3,5,10,15,25,35] |
|
keytime.append(all_steps) |
|
while curr_step<all_steps: |
|
refister_time(self.unet, curr_step) |
|
|
|
merge_span = curr_span |
|
if merge_span>0: |
|
time_ls = [] |
|
for i in range(curr_step, curr_step+merge_span): |
|
if i<all_steps: |
|
time_ls.append(self.scheduler.timesteps[i]) |
|
else: |
|
break |
|
|
|
|
|
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents |
|
|
|
|
|
noise_pred = self.unet( |
|
latent_model_input, |
|
time_ls, |
|
encoder_hidden_states=prompt_embeds, |
|
timestep_cond=timestep_cond, |
|
cross_attention_kwargs=self.cross_attention_kwargs, |
|
return_dict=False, |
|
)[0] |
|
|
|
|
|
if self.do_classifier_free_guidance: |
|
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) |
|
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) |
|
|
|
if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: |
|
|
|
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) |
|
|
|
|
|
|
|
step_span = len(time_ls) |
|
bs = noise_pred.shape[0] |
|
bs_perstep = bs//step_span |
|
|
|
denoised_latent = latents |
|
for i, timestep in enumerate(time_ls): |
|
if timestep/1000 < 0.5: |
|
denoised_latent = denoised_latent + 0.003*init_latents |
|
curr_noise = noise_pred[i*bs_perstep:(i+1)*bs_perstep] |
|
denoised_latent = self.scheduler.step(curr_noise, timestep, denoised_latent, **extra_step_kwargs, return_dict=False)[0] |
|
|
|
latents = denoised_latent |
|
|
|
curr_step += curr_span |
|
idx += 1 |
|
|
|
if curr_step<all_steps: |
|
curr_span = keytime[idx] - keytime[idx-1] |
|
|
|
|
|
if not output_type == "latent": |
|
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ |
|
0 |
|
] |
|
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) |
|
else: |
|
image = latents |
|
has_nsfw_concept = None |
|
|
|
if has_nsfw_concept is None: |
|
do_denormalize = [True] * image.shape[0] |
|
else: |
|
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] |
|
|
|
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) |
|
|
|
|
|
self.maybe_free_model_hooks() |
|
|
|
if not return_dict: |
|
return (image, has_nsfw_concept) |
|
|
|
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) |
|
return call |
|
pipe.call = new_call(pipe) |
|
|
|
def register_faster_forward(model, mod = '50ls'): |
|
def faster_forward(self): |
|
def forward( |
|
sample: torch.FloatTensor, |
|
timestep: Union[torch.Tensor, float, int], |
|
encoder_hidden_states: torch.Tensor, |
|
class_labels: Optional[torch.Tensor] = None, |
|
timestep_cond: Optional[torch.Tensor] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None, |
|
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, |
|
mid_block_additional_residual: Optional[torch.Tensor] = None, |
|
return_dict: bool = True, |
|
) -> Union[UNet2DConditionOutput, Tuple]: |
|
r""" |
|
Args: |
|
sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor |
|
timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps |
|
encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states |
|
return_dict (`bool`, *optional*, defaults to `True`): |
|
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. |
|
cross_attention_kwargs (`dict`, *optional*): |
|
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under |
|
`self.processor` in |
|
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). |
|
|
|
Returns: |
|
[`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: |
|
[`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When |
|
returning a tuple, the first element is the sample tensor. |
|
""" |
|
|
|
|
|
|
|
|
|
default_overall_up_factor = 2**self.num_upsamplers |
|
|
|
|
|
forward_upsample_size = False |
|
upsample_size = None |
|
|
|
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): |
|
logger.info("Forward upsample size to force interpolation output size.") |
|
forward_upsample_size = True |
|
|
|
|
|
if attention_mask is not None: |
|
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 |
|
attention_mask = attention_mask.unsqueeze(1) |
|
|
|
|
|
if self.config.center_input_sample: |
|
sample = 2 * sample - 1.0 |
|
|
|
|
|
if isinstance(timestep, list): |
|
timesteps = timestep[0] |
|
step = len(timestep) |
|
else: |
|
timesteps = timestep |
|
step = 1 |
|
if not torch.is_tensor(timesteps) and (not isinstance(timesteps,list)): |
|
|
|
|
|
is_mps = sample.device.type == "mps" |
|
if isinstance(timestep, float): |
|
dtype = torch.float32 if is_mps else torch.float64 |
|
else: |
|
dtype = torch.int32 if is_mps else torch.int64 |
|
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) |
|
elif (not isinstance(timesteps,list)) and len(timesteps.shape) == 0: |
|
timesteps = timesteps[None].to(sample.device) |
|
|
|
if (not isinstance(timesteps,list)) and len(timesteps.shape) == 1: |
|
|
|
timesteps = timesteps.expand(sample.shape[0]) |
|
elif isinstance(timesteps, list): |
|
|
|
timesteps = warpped_timestep(timesteps, sample.shape[0]).to(sample.device) |
|
t_emb = self.time_proj(timesteps) |
|
|
|
|
|
|
|
|
|
t_emb = t_emb.to(dtype=self.dtype) |
|
|
|
emb = self.time_embedding(t_emb, timestep_cond) |
|
|
|
if self.class_embedding is not None: |
|
if class_labels is None: |
|
raise ValueError("class_labels should be provided when num_class_embeds > 0") |
|
|
|
if self.config.class_embed_type == "timestep": |
|
class_labels = self.time_proj(class_labels) |
|
|
|
|
|
|
|
class_labels = class_labels.to(dtype=sample.dtype) |
|
|
|
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) |
|
|
|
if self.config.class_embeddings_concat: |
|
emb = torch.cat([emb, class_emb], dim=-1) |
|
else: |
|
emb = emb + class_emb |
|
|
|
if self.config.addition_embed_type == "text": |
|
aug_emb = self.add_embedding(encoder_hidden_states) |
|
emb = emb + aug_emb |
|
|
|
if self.time_embed_act is not None: |
|
emb = self.time_embed_act(emb) |
|
|
|
if self.encoder_hid_proj is not None: |
|
encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states) |
|
|
|
|
|
order = self.order |
|
|
|
ipow = int(np.sqrt(9 + 8*order)) |
|
cond = order in [0, 1, 2, 3, 5, 10, 15, 25, 35] |
|
if isinstance(mod, int): |
|
cond = order % mod == 0 |
|
elif mod == "pro": |
|
cond = ipow * ipow == (9 + 8 * order) |
|
elif mod == "50ls": |
|
cond = order in [0, 1, 2, 3, 5, 10, 15, 25, 35] |
|
elif mod == "50ls2": |
|
cond = order in [0, 10, 11, 12, 15, 20, 25, 30,35,45] |
|
elif mod == "50ls3": |
|
cond = order in [0, 20, 25, 30,35,45,46,47,48,49] |
|
elif mod == "50ls4": |
|
cond = order in [0, 9, 13, 14, 15, 28, 29, 32, 36,45] |
|
elif mod == "100ls": |
|
cond = order > 85 or order < 10 or order % 5 == 0 |
|
elif mod == "75ls": |
|
cond = order > 65 or order < 10 or order % 5 == 0 |
|
elif mod == "s2": |
|
cond = order < 20 or order > 40 or order % 2 == 0 |
|
|
|
if cond: |
|
print(order) |
|
|
|
sample = self.conv_in(sample) |
|
|
|
|
|
down_block_res_samples = (sample,) |
|
for downsample_block in self.down_blocks: |
|
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: |
|
sample, res_samples = downsample_block( |
|
hidden_states=sample, |
|
temb=emb, |
|
encoder_hidden_states=encoder_hidden_states, |
|
attention_mask=attention_mask, |
|
cross_attention_kwargs=cross_attention_kwargs, |
|
) |
|
else: |
|
sample, res_samples = downsample_block(hidden_states=sample, temb=emb) |
|
|
|
down_block_res_samples += res_samples |
|
|
|
if down_block_additional_residuals is not None: |
|
new_down_block_res_samples = () |
|
|
|
for down_block_res_sample, down_block_additional_residual in zip( |
|
down_block_res_samples, down_block_additional_residuals |
|
): |
|
down_block_res_sample = down_block_res_sample + down_block_additional_residual |
|
new_down_block_res_samples += (down_block_res_sample,) |
|
|
|
down_block_res_samples = new_down_block_res_samples |
|
|
|
|
|
if self.mid_block is not None: |
|
sample = self.mid_block( |
|
sample, |
|
emb, |
|
encoder_hidden_states=encoder_hidden_states, |
|
attention_mask=attention_mask, |
|
cross_attention_kwargs=cross_attention_kwargs, |
|
) |
|
|
|
if mid_block_additional_residual is not None: |
|
sample = sample + mid_block_additional_residual |
|
|
|
|
|
|
|
setattr(self, 'skip_feature', deepcopy(down_block_res_samples)) |
|
setattr(self, 'toup_feature', sample.detach().clone()) |
|
|
|
|
|
|
|
|
|
|
|
if isinstance(timestep, list): |
|
|
|
timesteps = warpped_timestep(timestep, sample.shape[0]).to(sample.device) |
|
t_emb = self.time_proj(timesteps) |
|
|
|
|
|
|
|
|
|
t_emb = t_emb.to(dtype=self.dtype) |
|
|
|
emb = self.time_embedding(t_emb, timestep_cond) |
|
|
|
|
|
|
|
down_block_res_samples = warpped_skip_feature(down_block_res_samples, step) |
|
sample = warpped_feature(sample, step) |
|
|
|
|
|
encoder_hidden_states = warpped_text_emb(encoder_hidden_states, step) |
|
|
|
|
|
|
|
|
|
|
|
else: |
|
down_block_res_samples = self.skip_feature |
|
sample = self.toup_feature |
|
|
|
|
|
down_block_res_samples = warpped_skip_feature(down_block_res_samples, step) |
|
sample = warpped_feature(sample, step) |
|
encoder_hidden_states = warpped_text_emb(encoder_hidden_states, step) |
|
|
|
|
|
|
|
for i, upsample_block in enumerate(self.up_blocks): |
|
is_final_block = i == len(self.up_blocks) - 1 |
|
|
|
res_samples = down_block_res_samples[-len(upsample_block.resnets) :] |
|
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] |
|
|
|
|
|
|
|
if not is_final_block and forward_upsample_size: |
|
upsample_size = down_block_res_samples[-1].shape[2:] |
|
|
|
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: |
|
sample = upsample_block( |
|
hidden_states=sample, |
|
temb=emb, |
|
res_hidden_states_tuple=res_samples, |
|
encoder_hidden_states=encoder_hidden_states, |
|
cross_attention_kwargs=cross_attention_kwargs, |
|
upsample_size=upsample_size, |
|
attention_mask=attention_mask, |
|
) |
|
else: |
|
sample = upsample_block( |
|
hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size |
|
) |
|
|
|
|
|
if self.conv_norm_out: |
|
sample = self.conv_norm_out(sample) |
|
sample = self.conv_act(sample) |
|
sample = self.conv_out(sample) |
|
|
|
if not return_dict: |
|
return (sample,) |
|
|
|
return UNet2DConditionOutput(sample=sample) |
|
return forward |
|
if model.__class__.__name__ == 'UNet2DConditionModel': |
|
model.forward = faster_forward(model) |
|
|
|
def register_normal_forward(model): |
|
def normal_forward(self): |
|
def forward( |
|
sample: torch.FloatTensor, |
|
timestep: Union[torch.Tensor, float, int], |
|
encoder_hidden_states: torch.Tensor, |
|
class_labels: Optional[torch.Tensor] = None, |
|
timestep_cond: Optional[torch.Tensor] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None, |
|
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, |
|
mid_block_additional_residual: Optional[torch.Tensor] = None, |
|
return_dict: bool = True, |
|
) -> Union[UNet2DConditionOutput, Tuple]: |
|
|
|
|
|
|
|
|
|
default_overall_up_factor = 2**self.num_upsamplers |
|
|
|
|
|
forward_upsample_size = False |
|
upsample_size = None |
|
|
|
|
|
|
|
|
|
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): |
|
logger.info("Forward upsample size to force interpolation output size.") |
|
forward_upsample_size = True |
|
|
|
|
|
if attention_mask is not None: |
|
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 |
|
attention_mask = attention_mask.unsqueeze(1) |
|
|
|
|
|
if self.config.center_input_sample: |
|
sample = 2 * sample - 1.0 |
|
|
|
|
|
timesteps = timestep |
|
if not torch.is_tensor(timesteps): |
|
|
|
|
|
is_mps = sample.device.type == "mps" |
|
if isinstance(timestep, float): |
|
dtype = torch.float32 if is_mps else torch.float64 |
|
else: |
|
dtype = torch.int32 if is_mps else torch.int64 |
|
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) |
|
elif len(timesteps.shape) == 0: |
|
timesteps = timesteps[None].to(sample.device) |
|
|
|
|
|
timesteps = timesteps.expand(sample.shape[0]) |
|
|
|
t_emb = self.time_proj(timesteps) |
|
|
|
|
|
|
|
|
|
t_emb = t_emb.to(dtype=self.dtype) |
|
|
|
emb = self.time_embedding(t_emb, timestep_cond) |
|
|
|
if self.class_embedding is not None: |
|
if class_labels is None: |
|
raise ValueError("class_labels should be provided when num_class_embeds > 0") |
|
|
|
if self.config.class_embed_type == "timestep": |
|
class_labels = self.time_proj(class_labels) |
|
|
|
|
|
|
|
class_labels = class_labels.to(dtype=sample.dtype) |
|
|
|
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) |
|
|
|
if self.config.class_embeddings_concat: |
|
emb = torch.cat([emb, class_emb], dim=-1) |
|
else: |
|
emb = emb + class_emb |
|
|
|
if self.config.addition_embed_type == "text": |
|
aug_emb = self.add_embedding(encoder_hidden_states) |
|
emb = emb + aug_emb |
|
|
|
if self.time_embed_act is not None: |
|
emb = self.time_embed_act(emb) |
|
|
|
if self.encoder_hid_proj is not None: |
|
encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states) |
|
|
|
|
|
sample = self.conv_in(sample) |
|
|
|
|
|
down_block_res_samples = (sample,) |
|
for i, downsample_block in enumerate(self.down_blocks): |
|
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: |
|
sample, res_samples = downsample_block( |
|
hidden_states=sample, |
|
temb=emb, |
|
encoder_hidden_states=encoder_hidden_states, |
|
attention_mask=attention_mask, |
|
cross_attention_kwargs=cross_attention_kwargs, |
|
) |
|
else: |
|
sample, res_samples = downsample_block(hidden_states=sample, temb=emb) |
|
|
|
|
|
|
|
down_block_res_samples += res_samples |
|
|
|
if down_block_additional_residuals is not None: |
|
new_down_block_res_samples = () |
|
|
|
for down_block_res_sample, down_block_additional_residual in zip( |
|
down_block_res_samples, down_block_additional_residuals |
|
): |
|
down_block_res_sample = down_block_res_sample + down_block_additional_residual |
|
new_down_block_res_samples += (down_block_res_sample,) |
|
|
|
down_block_res_samples = new_down_block_res_samples |
|
|
|
|
|
if self.mid_block is not None: |
|
sample = self.mid_block( |
|
sample, |
|
emb, |
|
encoder_hidden_states=encoder_hidden_states, |
|
attention_mask=attention_mask, |
|
cross_attention_kwargs=cross_attention_kwargs, |
|
) |
|
|
|
if mid_block_additional_residual is not None: |
|
sample = sample + mid_block_additional_residual |
|
|
|
for i, upsample_block in enumerate(self.up_blocks): |
|
is_final_block = i == len(self.up_blocks) - 1 |
|
|
|
res_samples = down_block_res_samples[-len(upsample_block.resnets) :] |
|
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] |
|
|
|
|
|
|
|
if not is_final_block and forward_upsample_size: |
|
upsample_size = down_block_res_samples[-1].shape[2:] |
|
|
|
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: |
|
sample = upsample_block( |
|
hidden_states=sample, |
|
temb=emb, |
|
res_hidden_states_tuple=res_samples, |
|
encoder_hidden_states=encoder_hidden_states, |
|
cross_attention_kwargs=cross_attention_kwargs, |
|
upsample_size=upsample_size, |
|
attention_mask=attention_mask, |
|
) |
|
else: |
|
sample = upsample_block( |
|
hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size |
|
) |
|
|
|
|
|
|
|
|
|
if self.conv_norm_out: |
|
sample = self.conv_norm_out(sample) |
|
sample = self.conv_act(sample) |
|
sample = self.conv_out(sample) |
|
|
|
if not return_dict: |
|
return (sample,) |
|
|
|
return UNet2DConditionOutput(sample=sample) |
|
return forward |
|
if model.__class__.__name__ == 'UNet2DConditionModel': |
|
model.forward = normal_forward(model) |
|
|
|
def refister_time(unet, t): |
|
setattr(unet, 'order', t) |
|
|
|
|
|
|
|
def register_controlnet_pipeline2(pipe): |
|
def new_call(self): |
|
@torch.no_grad() |
|
|
|
def call( |
|
prompt: Union[str, List[str]] = None, |
|
image: Union[ |
|
torch.FloatTensor, |
|
PIL.Image.Image, |
|
np.ndarray, |
|
List[torch.FloatTensor], |
|
List[PIL.Image.Image], |
|
List[np.ndarray], |
|
] = None, |
|
height: Optional[int] = None, |
|
width: Optional[int] = None, |
|
num_inference_steps: int = 50, |
|
guidance_scale: float = 7.5, |
|
negative_prompt: Optional[Union[str, List[str]]] = None, |
|
num_images_per_prompt: Optional[int] = 1, |
|
eta: float = 0.0, |
|
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, |
|
latents: Optional[torch.FloatTensor] = None, |
|
prompt_embeds: Optional[torch.FloatTensor] = None, |
|
negative_prompt_embeds: Optional[torch.FloatTensor] = None, |
|
output_type: Optional[str] = "pil", |
|
return_dict: bool = True, |
|
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, |
|
callback_steps: int = 1, |
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None, |
|
controlnet_conditioning_scale: Union[float, List[float]] = 1.0, |
|
guess_mode: bool = False, |
|
): |
|
|
|
self.check_inputs( |
|
prompt, |
|
image, |
|
callback_steps, |
|
negative_prompt, |
|
prompt_embeds, |
|
negative_prompt_embeds, |
|
controlnet_conditioning_scale, |
|
) |
|
|
|
|
|
if prompt is not None and isinstance(prompt, str): |
|
batch_size = 1 |
|
elif prompt is not None and isinstance(prompt, list): |
|
batch_size = len(prompt) |
|
else: |
|
batch_size = prompt_embeds.shape[0] |
|
|
|
device = self._execution_device |
|
|
|
|
|
|
|
do_classifier_free_guidance = guidance_scale > 1.0 |
|
|
|
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet |
|
|
|
if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): |
|
controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) |
|
|
|
global_pool_conditions = ( |
|
controlnet.config.global_pool_conditions |
|
if isinstance(controlnet, ControlNetModel) |
|
else controlnet.nets[0].config.global_pool_conditions |
|
) |
|
guess_mode = guess_mode or global_pool_conditions |
|
|
|
|
|
text_encoder_lora_scale = ( |
|
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None |
|
) |
|
prompt_embeds = self._encode_prompt( |
|
prompt, |
|
device, |
|
num_images_per_prompt, |
|
do_classifier_free_guidance, |
|
negative_prompt, |
|
prompt_embeds=prompt_embeds, |
|
negative_prompt_embeds=negative_prompt_embeds, |
|
lora_scale=text_encoder_lora_scale, |
|
) |
|
|
|
|
|
if isinstance(controlnet, ControlNetModel): |
|
image = self.prepare_image( |
|
image=image, |
|
width=width, |
|
height=height, |
|
batch_size=batch_size * num_images_per_prompt, |
|
num_images_per_prompt=num_images_per_prompt, |
|
device=device, |
|
dtype=controlnet.dtype, |
|
do_classifier_free_guidance=do_classifier_free_guidance, |
|
guess_mode=guess_mode, |
|
) |
|
height, width = image.shape[-2:] |
|
elif isinstance(controlnet, MultiControlNetModel): |
|
images = [] |
|
|
|
for image_ in image: |
|
image_ = self.prepare_image( |
|
image=image_, |
|
width=width, |
|
height=height, |
|
batch_size=batch_size * num_images_per_prompt, |
|
num_images_per_prompt=num_images_per_prompt, |
|
device=device, |
|
dtype=controlnet.dtype, |
|
do_classifier_free_guidance=do_classifier_free_guidance, |
|
guess_mode=guess_mode, |
|
) |
|
|
|
images.append(image_) |
|
|
|
image = images |
|
height, width = image[0].shape[-2:] |
|
else: |
|
assert False |
|
|
|
|
|
self.scheduler.set_timesteps(num_inference_steps, device=device) |
|
timesteps = self.scheduler.timesteps |
|
|
|
|
|
num_channels_latents = self.unet.config.in_channels |
|
latents = self.prepare_latents( |
|
batch_size * num_images_per_prompt, |
|
num_channels_latents, |
|
height, |
|
width, |
|
prompt_embeds.dtype, |
|
device, |
|
generator, |
|
latents, |
|
) |
|
self.init_latent = latents.detach().clone() |
|
|
|
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) |
|
|
|
|
|
|
|
all_steps = len(self.scheduler.timesteps) |
|
curr_span = 1 |
|
curr_step = 0 |
|
|
|
|
|
idx = 1 |
|
keytime = [0,1,2,3,5,10,15,25,35,50] |
|
|
|
while curr_step<all_steps: |
|
|
|
|
|
refister_time(self.unet, curr_step) |
|
|
|
merge_span = curr_span |
|
if merge_span>0: |
|
time_ls = [] |
|
for i in range(curr_step, curr_step+merge_span): |
|
if i<all_steps: |
|
time_ls.append(self.scheduler.timesteps[i]) |
|
else: |
|
break |
|
|
|
|
|
|
|
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents |
|
latent_model_input = self.scheduler.scale_model_input(latent_model_input, time_ls[0]) |
|
|
|
if curr_step in [0,1,2,3,5,10,15,25,35]: |
|
|
|
control_model_input = latent_model_input |
|
controlnet_prompt_embeds = prompt_embeds |
|
|
|
down_block_res_samples, mid_block_res_sample = self.controlnet( |
|
control_model_input, |
|
time_ls[0], |
|
encoder_hidden_states=controlnet_prompt_embeds, |
|
controlnet_cond=image, |
|
conditioning_scale=controlnet_conditioning_scale, |
|
guess_mode=guess_mode, |
|
return_dict=False, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
else: |
|
down_block_res_samples = None |
|
mid_block_res_sample = None |
|
|
|
noise_pred = self.unet( |
|
latent_model_input, |
|
time_ls, |
|
encoder_hidden_states=prompt_embeds, |
|
cross_attention_kwargs=cross_attention_kwargs, |
|
down_block_additional_residuals=down_block_res_samples, |
|
mid_block_additional_residual=mid_block_res_sample, |
|
return_dict=False, |
|
)[0] |
|
|
|
|
|
if do_classifier_free_guidance: |
|
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) |
|
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) |
|
|
|
|
|
|
|
if isinstance(time_ls, list): |
|
step_span = len(time_ls) |
|
bs = noise_pred.shape[0] |
|
bs_perstep = bs//step_span |
|
|
|
denoised_latent = latents |
|
for i, timestep in enumerate(time_ls): |
|
curr_noise = noise_pred[i*bs_perstep:(i+1)*bs_perstep] |
|
denoised_latent = self.scheduler.step(curr_noise, timestep, denoised_latent, **extra_step_kwargs, return_dict=False)[0] |
|
|
|
latents = denoised_latent |
|
|
|
curr_step += curr_span |
|
idx += 1 |
|
if curr_step<all_steps: |
|
curr_span = keytime[idx] - keytime[idx-1] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: |
|
self.unet.to("cpu") |
|
self.controlnet.to("cpu") |
|
torch.cuda.empty_cache() |
|
|
|
if not output_type == "latent": |
|
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] |
|
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) |
|
else: |
|
image = latents |
|
has_nsfw_concept = None |
|
|
|
if has_nsfw_concept is None: |
|
do_denormalize = [True] * image.shape[0] |
|
else: |
|
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] |
|
|
|
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) |
|
|
|
|
|
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: |
|
self.final_offload_hook.offload() |
|
|
|
if not return_dict: |
|
return (image, has_nsfw_concept) |
|
|
|
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) |
|
return call |
|
pipe.call = new_call(pipe) |
|
|
|
@torch.no_grad() |
|
def multistep_pre(self, noise_pred, t, x): |
|
step_span = len(t) |
|
bs = noise_pred.shape[0] |
|
bs_perstep = bs//step_span |
|
|
|
denoised_latent = x |
|
for i, timestep in enumerate(t): |
|
curr_noise = noise_pred[i*bs_perstep:(i+1)*bs_perstep] |
|
denoised_latent = self.scheduler.step(curr_noise, timestep, denoised_latent)['prev_sample'] |
|
return denoised_latent |
|
|
|
def register_t2v(model): |
|
def new_back(self): |
|
def backward_loop( |
|
latents, |
|
timesteps, |
|
prompt_embeds, |
|
guidance_scale, |
|
callback, |
|
callback_steps, |
|
num_warmup_steps, |
|
extra_step_kwargs, |
|
cross_attention_kwargs=None,): |
|
do_classifier_free_guidance = guidance_scale > 1.0 |
|
num_steps = (len(timesteps) - num_warmup_steps) // self.scheduler.order |
|
import time |
|
if num_steps<10: |
|
with self.progress_bar(total=num_steps) as progress_bar: |
|
for i, t in enumerate(timesteps): |
|
setattr(self.unet, 'order', i) |
|
|
|
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents |
|
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) |
|
|
|
|
|
noise_pred = self.unet( |
|
latent_model_input, |
|
t, |
|
encoder_hidden_states=prompt_embeds, |
|
cross_attention_kwargs=cross_attention_kwargs, |
|
).sample |
|
|
|
|
|
if do_classifier_free_guidance: |
|
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) |
|
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) |
|
|
|
|
|
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample |
|
|
|
|
|
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): |
|
progress_bar.update() |
|
if callback is not None and i % callback_steps == 0: |
|
step_idx = i // getattr(self.scheduler, "order", 1) |
|
callback(step_idx, t, latents) |
|
|
|
else: |
|
all_timesteps = len(timesteps) |
|
curr_step = 0 |
|
|
|
while curr_step<all_timesteps: |
|
refister_time(self.unet, curr_step) |
|
|
|
time_ls = [] |
|
time_ls.append(timesteps[curr_step]) |
|
curr_step += 1 |
|
cond = curr_step in [0,1,2,3,5,10,15,25,35] |
|
|
|
while (not cond) and (curr_step<all_timesteps): |
|
time_ls.append(timesteps[curr_step]) |
|
curr_step += 1 |
|
cond = curr_step in [0,1,2,3,5,10,15,25,35] |
|
|
|
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents |
|
|
|
noise_pred = self.unet( |
|
latent_model_input, |
|
time_ls, |
|
encoder_hidden_states=prompt_embeds, |
|
cross_attention_kwargs=cross_attention_kwargs, |
|
).sample |
|
|
|
|
|
if do_classifier_free_guidance: |
|
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) |
|
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) |
|
|
|
|
|
latents = multistep_pre(self, noise_pred, time_ls, latents) |
|
|
|
return latents.clone().detach() |
|
return backward_loop |
|
model.backward_loop = new_back(model) |
|
|