import torch from transformers import CLIPModel, CLIPTextModel, CLIPTokenizer from omegaconf import OmegaConf import matplotlib.pyplot as plt import math import imageio from PIL import Image import torchvision import torch.nn.functional as F import torch import numpy as np from PIL import Image import matplotlib.pyplot as plt import time import datetime import torch import sys import os from torchvision import datasets import pickle # StableDiffusion P2P implementation originally from https://github.com/bloc97/CrossAttentionControl use_half_prec = True if use_half_prec: from my_half_diffusers import AutoencoderKL, UNet2DConditionModel from my_half_diffusers.schedulers.scheduling_utils import SchedulerOutput from my_half_diffusers import LMSDiscreteScheduler, PNDMScheduler, DDPMScheduler, DDIMScheduler else: from my_diffusers import AutoencoderKL, UNet2DConditionModel from my_diffusers.schedulers.scheduling_utils import SchedulerOutput from my_diffusers import LMSDiscreteScheduler, PNDMScheduler, DDPMScheduler, DDIMScheduler torch_dtype = torch.float16 if use_half_prec else torch.float64 np_dtype = np.float16 if use_half_prec else np.float64 import random from tqdm.auto import tqdm from torch import autocast from difflib import SequenceMatcher # Build our CLIP model model_path_clip = "openai/clip-vit-large-patch14" clip_tokenizer = CLIPTokenizer.from_pretrained(model_path_clip) clip_model = CLIPModel.from_pretrained(model_path_clip, torch_dtype=torch_dtype) clip = clip_model.text_model # Getting our HF Auth token auth_token = os.environ.get('hf_auth') if auth_token is None: with open('hf_auth', 'r') as f: auth_token = f.readlines()[0].strip() model_path_diffusion = "CompVis/stable-diffusion-v1-4" # Build our SD model unet = UNet2DConditionModel.from_pretrained(model_path_diffusion, subfolder="unet", use_auth_token=auth_token, revision="fp16", torch_dtype=torch_dtype) vae = AutoencoderKL.from_pretrained(model_path_diffusion, subfolder="vae", use_auth_token=auth_token, revision="fp16", torch_dtype=torch_dtype) # Push to devices w/ double precision device = 'cuda' if use_half_prec: unet.to(device) vae.to(device) clip.to(device) else: unet.double().to(device) vae.double().to(device) clip.double().to(device) print("Loaded all models") def EDICT_editing(im_path, base_prompt, edit_prompt, use_p2p=False, steps=50, mix_weight=0.93, init_image_strength=0.8, guidance_scale=3, run_baseline=False, width=512, height=512): """ Main call of our research, performs editing with either EDICT or DDIM Args: im_path: path to image to run on base_prompt: conditional prompt to deterministically noise with edit_prompt: desired text conditoining steps: ddim steps mix_weight: Weight of mixing layers. Higher means more consistent generations but divergence in inversion Lower means opposite This is fairly tuned and can get good results init_image_strength: Editing strength. Higher = more dramatic edit. Typically [0.6, 0.9] is good range. Definitely tunable per-image/maybe best results are at a different value guidance_scale: classifier-free guidance scale 3 I've found is the best for both our method and basic DDIM inversion Higher can result in more distorted results run_baseline: VERY IMPORTANT True is EDICT, False is DDIM Output: PAIR of Images (tuple) If run_baseline=True then [0] will be edit and [1] will be original If run_baseline=False then they will be two nearly identical edited versions """ # Resize/center crop to 512x512 (Can do higher res. if desired) if isinstance(im_path, str): orig_im = load_im_into_format_from_path(im_path) elif Image.isImageType(im_path): width, height = im_path.size # add max dim for sake of memory max_dim = max(width, height) if max_dim > 1024: factor = 1024 / max_dim width *= factor height *= factor width = int(width) height = int(height) im_path = im_path.resize((width, height)) min_dim = min(width, height) if min_dim < 512: factor = 512 / min_dim width *= factor height *= factor width = int(width) height = int(height) im_path = im_path.resize((width, height)) width = width - (width%64) height = height - (height%64) orig_im = im_path # general_crop(im_path, width, height) else: orig_im = im_path # compute latent pair (second one will be original latent if run_baseline=True) latents = coupled_stablediffusion(base_prompt, reverse=True, init_image=orig_im, init_image_strength=init_image_strength, steps=steps, mix_weight=mix_weight, guidance_scale=guidance_scale, run_baseline=run_baseline, width=width, height=height) # Denoise intermediate state with new conditioning gen = coupled_stablediffusion(edit_prompt if (not use_p2p) else base_prompt, None if (not use_p2p) else edit_prompt, fixed_starting_latent=latents, init_image_strength=init_image_strength, steps=steps, mix_weight=mix_weight, guidance_scale=guidance_scale, run_baseline=run_baseline, width=width, height=height) return gen def img2img_editing(im_path, edit_prompt, steps=50, init_image_strength=0.7, guidance_scale=3): """ Basic SDEdit/img2img, given an image add some noise and denoise with prompt """ orig_im = load_im_into_format_from_path(im_path) return baseline_stablediffusion(edit_prompt, init_image_strength=init_image_strength, steps=steps, init_image=orig_im, guidance_scale=guidance_scale) def center_crop(im): width, height = im.size # Get dimensions min_dim = min(width, height) left = (width - min_dim)/2 top = (height - min_dim)/2 right = (width + min_dim)/2 bottom = (height + min_dim)/2 # Crop the center of the image im = im.crop((left, top, right, bottom)) return im def general_crop(im, target_w, target_h): width, height = im.size # Get dimensions min_dim = min(width, height) left = target_w / 2 # (width - min_dim)/2 top = target_h / 2 # (height - min_dim)/2 right = width - (target_w / 2) # (width + min_dim)/2 bottom = height - (target_h / 2) # (height + min_dim)/2 # Crop the center of the image im = im.crop((left, top, right, bottom)) return im def load_im_into_format_from_path(im_path): return center_crop(Image.open(im_path)).resize((512,512)) #### P2P STUFF #### def init_attention_weights(weight_tuples): tokens_length = clip_tokenizer.model_max_length weights = torch.ones(tokens_length) for i, w in weight_tuples: if i < tokens_length and i >= 0: weights[i] = w for name, module in unet.named_modules(): module_name = type(module).__name__ if module_name == "CrossAttention" and "attn2" in name: module.last_attn_slice_weights = weights.to(device) if module_name == "CrossAttention" and "attn1" in name: module.last_attn_slice_weights = None def init_attention_edit(tokens, tokens_edit): tokens_length = clip_tokenizer.model_max_length mask = torch.zeros(tokens_length) indices_target = torch.arange(tokens_length, dtype=torch.long) indices = torch.zeros(tokens_length, dtype=torch.long) tokens = tokens.input_ids.numpy()[0] tokens_edit = tokens_edit.input_ids.numpy()[0] for name, a0, a1, b0, b1 in SequenceMatcher(None, tokens, tokens_edit).get_opcodes(): if b0 < tokens_length: if name == "equal" or (name == "replace" and a1-a0 == b1-b0): mask[b0:b1] = 1 indices[b0:b1] = indices_target[a0:a1] for name, module in unet.named_modules(): module_name = type(module).__name__ if module_name == "CrossAttention" and "attn2" in name: module.last_attn_slice_mask = mask.to(device) module.last_attn_slice_indices = indices.to(device) if module_name == "CrossAttention" and "attn1" in name: module.last_attn_slice_mask = None module.last_attn_slice_indices = None def init_attention_func(): def new_attention(self, query, key, value, sequence_length, dim): batch_size_attention = query.shape[0] hidden_states = torch.zeros( (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype ) slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0] for i in range(hidden_states.shape[0] // slice_size): start_idx = i * slice_size end_idx = (i + 1) * slice_size attn_slice = ( torch.einsum("b i d, b j d -> b i j", query[start_idx:end_idx], key[start_idx:end_idx]) * self.scale ) attn_slice = attn_slice.softmax(dim=-1) if self.use_last_attn_slice: if self.last_attn_slice_mask is not None: new_attn_slice = torch.index_select(self.last_attn_slice, -1, self.last_attn_slice_indices) attn_slice = attn_slice * (1 - self.last_attn_slice_mask) + new_attn_slice * self.last_attn_slice_mask else: attn_slice = self.last_attn_slice self.use_last_attn_slice = False if self.save_last_attn_slice: self.last_attn_slice = attn_slice self.save_last_attn_slice = False if self.use_last_attn_weights and self.last_attn_slice_weights is not None: attn_slice = attn_slice * self.last_attn_slice_weights self.use_last_attn_weights = False attn_slice = torch.einsum("b i j, b j d -> b i d", attn_slice, value[start_idx:end_idx]) hidden_states[start_idx:end_idx] = attn_slice # reshape hidden_states hidden_states = self.reshape_batch_dim_to_heads(hidden_states) return hidden_states for name, module in unet.named_modules(): module_name = type(module).__name__ if module_name == "CrossAttention": module.last_attn_slice = None module.use_last_attn_slice = False module.use_last_attn_weights = False module.save_last_attn_slice = False module._attention = new_attention.__get__(module, type(module)) def use_last_tokens_attention(use=True): for name, module in unet.named_modules(): module_name = type(module).__name__ if module_name == "CrossAttention" and "attn2" in name: module.use_last_attn_slice = use def use_last_tokens_attention_weights(use=True): for name, module in unet.named_modules(): module_name = type(module).__name__ if module_name == "CrossAttention" and "attn2" in name: module.use_last_attn_weights = use def use_last_self_attention(use=True): for name, module in unet.named_modules(): module_name = type(module).__name__ if module_name == "CrossAttention" and "attn1" in name: module.use_last_attn_slice = use def save_last_tokens_attention(save=True): for name, module in unet.named_modules(): module_name = type(module).__name__ if module_name == "CrossAttention" and "attn2" in name: module.save_last_attn_slice = save def save_last_self_attention(save=True): for name, module in unet.named_modules(): module_name = type(module).__name__ if module_name == "CrossAttention" and "attn1" in name: module.save_last_attn_slice = save #################################### ##### BASELINE ALGORITHM, ONLY USED NOW FOR SDEDIT ####3 @torch.no_grad() def baseline_stablediffusion(prompt="", prompt_edit=None, null_prompt='', prompt_edit_token_weights=[], prompt_edit_tokens_start=0.0, prompt_edit_tokens_end=1.0, prompt_edit_spatial_start=0.0, prompt_edit_spatial_end=1.0, clip_start=0.0, clip_end=1.0, guidance_scale=7, steps=50, seed=1, width=512, height=512, init_image=None, init_image_strength=0.5, fixed_starting_latent = None, prev_image= None, grid=None, clip_guidance=None, clip_guidance_scale=1, num_cutouts=4, cut_power=1, scheduler_str='lms', return_latent=False, one_pass=False, normalize_noise_pred=False): width = width - width % 64 height = height - height % 64 #If seed is None, randomly select seed from 0 to 2^32-1 if seed is None: seed = random.randrange(2**32 - 1) generator = torch.cuda.manual_seed(seed) #Set inference timesteps to scheduler scheduler_dict = {'ddim':DDIMScheduler, 'lms':LMSDiscreteScheduler, 'pndm':PNDMScheduler, 'ddpm':DDPMScheduler} scheduler_call = scheduler_dict[scheduler_str] if scheduler_str == 'ddim': scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False) else: scheduler = scheduler_call(beta_schedule="scaled_linear", num_train_timesteps=1000) scheduler.set_timesteps(steps) if prev_image is not None: prev_scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000) prev_scheduler.set_timesteps(steps) #Preprocess image if it exists (img2img) if init_image is not None: init_image = init_image.resize((width, height), resample=Image.Resampling.LANCZOS) init_image = np.array(init_image).astype(np_dtype) / 255.0 * 2.0 - 1.0 init_image = torch.from_numpy(init_image[np.newaxis, ...].transpose(0, 3, 1, 2)) #If there is alpha channel, composite alpha for white, as the diffusion model does not support alpha channel if init_image.shape[1] > 3: init_image = init_image[:, :3] * init_image[:, 3:] + (1 - init_image[:, 3:]) #Move image to GPU init_image = init_image.to(device) #Encode image with autocast(device): init_latent = vae.encode(init_image).latent_dist.sample(generator=generator) * 0.18215 t_start = steps - int(steps * init_image_strength) else: init_latent = torch.zeros((1, unet.in_channels, height // 8, width // 8), device=device) t_start = 0 #Generate random normal noise if fixed_starting_latent is None: noise = torch.randn(init_latent.shape, generator=generator, device=device, dtype=unet.dtype) if scheduler_str == 'ddim': if init_image is not None: raise notImplementedError latent = scheduler.add_noise(init_latent, noise, 1000 - int(1000 * init_image_strength)).to(device) else: latent = noise else: latent = scheduler.add_noise(init_latent, noise, t_start).to(device) else: latent = fixed_starting_latent t_start = steps - int(steps * init_image_strength) if prev_image is not None: #Resize and prev_image for numpy b h w c -> torch b c h w prev_image = prev_image.resize((width, height), resample=Image.Resampling.LANCZOS) prev_image = np.array(prev_image).astype(np_dtype) / 255.0 * 2.0 - 1.0 prev_image = torch.from_numpy(prev_image[np.newaxis, ...].transpose(0, 3, 1, 2)) #If there is alpha channel, composite alpha for white, as the diffusion model does not support alpha channel if prev_image.shape[1] > 3: prev_image = prev_image[:, :3] * prev_image[:, 3:] + (1 - prev_image[:, 3:]) #Move image to GPU prev_image = prev_image.to(device) #Encode image with autocast(device): prev_init_latent = vae.encode(prev_image).latent_dist.sample(generator=generator) * 0.18215 t_start = steps - int(steps * init_image_strength) prev_latent = prev_scheduler.add_noise(prev_init_latent, noise, t_start).to(device) else: prev_latent = None #Process clip with autocast(device): tokens_unconditional = clip_tokenizer(null_prompt, padding="max_length", max_length=clip_tokenizer.model_max_length, truncation=True, return_tensors="pt", return_overflowing_tokens=True) embedding_unconditional = clip(tokens_unconditional.input_ids.to(device)).last_hidden_state tokens_conditional = clip_tokenizer(prompt, padding="max_length", max_length=clip_tokenizer.model_max_length, truncation=True, return_tensors="pt", return_overflowing_tokens=True) embedding_conditional = clip(tokens_conditional.input_ids.to(device)).last_hidden_state #Process prompt editing assert not ((prompt_edit is not None) and (prev_image is not None)) if prompt_edit is not None: tokens_conditional_edit = clip_tokenizer(prompt_edit, padding="max_length", max_length=clip_tokenizer.model_max_length, truncation=True, return_tensors="pt", return_overflowing_tokens=True) embedding_conditional_edit = clip(tokens_conditional_edit.input_ids.to(device)).last_hidden_state init_attention_edit(tokens_conditional, tokens_conditional_edit) elif prev_image is not None: init_attention_edit(tokens_conditional, tokens_conditional) init_attention_func() init_attention_weights(prompt_edit_token_weights) timesteps = scheduler.timesteps[t_start:] # print(timesteps) assert isinstance(guidance_scale, int) num_cycles = 1 # guidance_scale + 1 last_noise_preds = None for i, t in tqdm(enumerate(timesteps), total=len(timesteps)): t_index = t_start + i latent_model_input = latent if scheduler_str=='lms': sigma = scheduler.sigmas[t_index] # last is first and first is last latent_model_input = (latent_model_input / ((sigma**2 + 1) ** 0.5)).to(unet.dtype) else: assert scheduler_str in ['ddim', 'pndm', 'ddpm'] #Predict the unconditional noise residual if len(t.shape) == 0: t = t[None].to(unet.device) noise_pred_uncond = unet(latent_model_input, t, encoder_hidden_states=embedding_unconditional, ).sample if prev_latent is not None: prev_latent_model_input = prev_latent prev_latent_model_input = (prev_latent_model_input / ((sigma**2 + 1) ** 0.5)).to(unet.dtype) prev_noise_pred_uncond = unet(prev_latent_model_input, t, encoder_hidden_states=embedding_unconditional, ).sample # noise_pred_uncond = unet(latent_model_input, t, # encoder_hidden_states=embedding_unconditional)['sample'] #Prepare the Cross-Attention layers if prompt_edit is not None or prev_latent is not None: save_last_tokens_attention() save_last_self_attention() else: #Use weights on non-edited prompt when edit is None use_last_tokens_attention_weights() #Predict the conditional noise residual and save the cross-attention layer activations if prev_latent is not None: raise NotImplementedError # I totally lost track of what this is prev_noise_pred_cond = unet(prev_latent_model_input, t, encoder_hidden_states=embedding_conditional, ).sample else: noise_pred_cond = unet(latent_model_input, t, encoder_hidden_states=embedding_conditional, ).sample #Edit the Cross-Attention layer activations t_scale = t / scheduler.num_train_timesteps if prompt_edit is not None or prev_latent is not None: if t_scale >= prompt_edit_tokens_start and t_scale <= prompt_edit_tokens_end: use_last_tokens_attention() if t_scale >= prompt_edit_spatial_start and t_scale <= prompt_edit_spatial_end: use_last_self_attention() #Use weights on edited prompt use_last_tokens_attention_weights() #Predict the edited conditional noise residual using the cross-attention masks if prompt_edit is not None: noise_pred_cond = unet(latent_model_input, t, encoder_hidden_states=embedding_conditional_edit).sample #Perform guidance # if i%(num_cycles)==0: # cycle_i+1==num_cycles: """ if cycle_i+1==num_cycles: noise_pred = noise_pred_uncond else: noise_pred = noise_pred_cond - noise_pred_uncond """ if last_noise_preds is not None: # print( (last_noise_preds[0]*noise_pred_uncond).sum(), (last_noise_preds[1]*noise_pred_cond).sum()) # print(F.cosine_similarity(last_noise_preds[0].flatten(), noise_pred_uncond.flatten(), dim=0), # F.cosine_similarity(last_noise_preds[1].flatten(), noise_pred_cond.flatten(), dim=0)) last_grad= last_noise_preds[1] - last_noise_preds[0] new_grad = noise_pred_cond - noise_pred_uncond # print( F.cosine_similarity(last_grad.flatten(), new_grad.flatten(), dim=0)) last_noise_preds = (noise_pred_uncond, noise_pred_cond) use_cond_guidance = True if use_cond_guidance: noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) else: noise_pred = noise_pred_uncond if clip_guidance is not None and t_scale >= clip_start and t_scale <= clip_end: noise_pred, latent = new_cond_fn(latent, t, t_index, embedding_conditional, noise_pred,clip_guidance, clip_guidance_scale, num_cutouts, scheduler, unet,use_cutouts=True, cut_power=cut_power) if normalize_noise_pred: noise_pred = noise_pred * noise_pred_uncond.norm() / noise_pred.norm() if scheduler_str == 'ddim': latent = forward_step(scheduler, noise_pred, t, latent).prev_sample else: latent = scheduler.step(noise_pred, t_index, latent).prev_sample if prev_latent is not None: prev_noise_pred = prev_noise_pred_uncond + guidance_scale * (prev_noise_pred_cond - prev_noise_pred_uncond) prev_latent = prev_scheduler.step(prev_noise_pred, t_index, prev_latent).prev_sample if one_pass: break #scale and decode the image latents with vae if return_latent: return latent latent = latent / 0.18215 image = vae.decode(latent.to(vae.dtype)).sample image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).numpy() image = (image[0] * 255).round().astype("uint8") return Image.fromarray(image) #################################### #### HELPER FUNCTIONS FOR OUR METHOD ##### def get_alpha_and_beta(t, scheduler): # want to run this for both current and previous timnestep if t.dtype==torch.long: alpha = scheduler.alphas_cumprod[t] return alpha, 1-alpha if t<0: return scheduler.final_alpha_cumprod, 1 - scheduler.final_alpha_cumprod low = t.floor().long() high = t.ceil().long() rem = t - low low_alpha = scheduler.alphas_cumprod[low] high_alpha = scheduler.alphas_cumprod[high] interpolated_alpha = low_alpha * rem + high_alpha * (1-rem) interpolated_beta = 1 - interpolated_alpha return interpolated_alpha, interpolated_beta # A DDIM forward step function def forward_step( self, model_output, timestep: int, sample, eta: float = 0.0, use_clipped_model_output: bool = False, generator=None, return_dict: bool = True, use_double=False, ) : if self.num_inference_steps is None: raise ValueError( "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" ) prev_timestep = timestep - self.config.num_train_timesteps / self.num_inference_steps if timestep > self.timesteps.max(): raise NotImplementedError("Need to double check what the overflow is") alpha_prod_t, beta_prod_t = get_alpha_and_beta(timestep, self) alpha_prod_t_prev, _ = get_alpha_and_beta(prev_timestep, self) alpha_quotient = ((alpha_prod_t / alpha_prod_t_prev)**0.5) first_term = (1./alpha_quotient) * sample second_term = (1./alpha_quotient) * (beta_prod_t ** 0.5) * model_output third_term = ((1 - alpha_prod_t_prev)**0.5) * model_output return first_term - second_term + third_term # A DDIM reverse step function, the inverse of above def reverse_step( self, model_output, timestep: int, sample, eta: float = 0.0, use_clipped_model_output: bool = False, generator=None, return_dict: bool = True, use_double=False, ) : if self.num_inference_steps is None: raise ValueError( "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" ) prev_timestep = timestep - self.config.num_train_timesteps / self.num_inference_steps if timestep > self.timesteps.max(): raise NotImplementedError else: alpha_prod_t = self.alphas_cumprod[timestep] alpha_prod_t, beta_prod_t = get_alpha_and_beta(timestep, self) alpha_prod_t_prev, _ = get_alpha_and_beta(prev_timestep, self) alpha_quotient = ((alpha_prod_t / alpha_prod_t_prev)**0.5) first_term = alpha_quotient * sample second_term = ((beta_prod_t)**0.5) * model_output third_term = alpha_quotient * ((1 - alpha_prod_t_prev)**0.5) * model_output return first_term + second_term - third_term @torch.no_grad() def latent_to_image(latent): image = vae.decode(latent.to(vae.dtype)/0.18215).sample image = prep_image_for_return(image) return image def prep_image_for_return(image): image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).numpy() image = (image[0] * 255).round().astype("uint8") image = Image.fromarray(image) return image ############################# ##### MAIN EDICT FUNCTION ####### # Use EDICT_editing to perform calls @torch.no_grad() def coupled_stablediffusion(prompt="", prompt_edit=None, null_prompt='', prompt_edit_token_weights=[], prompt_edit_tokens_start=0.0, prompt_edit_tokens_end=1.0, prompt_edit_spatial_start=0.0, prompt_edit_spatial_end=1.0, guidance_scale=7.0, steps=50, seed=1, width=512, height=512, init_image=None, init_image_strength=1.0, run_baseline=False, use_lms=False, leapfrog_steps=True, reverse=False, return_latents=False, fixed_starting_latent=None, beta_schedule='scaled_linear', mix_weight=0.93): #If seed is None, randomly select seed from 0 to 2^32-1 if seed is None: seed = random.randrange(2**32 - 1) generator = torch.cuda.manual_seed(seed) def image_to_latent(im): if isinstance(im, torch.Tensor): # assume it's the latent # used to avoid clipping new generation before inversion init_latent = im.to(device) else: #Resize and transpose for numpy b h w c -> torch b c h w im = im.resize((width, height), resample=Image.Resampling.LANCZOS) im = np.array(im).astype(np_dtype) / 255.0 * 2.0 - 1.0 # check if black and white if len(im.shape) < 3: im = np.stack([im for _ in range(3)], axis=2) # putting at end b/c channels im = torch.from_numpy(im[np.newaxis, ...].transpose(0, 3, 1, 2)) #If there is alpha channel, composite alpha for white, as the diffusion model does not support alpha channel if im.shape[1] > 3: im = im[:, :3] * im[:, 3:] + (1 - im[:, 3:]) #Move image to GPU im = im.to(device) #Encode image if use_half_prec: init_latent = vae.encode(im).latent_dist.sample(generator=generator) * 0.18215 else: with autocast(device): init_latent = vae.encode(im).latent_dist.sample(generator=generator) * 0.18215 return init_latent assert not use_lms, "Can't invert LMS the same as DDIM" if run_baseline: leapfrog_steps=False #Change size to multiple of 64 to prevent size mismatches inside model width = width - width % 64 height = height - height % 64 #Preprocess image if it exists (img2img) if init_image is not None: assert reverse # want to be performing deterministic noising # can take either pair (output of generative process) or single image if isinstance(init_image, list): if isinstance(init_image[0], torch.Tensor): init_latent = [t.clone() for t in init_image] else: init_latent = [image_to_latent(im) for im in init_image] else: init_latent = image_to_latent(init_image) # this is t_start for forward, t_end for reverse t_limit = steps - int(steps * init_image_strength) else: assert not reverse, 'Need image to reverse from' init_latent = torch.zeros((1, unet.in_channels, height // 8, width // 8), device=device) t_limit = 0 if reverse: latent = init_latent else: #Generate random normal noise noise = torch.randn(init_latent.shape, generator=generator, device=device, dtype=torch_dtype) if fixed_starting_latent is None: latent = noise else: if isinstance(fixed_starting_latent, list): latent = [l.clone() for l in fixed_starting_latent] else: latent = fixed_starting_latent.clone() t_limit = steps - int(steps * init_image_strength) if isinstance(latent, list): # initializing from pair of images latent_pair = latent else: # initializing from noise latent_pair = [latent.clone(), latent.clone()] if steps==0: if init_image is not None: return image_to_latent(init_image) else: image = vae.decode(latent.to(vae.dtype) / 0.18215).sample return prep_image_for_return(image) #Set inference timesteps to scheduler schedulers = [] for i in range(2): # num_raw_timesteps = max(1000, steps) scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule=beta_schedule, num_train_timesteps=1000, clip_sample=False, set_alpha_to_one=False) scheduler.set_timesteps(steps) schedulers.append(scheduler) with autocast(device): # CLIP Text Embeddings tokens_unconditional = clip_tokenizer(null_prompt, padding="max_length", max_length=clip_tokenizer.model_max_length, truncation=True, return_tensors="pt", return_overflowing_tokens=True) embedding_unconditional = clip(tokens_unconditional.input_ids.to(device)).last_hidden_state tokens_conditional = clip_tokenizer(prompt, padding="max_length", max_length=clip_tokenizer.model_max_length, truncation=True, return_tensors="pt", return_overflowing_tokens=True) embedding_conditional = clip(tokens_conditional.input_ids.to(device)).last_hidden_state #Process prompt editing (if running Prompt-to-Prompt) if prompt_edit is not None: tokens_conditional_edit = clip_tokenizer(prompt_edit, padding="max_length", max_length=clip_tokenizer.model_max_length, truncation=True, return_tensors="pt", return_overflowing_tokens=True) embedding_conditional_edit = clip(tokens_conditional_edit.input_ids.to(device)).last_hidden_state init_attention_edit(tokens_conditional, tokens_conditional_edit) init_attention_func() init_attention_weights(prompt_edit_token_weights) timesteps = schedulers[0].timesteps[t_limit:] if reverse: timesteps = timesteps.flip(0) for i, t in tqdm(enumerate(timesteps), total=len(timesteps)): t_scale = t / schedulers[0].num_train_timesteps if (reverse) and (not run_baseline): # Reverse mixing layer new_latents = [l.clone() for l in latent_pair] new_latents[1] = (new_latents[1].clone() - (1-mix_weight)*new_latents[0].clone()) / mix_weight new_latents[0] = (new_latents[0].clone() - (1-mix_weight)*new_latents[1].clone()) / mix_weight latent_pair = new_latents # alternate EDICT steps for latent_i in range(2): if run_baseline and latent_i==1: continue # just have one sequence for baseline # this modifies latent_pair[i] while using # latent_pair[(i+1)%2] if reverse and (not run_baseline): if leapfrog_steps: # what i would be from going other way orig_i = len(timesteps) - (i+1) offset = (orig_i+1) % 2 latent_i = (latent_i + offset) % 2 else: # Do 1 then 0 latent_i = (latent_i+1)%2 else: if leapfrog_steps: offset = i%2 latent_i = (latent_i + offset) % 2 latent_j = ((latent_i+1) % 2) if not run_baseline else latent_i latent_model_input = latent_pair[latent_j] latent_base = latent_pair[latent_i] #Predict the unconditional noise residual noise_pred_uncond = unet(latent_model_input, t, encoder_hidden_states=embedding_unconditional).sample #Prepare the Cross-Attention layers if prompt_edit is not None: save_last_tokens_attention() save_last_self_attention() else: #Use weights on non-edited prompt when edit is None use_last_tokens_attention_weights() #Predict the conditional noise residual and save the cross-attention layer activations noise_pred_cond = unet(latent_model_input, t, encoder_hidden_states=embedding_conditional).sample #Edit the Cross-Attention layer activations if prompt_edit is not None: t_scale = t / schedulers[0].num_train_timesteps if t_scale >= prompt_edit_tokens_start and t_scale <= prompt_edit_tokens_end: use_last_tokens_attention() if t_scale >= prompt_edit_spatial_start and t_scale <= prompt_edit_spatial_end: use_last_self_attention() #Use weights on edited prompt use_last_tokens_attention_weights() #Predict the edited conditional noise residual using the cross-attention masks noise_pred_cond = unet(latent_model_input, t, encoder_hidden_states=embedding_conditional_edit).sample #Perform guidance grad = (noise_pred_cond - noise_pred_uncond) noise_pred = noise_pred_uncond + guidance_scale * grad step_call = reverse_step if reverse else forward_step new_latent = step_call(schedulers[latent_i], noise_pred, t, latent_base)# .prev_sample new_latent = new_latent.to(latent_base.dtype) latent_pair[latent_i] = new_latent if (not reverse) and (not run_baseline): # Mixing layer (contraction) during generative process new_latents = [l.clone() for l in latent_pair] new_latents[0] = (mix_weight*new_latents[0] + (1-mix_weight)*new_latents[1]).clone() new_latents[1] = ((1-mix_weight)*new_latents[0] + (mix_weight)*new_latents[1]).clone() latent_pair = new_latents #scale and decode the image latents with vae, can return latents instead of images if reverse or return_latents: results = [latent_pair] return results if len(results)>1 else results[0] # decode latents to iamges images = [] for latent_i in range(2): latent = latent_pair[latent_i] / 0.18215 image = vae.decode(latent.to(vae.dtype)).sample images.append(image) # Return images return_arr = [] for image in images: image = prep_image_for_return(image) return_arr.append(image) results = [return_arr] return results if len(results)>1 else results[0]