import os import yaml import argparse import warnings from PIL import Image from tqdm import tqdm from datetime import datetime import torch import torch.nn as nn import torchvision.transforms as T from transformers import logging from diffusers import DDIMScheduler, StableDiffusionPipeline from .ditail_utils import * # suppress warnings logging.set_verbosity_error() warnings.filterwarnings("ignore", message=".*LoRA backend.*") class DitailDemo(nn.Module): def __init__(self, args): super().__init__() self.args = args if isinstance(self.args, dict): for k, v in args.items(): setattr(self, k, v) else: for k, v in vars(args).items(): setattr(self, k, v) def load_inv_model(self): self.scheduler = DDIMScheduler.from_pretrained(self.inv_model, subfolder='scheduler') self.scheduler.set_timesteps(self.inv_steps, device=self.device) print(f'[INFO] Loading inversion model: {self.inv_model}') pipe = StableDiffusionPipeline.from_pretrained( self.inv_model, torch_dtype=torch.float16 ).to(self.device) pipe.enable_xformers_memory_efficient_attention() self.text_encoder = pipe.text_encoder self.tokenizer = pipe.tokenizer self.unet = pipe.unet self.vae = pipe.vae self.tokenizer_kwargs = dict( truncation=True, return_tensors='pt', padding='max_length', max_length=self.tokenizer.model_max_length ) def load_spl_model(self): self.scheduler = DDIMScheduler.from_pretrained(self.spl_model, subfolder='scheduler') self.scheduler.set_timesteps(self.spl_steps, device=self.device) print(f'[INFO] Loading sampling model: {self.spl_model}') if (self.lora != 'none') or (self.inv_model != self.spl_model): pipe = StableDiffusionPipeline.from_pretrained( self.spl_model, torch_dtype=torch.float16 ).to(self.device) if self.lora != 'none': # pipe.unfuse_lora() # pipe.unload_lora_weights() pipe.load_lora_weights(self.lora_dir, weight_name=f'{self.lora}.safetensors') pipe.fuse_lora(lora_scale=self.lora_scale) pipe.enable_xformers_memory_efficient_attention() self.text_encoder = pipe.text_encoder self.tokenizer = pipe.tokenizer self.unet = pipe.unet self.vae = pipe.vae self.tokenizer_kwargs = dict( truncation=True, return_tensors='pt', padding='max_length', max_length=self.tokenizer.model_max_length ) @torch.no_grad() def encode_image(self, image_pil): # image_pil = T.Resize(512)(img.convert('RGB')) image_pil = T.Resize(512)(image_pil) width, height = image_pil.size image = T.ToTensor()(image_pil).unsqueeze(0).to(self.device) with torch.autocast(device_type=self.device, dtype=torch.float32): image = 2 * image - 1 posterior = self.vae.encode(image).latent_dist latent = posterior.mean * 0.18215 return latent @torch.no_grad() def invert_image(self, cond, latent): self.latents = {} timesteps = reversed(self.scheduler.timesteps) with torch.autocast(device_type=self.device, dtype=torch.float32): for i, t in enumerate(tqdm(timesteps)): cond_batch = cond.repeat(latent.shape[0], 1, 1) alpha_prod_t = self.scheduler.alphas_cumprod[t] alpha_prod_t_prev = ( self.scheduler.alphas_cumprod[timesteps[i-1]] if i > 0 else self.scheduler.final_alpha_cumprod ) mu = alpha_prod_t ** 0.5 mu_prev = alpha_prod_t_prev ** 0.5 sigma = (1 - alpha_prod_t) ** 0.5 sigma_prev = (1 - alpha_prod_t_prev) ** 0.5 eps = self.unet(latent, t, encoder_hidden_states=cond_batch).sample pred_x0 = (latent - sigma_prev * eps) / mu_prev latent = mu * pred_x0 + sigma * eps self.latents[t.item()] = latent self.noisy_latent = latent @torch.no_grad() def extract_latents(self): # get the embeddings for pos & neg prompts # self.pos_prompt = ' ,'.join(LORA_TRIGGER_WORD.get(self.lora, [''])+[self.pos_prompt]) # print('the prompt after adding trigger word:', self.pos_prompt) text_pos = self.tokenizer(self.pos_prompt, **self.tokenizer_kwargs) text_neg = self.tokenizer(self.neg_prompt, **self.tokenizer_kwargs) self.emb_pos = self.text_encoder(text_pos.input_ids.to(self.device))[0] self.emb_neg = self.text_encoder(text_neg.input_ids.to(self.device))[0] # apply condition scaling cond = self.alpha * self.emb_pos - self.beta * self.emb_neg # encode source image & apply DDIM inversion self.invert_image(cond, self.encode_image(self.img)) @torch.no_grad() def latent_to_image(self, latent, save_path=None): with torch.autocast(device_type=self.device, dtype=torch.float32): latent = 1 / 0.18215 * latent image = self.vae.decode(latent).sample[0] image = (image / 2 + 0.5).clamp(0, 1) # T.ToPILImage()(image).save(save_path) return T.ToPILImage()(image) def init_injection(self, attn_ratio=0.5, conv_ratio=0.8): attn_thresh = int(attn_ratio * self.spl_steps) conv_thresh = int(conv_ratio * self.spl_steps) self.attn_inj_timesteps = self.scheduler.timesteps[:attn_thresh] self.conv_inj_timesteps = self.scheduler.timesteps[:conv_thresh] register_attn_inj(self, self.attn_inj_timesteps) register_conv_inj(self, self.conv_inj_timesteps) @torch.no_grad() def sampling_loop(self): # init text embeddings text_ept = self.tokenizer('', **self.tokenizer_kwargs) self.emb_ept = self.text_encoder(text_ept.input_ids.to(self.device))[0] self.emb_spl = torch.cat([self.emb_ept, self.emb_pos, self.emb_neg], dim=0) with torch.autocast(device_type=self.device, dtype=torch.float16): # use noisy latent as starting point x = self.latents[self.scheduler.timesteps[0].item()] # sampling loop for t in tqdm(self.scheduler.timesteps): # concat latents & register timestep src_latent = self.latents[t.item()] latents = torch.cat([src_latent, x, x]) register_time(self, t.item()) # apply U-Net for denoising noise_pred = self.unet(latents, t, encoder_hidden_states=self.emb_spl).sample # classifier-free guidance _, noise_pred_pos, noise_pred_neg = noise_pred.chunk(3) noise_pred = noise_pred_neg + self.omega * (noise_pred_pos - noise_pred_neg) # denoise step x = self.scheduler.step(noise_pred, t, x).prev_sample # save output latent self.output_latent = x def run_ditail(self): # init output dir & dump config os.makedirs(self.output_dir, exist_ok=True) # self.save_dir = get_save_dir(self.output_dir) # os.makedirs(self.save_dir, exist_ok=True) # with open(os.path.join(self.output_dir, 'config.yaml'), 'w') as f: # if isinstance(self.args, dict): # f.write(yaml.dump(self.args)) # else: # f.write(yaml.dump(vars(self.args))) # step 1: inversion stage self.load_inv_model() self.extract_latents() # self.latent_to_image( # latent=self.noisy_latent, # save_path=os.path.join(self.save_dir, 'noise.png') # ) # step 2: sampling stage self.load_spl_model() if not self.no_injection: self.init_injection() self.sampling_loop() return self.latent_to_image( latent=self.output_latent, # save_path=os.path.join(self.save_dir, 'output.png') ) def main(args): seed_everything(args.seed) ditail = DitailDemo(args) ditail.run_ditail() if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('--seed', type=int, default=42) parser.add_argument('--device', type=str, default='cuda') parser.add_argument('--output_dir', type=str, default='./output_demo') parser.add_argument('--inv_model', type=str, default='runwayml/stable-diffusion-v1-5', help='Pre-trained inversion model name or path (step 1)') parser.add_argument('--spl_model', type=str, default='runwayml/stable-diffusion-v1-5', help='Pre-trained sampling model name or path (step 2)') parser.add_argument('--inv_steps', type=int, default=50, help='Number of inversion steps (step 1)') parser.add_argument('--spl_steps', type=int, default=50, help='Number of sampling steps (step 2)') # parser.add_argument('--img_path', type=str, required=True, # help='Path to the source image') parser.add_argument('--pos_prompt', type=str, required=True, help='Positive prompt for inversion') parser.add_argument('--neg_prompt', type=str, default='worst quality, blurry, low res, NSFW', help='Negative prompt for inversion') parser.add_argument('--alpha', type=float, default=2.0, help='Positive prompt scaling factor') parser.add_argument('--beta', type=float, default=1.0, help='Negative prompt scaling factor') parser.add_argument('--omega', type=float, default=15, help='Classifier-free guidance factor') parser.add_argument('--mask', type=str, default='none', help='Optional mask for regional injection') parser.add_argument('--lora', type=str, default='none', help='Optional LoRA for the sampling stage') parser.add_argument('--lora_dir', type=str, default='./lora', help='Optional LoRA storing directory') parser.add_argument('--lora_scale', type=float, default=0.7, help='Optional LoRA scaling weight') parser.add_argument('--no_injection', action="store_true", help='Do not use PnP injection') args = parser.parse_args() main(args)