Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -1,27 +1,26 @@
|
|
1 |
import os
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
import torch
|
6 |
import torchvision
|
7 |
import torchvision.transforms as transforms
|
8 |
from torch.utils.data import Dataset, DataLoader
|
9 |
import gradio as gr
|
10 |
import sys
|
11 |
-
import os
|
12 |
import tqdm
|
13 |
sys.path.append(os.path.abspath(os.path.join("", "..")))
|
14 |
-
import torch
|
15 |
import gc
|
16 |
import warnings
|
17 |
warnings.filterwarnings("ignore")
|
18 |
from PIL import Image
|
19 |
-
|
|
|
20 |
from editing import get_direction, debias
|
21 |
from sampling import sample_weights
|
22 |
from lora_w2w import LoRAw2w
|
23 |
from huggingface_hub import snapshot_download
|
24 |
-
import
|
|
|
25 |
|
26 |
|
27 |
global device
|
@@ -32,11 +31,9 @@ global text_encoder
|
|
32 |
global tokenizer
|
33 |
global noise_scheduler
|
34 |
global network
|
35 |
-
global original_image
|
36 |
device = "cuda:0"
|
37 |
generator = torch.Generator(device=device)
|
38 |
-
|
39 |
-
import spaces
|
40 |
|
41 |
|
42 |
models_path = snapshot_download(repo_id="Snapchat/w2w")
|
@@ -125,12 +122,9 @@ def edit_inference(prompt, negative_prompt, guidance_scale, ddim_steps, seed, st
|
|
125 |
global pointy
|
126 |
global wavy
|
127 |
global large
|
128 |
-
global original_image
|
129 |
-
|
130 |
|
131 |
original_weights = network.proj.clone()
|
132 |
|
133 |
-
|
134 |
#pad to same number of PCs
|
135 |
pcs_original = original_weights.shape[1]
|
136 |
pcs_edits = young.shape[1]
|
@@ -141,7 +135,7 @@ def edit_inference(prompt, negative_prompt, guidance_scale, ddim_steps, seed, st
|
|
141 |
large_pad = torch.cat((large, padding), 1)
|
142 |
|
143 |
|
144 |
-
edited_weights = original_weights+a1*1e6*young_pad+a2*1e6*pointy_pad+a3*1e6*wavy_pad+a4*
|
145 |
|
146 |
generator = generator.manual_seed(seed)
|
147 |
latents = torch.randn(
|
@@ -197,22 +191,19 @@ def edit_inference(prompt, negative_prompt, guidance_scale, ddim_steps, seed, st
|
|
197 |
#reset weights back to original
|
198 |
network.proj = torch.nn.Parameter(original_weights)
|
199 |
network.reset()
|
200 |
-
|
201 |
-
return
|
202 |
|
203 |
def sample_then_run():
|
204 |
-
global original_image
|
205 |
sample_model()
|
206 |
prompt = "sks person"
|
207 |
negative_prompt = "low quality, blurry, unfinished, nudity, weapon"
|
208 |
seed = 5
|
209 |
cfg = 3.0
|
210 |
steps = 50
|
211 |
-
|
212 |
torch.save(network.proj, "model.pt" )
|
213 |
-
|
214 |
-
|
215 |
-
return (original_image, original_image), "model.pt"
|
216 |
|
217 |
|
218 |
global young
|
@@ -275,14 +266,10 @@ class CustomImageDataset(Dataset):
|
|
275 |
image = self.transform(image)
|
276 |
return image
|
277 |
|
278 |
-
def invert(
|
279 |
global unet
|
280 |
del unet
|
281 |
global network
|
282 |
-
|
283 |
-
image = dict["background"].convert("RGB").resize((512, 512))
|
284 |
-
mask = dict["layers"][0].convert("RGB").resize((512, 512))
|
285 |
-
|
286 |
unet, _, _, _, _ = load_models(device)
|
287 |
|
288 |
proj = torch.zeros(1,pcs).bfloat16().to(device)
|
@@ -294,18 +281,13 @@ def invert(dict, pcs=10000, epochs=400, weight_decay = 1e-10, lr=1e-1):
|
|
294 |
train_method="xattn-strict"
|
295 |
).to(device, torch.bfloat16)
|
296 |
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
### load mask
|
302 |
mask = transforms.Resize((64,64), interpolation=transforms.InterpolationMode.BILINEAR)(mask)
|
303 |
mask = torchvision.transforms.functional.pil_to_tensor(mask).unsqueeze(0).to(device).bfloat16()[:,0,:,:].unsqueeze(1)
|
304 |
### check if an actual mask was draw, otherwise mask is just all ones
|
305 |
if torch.sum(mask) == 0:
|
306 |
mask = torch.ones((1,1,64,64)).to(device).bfloat16()
|
307 |
-
|
308 |
-
|
309 |
### single image dataset
|
310 |
image_transforms = transforms.Compose([transforms.Resize(512, interpolation=transforms.InterpolationMode.BILINEAR),
|
311 |
transforms.RandomCrop(512),
|
@@ -313,11 +295,9 @@ def invert(dict, pcs=10000, epochs=400, weight_decay = 1e-10, lr=1e-1):
|
|
313 |
transforms.Normalize([0.5], [0.5])])
|
314 |
|
315 |
|
316 |
-
train_dataset = CustomImageDataset(
|
317 |
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=1, shuffle=True)
|
318 |
|
319 |
-
|
320 |
-
|
321 |
### optimizer
|
322 |
optim = torch.optim.Adam(network.parameters(), lr=lr, weight_decay=weight_decay)
|
323 |
|
@@ -347,40 +327,34 @@ def invert(dict, pcs=10000, epochs=400, weight_decay = 1e-10, lr=1e-1):
|
|
347 |
optim.step()
|
348 |
|
349 |
### return optimized network
|
350 |
-
|
351 |
return network
|
352 |
|
353 |
|
354 |
|
355 |
-
|
356 |
def run_inversion(dict, pcs, epochs, weight_decay,lr):
|
357 |
global network
|
358 |
-
|
359 |
-
|
360 |
-
|
361 |
-
network = invert( dict, pcs, epochs, weight_decay,lr)
|
362 |
|
363 |
|
364 |
#sample an image
|
365 |
prompt = "sks person"
|
366 |
-
negative_prompt = "low quality, blurry, unfinished, nudity
|
367 |
seed = 5
|
368 |
cfg = 3.0
|
369 |
steps = 50
|
370 |
-
|
371 |
torch.save(network.proj, "model.pt" )
|
372 |
-
return
|
|
|
373 |
|
374 |
-
|
375 |
-
|
376 |
|
377 |
def file_upload(file):
|
378 |
global unet
|
379 |
del unet
|
380 |
global network
|
381 |
global device
|
382 |
-
global original_image
|
383 |
-
|
384 |
|
385 |
|
386 |
|
@@ -393,39 +367,38 @@ def file_upload(file):
|
|
393 |
|
394 |
unet, _, _, _, _ = load_models(device)
|
395 |
|
396 |
-
|
397 |
-
network = LoRAw2w( proj, mean, std, v[:, :
|
398 |
unet,
|
399 |
rank=1,
|
400 |
multiplier=1.0,
|
401 |
alpha=27.0,
|
402 |
train_method="xattn-strict"
|
403 |
).to(device, torch.bfloat16)
|
404 |
-
|
405 |
|
406 |
prompt = "sks person"
|
407 |
-
negative_prompt = "low quality, blurry, unfinished, nudity
|
408 |
seed = 5
|
409 |
cfg = 3.0
|
410 |
steps = 50
|
411 |
-
|
412 |
-
return
|
413 |
-
|
414 |
-
|
415 |
|
416 |
|
|
|
417 |
|
418 |
|
419 |
|
420 |
|
421 |
intro = """
|
422 |
<div style="display: flex;align-items: center;justify-content: center">
|
423 |
-
<
|
|
|
424 |
</div>
|
425 |
<p style="font-size: 0.95rem;margin: 0rem;line-height: 1.2em;margin-top:1em;display: inline-block">
|
426 |
-
<a href="https://snap-research.github.io/weights2weights/" target="_blank">
|
427 |
|
|
428 |
-
<a href="https://github.com/snap-research/weights2weights" target="_blank">Code</a> |
|
429 |
<a href="https://huggingface.co/spaces/Snapchat/w2w-demo?duplicate=true" target="_blank" style="
|
430 |
display: inline-block;
|
431 |
">
|
@@ -437,115 +410,86 @@ intro = """
|
|
437 |
|
438 |
with gr.Blocks(css="style.css") as demo:
|
439 |
gr.HTML(intro)
|
440 |
-
|
441 |
-
|
442 |
-
|
443 |
-
|
444 |
-
|
445 |
-
|
446 |
-
|
|
|
|
|
447 |
sample = gr.Button("🎲 Sample New Model")
|
448 |
-
|
449 |
-
|
450 |
-
|
451 |
-
|
452 |
-
with gr.Column():
|
453 |
-
image_slider1 = ImageSlider(position=0.5, type="pil", height=512, width=512, label= "Reference Identity | Generated Samples by User")
|
454 |
-
|
455 |
-
prompt1 = gr.Textbox(label="Prompt",
|
456 |
-
info="Make sure to include 'sks person'" ,
|
457 |
-
placeholder="sks person",
|
458 |
-
value="sks person")
|
459 |
-
seed1 = gr.Number(value=5, label="Seed", precision=0, interactive=True)
|
460 |
-
|
461 |
|
|
|
|
|
|
|
|
|
|
|
|
|
462 |
|
463 |
-
|
464 |
-
with gr.Row():
|
465 |
-
a1_1 = gr.Slider(label="- Young +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True)
|
466 |
-
a2_1 = gr.Slider(label="- Pointy Nose +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True)
|
467 |
-
with gr.Row():
|
468 |
-
a3_1 = gr.Slider(label="- Curly Hair +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True)
|
469 |
-
a4_1 = gr.Slider(label="- Thick Eyebrows +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True)
|
470 |
-
|
471 |
-
|
472 |
-
|
473 |
-
with gr.Accordion("Advanced Options", open=False):
|
474 |
-
cfg1= gr.Slider(label="CFG", value=3.0, step=0.1, minimum=0, maximum=10, interactive=True)
|
475 |
-
steps1 = gr.Slider(label="Inference Steps", value=50, step=1, minimum=0, maximum=100, interactive=True)
|
476 |
-
negative_prompt1 = gr.Textbox(label="Negative Prompt", placeholder="low quality, blurry, unfinished, nudity, weapon", value="low quality, blurry, unfinished, nudity, weapon")
|
477 |
-
injection_step1 = gr.Slider(label="Injection Step", value=800, step=1, minimum=0, maximum=1000, interactive=True)
|
478 |
-
|
479 |
-
|
480 |
-
|
481 |
-
|
482 |
-
submit1 = gr.Button("Generate")
|
483 |
-
|
484 |
-
with gr.Tab("Inversion"):
|
485 |
-
gr.Markdown("""
|
486 |
-
Upload an image and optionally define a mask by drawing over the face. Then click `invert` to get started ✨
|
487 |
-
""")
|
488 |
-
with gr.Row():
|
489 |
with gr.Column():
|
490 |
-
|
491 |
-
|
492 |
-
|
493 |
-
|
494 |
-
|
495 |
-
|
496 |
-
file_output2 = gr.File(label="Download Inverted Model", container=True, interactive=False)
|
497 |
|
498 |
-
|
499 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
500 |
|
501 |
-
|
502 |
-
|
503 |
-
placeholder="sks person",
|
504 |
-
value="sks person")
|
505 |
-
seed2 = gr.Number(value=5, label="Seed", precision=0, interactive=True)
|
506 |
|
507 |
-
|
508 |
-
|
509 |
|
510 |
-
|
511 |
-
|
512 |
-
a2_2 = gr.Slider(label="- Pointy Nose +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True)
|
513 |
-
with gr.Row():
|
514 |
-
a3_2 = gr.Slider(label="- Curly Hair +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True)
|
515 |
-
a4_2 = gr.Slider(label="- Thick Eyebrows +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True)
|
516 |
|
517 |
-
|
518 |
-
|
519 |
-
with gr.Accordion("Advanced Options", open=False):
|
520 |
-
cfg2= gr.Slider(label="CFG", value=3.0, step=0.1, minimum=0, maximum=10, interactive=True)
|
521 |
-
steps2 = gr.Slider(label="Inference Steps", value=50, step=1, minimum=0, maximum=100, interactive=True)
|
522 |
-
negative_prompt2 = gr.Textbox(label="Negative Prompt", placeholder="low quality, blurry, unfinished, nudity, weapon", value="low quality, blurry, unfinished, nudity, weapon")
|
523 |
-
injection_step2 = gr.Slider(label="Injection Step", value=800, step=1, minimum=0, maximum=1000, interactive=True)
|
524 |
-
|
525 |
-
|
526 |
-
|
527 |
|
528 |
-
|
529 |
|
|
|
|
|
530 |
|
531 |
|
532 |
-
|
533 |
-
|
534 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
535 |
|
536 |
|
537 |
|
538 |
-
|
539 |
-
sample.click(fn=sample_then_run, outputs=[image_slider1, file_output1])
|
540 |
-
submit1.click(fn=edit_inference, inputs=[ prompt1, negative_prompt1, cfg1, steps1, seed1, injection_step1, a1_1, a2_1, a3_1, a4_1], outputs=image_slider1)
|
541 |
-
file_input.change(fn=file_upload, inputs=file_input, outputs = image_slider1)
|
542 |
|
543 |
|
544 |
-
invert_button.click(fn=run_inversion, inputs=[input_image, pcs, epochs, weight_decay,lr], outputs = [image_slider2, file_output2])
|
545 |
-
submit2.click(fn=edit_inference, inputs=[ prompt2, negative_prompt2, cfg2, steps2, seed2, injection_step2, a1_2, a2_2, a3_2, a4_2], outputs=image_slider2)
|
546 |
|
547 |
|
548 |
-
|
549 |
|
550 |
|
551 |
|
|
|
1 |
import os
|
2 |
+
os.system("pip uninstall -y gradio")
|
3 |
+
os.system('pip install gradio==3.43.1')
|
|
|
4 |
import torch
|
5 |
import torchvision
|
6 |
import torchvision.transforms as transforms
|
7 |
from torch.utils.data import Dataset, DataLoader
|
8 |
import gradio as gr
|
9 |
import sys
|
|
|
10 |
import tqdm
|
11 |
sys.path.append(os.path.abspath(os.path.join("", "..")))
|
|
|
12 |
import gc
|
13 |
import warnings
|
14 |
warnings.filterwarnings("ignore")
|
15 |
from PIL import Image
|
16 |
+
import numpy as np
|
17 |
+
from utils import load_models
|
18 |
from editing import get_direction, debias
|
19 |
from sampling import sample_weights
|
20 |
from lora_w2w import LoRAw2w
|
21 |
from huggingface_hub import snapshot_download
|
22 |
+
import spaces
|
23 |
+
|
24 |
|
25 |
|
26 |
global device
|
|
|
31 |
global tokenizer
|
32 |
global noise_scheduler
|
33 |
global network
|
|
|
34 |
device = "cuda:0"
|
35 |
generator = torch.Generator(device=device)
|
36 |
+
|
|
|
37 |
|
38 |
|
39 |
models_path = snapshot_download(repo_id="Snapchat/w2w")
|
|
|
122 |
global pointy
|
123 |
global wavy
|
124 |
global large
|
|
|
|
|
125 |
|
126 |
original_weights = network.proj.clone()
|
127 |
|
|
|
128 |
#pad to same number of PCs
|
129 |
pcs_original = original_weights.shape[1]
|
130 |
pcs_edits = young.shape[1]
|
|
|
135 |
large_pad = torch.cat((large, padding), 1)
|
136 |
|
137 |
|
138 |
+
edited_weights = original_weights+a1*1e6*young_pad+a2*1e6*pointy_pad+a3*1e6*wavy_pad+a4*2e6*large_pad
|
139 |
|
140 |
generator = generator.manual_seed(seed)
|
141 |
latents = torch.randn(
|
|
|
191 |
#reset weights back to original
|
192 |
network.proj = torch.nn.Parameter(original_weights)
|
193 |
network.reset()
|
194 |
+
|
195 |
+
return image
|
196 |
|
197 |
def sample_then_run():
|
|
|
198 |
sample_model()
|
199 |
prompt = "sks person"
|
200 |
negative_prompt = "low quality, blurry, unfinished, nudity, weapon"
|
201 |
seed = 5
|
202 |
cfg = 3.0
|
203 |
steps = 50
|
204 |
+
image = inference( prompt, negative_prompt, cfg, steps, seed)
|
205 |
torch.save(network.proj, "model.pt" )
|
206 |
+
return image, "model.pt"
|
|
|
|
|
207 |
|
208 |
|
209 |
global young
|
|
|
266 |
image = self.transform(image)
|
267 |
return image
|
268 |
|
269 |
+
def invert(image, mask, pcs=10000, epochs=400, weight_decay = 1e-10, lr=1e-1):
|
270 |
global unet
|
271 |
del unet
|
272 |
global network
|
|
|
|
|
|
|
|
|
273 |
unet, _, _, _, _ = load_models(device)
|
274 |
|
275 |
proj = torch.zeros(1,pcs).bfloat16().to(device)
|
|
|
281 |
train_method="xattn-strict"
|
282 |
).to(device, torch.bfloat16)
|
283 |
|
|
|
|
|
|
|
|
|
284 |
### load mask
|
285 |
mask = transforms.Resize((64,64), interpolation=transforms.InterpolationMode.BILINEAR)(mask)
|
286 |
mask = torchvision.transforms.functional.pil_to_tensor(mask).unsqueeze(0).to(device).bfloat16()[:,0,:,:].unsqueeze(1)
|
287 |
### check if an actual mask was draw, otherwise mask is just all ones
|
288 |
if torch.sum(mask) == 0:
|
289 |
mask = torch.ones((1,1,64,64)).to(device).bfloat16()
|
290 |
+
|
|
|
291 |
### single image dataset
|
292 |
image_transforms = transforms.Compose([transforms.Resize(512, interpolation=transforms.InterpolationMode.BILINEAR),
|
293 |
transforms.RandomCrop(512),
|
|
|
295 |
transforms.Normalize([0.5], [0.5])])
|
296 |
|
297 |
|
298 |
+
train_dataset = CustomImageDataset(image, transform=image_transforms)
|
299 |
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=1, shuffle=True)
|
300 |
|
|
|
|
|
301 |
### optimizer
|
302 |
optim = torch.optim.Adam(network.parameters(), lr=lr, weight_decay=weight_decay)
|
303 |
|
|
|
327 |
optim.step()
|
328 |
|
329 |
### return optimized network
|
|
|
330 |
return network
|
331 |
|
332 |
|
333 |
|
|
|
334 |
def run_inversion(dict, pcs, epochs, weight_decay,lr):
|
335 |
global network
|
336 |
+
init_image = dict["image"].convert("RGB").resize((512, 512))
|
337 |
+
mask = dict["mask"].convert("RGB").resize((512, 512))
|
338 |
+
network = invert([init_image], mask, pcs, epochs, weight_decay,lr)
|
|
|
339 |
|
340 |
|
341 |
#sample an image
|
342 |
prompt = "sks person"
|
343 |
+
negative_prompt = "low quality, blurry, unfinished, nudity"
|
344 |
seed = 5
|
345 |
cfg = 3.0
|
346 |
steps = 50
|
347 |
+
image = inference( prompt, negative_prompt, cfg, steps, seed)
|
348 |
torch.save(network.proj, "model.pt" )
|
349 |
+
return image, "model.pt"
|
350 |
+
|
351 |
|
|
|
|
|
352 |
|
353 |
def file_upload(file):
|
354 |
global unet
|
355 |
del unet
|
356 |
global network
|
357 |
global device
|
|
|
|
|
358 |
|
359 |
|
360 |
|
|
|
367 |
|
368 |
unet, _, _, _, _ = load_models(device)
|
369 |
|
370 |
+
|
371 |
+
network = LoRAw2w( proj, mean, std, v[:, :pcs],
|
372 |
unet,
|
373 |
rank=1,
|
374 |
multiplier=1.0,
|
375 |
alpha=27.0,
|
376 |
train_method="xattn-strict"
|
377 |
).to(device, torch.bfloat16)
|
378 |
+
|
379 |
|
380 |
prompt = "sks person"
|
381 |
+
negative_prompt = "low quality, blurry, unfinished, nudity"
|
382 |
seed = 5
|
383 |
cfg = 3.0
|
384 |
steps = 50
|
385 |
+
image = inference( prompt, negative_prompt, cfg, steps, seed)
|
386 |
+
return image
|
|
|
|
|
387 |
|
388 |
|
389 |
+
|
390 |
|
391 |
|
392 |
|
393 |
|
394 |
intro = """
|
395 |
<div style="display: flex;align-items: center;justify-content: center">
|
396 |
+
<h1 style="margin-left: 12px;text-align: center;margin-bottom: 7px;display: inline-block">weights2weights</h1>
|
397 |
+
<h3 style="display: inline-block;margin-left: 10px;margin-top: 6px;font-weight: 500">Interpreting the Weight Space of Customized Diffusion Models</h3>
|
398 |
</div>
|
399 |
<p style="font-size: 0.95rem;margin: 0rem;line-height: 1.2em;margin-top:1em;display: inline-block">
|
400 |
+
<a href="https://snap-research.github.io/weights2weights/" target="_blank">project page</a> | <a href="https://arxiv.org/abs/2406.09413" target="_blank">paper</a>
|
401 |
|
|
|
|
402 |
<a href="https://huggingface.co/spaces/Snapchat/w2w-demo?duplicate=true" target="_blank" style="
|
403 |
display: inline-block;
|
404 |
">
|
|
|
410 |
|
411 |
with gr.Blocks(css="style.css") as demo:
|
412 |
gr.HTML(intro)
|
413 |
+
|
414 |
+
gr.Markdown("""<div style="text-align: justify;"> Click below to sample an identity-encoding model, or upload an image below and click \"invert\". You can also optionally draw over the face to define a mask. To use model previously downloaded from this demo see \"Uplaoding a model\" in the Advanced options""")
|
415 |
+
with gr.Column():
|
416 |
+
with gr.Row():
|
417 |
+
with gr.Column():
|
418 |
+
input_image = gr.Image(source='upload', elem_id="image_upload", tool='sketch', type='pil', label="Upload image and draw to define mask",
|
419 |
+
height=512, width=512, brush_color='#00FFFF', mask_opacity=0.6)
|
420 |
+
|
421 |
+
with gr.Row():
|
422 |
sample = gr.Button("🎲 Sample New Model")
|
423 |
+
invert_button = gr.Button("⬆️ Invert")
|
424 |
+
with gr.Column():
|
425 |
+
gallery = gr.Image(label="Image",height=512, width=512, interactive=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
426 |
|
427 |
+
prompt = gr.Textbox(label="Prompt",
|
428 |
+
info="Make sure to include 'sks person'" ,
|
429 |
+
placeholder="sks person",
|
430 |
+
value="sks person")
|
431 |
+
|
432 |
+
seed = gr.Number(value=5, label="Seed", precision=0, interactive=True)
|
433 |
|
434 |
+
# Editing
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
435 |
with gr.Column():
|
436 |
+
with gr.Row():
|
437 |
+
a1 = gr.Slider(label="- Young +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True)
|
438 |
+
a2 = gr.Slider(label="- Pointy Nose +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True)
|
439 |
+
with gr.Row():
|
440 |
+
a3 = gr.Slider(label="- Curly Hair +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True)
|
441 |
+
a4 = gr.Slider(label="- Thick Eyebrows +", value=0, step=0.001, minimum=-1, maximum=1, interactive=True)
|
|
|
442 |
|
443 |
+
|
444 |
+
with gr.Accordion("Advanced Options", open=False):
|
445 |
+
with gr.Tab("Inversion"):
|
446 |
+
with gr.Row():
|
447 |
+
lr = gr.Number(value=1e-1, label="Learning Rate", interactive=True)
|
448 |
+
pcs = gr.Slider(label="# Principal Components", value=10000, step=1, minimum=1, maximum=10000, interactive=True)
|
449 |
+
with gr.Row():
|
450 |
+
epochs = gr.Slider(label="Epochs", value=800, step=1, minimum=1, maximum=2000, interactive=True)
|
451 |
+
weight_decay = gr.Number(value=1e-10, label="Weight Decay", interactive=True)
|
452 |
+
with gr.Tab("Sampling"):
|
453 |
+
with gr.Row():
|
454 |
+
cfg= gr.Slider(label="CFG", value=3.0, step=0.1, minimum=0, maximum=10, interactive=True)
|
455 |
+
steps = gr.Slider(label="Inference Steps", value=50, step=1, minimum=0, maximum=100, interactive=True)
|
456 |
+
with gr.Row():
|
457 |
+
negative_prompt = gr.Textbox(label="Negative Prompt", placeholder="low quality, blurry, unfinished, nudity, weapon", value="low quality, blurry, unfinished, nudity, weapon")
|
458 |
+
injection_step = gr.Slider(label="Injection Step", value=800, step=1, minimum=0, maximum=1000, interactive=True)
|
459 |
|
460 |
+
with gr.Tab("Uploading a model"):
|
461 |
+
gr.Markdown("""<div style="text-align: justify;">Upload a model below downloaded from this demo.""")
|
|
|
|
|
|
|
462 |
|
463 |
+
file_input = gr.File(label="Upload Model", container=True)
|
|
|
464 |
|
465 |
+
submit = gr.Button("Generate")
|
466 |
+
|
|
|
|
|
|
|
|
|
467 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
468 |
|
469 |
+
gr.Markdown("""<div style="text-align: justify;"> After sampling a new model or inverting, you can download the model below.""")
|
470 |
|
471 |
+
with gr.Row():
|
472 |
+
file_output = gr.File(label="Download Sampled Model", container=True, interactive=False)
|
473 |
|
474 |
|
475 |
+
|
476 |
+
|
477 |
+
|
478 |
+
invert_button.click(fn=run_inversion,
|
479 |
+
inputs=[input_image, pcs, epochs, weight_decay,lr],
|
480 |
+
outputs = [gallery, file_output])
|
481 |
+
sample.click(fn=sample_then_run, outputs=[gallery, file_output])
|
482 |
+
submit.click(
|
483 |
+
fn=edit_inference, inputs=[prompt, negative_prompt, cfg, steps, seed, injection_step, a1, a2, a3, a4], outputs=[gallery]
|
484 |
+
)
|
485 |
+
file_input.change(fn=file_upload, inputs=file_input, outputs = input_image)
|
486 |
|
487 |
|
488 |
|
|
|
|
|
|
|
|
|
489 |
|
490 |
|
|
|
|
|
491 |
|
492 |
|
|
|
493 |
|
494 |
|
495 |
|