Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,3 +1,7 @@
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
import sys
|
3 |
import os
|
@@ -9,8 +13,9 @@ import warnings
|
|
9 |
warnings.filterwarnings("ignore")
|
10 |
from PIL import Image
|
11 |
from utils import load_models, save_model_w2w, save_model_for_diffusers
|
12 |
-
from sampling import sample_weights
|
13 |
from editing import get_direction, debias
|
|
|
|
|
14 |
from huggingface_hub import snapshot_download
|
15 |
|
16 |
global device
|
@@ -20,11 +25,13 @@ global vae
|
|
20 |
global text_encoder
|
21 |
global tokenizer
|
22 |
global noise_scheduler
|
23 |
-
|
24 |
device = "cuda:0"
|
25 |
generator = torch.Generator(device=device)
|
26 |
|
27 |
|
|
|
|
|
28 |
models_path = snapshot_download(repo_id="Snapchat/w2w")
|
29 |
|
30 |
mean = torch.load(f"{models_path}/files/mean.pt").bfloat16().to(device)
|
@@ -36,7 +43,7 @@ weight_dimensions = torch.load(f"{models_path}/files/weight_dimensions.pt")
|
|
36 |
pinverse = torch.load(f"{models_path}/files/pinverse_1000pc.pt").bfloat16().to(device)
|
37 |
|
38 |
unet, vae, text_encoder, tokenizer, noise_scheduler = load_models(device)
|
39 |
-
|
40 |
|
41 |
def sample_model():
|
42 |
global unet
|
@@ -47,6 +54,9 @@ def sample_model():
|
|
47 |
network = sample_weights(unet, proj, mean, std, v[:, :1000], device, factor = 1.00)
|
48 |
|
49 |
|
|
|
|
|
|
|
50 |
@torch.no_grad()
|
51 |
def inference( prompt, negative_prompt, guidance_scale, ddim_steps, seed):
|
52 |
global device
|
@@ -94,7 +104,7 @@ def inference( prompt, negative_prompt, guidance_scale, ddim_steps, seed):
|
|
94 |
|
95 |
image = Image.fromarray((image * 255).round().astype("uint8"))
|
96 |
|
97 |
-
return
|
98 |
|
99 |
|
100 |
|
@@ -173,16 +183,13 @@ def edit_inference(prompt, negative_prompt, guidance_scale, ddim_steps, seed, st
|
|
173 |
network.proj = torch.nn.Parameter(original_weights)
|
174 |
network.reset()
|
175 |
|
176 |
-
return
|
177 |
|
178 |
|
179 |
|
180 |
|
181 |
def sample_then_run():
|
182 |
-
|
183 |
-
|
184 |
-
sample_model()
|
185 |
-
|
186 |
prompt = "sks person"
|
187 |
negative_prompt = "low quality, blurry, unfinished, cartoon"
|
188 |
seed = 5
|
@@ -192,6 +199,8 @@ def sample_then_run():
|
|
192 |
return image
|
193 |
|
194 |
|
|
|
|
|
195 |
#directions
|
196 |
global young
|
197 |
global pointy
|
@@ -233,6 +242,115 @@ large = debias(large, "Wavy_Hair", df, pinverse, device)
|
|
233 |
large_max = torch.max(proj@large[0]/(torch.norm(large))**2).item()
|
234 |
large_min = torch.min(proj@large[0]/(torch.norm(large))**2).item()
|
235 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
236 |
|
237 |
intro = """
|
238 |
<div style="display: flex;align-items: center;justify-content: center">
|
@@ -249,61 +367,97 @@ intro = """
|
|
249 |
</p>
|
250 |
"""
|
251 |
|
252 |
-
with gr.Blocks(css="style.css") as demo:
|
253 |
-
gr.HTML(intro)
|
254 |
-
with gr.Row():
|
255 |
-
with gr.Column():
|
256 |
-
gallery1 = gr.Gallery(label="Identity from Sampled Model")
|
257 |
-
sample = gr.Button("Sample New Model")
|
258 |
-
gallery2 = gr.Gallery(label="Identity from Edited Model")
|
259 |
-
|
260 |
-
|
261 |
-
with gr.Row():
|
262 |
-
with gr.Column():
|
263 |
-
prompt = gr.Textbox(label="Prompt",
|
264 |
-
info="Make sure to include 'sks person'" ,
|
265 |
-
placeholder="sks person",
|
266 |
-
value="sks person")
|
267 |
-
negative_prompt = gr.Textbox(label="Negative Prompt", placeholder="low quality, blurry, unfinished, cartoon", value="low quality, blurry, unfinished, cartoon")
|
268 |
-
with gr.Row():
|
269 |
-
a1 = gr.Slider(label="+Young", value=0, step=0.001, minimum=-1, maximum=1, interactive=True)
|
270 |
-
a2 = gr.Slider(label="+Pointy Nose", value=0, step=0.001, minimum=-1, maximum=1, interactive=True)
|
271 |
-
with gr.Row():
|
272 |
-
a3 = gr.Slider(label="+Curly Hair", value=0, step=0.001, minimum=-1, maximum=1, interactive=True)
|
273 |
-
a4 = gr.Slider(label="+Large", value=0, step=0.001, minimum=-1, maximum=1, interactive=True)
|
274 |
-
|
275 |
-
|
276 |
-
with gr.Accordion("Advanced Options", open=False):
|
277 |
-
with gr.Column():
|
278 |
-
seed = gr.Number(value=5, label="Seed", interactive=True)
|
279 |
-
cfg = gr.Slider(label="CFG", value=3.0, step=0.1, minimum=0, maximum=10, interactive=True)
|
280 |
-
steps = gr.Slider(label="Inference Steps", value=50, step=1, minimum=0, maximum=100, interactive=True)
|
281 |
-
injection_step = gr.Slider(label="Injection Step", value=800, step=1, minimum=0, maximum=1000, interactive=True)
|
282 |
|
283 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
284 |
|
285 |
-
submit = gr.Button("Submit")
|
286 |
-
|
287 |
-
|
288 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
289 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
290 |
|
291 |
|
292 |
-
|
293 |
-
|
|
|
|
|
294 |
|
|
|
|
|
|
|
|
|
295 |
|
296 |
-
|
297 |
-
|
298 |
-
|
|
|
|
|
|
|
299 |
|
|
|
|
|
|
|
|
|
|
|
300 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
301 |
|
|
|
|
|
302 |
|
|
|
|
|
|
|
303 |
|
|
|
|
|
|
|
304 |
|
305 |
-
|
306 |
|
307 |
|
|
|
|
|
308 |
|
|
|
309 |
|
|
|
1 |
+
import torch
|
2 |
+
import torchvision
|
3 |
+
import torchvision.transforms as transforms
|
4 |
+
from torch.utils.data import Dataset, DataLoader
|
5 |
import gradio as gr
|
6 |
import sys
|
7 |
import os
|
|
|
13 |
warnings.filterwarnings("ignore")
|
14 |
from PIL import Image
|
15 |
from utils import load_models, save_model_w2w, save_model_for_diffusers
|
|
|
16 |
from editing import get_direction, debias
|
17 |
+
from sampling import sample_weights
|
18 |
+
from lora_w2w import LoRAw2w
|
19 |
from huggingface_hub import snapshot_download
|
20 |
|
21 |
global device
|
|
|
25 |
global text_encoder
|
26 |
global tokenizer
|
27 |
global noise_scheduler
|
28 |
+
global network
|
29 |
device = "cuda:0"
|
30 |
generator = torch.Generator(device=device)
|
31 |
|
32 |
|
33 |
+
|
34 |
+
|
35 |
models_path = snapshot_download(repo_id="Snapchat/w2w")
|
36 |
|
37 |
mean = torch.load(f"{models_path}/files/mean.pt").bfloat16().to(device)
|
|
|
43 |
pinverse = torch.load(f"{models_path}/files/pinverse_1000pc.pt").bfloat16().to(device)
|
44 |
|
45 |
unet, vae, text_encoder, tokenizer, noise_scheduler = load_models(device)
|
46 |
+
|
47 |
|
48 |
def sample_model():
|
49 |
global unet
|
|
|
54 |
network = sample_weights(unet, proj, mean, std, v[:, :1000], device, factor = 1.00)
|
55 |
|
56 |
|
57 |
+
|
58 |
+
|
59 |
+
|
60 |
@torch.no_grad()
|
61 |
def inference( prompt, negative_prompt, guidance_scale, ddim_steps, seed):
|
62 |
global device
|
|
|
104 |
|
105 |
image = Image.fromarray((image * 255).round().astype("uint8"))
|
106 |
|
107 |
+
return image
|
108 |
|
109 |
|
110 |
|
|
|
183 |
network.proj = torch.nn.Parameter(original_weights)
|
184 |
network.reset()
|
185 |
|
186 |
+
return image
|
187 |
|
188 |
|
189 |
|
190 |
|
191 |
def sample_then_run():
|
192 |
+
sample_model()
|
|
|
|
|
|
|
193 |
prompt = "sks person"
|
194 |
negative_prompt = "low quality, blurry, unfinished, cartoon"
|
195 |
seed = 5
|
|
|
199 |
return image
|
200 |
|
201 |
|
202 |
+
|
203 |
+
|
204 |
#directions
|
205 |
global young
|
206 |
global pointy
|
|
|
242 |
large_max = torch.max(proj@large[0]/(torch.norm(large))**2).item()
|
243 |
large_min = torch.min(proj@large[0]/(torch.norm(large))**2).item()
|
244 |
|
245 |
+
class CustomImageDataset(Dataset):
|
246 |
+
def __init__(self, images, transform=None):
|
247 |
+
self.images = images
|
248 |
+
self.transform = transform
|
249 |
+
|
250 |
+
def __len__(self):
|
251 |
+
return len(self.images)
|
252 |
+
|
253 |
+
def __getitem__(self, idx):
|
254 |
+
image = self.images[idx]
|
255 |
+
if self.transform:
|
256 |
+
image = self.transform(image)
|
257 |
+
return image
|
258 |
+
|
259 |
+
def invert(image, mask, pcs=10000, epochs=400, weight_decay = 1e-10, lr=1e-1):
|
260 |
+
global unet
|
261 |
+
del unet
|
262 |
+
global network
|
263 |
+
unet, _, _, _, _ = load_models(device)
|
264 |
+
|
265 |
+
proj = torch.zeros(1,pcs).bfloat16().to(device)
|
266 |
+
network = LoRAw2w( proj, mean, std, v[:, :pcs],
|
267 |
+
unet,
|
268 |
+
rank=1,
|
269 |
+
multiplier=1.0,
|
270 |
+
alpha=27.0,
|
271 |
+
train_method="xattn-strict"
|
272 |
+
).to(device, torch.bfloat16)
|
273 |
+
|
274 |
+
|
275 |
+
|
276 |
+
|
277 |
+
|
278 |
+
### load mask
|
279 |
+
mask = transforms.Resize((64,64), interpolation=transforms.InterpolationMode.BILINEAR)(mask)
|
280 |
+
mask = torchvision.transforms.functional.pil_to_tensor(mask).unsqueeze(0).to(device).bfloat16()[:,0,:,:].unsqueeze(1)
|
281 |
+
### check if an actual mask was draw, otherwise mask is just all ones
|
282 |
+
if torch.sum(mask) == 0:
|
283 |
+
mask = torch.ones((1,1,64,64)).to(device).bfloat16()
|
284 |
+
|
285 |
+
|
286 |
+
### single image dataset
|
287 |
+
image_transforms = transforms.Compose([transforms.Resize(512, interpolation=transforms.InterpolationMode.BILINEAR),
|
288 |
+
transforms.RandomCrop(512),
|
289 |
+
transforms.ToTensor(),
|
290 |
+
transforms.Normalize([0.5], [0.5])])
|
291 |
+
|
292 |
+
|
293 |
+
train_dataset = CustomImageDataset(image, transform=image_transforms)
|
294 |
+
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=1, shuffle=True)
|
295 |
+
|
296 |
+
|
297 |
+
|
298 |
+
### optimizer
|
299 |
+
optim = torch.optim.Adam(network.parameters(), lr=lr, weight_decay=weight_decay)
|
300 |
+
|
301 |
+
### training loop
|
302 |
+
unet.train()
|
303 |
+
for epoch in tqdm.tqdm(range(epochs)):
|
304 |
+
for batch in train_dataloader:
|
305 |
+
### prepare inputs
|
306 |
+
batch = batch.to(device).bfloat16()
|
307 |
+
latents = vae.encode(batch).latent_dist.sample()
|
308 |
+
latents = latents*0.18215
|
309 |
+
noise = torch.randn_like(latents)
|
310 |
+
bsz = latents.shape[0]
|
311 |
+
|
312 |
+
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
|
313 |
+
timesteps = timesteps.long()
|
314 |
+
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
315 |
+
text_input = tokenizer("sks person", padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
|
316 |
+
text_embeddings = text_encoder(text_input.input_ids.to(device))[0]
|
317 |
+
|
318 |
+
### loss + sgd step
|
319 |
+
with network:
|
320 |
+
model_pred = unet(noisy_latents, timesteps, text_embeddings).sample
|
321 |
+
loss = torch.nn.functional.mse_loss(mask*model_pred.float(), mask*noise.float(), reduction="mean")
|
322 |
+
optim.zero_grad()
|
323 |
+
loss.backward()
|
324 |
+
optim.step()
|
325 |
+
|
326 |
+
### return optimized network
|
327 |
+
|
328 |
+
return network
|
329 |
+
|
330 |
+
|
331 |
+
|
332 |
+
def run_inversion(dict, pcs, epochs, weight_decay,lr):
|
333 |
+
global network
|
334 |
+
init_image = dict["image"].convert("RGB").resize((512, 512))
|
335 |
+
mask = dict["mask"].convert("RGB").resize((512, 512))
|
336 |
+
network = invert([init_image], mask, pcs, epochs, weight_decay,lr)
|
337 |
+
|
338 |
+
|
339 |
+
#sample an image
|
340 |
+
prompt = "sks person"
|
341 |
+
negative_prompt = "low quality, blurry, unfinished, cartoon"
|
342 |
+
seed = 5
|
343 |
+
cfg = 3.0
|
344 |
+
steps = 50
|
345 |
+
image = inference( prompt, negative_prompt, cfg, steps, seed)
|
346 |
+
torch.save(network.proj, "model.pt" )
|
347 |
+
return image, "model.pt"
|
348 |
+
|
349 |
+
|
350 |
+
|
351 |
+
|
352 |
+
|
353 |
+
|
354 |
|
355 |
intro = """
|
356 |
<div style="display: flex;align-items: center;justify-content: center">
|
|
|
367 |
</p>
|
368 |
"""
|
369 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
370 |
|
371 |
|
372 |
+
with gr.Blocks(css="style.css") as demo:
|
373 |
+
gr.HTML(intro)
|
374 |
+
with gr.Tab("Sampling Models + Editing"):
|
375 |
+
with gr.Row():
|
376 |
+
with gr.Column():
|
377 |
+
gallery1 = gr.Image(label="Identity from Sampled Model")
|
378 |
+
sample = gr.Button("Sample New Model")
|
379 |
+
gallery2 = gr.Image(label="Identity from Edited Model")
|
380 |
|
|
|
|
|
|
|
381 |
|
382 |
+
with gr.Row():
|
383 |
+
with gr.Column():
|
384 |
+
prompt = gr.Textbox(label="Prompt",
|
385 |
+
info="Make sure to include 'sks person'" ,
|
386 |
+
placeholder="sks person",
|
387 |
+
value="sks person")
|
388 |
+
negative_prompt = gr.Textbox(label="Negative Prompt", placeholder="low quality, blurry, unfinished, cartoon", value="low quality, blurry, unfinished, cartoon")
|
389 |
+
with gr.Row():
|
390 |
+
a1 = gr.Slider(label="- Young +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True)
|
391 |
+
|
392 |
+
a2 = gr.Slider(label="- Pointy Nose +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True)
|
393 |
+
with gr.Row():
|
394 |
+
a3 = gr.Slider(label="- Curly Hair +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True)
|
395 |
+
a4 = gr.Slider(label="- placeholder for some fourth attribute +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True)
|
396 |
|
397 |
+
|
398 |
+
with gr.Accordion("Advanced Options", open=False):
|
399 |
+
with gr.Column():
|
400 |
+
seed = gr.Number(value=5, label="Seed", precision=0, interactive=True)
|
401 |
+
cfg = gr.Slider(label="CFG", value=3.0, step=0.1, minimum=0, maximum=10, interactive=True)
|
402 |
+
steps = gr.Slider(label="Inference Steps", value=50, step=1, minimum=0, maximum=100, interactive=True)
|
403 |
+
injection_step = gr.Slider(label="Injection Step", value=800, step=1, minimum=0, maximum=1000, interactive=True)
|
404 |
|
405 |
|
406 |
+
|
407 |
+
submit = gr.Button("Generate")
|
408 |
+
|
409 |
+
sample.click(fn=sample_then_run, outputs=gallery1)
|
410 |
|
411 |
+
submit.click(fn=edit_inference,
|
412 |
+
inputs=[prompt, negative_prompt, cfg, steps, seed, injection_step, a1, a2, a3, a4],
|
413 |
+
outputs=gallery2)
|
414 |
+
|
415 |
|
416 |
+
|
417 |
+
with gr.Tab("Inversion"):
|
418 |
+
with gr.Row():
|
419 |
+
with gr.Column():
|
420 |
+
input_image = gr.Image(source='upload', elem_id="image_upload", tool='sketch', type='pil', label="Upload image and draw to define mask",
|
421 |
+
height=512, width=512, brush_color='#00FFFF', mask_opacity=0.6)
|
422 |
|
423 |
+
|
424 |
+
lr = gr.Number(value=1e-1, label="Learning Rate", interactive=True)
|
425 |
+
weight_decay = gr.Number(value=1e-10, label="Weight Decay", interactive=True)
|
426 |
+
pcs = gr.Slider(label="# Principal Components", value=10000, step=1, minimum=1, maximum=10000, interactive=True)
|
427 |
+
epochs = gr.Slider(label="Epochs", value=400, step=1, minimum=1, maximum=2000, interactive=True)
|
428 |
|
429 |
+
invert_button = gr.Button("Invert")
|
430 |
+
|
431 |
+
with gr.Column():
|
432 |
+
gallery = gr.Image(label="Sample from Inverted Model", height=512, width=512)
|
433 |
+
prompt = gr.Textbox(label="Prompt",
|
434 |
+
info="Make sure to include 'sks person'" ,
|
435 |
+
placeholder="sks person",
|
436 |
+
value="sks person")
|
437 |
+
negative_prompt = gr.Textbox(label="Negative Prompt", placeholder="low quality, blurry, unfinished, cartoon", value="low quality, blurry, unfinished, cartoon")
|
438 |
+
seed = gr.Number(value=5, label="Seed", precision=0, interactive=True)
|
439 |
+
cfg = gr.Slider(label="CFG", value=3.0, step=0.1, minimum=0, maximum=10, interactive=True)
|
440 |
+
steps = gr.Slider(label="Inference Steps", value=50, step=1, minimum=0, maximum=100, interactive=True)
|
441 |
+
submit = gr.Button("Generate")
|
442 |
+
|
443 |
+
file_output = gr.File(label="Download Model", container=False)
|
444 |
|
445 |
+
|
446 |
+
|
447 |
|
448 |
+
invert_button.click(fn=run_inversion,
|
449 |
+
inputs=[input_image, pcs, epochs, weight_decay,lr],
|
450 |
+
outputs = [gallery, file_output])
|
451 |
|
452 |
+
submit.click(fn=inference,
|
453 |
+
inputs=[prompt, negative_prompt, cfg, steps, seed,],
|
454 |
+
outputs=gallery)
|
455 |
|
456 |
+
|
457 |
|
458 |
|
459 |
+
|
460 |
+
|
461 |
|
462 |
+
demo.queue().launch(share=True)
|
463 |
|