import os import torch import torchvision import torchvision.transforms as transforms from torch.utils.data import Dataset, DataLoader import gradio as gr import sys import uuid import tqdm sys.path.append(os.path.abspath(os.path.join("", ".."))) import gc import warnings warnings.filterwarnings("ignore") from PIL import Image import numpy as np from editing import get_direction, debias from sampling import sample_weights from lora_w2w import LoRAw2w from transformers import CLIPTextModel from lora_w2w import LoRAw2w from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel, LMSDiscreteScheduler from transformers import AutoTokenizer, PretrainedConfig from diffusers import ( AutoencoderKL, DDPMScheduler, DiffusionPipeline, DPMSolverMultistepScheduler, UNet2DConditionModel, PNDMScheduler, StableDiffusionPipeline ) from huggingface_hub import snapshot_download import spaces models_path = snapshot_download(repo_id="Snapchat/w2w") device = "cuda" pretrained_model_name_or_path = "stablediffusionapi/realistic-vision-v51" revision = None weight_dtype = torch.bfloat16 # Load scheduler, tokenizer and models. pipe = StableDiffusionPipeline.from_pretrained("stablediffusionapi/realistic-vision-v51", torch_dtype=torch.float16,safety_checker = None, requires_safety_checker = False).to(device) noise_scheduler = pipe.scheduler del pipe tokenizer = AutoTokenizer.from_pretrained( pretrained_model_name_or_path, subfolder="tokenizer", revision=revision ) text_encoder = CLIPTextModel.from_pretrained( pretrained_model_name_or_path, subfolder="text_encoder", revision=revision ) vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae", revision=revision) unet = UNet2DConditionModel.from_pretrained( pretrained_model_name_or_path, subfolder="unet", revision=revision ) unet.requires_grad_(False) unet.to(device, dtype=weight_dtype) vae.requires_grad_(False) text_encoder.requires_grad_(False) vae.requires_grad_(False) vae.to(device, dtype=weight_dtype) text_encoder.to(device, dtype=weight_dtype) print("") mean = torch.load(f"{models_path}/files/mean.pt", map_location=torch.device('cpu')).bfloat16().to(device) std = torch.load(f"{models_path}/files/std.pt", map_location=torch.device('cpu')).bfloat16().to(device) v = torch.load(f"{models_path}/files/V.pt", map_location=torch.device('cpu')).bfloat16().to(device) proj = torch.load(f"{models_path}/files/proj_1000pc.pt", map_location=torch.device('cpu')).bfloat16().to(device) df = torch.load(f"{models_path}/files/identity_df.pt") weight_dimensions = torch.load(f"{models_path}/files/weight_dimensions.pt") pinverse = torch.load(f"{models_path}/files/pinverse_1000pc.pt", map_location=torch.device('cpu')).bfloat16().to(device) young = get_direction(df, "Young", pinverse, 1000, device) young = debias(young, "Male", df, pinverse, device) young = debias(young, "Pointy_Nose", df, pinverse, device) young = debias(young, "Wavy_Hair", df, pinverse, device) young = debias(young, "Chubby", df, pinverse, device) young = debias(young, "No_Beard", df, pinverse, device) young = debias(young, "Mustache", df, pinverse, device) pointy = get_direction(df, "Pointy_Nose", pinverse, 1000, device) pointy = debias(pointy, "Young", df, pinverse, device) pointy = debias(pointy, "Male", df, pinverse, device) pointy = debias(pointy, "Wavy_Hair", df, pinverse, device) pointy = debias(pointy, "Chubby", df, pinverse, device) pointy = debias(pointy, "Heavy_Makeup", df, pinverse, device) wavy = get_direction(df, "Wavy_Hair", pinverse, 1000, device) wavy = debias(wavy, "Young", df, pinverse, device) wavy = debias(wavy, "Male", df, pinverse, device) wavy = debias(wavy, "Pointy_Nose", df, pinverse, device) wavy = debias(wavy, "Chubby", df, pinverse, device) wavy = debias(wavy, "Heavy_Makeup", df, pinverse, device) thick = get_direction(df, "Bushy_Eyebrows", pinverse, 1000, device) thick = debias(thick, "Male", df, pinverse, device) thick = debias(thick, "Young", df, pinverse, device) thick = debias(thick, "Pointy_Nose", df, pinverse, device) thick = debias(thick, "Wavy_Hair", df, pinverse, device) thick = debias(thick, "Mustache", df, pinverse, device) thick = debias(thick, "No_Beard", df, pinverse, device) thick = debias(thick, "Sideburns", df, pinverse, device) thick = debias(thick, "Big_Nose", df, pinverse, device) thick = debias(thick, "Big_Lips", df, pinverse, device) thick = debias(thick, "Black_Hair", df, pinverse, device) thick = debias(thick, "Brown_Hair", df, pinverse, device) thick = debias(thick, "Pale_Skin", df, pinverse, device) thick = debias(thick, "Heavy_Makeup", df, pinverse, device) @torch.no_grad() @spaces.GPU def sample_then_run(net): device = "cuda" # get mean and standard deviation for each principal component m = torch.mean(proj, 0) standev = torch.std(proj, 0) # sample sample = torch.zeros([1, 10000]).to(device) #only first 1000 PCs for i in range(1000): sample[0, i] = torch.normal(m[i], standev[i], (1,1)) net = "model_"+str(uuid.uuid4())[:4]+".pt" torch.save(sample, net) image = prompt = "sks person" negative_prompt = "low quality, blurry, unfinished, nudity, weapon" seed = 5 cfg = 3.0 steps = 25 image = inference(net, prompt, negative_prompt, cfg, steps, seed) return net,net,image @torch.no_grad() @spaces.GPU() def inference(net, prompt, negative_prompt, guidance_scale, ddim_steps, seed): device = "cuda" mean.to(device) std.to(device) v.to(device) weights = torch.load(net).to(device) network = LoRAw2w(weights, mean, std, v[:, :10000], unet, rank=1, multiplier=1.0, alpha=27.0, train_method="xattn-strict" ).to(device, torch.bfloat16) generator = torch.Generator(device=device).manual_seed(seed) latents = torch.randn( (1, unet.in_channels, 512 // 8, 512 // 8), generator = generator, device = device ).bfloat16() text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt") text_embeddings = text_encoder(text_input.input_ids.to(device))[0] max_length = text_input.input_ids.shape[-1] uncond_input = tokenizer( [negative_prompt], padding="max_length", max_length=max_length, return_tensors="pt" ) uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0] text_embeddings = torch.cat([uncond_embeddings, text_embeddings]).bfloat16() noise_scheduler.set_timesteps(ddim_steps) latents = latents * noise_scheduler.init_noise_sigma for i,t in enumerate(tqdm.tqdm(noise_scheduler.timesteps)): latent_model_input = torch.cat([latents] * 2) latent_model_input = noise_scheduler.scale_model_input(latent_model_input, timestep=t) with torch.no_grad(): with network: noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings, timestep_cond= None).sample #guidance noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) latents = noise_scheduler.step(noise_pred, t, latents).prev_sample latents = 1 / 0.18215 * latents image = vae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) image = image.detach().cpu().float().permute(0, 2, 3, 1).numpy()[0] image = Image.fromarray((image * 255).round().astype("uint8")) return image @torch.no_grad() @spaces.GPU() def edit_inference(net, prompt, negative_prompt, guidance_scale, ddim_steps, seed, start_noise, a1, a2, a3, a4): device = "cuda" mean.to(device) std.to(device) v.to(device) weights = torch.load(net).to(device) network = LoRAw2w(weights, mean, std, v[:, :10000], unet, rank=1, multiplier=1.0, alpha=27.0, train_method="xattn-strict" ).to(device, torch.bfloat16) #pad to same number of PCs pcs_original = weights.shape[1] pcs_edits = young.shape[1] padding = torch.zeros((1,pcs_original-pcs_edits)).to(device) young_pad = torch.cat((young.to(device), padding), 1) pointy_pad = torch.cat((pointy.to(device), padding), 1) wavy_pad = torch.cat((wavy.to(device), padding), 1) thick_pad = torch.cat((thick.to(device), padding), 1) edited_weights = weights+a1*1e6*young_pad+a2*1e6*pointy_pad+a3*1e6*wavy_pad+a4*2e6*thick_pad generator = torch.Generator(device=device).manual_seed(seed) latents = torch.randn( (1, unet.in_channels, 512 // 8, 512 // 8), generator = generator, device = device ).bfloat16() text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt") text_embeddings = text_encoder(text_input.input_ids.to(device))[0] max_length = text_input.input_ids.shape[-1] uncond_input = tokenizer( [negative_prompt], padding="max_length", max_length=max_length, return_tensors="pt" ) uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0] text_embeddings = torch.cat([uncond_embeddings, text_embeddings]).bfloat16() noise_scheduler.set_timesteps(ddim_steps) latents = latents * noise_scheduler.init_noise_sigma for i,t in enumerate(tqdm.tqdm(noise_scheduler.timesteps)): latent_model_input = torch.cat([latents] * 2) latent_model_input = noise_scheduler.scale_model_input(latent_model_input, timestep=t) if t>start_noise: pass elif t<=start_noise: network.proj = torch.nn.Parameter(edited_weights) network.reset() with torch.no_grad(): with network: noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings, timestep_cond= None).sample #guidance noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) latents = noise_scheduler.step(noise_pred, t, latents).prev_sample latents = 1 / 0.18215 * latents image = vae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) image = image.detach().cpu().float().permute(0, 2, 3, 1).numpy()[0] image = Image.fromarray((image * 255).round().astype("uint8")) return net, image class CustomImageDataset(Dataset): def __init__(self, images, transform=None): self.images = images self.transform = transform def __len__(self): return len(self.images) def __getitem__(self, idx): image = self.images[idx] if self.transform: image = self.transform(image) return image @spaces.GPU(duration=200) def invert(image, mask, pcs=10000, epochs=400, weight_decay = 1e-10, lr=1e-1): device = "cuda" mean.to(device) std.to(device) v.to(device) weights = torch.zeros(1,pcs).bfloat16().to(device) network = LoRAw2w( weights, mean, std, v[:, :pcs], unet, rank=1, multiplier=1.0, alpha=27.0, train_method="xattn-strict" ).to(device, torch.bfloat16) ### load mask mask = transforms.Resize((64,64), interpolation=transforms.InterpolationMode.BILINEAR)(mask) mask = torchvision.transforms.functional.pil_to_tensor(mask).unsqueeze(0).to(device).bfloat16()[:,0,:,:].unsqueeze(1) ### check if an actual mask was draw, otherwise mask is just all ones if torch.sum(mask) == 0: mask = torch.ones((1,1,64,64)).to(device).bfloat16() ### single image dataset image_transforms = transforms.Compose([transforms.Resize(512, interpolation=transforms.InterpolationMode.BILINEAR), transforms.RandomCrop(512), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]) train_dataset = CustomImageDataset(image, transform=image_transforms) train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=1, shuffle=True) ### optimizer optim = torch.optim.Adam(network.parameters(), lr=lr, weight_decay=weight_decay) ### training loop unet.train() for epoch in tqdm.tqdm(range(epochs)): for batch in train_dataloader: ### prepare inputs batch = batch.to(device).bfloat16() latents = vae.encode(batch).latent_dist.sample() latents = latents*0.18215 noise = torch.randn_like(latents) bsz = latents.shape[0] timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) timesteps = timesteps.long() noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) text_input = tokenizer("sks person", padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt") text_embeddings = text_encoder(text_input.input_ids.to(device))[0] ### loss + sgd step with network: model_pred = unet(noisy_latents, timesteps, text_embeddings).sample loss = torch.nn.functional.mse_loss(mask*model_pred.float(), mask*noise.float(), reduction="mean") optim.zero_grad() loss.backward() optim.step() #pad to 10000 PCs pcs_original = weights.shape[1] padding = torch.zeros((1,10000-pcs_original)).to(device) weights = network.proj.detach() weights = torch.cat((weights, padding), 1) net = "model_"+str(uuid.uuid4())[:4]+".pt" torch.save(weights, net) return net @spaces.GPU(duration=200) def run_inversion(net, dict, pcs, epochs, weight_decay,lr): init_image = dict["background"].convert("RGB").resize((512, 512)) mask = dict["layers"][0].convert("RGB").resize((512, 512)) net = invert([init_image], mask, pcs, epochs, weight_decay,lr) #sample an image prompt = "sks person" negative_prompt = "low quality, blurry, unfinished, nudity, weapon" seed = 5 cfg = 3.0 steps = 25 image = inference( net, prompt, negative_prompt, cfg, steps, seed) return net, net, image @spaces.GPU def file_upload(file, net): device="cuda" weights = torch.load(file.name).to(device) #pad to 10000 Principal components to keep everything consistent pcs = weights.shape[1] padding = torch.zeros((1,10000-pcs)).to(device) weights = torch.cat((weights, padding), 1) net = "model_"+str(uuid.uuid4())[:4]+".pt" torch.save(weights, net) image = prompt = "sks person" negative_prompt = "low quality, blurry, unfinished, nudity, weapon" seed = 5 cfg = 3.0 steps = 25 image = inference(net, prompt, negative_prompt, cfg, steps, seed) return net, image intro = """
Project Page | Paper | Code |
""" with gr.Blocks(css="style.css") as demo: net = gr.State() gr.HTML(intro) gr.Markdown(""" **Getting Started:** Sample a random identity or invert to get an identity-encoding model 👩🏻🎨(or - Upload a previously downloaded model using the `Uploading a model` option in `Advanced Options`). **What You Can Do?** Generate new images & edit the encoded identity 👩🏻->👩🏻🦱. See further instructions and tips at the bottom of the page 🤗.""") with gr.Column(): with gr.Row(): with gr.Column(): gr.Markdown(""" ❶ sample a face -or- upload an image (optional - draw a mask over the head) and invert""") input_image = gr.ImageEditor(elem_id="image_upload", type='pil', label="Reference Identity", width=512, height=512) with gr.Row(): sample = gr.Button("🎲 Sample random face") invert_button = gr.Button("⬆️ Invert") with gr.Column(): gr.Markdown("""❷ Generate new images of the sampled/inverted identity & edit with the sliders""") gallery = gr.Image(label="Generated Image",height=512, width=512, interactive=False) submit = gr.Button("Generate") prompt = gr.Textbox(label="Prompt", info="Make sure to include 'sks person'" , placeholder="sks person", value="sks person") # Editing with gr.Column(): with gr.Row(): a1 = gr.Slider(label="- Young +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True) a2 = gr.Slider(label="- Pointy Nose +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True) with gr.Row(): a3 = gr.Slider(label="- Curly Hair +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True) a4 = gr.Slider(label="- Thick Eyebrows +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True) with gr.Accordion("Advanced Options", open=False): with gr.Tab("Sampling"): with gr.Row(): seed = gr.Number(value=5, label="Seed", precision=0, interactive=True) cfg= gr.Slider(label="CFG", value=3.0, step=0.1, minimum=0, maximum=10, interactive=True) steps = gr.Slider(label="Inference Steps", value=25, step=1, minimum=0, maximum=100, interactive=True) with gr.Row(): negative_prompt = gr.Textbox(label="Negative Prompt", placeholder="low quality, blurry, unfinished, nudity, weapon", value="low quality, blurry, unfinished, nudity, weapon") injection_step = gr.Slider(label="Injection Step", value=800, step=1, minimum=0, maximum=1000, interactive=True) with gr.Tab("Inversion"): with gr.Row(): lr = gr.Number(value=1e-1, label="Learning Rate", interactive=True) pcs = gr.Slider(label="# Principal Components", value=10000, step=1, minimum=1, maximum=10000, interactive=True) with gr.Row(): epochs = gr.Slider(label="Epochs", value=800, step=1, minimum=1, maximum=2000, interactive=True) weight_decay = gr.Number(value=1e-10, label="Weight Decay", interactive=True) with gr.Tab("Uploading a model"): gr.Markdown("""