Spaces:
Runtime error
Runtime error
aayushmnit
commited on
Commit
•
cb16212
1
Parent(s):
5233387
Uploading diffedit app
Browse files- .gitattributes +4 -0
- Gradio Demo.ipynb +0 -0
- app.py +179 -0
- fruitbowl.jpg +0 -0
- horse.jpg +3 -0
- packages.txt +1 -0
- requirements.txt +10 -0
.gitattributes
CHANGED
@@ -32,3 +32,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
35 |
+
horse.jpg filter=lfs diff=lfs merge=lfs -text
|
36 |
+
fruitbowl.jpg filter=lfs diff=lfs merge=lfs -text
|
37 |
+
*jpg filter=lfs diff=lfs merge=lfs -text
|
38 |
+
*.jpg filter=lfs diff=lfs merge=lfs -text
|
Gradio Demo.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
app.py
ADDED
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torchvision import transforms as tfms
|
3 |
+
import numpy as np
|
4 |
+
import cv2
|
5 |
+
from PIL import Image
|
6 |
+
from transformers import CLIPTextModel, CLIPTokenizer
|
7 |
+
from diffusers import AutoencoderKL, UNet2DConditionModel, DDIMScheduler
|
8 |
+
from diffusers import StableDiffusionInpaintPipeline
|
9 |
+
import gradio as gr
|
10 |
+
|
11 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
12 |
+
def load_artifacts():
|
13 |
+
'''
|
14 |
+
A function to load all diffusion artifacts
|
15 |
+
'''
|
16 |
+
vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae", torch_dtype=torch.float16).to(device)
|
17 |
+
unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet", torch_dtype=torch.float16).to(device)
|
18 |
+
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=torch.float16)
|
19 |
+
text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=torch.float16).to(device)
|
20 |
+
scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
|
21 |
+
return vae, unet, tokenizer, text_encoder, scheduler
|
22 |
+
|
23 |
+
def load_image(p):
|
24 |
+
'''
|
25 |
+
Function to load images from a defined path
|
26 |
+
'''
|
27 |
+
return Image.open(p).convert('RGB').resize((512,512))
|
28 |
+
|
29 |
+
def pil_to_latents(image):
|
30 |
+
'''
|
31 |
+
Function to convert image to latents
|
32 |
+
'''
|
33 |
+
init_image = tfms.ToTensor()(image).unsqueeze(0) * 2.0 - 1.0
|
34 |
+
init_image = init_image.to(device=device, dtype=torch.float16)
|
35 |
+
init_latent_dist = vae.encode(init_image).latent_dist.sample() * 0.18215
|
36 |
+
return init_latent_dist
|
37 |
+
|
38 |
+
def latents_to_pil(latents):
|
39 |
+
'''
|
40 |
+
Function to convert latents to images
|
41 |
+
'''
|
42 |
+
latents = (1 / 0.18215) * latents
|
43 |
+
with torch.no_grad():
|
44 |
+
image = vae.decode(latents).sample
|
45 |
+
image = (image / 2 + 0.5).clamp(0, 1)
|
46 |
+
image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
|
47 |
+
images = (image * 255).round().astype("uint8")
|
48 |
+
pil_images = [Image.fromarray(image) for image in images]
|
49 |
+
return pil_images
|
50 |
+
|
51 |
+
def text_enc(prompts, maxlen=None):
|
52 |
+
'''
|
53 |
+
A function to take a texual promt and convert it into embeddings
|
54 |
+
'''
|
55 |
+
if maxlen is None: maxlen = tokenizer.model_max_length
|
56 |
+
inp = tokenizer(prompts, padding="max_length", max_length=maxlen, truncation=True, return_tensors="pt")
|
57 |
+
return text_encoder(inp.input_ids.to(device))[0].half()
|
58 |
+
|
59 |
+
def prompt_2_img_i2i_fast(prompts, init_img, g=7.5, seed=100, strength =0.5, steps=50, dim=512):
|
60 |
+
"""
|
61 |
+
Diffusion process to convert prompt to image
|
62 |
+
"""
|
63 |
+
# Converting textual prompts to embedding
|
64 |
+
text = text_enc(prompts)
|
65 |
+
|
66 |
+
# Adding an unconditional prompt , helps in the generation process
|
67 |
+
uncond = text_enc([""], text.shape[1])
|
68 |
+
emb = torch.cat([uncond, text])
|
69 |
+
|
70 |
+
# Setting the seed
|
71 |
+
if seed: torch.manual_seed(seed)
|
72 |
+
|
73 |
+
# Setting number of steps in scheduler
|
74 |
+
scheduler.set_timesteps(steps)
|
75 |
+
|
76 |
+
# Convert the seed image to latent
|
77 |
+
init_latents = pil_to_latents(init_img)
|
78 |
+
|
79 |
+
# Figuring initial time step based on strength
|
80 |
+
init_timestep = int(steps * strength)
|
81 |
+
timesteps = scheduler.timesteps[-init_timestep]
|
82 |
+
timesteps = torch.tensor([timesteps], device=device)
|
83 |
+
|
84 |
+
# Adding noise to the latents
|
85 |
+
noise = torch.randn(init_latents.shape, generator=None, device=device, dtype=init_latents.dtype)
|
86 |
+
init_latents = scheduler.add_noise(init_latents, noise, timesteps)
|
87 |
+
latents = init_latents
|
88 |
+
|
89 |
+
# We need to scale the i/p latents to match the variance
|
90 |
+
inp = scheduler.scale_model_input(torch.cat([latents] * 2), timesteps)
|
91 |
+
# Predicting noise residual using U-Net
|
92 |
+
with torch.no_grad(): u,t = unet(inp, timesteps, encoder_hidden_states=emb).sample.chunk(2)
|
93 |
+
|
94 |
+
# Performing Guidance
|
95 |
+
pred = u + g*(t-u)
|
96 |
+
|
97 |
+
# Zero shot prediction
|
98 |
+
latents = scheduler.step(pred, timesteps, latents).pred_original_sample
|
99 |
+
|
100 |
+
# Returning the latent representation to output an array of 4x64x64
|
101 |
+
return latents.detach().cpu()
|
102 |
+
|
103 |
+
def create_mask_fast(init_img, rp, qp, n=20, s=0.5):
|
104 |
+
## Initialize a dictionary to save n iterations
|
105 |
+
diff = {}
|
106 |
+
|
107 |
+
## Repeating the difference process n times
|
108 |
+
for idx in range(n):
|
109 |
+
## Creating denoised sample using reference / original text
|
110 |
+
orig_noise = prompt_2_img_i2i_fast(prompts=rp, init_img=init_img, strength=s, seed = 100*idx)[0]
|
111 |
+
## Creating denoised sample using query / target text
|
112 |
+
query_noise = prompt_2_img_i2i_fast(prompts=qp, init_img=init_img, strength=s, seed = 100*idx)[0]
|
113 |
+
## Taking the difference
|
114 |
+
diff[idx] = (np.array(orig_noise)-np.array(query_noise))
|
115 |
+
|
116 |
+
## Creating a mask placeholder
|
117 |
+
mask = np.zeros_like(diff[0])
|
118 |
+
|
119 |
+
## Taking an average of 10 iterations
|
120 |
+
for idx in range(n):
|
121 |
+
## Note np.abs is a key step
|
122 |
+
mask += np.abs(diff[idx])
|
123 |
+
|
124 |
+
## Averaging multiple channels
|
125 |
+
mask = mask.mean(0)
|
126 |
+
|
127 |
+
## Normalizing
|
128 |
+
mask = (mask - mask.mean()) / np.std(mask)
|
129 |
+
|
130 |
+
## Binarizing and returning the mask object
|
131 |
+
return (mask > 0).astype("uint8")
|
132 |
+
|
133 |
+
def improve_mask(mask):
|
134 |
+
mask = cv2.GaussianBlur(mask*255,(3,3),1) > 0
|
135 |
+
return mask.astype('uint8')
|
136 |
+
|
137 |
+
vae, unet, tokenizer, text_encoder, scheduler = load_artifacts()
|
138 |
+
pipe = StableDiffusionInpaintPipeline.from_pretrained(
|
139 |
+
"runwayml/stable-diffusion-inpainting",
|
140 |
+
revision="fp16",
|
141 |
+
torch_dtype=torch.float16,
|
142 |
+
).to(device)
|
143 |
+
|
144 |
+
def fastDiffEdit(init_img, reference_prompt , query_prompt, g=7.5, seed=100, strength =0.7, steps=20, dim=512):
|
145 |
+
|
146 |
+
## Step 1: Create mask
|
147 |
+
mask = create_mask_fast(init_img=init_img, rp=reference_prompt, qp=query_prompt, n=20)
|
148 |
+
|
149 |
+
## Improve masking using CV trick
|
150 |
+
mask = improve_mask(mask)
|
151 |
+
|
152 |
+
## Step 2 and 3: Diffusion process using mask
|
153 |
+
output = pipe(
|
154 |
+
prompt=query_prompt,
|
155 |
+
image=init_img,
|
156 |
+
mask_image=Image.fromarray(mask*255).resize((512,512)),
|
157 |
+
generator=torch.Generator(device).manual_seed(100),
|
158 |
+
num_inference_steps = steps
|
159 |
+
).images
|
160 |
+
return output[0]
|
161 |
+
|
162 |
+
|
163 |
+
|
164 |
+
demo = gr.Interface(
|
165 |
+
fn=fastDiffEdit,
|
166 |
+
inputs=[
|
167 |
+
gr.inputs.Image(shape=(512, 512), type="pil", label = "Upload your image photo"),
|
168 |
+
gr.Textbox(label="Describe your image. Ex: a horse image"),
|
169 |
+
gr.Textbox(label="Retype the description with target output. Ex: a zebra image")],
|
170 |
+
outputs="image",
|
171 |
+
title = "DiffEdit demo",
|
172 |
+
description = "DiffEdit paper demo. Upload an image, pass reference prompt describing the image, pass query prompt to replace the object with target object",
|
173 |
+
examples = [
|
174 |
+
["fruitbowl.jpg", "a bowl of fruit", "a bowl of grapes"],
|
175 |
+
["horse.jpg", "a horse image", "a zebra image"]],
|
176 |
+
enable_queue=True
|
177 |
+
)
|
178 |
+
|
179 |
+
demo.launch()
|
fruitbowl.jpg
ADDED
horse.jpg
ADDED
Git LFS Details
|
packages.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
python3-opencv
|
requirements.txt
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
--extra-index-url https://download.pytorch.org/whl/cu113
|
2 |
+
torch
|
3 |
+
torchvision
|
4 |
+
|
5 |
+
Pillow
|
6 |
+
opencv-python
|
7 |
+
ftfy
|
8 |
+
transformers==4.23.1
|
9 |
+
diffusers==0.6.0
|
10 |
+
|